From fdf0a2f94ccd50d8502a981bbd4697f5dd002878 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 16 Jun 2022 05:44:43 -0400 Subject: [PATCH] Access token is available when customizing ID Token Closes gh-744 --- ...thorizationCodeAuthenticationProvider.java | 7 +++- ...th2RefreshTokenAuthenticationProvider.java | 7 +++- .../TestOAuth2Authorizations.java | 42 ++++++++++++------- ...zationCodeAuthenticationProviderTests.java | 12 ++++-- ...freshTokenAuthenticationProviderTests.java | 3 +- 5 files changed, 49 insertions(+), 22 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 9e720cc4..244ffc8b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -179,7 +179,12 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth // ----- ID token ----- OidcIdToken idToken; if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) { - tokenContext = tokenContextBuilder.tokenType(ID_TOKEN_TOKEN_TYPE).build(); + // @formatter:off + tokenContext = tokenContextBuilder + .tokenType(ID_TOKEN_TOKEN_TYPE) + .authorization(authorizationBuilder.build()) // ID token customizer may need access to the access token and/or refresh token + .build(); + // @formatter:on OAuth2Token generatedIdToken = this.tokenGenerator.generate(tokenContext); if (!(generatedIdToken instanceof Jwt)) { OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java index f2d766ff..877d876a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java @@ -176,7 +176,12 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic // ----- ID token ----- OidcIdToken idToken; if (authorizedScopes.contains(OidcScopes.OPENID)) { - tokenContext = tokenContextBuilder.tokenType(ID_TOKEN_TOKEN_TYPE).build(); + // @formatter:off + tokenContext = tokenContextBuilder + .tokenType(ID_TOKEN_TOKEN_TYPE) + .authorization(authorizationBuilder.build()) // ID token customizer may need access to the access token and/or refresh token + .build(); + // @formatter:on OAuth2Token generatedIdToken = this.tokenGenerator.generate(tokenContext); if (!(generatedIdToken instanceof Jwt)) { OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 014df2bb..fac50b0f 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,25 +47,30 @@ public class TestOAuth2Authorizations { return authorization(registeredClient, Collections.emptyMap()); } - public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, - OAuth2AccessToken accessToken, Map accessTokenClaims) { - return authorization(registeredClient, accessToken, accessTokenClaims, Collections.emptyMap()); - } - public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, Map authorizationRequestAdditionalParameters) { + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( + "code", Instant.now(), Instant.now().plusSeconds(120)); OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300)); - return authorization(registeredClient, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters); + return authorization(registeredClient, authorizationCode, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters); } - private static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, - OAuth2AccessToken accessToken, Map accessTokenClaims, - Map authorizationRequestAdditionalParameters) { + public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, + OAuth2AuthorizationCode authorizationCode) { + return authorization(registeredClient, authorizationCode, null, Collections.emptyMap(), Collections.emptyMap()); + } + + public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, + OAuth2AccessToken accessToken, Map accessTokenClaims) { OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( "code", Instant.now(), Instant.now().plusSeconds(120)); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken( - "refresh-token", Instant.now(), Instant.now().plus(1, ChronoUnit.HOURS)); + return authorization(registeredClient, authorizationCode, accessToken, accessTokenClaims, Collections.emptyMap()); + } + + private static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, + OAuth2AuthorizationCode authorizationCode, OAuth2AccessToken accessToken, + Map accessTokenClaims, Map authorizationRequestAdditionalParameters) { OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .authorizationUri("https://provider.com/oauth2/authorize") .clientId(registeredClient.getClientId()) @@ -74,18 +79,25 @@ public class TestOAuth2Authorizations { .additionalParameters(authorizationRequestAdditionalParameters) .state("state") .build(); - return OAuth2Authorization.withRegisteredClient(registeredClient) + OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient) .id("id") .principalName("principal") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .token(authorizationCode) - .token(accessToken, (metadata) -> metadata.putAll(tokenMetadata(accessTokenClaims))) - .refreshToken(refreshToken) .attribute(OAuth2ParameterNames.STATE, "state") .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest) .attribute(Principal.class.getName(), new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B")) .attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizationRequest.getScopes()); + if (accessToken != null) { + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken( + "refresh-token", Instant.now(), Instant.now().plus(1, ChronoUnit.HOURS)); + builder + .token(accessToken, (metadata) -> metadata.putAll(tokenMetadata(accessTokenClaims))) + .refreshToken(refreshToken); + } + + return builder; } private static Map tokenMetadata(Map tokenClaims) { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index c712a0b3..33b05d6b 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -443,7 +443,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build(); - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( + "code", Instant.now(), Instant.now().plusSeconds(120)); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, authorizationCode).build(); when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) .thenReturn(authorization); @@ -466,6 +468,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient); assertThat(accessTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName())); assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization); + assertThat(accessTokenContext.getAuthorization().getAccessToken()).isNull(); assertThat(accessTokenContext.getAuthorizedScopes()) .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)); assertThat(accessTokenContext.getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN); @@ -481,7 +484,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1); assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient); assertThat(idTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName())); - assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization); + assertThat(idTokenContext.getAuthorization()).isNotEqualTo(authorization); + assertThat(idTokenContext.getAuthorization().getAccessToken()).isNotNull(); assertThat(idTokenContext.getAuthorizedScopes()) .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)); assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN); @@ -503,8 +507,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(accessTokenScopes); assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull(); assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken()); - OAuth2Authorization.Token authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class); - assertThat(authorizationCode.isInvalidated()).isTrue(); + OAuth2Authorization.Token authorizationCodeToken = updatedAuthorization.getToken(OAuth2AuthorizationCode.class); + assertThat(authorizationCodeToken.isInvalidated()).isTrue(); OAuth2Authorization.Token idToken = updatedAuthorization.getToken(OidcIdToken.class); assertThat(idToken).isNotNull(); assertThat(accessTokenAuthentication.getAdditionalParameters()) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java index c1d1cb63..269dd054 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java @@ -233,7 +233,8 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1); assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient); assertThat(idTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName())); - assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization); + assertThat(idTokenContext.getAuthorization()).isNotEqualTo(authorization); + assertThat(idTokenContext.getAuthorization().getAccessToken()).isNotEqualTo(authorization.getAccessToken()); assertThat(idTokenContext.getAuthorizedScopes()) .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)); assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);