Browse Source

Add PreFlightRequestHandler for Spring MVC

This is equivalent of the same contract for WebFlux. It is implemented
by HandlerMappingIntrospector, and may be called directly by Spring
Security to handle a pre-flight request without delegate to the rest
of the filter chain.

HandlerMappingIntrospector also has the boolean method
allHandlerMappingsUsePathPatternParser that checks whether all handler
mappings are configured to use parsed PathPattern's.

See gh-31823
pull/32703/head
rstoyanchev 2 years ago
parent
commit
75a5409c97
  1. 40
      spring-web/src/main/java/org/springframework/web/cors/PreFlightRequestHandler.java
  2. 34
      spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java
  3. 50
      spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java
  4. 9
      spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java
  5. 60
      spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java
  6. 7
      spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java

40
spring-web/src/main/java/org/springframework/web/cors/PreFlightRequestHandler.java

@ -0,0 +1,40 @@ @@ -0,0 +1,40 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.cors;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
/**
* Handler for CORS pre-flight requests.
*
* @author Rossen Stoyanchev
* @since 6.2
*/
public interface PreFlightRequestHandler {
/**
* Handle a pre-flight request by finding and applying the CORS configuration
* that matches the expected actual request. As a result of handling, the
* response should be updated with CORS headers or rejected with
* {@link org.springframework.http.HttpStatus#FORBIDDEN}.
* @param request current HTTP request
* @param response current HTTP response
*/
void handlePreFlight(HttpServletRequest request, HttpServletResponse response) throws Exception;
}

34
spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java

