@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2018 the original author or authors .
* Copyright 2002 - 2020 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -15,17 +15,9 @@
@@ -15,17 +15,9 @@
* /
package org.springframework.security.oauth2.client.web ;
import java.util.HashMap ;
import java.util.Map ;
import javax.servlet.FilterChain ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import javax.servlet.http.HttpSession ;
import org.junit.After ;
import org.junit.Before ;
import org.junit.Test ;
import org.springframework.mock.web.MockHttpServletRequest ;
import org.springframework.mock.web.MockHttpServletResponse ;
import org.springframework.security.authentication.AnonymousAuthenticationToken ;
@ -50,13 +42,26 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
@@ -50,13 +42,26 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache ;
import org.springframework.security.web.savedrequest.RequestCache ;
import org.springframework.security.web.util.UrlUtils ;
import org.springframework.util.CollectionUtils ;
import javax.servlet.FilterChain ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import javax.servlet.http.HttpSession ;
import java.util.HashMap ;
import java.util.LinkedHashMap ;
import java.util.Map ;
import java.util.stream.Collectors ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatThrownBy ;
import static org.mockito.Mockito.any ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.spy ;
import static org.mockito.Mockito.times ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.verifyNoInteractions ;
import static org.mockito.Mockito.when ;
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes ;
import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken ;
@ -131,8 +136,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@@ -131,8 +136,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
// NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter.
HttpServletResponse response = mock ( HttpServletResponse . class ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . filter . doFilter ( request , response , filterChain ) ;
@ -142,94 +146,142 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@@ -142,94 +146,142 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@Test
public void doFilterWhenAuthorizationRequestNotFoundThenNotProcessed ( ) throws Exception {
String requestUri = "/path" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
HttpServletResponse response = mock ( HttpServletResponse . class ) ;
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( "/path" ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse ( authorizationRequest ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . filter . doFilter ( request , response , filterChain ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
verify ( filterChain ) . doFilter ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
}
@Test
public void doFilterWhenAuthorizationResponseUrlDoesNotMatchAuthorizationRe questRedirectUriThenNotProcessed ( ) throws Exception {
public void doFilterWhenAuthorizationRequestRedirectUriDoesNotMatch ThenNotProcessed ( ) throws Exception {
String requestUri = "/callback/client-1" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
HttpServletResponse response = mock ( HttpServletResponse . class ) ;
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( requestUri ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse ( authorizationRequest ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
authorizationResponse . setRequestURI ( requestUri + "-no-match" ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration1 ) ;
request . setRequestURI ( requestUri + "-no-match" ) ;
this . filter . doFilter ( request , response , filterChain ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
verify ( filterChain ) . doFilter ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
}
// gh-7963
@Test
public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved ( ) throws Exception {
public void doFilterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed ( ) throws Exception {
// 1) redirect_uri with query parameters
String requestUri = "/callback/client-1" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
Map < String , String > parameters = new LinkedHashMap < > ( ) ;
parameters . put ( "param1" , "value1" ) ;
parameters . put ( "param2" , "value2" ) ;
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( requestUri , parameters ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
this . setUpAuthenticationResult ( this . registration1 ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse ( authorizationRequest ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
verifyNoInteractions ( filterChain ) ;
// 2) redirect_uri with query parameters AND authorization response additional parameters
Map < String , String > additionalParameters = new LinkedHashMap < > ( ) ;
additionalParameters . put ( "auth-param1" , "value1" ) ;
additionalParameters . put ( "auth-param2" , "value2" ) ;
response = new MockHttpServletResponse ( ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
authorizationResponse = createAuthorizationResponse ( authorizationRequest , additionalParameters ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
verifyNoInteractions ( filterChain ) ;
}
// gh-7963
@Test
public void doFilterWhenAuthorizationRequestRedirectUriParametersDoesNotMatchThenNotProcessed ( ) throws Exception {
String requestUri = "/callback/client-1" ;
Map < String , String > parameters = new LinkedHashMap < > ( ) ;
parameters . put ( "param1" , "value1" ) ;
parameters . put ( "param2" , "value2" ) ;
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( requestUri , parameters ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
this . setUpAuthenticationResult ( this . registration1 ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration1 ) ;
// 1) Parameter value
Map < String , String > parametersNotMatch = new LinkedHashMap < > ( parameters ) ;
parametersNotMatch . put ( "param2" , "value8" ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse (
createAuthorizationRequest ( requestUri , parametersNotMatch ) ) ;
authorizationResponse . setSession ( authorizationRequest . getSession ( ) ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
verify ( filterChain , times ( 1 ) ) . doFilter ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
// 2) Parameter order
parametersNotMatch = new LinkedHashMap < > ( ) ;
parametersNotMatch . put ( "param2" , "value2" ) ;
parametersNotMatch . put ( "param1" , "value1" ) ;
authorizationResponse = createAuthorizationResponse (
createAuthorizationRequest ( requestUri , parametersNotMatch ) ) ;
authorizationResponse . setSession ( authorizationRequest . getSession ( ) ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
verify ( filterChain , times ( 2 ) ) . doFilter ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
// 3) Parameter missing
parametersNotMatch = new LinkedHashMap < > ( parameters ) ;
parametersNotMatch . remove ( "param2" ) ;
authorizationResponse = createAuthorizationResponse (
createAuthorizationRequest ( requestUri , parametersNotMatch ) ) ;
authorizationResponse . setSession ( authorizationRequest . getSession ( ) ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
verify ( filterChain , times ( 3 ) ) . doFilter ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
}
@Test
public void doFilterWhenAuthorizationRequestMatchThenAuthorizationRequestRemoved ( ) throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( "/callback/client-1" ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse ( authorizationRequest ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
this . setUpAuthenticationResult ( this . registration1 ) ;
this . filter . doFilter ( request , response , filterChain ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
assertThat ( this . authorizationRequestRepository . loadAuthorizationRequest ( request ) ) . isNull ( ) ;
assertThat ( this . authorizationRequestRepository . loadAuthorizationRequest ( authorizationResponse ) ) . isNull ( ) ;
}
@Test
public void doFilterWhenAuthorizationFailsThenHandleOAuth2AuthorizationException ( ) throws Exception {
String requestUri = "/callback/client-1" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( "/callback/client-1" ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse ( authorizationRequest ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
this . setUpAuthorizationRequest ( request , response , this . registration1 ) ;
OAuth2Error error = new OAuth2Error ( OAuth2ErrorCodes . INVALID_GRANT ) ;
when ( this . authenticationManager . authenticate ( any ( Authentication . class ) ) )
. thenThrow ( new OAuth2AuthorizationException ( error ) ) ;
this . filter . doFilter ( request , response , filterChain ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
assertThat ( response . getRedirectedUrl ( ) ) . isEqualTo ( "http://localhost/callback/client-1?error=invalid_grant" ) ;
}
@Test
public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService ( ) throws Exception {
String requestUri = "/callback/client-1" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
public void doFilterWhenAuthorizationSucceedsThenAuthorizedClientSavedToService ( ) throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( "/callback/client-1" ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse ( authorizationRequest ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration1 ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
this . setUpAuthenticationResult ( this . registration1 ) ;
this . filter . doFilter ( request , response , filterChain ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientService . loadAuthorizedClient (
this . registration1 . getRegistrationId ( ) , this . principalName1 ) ;
@ -241,40 +293,31 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@@ -241,40 +293,31 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
}
@Test
public void doFilterWhenAuthorizationResponseSuccessThenRedirected ( ) throws Exception {
String requestUri = "/callback/client-1" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
public void doFilterWhenAuthorizationSucceedsThenRedirected ( ) throws Exception {
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( "/callback/client-1" ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse ( authorizationRequest ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration1 ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
this . setUpAuthenticationResult ( this . registration1 ) ;
this . filter . doFilter ( request , response , filterChain ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
assertThat ( response . getRedirectedUrl ( ) ) . isEqualTo ( "http://localhost/callback/client-1" ) ;
}
@Test
public void doFilterWhenAuthorizationResponseSuccessHasSavedRequestThenRedirected ToSavedRequest ( ) throws Exception {
public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirect ToSavedRequest ( ) throws Exception {
String requestUri = "/saved-request" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
RequestCache requestCache = new HttpSessionRequestCache ( ) ;
requestCache . saveRequest ( request , response ) ;
requestUri = "/callback/client-1" ;
request . setRequestURI ( requestUri ) ;
request . setRequestURI ( "/callback/client-1" ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration1 ) ;
this . setUpAuthenticationResult ( this . registration1 ) ;
@ -284,36 +327,30 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@@ -284,36 +327,30 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
}
@Test
public void doFilterWhenAuthorizationResponseSucces sAndAnonymousAccessThenAuthorizedClientSavedToHttpSession ( ) throws Exception {
public void doFilterWhenAuthorizationSucceed sAndAnonymousAccessThenAuthorizedClientSavedToHttpSession ( ) throws Exception {
AnonymousAuthenticationToken anonymousPrincipal =
new AnonymousAuthenticationToken ( "key-1234" , "anonymousUser" , AuthorityUtils . createAuthorityList ( "ROLE_ANONYMOUS" ) ) ;
SecurityContext securityContext = SecurityContextHolder . createEmptyContext ( ) ;
securityContext . setAuthentication ( anonymousPrincipal ) ;
SecurityContextHolder . setContext ( securityContext ) ;
String requestUri = "/callback/client-1" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( "/callback/client-1" ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse ( authorizationRequest ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration1 ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
this . setUpAuthenticationResult ( this . registration1 ) ;
this . filter . doFilter ( request , response , filterChain ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientRepository . loadAuthorizedClient (
this . registration1 . getRegistrationId ( ) , anonymousPrincipal , request ) ;
this . registration1 . getRegistrationId ( ) , anonymousPrincipal , authorizationResponse ) ;
assertThat ( authorizedClient ) . isNotNull ( ) ;
assertThat ( authorizedClient . getClientRegistration ( ) ) . isEqualTo ( this . registration1 ) ;
assertThat ( authorizedClient . getPrincipalName ( ) ) . isEqualTo ( anonymousPrincipal . getName ( ) ) ;
assertThat ( authorizedClient . getAccessToken ( ) ) . isNotNull ( ) ;
HttpSession session = request . getSession ( false ) ;
HttpSession session = authorizationResponse . getSession ( false ) ;
assertThat ( session ) . isNotNull ( ) ;
@SuppressWarnings ( "unchecked" )
@ -325,33 +362,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@@ -325,33 +362,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
}
@Test
public void doFilterWhenAuthorizationResponseSucces sAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession ( ) throws Exception {
public void doFilterWhenAuthorizationSucceed sAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession ( ) throws Exception {
SecurityContext securityContext = SecurityContextHolder . createEmptyContext ( ) ;
SecurityContextHolder . setContext ( securityContext ) ; // null Authentication
String requestUri = "/callback/client-1" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
MockHttpServletRequest authorizationRequest = createAuthorizationRequest ( "/callback/client-1" ) ;
MockHttpServletRequest authorizationResponse = createAuthorizationResponse ( authorizationRequest ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration1 ) ;
this . setUpAuthorizationRequest ( authorizationRequest , response , this . registration1 ) ;
this . setUpAuthenticationResult ( this . registration1 ) ;
this . filter . doFilter ( request , response , filterChain ) ;
this . filter . doFilter ( authorizationResponse , response , filterChain ) ;
OAuth2AuthorizedClient authorizedClient = this . authorizedClientRepository . loadAuthorizedClient (
this . registration1 . getRegistrationId ( ) , null , request ) ;
this . registration1 . getRegistrationId ( ) , null , authorizationResponse ) ;
assertThat ( authorizedClient ) . isNotNull ( ) ;
assertThat ( authorizedClient . getClientRegistration ( ) ) . isEqualTo ( this . registration1 ) ;
assertThat ( authorizedClient . getPrincipalName ( ) ) . isEqualTo ( "anonymousUser" ) ;
assertThat ( authorizedClient . getAccessToken ( ) ) . isNotNull ( ) ;
HttpSession session = request . getSession ( false ) ;
HttpSession session = authorizationResponse . getSession ( false ) ;
assertThat ( session ) . isNotNull ( ) ;
@SuppressWarnings ( "unchecked" )
@ -362,13 +393,51 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
@@ -362,13 +393,51 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
assertThat ( authorizedClients . values ( ) . iterator ( ) . next ( ) ) . isSameAs ( authorizedClient ) ;
}
private static MockHttpServletRequest createAuthorizationRequest ( String requestUri ) {
return createAuthorizationRequest ( requestUri , new LinkedHashMap < > ( ) ) ;
}
private static MockHttpServletRequest createAuthorizationRequest ( String requestUri , Map < String , String > parameters ) {
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
if ( ! CollectionUtils . isEmpty ( parameters ) ) {
parameters . forEach ( request : : addParameter ) ;
request . setQueryString (
parameters . entrySet ( ) . stream ( )
. map ( e - > e . getKey ( ) + "=" + e . getValue ( ) )
. collect ( Collectors . joining ( "&" ) ) ) ;
}
return request ;
}
private static MockHttpServletRequest createAuthorizationResponse ( MockHttpServletRequest authorizationRequest ) {
return createAuthorizationResponse ( authorizationRequest , new LinkedHashMap < > ( ) ) ;
}
private static MockHttpServletRequest createAuthorizationResponse (
MockHttpServletRequest authorizationRequest , Map < String , String > additionalParameters ) {
MockHttpServletRequest authorizationResponse = new MockHttpServletRequest (
authorizationRequest . getMethod ( ) , authorizationRequest . getRequestURI ( ) ) ;
authorizationResponse . setServletPath ( authorizationRequest . getRequestURI ( ) ) ;
authorizationRequest . getParameterMap ( ) . forEach ( authorizationResponse : : addParameter ) ;
authorizationResponse . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
authorizationResponse . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
additionalParameters . forEach ( authorizationResponse : : addParameter ) ;
authorizationResponse . setQueryString (
authorizationResponse . getParameterMap ( ) . entrySet ( ) . stream ( )
. map ( e - > e . getKey ( ) + "=" + e . getValue ( ) [ 0 ] )
. collect ( Collectors . joining ( "&" ) ) ) ;
authorizationResponse . setSession ( authorizationRequest . getSession ( ) ) ;
return authorizationResponse ;
}
private void setUpAuthorizationRequest ( HttpServletRequest request , HttpServletResponse response ,
ClientRegistration registration ) {
Map < String , Object > additionalParameters = new HashMap < > ( ) ;
additionalParameters . put ( OAuth2ParameterNames . REGISTRATION_ID , registration . getRegistrationId ( ) ) ;
Map < String , Object > attribute s = new HashMap < > ( ) ;
attribute s . put ( OAuth2ParameterNames . REGISTRATION_ID , registration . getRegistrationId ( ) ) ;
OAuth2AuthorizationRequest authorizationRequest = request ( )
. additionalParameters ( additionalParameters )
. redirectUri ( request . getRequestURL ( ) . toString ( ) ) . build ( ) ;
. attributes ( attribute s )
. redirectUri ( UrlUtils . buildFullRequestUrl ( request ) ) . build ( ) ;
this . authorizationRequestRepository . saveAuthorizationRequest ( authorizationRequest , request , response ) ;
}