From 4555384528d2d1870da330ffc9043da7555a4680 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Mon, 1 Jul 2024 14:51:46 +0200 Subject: [PATCH] Introduce SmartHttpMessageConverter SmartHttpMessageConverter is similar to GenericHttpMessageConverter, but more consistent with WebFlux Encoder and Decoder contracts, with the following differences: - A ResolvableType parameter is used instead of the Type one - The MethodParameter can be retrieved via the ResolvableType source - No contextClass parameter - `@Nullable Map hints` additional parameter for write and read methods This commit also refines RestTemplate#canReadResponse in order to use the most specific converter contract when possible. Closes gh-33118 --- .../AbstractSmartHttpMessageConverter.java | 148 ++++++++++++++++++ .../GenericHttpMessageConverter.java | 3 +- .../http/converter/HttpMessageConverter.java | 1 + .../converter/SmartHttpMessageConverter.java | 132 ++++++++++++++++ .../web/client/DefaultRestClient.java | 28 +++- .../client/HttpMessageConverterExtractor.java | 18 ++- .../web/client/RestTemplate.java | 25 ++- .../HttpMessageConverterExtractorTests.java | 23 +++ .../web/client/RestTemplateTests.java | 37 +++++ .../DefaultEntityResponseBuilder.java | 11 +- .../function/DefaultServerRequest.java | 9 +- .../function/DefaultServerRequestBuilder.java | 9 +- ...essageConverterMethodArgumentResolver.java | 66 ++++++-- ...stractMessageConverterMethodProcessor.java | 51 ++++-- 14 files changed, 513 insertions(+), 48 deletions(-) create mode 100644 spring-web/src/main/java/org/springframework/http/converter/AbstractSmartHttpMessageConverter.java create mode 100644 spring-web/src/main/java/org/springframework/http/converter/SmartHttpMessageConverter.java diff --git a/spring-web/src/main/java/org/springframework/http/converter/AbstractSmartHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/AbstractSmartHttpMessageConverter.java new file mode 100644 index 00000000000..bcd9494ac7b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/AbstractSmartHttpMessageConverter.java @@ -0,0 +1,148 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.Map; + +import org.springframework.core.ResolvableType; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.lang.Nullable; + +/** + * Abstract base class for most {@link SmartHttpMessageConverter} implementations. + * + * @author Sebastien Deleuze + * @since 6.2 + * @param the converted object type + */ +public abstract class AbstractSmartHttpMessageConverter extends AbstractHttpMessageConverter + implements SmartHttpMessageConverter { + + /** + * Construct an {@code AbstractSmartHttpMessageConverter} with no supported media types. + * @see #setSupportedMediaTypes + */ + protected AbstractSmartHttpMessageConverter() { + } + + /** + * Construct an {@code AbstractSmartHttpMessageConverter} with one supported media type. + * @param supportedMediaType the supported media type + */ + protected AbstractSmartHttpMessageConverter(MediaType supportedMediaType) { + super(supportedMediaType); + } + + /** + * Construct an {@code AbstractSmartHttpMessageConverter} with multiple supported media type. + * @param supportedMediaTypes the supported media types + */ + protected AbstractSmartHttpMessageConverter(MediaType... supportedMediaTypes) { + super(supportedMediaTypes); + } + + + @Override + protected boolean supports(Class clazz) { + return true; + } + + @Override + public boolean canRead(ResolvableType type, @Nullable MediaType mediaType) { + Class clazz = type.resolve(); + return (clazz != null ? canRead(clazz, mediaType) : canRead(mediaType)); + } + + @Override + public boolean canWrite(ResolvableType type, Class clazz, @Nullable MediaType mediaType) { + return canWrite(clazz, mediaType); + } + + /** + * This implementation sets the default headers by calling {@link #addDefaultHeaders}, + * and then calls {@link #writeInternal}. + */ + @Override + public final void write(T t, ResolvableType type, @Nullable MediaType contentType, + HttpOutputMessage outputMessage, @Nullable Map hints) + throws IOException, HttpMessageNotWritableException { + + HttpHeaders headers = outputMessage.getHeaders(); + addDefaultHeaders(headers, t, contentType); + + if (outputMessage instanceof StreamingHttpOutputMessage streamingOutputMessage) { + streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() { + @Override + public void writeTo(OutputStream outputStream) throws IOException { + writeInternal(t, type, new HttpOutputMessage() { + @Override + public OutputStream getBody() { + return outputStream; + } + + @Override + public HttpHeaders getHeaders() { + return headers; + } + }, hints); + } + + @Override + public boolean repeatable() { + return supportsRepeatableWrites(t); + } + }); + } + else { + writeInternal(t, type, outputMessage, hints); + outputMessage.getBody().flush(); + } + } + + @Override + protected void writeInternal(T t, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + + writeInternal(t, ResolvableType.NONE, outputMessage, null); + } + + /** + * Abstract template method that writes the actual body. Invoked from + * {@link #write(Object, ResolvableType, MediaType, HttpOutputMessage, Map)}. + * @param t the object to write to the output message + * @param type the type of object to write + * @param outputMessage the HTTP output message to write to + * @param hints additional information about how to encode + * @throws IOException in case of I/O errors + * @throws HttpMessageNotWritableException in case of conversion errors + */ + protected abstract void writeInternal(T t, ResolvableType type, HttpOutputMessage outputMessage, + @Nullable Map hints) throws IOException, HttpMessageNotWritableException; + + @Override + protected T readInternal(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + return read(ResolvableType.forClass(clazz), inputMessage, null); + } +} diff --git a/spring-web/src/main/java/org/springframework/http/converter/GenericHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/GenericHttpMessageConverter.java index 905434e6e03..760c92ace65 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/GenericHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/GenericHttpMessageConverter.java @@ -35,6 +35,7 @@ import org.springframework.lang.Nullable; * @since 3.2 * @param the converted object type * @see org.springframework.core.ParameterizedTypeReference + * @see SmartHttpMessageConverter */ public interface GenericHttpMessageConverter extends HttpMessageConverter { @@ -53,7 +54,7 @@ public interface GenericHttpMessageConverter extends HttpMessageConverter boolean canRead(Type type, @Nullable Class contextClass, @Nullable MediaType mediaType); /** - * Read an object of the given type form the given input message, and returns it. + * Read an object of the given type from the given input message, and returns it. * @param type the (potentially generic) type of object to return. This type must have * previously been passed to the {@link #canRead canRead} method of this interface, * which must have returned {@code true}. diff --git a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java index 306a1b93a8f..5c0ed6fc659 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/HttpMessageConverter.java @@ -33,6 +33,7 @@ import org.springframework.lang.Nullable; * @author Rossen Stoyanchev * @since 3.0 * @param the converted object type + * @see SmartHttpMessageConverter */ public interface HttpMessageConverter { diff --git a/spring-web/src/main/java/org/springframework/http/converter/SmartHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/SmartHttpMessageConverter.java new file mode 100644 index 00000000000..98b1d44448b --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/converter/SmartHttpMessageConverter.java @@ -0,0 +1,132 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.converter; + +import java.io.IOException; +import java.util.Map; + +import org.springframework.core.ResolvableType; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; + +/** + * A specialization of {@link HttpMessageConverter} that can convert an HTTP request + * into a target object of a specified {@link ResolvableType} and a source object of + * a specified {@link ResolvableType} into an HTTP response with optional hints. + * + *

