diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java index 62b1d16e..83a85651 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization.authentication; import java.security.Principal; +import java.time.Instant; import java.util.Arrays; import java.util.Base64; import java.util.Collections; @@ -27,6 +28,7 @@ import java.util.function.Predicate; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.log.LogMessage; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; @@ -353,6 +355,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen throwError(OAuth2ErrorCodes.INVALID_REQUEST, "request_uri", authorizationCodeRequestAuthentication, null); } + if (this.logger.isTraceEnabled()) { + this.logger.trace("Retrieved authorization with pushed authorization request"); + } + OAuth2AuthorizationRequest authorizationRequest = authorization .getAttribute(OAuth2AuthorizationRequest.class.getName()); @@ -361,6 +367,16 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen authorizationCodeRequestAuthentication, null); } + if (Instant.now().isAfter(pushedAuthorizationRequestUri.getExpiresAt())) { + // Remove (effectively invalidating) the pushed authorization request + this.authorizationService.remove(authorization); + if (this.logger.isWarnEnabled()) { + this.logger.warn(LogMessage.format("Removed expired pushed authorization request for client id '%s'", + authorizationRequest.getClientId())); + } + throwError(OAuth2ErrorCodes.INVALID_REQUEST, "request_uri", authorizationCodeRequestAuthentication, null); + } + return new OAuth2AuthorizationCodeRequestAuthenticationToken( authorizationCodeRequestAuthentication.getAuthorizationUri(), authorizationRequest.getClientId(), (Authentication) authorizationCodeRequestAuthentication.getPrincipal(), diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2PushedAuthorizationRequestUri.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2PushedAuthorizationRequestUri.java index 538cc8c5..084e8791 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2PushedAuthorizationRequestUri.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2PushedAuthorizationRequestUri.java @@ -41,8 +41,11 @@ final class OAuth2PushedAuthorizationRequestUri { private Instant expiresAt; static OAuth2PushedAuthorizationRequestUri create() { + return create(Instant.now().plusSeconds(30)); + } + + static OAuth2PushedAuthorizationRequestUri create(Instant expiresAt) { String state = DEFAULT_STATE_GENERATOR.generateKey(); - Instant expiresAt = Instant.now().plusSeconds(30); OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = new OAuth2PushedAuthorizationRequestUri(); pushedAuthorizationRequestUri.requestUri = REQUEST_URI_PREFIX + state + REQUEST_URI_DELIMITER + expiresAt.toEpochMilli(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java index def23a36..425d6e1f 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.server.authorization.authentication; import java.security.Principal; +import java.time.Instant; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -690,6 +691,31 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests { OAuth2ErrorCodes.INVALID_REQUEST, "client_id", null)); } + @Test + public void authenticateWhenAuthorizationCodeRequestWithExpiredRequestUriThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + + OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = OAuth2PushedAuthorizationRequestUri + .create(Instant.now().minusSeconds(5)); + Map additionalParameters = new HashMap<>(); + additionalParameters.put("request_uri", pushedAuthorizationRequestUri.getRequestUri()); + OAuth2Authorization authorization = TestOAuth2Authorizations + .authorization(registeredClient, additionalParameters) + .build(); + given(this.authorizationService.findByToken(eq(pushedAuthorizationRequestUri.getState()), eq(STATE_TOKEN_TYPE))) + .willReturn(authorization); + + OAuth2AuthorizationCodeRequestAuthenticationToken authentication = new OAuth2AuthorizationCodeRequestAuthenticationToken( + AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, null, null, null, + additionalParameters); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthorizationCodeRequestAuthenticationException.class) + .satisfies((ex) -> assertAuthenticationException((OAuth2AuthorizationCodeRequestAuthenticationException) ex, + OAuth2ErrorCodes.INVALID_REQUEST, "request_uri", null)); + verify(this.authorizationService).remove(eq(authorization)); + } + @Test public void authenticateWhenAuthorizationCodeNotGeneratedThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();