diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java index df030d5be7c..e254fddac2a 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -91,6 +91,10 @@ public class HandlerExecutionChain { initInterceptorList().add(interceptor); } + public void addInterceptor(int index, HandlerInterceptor interceptor) { + initInterceptorList().add(index, interceptor); + } + public void addInterceptors(HandlerInterceptor... interceptors) { if (!ObjectUtils.isEmpty(interceptors)) { CollectionUtils.mergeArrayIntoCollection(interceptors, initInterceptorList()); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java index bdb0600650b..4d2a9558b18 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java @@ -526,7 +526,7 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport chain = new HandlerExecutionChain(new PreFlightHandler(config), interceptors); } else { - chain.addInterceptor(new CorsInterceptor(config)); + chain.addInterceptor(0, new CorsInterceptor(config)); } return chain; } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java index f037784ba68..c54b1640847 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java @@ -86,16 +86,11 @@ import org.springframework.web.servlet.view.ViewResolverComposite; import org.springframework.web.servlet.view.json.MappingJackson2JsonView; import org.springframework.web.util.UrlPathHelper; -import static com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES; -import static com.fasterxml.jackson.databind.MapperFeature.DEFAULT_VIEW_INCLUSION; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertSame; +import static com.fasterxml.jackson.databind.DeserializationFeature.*; +import static com.fasterxml.jackson.databind.MapperFeature.*; +import static org.junit.Assert.*; import static org.mockito.Mockito.*; -import static org.springframework.http.MediaType.APPLICATION_ATOM_XML; -import static org.springframework.http.MediaType.APPLICATION_JSON; -import static org.springframework.http.MediaType.APPLICATION_XML; +import static org.springframework.http.MediaType.*; /** * A test fixture with a sub-class of {@link WebMvcConfigurationSupport} that also @@ -141,9 +136,10 @@ public class WebMvcConfigurationSupportExtensionTests { assertNotNull(chain); assertNotNull(chain.getInterceptors()); assertEquals(4, chain.getInterceptors().length); - assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[0].getClass()); - assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[1].getClass()); - assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[2].getClass()); + assertEquals("CorsInterceptor", chain.getInterceptors()[0].getClass().getSimpleName()); + assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[1].getClass()); + assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[2].getClass()); + assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[3].getClass()); Map map = rmHandlerMapping.getHandlerMethods(); assertEquals(2, map.size()); @@ -185,10 +181,11 @@ public class WebMvcConfigurationSupportExtensionTests { assertNotNull(chain); assertNotNull(chain.getHandler()); assertEquals(Arrays.toString(chain.getInterceptors()), 5, chain.getInterceptors().length); - // PathExposingHandlerInterceptor at chain.getInterceptors()[0] - assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[1].getClass()); - assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[2].getClass()); - assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[3].getClass()); + assertEquals("CorsInterceptor", chain.getInterceptors()[0].getClass().getSimpleName()); + // PathExposingHandlerInterceptor at chain.getInterceptors()[1] + assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[2].getClass()); + assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[3].getClass()); + assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[4].getClass()); handlerMapping = (AbstractHandlerMapping) this.config.defaultServletHandlerMapping(); handlerMapping.setApplicationContext(this.context); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java index d400c4ca3ec..df97aeda198 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java @@ -16,11 +16,8 @@ package org.springframework.web.servlet.handler; -import static org.junit.Assert.*; - import java.io.IOException; import java.util.Collections; - import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -32,6 +29,7 @@ import org.springframework.beans.DirectFieldAccessor; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.util.ObjectUtils; import org.springframework.web.HttpRequestHandler; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.context.support.StaticWebApplicationContext; @@ -41,6 +39,9 @@ import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.support.WebContentGenerator; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + /** * Unit tests for CORS-related handling in {@link AbstractHandlerMapping}. * @author Sebastien Deleuze @@ -57,6 +58,7 @@ public class CorsAbstractHandlerMappingTests { public void setup() { StaticWebApplicationContext context = new StaticWebApplicationContext(); this.handlerMapping = new TestHandlerMapping(); + this.handlerMapping.setInterceptors(mock(HandlerInterceptor.class)); this.handlerMapping.setApplicationContext(context); this.request = new MockHttpServletRequest(); this.request.setRemoteHost("domain1.com"); @@ -69,6 +71,7 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertTrue(chain.getHandler() instanceof SimpleHandler); } @@ -80,6 +83,7 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertTrue(chain.getHandler() instanceof SimpleHandler); } @@ -91,11 +95,10 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertTrue(chain.getHandler() instanceof CorsAwareHandler); - CorsConfiguration config = getCorsConfiguration(chain, false); - assertNotNull(config); - assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"}); + assertEquals(Collections.singletonList("*"), getRequiredCorsConfiguration(chain, false).getAllowedOrigins()); } @Test @@ -105,12 +108,11 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertNotNull(chain.getHandler()); - assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler")); - CorsConfiguration config = getCorsConfiguration(chain, true); - assertNotNull(config); - assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"}); + assertEquals("PreFlightHandler", chain.getHandler().getClass().getSimpleName()); + assertEquals(Collections.singletonList("*"), getRequiredCorsConfiguration(chain, true).getAllowedOrigins()); } @Test @@ -123,11 +125,10 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertTrue(chain.getHandler() instanceof SimpleHandler); - config = getCorsConfiguration(chain, false); - assertNotNull(config); - assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"}); + assertEquals(Collections.singletonList("*"), getRequiredCorsConfiguration(chain, false).getAllowedOrigins()); } @Test @@ -140,12 +141,11 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertNotNull(chain.getHandler()); - assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler")); - config = getCorsConfiguration(chain, true); - assertNotNull(config); - assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"}); + assertEquals("PreFlightHandler", chain.getHandler().getClass().getSimpleName()); + assertEquals(Collections.singletonList("*"), getRequiredCorsConfiguration(chain, true).getAllowedOrigins()); } @Test @@ -156,11 +156,12 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertTrue(chain.getHandler() instanceof SimpleHandler); - CorsConfiguration config = getCorsConfiguration(chain, false); + CorsConfiguration config = getRequiredCorsConfiguration(chain, false); assertNotNull(config); - assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); + assertEquals(Collections.singletonList("*"), config.getAllowedOrigins()); assertEquals(true, config.getAllowCredentials()); } @@ -172,35 +173,35 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + assertNotNull(chain); assertNotNull(chain.getHandler()); - assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler")); - CorsConfiguration config = getCorsConfiguration(chain, true); + assertEquals("PreFlightHandler", chain.getHandler().getClass().getSimpleName()); + CorsConfiguration config = getRequiredCorsConfiguration(chain, true); assertNotNull(config); - assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray()); + assertEquals(Collections.singletonList("*"), config.getAllowedOrigins()); assertEquals(true, config.getAllowCredentials()); } - private CorsConfiguration getCorsConfiguration(HandlerExecutionChain chain, boolean isPreFlightRequest) { + @SuppressWarnings("ConstantConditions") + private CorsConfiguration getRequiredCorsConfiguration(HandlerExecutionChain chain, boolean isPreFlightRequest) { + CorsConfiguration corsConfig = null; if (isPreFlightRequest) { Object handler = chain.getHandler(); - assertTrue(handler.getClass().getSimpleName().equals("PreFlightHandler")); + assertEquals("PreFlightHandler", handler.getClass().getSimpleName()); DirectFieldAccessor accessor = new DirectFieldAccessor(handler); - return (CorsConfiguration)accessor.getPropertyValue("config"); + corsConfig = (CorsConfiguration) accessor.getPropertyValue("config"); } else { HandlerInterceptor[] interceptors = chain.getInterceptors(); - if (interceptors != null) { - for (HandlerInterceptor interceptor : interceptors) { - if (interceptor.getClass().getSimpleName().equals("CorsInterceptor")) { - DirectFieldAccessor accessor = new DirectFieldAccessor(interceptor); - return (CorsConfiguration) accessor.getPropertyValue("config"); - } - } + if (!ObjectUtils.isEmpty(interceptors)) { + DirectFieldAccessor accessor = new DirectFieldAccessor(interceptors[0]); + corsConfig = (CorsConfiguration) accessor.getPropertyValue("config"); } } - return null; + assertNotNull(corsConfig); + return corsConfig; } public class TestHandlerMapping extends AbstractHandlerMapping {