@ -16,11 +16,8 @@
@@ -16,11 +16,8 @@
package org.springframework.web.servlet.handler ;
import static org.junit.Assert.* ;
import java.io.IOException ;
import java.util.Collections ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
@ -32,6 +29,7 @@ import org.springframework.beans.DirectFieldAccessor;
@@ -32,6 +29,7 @@ import org.springframework.beans.DirectFieldAccessor;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpStatus ;
import org.springframework.mock.web.test.MockHttpServletRequest ;
import org.springframework.util.ObjectUtils ;
import org.springframework.web.HttpRequestHandler ;
import org.springframework.web.bind.annotation.RequestMethod ;
import org.springframework.web.context.support.StaticWebApplicationContext ;
@ -41,6 +39,9 @@ import org.springframework.web.servlet.HandlerExecutionChain;
@@ -41,6 +39,9 @@ import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor ;
import org.springframework.web.servlet.support.WebContentGenerator ;
import static org.junit.Assert.* ;
import static org.mockito.Mockito.* ;
/ * *
* Unit tests for CORS - related handling in { @link AbstractHandlerMapping } .
* @author Sebastien Deleuze
@ -57,6 +58,7 @@ public class CorsAbstractHandlerMappingTests {
@@ -57,6 +58,7 @@ public class CorsAbstractHandlerMappingTests {
public void setup ( ) {
StaticWebApplicationContext context = new StaticWebApplicationContext ( ) ;
this . handlerMapping = new TestHandlerMapping ( ) ;
this . handlerMapping . setInterceptors ( mock ( HandlerInterceptor . class ) ) ;
this . handlerMapping . setApplicationContext ( context ) ;
this . request = new MockHttpServletRequest ( ) ;
this . request . setRemoteHost ( "domain1.com" ) ;
@ -69,6 +71,7 @@ public class CorsAbstractHandlerMappingTests {
@@ -69,6 +71,7 @@ public class CorsAbstractHandlerMappingTests {
this . request . addHeader ( HttpHeaders . ORIGIN , "https://domain2.com" ) ;
this . request . addHeader ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" ) ;
HandlerExecutionChain chain = handlerMapping . getHandler ( this . request ) ;
assertNotNull ( chain ) ;
assertTrue ( chain . getHandler ( ) instanceof SimpleHandler ) ;
}
@ -80,6 +83,7 @@ public class CorsAbstractHandlerMappingTests {
@@ -80,6 +83,7 @@ public class CorsAbstractHandlerMappingTests {
this . request . addHeader ( HttpHeaders . ORIGIN , "https://domain2.com" ) ;
this . request . addHeader ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" ) ;
HandlerExecutionChain chain = handlerMapping . getHandler ( this . request ) ;
assertNotNull ( chain ) ;
assertTrue ( chain . getHandler ( ) instanceof SimpleHandler ) ;
}
@ -91,11 +95,10 @@ public class CorsAbstractHandlerMappingTests {
@@ -91,11 +95,10 @@ public class CorsAbstractHandlerMappingTests {
this . request . addHeader ( HttpHeaders . ORIGIN , "https://domain2.com" ) ;
this . request . addHeader ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" ) ;
HandlerExecutionChain chain = handlerMapping . getHandler ( this . request ) ;
assertNotNull ( chain ) ;
assertTrue ( chain . getHandler ( ) instanceof CorsAwareHandler ) ;
CorsConfiguration config = getCorsConfiguration ( chain , false ) ;
assertNotNull ( config ) ;
assertArrayEquals ( config . getAllowedOrigins ( ) . toArray ( ) , new String [ ] { "*" } ) ;
assertEquals ( Collections . singletonList ( "*" ) , getRequiredCorsConfiguration ( chain , false ) . getAllowedOrigins ( ) ) ;
}
@Test
@ -105,12 +108,11 @@ public class CorsAbstractHandlerMappingTests {
@@ -105,12 +108,11 @@ public class CorsAbstractHandlerMappingTests {
this . request . addHeader ( HttpHeaders . ORIGIN , "https://domain2.com" ) ;
this . request . addHeader ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" ) ;
HandlerExecutionChain chain = handlerMapping . getHandler ( this . request ) ;
assertNotNull ( chain ) ;
assertNotNull ( chain . getHandler ( ) ) ;
assertTrue ( chain . getHandler ( ) . getClass ( ) . getSimpleName ( ) . equals ( "PreFlightHandler" ) ) ;
CorsConfiguration config = getCorsConfiguration ( chain , true ) ;
assertNotNull ( config ) ;
assertArrayEquals ( config . getAllowedOrigins ( ) . toArray ( ) , new String [ ] { "*" } ) ;
assertEquals ( "PreFlightHandler" , chain . getHandler ( ) . getClass ( ) . getSimpleName ( ) ) ;
assertEquals ( Collections . singletonList ( "*" ) , getRequiredCorsConfiguration ( chain , true ) . getAllowedOrigins ( ) ) ;
}
@Test
@ -123,11 +125,10 @@ public class CorsAbstractHandlerMappingTests {
@@ -123,11 +125,10 @@ public class CorsAbstractHandlerMappingTests {
this . request . addHeader ( HttpHeaders . ORIGIN , "https://domain2.com" ) ;
this . request . addHeader ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" ) ;
HandlerExecutionChain chain = handlerMapping . getHandler ( this . request ) ;
assertNotNull ( chain ) ;
assertTrue ( chain . getHandler ( ) instanceof SimpleHandler ) ;
config = getCorsConfiguration ( chain , false ) ;
assertNotNull ( config ) ;
assertArrayEquals ( config . getAllowedOrigins ( ) . toArray ( ) , new String [ ] { "*" } ) ;
assertEquals ( Collections . singletonList ( "*" ) , getRequiredCorsConfiguration ( chain , false ) . getAllowedOrigins ( ) ) ;
}
@Test
@ -140,12 +141,11 @@ public class CorsAbstractHandlerMappingTests {
@@ -140,12 +141,11 @@ public class CorsAbstractHandlerMappingTests {
this . request . addHeader ( HttpHeaders . ORIGIN , "https://domain2.com" ) ;
this . request . addHeader ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" ) ;
HandlerExecutionChain chain = handlerMapping . getHandler ( this . request ) ;
assertNotNull ( chain ) ;
assertNotNull ( chain . getHandler ( ) ) ;
assertTrue ( chain . getHandler ( ) . getClass ( ) . getSimpleName ( ) . equals ( "PreFlightHandler" ) ) ;
config = getCorsConfiguration ( chain , true ) ;
assertNotNull ( config ) ;
assertArrayEquals ( config . getAllowedOrigins ( ) . toArray ( ) , new String [ ] { "*" } ) ;
assertEquals ( "PreFlightHandler" , chain . getHandler ( ) . getClass ( ) . getSimpleName ( ) ) ;
assertEquals ( Collections . singletonList ( "*" ) , getRequiredCorsConfiguration ( chain , true ) . getAllowedOrigins ( ) ) ;
}
@Test
@ -156,11 +156,12 @@ public class CorsAbstractHandlerMappingTests {
@@ -156,11 +156,12 @@ public class CorsAbstractHandlerMappingTests {
this . request . addHeader ( HttpHeaders . ORIGIN , "https://domain2.com" ) ;
this . request . addHeader ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" ) ;
HandlerExecutionChain chain = handlerMapping . getHandler ( this . request ) ;
assertNotNull ( chain ) ;
assertTrue ( chain . getHandler ( ) instanceof SimpleHandler ) ;
CorsConfiguration config = getCorsConfiguration ( chain , false ) ;
CorsConfiguration config = getRequired CorsConfiguration ( chain , false ) ;
assertNotNull ( config ) ;
assertArray Equals ( new String [ ] { "*" } , config . getAllowedOrigins ( ) . toArray ( ) ) ;
assertEquals ( Collections . singletonList ( "*" ) , config . getAllowedOrigins ( ) ) ;
assertEquals ( true , config . getAllowCredentials ( ) ) ;
}
@ -172,35 +173,35 @@ public class CorsAbstractHandlerMappingTests {
@@ -172,35 +173,35 @@ public class CorsAbstractHandlerMappingTests {
this . request . addHeader ( HttpHeaders . ORIGIN , "https://domain2.com" ) ;
this . request . addHeader ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" ) ;
HandlerExecutionChain chain = handlerMapping . getHandler ( this . request ) ;
assertNotNull ( chain ) ;
assertNotNull ( chain . getHandler ( ) ) ;
assertTrue ( chain . getHandler ( ) . getClass ( ) . getSimpleName ( ) . equals ( "PreFlightHandler" ) ) ;
CorsConfiguration config = getCorsConfiguration ( chain , true ) ;
assertEquals ( "PreFlightHandler" , chain . getHandler ( ) . getClass ( ) . getSimpleName ( ) ) ;
CorsConfiguration config = getRequired CorsConfiguration ( chain , true ) ;
assertNotNull ( config ) ;
assertArray Equals ( new String [ ] { "*" } , config . getAllowedOrigins ( ) . toArray ( ) ) ;
assertEquals ( Collections . singletonList ( "*" ) , config . getAllowedOrigins ( ) ) ;
assertEquals ( true , config . getAllowCredentials ( ) ) ;
}
private CorsConfiguration getCorsConfiguration ( HandlerExecutionChain chain , boolean isPreFlightRequest ) {
@SuppressWarnings ( "ConstantConditions" )
private CorsConfiguration getRequiredCorsConfiguration ( HandlerExecutionChain chain , boolean isPreFlightRequest ) {
CorsConfiguration corsConfig = null ;
if ( isPreFlightRequest ) {
Object handler = chain . getHandler ( ) ;
assertTrue ( handler . getClass ( ) . getSimpleName ( ) . equals ( "PreFlightHandler" ) ) ;
assertEquals ( "PreFlightHandler" , handler . getClass ( ) . getSimpleName ( ) ) ;
DirectFieldAccessor accessor = new DirectFieldAccessor ( handler ) ;
return ( CorsConfiguration ) accessor . getPropertyValue ( "config" ) ;
corsConfig = ( CorsConfiguration ) accessor . getPropertyValue ( "config" ) ;
}
else {
HandlerInterceptor [ ] interceptors = chain . getInterceptors ( ) ;
if ( interceptors ! = null ) {
for ( HandlerInterceptor interceptor : interceptors ) {
if ( interceptor . getClass ( ) . getSimpleName ( ) . equals ( "CorsInterceptor" ) ) {
DirectFieldAccessor accessor = new DirectFieldAccessor ( interceptor ) ;
return ( CorsConfiguration ) accessor . getPropertyValue ( "config" ) ;
}
}
if ( ! ObjectUtils . isEmpty ( interceptors ) ) {
DirectFieldAccessor accessor = new DirectFieldAccessor ( interceptors [ 0 ] ) ;
corsConfig = ( CorsConfiguration ) accessor . getPropertyValue ( "config" ) ;
}
}
return null ;
assertNotNull ( corsConfig ) ;
return corsConfig ;
}
public class TestHandlerMapping extends AbstractHandlerMapping {