@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
package org.springframework.web.servlet.handler;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@ -47,6 +48,7 @@ import org.springframework.web.cors.CorsConfigurationSource; @@ -47,6 +48,7 @@ import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.CorsProcessor;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.cors.DefaultCorsProcessor;
import org.springframework.web.cors.PreFlightRequestHandler;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerExecutionChain;
@ -679,9 +681,9 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport @@ -679,9 +681,9 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
HttpServletRequest request, HandlerExecutionChain chain, @Nullable CorsConfiguration config) {
if (CorsUtils.isPreFlightRequest(request)) {
PreFlightHandler preFlightHandler = new PreFlightHandler(config);
chain.addInterceptor(0, preFlightHandler);
return new HandlerExecutionChain(preFlightHandler, chain.getInterceptors());
PreFlightHttpRequestHandler handler = new PreFlightHttpRequestHandler(config);
chain.addInterceptor(0, handler);
return new HandlerExecutionChain(handler, chain.getInterceptors());
}
else {
chain.addInterceptor(0, new CorsInterceptor(config));
@ -699,6 +701,12 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport @@ -699,6 +701,12 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
this.config = config;
}
@Override
@Nullable
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
return this.config;
}
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
throws Exception {
@ -709,20 +717,21 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport @@ -709,20 +717,21 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
return true;
}
return corsProcessor.processRequest(this.config, request, response);
return invokeCorsProcessor(request, response);
}
@Override
@Nullable
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
return this.config;
protected boolean invokeCorsProcessor(
HttpServletRequest request, HttpServletResponse response) throws IOException {
return corsProcessor.processRequest(this.config, request, response);
}
}
private class PreFlightHandler extends CorsInterceptor implements HttpRequestHandler {
private final class PreFlightHttpRequestHandler
extends CorsInterceptor implements HttpRequestHandler, PreFlightRequestHandler {
public PreFlightHandler(@Nullable CorsConfiguration config) {
public PreFlightHttpRequestHandler(@Nullable CorsConfiguration config) {
super(config);
}
@ -730,6 +739,11 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport @@ -730,6 +739,11 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
public void handleRequest(HttpServletRequest request, HttpServletResponse response) {
// no-op
}
@Override
public void handlePreFlight(HttpServletRequest request, HttpServletResponse response) throws IOException {
invokeCorsProcessor(request, response);
}
}
}

50
spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java

@ -33,6 +33,7 @@ import jakarta.servlet.Filter; @@ -33,6 +33,7 @@ import jakarta.servlet.Filter;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -45,16 +46,20 @@ import org.springframework.core.io.ClassPathResource; @@ -45,16 +46,20 @@ import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PropertiesLoaderUtils;
import org.springframework.http.server.RequestPath;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.cors.PreFlightRequestHandler;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.NoHandlerFoundException;
import org.springframework.web.util.ServletRequestPathUtils;
import org.springframework.web.util.UrlPathHelper;
import org.springframework.web.util.pattern.PathPatternParser;
@ -87,7 +92,7 @@ import org.springframework.web.util.pattern.PathPatternParser; @@ -87,7 +92,7 @@ import org.springframework.web.util.pattern.PathPatternParser;
* @since 4.3.1
*/
public class HandlerMappingIntrospector
implements CorsConfigurationSource, ApplicationContextAware, InitializingBean {
implements CorsConfigurationSource, PreFlightRequestHandler, ApplicationContextAware, InitializingBean {
private static final Log logger = LogFactory.getLog(HandlerMappingIntrospector.class.getName());
@ -172,6 +177,49 @@ public class HandlerMappingIntrospector @@ -172,6 +177,49 @@ public class HandlerMappingIntrospector
return (this.handlerMappings != null ? this.handlerMappings : Collections.emptyList());
}
/**
* Return {@code true} if all {@link HandlerMapping} beans
* {@link HandlerMapping#usesPathPatterns() use parsed PathPatterns},
* and {@code false} if any don't.
* @since 6.2
*/
public boolean allHandlerMappingsUsePathPatternParser() {
Assert.state(this.handlerMappings != null, "Not yet initialized via afterPropertiesSet.");
return getHandlerMappings().stream().allMatch(HandlerMapping::usesPathPatterns);
}
/**
* Find the matching {@link HandlerMapping} for the request, and invoke the
* handler it returns as a {@link PreFlightRequestHandler}.
* @throws NoHandlerFoundException if no handler matches the request
* @since 6.2
*/
public void handlePreFlight(HttpServletRequest request, HttpServletResponse response) throws Exception {
Assert.state(this.handlerMappings != null, "Not yet initialized via afterPropertiesSet.");
Assert.state(CorsUtils.isPreFlightRequest(request), "Not a pre-flight request.");
RequestPath previousPath = (RequestPath) request.getAttribute(ServletRequestPathUtils.PATH_ATTRIBUTE);
try {
ServletRequestPathUtils.parseAndCache(request);
for (HandlerMapping mapping : this.handlerMappings) {
HandlerExecutionChain chain = mapping.getHandler(request);
if (chain != null) {
Object handler = chain.getHandler();
if (handler instanceof PreFlightRequestHandler preFlightHandler) {
preFlightHandler.handlePreFlight(request, response);
return;
}
throw new IllegalStateException("Expected PreFlightRequestHandler: " + handler.getClass());
}
}
throw new NoHandlerFoundException(
request.getMethod(), request.getRequestURI(), new ServletServerHttpRequest(request).getHeaders());
}
finally {
ServletRequestPathUtils.setParsedRequestPath(previousPath, request);
}
}
/**
* {@link Filter} that looks up the {@code MatchableHandlerMapping} and

9
spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java

@ -31,6 +31,7 @@ import org.springframework.web.HttpRequestHandler; @@ -31,6 +31,7 @@ import org.springframework.web.HttpRequestHandler;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.PreFlightRequestHandler;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.support.WebContentGenerator;
@ -72,7 +73,7 @@ class CorsAbstractHandlerMappingTests { @@ -72,7 +73,7 @@ class CorsAbstractHandlerMappingTests {
assertThat(chain).isNotNull();
assertThat(chain.getHandler()).isNotNull();
assertThat(chain.getHandler().getClass().getSimpleName()).isEqualTo("PreFlightHandler");
assertThat(chain.getHandler()).isInstanceOf(PreFlightRequestHandler.class);
assertThat(mapping.hasSavedCorsConfig()).isFalse();
}
@ -103,7 +104,7 @@ class CorsAbstractHandlerMappingTests { @@ -103,7 +104,7 @@ class CorsAbstractHandlerMappingTests {
assertThat(chain).isNotNull();
assertThat(chain.getHandler()).isNotNull();
assertThat(chain.getHandler().getClass().getSimpleName()).isEqualTo("PreFlightHandler");
assertThat(chain.getHandler()).isInstanceOf(PreFlightRequestHandler.class);
assertThat(mapping.getRequiredCorsConfig().getAllowedOrigins()).containsExactly("*");
}
@ -144,7 +145,7 @@ class CorsAbstractHandlerMappingTests { @@ -144,7 +145,7 @@ class CorsAbstractHandlerMappingTests {
assertThat(chain).isNotNull();
assertThat(chain.getHandler()).isNotNull();
assertThat(chain.getHandler().getClass().getSimpleName()).isEqualTo("PreFlightHandler");
assertThat(chain.getHandler()).isInstanceOf(PreFlightRequestHandler.class);
assertThat(mapping.getRequiredCorsConfig().getAllowedOrigins()).containsExactly("*");
}
@ -172,7 +173,7 @@ class CorsAbstractHandlerMappingTests { @@ -172,7 +173,7 @@ class CorsAbstractHandlerMappingTests {
assertThat(chain).isNotNull();
assertThat(chain.getHandler()).isNotNull();
assertThat(chain.getHandler().getClass().getSimpleName()).isEqualTo("PreFlightHandler");
assertThat(chain.getHandler()).isInstanceOf(PreFlightRequestHandler.class);
CorsConfiguration config = mapping.getRequiredCorsConfig();
assertThat(config).isNotNull();

60
spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java

@ -48,6 +48,7 @@ import org.springframework.web.cors.CorsConfiguration; @@ -48,6 +48,7 @@ import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.NoHandlerFoundException;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.RouterFunctions;
import org.springframework.web.servlet.function.ServerResponse;
@ -63,6 +64,7 @@ import org.springframework.web.util.pattern.PatternParseException; @@ -63,6 +64,7 @@ import org.springframework.web.util.pattern.PatternParseException;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.springframework.web.servlet.HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE;
/**
@ -113,6 +115,29 @@ class HandlerMappingIntrospectorTests { @@ -113,6 +115,29 @@ class HandlerMappingIntrospectorTests {
assertThat(actual).isEqualTo(expected);
}
@Test
void useParsedPatternsOnly() {
GenericWebApplicationContext context = new GenericWebApplicationContext();
context.registerBean("A", SimpleUrlHandlerMapping.class);
context.registerBean("B", SimpleUrlHandlerMapping.class);
context.registerBean("C", SimpleUrlHandlerMapping.class);
context.refresh();
assertThat(initIntrospector(context).allHandlerMappingsUsePathPatternParser()).isTrue();
context = new GenericWebApplicationContext();
context.registerBean("A", SimpleUrlHandlerMapping.class);
context.registerBean("B", SimpleUrlHandlerMapping.class);
context.registerBean("C", SimpleUrlHandlerMapping.class, () -> {
SimpleUrlHandlerMapping mapping = new SimpleUrlHandlerMapping();
mapping.setPatternParser(null);
return mapping;
});
context.refresh();
assertThat(initIntrospector(context).allHandlerMappingsUsePathPatternParser()).isFalse();
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void getMatchable(boolean usePathPatterns) throws Exception {
@ -204,6 +229,41 @@ class HandlerMappingIntrospectorTests { @@ -204,6 +229,41 @@ class HandlerMappingIntrospectorTests {
assertThat(corsConfig.getAllowedMethods()).isEqualTo(Collections.singletonList("POST"));
}
@Test
void handlePreFlight() throws Exception {
AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext();
context.register(TestConfig.class);
context.refresh();
MockHttpServletRequest request = new MockHttpServletRequest("OPTIONS", "/path");
request.addHeader("Origin", "http://localhost:9000");
request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST");
MockHttpServletResponse response = new MockHttpServletResponse();
initIntrospector(context).handlePreFlight(request, response);
assertThat(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).isEqualTo("http://localhost:9000");
assertThat(response.getHeaders(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)).containsExactly("POST");
}
@Test
void handlePreFlightWithNoHandlerFoundException() {
AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext();
context.register(TestConfig.class);
context.refresh();
MockHttpServletRequest request = new MockHttpServletRequest("OPTIONS", "/unknownPath");
request.addHeader("Origin", "http://localhost:9000");
request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST");
MockHttpServletResponse response = new MockHttpServletResponse();
assertThatThrownBy(() -> initIntrospector(context).handlePreFlight(request, response))
.isInstanceOf(NoHandlerFoundException.class);
assertThat(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).isNull();
assertThat(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)).isNull();
}
@ParameterizedTest
@ValueSource(strings = {"/test", "/resource/1234****"}) // gh-31937
void cacheFilter(String uri) throws Exception {

7
spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -47,6 +47,7 @@ import org.springframework.web.bind.annotation.RequestMapping; @@ -47,6 +47,7 @@ import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.PreFlightRequestHandler;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.handler.PathPatternsParameterizedTest;
@ -127,7 +128,7 @@ class CrossOriginTests { @@ -127,7 +128,7 @@ class CrossOriginTests {
request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = mapping.getHandler(request);
assertThat(chain).isNotNull();
assertThat(chain.getHandler().getClass().getName()).endsWith("AbstractHandlerMapping$PreFlightHandler");
assertThat(chain.getHandler()).isInstanceOf(PreFlightRequestHandler.class);
}
@PathPatternsParameterizedTest // SPR-12931
@ -389,7 +390,7 @@ class CrossOriginTests { @@ -389,7 +390,7 @@ class CrossOriginTests {
assertThat(chain).isNotNull();
if (isPreFlightRequest) {
Object handler = chain.getHandler();
assertThat(handler.getClass().getSimpleName()).isEqualTo("PreFlightHandler");
assertThat(handler).isInstanceOf(PreFlightRequestHandler.class);
DirectFieldAccessor accessor = new DirectFieldAccessor(handler);
return (CorsConfiguration)accessor.getPropertyValue("config");
}

Loading…
Cancel
Save