diff --git a/spring-web/src/main/java/org/springframework/http/HttpEntity.java b/spring-web/src/main/java/org/springframework/http/HttpEntity.java index a709bf78603..38c1d3fddd0 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpEntity.java +++ b/spring-web/src/main/java/org/springframework/http/HttpEntity.java @@ -16,12 +16,7 @@ package org.springframework.http; -import org.reactivestreams.Publisher; - -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.core.ResolvableType; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.util.ObjectUtils; @@ -72,9 +67,6 @@ public class HttpEntity { @Nullable private final T body; - @Nullable - private final ResolvableType bodyType; - /** * Create a new, empty {@code HttpEntity}. @@ -105,18 +97,7 @@ public class HttpEntity { * @param headers the entity headers */ public HttpEntity(@Nullable T body, @Nullable MultiValueMap headers) { - this(body, null, headers); - } - - private HttpEntity(@Nullable T body, @Nullable ResolvableType bodyType, - @Nullable MultiValueMap headers) { this.body = body; - - if (bodyType == null && body != null) { - bodyType = ResolvableType.forClass(body.getClass()); - } - this.bodyType = bodyType ; - HttpHeaders tempHeaders = new HttpHeaders(); if (headers != null) { tempHeaders.putAll(headers); @@ -147,13 +128,6 @@ public class HttpEntity { return (this.body != null); } - /** - * Returns the type of the body. - */ - @Nullable - public ResolvableType getBodyType() { - return this.bodyType; - } @Override public boolean equals(@Nullable Object other) { @@ -185,44 +159,4 @@ public class HttpEntity { return builder.toString(); } - - // Static builder methods - - /** - * Create a new {@code HttpEntity} with the given {@link Publisher} as body, class contained in - * {@code publisher}, and headers. - * @param publisher the publisher to use as body - * @param elementClass the class of elements contained in the publisher - * @param headers the entity headers - * @param the type of the elements contained in the publisher - * @param

the type of the {@code Publisher} - * @return the created entity - */ - public static > HttpEntity

fromPublisher(P publisher, - Class elementClass, @Nullable MultiValueMap headers) { - - Assert.notNull(publisher, "'publisher' must not be null"); - Assert.notNull(elementClass, "'elementClass' must not be null"); - return new HttpEntity<>(publisher, ResolvableType.forClass(elementClass), headers); - } - - /** - * Create a new {@code HttpEntity} with the given {@link Publisher} as body, type contained in - * {@code publisher}, and headers. - * @param publisher the publisher to use as body - * @param typeReference the type of elements contained in the publisher - * @param headers the entity headers - * @param the type of the elements contained in the publisher - * @param

the type of the {@code Publisher} - * @return the created entity - */ - public static > HttpEntity

fromPublisher(P publisher, - ParameterizedTypeReference typeReference, - @Nullable MultiValueMap headers) { - - Assert.notNull(publisher, "'publisher' must not be null"); - Assert.notNull(typeReference, "'typeReference' must not be null"); - return new HttpEntity<>(publisher, ResolvableType.forType(typeReference), headers); - } - } diff --git a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java index 231a7a52e29..80289e4fd87 100644 --- a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java +++ b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java @@ -143,8 +143,8 @@ public final class MultipartBodyBuilder { Assert.notNull(elementType, "'elementType' must not be null"); HttpHeaders partHeaders = new HttpHeaders(); - PublisherClassPartBuilder builder = - new PublisherClassPartBuilder<>(publisher, elementClass, partHeaders); + PublisherPartBuilder builder = + new PublisherPartBuilder<>(publisher, elementClass, partHeaders); this.parts.add(name, builder); return builder; @@ -155,21 +155,21 @@ public final class MultipartBodyBuilder { * the returned {@link PartBuilder}. * @param name the name of the part to add (may not be empty) * @param publisher the contents of the part to add - * @param elementType the type of elements contained in the publisher + * @param typeReference the type of elements contained in the publisher * @return a builder that allows for further header customization */ public > PartBuilder asyncPart(String name, P publisher, - ParameterizedTypeReference elementType) { + ParameterizedTypeReference typeReference) { - Assert.notNull(elementType, "'elementType' must not be null"); - ResolvableType elementType1 = ResolvableType.forType(elementType); + Assert.notNull(typeReference, "'typeReference' must not be null"); + ResolvableType elementType1 = ResolvableType.forType(typeReference); Assert.hasLength(name, "'name' must not be empty"); Assert.notNull(publisher, "'publisher' must not be null"); - Assert.notNull(elementType1, "'elementType' must not be null"); + Assert.notNull(elementType1, "'typeReference' must not be null"); HttpHeaders partHeaders = new HttpHeaders(); - PublisherTypReferencePartBuilder builder = - new PublisherTypReferencePartBuilder<>(publisher, elementType, partHeaders); + PublisherPartBuilder builder = + new PublisherPartBuilder<>(publisher, typeReference, partHeaders); this.parts.add(name, builder); return builder; } @@ -213,43 +213,57 @@ public final class MultipartBodyBuilder { } } - private static class PublisherClassPartBuilder> + private static class PublisherPartBuilder> extends DefaultPartBuilder { - private final Class bodyType; + private final ResolvableType resolvableType; - public PublisherClassPartBuilder(P body, Class bodyType, HttpHeaders headers) { + public PublisherPartBuilder(P body, Class elementClass, HttpHeaders headers) { super(body, headers); - this.bodyType = bodyType; + this.resolvableType = ResolvableType.forClass(elementClass); + } + + public PublisherPartBuilder(P body, ParameterizedTypeReference typeReference, + HttpHeaders headers) { + + super(body, headers); + this.resolvableType = ResolvableType.forType(typeReference); } @Override @SuppressWarnings("unchecked") public HttpEntity build() { - P body = (P) this.body; - Assert.state(body != null, "'body' must not be null"); - return HttpEntity.fromPublisher(body, this.bodyType, this.headers); + P publisher = (P) this.body; + Assert.state(publisher != null, "'publisher' must not be null"); + return new PublisherEntity<>(publisher, this.resolvableType, this.headers); } } - private static class PublisherTypReferencePartBuilder> - extends DefaultPartBuilder { - private final ParameterizedTypeReference bodyType; + /** + * Specific subtype of {@link HttpEntity} for containing {@link Publisher}s as body. + * Exposes the type contained in the publisher through {@link #getResolvableType()}. + * @param The type contained in the publisher + * @param

The publisher + */ + public static final class PublisherEntity> extends HttpEntity

{ - public PublisherTypReferencePartBuilder(P body, ParameterizedTypeReference bodyType, - HttpHeaders headers) { + private final ResolvableType resolvableType; - super(body, headers); - this.bodyType = bodyType; + + PublisherEntity(P publisher, ResolvableType resolvableType, + @Nullable MultiValueMap headers) { + super(publisher, headers); + Assert.notNull(publisher, "'publisher' must not be null"); + Assert.notNull(resolvableType, "'resolvableType' must not be null"); + this.resolvableType = resolvableType; } - @Override - @SuppressWarnings("unchecked") - public HttpEntity build() { - P body = (P) this.body; - Assert.state(body != null, "'body' must not be null"); - return HttpEntity.fromPublisher(body, this.bodyType, this.headers); + /** + * Return the resolvable type for this entry. + */ + public ResolvableType getResolvableType() { + return this.resolvableType; } } diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java index 5923ac7eb0b..aef9cdfae77 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java @@ -44,6 +44,7 @@ import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.client.MultipartBodyBuilder; import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.FormHttpMessageWriter; import org.springframework.http.codec.HttpMessageWriter; @@ -230,20 +231,25 @@ public class MultipartHttpMessageWriter implements HttpMessageWriter httpEntity = (HttpEntity) value; outputMessage.getHeaders().putAll(httpEntity.getHeaders()); body = httpEntity.getBody(); Assert.state(body != null, "MultipartHttpMessageWriter only supports HttpEntity with body"); - bodyType = httpEntity.getBodyType(); + + if (httpEntity instanceof MultipartBodyBuilder.PublisherEntity) { + MultipartBodyBuilder.PublisherEntity publisherEntity = + (MultipartBodyBuilder.PublisherEntity) httpEntity; + resolvableType = publisherEntity.getResolvableType(); + } } else { body = value; } - if (bodyType == null) { - bodyType = ResolvableType.forClass(body.getClass()); + if (resolvableType == null) { + resolvableType = ResolvableType.forClass(body.getClass()); } String filename = (body instanceof Resource ? ((Resource) body).getFilename() : null); @@ -251,7 +257,7 @@ public class MultipartHttpMessageWriter implements HttpMessageWriter> writer = this.partWriters.stream() .filter(partWriter -> partWriter.canWrite(finalBodyType, contentType)) .findFirst(); @@ -264,7 +270,7 @@ public class MultipartHttpMessageWriter implements HttpMessageWriter) body : Mono.just(body); Mono partWritten = ((HttpMessageWriter) writer.get()) - .write(bodyPublisher, bodyType, contentType, outputMessage, Collections.emptyMap()); + .write(bodyPublisher, resolvableType, contentType, outputMessage, Collections.emptyMap()); // partWritten.subscribe() is required in order to make sure MultipartHttpOutputMessage#getBody() // returns a non-null value (occurs with ResourceHttpMessageWriter that invokes diff --git a/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java b/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java index ed42d016c61..0fc970b8f9c 100644 --- a/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java @@ -20,6 +20,7 @@ import org.junit.Test; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; @@ -49,11 +50,12 @@ public class MultipartBodyBuilderTests { builder.part("key", multipartData).header("foo", "bar"); builder.part("logo", logo).header("baz", "qux"); builder.part("entity", entity).header("baz", "qux"); - builder.asyncPart("publisher", publisher, String.class).header("baz", "qux"); + builder.asyncPart("publisherClass", publisher, String.class).header("baz", "qux"); + builder.asyncPart("publisherPtr", publisher, new ParameterizedTypeReference() {}).header("baz", "qux"); MultiValueMap> result = builder.build(); - assertEquals(4, result.size()); + assertEquals(5, result.size()); assertNotNull(result.getFirst("key")); assertEquals(multipartData, result.getFirst("key").getBody()); assertEquals("bar", result.getFirst("key").getHeaders().getFirst("foo")); @@ -67,11 +69,15 @@ public class MultipartBodyBuilderTests { assertEquals("bar", result.getFirst("entity").getHeaders().getFirst("foo")); assertEquals("qux", result.getFirst("entity").getHeaders().getFirst("baz")); - assertNotNull(result.getFirst("publisher")); - assertEquals(publisher, result.getFirst("publisher").getBody()); - assertEquals(ResolvableType.forClass(String.class), result.getFirst("publisher").getBodyType()); - assertEquals("bar", result.getFirst("entity").getHeaders().getFirst("foo")); - assertEquals("qux", result.getFirst("entity").getHeaders().getFirst("baz")); + assertNotNull(result.getFirst("publisherClass")); + assertEquals(publisher, result.getFirst("publisherClass").getBody()); + assertEquals(ResolvableType.forClass(String.class), ((MultipartBodyBuilder.PublisherEntity) result.getFirst("publisherClass")).getResolvableType()); + assertEquals("qux", result.getFirst("publisherClass").getHeaders().getFirst("baz")); + + assertNotNull(result.getFirst("publisherPtr")); + assertEquals(publisher, result.getFirst("publisherPtr").getBody()); + assertEquals(ResolvableType.forClass(String.class), ((MultipartBodyBuilder.PublisherEntity) result.getFirst("publisherPtr")).getResolvableType()); + assertEquals("qux", result.getFirst("publisherPtr").getHeaders().getFirst("baz")); }