diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index 29186d0a19a..d3d007bea96 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -477,7 +477,14 @@ final class DefaultRestClient implements RestClient { @Override public RequestBodySpec body(StreamingHttpOutputMessage.Body body) { - this.body = request -> body.writeTo(request.getBody()); + this.body = request -> { + if (request instanceof StreamingHttpOutputMessage streamingMessage) { + streamingMessage.setBody(body); + } + else { + body.writeTo(request.getBody()); + } + }; return this; } diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java index f3b05812da7..4e435902e8d 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -16,6 +16,7 @@ package org.springframework.web.client; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -44,6 +45,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatusCode; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; +import org.springframework.http.StreamingHttpOutputMessage; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; @@ -53,6 +55,7 @@ import org.springframework.http.client.JettyClientHttpRequestFactory; import org.springframework.http.client.ReactorClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.util.CollectionUtils; +import org.springframework.util.FastByteArrayOutputStream; import org.springframework.util.FileCopyUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -602,6 +605,30 @@ class RestClientIntegrationTests { }); } + @ParameterizedRestClientTest // gh-35102 + void postStreamingBody(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + prepareResponse(response -> response.setResponseCode(200)); + + StreamingHttpOutputMessage.Body testBody = out -> { + assertThat(out).as("Not a streaming response").isNotInstanceOf(FastByteArrayOutputStream.class); + new ByteArrayInputStream("test-data".getBytes(UTF_8)).transferTo(out); + }; + + ResponseEntity result = this.restClient.post() + .uri("/streaming/body") + .body(testBody) + .retrieve() + .toBodilessEntity(); + + assertThat(result.getStatusCode()).isEqualTo(HttpStatus.OK); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/streaming/body"); + assertThat(request.getBody().readUtf8()).isEqualTo("test-data"); + }); + } @ParameterizedRestClientTest void statusHandler(ClientHttpRequestFactory requestFactory) {