@ -1,5 +1,5 @@
/ *
/ *
* Copyright 2002 - 2015 the original author or authors .
* Copyright 2002 - 2018 the original author or authors .
*
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
* you may not use this file except in compliance with the License .
@ -16,15 +16,19 @@
package org.springframework.web.cors.reactive ;
package org.springframework.web.cors.reactive ;
import java.util.concurrent.atomic.AtomicReference ;
import org.junit.Test ;
import org.junit.Test ;
import reactor.core.publisher.Mono ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.server.reactive.ServerHttpRequest ;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest ;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest ;
import org.springframework.mock.web.test.server.MockServerWebExchange ;
import org.springframework.web.filter.reactive.ForwardedHeaderFilter ;
import static org.junit.Assert.assertFalse ;
import static org.junit.Assert.* ;
import static org.junit.Assert.assertTrue ;
import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.* ;
import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.get ;
import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.options ;
/ * *
/ * *
* Test case for reactive { @link CorsUtils } .
* Test case for reactive { @link CorsUtils } .
@ -35,19 +39,19 @@ public class CorsUtilsTests {
@Test
@Test
public void isCorsRequest ( ) {
public void isCorsRequest ( ) {
Mock ServerHttpRequest request = get ( "/" ) . header ( HttpHeaders . ORIGIN , "http://domain.com" ) . build ( ) ;
ServerHttpRequest request = get ( "/" ) . header ( HttpHeaders . ORIGIN , "http://domain.com" ) . build ( ) ;
assertTrue ( CorsUtils . isCorsRequest ( request ) ) ;
assertTrue ( CorsUtils . isCorsRequest ( request ) ) ;
}
}
@Test
@Test
public void isNotCorsRequest ( ) {
public void isNotCorsRequest ( ) {
Mock ServerHttpRequest request = get ( "/" ) . build ( ) ;
ServerHttpRequest request = get ( "/" ) . build ( ) ;
assertFalse ( CorsUtils . isCorsRequest ( request ) ) ;
assertFalse ( CorsUtils . isCorsRequest ( request ) ) ;
}
}
@Test
@Test
public void isPreFlightRequest ( ) {
public void isPreFlightRequest ( ) {
Mock ServerHttpRequest request = options ( "/" )
ServerHttpRequest request = options ( "/" )
. header ( HttpHeaders . ORIGIN , "http://domain.com" )
. header ( HttpHeaders . ORIGIN , "http://domain.com" )
. header ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" )
. header ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD , "GET" )
. build ( ) ;
. build ( ) ;
@ -56,7 +60,7 @@ public class CorsUtilsTests {
@Test
@Test
public void isNotPreFlightRequest ( ) {
public void isNotPreFlightRequest ( ) {
Mock ServerHttpRequest request = get ( "/" ) . build ( ) ;
ServerHttpRequest request = get ( "/" ) . build ( ) ;
assertFalse ( CorsUtils . isPreFlightRequest ( request ) ) ;
assertFalse ( CorsUtils . isPreFlightRequest ( request ) ) ;
request = options ( "/" ) . header ( HttpHeaders . ORIGIN , "http://domain.com" ) . build ( ) ;
request = options ( "/" ) . header ( HttpHeaders . ORIGIN , "http://domain.com" ) . build ( ) ;
@ -68,31 +72,35 @@ public class CorsUtilsTests {
@Test // SPR-16262
@Test // SPR-16262
public void isSameOriginWithXForwardedHeaders ( ) {
public void isSameOriginWithXForwardedHeaders ( ) {
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , - 1 , "https" , null , - 1 , "https://mydomain1.com" ) ) ;
String server = "mydomain1.com" ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , 123 , "https" , null , - 1 , "https://mydomain1.com" ) ) ;
testWithXForwardedHeaders ( server , - 1 , "https" , null , - 1 , "https://mydomain1.com" ) ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , - 1 , "https" , "mydomain2.com" , - 1 , "https://mydomain2.com" ) ) ;
testWithXForwardedHeaders ( server , 123 , "https" , null , - 1 , "https://mydomain1.com" ) ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , 123 , "https" , "mydomain2.com" , - 1 , "https://mydomain2.com" ) ) ;
testWithXForwardedHeaders ( server , - 1 , "https" , "mydomain2.com" , - 1 , "https://mydomain2.com" ) ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , - 1 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" ) ) ;
testWithXForwardedHeaders ( server , 123 , "https" , "mydomain2.com" , - 1 , "https://mydomain2.com" ) ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , 123 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" ) ) ;
testWithXForwardedHeaders ( server , - 1 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" ) ;
testWithXForwardedHeaders ( server , 123 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" ) ;
}
}
@Test // SPR-16262
@Test // SPR-16262
public void isSameOriginWithForwardedHeader ( ) {
public void isSameOriginWithForwardedHeader ( ) {
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , - 1 , "proto=https" , "https://mydomain1.com" ) ) ;
String server = "mydomain1.com" ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , 123 , "proto=https" , "https://mydomain1.com" ) ) ;
testWithForwardedHeader ( server , - 1 , "proto=https" , "https://mydomain1.com" ) ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , - 1 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" ) ) ;
testWithForwardedHeader ( server , 123 , "proto=https" , "https://mydomain1.com" ) ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , 123 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" ) ) ;
testWithForwardedHeader ( server , - 1 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" ) ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , - 1 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" ) ) ;
testWithForwardedHeader ( server , 123 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" ) ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , 123 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" ) ) ;
testWithForwardedHeader ( server , - 1 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" ) ;
testWithForwardedHeader ( server , 123 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" ) ;
}
}
private boolean checkSameOriginWithXForwardedHeaders ( String serverName , int port , String forwardedProto , String forwardedHost , int forwardedPort , String originHeader ) {
private void testWithXForwardedHeaders ( String serverName , int port ,
String forwardedProto , String forwardedHost , int forwardedPort , String originHeader ) {
String url = "http://" + serverName ;
String url = "http://" + serverName ;
if ( port ! = - 1 ) {
if ( port ! = - 1 ) {
url = url + ":" + port ;
url = url + ":" + port ;
}
}
MockServerHttpRequest . BaseBuilder < ? > builder = get ( url )
. header ( HttpHeaders . ORIGIN , originHeader ) ;
MockServerHttpRequest . BaseBuilder < ? > builder = get ( url ) . header ( HttpHeaders . ORIGIN , originHeader ) ;
if ( forwardedProto ! = null ) {
if ( forwardedProto ! = null ) {
builder . header ( "X-Forwarded-Proto" , forwardedProto ) ;
builder . header ( "X-Forwarded-Proto" , forwardedProto ) ;
}
}
@ -102,18 +110,36 @@ public class CorsUtilsTests {
if ( forwardedPort ! = - 1 ) {
if ( forwardedPort ! = - 1 ) {
builder . header ( "X-Forwarded-Port" , String . valueOf ( forwardedPort ) ) ;
builder . header ( "X-Forwarded-Port" , String . valueOf ( forwardedPort ) ) ;
}
}
return CorsUtils . isSameOrigin ( builder . build ( ) ) ;
ServerHttpRequest request = adaptFromForwardedHeaders ( builder ) ;
assertTrue ( CorsUtils . isSameOrigin ( request ) ) ;
}
}
private boolean checkSameOriginWithForwardedHeader ( String serverName , int port , String forwardedHeader , String originHeader ) {
private void testWithForwardedHeader ( String serverName , int port ,
String forwardedHeader , String originHeader ) {
String url = "http://" + serverName ;
String url = "http://" + serverName ;
if ( port ! = - 1 ) {
if ( port ! = - 1 ) {
url = url + ":" + port ;
url = url + ":" + port ;
}
}
MockServerHttpRequest . BaseBuilder < ? > builder = get ( url )
MockServerHttpRequest . BaseBuilder < ? > builder = get ( url )
. header ( "Forwarded" , forwardedHeader )
. header ( "Forwarded" , forwardedHeader )
. header ( HttpHeaders . ORIGIN , originHeader ) ;
. header ( HttpHeaders . ORIGIN , originHeader ) ;
return CorsUtils . isSameOrigin ( builder . build ( ) ) ;
ServerHttpRequest request = adaptFromForwardedHeaders ( builder ) ;
assertTrue ( CorsUtils . isSameOrigin ( request ) ) ;
}
// SPR-16668
private ServerHttpRequest adaptFromForwardedHeaders ( MockServerHttpRequest . BaseBuilder < ? > builder ) {
AtomicReference < ServerHttpRequest > requestRef = new AtomicReference < > ( ) ;
MockServerWebExchange exchange = MockServerWebExchange . from ( builder ) ;
new ForwardedHeaderFilter ( ) . filter ( exchange , exchange2 - > {
requestRef . set ( exchange2 . getRequest ( ) ) ;
return Mono . empty ( ) ;
} ) . block ( ) ;
return requestRef . get ( ) ;
}
}
}
}