It provides default methods for {@link HttpMessageConverter} in order to allow + * subclasses to only have to implement the smart APIs. + * + * @author Sebastien Deleuze + * @since 6.2 + * @param the converted object type + */ +public interface SmartHttpMessageConverter extends HttpMessageConverter { + + /** + * Indicates whether the given type can be read by this converter. + * This method should perform the same checks as + * {@link HttpMessageConverter#canRead(Class, MediaType)} with additional ones + * related to the generic type. + * @param type the (potentially generic) type to test for readability. The + * {@linkplain ResolvableType#getSource() type source} may be used for retrieving + * additional information (the related method signature for example) when relevant. + * @param mediaType the media type to read, can be {@code null} if not specified. + * Typically, the value of a {@code Content-Type} header. + * @return {@code true} if readable; {@code false} otherwise + */ + boolean canRead(ResolvableType type, @Nullable MediaType mediaType); + + @Override + default boolean canRead(Class clazz, @Nullable MediaType mediaType) { + return canRead(ResolvableType.forClass(clazz), mediaType); + } + + /** + * Read an object of the given type from the given input message, and returns it. + * @param type the (potentially generic) type of object to return. This type must have + * previously been passed to the {@link #canRead(ResolvableType, MediaType) canRead} + * method of this interface, which must have returned {@code true}. The + * {@linkplain ResolvableType#getSource() type source} may be used for retrieving + * additional information (the related method signature for example) when relevant. + * @param inputMessage the HTTP input message to read from + * @param hints additional information about how to encode + * @return the converted object + * @throws IOException in case of I/O errors + * @throws HttpMessageNotReadableException in case of conversion errors + */ + T read(ResolvableType type, HttpInputMessage inputMessage, @Nullable Map hints) + throws IOException, HttpMessageNotReadableException; + + @Override + default T read(Class clazz, HttpInputMessage inputMessage) + throws IOException, HttpMessageNotReadableException { + + return read(ResolvableType.forClass(clazz), inputMessage, null); + } + + /** + * Indicates whether the given class can be written by this converter. + *

This method should perform the same checks as + * {@link HttpMessageConverter#canWrite(Class, MediaType)} with additional ones + * related to the generic type. + * @param targetType the (potentially generic) target type to test for writability + * (can be {@link ResolvableType#NONE} if not specified). The {@linkplain ResolvableType#getSource() type source} + * may be used for retrieving additional information (the related method signature for example) when relevant. + * @param valueClass the source object class to test for writability + * @param mediaType the media type to write (can be {@code null} if not specified); + * typically the value of an {@code Accept} header. + * @return {@code true} if writable; {@code false} otherwise + */ + boolean canWrite(ResolvableType targetType, Class valueClass, @Nullable MediaType mediaType); + + @Override + default boolean canWrite(Class clazz, @Nullable MediaType mediaType) { + return canWrite(ResolvableType.forClass(clazz), clazz, mediaType); + } + + /** + * Write a given object to the given output message. + * @param t the object to write to the output message. The type of this object must + * have previously been passed to the {@link #canWrite canWrite} method of this + * interface, which must have returned {@code true}. + * @param type the (potentially generic) type of object to write. This type must have + * previously been passed to the {@link #canWrite canWrite} method of this interface, + * which must have returned {@code true}. Can be {@link ResolvableType#NONE} if not specified. + * The {@linkplain ResolvableType#getSource() type source} may be used for retrieving additional + * information (the related method signature for example) when relevant. + * @param contentType the content type to use when writing. May be {@code null} to + * indicate that the default content type of the converter must be used. If not + * {@code null}, this media type must have previously been passed to the + * {@link #canWrite canWrite} method of this interface, which must have returned + * {@code true}. + * @param outputMessage the message to write to + * @param hints additional information about how to encode + * @throws IOException in case of I/O errors + * @throws HttpMessageNotWritableException in case of conversion errors + */ + void write(T t, ResolvableType type, @Nullable MediaType contentType, HttpOutputMessage outputMessage, + @Nullable Map hints) throws IOException, HttpMessageNotWritableException; + + @Override + default void write(T t, @Nullable MediaType contentType, HttpOutputMessage outputMessage) + throws IOException, HttpMessageNotWritableException { + write(t, ResolvableType.forInstance(t), contentType, outputMessage, null); + } +} diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index c29858d2aa8..39a9585bdb8 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -60,6 +60,7 @@ import org.springframework.http.client.observation.DefaultClientRequestObservati import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -212,15 +213,24 @@ final class DefaultRestClient implements RestClient { } for (HttpMessageConverter messageConverter : this.messageConverters) { - if (messageConverter instanceof GenericHttpMessageConverter genericHttpMessageConverter) { - if (genericHttpMessageConverter.canRead(bodyType, null, contentType)) { + if (messageConverter instanceof GenericHttpMessageConverter genericMessageConverter) { + if (genericMessageConverter.canRead(bodyType, null, contentType)) { if (logger.isDebugEnabled()) { logger.debug("Reading to [" + ResolvableType.forType(bodyType) + "]"); } - return (T) genericHttpMessageConverter.read(bodyType, null, responseWrapper); + return (T) genericMessageConverter.read(bodyType, null, responseWrapper); + } + } + else if (messageConverter instanceof SmartHttpMessageConverter smartMessageConverter) { + ResolvableType resolvableType = ResolvableType.forType(bodyType); + if (smartMessageConverter.canRead(resolvableType, contentType)) { + if (logger.isDebugEnabled()) { + logger.debug("Reading to [" + resolvableType + "]"); + } + return (T) smartMessageConverter.read(resolvableType, responseWrapper, null); } } - if (messageConverter.canRead(bodyClass, contentType)) { + else if (messageConverter.canRead(bodyClass, contentType)) { if (logger.isDebugEnabled()) { logger.debug("Reading to [" + bodyClass.getName() + "] as \"" + contentType + "\""); } @@ -453,7 +463,15 @@ final class DefaultRestClient implements RestClient { return; } } - if (messageConverter.canWrite(bodyClass, contentType)) { + else if (messageConverter instanceof SmartHttpMessageConverter smartMessageConverter) { + ResolvableType resolvableType = ResolvableType.forType(bodyType); + if (smartMessageConverter.canWrite(resolvableType, bodyClass, contentType)) { + logBody(body, contentType, smartMessageConverter); + smartMessageConverter.write(body, resolvableType, contentType, clientRequest, null); + return; + } + } + else if (messageConverter.canWrite(bodyClass, contentType)) { logBody(body, contentType, messageConverter); messageConverter.write(body, contentType, clientRequest); return; diff --git a/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java b/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java index 0efb72375d4..3f79cfa7ece 100644 --- a/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java +++ b/spring-web/src/main/java/org/springframework/web/client/HttpMessageConverterExtractor.java @@ -29,6 +29,7 @@ import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.FileCopyUtils; @@ -104,14 +105,21 @@ public class HttpMessageConverterExtractor implements ResponseExtractor { return (T) genericMessageConverter.read(this.responseType, null, responseWrapper); } } - if (this.responseClass != null) { - if (messageConverter.canRead(this.responseClass, contentType)) { + else if (messageConverter instanceof SmartHttpMessageConverter smartMessageConverter) { + ResolvableType resolvableType = ResolvableType.forType(this.responseType); + if (smartMessageConverter.canRead(resolvableType, contentType)) { if (logger.isDebugEnabled()) { - String className = this.responseClass.getName(); - logger.debug("Reading to [" + className + "] as \"" + contentType + "\""); + logger.debug("Reading to [" + resolvableType + "]"); } - return (T) messageConverter.read((Class) this.responseClass, responseWrapper); + return (T) smartMessageConverter.read(resolvableType, responseWrapper, null); + } + } + else if (this.responseClass != null && messageConverter.canRead(this.responseClass, contentType)) { + if (logger.isDebugEnabled()) { + String className = this.responseClass.getName(); + logger.debug("Reading to [" + className + "] as \"" + contentType + "\""); } + return (T) messageConverter.read((Class) this.responseClass, responseWrapper); } } } diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 343b42c5f2a..e9784a5e434 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -33,6 +33,7 @@ import io.micrometer.observation.ObservationConvention; import io.micrometer.observation.ObservationRegistry; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -52,6 +53,7 @@ import org.springframework.http.converter.ByteArrayHttpMessageConverter; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.ResourceHttpMessageConverter; +import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.http.converter.cbor.KotlinSerializationCborHttpMessageConverter; import org.springframework.http.converter.cbor.MappingJackson2CborHttpMessageConverter; @@ -1030,13 +1032,15 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat } private boolean canReadResponse(Type responseType, HttpMessageConverter converter) { - Class responseClass = (responseType instanceof Class clazz ? clazz : null); - if (responseClass != null) { - return converter.canRead(responseClass, null); - } - else if (converter instanceof GenericHttpMessageConverter genericConverter) { + if (converter instanceof GenericHttpMessageConverter genericConverter) { return genericConverter.canRead(responseType, null, null); } + else if (converter instanceof SmartHttpMessageConverter smartConverter) { + return smartConverter.canRead(ResolvableType.forType(responseType), null); + } + else if (responseType instanceof Class responseClass) { + return converter.canRead(responseClass, null); + } return false; } @@ -1114,6 +1118,17 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat return; } } + else if (messageConverter instanceof SmartHttpMessageConverter smartConverter) { + ResolvableType resolvableType = ResolvableType.forType(requestBodyType); + if (smartConverter.canWrite(resolvableType, requestBodyClass, requestContentType)) { + if (!requestHeaders.isEmpty()) { + requestHeaders.forEach((key, values) -> httpHeaders.put(key, new ArrayList<>(values))); + } + logBody(requestBody, requestContentType, smartConverter); + smartConverter.write(requestBody, resolvableType, requestContentType, httpRequest, null); + return; + } + } else if (messageConverter.canWrite(requestBodyClass, requestContentType)) { if (!requestHeaders.isEmpty()) { requestHeaders.forEach((key, values) -> httpHeaders.put(key, new ArrayList<>(values))); diff --git a/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java index d5de3c8b0c8..ba26b2109f9 100644 --- a/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/HttpMessageConverterExtractorTests.java @@ -25,6 +25,7 @@ import java.util.List; import org.junit.jupiter.api.Test; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpStatus; @@ -33,12 +34,14 @@ import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.SmartHttpMessageConverter; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -166,6 +169,26 @@ class HttpMessageConverterExtractorTests { assertThat(result).isEqualTo(expected); } + @Test + void smartConverter() throws IOException { + responseHeaders.setContentType(contentType); + String expected = "Foo"; + ParameterizedTypeReference> reference = new ParameterizedTypeReference<>() {}; + ResolvableType resolvableType = ResolvableType.forType(reference.getType()); + + SmartHttpMessageConverter converter = mock(); + HttpMessageConverterExtractor extractor = new HttpMessageConverterExtractor>(resolvableType.getType(), List.of(converter)); + + given(response.getStatusCode()).willReturn(HttpStatus.OK); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream(expected.getBytes())); + given(converter.canRead(resolvableType, contentType)).willReturn(true); + given(converter.read(eq(resolvableType), any(HttpInputMessage.class), isNull())).willReturn(expected); + + Object result = extractor.extractData(response); + assertThat(result).isEqualTo(expected); + } + @Test // SPR-13592 void converterThrowsIOException() throws IOException { responseHeaders.setContentType(contentType); diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index b22d3056426..cf7f6fadd44 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -35,6 +35,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpInputMessage; @@ -50,6 +51,7 @@ import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.http.converter.json.KotlinSerializationJsonHttpMessageConverter; import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; import org.springframework.web.util.DefaultUriBuilderFactory; @@ -683,6 +685,41 @@ class RestTemplateTests { verify(response).close(); } + @Test + @SuppressWarnings("rawtypes") + void exchangeParameterizedTypeWithSmartConverter() throws Exception { + SmartHttpMessageConverter converter = mock(); + template.setMessageConverters(Collections.singletonList(converter)); + ParameterizedTypeReference> intList = new ParameterizedTypeReference<>() {}; + given(converter.canRead(ResolvableType.forType(intList.getType()), null)).willReturn(true); + given(converter.getSupportedMediaTypes(any())).willReturn(Collections.singletonList(MediaType.TEXT_PLAIN)); + given(converter.canWrite(ResolvableType.forClass(String.class), String.class, null)).willReturn(true); + + HttpHeaders requestHeaders = new HttpHeaders(); + mockSentRequest(POST, "https://example.com", requestHeaders); + List expected = Collections.singletonList(42); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(MediaType.TEXT_PLAIN); + responseHeaders.setContentLength(10); + mockResponseStatus(HttpStatus.OK); + given(response.getHeaders()).willReturn(responseHeaders); + given(response.getBody()).willReturn(new ByteArrayInputStream(Integer.toString(42).getBytes())); + given(converter.canRead(ResolvableType.forType(intList.getType()), MediaType.TEXT_PLAIN)).willReturn(true); + given(converter.read(eq(ResolvableType.forType(intList.getType())), any(HttpInputMessage.class), eq(null))).willReturn(expected); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.set("MyHeader", "MyValue"); + HttpEntity requestEntity = new HttpEntity<>("Hello World", entityHeaders); + ResponseEntity> result = template.exchange("https://example.com", POST, requestEntity, intList); + assertThat(result.getBody()).as("Invalid POST result").isEqualTo(expected); + assertThat(result.getHeaders().getContentType()).as("Invalid Content-Type").isEqualTo(MediaType.TEXT_PLAIN); + assertThat(requestHeaders.getFirst("Accept")).as("Invalid Accept header").isEqualTo(MediaType.TEXT_PLAIN_VALUE); + assertThat(requestHeaders.getFirst("MyHeader")).as("Invalid custom header").isEqualTo("MyValue"); + assertThat(result.getStatusCode()).as("Invalid status code").isEqualTo(HttpStatus.OK); + + verify(response).close(); + } + @Test // SPR-15066 void requestInterceptorCanAddExistingHeaderValueWithoutBody() throws Exception { ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java index 2ab3f6e043f..135792c052f 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java @@ -41,6 +41,7 @@ import org.reactivestreams.Subscription; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ReactiveAdapter; import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.ResolvableType; import org.springframework.core.io.InputStreamResource; import org.springframework.core.io.Resource; import org.springframework.core.io.support.ResourceRegion; @@ -54,6 +55,7 @@ import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -309,7 +311,14 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder return; } } - if (messageConverter.canWrite(entityClass, contentType)) { + else if (messageConverter instanceof SmartHttpMessageConverter smartMessageConverter) { + ResolvableType resolvableType = ResolvableType.forType(entityType); + if (smartMessageConverter.canWrite(resolvableType, entityClass, contentType)) { + smartMessageConverter.write(entity, resolvableType, contentType, serverResponse, null); + return; + } + } + else if (messageConverter.canWrite(entityClass, contentType)) { ((HttpMessageConverter) messageConverter).write(entity, contentType, serverResponse); return; } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java index f4d57f6ace3..6b073a707c2 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java @@ -58,6 +58,7 @@ import org.springframework.http.HttpRange; import org.springframework.http.MediaType; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.http.server.RequestPath; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.lang.Nullable; @@ -210,7 +211,13 @@ class DefaultServerRequest implements ServerRequest { return (T) genericMessageConverter.read(bodyType, bodyClass, this.serverHttpRequest); } } - if (messageConverter.canRead(bodyClass, contentType)) { + else if (messageConverter instanceof SmartHttpMessageConverter smartMessageConverter) { + ResolvableType resolvableType = ResolvableType.forType(bodyType); + if (smartMessageConverter.canRead(resolvableType, contentType)) { + return (T) smartMessageConverter.read(resolvableType, this.serverHttpRequest, null); + } + } + else if (messageConverter.canRead(bodyClass, contentType)) { HttpMessageConverter theConverter = (HttpMessageConverter) messageConverter; Class clazz = (Class) bodyClass; diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java index 2705b712bc1..04340081177 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequestBuilder.java @@ -49,6 +49,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; @@ -318,7 +319,13 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder { return (T) genericMessageConverter.read(bodyType, bodyClass, inputMessage); } } - if (messageConverter.canRead(bodyClass, contentType)) { + else if (messageConverter instanceof SmartHttpMessageConverter smartMessageConverter) { + ResolvableType resolvableType = ResolvableType.forType(bodyType); + if (smartMessageConverter.canRead(resolvableType, contentType)) { + return (T) smartMessageConverter.read(resolvableType, inputMessage, null); + } + } + else if (messageConverter.canRead(bodyClass, contentType)) { HttpMessageConverter theConverter = (HttpMessageConverter) messageConverter; Class clazz = (Class) bodyClass; diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java index f1d65922816..93cb61212cc 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodArgumentResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; import org.springframework.core.log.LogFormatUtils; +import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpMethod; @@ -46,6 +47,7 @@ import org.springframework.http.MediaType; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -64,6 +66,7 @@ import org.springframework.web.method.support.HandlerMethodArgumentResolver; * @author Arjen Poutsma * @author Rossen Stoyanchev * @author Juergen Hoeller + * @author Sebastien Deleuze * @since 3.1 */ public abstract class AbstractMessageConverterMethodArgumentResolver implements HandlerMethodArgumentResolver { @@ -77,6 +80,8 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements protected final List> messageConverters; + protected enum ConverterType { BASE, GENERIC, SMART }; + private final RequestResponseBodyAdviceChain advice; @@ -99,7 +104,6 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements this.advice = new RequestResponseBodyAdviceChain(requestResponseBodyAdvice); } - /** * Return the configured {@link RequestBodyAdvice} and * {@link RequestBodyAdvice} where each instance may be wrapped as a @@ -147,8 +151,8 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements Class contextClass = parameter.getContainingClass(); Class targetClass = (targetType instanceof Class clazz ? clazz : null); + ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); if (targetClass == null) { - ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); targetClass = (Class) resolvableType.resolve(); } @@ -171,26 +175,46 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements EmptyBodyCheckingHttpInputMessage message = null; try { + ResolvableType targetResolvableType = null; message = new EmptyBodyCheckingHttpInputMessage(inputMessage); for (HttpMessageConverter converter : this.messageConverters) { - Class> converterType = (Class>) converter.getClass(); - GenericHttpMessageConverter genericConverter = - (converter instanceof GenericHttpMessageConverter ghmc ? ghmc : null); - if (genericConverter != null ? genericConverter.canRead(targetType, contextClass, contentType) : - (targetClass != null && converter.canRead(targetClass, contentType))) { + Class> converterClass = (Class>) converter.getClass(); + ConverterType converterTypeToUse = null; + if (converter instanceof GenericHttpMessageConverter genericConverter) { + if (genericConverter.canRead(targetType, contextClass, contentType)) { + converterTypeToUse = ConverterType.GENERIC; + } + } + else if (converter instanceof SmartHttpMessageConverter smartConverter) { + if (targetResolvableType == null) { + targetResolvableType = getNestedTypeIfNeeded(resolvableType); + } + if (smartConverter.canRead(targetResolvableType, contentType)) { + converterTypeToUse = ConverterType.SMART; + } + } + else if (targetClass != null && converter.canRead(targetClass, contentType)) { + converterTypeToUse = ConverterType.BASE; + } + if (converterTypeToUse != null) { if (message.hasBody()) { HttpInputMessage msgToUse = - getAdvice().beforeBodyRead(message, parameter, targetType, converterType); - body = (genericConverter != null ? genericConverter.read(targetType, contextClass, msgToUse) : - ((HttpMessageConverter) converter).read(targetClass, msgToUse)); - body = getAdvice().afterBodyRead(body, msgToUse, parameter, targetType, converterType); + getAdvice().beforeBodyRead(message, parameter, targetType, converterClass); + body = switch (converterTypeToUse) { + case BASE -> ((HttpMessageConverter) converter).read(targetClass, msgToUse); + case GENERIC -> ((GenericHttpMessageConverter) converter).read(targetType, contextClass, msgToUse); + case SMART -> ((SmartHttpMessageConverter) converter).read(targetResolvableType, msgToUse, null); + }; + body = getAdvice().afterBodyRead(body, msgToUse, parameter, targetType, converterClass); } else { - body = getAdvice().handleEmptyBody(null, message, parameter, targetType, converterType); + body = getAdvice().handleEmptyBody(null, message, parameter, targetType, converterClass); } break; } + } + if (body == NO_VALUE && noContentType && !message.hasBody()) { body = getAdvice().handleEmptyBody( null, message, parameter, targetType, NoContentTypeHttpMessageConverter.class); @@ -223,6 +247,22 @@ public abstract class AbstractMessageConverterMethodArgumentResolver implements return body; } + /** + * Return the generic type of the {@code returnType} (or of the nested type + * if it is an {@link HttpEntity} or/and an {@link Optional}). + */ + protected ResolvableType getNestedTypeIfNeeded(ResolvableType type) { + ResolvableType genericType = type; + if (Optional.class.isAssignableFrom(genericType.toClass())) { + genericType = genericType.getNested(2); + } + if (HttpEntity.class.isAssignableFrom(genericType.toClass())) { + genericType = genericType.getNested(2); + } + return genericType; + } + + /** * Create a new {@link HttpInputMessage} from the given {@link NativeWebRequest}. * @param webRequest the web request to create an input message from diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodProcessor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodProcessor.java index 0bf9f42c10b..b7f06e87d82 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodProcessor.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodProcessor.java @@ -50,6 +50,7 @@ import org.springframework.http.ProblemDetail; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.lang.Nullable; @@ -74,6 +75,7 @@ import org.springframework.web.util.UrlPathHelper; * @author Rossen Stoyanchev * @author Brian Clozel * @author Juergen Hoeller + * @author Sebastien Deleuze * @since 3.1 */ public abstract class AbstractMessageConverterMethodProcessor extends AbstractMessageConverterMethodArgumentResolver @@ -202,7 +204,7 @@ public abstract class AbstractMessageConverterMethodProcessor extends AbstractMe * be written by a converter, or if the content-type chosen by the server * has no compatible converter. */ - @SuppressWarnings({"rawtypes", "unchecked"}) + @SuppressWarnings({"rawtypes", "unchecked", "NullAway"}) protected void writeWithMessageConverters(@Nullable T value, MethodParameter returnType, ServletServerHttpRequest inputMessage, ServletServerHttpResponse outputMessage) throws IOException, HttpMediaTypeNotAcceptableException, HttpMessageNotWritableException { @@ -312,25 +314,36 @@ public abstract class AbstractMessageConverterMethodProcessor extends AbstractMe if (selectedMediaType != null) { selectedMediaType = selectedMediaType.removeQualityValue(); - for (HttpMessageConverter converter : this.messageConverters) { - GenericHttpMessageConverter genericConverter = - (converter instanceof GenericHttpMessageConverter ghmc ? ghmc : null); - if (genericConverter != null ? - ((GenericHttpMessageConverter) converter).canWrite(targetType, valueType, selectedMediaType) : - converter.canWrite(valueType, selectedMediaType)) { + + ResolvableType targetResolvableType = null; + for (HttpMessageConverter converter : this.messageConverters) { + ConverterType converterTypeToUse = null; + if (converter instanceof GenericHttpMessageConverter genericConverter) { + if (genericConverter.canWrite(targetType, valueType, selectedMediaType)) { + converterTypeToUse = ConverterType.GENERIC; + } + } + else if (converter instanceof SmartHttpMessageConverter smartConverter) { + targetResolvableType = getNestedTypeIfNeeded(ResolvableType.forMethodParameter(returnType)); + if (smartConverter.canWrite(targetResolvableType, valueType, selectedMediaType)) { + converterTypeToUse = ConverterType.SMART; + } + } + else if (converter.canWrite(valueType, selectedMediaType)){ + converterTypeToUse = ConverterType.BASE; + } + if (converterTypeToUse != null) { body = getAdvice().beforeBodyWrite(body, returnType, selectedMediaType, - (Class>) converter.getClass(), - inputMessage, outputMessage); + (Class>) converter.getClass(), inputMessage, outputMessage); if (body != null) { Object theBody = body; LogFormatUtils.traceDebug(logger, traceOn -> "Writing [" + LogFormatUtils.formatValue(theBody, !traceOn) + "]"); addContentDispositionHeader(inputMessage, outputMessage); - if (genericConverter != null) { - genericConverter.write(body, targetType, selectedMediaType, outputMessage); - } - else { - ((HttpMessageConverter) converter).write(body, selectedMediaType, outputMessage); + switch (converterTypeToUse) { + case BASE -> converter.write(body, selectedMediaType, outputMessage); + case GENERIC -> ((GenericHttpMessageConverter) converter).write(body, targetType, selectedMediaType, outputMessage); + case SMART -> ((SmartHttpMessageConverter) converter).write(body, targetResolvableType, selectedMediaType, outputMessage, null); } } else { @@ -416,8 +429,14 @@ public abstract class AbstractMessageConverterMethodProcessor extends AbstractMe } Set result = new LinkedHashSet<>(); for (HttpMessageConverter converter : this.messageConverters) { - if (converter instanceof GenericHttpMessageConverter ghmc && targetType != null) { - if (ghmc.canWrite(targetType, valueClass, null)) { + if (converter instanceof GenericHttpMessageConverter genericConverter && targetType != null) { + if (genericConverter.canWrite(targetType, valueClass, null)) { + result.addAll(converter.getSupportedMediaTypes(valueClass)); + } + } + else if (converter instanceof SmartHttpMessageConverter smartConverter && targetType != null) { + ResolvableType resolvableType = ResolvableType.forType(targetType); + if (smartConverter.canWrite(resolvableType, valueClass, null)) { result.addAll(converter.getSupportedMediaTypes(valueClass)); } }