diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index 95e485d5b10..36ce64db342 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -466,7 +466,14 @@ final class DefaultRestClient implements RestClient { @Override public RequestBodySpec body(StreamingHttpOutputMessage.Body body) { - this.body = request -> body.writeTo(request.getBody()); + this.body = request -> { + if (request instanceof StreamingHttpOutputMessage streamingMessage) { + streamingMessage.setBody(body); + } + else { + body.writeTo(request.getBody()); + } + }; return this; } diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java index c1a726181d6..4fa39300121 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -16,6 +16,7 @@ package org.springframework.web.client; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -529,6 +530,27 @@ class RestClientIntegrationTests { }); } + @ParameterizedRestClientTest + void postStreamingMessageBody(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setResponseCode(200)); + + ResponseEntity result = this.restClient.post() + .uri("/streaming/body") + .body(new ByteArrayInputStream("test-data".getBytes(UTF_8))::transferTo) + .retrieve() + .toBodilessEntity(); + + assertThat(result.getStatusCode()).isEqualTo(HttpStatus.OK); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getPath()).isEqualTo("/streaming/body"); + assertThat(request.getBody().readUtf8()).isEqualTo("test-data"); + }); + } + @ParameterizedRestClientTest // gh-31361 public void postForm(ClientHttpRequestFactory requestFactory) { startServer(requestFactory);