@ -15,10 +15,12 @@
@@ -15,10 +15,12 @@
* /
package org.springframework.web.filter ;
import java.io.IOException ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServlet ;
import javax.servlet.http.HttpServletRequest ;
import org.junit.Before ;
import org.junit.Test ;
import org.springframework.mock.web.test.MockFilterChain ;
@ -38,6 +40,98 @@ public class ForwardedHeaderFilterTests {
@@ -38,6 +40,98 @@ public class ForwardedHeaderFilterTests {
private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter ( ) ;
private MockHttpServletRequest request ;
private MockFilterChain filterChain ;
@Before
public void setUp ( ) throws Exception {
this . request = new MockHttpServletRequest ( ) ;
this . request . setScheme ( "http" ) ;
this . request . setServerName ( "localhost" ) ;
this . request . setServerPort ( 80 ) ;
this . filterChain = new MockFilterChain ( new HttpServlet ( ) { } ) ;
}
@Test ( expected = IllegalArgumentException . class )
public void contextPathNull ( ) {
this . filter . setContextPath ( null ) ;
}
@Test
public void contextPathEmpty ( ) throws Exception {
this . filter . setContextPath ( "" ) ;
assertEquals ( "" , filterAndGetContextPath ( ) ) ;
}
@Test
public void contextPathWithExtraSpaces ( ) throws Exception {
this . filter . setContextPath ( " /foo " ) ;
assertEquals ( "/foo" , filterAndGetContextPath ( ) ) ;
}
@Test
public void contextPathWithNoLeadingSlash ( ) throws Exception {
this . filter . setContextPath ( "foo" ) ;
assertEquals ( "/foo" , filterAndGetContextPath ( ) ) ;
}
@Test
public void contextPathWithTrailingSlash ( ) throws Exception {
this . filter . setContextPath ( "/foo/bar/" ) ;
assertEquals ( "/foo/bar" , filterAndGetContextPath ( ) ) ;
}
@Test
public void contextPathWithTrailingSlashes ( ) throws Exception {
this . filter . setContextPath ( "/foo/bar/baz///" ) ;
assertEquals ( "/foo/bar/baz" , filterAndGetContextPath ( ) ) ;
}
@Test
public void requestUri ( ) throws Exception {
this . filter . setContextPath ( "/" ) ;
this . request . setContextPath ( "/app" ) ;
this . request . setRequestURI ( "/app/path" ) ;
HttpServletRequest actual = filterAndGetWrappedRequest ( ) ;
assertEquals ( "/" , actual . getContextPath ( ) ) ;
assertEquals ( "/path" , actual . getRequestURI ( ) ) ;
}
@Test
public void requestUriWithTrailingSlash ( ) throws Exception {
this . filter . setContextPath ( "/" ) ;
this . request . setContextPath ( "/app" ) ;
this . request . setRequestURI ( "/app/path/" ) ;
HttpServletRequest actual = filterAndGetWrappedRequest ( ) ;
assertEquals ( "/" , actual . getContextPath ( ) ) ;
assertEquals ( "/path/" , actual . getRequestURI ( ) ) ;
}
@Test
public void requestUriEqualsContextPath ( ) throws Exception {
this . filter . setContextPath ( "/" ) ;
this . request . setContextPath ( "/app" ) ;
this . request . setRequestURI ( "/app" ) ;
HttpServletRequest actual = filterAndGetWrappedRequest ( ) ;
assertEquals ( "/" , actual . getContextPath ( ) ) ;
assertEquals ( "/" , actual . getRequestURI ( ) ) ;
}
@Test
public void requestUriRootUrl ( ) throws Exception {
this . filter . setContextPath ( "/" ) ;
this . request . setContextPath ( "/app" ) ;
this . request . setRequestURI ( "/app/" ) ;
HttpServletRequest actual = filterAndGetWrappedRequest ( ) ;
assertEquals ( "/" , actual . getContextPath ( ) ) ;
assertEquals ( "/" , actual . getRequestURI ( ) ) ;
}
@Test
public void shouldFilter ( ) throws Exception {
@ -54,19 +148,14 @@ public class ForwardedHeaderFilterTests {
@@ -54,19 +148,14 @@ public class ForwardedHeaderFilterTests {
@Test
public void forwardedRequest ( ) throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest ( ) ;
request . setScheme ( "http" ) ;
request . setServerName ( "localhost" ) ;
request . setServerPort ( 80 ) ;
request . setRequestURI ( "/mvc-showcase" ) ;
request . addHeader ( "X-Forwarded-Proto" , "https" ) ;
request . addHeader ( "X-Forwarded-Host" , "84.198.58.199" ) ;
request . addHeader ( "X-Forwarded-Port" , "443" ) ;
request . addHeader ( "foo" , "bar" ) ;
MockFilterChain chain = new MockFilterChain ( new HttpServlet ( ) { } ) ;
this . filter . doFilter ( request , new MockHttpServletResponse ( ) , chain ) ;
HttpServletRequest actual = ( HttpServletRequest ) chain . getRequest ( ) ;
this . request . setRequestURI ( "/mvc-showcase" ) ;
this . request . addHeader ( "X-Forwarded-Proto" , "https" ) ;
this . request . addHeader ( "X-Forwarded-Host" , "84.198.58.199" ) ;
this . request . addHeader ( "X-Forwarded-Port" , "443" ) ;
this . request . addHeader ( "foo" , "bar" ) ;
this . filter . doFilter ( this . request , new MockHttpServletResponse ( ) , this . filterChain ) ;
HttpServletRequest actual = ( HttpServletRequest ) this . filterChain . getRequest ( ) ;
assertEquals ( "https://84.198.58.199/mvc-showcase" , actual . getRequestURL ( ) . toString ( ) ) ;
assertEquals ( "https" , actual . getScheme ( ) ) ;
@ -81,11 +170,20 @@ public class ForwardedHeaderFilterTests {
@@ -81,11 +170,20 @@ public class ForwardedHeaderFilterTests {
}
private String filterAndGetContextPath ( ) throws ServletException , IOException {
return filterAndGetWrappedRequest ( ) . getContextPath ( ) ;
}
private HttpServletRequest filterAndGetWrappedRequest ( ) throws ServletException , IOException {
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
this . filter . doFilterInternal ( this . request , response , this . filterChain ) ;
return ( HttpServletRequest ) this . filterChain . getRequest ( ) ;
}
private void testShouldFilter ( String headerName ) throws ServletException {
MockHttpServletRequest request = new MockHttpServletRequest ( ) ;
request . addHeader ( headerName , "1" ) ;
assertFalse ( this . filter . shouldNotFilter ( request ) ) ;
}
}