From e537844a093d313595e20cc440506fe8b04bc7db Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Thu, 21 Jan 2021 16:24:59 +0100 Subject: [PATCH] Copy headers from part in MultipartBodyBuilder This commit makes sure that Part.headers() is copied over when adding a part in the MultipartBodyBuilder. Closes gh-26410 --- .../http/client/MultipartBodyBuilder.java | 10 +++++----- .../multipart/MultipartHttpMessageWriterTests.java | 8 +++++++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java index f4031db2955..08f10762d89 100644 --- a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java +++ b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java @@ -30,7 +30,6 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; -import org.springframework.http.codec.multipart.FilePart; import org.springframework.http.codec.multipart.Part; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; @@ -129,13 +128,14 @@ public final class MultipartBodyBuilder { Assert.notNull(part, "'part' must not be null"); if (part instanceof Part) { - PartBuilder builder = asyncPart(name, ((Part) part).content(), DataBuffer.class); + Part partObject = (Part) part; + PartBuilder builder = asyncPart(name, partObject.content(), DataBuffer.class); + if (!partObject.headers().isEmpty()) { + builder.headers(headers -> headers.putAll(partObject.headers())); + } if (contentType != null) { builder.contentType(contentType); } - if (part instanceof FilePart) { - builder.filename(((FilePart) part).filename()); - } return builder; } diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java index 167b300fd37..7219fab9280 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java @@ -37,6 +37,7 @@ import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.core.testfixture.io.buffer.AbstractLeakCheckingTests; import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.client.MultipartBodyBuilder; import org.springframework.http.codec.ClientCodecConfigurer; @@ -102,8 +103,12 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTests { this.bufferFactory.wrap("Cc".getBytes(StandardCharsets.UTF_8)) ); FilePart mockPart = mock(FilePart.class); + HttpHeaders partHeaders = new HttpHeaders(); + partHeaders.setContentType(MediaType.TEXT_PLAIN); + partHeaders.setContentDispositionFormData("filePublisher", "file.txt"); + partHeaders.add("foo", "bar"); + given(mockPart.headers()).willReturn(partHeaders); given(mockPart.content()).willReturn(bufferPublisher); - given(mockPart.filename()).willReturn("file.txt"); MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); bodyBuilder.part("name 1", "value 1"); @@ -166,6 +171,7 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTests { part = requestParts.getFirst("filePublisher"); assertThat(part.name()).isEqualTo("filePublisher"); + assertThat(part.headers()).containsEntry("foo", Collections.singletonList("bar")); assertThat(((FilePart) part).filename()).isEqualTo("file.txt"); value = decodeToString(part); assertThat(value).isEqualTo("AaBbCc");