diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index 434916de5f4..a5af8fe2c6c 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -42,6 +42,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; +import org.springframework.core.io.InputStreamResource; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpRequest; @@ -253,6 +254,10 @@ final class DefaultRestClient implements RestClient { } } + if (bodyClass.equals(InputStream.class)) { + return (T) responseWrapper.getBody(); + } + throw new UnknownContentTypeException(bodyType, contentType, responseWrapper.getStatusCode(), responseWrapper.getStatusText(), responseWrapper.getHeaders(), RestClientUtils.getBody(responseWrapper)); @@ -609,7 +614,11 @@ final class DefaultRestClient implements RestClient { clientResponse = clientRequest.execute(); observationContext.setResponse(clientResponse); ConvertibleClientHttpResponse convertibleWrapper = new DefaultConvertibleClientHttpResponse(clientResponse, this.hints); - return exchangeFunction.exchange(clientRequest, convertibleWrapper); + T result = exchangeFunction.exchange(clientRequest, convertibleWrapper); + if (close && isStreamingResult(result)) { + close = false; + } + return result; } catch (IOException ex) { ResourceAccessException resourceAccessException = createResourceAccessException(uri, this.httpMethod, ex); @@ -746,6 +755,10 @@ final class DefaultRestClient implements RestClient { return request; } + private static boolean isStreamingResult(@Nullable Object result) { + return (result instanceof InputStream || result instanceof InputStreamResource); + } + private static ResourceAccessException createResourceAccessException(URI url, HttpMethod method, IOException ex) { StringBuilder msg = new StringBuilder("I/O error on "); msg.append(method.name()); diff --git a/spring-web/src/main/java/org/springframework/web/client/support/RestClientAdapter.java b/spring-web/src/main/java/org/springframework/web/client/support/RestClientAdapter.java index d58fac3a0a5..c98fc680a54 100644 --- a/spring-web/src/main/java/org/springframework/web/client/support/RestClientAdapter.java +++ b/spring-web/src/main/java/org/springframework/web/client/support/RestClientAdapter.java @@ -16,8 +16,6 @@ package org.springframework.web.client.support; -import java.io.IOException; -import java.io.InputStream; import java.net.URI; import java.util.ArrayList; import java.util.List; @@ -75,8 +73,7 @@ public final class RestClientAdapter implements HttpExchangeAdapter { @Override public @Nullable T exchangeForBody(HttpRequestValues values, ParameterizedTypeReference bodyType) { - return (bodyType.getType().equals(InputStream.class) ? - exchangeForInputStream(values) : newRequest(values).retrieve().body(bodyType)); + return newRequest(values).retrieve().body(bodyType); } @Override @@ -86,21 +83,7 @@ public final class RestClientAdapter implements HttpExchangeAdapter { @Override public ResponseEntity exchangeForEntity(HttpRequestValues values, ParameterizedTypeReference bodyType) { - return (bodyType.getType().equals(InputStream.class) ? - exchangeForEntityInputStream(values) : newRequest(values).retrieve().toEntity(bodyType)); - } - - @SuppressWarnings("unchecked") - private T exchangeForInputStream(HttpRequestValues values) { - return (T) newRequest(values).exchange((request, response) -> getInputStream(response), false); - } - - @SuppressWarnings("unchecked") - private ResponseEntity exchangeForEntityInputStream(HttpRequestValues values) { - return (ResponseEntity) newRequest(values).exchangeForRequiredValue((request, response) -> - ResponseEntity.status(response.getStatusCode()) - .headers(response.getHeaders()) - .body(getInputStream(response)), false); + return newRequest(values).retrieve().toEntity(bodyType); } @SuppressWarnings("unchecked") @@ -162,16 +145,6 @@ public final class RestClientAdapter implements HttpExchangeAdapter { return bodySpec; } - private static InputStream getInputStream( - RestClient.RequestHeadersSpec.ConvertibleClientHttpResponse response) throws IOException { - - if (response.getStatusCode().isError()) { - throw response.createException(); - } - return response.getBody(); - } - - /** * Create a {@link RestClientAdapter} for the given {@link RestClient}. diff --git a/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java b/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java index 45847f586f9..44cfc95742d 100644 --- a/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java @@ -18,6 +18,7 @@ package org.springframework.web.client; import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.net.URI; import org.junit.jupiter.api.BeforeEach; @@ -36,6 +37,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Tests for {@link DefaultRestClient}. @@ -117,6 +120,21 @@ class DefaultRestClientTests { ); } + @Test + void inputStreamBody() throws IOException { + mockSentRequest(HttpMethod.GET, "https://example.org"); + mockResponseStatus(HttpStatus.OK); + mockResponseBody("Hello World", MediaType.TEXT_PLAIN); + + InputStream result = this.client.get() + .uri("https://example.org") + .retrieve() + .requiredBody(InputStream.class); + + assertThat(result).isInstanceOf(InputStream.class); + verify(this.response, times(0)).close(); + } + private void mockSentRequest(HttpMethod method, String uri) throws IOException { given(this.requestFactory.createRequest(URI.create(uri), method)).willReturn(this.request);