Browse Source

Fix response stream handling in `RestClient`

Prior to this commit, the `DefaultRestClient` implementation would
look at the type of the returned value to decide whether the HTTP
response stream needs to be closed before returning.
As of gh-36380, this checks for `InputStream` and `InputStreamResource`
return types, effectively not considering `ResponseEntity<InputStream>`
and variants.

This commit ensures that `ResponseEntity` types are unwrapped before
checking for streaming types.

Fixes gh-36492
7.0.x
Brian Clozel 3 days ago
parent
commit
d1950f598a
  1. 3
      spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java
  2. 29
      spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java

3
spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java

@ -756,6 +756,9 @@ final class DefaultRestClient implements RestClient {
} }
private static boolean isStreamingResult(@Nullable Object result) { private static boolean isStreamingResult(@Nullable Object result) {
if (result instanceof ResponseEntity<?> entity) {
result = entity.getBody();
}
return (result instanceof InputStream || result instanceof InputStreamResource); return (result instanceof InputStream || result instanceof InputStreamResource);
} }

29
spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java

@ -20,15 +20,21 @@ import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.URI; import java.net.URI;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.ClientHttpResponse;
@ -115,18 +121,31 @@ class DefaultRestClientTests {
); );
} }
@Test @ParameterizedTest(name = "{0}")
void inputStreamBody() throws IOException { @MethodSource("streamResponseBodies")
void streamingBody(String typeName, Consumer<RestClient> clientConsumer) throws IOException {
mockSentRequest(HttpMethod.GET, URL); mockSentRequest(HttpMethod.GET, URL);
mockResponseStatus(HttpStatus.OK); mockResponseStatus(HttpStatus.OK);
mockResponseBody(BODY, MediaType.TEXT_PLAIN); mockResponseBody(BODY, MediaType.APPLICATION_OCTET_STREAM);
InputStream result = this.client.get().uri(URL).retrieve().requiredBody(InputStream.class); clientConsumer.accept(this.client);
assertThat(result).isInstanceOf(InputStream.class);
verify(this.response, times(0)).close(); verify(this.response, times(0)).close();
} }
static Stream<Arguments> streamResponseBodies() {
return Stream.of(
Arguments.of("InputStream", (Consumer<RestClient>) client -> {
InputStream result = client.get().uri(URL).retrieve().requiredBody(InputStream.class);
assertThat(result).isInstanceOf(InputStream.class);
}),
Arguments.of("ResponseEntity<Inpustream>", (Consumer<RestClient>) client -> {
ResponseEntity<InputStream> result = client.get().uri(URL).retrieve().toEntity(InputStream.class);
assertThat(result).isInstanceOf(ResponseEntity.class);
})
);
}
private void mockSentRequest(HttpMethod method, String uri) throws IOException { private void mockSentRequest(HttpMethod method, String uri) throws IOException {
given(this.requestFactory.createRequest(URI.create(uri), method)).willReturn(this.request); given(this.requestFactory.createRequest(URI.create(uri), method)).willReturn(this.request);

Loading…
Cancel
Save