diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index a5af8fe2c6c..af7ea9133e0 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -756,6 +756,9 @@ final class DefaultRestClient implements RestClient { } private static boolean isStreamingResult(@Nullable Object result) { + if (result instanceof ResponseEntity entity) { + result = entity.getBody(); + } return (result instanceof InputStream || result instanceof InputStreamResource); } diff --git a/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java b/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java index 814f3c3a2eb..12646929c67 100644 --- a/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java @@ -20,15 +20,21 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.net.URI; +import java.util.function.Consumer; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpResponse; @@ -115,18 +121,31 @@ class DefaultRestClientTests { ); } - @Test - void inputStreamBody() throws IOException { + @ParameterizedTest(name = "{0}") + @MethodSource("streamResponseBodies") + void streamingBody(String typeName, Consumer clientConsumer) throws IOException { mockSentRequest(HttpMethod.GET, URL); mockResponseStatus(HttpStatus.OK); - mockResponseBody(BODY, MediaType.TEXT_PLAIN); + mockResponseBody(BODY, MediaType.APPLICATION_OCTET_STREAM); - InputStream result = this.client.get().uri(URL).retrieve().requiredBody(InputStream.class); + clientConsumer.accept(this.client); - assertThat(result).isInstanceOf(InputStream.class); verify(this.response, times(0)).close(); } + static Stream streamResponseBodies() { + return Stream.of( + Arguments.of("InputStream", (Consumer) client -> { + InputStream result = client.get().uri(URL).retrieve().requiredBody(InputStream.class); + assertThat(result).isInstanceOf(InputStream.class); + }), + Arguments.of("ResponseEntity", (Consumer) client -> { + ResponseEntity result = client.get().uri(URL).retrieve().toEntity(InputStream.class); + assertThat(result).isInstanceOf(ResponseEntity.class); + }) + ); + } + private void mockSentRequest(HttpMethod method, String uri) throws IOException { given(this.requestFactory.createRequest(URI.create(uri), method)).willReturn(this.request);