diff --git a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceProxyFactory.java b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceProxyFactory.java index a442fa58e48..e4fcb5e4af8 100644 --- a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceProxyFactory.java +++ b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceProxyFactory.java @@ -187,9 +187,9 @@ public final class HttpServiceProxyFactory { private List initArgumentResolvers(ConversionService conversionService) { List resolvers = new ArrayList<>(this.customResolvers); - resolvers.add(new HttpMethodArgumentResolver()); - resolvers.add(new PathVariableArgumentResolver(conversionService)); resolvers.add(new RequestHeaderArgumentResolver(conversionService)); + resolvers.add(new PathVariableArgumentResolver(conversionService)); + resolvers.add(new HttpMethodArgumentResolver()); return resolvers; } diff --git a/spring-web/src/main/java/org/springframework/web/service/invoker/PathVariableArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/service/invoker/PathVariableArgumentResolver.java index 371439239ac..a9e6b5f79b6 100644 --- a/spring-web/src/main/java/org/springframework/web/service/invoker/PathVariableArgumentResolver.java +++ b/spring-web/src/main/java/org/springframework/web/service/invoker/PathVariableArgumentResolver.java @@ -63,28 +63,28 @@ public class PathVariableArgumentResolver implements HttpServiceArgumentResolver return false; } - if (Map.class.isAssignableFrom(parameter.getParameterType())) { + Class parameterType = parameter.getParameterType(); + boolean required = (annotation.required() && !Optional.class.isAssignableFrom(parameterType)); + + if (Map.class.isAssignableFrom(parameterType)) { if (argument != null) { Assert.isInstanceOf(Map.class, argument); - ((Map) argument).forEach((key, value) -> - addUriParameter(key, value, annotation.required(), requestValues)); + ((Map) argument).forEach((key, value) -> + addUriParameter(key, value, required, requestValues)); } } else { String name = StringUtils.hasText(annotation.value()) ? annotation.value() : annotation.name(); name = StringUtils.hasText(name) ? name : parameter.getParameterName(); Assert.notNull(name, "Failed to determine path variable name for parameter: " + parameter); - addUriParameter(name, argument, annotation.required(), requestValues); + addUriParameter(name, argument, required, requestValues); } return true; } private void addUriParameter( - Object name, @Nullable Object value, boolean required, HttpRequestValues.Builder requestValues) { - - String stringName = this.conversionService.convert(name, String.class); - Assert.notNull(stringName, "Missing path variable name"); + String name, @Nullable Object value, boolean required, HttpRequestValues.Builder requestValues) { if (value instanceof Optional) { value = ((Optional) value).orElse(null); @@ -95,15 +95,15 @@ public class PathVariableArgumentResolver implements HttpServiceArgumentResolver } if (value == null) { - Assert.isTrue(!required, "Missing required path variable '" + stringName + "'"); + Assert.isTrue(!required, "Missing required path variable '" + name + "'"); return; } if (logger.isTraceEnabled()) { - logger.trace("Resolved path variable '" + stringName + "' to " + value); + logger.trace("Resolved path variable '" + name + "' to " + value); } - requestValues.setUriVariable(stringName, (String) value); + requestValues.setUriVariable(name, (String) value); } } diff --git a/spring-web/src/main/java/org/springframework/web/service/invoker/RequestHeaderArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/service/invoker/RequestHeaderArgumentResolver.java index 98b9889d8cb..a00955c344e 100644 --- a/spring-web/src/main/java/org/springframework/web/service/invoker/RequestHeaderArgumentResolver.java +++ b/spring-web/src/main/java/org/springframework/web/service/invoker/RequestHeaderArgumentResolver.java @@ -19,9 +19,7 @@ package org.springframework.web.service.invoker; import java.util.Arrays; import java.util.Collection; import java.util.Map; -import java.util.Objects; import java.util.Optional; -import java.util.stream.Stream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,17 +28,26 @@ import org.springframework.core.MethodParameter; import org.springframework.core.convert.ConversionService; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.ValueConstants; + /** - * An implementation of {@link HttpServiceArgumentResolver} that resolves - * request headers based on method arguments annotated - * with {@link RequestHeader}. {@code null} values are allowed only - * if {@link RequestHeader#required()} is {@code true}. {@code null} - * values are replaced with {@link RequestHeader#defaultValue()} if it - * is not equal to {@link ValueConstants#DEFAULT_NONE}. + * {@link HttpServiceArgumentResolver} for {@link RequestHeader @RequestHeader} + * annotated arguments. + * + *

The argument may be: + *

    + *
  • {@code Map} or {@link org.springframework.util.MultiValueMap} with + * multiple headers and value(s). + *
  • {@code Collection} or an array of header values. + *
  • An individual header value. + *
+ * + *

Individual header values may be Strings or Objects to be converted to + * String values through the configured {@link ConversionService}. * * @author Olga Maciaszek-Sharma * @since 6.0 @@ -51,89 +58,91 @@ public class RequestHeaderArgumentResolver implements HttpServiceArgumentResolve private final ConversionService conversionService; + public RequestHeaderArgumentResolver(ConversionService conversionService) { Assert.notNull(conversionService, "ConversionService is required"); this.conversionService = conversionService; } + + @SuppressWarnings("unchecked") @Override - public boolean resolve(@Nullable Object argument, MethodParameter parameter, - HttpRequestValues.Builder requestValues) { - RequestHeader annotation = parameter.getParameterAnnotation(RequestHeader.class); + public boolean resolve( + @Nullable Object argument, MethodParameter parameter, HttpRequestValues.Builder requestValues) { - if (annotation == null) { + RequestHeader annot = parameter.getParameterAnnotation(RequestHeader.class); + if (annot == null) { return false; } - if (Map.class.isAssignableFrom(parameter.getParameterType())) { + Class parameterType = parameter.getParameterType(); + boolean required = (annot.required() && !Optional.class.isAssignableFrom(parameterType)); + Object defaultValue = (ValueConstants.DEFAULT_NONE.equals(annot.defaultValue()) ? null : annot.defaultValue()); + + if (Map.class.isAssignableFrom(parameterType)) { if (argument != null) { Assert.isInstanceOf(Map.class, argument); - ((Map) argument).forEach((key, value) -> - addRequestHeader(key, value, annotation.required(), annotation.defaultValue(), - requestValues)); + ((Map) argument).forEach((key, value) -> + addHeader(key, value, false, defaultValue, requestValues)); } } else { - String name = StringUtils.hasText(annotation.value()) ? - annotation.value() : annotation.name(); + String name = StringUtils.hasText(annot.value()) ? annot.value() : annot.name(); name = StringUtils.hasText(name) ? name : parameter.getParameterName(); Assert.notNull(name, "Failed to determine request header name for parameter: " + parameter); - addRequestHeader(name, argument, annotation.required(), annotation.defaultValue(), - requestValues); + addHeader(name, argument, required, defaultValue, requestValues); } + return true; } - private void addRequestHeader( - Object name, @Nullable Object value, boolean required, String defaultValue, + private void addHeader( + String name, @Nullable Object value, boolean required, @Nullable Object defaultValue, HttpRequestValues.Builder requestValues) { - String stringName = this.conversionService.convert(name, String.class); - Assert.notNull(stringName, "Failed to convert request header name '" + - name + "' to String"); - - if (value instanceof Optional) { - value = ((Optional) value).orElse(null); - } - - if (value == null) { - if (!ValueConstants.DEFAULT_NONE.equals(defaultValue)) { - value = defaultValue; + value = (ObjectUtils.isArray(value) ? Arrays.asList((Object[]) value) : value); + if (value instanceof Collection elements) { + boolean hasValue = false; + for (Object element : elements) { + if (element != null) { + hasValue = true; + addHeaderValue(name, element, false, requestValues); + } } - else { - Assert.isTrue(!required, "Missing required request header '" + stringName + "'"); + if (hasValue) { return; } + value = null; } - String[] headerValues = toStringArray(value); + if (value instanceof Optional optionalValue) { + value = optionalValue.orElse(null); + } - if (logger.isTraceEnabled()) { - logger.trace("Resolved request header '" + stringName + "' to list of values: " + - String.join(", ", headerValues)); + if (value == null && defaultValue != null) { + value = defaultValue; } - requestValues.addHeader(stringName, headerValues); + addHeaderValue(name, value, required, requestValues); } - private String[] toStringArray(Object value) { - return toValueStream(value) - .filter(Objects::nonNull) - .map(headerElement -> headerElement instanceof String - ? (String) headerElement : - this.conversionService.convert(headerElement, String.class)) - .filter(Objects::nonNull) - .toArray(String[]::new); - } + private void addHeaderValue( + String name, @Nullable Object value, boolean required, HttpRequestValues.Builder requestValues) { - private Stream toValueStream(Object value) { - if (value instanceof Object[]) { - return Arrays.stream((Object[]) value); + if (!(value instanceof String)) { + value = this.conversionService.convert(value, String.class); } - if (value instanceof Collection) { - return ((Collection) value).stream(); + + if (value == null) { + Assert.isTrue(!required, "Missing required header '" + name + "'"); + return; } - return Stream.of(value); + + if (logger.isTraceEnabled()) { + logger.trace("Resolved header '" + name + ":" + value + "'"); + } + + requestValues.addHeader(name, (String) value); } } diff --git a/spring-web/src/test/java/org/springframework/web/service/invoker/PathVariableArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/service/invoker/PathVariableArgumentResolverTests.java index d59deba6e7f..8db7873a553 100644 --- a/spring-web/src/test/java/org/springframework/web/service/invoker/PathVariableArgumentResolverTests.java +++ b/spring-web/src/test/java/org/springframework/web/service/invoker/PathVariableArgumentResolverTests.java @@ -135,12 +135,6 @@ class PathVariableArgumentResolverTests { .isThrownBy(() -> this.service.executeOptionalValueMap(Map.of("id", Optional.empty()))); } - @Test - void shouldResolvePathVariableNameFromObjectMapKey() { - this.service.executeValueMapWithObjectKey(Map.of(Boolean.TRUE, "true")); - assertPathVariable("true", "true"); - } - @SuppressWarnings("SameParameterValue") private void assertPathVariable(String name, @Nullable String expectedValue) { assertThat(getActualUriVariables().get(name)).isEqualTo(expectedValue); @@ -184,9 +178,6 @@ class PathVariableArgumentResolverTests { @GetExchange void executeValueMap(@Nullable @PathVariable Map map); - @GetExchange - void executeValueMapWithObjectKey(@Nullable @PathVariable Map map); - @GetExchange void executeOptionalValueMap(@PathVariable Map> map); } diff --git a/spring-web/src/test/java/org/springframework/web/service/invoker/RequestHeaderArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/service/invoker/RequestHeaderArgumentResolverTests.java index d512b0c74ae..167a717dae4 100644 --- a/spring-web/src/test/java/org/springframework/web/service/invoker/RequestHeaderArgumentResolverTests.java +++ b/spring-web/src/test/java/org/springframework/web/service/invoker/RequestHeaderArgumentResolverTests.java @@ -31,8 +31,9 @@ import org.springframework.web.service.annotation.GetExchange; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + /** - * Tests for {@link RequestHeaderArgumentResolver}. + * Unit tests for {@link RequestHeaderArgumentResolver}. * * @author Olga Maciaszek-Sharma */ @@ -42,104 +43,94 @@ class RequestHeaderArgumentResolverTests { private final Service service = this.clientAdapter.createService(Service.class); - @Test - void shouldResolveSingleValueRequestHeader() { - this.service.executeString("test"); - assertRequestHeaders("id", "test"); - } @Test - void shouldResolveRequestHeaderWithNameFromAnnotationName() { - this.service.executeNamed("test"); + void stringHeader() { + this.service.executeString("test"); assertRequestHeaders("id", "test"); } @Test - void shouldResolveRequestHeaderNameFromValue() { - this.service.executeNamedWithValue("test"); - assertRequestHeaders("test", "test"); + void objectHeader() { + this.service.execute(Boolean.TRUE); + assertRequestHeaders("id", "true"); } @Test - void shouldResolveObjectValueRequestHeader() { - this.service.execute(Boolean.TRUE); - assertRequestHeaders("id", "true"); + void namedHeader() { + this.service.executeNamed("test"); + assertRequestHeaders("id", "test"); } @Test - void shouldResolveListRequestHeader() { + void listHeader() { this.service.execute(List.of("test1", Boolean.TRUE, "test3")); - assertRequestHeaders("id", "test1", "true", "test3"); + assertRequestHeaders("multiValueHeader", "test1", "true", "test3"); } @Test - void shouldResolveArrayRequestHeader() { + void arrayHeader() { this.service.execute("test1", Boolean.FALSE, "test3"); - assertRequestHeaders("id", "test1", "false", "test3"); + assertRequestHeaders("multiValueHeader", "test1", "false", "test3"); } @Test - void shouldResolveRequestHeadersFromMap() { - this.service.executeMap(Maps.of(Boolean.TRUE, "true", Boolean.FALSE, "false")); - assertRequestHeaders("true", "true"); - assertRequestHeaders("false", "false"); + void mapHeader() { + this.service.executeMap(Maps.of("header1", "true", "header2", "false")); + assertRequestHeaders("header1", "true"); + assertRequestHeaders("header2", "false"); } @Test - void shouldThrowExceptionWhenRequiredHeaderNull() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.service.executeString(null)); + void mapHeaderNull() { + this.service.executeMap(null); + assertThat(getActualHeaders()).isEmpty(); } @Test - void shouldIgnoreNullWhenHeaderNotRequired() { - this.service.executeNotRequired(null); - assertThat(getActualHeaders().get("id")).isNull(); + void mapWithOptional() { + this.service.executeOptionalMapValue(Map.of("id", Optional.of("test"))); + assertRequestHeaders("id", "test"); } @Test - void shouldIgnoreNullMapValue() { - this.service.executeMap(null); - assertThat(getActualHeaders()).isEmpty(); + void nullHeaderRequired() { + assertThatIllegalArgumentException().isThrownBy(() -> this.service.executeString(null)); } @Test - void shouldResolveRequestHeaderFromOptionalArgumentWithConversion() { - this.service.executeOptional(Optional.of(Boolean.TRUE)); - assertRequestHeaders("id", "true"); + void nullHeaderNotRequired() { + this.service.executeNotRequired(null); + assertThat(getActualHeaders().get("id")).isNull(); } + @Test - void shouldResolveRequestHeaderFromOptionalArgument() { + void optional() { this.service.executeOptional(Optional.of("test")); assertRequestHeaders("id", "test"); } @Test - void shouldThrowExceptionForEmptyOptional() { - assertThatIllegalArgumentException().isThrownBy(() -> this.service.execute(Optional.empty())); + void optionalWithConversion() { + this.service.executeOptional(Optional.of(Boolean.TRUE)); + assertRequestHeaders("id", "true"); } @Test - void shouldIgnoreEmptyOptionalWhenNotRequired() { - this.service.executeOptionalNotRequired(Optional.empty()); + void optionalEmpty() { + this.service.executeOptional(Optional.empty()); assertThat(getActualHeaders().get("id")).isNull(); } @Test - void shouldResolveRequestHeaderFromOptionalMapValue() { - this.service.executeOptionalMapValue(Map.of("id", Optional.of("test"))); - assertRequestHeaders("id", "test"); - } - - @Test - void shouldReplaceNullValueWithDefaultWhenAvailable() { + void defaultValueWithNull() { this.service.executeWithDefaultValue(null); assertRequestHeaders("id", "default"); } @Test - void shouldReplaceEmptyOptionalValueWithDefaultWhenAvailable() { + void defaultValueWithOptional() { this.service.executeOptionalWithDefaultValue(Optional.empty()); assertRequestHeaders("id", "default"); } @@ -152,47 +143,43 @@ class RequestHeaderArgumentResolverTests { return this.clientAdapter.getRequestValues().getHeaders(); } + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") private interface Service { @GetExchange void executeString(@Nullable @RequestHeader String id); - @GetExchange - void executeNotRequired(@Nullable @RequestHeader(required = false) String id); - @GetExchange void execute(@RequestHeader Object id); @GetExchange - void execute(@RequestHeader List id); + void executeNamed(@RequestHeader(name = "id") String employeeId); @GetExchange - void execute(@RequestHeader Object... id); + void execute(@RequestHeader List multiValueHeader); @GetExchange - void executeMap(@Nullable @RequestHeader Map id); + void execute(@RequestHeader Object... multiValueHeader); @GetExchange - void executeOptionalMapValue(@RequestHeader Map> id); + void executeMap(@Nullable @RequestHeader Map id); @GetExchange - void executeOptional(@RequestHeader Optional id); + void executeOptionalMapValue(@RequestHeader Map> headers); @GetExchange - void executeOptionalNotRequired(@RequestHeader(required = false) Optional id); - - @GetExchange - void executeNamedWithValue(@Nullable @RequestHeader(name = "id", value = "test") String employeeId); + void executeNotRequired(@Nullable @RequestHeader(required = false) String id); @GetExchange - void executeNamed(@RequestHeader(name = "id") String employeeId); + void executeOptional(@RequestHeader Optional id); @GetExchange void executeWithDefaultValue(@Nullable @RequestHeader(defaultValue = "default") String id); @GetExchange - void executeOptionalWithDefaultValue(@Nullable @RequestHeader(defaultValue = "default") Optional id); + void executeOptionalWithDefaultValue(@RequestHeader(defaultValue = "default") Optional id); + } }