@ -16,12 +16,16 @@
@@ -16,12 +16,16 @@
package org.springframework.security.web.context ;
import java.io.IOException ;
import java.lang.annotation.ElementType ;
import java.lang.annotation.Retention ;
import java.lang.annotation.RetentionPolicy ;
import java.lang.annotation.Target ;
import javax.servlet.Filter ;
import javax.servlet.ServletException ;
import javax.servlet.ServletOutputStream ;
import javax.servlet.http.HttpServlet ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletRequestWrapper ;
import javax.servlet.http.HttpServletResponse ;
@ -31,6 +35,7 @@ import javax.servlet.http.HttpSession;
@@ -31,6 +35,7 @@ import javax.servlet.http.HttpSession;
import org.junit.After ;
import org.junit.Test ;
import org.springframework.mock.web.MockFilterChain ;
import org.springframework.mock.web.MockHttpServletRequest ;
import org.springframework.mock.web.MockHttpServletResponse ;
import org.springframework.mock.web.MockHttpSession ;
@ -38,10 +43,14 @@ import org.springframework.security.authentication.AbstractAuthenticationToken;
@@ -38,10 +43,14 @@ import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AnonymousAuthenticationToken ;
import org.springframework.security.authentication.AuthenticationTrustResolver ;
import org.springframework.security.authentication.TestingAuthenticationToken ;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken ;
import org.springframework.security.core.Transient ;
import org.springframework.security.core.authority.AuthorityUtils ;
import org.springframework.security.core.context.SecurityContext ;
import org.springframework.security.core.context.SecurityContextHolder ;
import org.springframework.security.core.context.SecurityContextImpl ;
import org.springframework.security.core.userdetails.User ;
import org.springframework.security.core.userdetails.UserDetails ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException ;
@ -162,6 +171,48 @@ public class HttpSessionSecurityContextRepositoryTests {
@@ -162,6 +171,48 @@ public class HttpSessionSecurityContextRepositoryTests {
verify ( session ) . setAttribute ( HttpSessionSecurityContextRepository . SPRING_SECURITY_CONTEXT_KEY , ctx ) ;
}
@Test
public void saveContextWhenSaveNewContextThenOriginalContextThenOriginalContextSaved ( ) throws Exception {
HttpSessionSecurityContextRepository repository = new HttpSessionSecurityContextRepository ( ) ;
SecurityContextPersistenceFilter securityContextPersistenceFilter = new SecurityContextPersistenceFilter (
repository ) ;
UserDetails original = User . withUsername ( "user" ) . password ( "password" ) . roles ( "USER" ) . build ( ) ;
SecurityContext originalContext = createSecurityContext ( original ) ;
UserDetails impersonate = User . withUserDetails ( original ) . username ( "impersonate" ) . build ( ) ;
SecurityContext impersonateContext = createSecurityContext ( impersonate ) ;
MockHttpServletRequest mockRequest = new MockHttpServletRequest ( ) ;
MockHttpServletResponse mockResponse = new MockHttpServletResponse ( ) ;
Filter saveImpersonateContext = ( request , response , chain ) - > {
SecurityContextHolder . setContext ( impersonateContext ) ;
// ensure the response is committed to trigger save
response . flushBuffer ( ) ;
chain . doFilter ( request , response ) ;
} ;
Filter saveOriginalContext = ( request , response , chain ) - > {
SecurityContextHolder . setContext ( originalContext ) ;
chain . doFilter ( request , response ) ;
} ;
HttpServlet servlet = new HttpServlet ( ) {
@Override
protected void service ( HttpServletRequest req , HttpServletResponse resp )
throws ServletException , IOException {
resp . getWriter ( ) . write ( "Hi" ) ;
}
} ;
SecurityContextHolder . setContext ( originalContext ) ;
MockFilterChain chain = new MockFilterChain ( servlet , saveImpersonateContext , saveOriginalContext ) ;
securityContextPersistenceFilter . doFilter ( mockRequest , mockResponse , chain ) ;
assertThat (
mockRequest . getSession ( ) . getAttribute ( HttpSessionSecurityContextRepository . SPRING_SECURITY_CONTEXT_KEY ) )
. isEqualTo ( originalContext ) ;
}
@Test
public void nonSecurityContextInSessionIsIgnored ( ) {
HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository ( ) ;
@ -577,6 +628,13 @@ public class HttpSessionSecurityContextRepositoryTests {
@@ -577,6 +628,13 @@ public class HttpSessionSecurityContextRepositoryTests {
assertThat ( session ) . isNull ( ) ;
}
private SecurityContext createSecurityContext ( UserDetails userDetails ) {
UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken ( userDetails ,
userDetails . getPassword ( ) , userDetails . getAuthorities ( ) ) ;
SecurityContext securityContext = new SecurityContextImpl ( token ) ;
return securityContext ;
}
@Transient
private static class SomeTransientAuthentication extends AbstractAuthenticationToken {