Browse Source

Add missing CORS headers defined in SockJS CORS config

Prior to this commit and following changes done in d27b5d0, the CORS
response headers would not be added for SockJS-related requests, even
though a CORS configuration had been applied to SockJS/WebSocket.
This was due to a missing case in our implementation: calling
`AbstractHandlerMapping#getHandlerInternal` can return a Handler
directly, but also a `HandlerExecutionChain` in some cases, as explained
in the Javadoc.

This commit ensures that, when checking for existing CORS configuration,
the `AbstractHandlerMapping` class also considers the
`HandlerExecutionChain` case and unwraps it to get the CORS
configuration from the actual Handler.

Fixes gh-23843
pull/23891/head
Brian Clozel 6 years ago
parent
commit
7d02ba0694
  1. 3
      spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java
  2. 37
      spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java

3
spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java

@ -485,6 +485,9 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport @@ -485,6 +485,9 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
* @since 5.2
*/
protected boolean hasCorsConfigurationSource(Object handler) {
if (handler instanceof HandlerExecutionChain) {
handler = ((HandlerExecutionChain) handler).getHandler();
}
return (handler instanceof CorsConfigurationSource || this.corsConfigurationSource != null);
}

37
spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java

@ -48,7 +48,7 @@ import static org.mockito.Mockito.mock; @@ -48,7 +48,7 @@ import static org.mockito.Mockito.mock;
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
*/
public class CorsAbstractHandlerMappingTests {
class CorsAbstractHandlerMappingTests {
private MockHttpServletRequest request;
@ -56,7 +56,7 @@ public class CorsAbstractHandlerMappingTests { @@ -56,7 +56,7 @@ public class CorsAbstractHandlerMappingTests {
@BeforeEach
public void setup() {
void setup() {
StaticWebApplicationContext context = new StaticWebApplicationContext();
this.handlerMapping = new TestHandlerMapping();
this.handlerMapping.setInterceptors(mock(HandlerInterceptor.class));
@ -66,7 +66,7 @@ public class CorsAbstractHandlerMappingTests { @@ -66,7 +66,7 @@ public class CorsAbstractHandlerMappingTests {
}
@Test
public void actualRequestWithoutCorsConfigurationProvider() throws Exception {
void actualRequestWithoutCorsConfigurationProvider() throws Exception {
this.request.setMethod(RequestMethod.GET.name());
this.request.setRequestURI("/foo");
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
@ -79,7 +79,7 @@ public class CorsAbstractHandlerMappingTests { @@ -79,7 +79,7 @@ public class CorsAbstractHandlerMappingTests {
}
@Test
public void preflightRequestWithoutCorsConfigurationProvider() throws Exception {
void preflightRequestWithoutCorsConfigurationProvider() throws Exception {
this.request.setMethod(RequestMethod.OPTIONS.name());
this.request.setRequestURI("/foo");
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
@ -92,7 +92,7 @@ public class CorsAbstractHandlerMappingTests { @@ -92,7 +92,7 @@ public class CorsAbstractHandlerMappingTests {
}
@Test
public void actualRequestWithCorsConfigurationProvider() throws Exception {
void actualRequestWithCorsConfigurationProvider() throws Exception {
this.request.setMethod(RequestMethod.GET.name());
this.request.setRequestURI("/cors");
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
@ -105,8 +105,22 @@ public class CorsAbstractHandlerMappingTests { @@ -105,8 +105,22 @@ public class CorsAbstractHandlerMappingTests {
assertThat(getRequiredCorsConfiguration(chain, false).getAllowedOrigins()).isEqualTo(Collections.singletonList("*"));
}
@Test // see gh-23843
void actualRequestWithCorsConfigurationProviderForHandlerChain() throws Exception {
this.request.setMethod(RequestMethod.GET.name());
this.request.setRequestURI("/chain");
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
assertThat(chain).isNotNull();
boolean condition = chain.getHandler() instanceof CorsAwareHandler;
assertThat(condition).isTrue();
assertThat(getRequiredCorsConfiguration(chain, false).getAllowedOrigins()).isEqualTo(Collections.singletonList("*"));
}
@Test
public void preflightRequestWithCorsConfigurationProvider() throws Exception {
void preflightRequestWithCorsConfigurationProvider() throws Exception {
this.request.setMethod(RequestMethod.OPTIONS.name());
this.request.setRequestURI("/cors");
this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com");
@ -120,7 +134,7 @@ public class CorsAbstractHandlerMappingTests { @@ -120,7 +134,7 @@ public class CorsAbstractHandlerMappingTests {
}
@Test
public void actualRequestWithMappedCorsConfiguration() throws Exception {
void actualRequestWithMappedCorsConfiguration() throws Exception {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOrigin("*");
this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/foo", config));
@ -137,7 +151,7 @@ public class CorsAbstractHandlerMappingTests { @@ -137,7 +151,7 @@ public class CorsAbstractHandlerMappingTests {
}
@Test
public void preflightRequestWithMappedCorsConfiguration() throws Exception {
void preflightRequestWithMappedCorsConfiguration() throws Exception {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOrigin("*");
this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/foo", config));
@ -154,7 +168,7 @@ public class CorsAbstractHandlerMappingTests { @@ -154,7 +168,7 @@ public class CorsAbstractHandlerMappingTests {
}
@Test
public void actualRequestWithCorsConfigurationSource() throws Exception {
void actualRequestWithCorsConfigurationSource() throws Exception {
this.handlerMapping.setCorsConfigurationSource(new CustomCorsConfigurationSource());
this.request.setMethod(RequestMethod.GET.name());
this.request.setRequestURI("/foo");
@ -172,7 +186,7 @@ public class CorsAbstractHandlerMappingTests { @@ -172,7 +186,7 @@ public class CorsAbstractHandlerMappingTests {
}
@Test
public void preflightRequestWithCorsConfigurationSource() throws Exception {
void preflightRequestWithCorsConfigurationSource() throws Exception {
this.handlerMapping.setCorsConfigurationSource(new CustomCorsConfigurationSource());
this.request.setMethod(RequestMethod.OPTIONS.name());
this.request.setRequestURI("/foo");
@ -217,6 +231,9 @@ public class CorsAbstractHandlerMappingTests { @@ -217,6 +231,9 @@ public class CorsAbstractHandlerMappingTests {
if (request.getRequestURI().equals("/cors")) {
return new CorsAwareHandler();
}
else if (request.getRequestURI().equals("/chain")) {
return new HandlerExecutionChain(new CorsAwareHandler());
}
return new SimpleHandler();
}
}

Loading…
Cancel
Save