From 66f33a82651b6cfaec8e85a5f5af9a8bc285d5ca Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Mon, 9 Dec 2024 15:21:55 +0000 Subject: [PATCH 1/6] MapMethodProcessor supportsParameter is more specific Closes gh-33160 --- .../method/annotation/MapMethodProcessor.java | 7 +++++-- .../annotation/MapMethodProcessorTests.java | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/MapMethodProcessor.java b/spring-web/src/main/java/org/springframework/web/method/annotation/MapMethodProcessor.java index ba47fe3d425..8f8cb85489c 100644 --- a/spring-web/src/main/java/org/springframework/web/method/annotation/MapMethodProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/MapMethodProcessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import java.util.Map; import org.springframework.core.MethodParameter; import org.springframework.lang.Nullable; +import org.springframework.ui.ModelMap; import org.springframework.util.Assert; import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.context.request.NativeWebRequest; @@ -42,7 +43,9 @@ public class MapMethodProcessor implements HandlerMethodArgumentResolver, Handle @Override public boolean supportsParameter(MethodParameter parameter) { - return (Map.class.isAssignableFrom(parameter.getParameterType()) && + // We don't support any type of Map + Class type = parameter.getParameterType(); + return ((type.isAssignableFrom(Map.class) || ModelMap.class.isAssignableFrom(type)) && parameter.getParameterAnnotations().length == 0); } diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/MapMethodProcessorTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/MapMethodProcessorTests.java index 483f88d80f3..2f7c40dec9e 100644 --- a/spring-web/src/test/java/org/springframework/web/method/annotation/MapMethodProcessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/MapMethodProcessorTests.java @@ -16,6 +16,7 @@ package org.springframework.web.method.annotation; +import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.BeforeEach; @@ -63,8 +64,13 @@ class MapMethodProcessorTests { void supportsParameter() { assertThat(this.processor.supportsParameter( this.resolvable.annotNotPresent().arg(Map.class, String.class, Object.class))).isTrue(); + assertThat(this.processor.supportsParameter( this.resolvable.annotPresent(RequestBody.class).arg(Map.class, String.class, Object.class))).isFalse(); + + // gh-33160 + assertThat(this.processor.supportsParameter( + ResolvableMethod.on(getClass()).argTypes(ExtendedMap.class).build().arg(ExtendedMap.class))).isFalse(); } @Test @@ -100,4 +106,15 @@ class MapMethodProcessorTests { return null; } + + @SuppressWarnings("unused") + private Map handle(ExtendedMap extendedMap) { + return null; + } + + + @SuppressWarnings("serial") + private static final class ExtendedMap extends HashMap { + } + } From 640e5705831beed32fd5c0490ef19c194ed62a95 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Tue, 10 Dec 2024 16:10:42 +0000 Subject: [PATCH 2/6] Minor refactoring in ServerSentEvent Extract re-usable method to serialize SSE fields. See gh-33975 --- .../http/codec/ServerSentEvent.java | 29 ++++++++++++++ .../ServerSentEventHttpMessageWriter.java | 38 ++++--------------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java index 8c988ee04f6..39752442789 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java @@ -20,6 +20,7 @@ import java.time.Duration; import org.springframework.lang.Nullable; import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; /** * Representation for a Server-Sent Event for use with Spring's reactive Web support. @@ -102,6 +103,34 @@ public final class ServerSentEvent { return this.data; } + /** + * Return a StringBuilder with the id, event, retry, and comment fields fully + * serialized, and also appending "data:" if there is data. + * @since 6.2.1 + */ + public String format() { + StringBuilder sb = new StringBuilder(); + if (this.id != null) { + appendAttribute("id", this.id, sb); + } + if (this.event != null) { + appendAttribute("event", this.event, sb); + } + if (this.retry != null) { + appendAttribute("retry", this.retry.toMillis(), sb); + } + if (this.comment != null) { + sb.append(':').append(StringUtils.replace(this.comment, "\n", "\n:")).append('\n'); + } + if (this.data != null) { + sb.append("data:"); + } + return sb.toString(); + } + + private void appendAttribute(String fieldName, Object fieldValue, StringBuilder sb) { + sb.append(fieldName).append(':').append(fieldValue).append('\n'); + } @Override public boolean equals(@Nullable Object other) { diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java index e23937943a9..28aac85286a 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ package org.springframework.http.codec; import java.nio.charset.StandardCharsets; -import java.time.Duration; import java.util.Collections; import java.util.List; import java.util.Map; @@ -124,38 +123,19 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter sse = (element instanceof ServerSentEvent serverSentEvent ? serverSentEvent : ServerSentEvent.builder().data(element).build()); - StringBuilder sb = new StringBuilder(); - String id = sse.id(); - String event = sse.event(); - Duration retry = sse.retry(); - String comment = sse.comment(); + String sseText = sse.format(); Object data = sse.data(); - if (id != null) { - writeField("id", id, sb); - } - if (event != null) { - writeField("event", event, sb); - } - if (retry != null) { - writeField("retry", retry.toMillis(), sb); - } - if (comment != null) { - sb.append(':').append(StringUtils.replace(comment, "\n", "\n:")).append('\n'); - } - if (data != null) { - sb.append("data:"); - } Flux result; if (data == null) { - result = Flux.just(encodeText(sb + "\n", mediaType, factory)); + result = Flux.just(encodeText(sseText + "\n", mediaType, factory)); } else if (data instanceof String text) { text = StringUtils.replace(text, "\n", "\ndata:"); - result = Flux.just(encodeText(sb + text + "\n\n", mediaType, factory)); + result = Flux.just(encodeText(sseText + text + "\n\n", mediaType, factory)); } else { - result = encodeEvent(sb, data, dataType, mediaType, factory, hints); + result = encodeEvent(sseText, data, dataType, mediaType, factory, hints); } return result.doOnDiscard(DataBuffer.class, DataBufferUtils::release); @@ -163,7 +143,7 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter Flux encodeEvent(StringBuilder eventContent, T data, ResolvableType dataType, + private Flux encodeEvent(CharSequence sseText, T data, ResolvableType dataType, MediaType mediaType, DataBufferFactory factory, Map hints) { if (this.encoder == null) { @@ -171,7 +151,7 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter { - DataBuffer startBuffer = encodeText(eventContent, mediaType, factory); + DataBuffer startBuffer = encodeText(sseText, mediaType, factory); DataBuffer endBuffer = encodeText("\n\n", mediaType, factory); DataBuffer dataBuffer = ((Encoder) this.encoder).encodeValue(data, factory, dataType, mediaType, hints); Hints.touchDataBuffer(dataBuffer, hints, logger); @@ -179,10 +159,6 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter Date: Tue, 10 Dec 2024 17:31:39 +0000 Subject: [PATCH 3/6] Support Flux> in WebFlux Closes gh-33975 --- .../view/ViewResolutionResultHandler.java | 85 +++++++++++++------ ...gmentViewResolutionResultHandlerTests.java | 67 +++++++++++---- .../ViewResolutionResultHandlerTests.java | 5 ++ 3 files changed, 116 insertions(+), 41 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java index 51866ba9dce..c567a6b6d0d 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java @@ -45,6 +45,7 @@ import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatusCode; import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.ServerHttpResponseDecorator; @@ -101,7 +102,7 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp private final List defaultViews = new ArrayList<>(4); - private final List streamHandlers = List.of(new SseStreamHandler()); + private final SseStreamHandler sseHandler = new SseStreamHandler(); /** @@ -175,7 +176,7 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp returnType = returnType.getNested(2); if (adapter.isMultiValue()) { - return Fragment.class.isAssignableFrom(type); + return (Fragment.class.isAssignableFrom(type) || isSseFragmentStream(returnType)); } } @@ -194,8 +195,13 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp } private static boolean isFragmentCollection(ResolvableType returnType) { - Class clazz = returnType.resolve(Object.class); - return (Collection.class.isAssignableFrom(clazz) && Fragment.class.equals(returnType.getNested(2).resolve())); + return (Collection.class.isAssignableFrom(returnType.resolve(Object.class)) && + Fragment.class.equals(returnType.getNested(2).resolve())); + } + + private static boolean isSseFragmentStream(ResolvableType returnType) { + return (ServerSentEvent.class.equals(returnType.resolve()) && + Fragment.class.equals(returnType.getNested(2).resolve())); } @Override @@ -204,9 +210,15 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp Mono valueMono; ResolvableType valueType; ReactiveAdapter adapter = getAdapter(result); + BindingContext bindingContext = result.getBindingContext(); + Locale locale = LocaleContextHolder.getLocale(exchange.getLocaleContext()); if (adapter != null) { if (adapter.isMultiValue()) { + if (isSseFragmentStream(result.getReturnType().getNested(2))) { + return handleSseFragmentStream(exchange, result, adapter, locale, bindingContext); + } + valueMono = (result.getReturnValue() != null ? Mono.just(FragmentsRendering.withPublisher(adapter.toPublisher(result.getReturnValue())).build()) : Mono.empty()); @@ -233,8 +245,6 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp Mono> viewsMono; Model model = result.getModel(); MethodParameter parameter = result.getReturnTypeSource(); - BindingContext bindingContext = result.getBindingContext(); - Locale locale = LocaleContextHolder.getLocale(exchange.getLocaleContext()); Class clazz = valueType.toClass(); if (clazz == Object.class) { @@ -277,13 +287,15 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp response.getHeaders().putAll(render.headers()); bindingContext.updateModel(exchange); - StreamHandler streamHandler = getStreamHandler(exchange); + StreamHandler streamHandler = + (this.sseHandler.supports(exchange.getRequest()) ? this.sseHandler : null); + if (streamHandler != null) { streamHandler.updateResponse(exchange); } Flux> renderFlux = render.fragments() - .concatMap(fragment -> renderFragment(fragment, streamHandler, locale, bindingContext, exchange)) + .concatMap(fragment -> renderFragment(fragment, null, streamHandler, locale, bindingContext, exchange)) .doOnDiscard(DataBuffer.class, DataBufferUtils::release); return response.writeAndFlushWith(renderFlux); @@ -338,9 +350,29 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp }); } + private Mono handleSseFragmentStream( + ServerWebExchange exchange, HandlerResult result, ReactiveAdapter adapter, Locale locale, + BindingContext bindingContext) { + + this.sseHandler.updateResponse(exchange); + + Flux> eventFlux = + Flux.from(adapter.toPublisher(result.getReturnValue())); + + Flux> dataBufferFlux = eventFlux + .concatMap(event -> renderFragment(event.data(), event, this.sseHandler, locale, bindingContext, exchange)) + .doOnDiscard(DataBuffer.class, DataBufferUtils::release); + + return exchange.getResponse().writeAndFlushWith(dataBufferFlux); + } + private Mono> renderFragment( - Fragment fragment, @Nullable StreamHandler streamHandler, Locale locale, - BindingContext bindingContext, ServerWebExchange exchange) { + @Nullable Fragment fragment, @Nullable Object streamingHints, @Nullable StreamHandler streamHandler, + Locale locale, BindingContext bindingContext, ServerWebExchange exchange) { + + if (fragment == null) { + return Mono.empty(); + } // Merge attributes from top-level model fragment.mergeAttributes(bindingContext.getModel()); @@ -355,8 +387,11 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp Map model = fragment.model(); if (streamHandler != null) { - return selectedViews.flatMap(views -> render(views, model, MediaType.TEXT_HTML, bindingContext, mutatedExchange)) - .then(Mono.fromSupplier(() -> streamHandler.format(response.getBodyFlux(), fragment, exchange))); + return selectedViews + .flatMap(views -> + render(views, model, MediaType.TEXT_HTML, bindingContext, mutatedExchange)) + .then(Mono.fromSupplier(() -> streamHandler.format( + response.getBodyFlux(), fragment, streamingHints, exchange))); } else { return selectedViews.flatMap(views -> render(views, model, null, bindingContext, mutatedExchange)) @@ -364,16 +399,6 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp } } - @Nullable - private StreamHandler getStreamHandler(ServerWebExchange exchange) { - for (StreamHandler handler : this.streamHandlers) { - if (handler.supports(exchange.getRequest())) { - return handler; - } - } - return null; - } - private String getNameForReturnValue(MethodParameter returnType) { return Optional.ofNullable(returnType.getMethodAnnotation(ModelAttribute.class)) .filter(ann -> StringUtils.hasText(ann.value())) @@ -499,10 +524,13 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp * Format the given fragment. * @param fragmentContent the fragment serialized to data buffers * @param fragment the fragment being rendered + * @param streamingHints extra hints for the stream format (e.g. ServerSentEvent wrapper) * @param exchange the current exchange * @return the formatted fragment */ - Flux format(Flux fragmentContent, Fragment fragment, ServerWebExchange exchange); + Flux format( + Flux fragmentContent, Fragment fragment, @Nullable Object streamingHints, + ServerWebExchange exchange); } @@ -540,16 +568,21 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp @Override public Flux format( - Flux fragmentFlux, Fragment fragment, ServerWebExchange exchange) { + Flux fragmentFlux, Fragment fragment, @Nullable Object hints, + ServerWebExchange exchange) { MediaType mediaType = exchange.getResponse().getHeaders().getContentType(); Charset charset = (mediaType != null && mediaType.getCharset() != null ? mediaType.getCharset() : StandardCharsets.UTF_8); + Assert.state(hints == null || hints instanceof ServerSentEvent, "Expected ServerSentEvent"); DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory(); - String eventLine = (fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : ""); - DataBuffer prefix = encodeText(eventLine + "data:", charset, bufferFactory); + ServerSentEvent sse = (ServerSentEvent) hints; + CharSequence eventText = (sse != null ? sse.format() : + (fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : "") + "data:"); + + DataBuffer prefix = encodeText(eventText.toString(), charset, bufferFactory); DataBuffer suffix = encodeText("\n\n", charset, bufferFactory); Mono content = DataBufferUtils.join(fragmentFlux) diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java index 40122c3f7a2..236e85e3d70 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java @@ -35,7 +35,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.support.ResourceBundleMessageSource; import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.accept.HeaderContentTypeResolver; @@ -99,7 +101,51 @@ public class FragmentViewResolutionResultHandlerTests { } @Test - void renderSse() { + void renderFragmentStream() { + + testSse(Flux.just(fragment1, fragment2), + on(Handler.class).resolveReturnType(Flux.class, Fragment.class), + """ + event:fragment1 + data:

+ data: Hello Foo + data:

+ + event:fragment2 + data:

+ data: Hello Bar + data:

+ + """); + } + + @Test + void renderServerSentEventFragmentStream() { + + ServerSentEvent event1 = ServerSentEvent.builder(fragment1).id("id1").event("event1").build(); + ServerSentEvent event2 = ServerSentEvent.builder(fragment2).id("id2").event("event2").build(); + + MethodParameter returnType = on(Handler.class).resolveReturnType( + Flux.class, ResolvableType.forClassWithGenerics(ServerSentEvent.class, Fragment.class)); + + testSse(Flux.just(event1, event2), returnType, + """ + id:id1 + event:event1 + data:

+ data: Hello Foo + data:

+ + id:id2 + event:event2 + data:

+ data: Hello Bar + data:

+ + """); + } + + private void testSse(Flux dataFlux, MethodParameter returnType, String output) { MockServerHttpRequest request = MockServerHttpRequest.get("/") .accept(MediaType.TEXT_EVENT_STREAM) .acceptLanguageAsLocales(Locale.ENGLISH) @@ -110,8 +156,8 @@ public class FragmentViewResolutionResultHandlerTests { HandlerResult result = new HandlerResult( new Handler(), - Flux.just(fragment1, fragment2).subscribeOn(Schedulers.boundedElastic()), - on(Handler.class).resolveReturnType(Flux.class, Fragment.class), + dataFlux.subscribeOn(Schedulers.boundedElastic()), + returnType, new BindingContext()); String body = initHandler().handleResult(exchange, result) @@ -119,18 +165,7 @@ public class FragmentViewResolutionResultHandlerTests { .block(Duration.ofSeconds(60)); assertThat(response.getHeaders().getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM); - assertThat(body).isEqualTo(""" - event:fragment1 - data:

- data: Hello Foo - data:

- - event:fragment2 - data:

- data: Hello Bar - data:

- - """); + assertThat(body).isEqualTo(output); } private ViewResolutionResultHandler initHandler() { @@ -155,6 +190,8 @@ public class FragmentViewResolutionResultHandlerTests { Flux renderFlux() { return null; } + Flux> renderSseFlux() { return null; } + List renderList() { return null; } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java index 2ad247317b4..bc4249cb621 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java @@ -41,6 +41,7 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.ui.ConcurrentModel; @@ -84,6 +85,9 @@ class ViewResolutionResultHandlerTests { testSupports(on(Handler.class).resolveReturnType(FragmentsRendering.class)); testSupports(on(Handler.class).resolveReturnType(Flux.class, Fragment.class)); + testSupports(on(Handler.class).resolveReturnType( + Flux.class, ResolvableType.forClassWithGenerics(ServerSentEvent.class, Fragment.class))); + testSupports(on(Handler.class).resolveReturnType(List.class, Fragment.class)); testSupports(on(Handler.class).resolveReturnType( Mono.class, ResolvableType.forClassWithGenerics(List.class, Fragment.class))); @@ -457,6 +461,7 @@ class ViewResolutionResultHandlerTests { FragmentsRendering fragmentsRendering() { return null; } Flux fragmentFlux() { return null; } + Flux> fragmentServerSentEventFlux() { return null; } Mono> monoFragmentList() { return null; } List fragmentList() { return null; } From 7b4e19c69bc15d3018740d72f962c112cbd1772e Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Wed, 11 Dec 2024 14:23:41 +0000 Subject: [PATCH 4/6] Make ExtendedServletRequestDataBinder public Make it public and move it down to the annotations package alongside InitBinderBindingContext. This is mirrors the hierarchy in Spring MVC with the ExtendedServletRequestDataBinder. The change will allow customization of the header names to include/exclude in data binding. See gh-34039 --- .../web/reactive/BindingContext.java | 64 +++------------ .../ExtendedWebExchangeDataBinder.java | 82 +++++++++++++++++++ .../annotation/InitBinderBindingContext.java | 11 ++- .../web/reactive/BindingContextTests.java | 52 ------------ .../InitBinderBindingContextTests.java | 51 ++++++++++++ 5 files changed, 156 insertions(+), 104 deletions(-) create mode 100644 spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/BindingContext.java b/spring-webflux/src/main/java/org/springframework/web/reactive/BindingContext.java index 2a70b3f2d3e..6f9a4b95a56 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/BindingContext.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/BindingContext.java @@ -18,19 +18,14 @@ package org.springframework.web.reactive; import java.lang.annotation.Annotation; import java.util.Collection; -import java.util.List; import java.util.Map; -import reactor.core.publisher.Mono; - import org.springframework.beans.BeanUtils; import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.ResolvableType; -import org.springframework.http.HttpHeaders; import org.springframework.lang.Nullable; import org.springframework.ui.Model; -import org.springframework.util.CollectionUtils; import org.springframework.validation.BindingResult; import org.springframework.validation.DataBinder; import org.springframework.validation.SmartValidator; @@ -141,7 +136,7 @@ public class BindingContext { public WebExchangeDataBinder createDataBinder( ServerWebExchange exchange, @Nullable Object target, String name, @Nullable ResolvableType targetType) { - WebExchangeDataBinder dataBinder = new ExtendedWebExchangeDataBinder(target, name); + WebExchangeDataBinder dataBinder = createBinderInstance(target, name); dataBinder.setNameResolver(new BindParamNameResolver()); if (target == null && targetType != null) { @@ -163,6 +158,18 @@ public class BindingContext { return dataBinder; } + /** + * Extension point to create the WebDataBinder instance. + * By default, this is {@code WebRequestDataBinder}. + * @param target the binding target or {@code null} for type conversion only + * @param name the binding target object name + * @return the created {@link WebExchangeDataBinder} instance + * @since 6.2.1 + */ + protected WebExchangeDataBinder createBinderInstance(@Nullable Object target, String name) { + return new WebExchangeDataBinder(target, name); + } + /** * Initialize the data binder instance for the given exchange. * @throws ServerErrorException if {@code @InitBinder} method invocation fails @@ -200,51 +207,6 @@ public class BindingContext { } - /** - * Extended variant of {@link WebExchangeDataBinder}, adding path variables. - */ - private static class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder { - - public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) { - super(target, objectName); - } - - @Override - public Mono> getValuesToBind(ServerWebExchange exchange) { - return super.getValuesToBind(exchange).doOnNext(map -> { - Map vars = exchange.getAttribute(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE); - if (!CollectionUtils.isEmpty(vars)) { - vars.forEach((key, value) -> addValueIfNotPresent(map, "URI variable", key, value)); - } - HttpHeaders headers = exchange.getRequest().getHeaders(); - for (Map.Entry> entry : headers.entrySet()) { - List values = entry.getValue(); - if (!CollectionUtils.isEmpty(values)) { - String name = entry.getKey().replace("-", ""); - addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values)); - } - } - }); - } - - private static void addValueIfNotPresent( - Map map, String label, String name, @Nullable Object value) { - - if (value != null) { - if (map.containsKey(name)) { - if (logger.isDebugEnabled()) { - logger.debug(label + " '" + name + "' overridden by request bind value."); - } - } - else { - map.put(name, value); - } - } - } - - } - - /** * Excludes Bean Validation if the method parameter has {@code @Valid}. */ diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java new file mode 100644 index 00000000000..0863499d867 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation; + +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.util.CollectionUtils; +import org.springframework.web.bind.support.WebExchangeDataBinder; +import org.springframework.web.reactive.HandlerMapping; +import org.springframework.web.server.ServerWebExchange; + +/** + * Extended variant of {@link WebExchangeDataBinder} that adds URI path variables + * and request headers to the bind values map. + * + *

Note: This class has existed since 5.0, but only as a private class within + * {@link org.springframework.web.reactive.BindingContext}. + * + * @author Rossen Stoyanchev + * @since 6.2.1 + */ +public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder { + + + public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) { + super(target, objectName); + } + + + @Override + public Mono> getValuesToBind(ServerWebExchange exchange) { + return super.getValuesToBind(exchange).doOnNext(map -> { + Map vars = exchange.getAttribute(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE); + if (!CollectionUtils.isEmpty(vars)) { + vars.forEach((key, value) -> addValueIfNotPresent(map, "URI variable", key, value)); + } + HttpHeaders headers = exchange.getRequest().getHeaders(); + for (Map.Entry> entry : headers.entrySet()) { + List values = entry.getValue(); + if (!CollectionUtils.isEmpty(values)) { + String name = entry.getKey().replace("-", ""); + addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values)); + } + } + }); + } + + private static void addValueIfNotPresent( + Map map, String label, String name, @Nullable Object value) { + + if (value != null) { + if (map.containsKey(name)) { + if (logger.isDebugEnabled()) { + logger.debug(label + " '" + name + "' overridden by request bind value."); + } + } + else { + map.put(name, value); + } + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContext.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContext.java index 8fa00f33d5e..c8d67038c64 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContext.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,6 +71,15 @@ class InitBinderBindingContext extends BindingContext { } + /** + * Returns an instance of {@link ExtendedWebExchangeDataBinder}. + * @since 6.2.1 + */ + @Override + protected WebExchangeDataBinder createBinderInstance(@Nullable Object target, String name) { + return new ExtendedWebExchangeDataBinder(target, name); + } + @Override protected WebExchangeDataBinder initDataBinder(WebExchangeDataBinder dataBinder, ServerWebExchange exchange) { this.binderMethods.stream() diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/BindingContextTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/BindingContextTests.java index 995cc00648d..dab582a98c2 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/BindingContextTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/BindingContextTests.java @@ -17,20 +17,16 @@ package org.springframework.web.reactive; import java.lang.reflect.Method; -import java.util.Map; import jakarta.validation.Valid; import org.junit.jupiter.api.Test; -import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.core.ResolvableType; -import org.springframework.http.MediaType; import org.springframework.validation.Errors; import org.springframework.validation.SmartValidator; import org.springframework.validation.Validator; import org.springframework.validation.beanvalidation.LocalValidatorFactoryBean; import org.springframework.web.bind.WebDataBinder; -import org.springframework.web.bind.support.WebExchangeDataBinder; import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest; import org.springframework.web.testfixture.server.MockServerWebExchange; @@ -68,54 +64,6 @@ class BindingContextTests { assertThat(binder.getValidatorsToApply()).containsExactly(springValidator); } - @Test - void bindUriVariablesAndHeaders() { - - MockServerHttpRequest request = MockServerHttpRequest.get("/path") - .header("Some-Int-Array", "1") - .header("Some-Int-Array", "2") - .build(); - - MockServerWebExchange exchange = MockServerWebExchange.from(request); - exchange.getAttributes().put( - HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, - Map.of("name", "John", "age", "25")); - - TestBean target = new TestBean(); - - BindingContext bindingContext = new BindingContext(null); - WebExchangeDataBinder binder = bindingContext.createDataBinder(exchange, target, "testBean", null); - - binder.bind(exchange).block(); - - assertThat(target.getName()).isEqualTo("John"); - assertThat(target.getAge()).isEqualTo(25); - assertThat(target.getSomeIntArray()).containsExactly(1, 2); - } - - @Test - void bindUriVarsAndHeadersAddedConditionally() { - - MockServerHttpRequest request = MockServerHttpRequest.post("/path") - .header("name", "Johnny") - .contentType(MediaType.APPLICATION_FORM_URLENCODED) - .body("name=John&age=25"); - - MockServerWebExchange exchange = MockServerWebExchange.from(request); - exchange.getAttributes().put(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, Map.of("age", "26")); - - TestBean target = new TestBean(); - - BindingContext bindingContext = new BindingContext(null); - WebExchangeDataBinder binder = bindingContext.createDataBinder(exchange, target, "testBean", null); - - binder.bind(exchange).block(); - - assertThat(target.getName()).isEqualTo("John"); - assertThat(target.getAge()).isEqualTo(25); - } - - @SuppressWarnings("unused") private void handleValidObject(@Valid Foo foo) { } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java index 52557547f01..b12f755ec6a 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java @@ -20,18 +20,23 @@ import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Test; +import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.convert.ConversionService; import org.springframework.format.support.DefaultFormattingConversionService; +import org.springframework.http.MediaType; import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; +import org.springframework.web.bind.support.WebExchangeDataBinder; import org.springframework.web.reactive.BindingContext; +import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.result.method.SyncHandlerMethodArgumentResolver; import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod; import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest; @@ -123,6 +128,52 @@ class InitBinderBindingContextTests { assertThat(dataBinder.getDisallowedFields()[0]).isEqualToIgnoringCase("requestParam-22"); } + @Test + void bindUriVariablesAndHeaders() throws Exception { + + MockServerHttpRequest request = MockServerHttpRequest.get("/path") + .header("Some-Int-Array", "1") + .header("Some-Int-Array", "2") + .build(); + + MockServerWebExchange exchange = MockServerWebExchange.from(request); + exchange.getAttributes().put( + HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, + Map.of("name", "John", "age", "25")); + + TestBean target = new TestBean(); + + BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class); + WebExchangeDataBinder binder = context.createDataBinder(exchange, target, "testBean", null); + + binder.bind(exchange).block(); + + assertThat(target.getName()).isEqualTo("John"); + assertThat(target.getAge()).isEqualTo(25); + assertThat(target.getSomeIntArray()).containsExactly(1, 2); + } + + @Test + void bindUriVarsAndHeadersAddedConditionally() throws Exception { + + MockServerHttpRequest request = MockServerHttpRequest.post("/path") + .header("name", "Johnny") + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .body("name=John&age=25"); + + MockServerWebExchange exchange = MockServerWebExchange.from(request); + exchange.getAttributes().put(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, Map.of("age", "26")); + + TestBean target = new TestBean(); + + BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class); + WebExchangeDataBinder binder = context.createDataBinder(exchange, target, "testBean", null); + + binder.bind(exchange).block(); + + assertThat(target.getName()).isEqualTo("John"); + assertThat(target.getAge()).isEqualTo(25); + } private BindingContext createBindingContext(String methodName, Class... parameterTypes) throws Exception { Object handler = new InitBinderHandler(); From 70c326ed30a0264ef38c59d76b3dbfa6745a4827 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Wed, 11 Dec 2024 15:10:09 +0000 Subject: [PATCH 5/6] Support headers in DataBinding via constructor args Closes gh-34073 --- .../ExtendedWebExchangeDataBinder.java | 7 +++- .../InitBinderBindingContextTests.java | 33 ++++++++++++++++++- .../ExtendedServletRequestDataBinder.java | 10 ++++++ ...ExtendedServletRequestDataBinderTests.java | 30 ++++++++++++++++- 4 files changed, 77 insertions(+), 3 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java index 0863499d867..3bba7da3d88 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java @@ -24,6 +24,7 @@ import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; import org.springframework.web.bind.support.WebExchangeDataBinder; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.server.ServerWebExchange; @@ -57,7 +58,11 @@ public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder { for (Map.Entry> entry : headers.entrySet()) { List values = entry.getValue(); if (!CollectionUtils.isEmpty(values)) { - String name = entry.getKey().replace("-", ""); + // For constructor args with @BindParam mapped to the actual header name + String name = entry.getKey(); + addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values)); + // Also adapt to Java conventions for setters + name = StringUtils.uncapitalize(entry.getKey().replace("-", "")); addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values)); } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java index b12f755ec6a..159c687d572 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java @@ -27,10 +27,12 @@ import org.junit.jupiter.api.Test; import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.ResolvableType; import org.springframework.core.convert.ConversionService; import org.springframework.format.support.DefaultFormattingConversionService; import org.springframework.http.MediaType; import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.BindParam; import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; @@ -129,7 +131,7 @@ class InitBinderBindingContextTests { } @Test - void bindUriVariablesAndHeaders() throws Exception { + void bindUriVariablesAndHeadersViaSetters() throws Exception { MockServerHttpRequest request = MockServerHttpRequest.get("/path") .header("Some-Int-Array", "1") @@ -153,6 +155,31 @@ class InitBinderBindingContextTests { assertThat(target.getSomeIntArray()).containsExactly(1, 2); } + @Test + void bindUriVariablesAndHeadersViaConstructor() throws Exception { + + MockServerHttpRequest request = MockServerHttpRequest.get("/path") + .header("Some-Int-Array", "1") + .header("Some-Int-Array", "2") + .build(); + + MockServerWebExchange exchange = MockServerWebExchange.from(request); + exchange.getAttributes().put( + HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, + Map.of("name", "John", "age", "25")); + + BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class); + WebExchangeDataBinder binder = context.createDataBinder(exchange, null, "dataBean", null); + binder.setTargetType(ResolvableType.forClass(DataBean.class)); + binder.construct(exchange).block(); + + DataBean bean = (DataBean) binder.getTarget(); + + assertThat(bean.name()).isEqualTo("John"); + assertThat(bean.age()).isEqualTo(25); + assertThat(bean.someIntArray()).containsExactly(1, 2); + } + @Test void bindUriVarsAndHeadersAddedConditionally() throws Exception { @@ -212,4 +239,8 @@ class InitBinderBindingContextTests { } } + + private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) { + } + } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java index 4d4e26a131a..c74b37fab8f 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java @@ -156,6 +156,9 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { if (uriVars != null) { value = uriVars.get(name); } + if (value == null && getRequest() instanceof HttpServletRequest httpServletRequest) { + value = getHeaderValue(httpServletRequest, name); + } } return value; } @@ -167,6 +170,13 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { if (uriVars != null) { set.addAll(uriVars.keySet()); } + if (request instanceof HttpServletRequest httpServletRequest) { + Enumeration enumeration = httpServletRequest.getHeaderNames(); + while (enumeration.hasMoreElements()) { + String headerName = enumeration.nextElement(); + set.add(headerName.replaceAll("-", "")); + } + } return set; } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java index 83f64ca1b8c..36fd05508cd 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java @@ -22,7 +22,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.core.ResolvableType; import org.springframework.web.bind.ServletRequestDataBinder; +import org.springframework.web.bind.annotation.BindParam; +import org.springframework.web.bind.support.BindParamNameResolver; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; @@ -45,7 +48,7 @@ class ExtendedServletRequestDataBinderTests { @Test - void createBinder() { + void createBinderViaSetters() { request.setAttribute( HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, Map.of("name", "John", "age", "25")); @@ -62,6 +65,27 @@ class ExtendedServletRequestDataBinderTests { assertThat(target.getSomeIntArray()).containsExactly(1, 2); } + @Test + void createBinderViaConstructor() { + request.setAttribute( + HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, + Map.of("name", "John", "age", "25")); + + request.addHeader("Some-Int-Array", "1"); + request.addHeader("Some-Int-Array", "2"); + + ServletRequestDataBinder binder = new ExtendedServletRequestDataBinder(null); + binder.setTargetType(ResolvableType.forClass(DataBean.class)); + binder.setNameResolver(new BindParamNameResolver()); + binder.construct(request); + + DataBean bean = (DataBean) binder.getTarget(); + + assertThat(bean.name()).isEqualTo("John"); + assertThat(bean.age()).isEqualTo(25); + assertThat(bean.someIntArray()).containsExactly(1, 2); + } + @Test void uriVarsAndHeadersAddedConditionally() { request.addParameter("name", "John"); @@ -88,4 +112,8 @@ class ExtendedServletRequestDataBinderTests { assertThat(target.getAge()).isEqualTo(0); } + + private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) { + } + } From 8aeced9f8060561bd20351446df452e0100912d7 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Wed, 11 Dec 2024 16:06:56 +0000 Subject: [PATCH 6/6] Support header filtering in web data binding Closes gh-34039 --- .../ExtendedWebExchangeDataBinder.java | 35 +++++++++++++++- .../InitBinderBindingContextTests.java | 18 ++++++++ .../ExtendedServletRequestDataBinder.java | 41 +++++++++++++++++-- ...ExtendedServletRequestDataBinderTests.java | 31 ++++++++++++++ 4 files changed, 121 insertions(+), 4 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java index 3bba7da3d88..ec3f998573a 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java @@ -18,6 +18,8 @@ package org.springframework.web.reactive.result.method.annotation; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; import reactor.core.publisher.Mono; @@ -41,12 +43,40 @@ import org.springframework.web.server.ServerWebExchange; */ public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder { + private static final Set FILTERED_HEADER_NAMES = Set.of("Priority"); + + + private Predicate headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name); + public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) { super(target, objectName); } + /** + * Add a Predicate that filters the header names to use for data binding. + * Multiple predicates are combined with {@code AND}. + * @param headerPredicate the predicate to add + * @since 6.2.1 + */ + public void addHeaderPredicate(Predicate headerPredicate) { + this.headerPredicate = this.headerPredicate.and(headerPredicate); + } + + /** + * Set the Predicate that filters the header names to use for data binding. + *

Note that this method resets any previous predicates that may have been + * set, including headers excluded by default such as the RFC 9218 defined + * "Priority" header. + * @param headerPredicate the predicate to add + * @since 6.2.1 + */ + public void setHeaderPredicate(Predicate headerPredicate) { + this.headerPredicate = headerPredicate; + } + + @Override public Mono> getValuesToBind(ServerWebExchange exchange) { return super.getValuesToBind(exchange).doOnNext(map -> { @@ -56,10 +86,13 @@ public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder { } HttpHeaders headers = exchange.getRequest().getHeaders(); for (Map.Entry> entry : headers.entrySet()) { + String name = entry.getKey(); + if (!this.headerPredicate.test(entry.getKey())) { + continue; + } List values = entry.getValue(); if (!CollectionUtils.isEmpty(values)) { // For constructor args with @BindParam mapped to the actual header name - String name = entry.getKey(); addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values)); // Also adapt to Java conventions for setters name = StringUtils.uncapitalize(entry.getKey().replace("-", "")); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java index 159c687d572..ad876a9ad95 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java @@ -202,6 +202,24 @@ class InitBinderBindingContextTests { assertThat(target.getAge()).isEqualTo(25); } + @Test + void headerPredicate() throws Exception { + MockServerHttpRequest request = MockServerHttpRequest.get("/path") + .header("Priority", "u1") + .header("Some-Int-Array", "1") + .header("Another-Int-Array", "1") + .build(); + + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class); + ExtendedWebExchangeDataBinder binder = (ExtendedWebExchangeDataBinder) context.createDataBinder(exchange, null, "", null); + binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array")); + + Map map = binder.getValuesToBind(exchange).block(); + assertThat(map).containsExactlyInAnyOrderEntriesOf(Map.of("someIntArray", "1", "Some-Int-Array", "1")); + } + private BindingContext createBindingContext(String methodName, Class... parameterTypes) throws Exception { Object handler = new InitBinderHandler(); Method method = handler.getClass().getMethod(methodName, parameterTypes); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java index c74b37fab8f..019bf9a75b1 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java @@ -21,12 +21,14 @@ import java.util.Enumeration; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Predicate; import jakarta.servlet.ServletRequest; import jakarta.servlet.http.HttpServletRequest; import org.springframework.beans.MutablePropertyValues; import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; import org.springframework.web.bind.ServletRequestDataBinder; import org.springframework.web.bind.WebDataBinder; import org.springframework.web.servlet.HandlerMapping; @@ -51,6 +53,12 @@ import org.springframework.web.servlet.HandlerMapping; */ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { + private static final Set FILTERED_HEADER_NAMES = Set.of("Priority"); + + + private Predicate headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name); + + /** * Create a new instance, with default object name. * @param target the target object to bind onto (or {@code null} @@ -73,6 +81,29 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { } + /** + * Add a Predicate that filters the header names to use for data binding. + * Multiple predicates are combined with {@code AND}. + * @param headerPredicate the predicate to add + * @since 6.2.1 + */ + public void addHeaderPredicate(Predicate headerPredicate) { + this.headerPredicate = this.headerPredicate.and(headerPredicate); + } + + /** + * Set the Predicate that filters the header names to use for data binding. + *

Note that this method resets any previous predicates that may have been + * set, including headers excluded by default such as the RFC 9218 defined + * "Priority" header. + * @param headerPredicate the predicate to add + * @since 6.2.1 + */ + public void setHeaderPredicate(Predicate headerPredicate) { + this.headerPredicate = headerPredicate; + } + + @Override protected ServletRequestValueResolver createValueResolver(ServletRequest request) { return new ExtendedServletRequestValueResolver(request, this); @@ -93,7 +124,7 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { String name = names.nextElement(); Object value = getHeaderValue(httpRequest, name); if (value != null) { - name = name.replace("-", ""); + name = StringUtils.uncapitalize(name.replace("-", "")); addValueIfNotPresent(mpvs, "Header", name, value); } } @@ -118,7 +149,11 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { } @Nullable - private static Object getHeaderValue(HttpServletRequest request, String name) { + private Object getHeaderValue(HttpServletRequest request, String name) { + if (!this.headerPredicate.test(name)) { + return null; + } + Enumeration valuesEnum = request.getHeaders(name); if (!valuesEnum.hasMoreElements()) { return null; @@ -141,7 +176,7 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { /** * Resolver of values that looks up URI path variables. */ - private static class ExtendedServletRequestValueResolver extends ServletRequestValueResolver { + private class ExtendedServletRequestValueResolver extends ServletRequestValueResolver { ExtendedServletRequestValueResolver(ServletRequest request, WebDataBinder dataBinder) { super(request, dataBinder); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java index 36fd05508cd..1d7653bdf5a 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java @@ -18,9 +18,11 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.util.Map; +import jakarta.servlet.ServletRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.core.ResolvableType; import org.springframework.web.bind.ServletRequestDataBinder; @@ -102,6 +104,22 @@ class ExtendedServletRequestDataBinderTests { assertThat(target.getAge()).isEqualTo(25); } + @Test + void headerPredicate() { + TestBinder binder = new TestBinder(); + binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array")); + + MutablePropertyValues mpvs = new MutablePropertyValues(); + request.addHeader("Priority", "u1"); + request.addHeader("Some-Int-Array", "1"); + request.addHeader("Another-Int-Array", "1"); + + binder.addBindValues(mpvs, request); + + assertThat(mpvs.size()).isEqualTo(1); + assertThat(mpvs.get("someIntArray")).isEqualTo("1"); + } + @Test void noUriTemplateVars() { TestBean target = new TestBean(); @@ -116,4 +134,17 @@ class ExtendedServletRequestDataBinderTests { private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) { } + + private static class TestBinder extends ExtendedServletRequestDataBinder { + + public TestBinder() { + super(null); + } + + @Override + public void addBindValues(MutablePropertyValues mpvs, ServletRequest request) { + super.addBindValues(mpvs, request); + } + } + }