diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java index 222a97ca..1288d200 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java @@ -118,11 +118,20 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme } } - // state (RECOMMENDED) + // state + // RECOMMENDED for Authorization Request String state = parameters.getFirst(OAuth2ParameterNames.STATE); - if (StringUtils.hasText(state) && - parameters.get(OAuth2ParameterNames.STATE).size() != 1) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE); + if (authorizationRequest) { + if (StringUtils.hasText(state) && + parameters.get(OAuth2ParameterNames.STATE).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE); + } + } else { + // REQUIRED for Authorization Consent Request + if (!StringUtils.hasText(state) || + parameters.get(OAuth2ParameterNames.STATE).size() != 1) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE); + } } // code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index f8f3cf09..b34e9573 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -222,6 +222,24 @@ public class OAuth2AuthorizationEndpointFilterTests { request -> request.addParameter(OAuth2ParameterNames.STATE, "state2")); } + @Test + public void doFilterWhenAuthorizationConsentRequestMissingStateThenInvalidRequestError() throws Exception { + doFilterWhenAuthorizationConsentRequestInvalidParameterThenError( + TestRegisteredClients.registeredClient().build(), + OAuth2ParameterNames.STATE, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter(OAuth2ParameterNames.STATE)); + } + + @Test + public void doFilterWhenAuthorizationConsentRequestMultipleStateThenInvalidRequestError() throws Exception { + doFilterWhenAuthorizationConsentRequestInvalidParameterThenError( + TestRegisteredClients.registeredClient().build(), + OAuth2ParameterNames.STATE, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.STATE, "state2")); + } + @Test public void doFilterWhenAuthorizationRequestMultipleCodeChallengeThenInvalidRequestError() throws Exception { doFilterWhenAuthorizationRequestInvalidParameterThenError( @@ -534,6 +552,13 @@ public class OAuth2AuthorizationEndpointFilterTests { parameterName, errorCode, requestConsumer); } + private void doFilterWhenAuthorizationConsentRequestInvalidParameterThenError(RegisteredClient registeredClient, + String parameterName, String errorCode, Consumer requestConsumer) throws Exception { + + doFilterWhenRequestInvalidParameterThenError(createAuthorizationConsentRequest(registeredClient), + parameterName, errorCode, requestConsumer); + } + private void doFilterWhenRequestInvalidParameterThenError(MockHttpServletRequest request, String parameterName, String errorCode, Consumer requestConsumer) throws Exception { @@ -564,6 +589,18 @@ public class OAuth2AuthorizationEndpointFilterTests { return request; } + private static MockHttpServletRequest createAuthorizationConsentRequest(RegisteredClient registeredClient) { + String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + + request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + registeredClient.getScopes().forEach((scope) -> request.addParameter(OAuth2ParameterNames.SCOPE, scope)); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + return request; + } + private static OAuth2AuthorizationCodeRequestAuthenticationToken.Builder authorizationCodeRequestAuthentication( RegisteredClient registeredClient, Authentication principal) { return OAuth2AuthorizationCodeRequestAuthenticationToken.with(registeredClient.getClientId(), principal)