@ -160,6 +160,26 @@ public class WebUtilsTests {
@@ -160,6 +160,26 @@ public class WebUtilsTests {
assertFalse ( checkSameOrigin ( "[::1]" , 8080 , "http://[2001:0db8:0000:85a3:0000:0000:ac1f:8001]:8080" ) ) ;
}
@Test // SPR-16262
public void isSameOriginWithXForwardedHeaders ( ) {
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , - 1 , "https" , null , - 1 , "https://mydomain1.com" ) ) ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , 123 , "https" , null , - 1 , "https://mydomain1.com" ) ) ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , - 1 , "https" , "mydomain2.com" , - 1 , "https://mydomain2.com" ) ) ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , 123 , "https" , "mydomain2.com" , - 1 , "https://mydomain2.com" ) ) ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , - 1 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" ) ) ;
assertTrue ( checkSameOriginWithXForwardedHeaders ( "mydomain1.com" , 123 , "https" , "mydomain2.com" , 456 , "https://mydomain2.com:456" ) ) ;
}
@Test // SPR-16262
public void isSameOriginWithForwardedHeader ( ) {
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , - 1 , "proto=https" , "https://mydomain1.com" ) ) ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , 123 , "proto=https" , "https://mydomain1.com" ) ) ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , - 1 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" ) ) ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , 123 , "proto=https; host=mydomain2.com" , "https://mydomain2.com" ) ) ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , - 1 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" ) ) ;
assertTrue ( checkSameOriginWithForwardedHeader ( "mydomain1.com" , 123 , "proto=https; host=mydomain2.com:456" , "https://mydomain2.com:456" ) ) ;
}
private boolean checkValidOrigin ( String serverName , int port , String originHeader , List < String > allowed ) {
MockHttpServletRequest servletRequest = new MockHttpServletRequest ( ) ;
@ -183,4 +203,36 @@ public class WebUtilsTests {
@@ -183,4 +203,36 @@ public class WebUtilsTests {
return WebUtils . isSameOrigin ( request ) ;
}
private boolean checkSameOriginWithXForwardedHeaders ( String serverName , int port , String forwardedProto , String forwardedHost , int forwardedPort , String originHeader ) {
MockHttpServletRequest servletRequest = new MockHttpServletRequest ( ) ;
ServerHttpRequest request = new ServletServerHttpRequest ( servletRequest ) ;
servletRequest . setServerName ( serverName ) ;
if ( port ! = - 1 ) {
servletRequest . setServerPort ( port ) ;
}
if ( forwardedProto ! = null ) {
request . getHeaders ( ) . set ( "X-Forwarded-Proto" , forwardedProto ) ;
}
if ( forwardedHost ! = null ) {
request . getHeaders ( ) . set ( "X-Forwarded-Host" , forwardedHost ) ;
}
if ( forwardedPort ! = - 1 ) {
request . getHeaders ( ) . set ( "X-Forwarded-Port" , String . valueOf ( forwardedPort ) ) ;
}
request . getHeaders ( ) . set ( HttpHeaders . ORIGIN , originHeader ) ;
return WebUtils . isSameOrigin ( request ) ;
}
private boolean checkSameOriginWithForwardedHeader ( String serverName , int port , String forwardedHeader , String originHeader ) {
MockHttpServletRequest servletRequest = new MockHttpServletRequest ( ) ;
ServerHttpRequest request = new ServletServerHttpRequest ( servletRequest ) ;
servletRequest . setServerName ( serverName ) ;
if ( port ! = - 1 ) {
servletRequest . setServerPort ( port ) ;
}
request . getHeaders ( ) . set ( "Forwarded" , forwardedHeader ) ;
request . getHeaders ( ) . set ( HttpHeaders . ORIGIN , originHeader ) ;
return WebUtils . isSameOrigin ( request ) ;
}
}