Browse Source

Provide access to raw content in RestTestClient

Closes gh-35399
pull/35503/head
rstoyanchev 3 months ago
parent
commit
11dd0d6118
  1. 47
      spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClient.java
  2. 2
      spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClientBuilder.java
  3. 55
      spring-test/src/main/java/org/springframework/test/web/servlet/client/ExchangeResult.java
  4. 25
      spring-test/src/test/java/org/springframework/test/web/servlet/client/RestTestClientTests.java

47
spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClient.java

@ -16,12 +16,14 @@ @@ -16,12 +16,14 @@
package org.springframework.test.web.servlet.client;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.time.ZonedDateTime;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;
@ -33,13 +35,18 @@ import org.jspecify.annotations.Nullable; @@ -33,13 +35,18 @@ import org.jspecify.annotations.Nullable;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRequest;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpRequestExecution;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.test.json.JsonAssert;
import org.springframework.test.json.JsonComparator;
import org.springframework.test.json.JsonCompareMode;
import org.springframework.test.util.AssertionErrors;
import org.springframework.test.util.ExceptionCollector;
import org.springframework.test.util.XmlExpectationsHelper;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestClient;
@ -56,6 +63,8 @@ class DefaultRestTestClient implements RestTestClient { @@ -56,6 +63,8 @@ class DefaultRestTestClient implements RestTestClient {
private final RestClient restClient;
private final WiretapInterceptor wiretapInterceptor = new WiretapInterceptor();
private final Consumer<EntityExchangeResult<?>> entityResultConsumer;
private final DefaultRestTestClientBuilder<?> restTestClientBuilder;
@ -67,7 +76,7 @@ class DefaultRestTestClient implements RestTestClient { @@ -67,7 +76,7 @@ class DefaultRestTestClient implements RestTestClient {
RestClient.Builder builder, Consumer<EntityExchangeResult<?>> entityResultConsumer,
DefaultRestTestClientBuilder<?> restTestClientBuilder) {
this.restClient = builder.build();
this.restClient = builder.requestInterceptor(this.wiretapInterceptor).build();
this.entityResultConsumer = entityResultConsumer;
this.restTestClientBuilder = restTestClientBuilder;
}
@ -128,12 +137,14 @@ class DefaultRestTestClient implements RestTestClient { @@ -128,12 +137,14 @@ class DefaultRestTestClient implements RestTestClient {
private final RestClient.RequestBodyUriSpec requestHeadersUriSpec;
private final String requestId;
private @Nullable String uriTemplate;
DefaultRequestBodyUriSpec(RestClient.RequestBodyUriSpec spec) {
this.requestHeadersUriSpec = spec;
String requestId = String.valueOf(requestIndex.incrementAndGet());
this.requestHeadersUriSpec.header(RESTTESTCLIENT_REQUEST_ID, requestId);
this.requestId = String.valueOf(requestIndex.incrementAndGet());
this.requestHeadersUriSpec.header(RESTTESTCLIENT_REQUEST_ID, this.requestId);
}
@Override
@ -252,7 +263,10 @@ class DefaultRestTestClient implements RestTestClient { @@ -252,7 +263,10 @@ class DefaultRestTestClient implements RestTestClient {
public ResponseSpec exchange() {
return new DefaultResponseSpec(
this.requestHeadersUriSpec.exchangeForRequiredValue(
(request, response) -> new ExchangeResult(request, response, this.uriTemplate), false),
(request, response) -> {
byte[] requestBody = wiretapInterceptor.getRequestContent(this.requestId);
return new ExchangeResult(request, response, this.uriTemplate, requestBody);
}, false),
DefaultRestTestClient.this.entityResultConsumer);
}
}
@ -476,4 +490,29 @@ class DefaultRestTestClient implements RestTestClient { @@ -476,4 +490,29 @@ class DefaultRestTestClient implements RestTestClient {
return this.result;
}
}
private static class WiretapInterceptor implements ClientHttpRequestInterceptor {
private final Map<String, byte[]> requestContentMap = new ConcurrentHashMap<>();
@Override
public ClientHttpResponse intercept(
HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException {
String header = RestTestClient.RESTTESTCLIENT_REQUEST_ID;
String requestId = request.getHeaders().getFirst(header);
Assert.state(requestId != null, () -> "No \"" + header + "\" header");
this.requestContentMap.put(requestId, body);
return execution.execute(request, body);
}
public byte[] getRequestContent(String requestId) {
byte[] bytes = this.requestContentMap.remove(requestId);
Assert.state(bytes != null, () ->
"No match for %s=%s".formatted(RestTestClient.RESTTESTCLIENT_REQUEST_ID, requestId));
return bytes;
}
}
}

2
spring-test/src/main/java/org/springframework/test/web/servlet/client/DefaultRestTestClientBuilder.java

@ -61,7 +61,7 @@ class DefaultRestTestClientBuilder<B extends RestTestClient.Builder<B>> implemen @@ -61,7 +61,7 @@ class DefaultRestTestClientBuilder<B extends RestTestClient.Builder<B>> implemen
}
DefaultRestTestClientBuilder(RestClient.Builder restClientBuilder) {
this.restClientBuilder = restClientBuilder;
this.restClientBuilder = restClientBuilder.bufferContent((uri, httpMethod) -> true);
}
DefaultRestTestClientBuilder(DefaultRestTestClientBuilder<B> other) {

55
spring-test/src/main/java/org/springframework/test/web/servlet/client/ExchangeResult.java

@ -19,6 +19,8 @@ package org.springframework.test.web.servlet.client; @@ -19,6 +19,8 @@ package org.springframework.test.web.servlet.client;
import java.io.IOException;
import java.net.HttpCookie;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
@ -34,10 +36,12 @@ import org.springframework.http.HttpMethod; @@ -34,10 +36,12 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRequest;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseCookie;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestClient.RequestHeadersSpec.ConvertibleClientHttpResponse;
/**
@ -54,6 +58,10 @@ public class ExchangeResult { @@ -54,6 +58,10 @@ public class ExchangeResult {
private static final Pattern PARTITIONED_PATTERN = Pattern.compile("(?i).*;\\s*Partitioned(\\s*;.*|\\s*)$");
private static final List<MediaType> PRINTABLE_MEDIA_TYPES = List.of(
MediaType.parseMediaType("application/*+json"), MediaType.APPLICATION_XML,
MediaType.parseMediaType("text/*"), MediaType.APPLICATION_FORM_URLENCODED);
private static final Log logger = LogFactory.getLog(ExchangeResult.class);
@ -64,22 +72,26 @@ public class ExchangeResult { @@ -64,22 +72,26 @@ public class ExchangeResult {
private final @Nullable String uriTemplate;
private final byte[] requestBody;
/** Ensure single logging; for example, for expectAll. */
private boolean diagnosticsLogged;
ExchangeResult(
HttpRequest request, ConvertibleClientHttpResponse response, @Nullable String uriTemplate) {
HttpRequest request, ConvertibleClientHttpResponse response, @Nullable String uriTemplate,
byte[] requestBody) {
Assert.notNull(request, "HttpRequest must not be null");
Assert.notNull(response, "ClientHttpResponse must not be null");
this.request = request;
this.clientResponse = response;
this.uriTemplate = uriTemplate;
this.requestBody = requestBody;
}
ExchangeResult(ExchangeResult result) {
this(result.request, result.clientResponse, result.uriTemplate);
this(result.request, result.clientResponse, result.uriTemplate, result.requestBody);
this.diagnosticsLogged = result.diagnosticsLogged;
}
@ -159,6 +171,13 @@ public class ExchangeResult { @@ -159,6 +171,13 @@ public class ExchangeResult {
.build();
}
/**
* Return the raw request body content written through the request.
*/
public byte[] getRequestBodyContent() {
return this.requestBody;
}
/**
* Provide access to the response. For internal use to decode the body.
*/
@ -166,6 +185,18 @@ public class ExchangeResult { @@ -166,6 +185,18 @@ public class ExchangeResult {
return this.clientResponse;
}
/**
* Return the raw response body read through the response.
*/
public byte[] getResponseBodyContent() {
try {
return StreamUtils.copyToByteArray(this.clientResponse.getBody());
}
catch (IOException ex) {
throw new IllegalStateException("Failed to get response content: " + ex);
}
}
/**
* Execute the given Runnable, catch any {@link AssertionError}, log details
* about the request and response at ERROR level under the class log
@ -190,8 +221,12 @@ public class ExchangeResult { @@ -190,8 +221,12 @@ public class ExchangeResult {
"> " + getMethod() + " " + getUrl() + "\n" +
"> " + formatHeaders(getRequestHeaders(), "\n> ") + "\n" +
"\n" +
formatBody(getRequestHeaders().getContentType(), this.requestBody) + "\n" +
"\n" +
"< " + formatStatus(getStatus()) + "\n" +
"< " + formatHeaders(getResponseHeaders(), "\n< ") + "\n";
"< " + formatHeaders(getResponseHeaders(), "\n< ") + "\n" +
"\n" +
formatBody(getResponseHeaders().getContentType(), getResponseBodyContent()) +"\n";
}
private String formatStatus(HttpStatusCode statusCode) {
@ -208,4 +243,18 @@ public class ExchangeResult { @@ -208,4 +243,18 @@ public class ExchangeResult {
.collect(Collectors.joining(delimiter));
}
private String formatBody(@Nullable MediaType contentType, byte[] bytes) {
if (contentType == null) {
return bytes.length + " bytes of content (unknown content-type).";
}
Charset charset = contentType.getCharset();
if (charset != null) {
return new String(bytes, charset);
}
if (PRINTABLE_MEDIA_TYPES.stream().anyMatch(contentType::isCompatibleWith)) {
return new String(bytes, StandardCharsets.UTF_8);
}
return bytes.length + " bytes of content.";
}
}

25
spring-test/src/test/java/org/springframework/test/web/servlet/client/RestTestClientTests.java

@ -36,6 +36,8 @@ import org.springframework.core.ParameterizedTypeReference; @@ -36,6 +36,8 @@ import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@ -317,6 +319,19 @@ class RestTestClientTests { @@ -317,6 +319,19 @@ class RestTestClientTests {
});
assertThat(result.getResponseBody().get("uri")).isEqualTo("/test");
}
@Test
void testResultContent() {
String body = "body-in";
EntityExchangeResult<String> result = RestTestClientTests.this.client.post().uri("/body")
.body(body)
.exchange()
.expectStatus().isOk()
.expectBody(String.class)
.returnResult();
assertThat(result.getRequestBodyContent()).isEqualTo(body.getBytes(StandardCharsets.UTF_8));
assertThat(result.getResponseBodyContent()).isEqualTo((body + "-out").getBytes(StandardCharsets.UTF_8));
}
}
@ -325,14 +340,20 @@ class RestTestClientTests { @@ -325,14 +340,20 @@ class RestTestClientTests {
@RequestMapping(path = {"/test", "/test/*"}, produces = "application/json")
public Map<String, Object> handle(
@RequestHeader HttpHeaders headers,
HttpServletRequest request, HttpServletResponse response) {
@RequestHeader HttpHeaders headers, HttpServletRequest request, HttpServletResponse response) {
response.addCookie(new Cookie("session", "abc"));
return Map.of(
"method", request.getMethod(),
"uri", request.getRequestURI(),
"headers", headers.toSingleValueMap()
);
}
@PostMapping("/body")
public String echoBody(@RequestBody String body) {
return body + "-out";
}
}
}

Loading…
Cancel
Save