Browse Source

Fix response stream handling in `RestClient`

Prior to this commit, the `DefaultRestClient` implementation would
look at the type of the returned value to decide whether the HTTP
response stream needs to be closed before returning.
As of gh-36380, this checks for `InputStream` and `InputStreamResource`
return types, effectively not considering `ResponseEntity<InputStream>`
and variants.

This commit ensures that `ResponseEntity` types are unwrapped before
checking for streaming types.

Fixes gh-36492
7.0.x
Brian Clozel 2 days ago
parent
commit
d1950f598a
  1. 3
      spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java
  2. 29
      spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java

3
spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java

@ -756,6 +756,9 @@ final class DefaultRestClient implements RestClient { @@ -756,6 +756,9 @@ final class DefaultRestClient implements RestClient {
}
private static boolean isStreamingResult(@Nullable Object result) {
if (result instanceof ResponseEntity<?> entity) {
result = entity.getBody();
}
return (result instanceof InputStream || result instanceof InputStreamResource);
}

29
spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java

@ -20,15 +20,21 @@ import java.io.ByteArrayInputStream; @@ -20,15 +20,21 @@ import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
@ -115,18 +121,31 @@ class DefaultRestClientTests { @@ -115,18 +121,31 @@ class DefaultRestClientTests {
);
}
@Test
void inputStreamBody() throws IOException {
@ParameterizedTest(name = "{0}")
@MethodSource("streamResponseBodies")
void streamingBody(String typeName, Consumer<RestClient> clientConsumer) throws IOException {
mockSentRequest(HttpMethod.GET, URL);
mockResponseStatus(HttpStatus.OK);
mockResponseBody(BODY, MediaType.TEXT_PLAIN);
mockResponseBody(BODY, MediaType.APPLICATION_OCTET_STREAM);
InputStream result = this.client.get().uri(URL).retrieve().requiredBody(InputStream.class);
clientConsumer.accept(this.client);
assertThat(result).isInstanceOf(InputStream.class);
verify(this.response, times(0)).close();
}
static Stream<Arguments> streamResponseBodies() {
return Stream.of(
Arguments.of("InputStream", (Consumer<RestClient>) client -> {
InputStream result = client.get().uri(URL).retrieve().requiredBody(InputStream.class);
assertThat(result).isInstanceOf(InputStream.class);
}),
Arguments.of("ResponseEntity<Inpustream>", (Consumer<RestClient>) client -> {
ResponseEntity<InputStream> result = client.get().uri(URL).retrieve().toEntity(InputStream.class);
assertThat(result).isInstanceOf(ResponseEntity.class);
})
);
}
private void mockSentRequest(HttpMethod method, String uri) throws IOException {
given(this.requestFactory.createRequest(URI.create(uri), method)).willReturn(this.request);

Loading…
Cancel
Save