diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClient.java b/spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClient.java index 900bd402f0b..fe01246f775 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClient.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClient.java @@ -16,12 +16,14 @@ package org.springframework.test.web.servlet.client; +import java.io.IOException; import java.net.URI; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.time.ZonedDateTime; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; @@ -33,13 +35,18 @@ import org.jspecify.annotations.Nullable; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRequest; import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpRequestExecution; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; import org.springframework.test.json.JsonAssert; import org.springframework.test.json.JsonComparator; import org.springframework.test.json.JsonCompareMode; import org.springframework.test.util.AssertionErrors; import org.springframework.test.util.ExceptionCollector; import org.springframework.test.util.XmlExpectationsHelper; +import org.springframework.util.Assert; import org.springframework.util.MimeType; import org.springframework.util.MultiValueMap; import org.springframework.web.client.RestClient; @@ -56,6 +63,8 @@ class DefaultRestTestClient implements RestTestClient { private final RestClient restClient; + private final WiretapInterceptor wiretapInterceptor = new WiretapInterceptor(); + private final Consumer> entityResultConsumer; private final DefaultRestTestClientBuilder restTestClientBuilder; @@ -67,7 +76,7 @@ class DefaultRestTestClient implements RestTestClient { RestClient.Builder builder, Consumer> entityResultConsumer, DefaultRestTestClientBuilder restTestClientBuilder) { - this.restClient = builder.build(); + this.restClient = builder.requestInterceptor(this.wiretapInterceptor).build(); this.entityResultConsumer = entityResultConsumer; this.restTestClientBuilder = restTestClientBuilder; } @@ -128,12 +137,14 @@ class DefaultRestTestClient implements RestTestClient { private final RestClient.RequestBodyUriSpec requestHeadersUriSpec; + private final String requestId; + private @Nullable String uriTemplate; DefaultRequestBodyUriSpec(RestClient.RequestBodyUriSpec spec) { this.requestHeadersUriSpec = spec; - String requestId = String.valueOf(requestIndex.incrementAndGet()); - this.requestHeadersUriSpec.header(RESTTESTCLIENT_REQUEST_ID, requestId); + this.requestId = String.valueOf(requestIndex.incrementAndGet()); + this.requestHeadersUriSpec.header(RESTTESTCLIENT_REQUEST_ID, this.requestId); } @Override @@ -252,7 +263,10 @@ class DefaultRestTestClient implements RestTestClient { public ResponseSpec exchange() { return new DefaultResponseSpec( this.requestHeadersUriSpec.exchangeForRequiredValue( - (request, response) -> new ExchangeResult(request, response, this.uriTemplate), false), + (request, response) -> { + byte[] requestBody = wiretapInterceptor.getRequestContent(this.requestId); + return new ExchangeResult(request, response, this.uriTemplate, requestBody); + }, false), DefaultRestTestClient.this.entityResultConsumer); } } @@ -476,4 +490,29 @@ class DefaultRestTestClient implements RestTestClient { return this.result; } } + + + private static class WiretapInterceptor implements ClientHttpRequestInterceptor { + + private final Map requestContentMap = new ConcurrentHashMap<>(); + + @Override + public ClientHttpResponse intercept( + HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException { + + String header = RestTestClient.RESTTESTCLIENT_REQUEST_ID; + String requestId = request.getHeaders().getFirst(header); + Assert.state(requestId != null, () -> "No \"" + header + "\" header"); + this.requestContentMap.put(requestId, body); + return execution.execute(request, body); + } + + public byte[] getRequestContent(String requestId) { + byte[] bytes = this.requestContentMap.remove(requestId); + Assert.state(bytes != null, () -> + "No match for %s=%s".formatted(RestTestClient.RESTTESTCLIENT_REQUEST_ID, requestId)); + return bytes; + } + } + } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClientBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClientBuilder.java index 53168bf60d3..484ebbf698a 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClientBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClientBuilder.java @@ -61,7 +61,7 @@ class DefaultRestTestClientBuilder> implemen } DefaultRestTestClientBuilder(RestClient.Builder restClientBuilder) { - this.restClientBuilder = restClientBuilder; + this.restClientBuilder = restClientBuilder.bufferContent((uri, httpMethod) -> true); } DefaultRestTestClientBuilder(DefaultRestTestClientBuilder other) { diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/client/ExchangeResult.java b/spring-test/src/main/java/org/springframework/test/web/servlet/client/ExchangeResult.java index 771a9eb1e0a..99cf979c0fa 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/client/ExchangeResult.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/client/ExchangeResult.java @@ -19,6 +19,8 @@ package org.springframework.test.web.servlet.client; import java.io.IOException; import java.net.HttpCookie; import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Optional; import java.util.regex.Matcher; @@ -34,10 +36,12 @@ import org.springframework.http.HttpMethod; import org.springframework.http.HttpRequest; import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; import org.springframework.http.ResponseCookie; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.util.StreamUtils; import org.springframework.web.client.RestClient.RequestHeadersSpec.ConvertibleClientHttpResponse; /** @@ -54,6 +58,10 @@ public class ExchangeResult { private static final Pattern PARTITIONED_PATTERN = Pattern.compile("(?i).*;\\s*Partitioned(\\s*;.*|\\s*)$"); + private static final List PRINTABLE_MEDIA_TYPES = List.of( + MediaType.parseMediaType("application/*+json"), MediaType.APPLICATION_XML, + MediaType.parseMediaType("text/*"), MediaType.APPLICATION_FORM_URLENCODED); + private static final Log logger = LogFactory.getLog(ExchangeResult.class); @@ -64,22 +72,26 @@ public class ExchangeResult { private final @Nullable String uriTemplate; + private final byte[] requestBody; + /** Ensure single logging; for example, for expectAll. */ private boolean diagnosticsLogged; ExchangeResult( - HttpRequest request, ConvertibleClientHttpResponse response, @Nullable String uriTemplate) { + HttpRequest request, ConvertibleClientHttpResponse response, @Nullable String uriTemplate, + byte[] requestBody) { Assert.notNull(request, "HttpRequest must not be null"); Assert.notNull(response, "ClientHttpResponse must not be null"); this.request = request; this.clientResponse = response; this.uriTemplate = uriTemplate; + this.requestBody = requestBody; } ExchangeResult(ExchangeResult result) { - this(result.request, result.clientResponse, result.uriTemplate); + this(result.request, result.clientResponse, result.uriTemplate, result.requestBody); this.diagnosticsLogged = result.diagnosticsLogged; } @@ -159,6 +171,13 @@ public class ExchangeResult { .build(); } + /** + * Return the raw request body content written through the request. + */ + public byte[] getRequestBodyContent() { + return this.requestBody; + } + /** * Provide access to the response. For internal use to decode the body. */ @@ -166,6 +185,18 @@ public class ExchangeResult { return this.clientResponse; } + /** + * Return the raw response body read through the response. + */ + public byte[] getResponseBodyContent() { + try { + return StreamUtils.copyToByteArray(this.clientResponse.getBody()); + } + catch (IOException ex) { + throw new IllegalStateException("Failed to get response content: " + ex); + } + } + /** * Execute the given Runnable, catch any {@link AssertionError}, log details * about the request and response at ERROR level under the class log @@ -190,8 +221,12 @@ public class ExchangeResult { "> " + getMethod() + " " + getUrl() + "\n" + "> " + formatHeaders(getRequestHeaders(), "\n> ") + "\n" + "\n" + + formatBody(getRequestHeaders().getContentType(), this.requestBody) + "\n" + + "\n" + "< " + formatStatus(getStatus()) + "\n" + - "< " + formatHeaders(getResponseHeaders(), "\n< ") + "\n"; + "< " + formatHeaders(getResponseHeaders(), "\n< ") + "\n" + + "\n" + + formatBody(getResponseHeaders().getContentType(), getResponseBodyContent()) +"\n"; } private String formatStatus(HttpStatusCode statusCode) { @@ -208,4 +243,18 @@ public class ExchangeResult { .collect(Collectors.joining(delimiter)); } + private String formatBody(@Nullable MediaType contentType, byte[] bytes) { + if (contentType == null) { + return bytes.length + " bytes of content (unknown content-type)."; + } + Charset charset = contentType.getCharset(); + if (charset != null) { + return new String(bytes, charset); + } + if (PRINTABLE_MEDIA_TYPES.stream().anyMatch(contentType::isCompatibleWith)) { + return new String(bytes, StandardCharsets.UTF_8); + } + return bytes.length + " bytes of content."; + } + } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/client/RestTestClientTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/client/RestTestClientTests.java index 6c72769348f..ee1393c16b1 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/client/RestTestClientTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/client/RestTestClientTests.java @@ -36,6 +36,8 @@ import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @@ -317,6 +319,19 @@ class RestTestClientTests { }); assertThat(result.getResponseBody().get("uri")).isEqualTo("/test"); } + + @Test + void testResultContent() { + String body = "body-in"; + EntityExchangeResult result = RestTestClientTests.this.client.post().uri("/body") + .body(body) + .exchange() + .expectStatus().isOk() + .expectBody(String.class) + .returnResult(); + assertThat(result.getRequestBodyContent()).isEqualTo(body.getBytes(StandardCharsets.UTF_8)); + assertThat(result.getResponseBodyContent()).isEqualTo((body + "-out").getBytes(StandardCharsets.UTF_8)); + } } @@ -325,14 +340,20 @@ class RestTestClientTests { @RequestMapping(path = {"/test", "/test/*"}, produces = "application/json") public Map handle( - @RequestHeader HttpHeaders headers, - HttpServletRequest request, HttpServletResponse response) { + @RequestHeader HttpHeaders headers, HttpServletRequest request, HttpServletResponse response) { + response.addCookie(new Cookie("session", "abc")); + return Map.of( "method", request.getMethod(), "uri", request.getRequestURI(), "headers", headers.toSingleValueMap() ); } + + @PostMapping("/body") + public String echoBody(@RequestBody String body) { + return body + "-out"; + } } }