diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java
index 56db12a0..006d9da2 100644
--- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java
+++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java
@@ -19,6 +19,9 @@ import java.security.Principal;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Set;
import org.springframework.beans.factory.annotation.Autowired;
@@ -35,17 +38,20 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
import org.springframework.security.oauth2.core.OAuth2TokenType;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.jwt.JoseHeader;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
-import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
-import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.util.Assert;
import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
@@ -55,6 +61,7 @@ import static org.springframework.security.oauth2.server.authorization.authentic
*
* @author Alexey Nesterov
* @author Joe Grandja
+ * @author Anoop Garlapati
* @since 0.0.3
* @see OAuth2RefreshTokenAuthenticationToken
* @see OAuth2AccessTokenAuthenticationToken
@@ -66,6 +73,7 @@ import static org.springframework.security.oauth2.server.authorization.authentic
* @see Section 6 Refreshing an Access Token
*/
public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {
+ private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
private final OAuth2AuthorizationService authorizationService;
private final JwtEncoder jwtEncoder;
@@ -174,19 +182,64 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
currentRefreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive());
}
+ Jwt jwtIdToken = null;
+ if (authorizedScopes.contains(OidcScopes.OPENID)) {
+ headersBuilder = JwtUtils.headers();
+ claimsBuilder = JwtUtils.idTokenClaims(
+ registeredClient, issuer, authorization.getPrincipalName(), null);
+
+ // @formatter:off
+ context = JwtEncodingContext.with(headersBuilder, claimsBuilder)
+ .registeredClient(registeredClient)
+ .principal(authorization.getAttribute(Principal.class.getName()))
+ .authorization(authorization)
+ .authorizedScopes(authorizedScopes)
+ .tokenType(ID_TOKEN_TOKEN_TYPE)
+ .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
+ .authorizationGrant(refreshTokenAuthentication)
+ .build();
+ // @formatter:on
+
+ this.jwtCustomizer.customize(context);
+
+ headers = context.getHeaders().build();
+ claims = context.getClaims().build();
+ jwtIdToken = this.jwtEncoder.encode(headers, claims);
+ }
+
+ OidcIdToken idToken;
+ if (jwtIdToken != null) {
+ idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(),
+ jwtIdToken.getExpiresAt(), jwtIdToken.getClaims());
+ } else {
+ idToken = null;
+ }
+
// @formatter:off
- authorization = OAuth2Authorization.from(authorization)
+ OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
.token(accessToken,
(metadata) ->
metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims()))
- .refreshToken(currentRefreshToken)
- .build();
+ .refreshToken(currentRefreshToken);
+ if (idToken != null) {
+ authorizationBuilder
+ .token(idToken,
+ (metadata) ->
+ metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims()));
+ }
+ authorization = authorizationBuilder.build();
// @formatter:on
this.authorizationService.save(authorization);
+ Map additionalParameters = Collections.emptyMap();
+ if (idToken != null) {
+ additionalParameters = new HashMap<>();
+ additionalParameters.put(OidcParameterNames.ID_TOKEN, idToken.getTokenValue());
+ }
+
return new OAuth2AccessTokenAuthenticationToken(
- registeredClient, clientPrincipal, accessToken, currentRefreshToken);
+ registeredClient, clientPrincipal, accessToken, currentRefreshToken, additionalParameters);
}
@Override
diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java
index 65a1fc1e..6823b0b3 100644
--- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java
+++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java
@@ -19,7 +19,9 @@ import java.security.Principal;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
+import java.util.Map;
import java.util.Set;
import org.junit.Before;
@@ -36,23 +38,28 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
-import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
-import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import static org.assertj.core.api.Assertions.entry;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -61,6 +68,7 @@ import static org.mockito.Mockito.when;
*
* @author Alexey Nesterov
* @author Joe Grandja
+ * @author Anoop Garlapati
* @since 0.0.3
*/
public class OAuth2RefreshTokenAuthenticationProviderTests {
@@ -156,6 +164,72 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
}
+ @Test
+ public void authenticateWhenValidRefreshTokenThenReturnIdToken() {
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+ OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+ when(this.authorizationService.findByToken(
+ eq(authorization.getRefreshToken().getToken().getTokenValue()),
+ eq(OAuth2TokenType.REFRESH_TOKEN)))
+ .thenReturn(authorization);
+
+ OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+ OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
+ authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
+
+ OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+ (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+ ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
+ verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture());
+ // Access Token context
+ JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0);
+ assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
+ assertThat(accessTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
+ assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization);
+ assertThat(accessTokenContext.getAuthorizedScopes())
+ .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
+ assertThat(accessTokenContext.getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN);
+ assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
+ assertThat(accessTokenContext.getAuthorizationGrant()).isEqualTo(authentication);
+ assertThat(accessTokenContext.getHeaders()).isNotNull();
+ assertThat(accessTokenContext.getClaims()).isNotNull();
+ Map claims = new HashMap<>();
+ accessTokenContext.getClaims().claims(claims::putAll);
+ assertThat(claims).flatExtracting(OAuth2ParameterNames.SCOPE)
+ .containsExactlyInAnyOrder(OidcScopes.OPENID, "scope1");
+ // ID Token context
+ JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1);
+ assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
+ assertThat(idTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
+ assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization);
+ assertThat(idTokenContext.getAuthorizedScopes())
+ .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
+ assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);
+ assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
+ assertThat(idTokenContext.getAuthorizationGrant()).isEqualTo(authentication);
+ assertThat(idTokenContext.getHeaders()).isNotNull();
+ assertThat(idTokenContext.getClaims()).isNotNull();
+
+ verify(this.jwtEncoder, times(2)).encode(any(), any()); // Access token and ID Token
+
+ ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+ verify(this.authorizationService).save(authorizationCaptor.capture());
+ OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+
+ assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
+ assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
+ assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
+ assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(authorization.getAccessToken());
+ OAuth2Authorization.Token idToken = updatedAuthorization.getToken(OidcIdToken.class);
+ assertThat(idToken).isNotNull();
+ assertThat(accessTokenAuthentication.getAdditionalParameters())
+ .containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getToken().getTokenValue()));
+ assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
+ // By default, refresh token is reused
+ assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
+ }
+
@Test
public void authenticateWhenReuseRefreshTokensFalseThenReturnNewRefreshToken() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()