diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java index 83a85651..18e192c3 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java @@ -122,11 +122,55 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen public Authentication authenticate(Authentication authentication) throws AuthenticationException { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = (OAuth2AuthorizationCodeRequestAuthenticationToken) authentication; + OAuth2Authorization pushedAuthorization = null; String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters() .get("request_uri"); if (StringUtils.hasText(requestUri)) { - authorizationCodeRequestAuthentication = fromPushedAuthorizationRequest( - authorizationCodeRequestAuthentication); + OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = null; + try { + pushedAuthorizationRequestUri = OAuth2PushedAuthorizationRequestUri.parse(requestUri); + } + catch (Exception ex) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, "request_uri", authorizationCodeRequestAuthentication, + null); + } + + pushedAuthorization = this.authorizationService.findByToken(pushedAuthorizationRequestUri.getState(), + STATE_TOKEN_TYPE); + if (pushedAuthorization == null) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, "request_uri", authorizationCodeRequestAuthentication, + null); + } + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Retrieved authorization with pushed authorization request"); + } + + OAuth2AuthorizationRequest authorizationRequest = pushedAuthorization + .getAttribute(OAuth2AuthorizationRequest.class.getName()); + + if (!authorizationCodeRequestAuthentication.getClientId().equals(authorizationRequest.getClientId())) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID, + authorizationCodeRequestAuthentication, null); + } + + if (Instant.now().isAfter(pushedAuthorizationRequestUri.getExpiresAt())) { + // Remove (effectively invalidating) the pushed authorization request + this.authorizationService.remove(pushedAuthorization); + if (this.logger.isWarnEnabled()) { + this.logger + .warn(LogMessage.format("Removed expired pushed authorization request for client id '%s'", + authorizationRequest.getClientId())); + } + throwError(OAuth2ErrorCodes.INVALID_REQUEST, "request_uri", authorizationCodeRequestAuthentication, + null); + } + + authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken( + authorizationCodeRequestAuthentication.getAuthorizationUri(), authorizationRequest.getClientId(), + (Authentication) authorizationCodeRequestAuthentication.getPrincipal(), + authorizationRequest.getRedirectUri(), authorizationRequest.getState(), + authorizationRequest.getScopes(), authorizationRequest.getAdditionalParameters()); } RegisteredClient registeredClient = this.registeredClientRepository @@ -223,13 +267,21 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen this.authorizationService.save(authorization); - Set currentAuthorizedScopes = (currentAuthorizationConsent != null) - ? currentAuthorizationConsent.getScopes() : null; - if (this.logger.isTraceEnabled()) { this.logger.trace("Saved authorization"); } + if (pushedAuthorization != null) { + // Enforce one-time use by removing the pushed authorization request + this.authorizationService.remove(pushedAuthorization); + if (this.logger.isTraceEnabled()) { + this.logger.trace("Removed authorization with pushed authorization request"); + } + } + + Set currentAuthorizedScopes = (currentAuthorizationConsent != null) + ? currentAuthorizationConsent.getScopes() : null; + return new OAuth2AuthorizationConsentAuthenticationToken(authorizationRequest.getAuthorizationUri(), registeredClient.getClientId(), principal, state, currentAuthorizedScopes, null); } @@ -257,6 +309,14 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen this.logger.trace("Saved authorization"); } + if (pushedAuthorization != null) { + // Enforce one-time use by removing the pushed authorization request + this.authorizationService.remove(pushedAuthorization); + if (this.logger.isTraceEnabled()) { + this.logger.trace("Removed authorization with pushed authorization request"); + } + } + String redirectUri = authorizationRequest.getRedirectUri(); if (!StringUtils.hasText(redirectUri)) { redirectUri = registeredClient.getRedirectUris().iterator().next(); @@ -335,55 +395,6 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen this.authorizationConsentRequired = authorizationConsentRequired; } - private OAuth2AuthorizationCodeRequestAuthenticationToken fromPushedAuthorizationRequest( - OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication) { - - String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters() - .get("request_uri"); - - OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = null; - try { - pushedAuthorizationRequestUri = OAuth2PushedAuthorizationRequestUri.parse(requestUri); - } - catch (Exception ex) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, "request_uri", authorizationCodeRequestAuthentication, null); - } - - OAuth2Authorization authorization = this.authorizationService - .findByToken(pushedAuthorizationRequestUri.getState(), STATE_TOKEN_TYPE); - if (authorization == null) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, "request_uri", authorizationCodeRequestAuthentication, null); - } - - if (this.logger.isTraceEnabled()) { - this.logger.trace("Retrieved authorization with pushed authorization request"); - } - - OAuth2AuthorizationRequest authorizationRequest = authorization - .getAttribute(OAuth2AuthorizationRequest.class.getName()); - - if (!authorizationCodeRequestAuthentication.getClientId().equals(authorizationRequest.getClientId())) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID, - authorizationCodeRequestAuthentication, null); - } - - if (Instant.now().isAfter(pushedAuthorizationRequestUri.getExpiresAt())) { - // Remove (effectively invalidating) the pushed authorization request - this.authorizationService.remove(authorization); - if (this.logger.isWarnEnabled()) { - this.logger.warn(LogMessage.format("Removed expired pushed authorization request for client id '%s'", - authorizationRequest.getClientId())); - } - throwError(OAuth2ErrorCodes.INVALID_REQUEST, "request_uri", authorizationCodeRequestAuthentication, null); - } - - return new OAuth2AuthorizationCodeRequestAuthenticationToken( - authorizationCodeRequestAuthentication.getAuthorizationUri(), authorizationRequest.getClientId(), - (Authentication) authorizationCodeRequestAuthentication.getPrincipal(), - authorizationRequest.getRedirectUri(), authorizationRequest.getState(), - authorizationRequest.getScopes(), authorizationRequest.getAdditionalParameters()); - } - private static boolean isAuthorizationConsentRequired( OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext) { if (!authenticationContext.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()) { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java index 61ddeddf..7945a21c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java @@ -633,6 +633,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests { assertAuthorizationCodeRequestWithAuthorizationCodeResult(registeredClient, authentication, authenticationResult); + verify(this.authorizationService).remove(eq(authorization)); } @Test