diff --git a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java
index a8ba24df5e..e7abefa6fd 100644
--- a/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java
+++ b/web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java
@@ -42,6 +42,8 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.session.NullAuthenticatedSessionStrategy;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
@@ -134,6 +136,8 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
private AuthenticationFailureHandler failureHandler = new SimpleUrlAuthenticationFailureHandler();
+ private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
/**
* @param defaultFilterProcessesUrl the default value for filterProcessesUrl.
*/
@@ -314,6 +318,7 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authResult);
SecurityContextHolder.setContext(context);
+ this.securityContextRepository.saveContext(context, request, response);
if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult));
}
@@ -435,6 +440,18 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
this.failureHandler = failureHandler;
}
+ /**
+ * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+ * authentication success. The default action is not to save the
+ * {@link SecurityContext}.
+ * @param securityContextRepository the {@link SecurityContextRepository} to use.
+ * Cannot be null.
+ */
+ public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+ Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+ this.securityContextRepository = securityContextRepository;
+ }
+
protected AuthenticationSuccessHandler getSuccessHandler() {
return this.successHandler;
}
diff --git a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java
index f7cdabce5c..c8b5816381 100644
--- a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java
+++ b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java
@@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest;
@@ -34,14 +35,17 @@ import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.InternalAuthenticationServiceException;
+import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServicesTests;
import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
+import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.firewall.DefaultHttpFirewall;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -322,6 +326,37 @@ public class AbstractAuthenticationProcessingFilterTests {
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
}
+ @Test
+ public void testSuccessfulAuthenticationThenDefaultDoesNotCreateSession() throws Exception {
+ Authentication authentication = TestAuthentication.authenticatedUser();
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ MockFilterChain chain = new MockFilterChain(false);
+ MockAuthenticationFilter filter = new MockAuthenticationFilter();
+
+ filter.successfulAuthentication(request, response, chain, authentication);
+
+ assertThat(request.getSession(false)).isNull();
+ }
+
+ @Test
+ public void testSuccessfulAuthenticationWhenCustomSecurityContextRepositoryThenAuthenticationSaved()
+ throws Exception {
+ ArgumentCaptor contextCaptor = ArgumentCaptor.forClass(SecurityContext.class);
+ SecurityContextRepository repository = mock(SecurityContextRepository.class);
+ Authentication authentication = TestAuthentication.authenticatedUser();
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ MockFilterChain chain = new MockFilterChain(false);
+ MockAuthenticationFilter filter = new MockAuthenticationFilter();
+ filter.setSecurityContextRepository(repository);
+
+ filter.successfulAuthentication(request, response, chain, authentication);
+
+ verify(repository).saveContext(contextCaptor.capture(), eq(request), eq(response));
+ assertThat(contextCaptor.getValue().getAuthentication()).isEqualTo(authentication);
+ }
+
@Test
public void testFailedAuthenticationInvokesFailureHandler() throws Exception {
// Setup our HTTP request