@ -37,7 +37,6 @@ import org.springframework.http.HttpStatus;
@@ -37,7 +37,6 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType ;
import org.springframework.mock.web.MockHttpServletRequest ;
import org.springframework.mock.web.MockHttpServletResponse ;
import org.springframework.mock.web.MockServletContext ;
import org.springframework.security.authentication.AuthenticationDetailsSource ;
import org.springframework.security.authentication.AuthenticationManager ;
import org.springframework.security.authentication.TestingAuthenticationToken ;
@ -59,8 +58,8 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
@@ -59,8 +58,8 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler ;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler ;
import org.springframework.security.web.authentication.WebAuthenticationDetails ;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders ;
import org.springframework.util.StringUtils ;
import org.springframework.web.util.UriComponentsBuilder ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatThrownBy ;
@ -173,7 +172,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -173,7 +172,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . RESPONSE_TYPE ,
OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . removeParameter ( OAuth2ParameterNames . RESPONSE_TYPE ) ) ;
request - > {
request . removeParameter ( OAuth2ParameterNames . RESPONSE_TYPE ) ;
updateQueryString ( request ) ;
} ) ;
}
@Test
@ -182,7 +184,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -182,7 +184,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . RESPONSE_TYPE ,
OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . addParameter ( OAuth2ParameterNames . RESPONSE_TYPE , "id_token" ) ) ;
request - > {
request . addParameter ( OAuth2ParameterNames . RESPONSE_TYPE , "id_token" ) ;
updateQueryString ( request ) ;
} ) ;
}
@Test
@ -191,7 +196,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -191,7 +196,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . RESPONSE_TYPE ,
OAuth2ErrorCodes . UNSUPPORTED_RESPONSE_TYPE ,
request - > request . setParameter ( OAuth2ParameterNames . RESPONSE_TYPE , "id_token" ) ) ;
request - > {
request . setParameter ( OAuth2ParameterNames . RESPONSE_TYPE , "id_token" ) ;
updateQueryString ( request ) ;
} ) ;
}
@Test
@ -200,7 +208,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -200,7 +208,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . CLIENT_ID ,
OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . removeParameter ( OAuth2ParameterNames . CLIENT_ID ) ) ;
request - > {
request . removeParameter ( OAuth2ParameterNames . CLIENT_ID ) ;
updateQueryString ( request ) ;
} ) ;
}
@Test
@ -209,7 +220,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -209,7 +220,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . CLIENT_ID ,
OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . addParameter ( OAuth2ParameterNames . CLIENT_ID , "client-2" ) ) ;
request - > {
request . addParameter ( OAuth2ParameterNames . CLIENT_ID , "client-2" ) ;
updateQueryString ( request ) ;
} ) ;
}
@Test
@ -218,7 +232,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -218,7 +232,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . REDIRECT_URI ,
OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . addParameter ( OAuth2ParameterNames . REDIRECT_URI , "https://example2.com" ) ) ;
request - > {
request . addParameter ( OAuth2ParameterNames . REDIRECT_URI , "https://example2.com" ) ;
updateQueryString ( request ) ;
} ) ;
}
@Test
@ -227,7 +244,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -227,7 +244,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . SCOPE ,
OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . addParameter ( OAuth2ParameterNames . SCOPE , "scope2" ) ) ;
request - > {
request . addParameter ( OAuth2ParameterNames . SCOPE , "scope2" ) ;
updateQueryString ( request ) ;
} ) ;
}
@Test
@ -236,7 +256,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -236,7 +256,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . STATE ,
OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . addParameter ( OAuth2ParameterNames . STATE , "state2" ) ) ;
request - > {
request . addParameter ( OAuth2ParameterNames . STATE , "state2" ) ;
updateQueryString ( request ) ;
} ) ;
}
@Test
@ -266,13 +289,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -266,13 +289,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
request - > {
request . addParameter ( PkceParameterNames . CODE_CHALLENGE , "code-challenge" ) ;
request . addParameter ( PkceParameterNames . CODE_CHALLENGE , "another-code-challenge" ) ;
String originalQueryString = request . getQueryString ( ) ;
if ( StringUtils . hasText ( originalQueryString ) ) {
String newQueryString = originalQueryString . concat ( PkceParameterNames . CODE_CHALLENGE )
. concat ( "=code-challenge" ) . concat ( "&" )
. concat ( PkceParameterNames . CODE_CHALLENGE ) . concat ( "=another-code-challenge" ) ;
request . setQueryString ( newQueryString ) ;
}
updateQueryString ( request ) ;
} ) ;
}
@ -285,13 +302,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -285,13 +302,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
request - > {
request . addParameter ( PkceParameterNames . CODE_CHALLENGE_METHOD , "S256" ) ;
request . addParameter ( PkceParameterNames . CODE_CHALLENGE_METHOD , "S256" ) ;
String originalQueryString = request . getQueryString ( ) ;
if ( StringUtils . hasText ( originalQueryString ) ) {
String newQueryString = originalQueryString . concat ( PkceParameterNames . CODE_CHALLENGE_METHOD )
. concat ( "=S256" ) . concat ( "&" )
. concat ( PkceParameterNames . CODE_CHALLENGE_METHOD ) . concat ( "=S256" ) ;
request . setQueryString ( newQueryString ) ;
}
updateQueryString ( request ) ;
} ) ;
}
@ -574,10 +585,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -574,10 +585,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
MockHttpServletRequest request = createAuthorizationRequest ( registeredClient ) ;
request . addParameter ( "custom-param" , "custom-value-1" , "custom-value-2" ) ;
String newQueryString = request . getQueryString ( ) . concat ( "custom-param" )
. concat ( "=custom-value-1" ) . concat ( "&" )
. concat ( "custom-param" ) . concat ( "=custom-value-2" ) ;
request . setQueryString ( newQueryString ) ;
updateQueryString ( request ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
@ -623,6 +631,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -623,6 +631,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
MockHttpServletRequest request = createAuthorizationRequest ( registeredClient ) ;
request . setMethod ( "POST" ) ; // OpenID Connect supports POST method
request . setQueryString ( null ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
@ -667,15 +676,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -667,15 +676,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
private static MockHttpServletRequest createAuthorizationRequest ( RegisteredClient registeredClient ) {
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI ;
MockHttpServletRequest request = MockMvcRequestBuilders . get ( requestUri )
. queryParam ( OAuth2ParameterNames . RESPONSE_TYPE , OAuth2AuthorizationResponseType . CODE . getValue ( ) )
. queryParam ( OAuth2ParameterNames . CLIENT_ID , registeredClient . getClientId ( ) )
. queryParam ( OAuth2ParameterNames . REDIRECT_URI , registeredClient . getRedirectUris ( ) . iterator ( ) . next ( ) )
. queryParam ( OAuth2ParameterNames . SCOPE ,
StringUtils . collectionToDelimitedString ( registeredClient . getScopes ( ) , " " ) )
. queryParam ( OAuth2ParameterNames . STATE , "state" )
. buildRequest ( new MockServletContext ( ) ) ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . setRemoteAddr ( REMOTE_ADDRESS ) ;
request . addParameter ( OAuth2ParameterNames . RESPONSE_TYPE , OAuth2AuthorizationResponseType . CODE . getValue ( ) ) ;
request . addParameter ( OAuth2ParameterNames . CLIENT_ID , registeredClient . getClientId ( ) ) ;
request . addParameter ( OAuth2ParameterNames . REDIRECT_URI , registeredClient . getRedirectUris ( ) . iterator ( ) . next ( ) ) ;
request . addParameter ( OAuth2ParameterNames . SCOPE ,
StringUtils . collectionToDelimitedString ( registeredClient . getScopes ( ) , " " ) ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
updateQueryString ( request ) ;
return request ;
}
@ -692,6 +704,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -692,6 +704,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
return request ;
}
private static void updateQueryString ( MockHttpServletRequest request ) {
UriComponentsBuilder uriBuilder = UriComponentsBuilder . fromUriString ( request . getRequestURI ( ) ) ;
request . getParameterMap ( ) . forEach ( ( key , values ) - > {
if ( values . length > 0 ) {
for ( String value : values ) {
uriBuilder . queryParam ( key , value ) ;
}
}
} ) ;
request . setQueryString ( uriBuilder . build ( ) . getQuery ( ) ) ;
}
private static String scopeCheckbox ( String scope ) {
return MessageFormat . format (
"<input class=\"form-check-input\" type=\"checkbox\" name=\"scope\" value=\"{0}\" id=\"{0}\">" ,