diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index bed80d38d7..5b17404efe 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -181,6 +181,8 @@ public final class OAuth2LoginConfigurer> private OAuth2AuthorizedClientRepository authorizedClientRepository; + private SecurityContextRepository securityContextRepository; + /** * Sets the repository of client registrations. * @param clientRegistrationRepository the repository of client registrations @@ -234,6 +236,17 @@ public final class OAuth2LoginConfigurer> return this; } + /** + * Sets the {@link SecurityContextRepository} to use. + * @param securityContextRepository the {@link SecurityContextRepository} to use + * @return the {@link OAuth2LoginConfigurer} for further configuration + */ + @Override + public OAuth2LoginConfigurer securityContextRepository(SecurityContextRepository securityContextRepository) { + this.securityContextRepository = securityContextRepository; + return this; + } + /** * Sets the registry for managing the OIDC client-provider session link * @param oidcSessionRegistry the {@link OidcSessionRegistry} to use @@ -354,6 +367,9 @@ public final class OAuth2LoginConfigurer> RequestMatcher processUri = RequestMatcherFactory.matcher(this.loginProcessingUrl); authenticationFilter.setRequiresAuthenticationRequestMatcher(processUri); authenticationFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); + if (this.securityContextRepository != null) { + authenticationFilter.setSecurityContextRepository(this.securityContextRepository); + } this.setAuthenticationFilter(authenticationFilter); super.loginProcessingUrl(this.loginProcessingUrl); if (this.loginPage != null) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index d277fc52a3..fef7913122 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -105,6 +105,7 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.context.NullSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.session.HttpSessionDestroyedEvent; import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher; @@ -114,6 +115,7 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; @@ -729,6 +731,12 @@ public class OAuth2LoginConfigurerTests { verify(this.context.getBean(SpyObjectPostProcessor.class).spy).authenticate(any()); } + // gh-16623 + @Test + public void oauth2LoginWithCustomSecurityContextRepository() { + assertThatNoException().isThrownBy(() -> loadConfig(OAuth2LoginConfigSecurityContextRepository.class)); + } + private void loadConfig(Class... configs) { AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); applicationContext.register(configs); @@ -977,6 +985,24 @@ public class OAuth2LoginConfigurerTests { } + @Configuration + @EnableWebSecurity + static class OAuth2LoginConfigSecurityContextRepository extends CommonSecurityFilterChainConfig { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((login) -> login + .clientRegistrationRepository( + new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) + .securityContextRepository(new NullSecurityContextRepository())); + // @formatter:on + return super.configureFilterChain(http); + } + + } + @Configuration @EnableWebSecurity static class OAuth2LoginConfigCustomAuthorizationRequestResolver extends CommonSecurityFilterChainConfig {