diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java index 42566bdaec0..28f059926f9 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java @@ -485,6 +485,9 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport * @since 5.2 */ protected boolean hasCorsConfigurationSource(Object handler) { + if (handler instanceof HandlerExecutionChain) { + handler = ((HandlerExecutionChain) handler).getHandler(); + } return (handler instanceof CorsConfigurationSource || this.corsConfigurationSource != null); } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java index 9cb12f794b2..b8213381b02 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java @@ -48,7 +48,7 @@ import static org.mockito.Mockito.mock; * @author Sebastien Deleuze * @author Rossen Stoyanchev */ -public class CorsAbstractHandlerMappingTests { +class CorsAbstractHandlerMappingTests { private MockHttpServletRequest request; @@ -56,7 +56,7 @@ public class CorsAbstractHandlerMappingTests { @BeforeEach - public void setup() { + void setup() { StaticWebApplicationContext context = new StaticWebApplicationContext(); this.handlerMapping = new TestHandlerMapping(); this.handlerMapping.setInterceptors(mock(HandlerInterceptor.class)); @@ -66,7 +66,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void actualRequestWithoutCorsConfigurationProvider() throws Exception { + void actualRequestWithoutCorsConfigurationProvider() throws Exception { this.request.setMethod(RequestMethod.GET.name()); this.request.setRequestURI("/foo"); this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); @@ -79,7 +79,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void preflightRequestWithoutCorsConfigurationProvider() throws Exception { + void preflightRequestWithoutCorsConfigurationProvider() throws Exception { this.request.setMethod(RequestMethod.OPTIONS.name()); this.request.setRequestURI("/foo"); this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); @@ -92,7 +92,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void actualRequestWithCorsConfigurationProvider() throws Exception { + void actualRequestWithCorsConfigurationProvider() throws Exception { this.request.setMethod(RequestMethod.GET.name()); this.request.setRequestURI("/cors"); this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); @@ -105,8 +105,22 @@ public class CorsAbstractHandlerMappingTests { assertThat(getRequiredCorsConfiguration(chain, false).getAllowedOrigins()).isEqualTo(Collections.singletonList("*")); } + @Test // see gh-23843 + void actualRequestWithCorsConfigurationProviderForHandlerChain() throws Exception { + this.request.setMethod(RequestMethod.GET.name()); + this.request.setRequestURI("/chain"); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + + assertThat(chain).isNotNull(); + boolean condition = chain.getHandler() instanceof CorsAwareHandler; + assertThat(condition).isTrue(); + assertThat(getRequiredCorsConfiguration(chain, false).getAllowedOrigins()).isEqualTo(Collections.singletonList("*")); + } + @Test - public void preflightRequestWithCorsConfigurationProvider() throws Exception { + void preflightRequestWithCorsConfigurationProvider() throws Exception { this.request.setMethod(RequestMethod.OPTIONS.name()); this.request.setRequestURI("/cors"); this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); @@ -120,7 +134,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void actualRequestWithMappedCorsConfiguration() throws Exception { + void actualRequestWithMappedCorsConfiguration() throws Exception { CorsConfiguration config = new CorsConfiguration(); config.addAllowedOrigin("*"); this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/foo", config)); @@ -137,7 +151,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void preflightRequestWithMappedCorsConfiguration() throws Exception { + void preflightRequestWithMappedCorsConfiguration() throws Exception { CorsConfiguration config = new CorsConfiguration(); config.addAllowedOrigin("*"); this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/foo", config)); @@ -154,7 +168,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void actualRequestWithCorsConfigurationSource() throws Exception { + void actualRequestWithCorsConfigurationSource() throws Exception { this.handlerMapping.setCorsConfigurationSource(new CustomCorsConfigurationSource()); this.request.setMethod(RequestMethod.GET.name()); this.request.setRequestURI("/foo"); @@ -172,7 +186,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void preflightRequestWithCorsConfigurationSource() throws Exception { + void preflightRequestWithCorsConfigurationSource() throws Exception { this.handlerMapping.setCorsConfigurationSource(new CustomCorsConfigurationSource()); this.request.setMethod(RequestMethod.OPTIONS.name()); this.request.setRequestURI("/foo"); @@ -217,6 +231,9 @@ public class CorsAbstractHandlerMappingTests { if (request.getRequestURI().equals("/cors")) { return new CorsAwareHandler(); } + else if (request.getRequestURI().equals("/chain")) { + return new HandlerExecutionChain(new CorsAwareHandler()); + } return new SimpleHandler(); } }