@ -27,6 +27,7 @@ import org.mockito.ArgumentCaptor;
@@ -27,6 +27,7 @@ import org.mockito.ArgumentCaptor;
import org.springframework.mock.web.MockHttpServletRequest ;
import org.springframework.mock.web.MockHttpServletResponse ;
import org.springframework.security.authentication.AuthenticationDetailsSource ;
import org.springframework.security.authentication.AuthenticationManager ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.AuthenticationException ;
@ -50,6 +51,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp
@@ -50,6 +51,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
import org.springframework.security.oauth2.core.user.OAuth2User ;
import org.springframework.security.web.authentication.AuthenticationFailureHandler ;
import org.springframework.security.web.authentication.WebAuthenticationDetails ;
import org.springframework.security.web.util.UrlUtils ;
import org.springframework.web.util.UriComponentsBuilder ;
@ -79,6 +81,7 @@ public class OAuth2LoginAuthenticationFilterTests {
@@ -79,6 +81,7 @@ public class OAuth2LoginAuthenticationFilterTests {
private AuthorizationRequestRepository < OAuth2AuthorizationRequest > authorizationRequestRepository ;
private AuthenticationFailureHandler failureHandler ;
private AuthenticationManager authenticationManager ;
private AuthenticationDetailsSource authenticationDetailsSource ;
private OAuth2LoginAuthenticationToken loginAuthentication ;
private OAuth2LoginAuthenticationFilter filter ;
@ -93,11 +96,13 @@ public class OAuth2LoginAuthenticationFilterTests {
@@ -93,11 +96,13 @@ public class OAuth2LoginAuthenticationFilterTests {
this . authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository ( ) ;
this . failureHandler = mock ( AuthenticationFailureHandler . class ) ;
this . authenticationManager = mock ( AuthenticationManager . class ) ;
this . authenticationDetailsSource = mock ( AuthenticationDetailsSource . class ) ;
this . filter = spy ( new OAuth2LoginAuthenticationFilter ( this . clientRegistrationRepository ,
this . authorizedClientRepository , OAuth2LoginAuthenticationFilter . DEFAULT_FILTER_PROCESSES_URI ) ) ;
this . filter . setAuthorizationRequestRepository ( this . authorizationRequestRepository ) ;
this . filter . setAuthenticationFailureHandler ( this . failureHandler ) ;
this . filter . setAuthenticationManager ( this . authenticationManager ) ;
this . filter . setAuthenticationDetailsSource ( this . authenticationDetailsSource ) ;
}
@Test
@ -400,6 +405,29 @@ public class OAuth2LoginAuthenticationFilterTests {
@@ -400,6 +405,29 @@ public class OAuth2LoginAuthenticationFilterTests {
assertThat ( authorizationResponse . getRedirectUri ( ) ) . isEqualTo ( expectedRedirectUri ) ;
}
// gh-6866
@Test
public void attemptAuthenticationShouldSetAuthenticationDetailsOnAuthenticationResult ( ) throws Exception {
String requestUri = "/login/oauth2/code/" + this . registration1 . getRegistrationId ( ) ;
String state = "state" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , state ) ;
WebAuthenticationDetails webAuthenticationDetails = mock ( WebAuthenticationDetails . class ) ;
when ( authenticationDetailsSource . buildDetails ( any ( ) ) ) . thenReturn ( webAuthenticationDetails ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
this . setUpAuthorizationRequest ( request , response , this . registration2 , state ) ;
this . setUpAuthenticationResult ( this . registration2 ) ;
Authentication result = this . filter . attemptAuthentication ( request , response ) ;
assertThat ( result . getDetails ( ) ) . isEqualTo ( webAuthenticationDetails ) ;
}
private void setUpAuthorizationRequest ( HttpServletRequest request , HttpServletResponse response ,
ClientRegistration registration , String state ) {
Map < String , Object > attributes = new HashMap < > ( ) ;