diff --git a/config/src/main/kotlin/org/springframework/security/config/web/servlet/OAuth2LoginDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/servlet/OAuth2LoginDsl.kt index b61fed4026..09668dcaa2 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/servlet/OAuth2LoginDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/servlet/OAuth2LoginDsl.kt @@ -16,6 +16,7 @@ package org.springframework.security.config.web.servlet +import org.springframework.security.authentication.AuthenticationDetailsSource import org.springframework.security.config.annotation.web.HttpSecurityBuilder import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.web.servlet.oauth2.login.AuthorizationEndpointDsl @@ -28,6 +29,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository import org.springframework.security.web.authentication.AuthenticationFailureHandler import org.springframework.security.web.authentication.AuthenticationSuccessHandler +import javax.servlet.http.HttpServletRequest /** * A Kotlin DSL to configure [HttpSecurity] OAuth 2.0 login using idiomatic Kotlin code. @@ -59,6 +61,7 @@ class OAuth2LoginDsl { var failureUrl: String? = null var loginProcessingUrl: String? = null var permitAll: Boolean? = null + var authenticationDetailsSource: AuthenticationDetailsSource? = null private var defaultSuccessUrlOption: Pair? = null private var authorizationEndpoint: ((OAuth2LoginConfigurer.AuthorizationEndpointConfig) -> Unit)? = null @@ -221,6 +224,7 @@ class OAuth2LoginDsl { tokenEndpoint?.also { oauth2Login.tokenEndpoint(tokenEndpoint) } redirectionEndpoint?.also { oauth2Login.redirectionEndpoint(redirectionEndpoint) } userInfoEndpoint?.also { oauth2Login.userInfoEndpoint(userInfoEndpoint) } + authenticationDetailsSource?.also { oauth2Login.authenticationDetailsSource(authenticationDetailsSource) } } } } diff --git a/config/src/test/kotlin/org/springframework/security/config/web/servlet/OAuth2LoginDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/servlet/OAuth2LoginDslTests.kt index 9521b4ccd8..f0b146a3c1 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/servlet/OAuth2LoginDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/servlet/OAuth2LoginDslTests.kt @@ -16,11 +16,15 @@ package org.springframework.security.config.web.servlet +import io.mockk.every +import io.mockk.mockkObject +import io.mockk.verify import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith import org.springframework.beans.factory.annotation.Autowired import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration +import org.springframework.security.authentication.AuthenticationDetailsSource import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter @@ -29,10 +33,16 @@ import org.springframework.security.config.test.SpringTestContext import org.springframework.security.config.test.SpringTestContextExtension import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository +import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.get +import org.springframework.test.web.servlet.post import org.springframework.web.bind.annotation.GetMapping import org.springframework.web.bind.annotation.RestController +import javax.servlet.http.HttpServletRequest /** * Tests for [OAuth2LoginDsl] @@ -113,6 +123,58 @@ class OAuth2LoginDslTests { } } + @Test + fun `oauth2Login when custom authentication details source then used`() { + this.spring + .register(CustomAuthenticationDetailsSourceConfig::class.java, ClientConfig::class.java) + .autowire() + mockkObject(CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE) + every { + CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE.buildDetails(any()) + } returns Any() + mockkObject(CustomAuthenticationDetailsSourceConfig.AUTHORIZATION_REQUEST_REPOSITORY) + every { + CustomAuthenticationDetailsSourceConfig.AUTHORIZATION_REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any()) + } returns OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri("/") + .clientId("clientId") + .redirectUri("/") + .attributes { attributes -> attributes[OAuth2ParameterNames.REGISTRATION_ID] = "google" } + .build() + + this.mockMvc.post("/login/oauth2/code/google") { + param(OAuth2ParameterNames.CODE, "code") + param(OAuth2ParameterNames.STATE, "state") + with(csrf()) + } + .andExpect { + status { is3xxRedirection() } + } + + verify(exactly = 1) { CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE.buildDetails(any()) } + } + + @EnableWebSecurity + open class CustomAuthenticationDetailsSourceConfig : WebSecurityConfigurerAdapter() { + + companion object { + val AUTHENTICATION_DETAILS_SOURCE: AuthenticationDetailsSource = + AuthenticationDetailsSource { Any() } + val AUTHORIZATION_REQUEST_REPOSITORY = HttpSessionOAuth2AuthorizationRequestRepository() + } + + override fun configure(http: HttpSecurity) { + http { + oauth2Login { + authenticationDetailsSource = AUTHENTICATION_DETAILS_SOURCE + authorizationEndpoint { + authorizationRequestRepository = AUTHORIZATION_REQUEST_REPOSITORY + } + } + } + } + } + @Configuration open class ClientConfig { @Bean