@ -222,6 +222,24 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -222,6 +222,24 @@ public class OAuth2AuthorizationEndpointFilterTests {
request - > request . addParameter ( OAuth2ParameterNames . STATE , "state2" ) ) ;
}
@Test
public void doFilterWhenAuthorizationConsentRequestMissingStateThenInvalidRequestError ( ) throws Exception {
doFilterWhenAuthorizationConsentRequestInvalidParameterThenError (
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . STATE ,
OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . removeParameter ( OAuth2ParameterNames . STATE ) ) ;
}
@Test
public void doFilterWhenAuthorizationConsentRequestMultipleStateThenInvalidRequestError ( ) throws Exception {
doFilterWhenAuthorizationConsentRequestInvalidParameterThenError (
TestRegisteredClients . registeredClient ( ) . build ( ) ,
OAuth2ParameterNames . STATE ,
OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . addParameter ( OAuth2ParameterNames . STATE , "state2" ) ) ;
}
@Test
public void doFilterWhenAuthorizationRequestMultipleCodeChallengeThenInvalidRequestError ( ) throws Exception {
doFilterWhenAuthorizationRequestInvalidParameterThenError (
@ -534,6 +552,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -534,6 +552,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
parameterName , errorCode , requestConsumer ) ;
}
private void doFilterWhenAuthorizationConsentRequestInvalidParameterThenError ( RegisteredClient registeredClient ,
String parameterName , String errorCode , Consumer < MockHttpServletRequest > requestConsumer ) throws Exception {
doFilterWhenRequestInvalidParameterThenError ( createAuthorizationConsentRequest ( registeredClient ) ,
parameterName , errorCode , requestConsumer ) ;
}
private void doFilterWhenRequestInvalidParameterThenError ( MockHttpServletRequest request ,
String parameterName , String errorCode , Consumer < MockHttpServletRequest > requestConsumer ) throws Exception {
@ -564,6 +589,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -564,6 +589,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
return request ;
}
private static MockHttpServletRequest createAuthorizationConsentRequest ( RegisteredClient registeredClient ) {
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI ;
MockHttpServletRequest request = new MockHttpServletRequest ( "POST" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CLIENT_ID , registeredClient . getClientId ( ) ) ;
registeredClient . getScopes ( ) . forEach ( ( scope ) - > request . addParameter ( OAuth2ParameterNames . SCOPE , scope ) ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
return request ;
}
private static OAuth2AuthorizationCodeRequestAuthenticationToken . Builder authorizationCodeRequestAuthentication (
RegisteredClient registeredClient , Authentication principal ) {
return OAuth2AuthorizationCodeRequestAuthenticationToken . with ( registeredClient . getClientId ( ) , principal )