Browse Source

Fix generating ID token with null sid when refresh_token grant

Closes gh-1283
pull/1349/head
cbilodeau 3 years ago committed by Joe Grandja
parent
commit
b6f3b5cc45
  1. 4
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java
  2. 82
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java

@ -134,7 +134,9 @@ public final class JwtGenerator implements OAuth2TokenGenerator<Jwt> { @@ -134,7 +134,9 @@ public final class JwtGenerator implements OAuth2TokenGenerator<Jwt> {
}
} else if (AuthorizationGrantType.REFRESH_TOKEN.equals(context.getAuthorizationGrantType())) {
OidcIdToken currentIdToken = context.getAuthorization().getToken(OidcIdToken.class).getToken();
claimsBuilder.claim("sid", currentIdToken.getClaim("sid"));
if (currentIdToken.hasClaim("sid")) {
claimsBuilder.claim("sid", currentIdToken.getClaim("sid"));
}
claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, currentIdToken.<Date>getClaim(IdTokenClaimNames.AUTH_TIME));
}
}

82
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@ -47,6 +47,7 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; @@ -47,6 +47,7 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.JwtEncoderParameters;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
@ -250,7 +251,86 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { @@ -250,7 +251,86 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
assertThat(idTokenContext.getJwsHeader()).isNotNull();
assertThat(idTokenContext.getClaims()).isNotNull();
verify(this.jwtEncoder, times(2)).encode(any()); // Access token and ID Token
ArgumentCaptor<JwtEncoderParameters> jwtEncoderParametersArgumentCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class);
verify(this.jwtEncoder, times(2)).encode(jwtEncoderParametersArgumentCaptor.capture()); // Access token and ID Token
JwtEncoderParameters jwtEncoderParameters = jwtEncoderParametersArgumentCaptor.getValue();
assertThat(jwtEncoderParameters.getClaims().getClaims().get("sid")).isNotNull();
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(authorization.getAccessToken());
OAuth2Authorization.Token<OidcIdToken> idToken = updatedAuthorization.getToken(OidcIdToken.class);
assertThat(idToken).isNotNull();
assertThat(accessTokenAuthentication.getAdditionalParameters())
.containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getToken().getTokenValue()));
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
// By default, refresh token is reused
assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
}
@Test
public void authenticateWhenValidRefreshTokenThenReturnIdTokenWithoutSid() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
OidcIdToken authorizedIdToken = OidcIdToken.withTokenValue("id-token")
.issuer("https://provider.com")
.subject("subject")
.issuedAt(Instant.now())
.expiresAt(Instant.now().plusSeconds(60))
.claim(IdTokenClaimNames.AUTH_TIME, Date.from(Instant.now()))
.build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).token(authorizedIdToken).build();
when(this.authorizationService.findByToken(
eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(OAuth2TokenType.REFRESH_TOKEN)))
.thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture());
// Access Token context
JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0);
assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
assertThat(accessTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization);
assertThat(accessTokenContext.getAuthorizedScopes()).isEqualTo(authorization.getAuthorizedScopes());
assertThat(accessTokenContext.getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN);
assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
assertThat(accessTokenContext.<OAuth2AuthorizationGrantAuthenticationToken>getAuthorizationGrant()).isEqualTo(authentication);
assertThat(accessTokenContext.getJwsHeader()).isNotNull();
assertThat(accessTokenContext.getClaims()).isNotNull();
Map<String, Object> claims = new HashMap<>();
accessTokenContext.getClaims().claims(claims::putAll);
assertThat(claims).flatExtracting(OAuth2ParameterNames.SCOPE)
.containsExactlyInAnyOrder(OidcScopes.OPENID, "scope1");
// ID Token context
JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1);
assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
assertThat(idTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
assertThat(idTokenContext.getAuthorization()).isNotEqualTo(authorization);
assertThat(idTokenContext.getAuthorization().getAccessToken()).isNotEqualTo(authorization.getAccessToken());
assertThat(idTokenContext.getAuthorizedScopes()).isEqualTo(authorization.getAuthorizedScopes());
assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);
assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
assertThat(idTokenContext.<OAuth2AuthorizationGrantAuthenticationToken>getAuthorizationGrant()).isEqualTo(authentication);
assertThat(idTokenContext.getJwsHeader()).isNotNull();
assertThat(idTokenContext.getClaims()).isNotNull();
ArgumentCaptor<JwtEncoderParameters> jwtEncoderParametersArgumentCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class);
verify(this.jwtEncoder, times(2)).encode(jwtEncoderParametersArgumentCaptor.capture()); // Access token and ID Token
JwtEncoderParameters jwtEncoderParameters = jwtEncoderParametersArgumentCaptor.getValue();
assertThat(jwtEncoderParameters.getClaims().getClaims().get("sid")).isNull();
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());

Loading…
Cancel
Save