@ -106,37 +106,47 @@ public class WebUtilsTests {
}
}
@Test
@Test
public void isValidOriginSuccess ( ) {
public void isValidOrigin ( ) {
List < String > allowed = Collections . emptyList ( ) ;
List < String > allowed = Collections . emptyList ( ) ;
assertTrue ( checkOrigin ( "mydomain1.com" , - 1 , "http://mydomain1.com" , allowed ) ) ;
assertTrue ( checkValidOrigin ( "mydomain1.com" , - 1 , "http://mydomain1.com" , allowed ) ) ;
assertTrue ( checkOrigin ( "mydomain1.com" , - 1 , "http://mydomain1.com:80" , allowed ) ) ;
assertFalse ( checkValidOrigin ( "mydomain1.com" , - 1 , "http://mydomain2.com" , allowed ) ) ;
assertTrue ( checkOrigin ( "mydomain1.com" , 443 , "https://mydomain1.com" , allowed ) ) ;
assertTrue ( checkOrigin ( "mydomain1.com" , 443 , "https://mydomain1.com:443" , allowed ) ) ;
assertTrue ( checkOrigin ( "mydomain1.com" , 123 , "http://mydomain1.com:123" , allowed ) ) ;
assertTrue ( checkOrigin ( "mydomain1.com" , - 1 , "ws://mydomain1.com" , allowed ) ) ;
assertTrue ( checkOrigin ( "mydomain1.com" , 443 , "wss://mydomain1.com" , allowed ) ) ;
allowed = Collections . singletonList ( "*" ) ;
allowed = Collections . singletonList ( "*" ) ;
assertTrue ( checkOrigin ( "mydomain1.com" , - 1 , "http://mydomain2.com" , allowed ) ) ;
assertTrue ( checkValidOrigin ( "mydomain1.com" , - 1 , "http://mydomain2.com" , allowed ) ) ;
allowed = Collections . singletonList ( "http://mydomain1.com" ) ;
allowed = Collections . singletonList ( "http://mydomain1.com" ) ;
assertTrue ( checkOrigin ( "mydomain2.com" , - 1 , "http://mydomain1.com" , allowed ) ) ;
assertTrue ( checkValidOrigin ( "mydomain2.com" , - 1 , "http://mydomain1.com" , allowed ) ) ;
assertFalse ( checkValidOrigin ( "mydomain2.com" , - 1 , "http://mydomain3.com" , allowed ) ) ;
}
}
@Test
@Test
public void isValidOriginFailure ( ) {
public void isSameOrigin ( ) {
assertTrue ( checkSameOrigin ( "mydomain1.com" , - 1 , "http://mydomain1.com" ) ) ;
assertTrue ( checkSameOrigin ( "mydomain1.com" , - 1 , "http://mydomain1.com:80" ) ) ;
assertTrue ( checkSameOrigin ( "mydomain1.com" , 443 , "https://mydomain1.com" ) ) ;
assertTrue ( checkSameOrigin ( "mydomain1.com" , 443 , "https://mydomain1.com:443" ) ) ;
assertTrue ( checkSameOrigin ( "mydomain1.com" , 123 , "http://mydomain1.com:123" ) ) ;
assertTrue ( checkSameOrigin ( "mydomain1.com" , - 1 , "ws://mydomain1.com" ) ) ;
assertTrue ( checkSameOrigin ( "mydomain1.com" , 443 , "wss://mydomain1.com" ) ) ;
assertFalse ( checkSameOrigin ( "mydomain1.com" , - 1 , "http://mydomain2.com" ) ) ;
assertFalse ( checkSameOrigin ( "mydomain1.com" , - 1 , "https://mydomain1.com" ) ) ;
assertFalse ( checkSameOrigin ( "mydomain1.com" , - 1 , "invalid-origin" ) ) ;
}
List < String > allowed = Collections . emptyList ( ) ;
assertFalse ( checkOrigin ( "mydomain1.com" , - 1 , "http://mydomain2.com" , allowed ) ) ;
assertFalse ( checkOrigin ( "mydomain1.com" , - 1 , "https://mydomain1.com" , allowed ) ) ;
assertFalse ( checkOrigin ( "mydomain1.com" , - 1 , "invalid-origin" , allowed ) ) ;
allowed = Collections . singletonList ( "http://mydomain1.com" ) ;
private boolean checkValidOrigin ( String serverName , int port , String originHeader , List < String > allowed ) {
assertFalse ( checkOrigin ( "mydomain2.com" , - 1 , "http://mydomain3.com" , allowed ) ) ;
MockHttpServletRequest servletRequest = new MockHttpServletRequest ( ) ;
ServerHttpRequest request = new ServletServerHttpRequest ( servletRequest ) ;
servletRequest . setServerName ( serverName ) ;
if ( port ! = - 1 ) {
servletRequest . setServerPort ( port ) ;
}
request . getHeaders ( ) . set ( HttpHeaders . ORIGIN , originHeader ) ;
return WebUtils . isValidOrigin ( request , allowed ) ;
}
}
private boolean checkOrigin ( String serverName , int port , String originHeader , List < String > allowed ) {
private boolean checkSame Origin ( String serverName , int port , String originHeader ) {
MockHttpServletRequest servletRequest = new MockHttpServletRequest ( ) ;
MockHttpServletRequest servletRequest = new MockHttpServletRequest ( ) ;
ServerHttpRequest request = new ServletServerHttpRequest ( servletRequest ) ;
ServerHttpRequest request = new ServletServerHttpRequest ( servletRequest ) ;
servletRequest . setServerName ( serverName ) ;
servletRequest . setServerName ( serverName ) ;
@ -144,7 +154,7 @@ public class WebUtilsTests {
servletRequest . setServerPort ( port ) ;
servletRequest . setServerPort ( port ) ;
}
}
request . getHeaders ( ) . set ( HttpHeaders . ORIGIN , originHeader ) ;
request . getHeaders ( ) . set ( HttpHeaders . ORIGIN , originHeader ) ;
return WebUtils . isValid Origin ( request , allowed ) ;
return WebUtils . isSame Origin ( request ) ;
}
}
}
}