Evgeniy Cheban 1 month ago committed by GitHub
parent
commit
71dee7a760
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 20
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java
  2. 302
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandler.java
  3. 31
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java
  4. 5
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java
  5. 36
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java
  6. 412
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandlerTests.java
  7. 10
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java
  8. 19
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java
  9. 28
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

20
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizedClientProviderBuilder.java

@ -254,6 +254,8 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { @@ -254,6 +254,8 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder {
private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;
private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler;
private Duration clockSkew;
private Clock clock;
@ -274,6 +276,21 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { @@ -274,6 +276,21 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder {
return this;
}
/**
* Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} that is called after
* the client is re-authorized, defaults to
* {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}.
* @param refreshTokenSuccessHandler the
* {@link ReactiveOAuth2AuthorizationSuccessHandler} to use
* @return the {@link RefreshTokenGrantBuilder}
* @since 7.0
*/
public RefreshTokenGrantBuilder refreshTokenSuccessHandler(
ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler) {
this.refreshTokenSuccessHandler = refreshTokenSuccessHandler;
return this;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the access
* token expiry. An access token is considered expired if
@ -310,6 +327,9 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder { @@ -310,6 +327,9 @@ public final class ReactiveOAuth2AuthorizedClientProviderBuilder {
if (this.accessTokenResponseClient != null) {
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
}
if (this.refreshTokenSuccessHandler != null) {
authorizedClientProvider.setRefreshTokenSuccessHandler(this.refreshTokenSuccessHandler);
}
if (this.clockSkew != null) {
authorizedClientProvider.setClockSkew(this.clockSkew);
}

302
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandler.java

@ -0,0 +1,302 @@ @@ -0,0 +1,302 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import java.time.Duration;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import reactor.core.publisher.Mono;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.oidc.authentication.ReactiveOidcIdTokenDecoderFactory;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
/**
* A {@link ReactiveOAuth2AuthorizationSuccessHandler} that refreshes an {@link OidcUser}
* in the {@link SecurityContext} if the refreshed {@link OidcIdToken} is valid according
* to <a href=
* "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse">OpenID
* Connect Core 1.0 - Section 12.2 Successful Refresh Response</a>
*
* @author Evgeniy Cheban
* @since 7.0
*/
public final class RefreshTokenReactiveOAuth2AuthorizationSuccessHandler
implements ReactiveOAuth2AuthorizationSuccessHandler {
private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";
private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce";
private static final String REFRESH_TOKEN_RESPONSE_ERROR_URI = "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse";
// @formatter:off
private static final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.deferContextual(Mono::just)
.filter((c) -> c.hasKey(ServerWebExchange.class))
.map((c) -> c.get(ServerWebExchange.class));
// @formatter:on
private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
private ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new ReactiveOidcIdTokenDecoderFactory();
private ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = new OidcReactiveOAuth2UserService();
private GrantedAuthoritiesMapper authoritiesMapper = (authorities) -> authorities;
private Duration clockSkew = Duration.ofSeconds(60);
@Override
public Mono<Void> onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal,
Map<String, Object> attributes) {
if (!(principal instanceof OAuth2AuthenticationToken authenticationToken)
|| authenticationToken.getClass() != OAuth2AuthenticationToken.class) {
// If the application customizes the authentication result, then a custom
// handler should be provided.
return Mono.empty();
}
// The current principal must be an OidcUser.
if (!(authenticationToken.getPrincipal() instanceof OidcUser existingOidcUser)) {
return Mono.empty();
}
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
// The registrationId must match the one used to log in.
if (!authenticationToken.getAuthorizedClientRegistrationId().equals(clientRegistration.getRegistrationId())) {
return Mono.empty();
}
// Create, validate OidcIdToken and refresh OidcUser in the SecurityContext.
return Mono.zip(serverWebExchange(attributes), accessTokenResponse(attributes)).flatMap((t2) -> {
ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
Map<String, Object> additionalParameters = t2.getT2().getAdditionalParameters();
return jwtDecoder.decode((String) additionalParameters.get(OidcParameterNames.ID_TOKEN))
.onErrorMap(JwtException.class, (ex) -> {
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(),
null);
return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
})
.map((jwt) -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(),
jwt.getClaims()))
.doOnNext((idToken) -> validateIdToken(existingOidcUser, idToken))
.flatMap((idToken) -> {
OidcUserRequest userRequest = new OidcUserRequest(clientRegistration,
authorizedClient.getAccessToken(), idToken);
return this.userService.loadUser(userRequest);
})
.flatMap((oidcUser) -> refreshSecurityContext(t2.getT1(), clientRegistration, authenticationToken,
oidcUser));
});
}
private Mono<ServerWebExchange> serverWebExchange(Map<String, Object> attributes) {
if (attributes.get(ServerWebExchange.class.getName()) instanceof ServerWebExchange exchange) {
return Mono.just(exchange);
}
return currentServerWebExchangeMono;
}
private Mono<OAuth2AccessTokenResponse> accessTokenResponse(Map<String, Object> attributes) {
if (attributes.get(OAuth2AccessTokenResponse.class.getName()) instanceof OAuth2AccessTokenResponse response) {
return Mono.just(response);
}
return Mono.empty();
}
private void validateIdToken(OidcUser existingOidcUser, OidcIdToken idToken) {
// OpenID Connect Core 1.0 - Section 12.2 Successful Refresh Response
// If an ID Token is returned as a result of a token refresh request, the
// following requirements apply:
// its iss Claim Value MUST be the same as in the ID Token issued when the
// original authentication occurred,
validateIssuer(existingOidcUser, idToken);
// its sub Claim Value MUST be the same as in the ID Token issued when the
// original authentication occurred,
validateSubject(existingOidcUser, idToken);
// its iat Claim MUST represent the time that the new ID Token is issued,
validateIssuedAt(existingOidcUser, idToken);
// its aud Claim Value MUST be the same as in the ID Token issued when the
// original authentication occurred,
validateAudience(existingOidcUser, idToken);
// if the ID Token contains an auth_time Claim, its value MUST represent the time
// of the original authentication - not the time that the new ID token is issued,
validateAuthenticatedAt(existingOidcUser, idToken);
// it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of
// the original authentication contained nonce; however, if it is present, its
// value MUST be the same as in the ID Token issued at the time of the original
// authentication,
validateNonce(existingOidcUser, idToken);
}
private void validateIssuer(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!idToken.getIssuer().toString().equals(existingOidcUser.getIdToken().getIssuer().toString())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issuer",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private void validateSubject(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!idToken.getSubject().equals(existingOidcUser.getIdToken().getSubject())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid subject",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private void validateIssuedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!idToken.getIssuedAt().isAfter(existingOidcUser.getIdToken().getIssuedAt().minus(this.clockSkew))) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issued at time",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private void validateAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!isValidAudience(existingOidcUser, idToken)) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid audience",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private boolean isValidAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
List<String> idTokenAudiences = idToken.getAudience();
Set<String> oidcUserAudiences = new HashSet<>(existingOidcUser.getIdToken().getAudience());
if (idTokenAudiences.size() != oidcUserAudiences.size()) {
return false;
}
for (String audience : idTokenAudiences) {
if (!oidcUserAudiences.contains(audience)) {
return false;
}
}
return true;
}
private void validateAuthenticatedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
if (idToken.getAuthenticatedAt() == null) {
return;
}
if (!idToken.getAuthenticatedAt().equals(existingOidcUser.getIdToken().getAuthenticatedAt())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid authenticated at time",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private void validateNonce(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!StringUtils.hasText(idToken.getNonce())) {
return;
}
if (!idToken.getNonce().equals(existingOidcUser.getIdToken().getNonce())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE, "Invalid nonce",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}
private Mono<Void> refreshSecurityContext(ServerWebExchange exchange, ClientRegistration clientRegistration,
OAuth2AuthenticationToken authenticationToken, OidcUser oidcUser) {
Collection<? extends GrantedAuthority> mappedAuthorities = this.authoritiesMapper
.mapAuthorities(oidcUser.getAuthorities());
OAuth2AuthenticationToken authenticationResult = new OAuth2AuthenticationToken(oidcUser, mappedAuthorities,
clientRegistration.getRegistrationId());
authenticationResult.setDetails(authenticationToken.getDetails());
SecurityContextImpl securityContext = new SecurityContextImpl(authenticationResult);
return this.serverSecurityContextRepository.save(exchange, securityContext);
}
/**
* Sets a {@link ServerSecurityContextRepository} to use for refreshing a
* {@link SecurityContext}, defaults to
* {@link WebSessionServerSecurityContextRepository}.
* @param serverSecurityContextRepository the {@link ServerSecurityContextRepository}
* to use
*/
public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) {
Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null");
this.serverSecurityContextRepository = serverSecurityContextRepository;
}
/**
* Sets a {@link ReactiveJwtDecoderFactory} to use for decoding refreshed oidc
* id-token, defaults to {@link ReactiveOidcIdTokenDecoderFactory}.
* @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} to use
*/
public void setJwtDecoderFactory(ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
this.jwtDecoderFactory = jwtDecoderFactory;
}
/**
* Sets a {@link GrantedAuthoritiesMapper} to use for mapping
* {@link GrantedAuthority}s, defaults to no-op implementation.
* @param authoritiesMapper the {@link GrantedAuthoritiesMapper} to use
*/
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
this.authoritiesMapper = authoritiesMapper;
}
/**
* Sets a {@link ReactiveOAuth2UserService} to use for loading an {@link OidcUser}
* from refreshed oidc id-token, defaults to {@link OidcReactiveOAuth2UserService}.
* @param userService the {@link ReactiveOAuth2UserService} to use
*/
public void setUserService(ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService) {
Assert.notNull(userService, "userService cannot be null");
this.userService = userService;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OidcIdToken#getIssuedAt()} to match the existing
* {@link OidcUser#getIdToken()}'s issuedAt time, defaults to 60 seconds.
* @param clockSkew the maximum acceptable clock skew to use
*/
public void setClockSkew(Duration clockSkew) {
Assert.notNull(clockSkew, "clockSkew cannot be null");
Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
this.clockSkew = clockSkew;
}
}

31
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java

@ -21,7 +21,9 @@ import java.time.Duration; @@ -21,7 +21,9 @@ import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import reactor.core.publisher.Mono;
@ -33,6 +35,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio @@ -33,6 +35,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
/**
@ -40,6 +43,7 @@ import org.springframework.util.Assert; @@ -40,6 +43,7 @@ import org.springframework.util.Assert;
* {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant.
*
* @author Joe Grandja
* @author Evgeniy Cheban
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientProvider
* @see WebClientReactiveRefreshTokenTokenResponseClient
@ -49,6 +53,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider @@ -49,6 +53,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider
private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient();
private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
private Duration clockSkew = Duration.ofSeconds(60);
private Clock clock = Clock.systemUTC();
@ -96,8 +102,16 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider @@ -96,8 +102,16 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider
.flatMap(this.accessTokenResponseClient::getTokenResponse)
.onErrorMap(OAuth2AuthorizationException.class,
(e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), e))
.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()));
.flatMap((tokenResponse) -> {
OAuth2AuthorizedClient refreshedAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration,
context.getPrincipal().getName(), tokenResponse.getAccessToken(),
tokenResponse.getRefreshToken());
Map<String, Object> attributes = new HashMap<>(context.getAttributes());
attributes.put(OAuth2AccessTokenResponse.class.getName(), tokenResponse);
return this.refreshTokenSuccessHandler
.onAuthorizationSuccess(refreshedAuthorizedClient, context.getPrincipal(), attributes)
.then(Mono.just(refreshedAuthorizedClient));
});
}
private boolean hasTokenExpired(OAuth2Token token) {
@ -116,6 +130,19 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider @@ -116,6 +130,19 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider
this.accessTokenResponseClient = accessTokenResponseClient;
}
/**
* Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} that is called after the
* client is re-authorized, defaults to
* {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}.
* @param refreshTokenSuccessHandler the
* {@link ReactiveOAuth2AuthorizationSuccessHandler} to use
* @since 7.0
*/
public void setRefreshTokenSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler) {
Assert.notNull(refreshTokenSuccessHandler, "refreshTokenSuccessHandler cannot be null");
this.refreshTokenSuccessHandler = refreshTokenSuccessHandler;
}
/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is

5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java

@ -85,6 +85,7 @@ import org.springframework.web.server.ServerWebExchange; @@ -85,6 +85,7 @@ import org.springframework.web.server.ServerWebExchange;
*
* @author Joe Grandja
* @author Phil Clay
* @author Evgeniy Cheban
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientManager
* @see ReactiveOAuth2AuthorizedClientProvider
@ -318,10 +319,10 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React @@ -318,10 +319,10 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(currentServerWebExchangeMono)
.flatMap((exchange) -> {
Map<String, Object> contextAttributes = Collections.emptyMap();
Map<String, Object> contextAttributes = new HashMap<>();
contextAttributes.put(ServerWebExchange.class.getName(), serverWebExchange);
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}

36
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@ -51,6 +51,8 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; @@ -51,6 +51,8 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.client.ClientRequest;
@ -96,6 +98,7 @@ import org.springframework.web.server.ServerWebExchange; @@ -96,6 +98,7 @@ import org.springframework.web.server.ServerWebExchange;
* @author Rob Winch
* @author Joe Grandja
* @author Phil Clay
* @author Evgeniy Cheban
* @since 5.1
*/
public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
@ -139,6 +142,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @@ -139,6 +142,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private ClientResponseHandler clientResponseHandler;
private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
/**
* Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the
* provided parameters.
@ -330,8 +335,11 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @@ -330,8 +335,11 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
}
private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) {
return next.exchange(request)
.transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono));
// Re-request an Authentication from serverSecurityContextRepository since it
// might have been changed during provider invocation.
return effectiveAuthentication(request).flatMap((authentication) -> next.exchange(request)
.transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono))
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)));
}
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
@ -362,6 +370,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @@ -362,6 +370,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
// @formatter:on
}
private Mono<Authentication> effectiveAuthentication(ClientRequest request) {
// @formatter:off
return effectiveServerWebExchange(request)
.filter(Optional::isPresent)
.map(Optional::get)
.flatMap(this.serverSecurityContextRepository::load)
.map(SecurityContext::getAuthentication)
.switchIfEmpty(this.currentAuthenticationMono);
// @formatter:on
}
/**
* Returns a {@link Mono} the emits the {@code clientRegistrationId} that is active
* for the given request.
@ -445,6 +464,19 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @@ -445,6 +464,19 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
}
/**
* Sets a {@link ServerSecurityContextRepository} to use for re-obtaining a
* {@link SecurityContext} if it has been refreshed during provider invocation,
* defaults to {@link WebSessionServerSecurityContextRepository}.
* @param serverSecurityContextRepository the {@link ServerSecurityContextRepository}
* to use
* @since 7.0
*/
public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) {
Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null");
this.serverSecurityContextRepository = serverSecurityContextRepository;
}
@FunctionalInterface
private interface ClientResponseHandler {

412
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizationSuccessHandlerTests.java

@ -0,0 +1,412 @@ @@ -0,0 +1,412 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThatException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}.
*
* @author Evgeniy Cheban
*/
class RefreshTokenReactiveOAuth2AuthorizationSuccessHandlerTests {
@Test
void onAuthorizationSuccessWhenIdTokenValidThenSecurityContextRefreshed() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken);
OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse,
ServerWebExchange.class.getName(), exchange);
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes).block();
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyComplete();
StepVerifier.create(serverSecurityContextRepository.load(exchange).map(SecurityContext::getAuthentication))
.expectNext(authenticationToken)
.verifyComplete();
}
@Test
void onAuthorizationSuccessWhenIdTokenIssuerNotSameThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken);
OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse,
ServerWebExchange.class.getName(), exchange);
Map<String, Object> claims = new HashMap<>();
claims.put("iss", "https://issuer.com");
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid issuer");
}
@Test
void onAuthorizationSuccessWhenIdTokenSubNotSameThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken);
OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse,
ServerWebExchange.class.getName(), exchange);
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", "invalid_sub");
claims.put("aud", principal.getAudience());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid subject");
}
@Test
void onAuthorizationSuccessWhenIdTokenIatNotAfterThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken);
OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse,
ServerWebExchange.class.getName(), exchange);
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt().minus(Duration.ofDays(1)));
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid issued at time");
}
@Test
void onAuthorizationSuccessWhenIdTokenAudEmptyThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken);
OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse,
ServerWebExchange.class.getName(), exchange);
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", Collections.emptyList());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid audience");
}
@Test
void onAuthorizationSuccessWhenIdTokenAudNotContainThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken);
OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse,
ServerWebExchange.class.getName(), exchange);
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", List.of("invalid_client-id"));
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid audience");
}
@Test
void onAuthorizationSuccessWhenIdTokenAuthTimeNotSameThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken);
OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse,
ServerWebExchange.class.getName(), exchange);
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("auth_time", principal.getIssuedAt());
claims.put("nonce", principal.getNonce());
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_id_token] Invalid authenticated at time");
}
@Test
void onAuthorizationSuccessWhenIdTokenNonceNotSameThenException() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
DefaultOidcUser principal = TestOidcUsers.create();
OAuth2AuthenticationToken authenticationToken = new OAuth2AuthenticationToken(principal,
principal.getAuthorities(), clientRegistration.getRegistrationId());
OAuth2AccessToken accessToken = createAccessToken();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, principal.getName(),
accessToken);
OAuth2AccessTokenResponse tokenResponse = createAccessTokenResponse();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
Map<String, Object> attributes = Map.of(OAuth2AccessTokenResponse.class.getName(), tokenResponse,
ServerWebExchange.class.getName(), exchange);
Map<String, Object> claims = new HashMap<>();
claims.put("iss", principal.getIssuer());
claims.put("sub", principal.getSubject());
claims.put("aud", principal.getAudience());
claims.put("nonce", "invalid_nonce");
Jwt jwt = mock(Jwt.class);
given(jwt.getTokenValue()).willReturn("id-token-1234");
given(jwt.getIssuedAt()).willReturn(principal.getIssuedAt());
given(jwt.getClaims()).willReturn(claims);
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
given(jwtDecoder.decode(any())).willReturn(Mono.just(jwt));
ReactiveJwtDecoderFactory<ClientRegistration> reactiveJwtDecoderFactory = mock(ReactiveJwtDecoderFactory.class);
given(reactiveJwtDecoderFactory.createDecoder(any())).willReturn(jwtDecoder);
ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
given(userService.loadUser(any())).willReturn(Mono.just(principal));
WebSessionServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
RefreshTokenReactiveOAuth2AuthorizationSuccessHandler handler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
handler.setJwtDecoderFactory(reactiveJwtDecoderFactory);
handler.setUserService(userService);
handler.setServerSecurityContextRepository(serverSecurityContextRepository);
StepVerifier.create(handler.onAuthorizationSuccess(authorizedClient, authenticationToken, attributes))
.verifyErrorMessage("[invalid_nonce] Invalid nonce");
}
@Test
void setServerSecurityContextRepositoryWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler()
.setServerSecurityContextRepository(null))
.withMessage("serverSecurityContextRepository cannot be null");
}
@Test
void setJwtDecoderFactoryWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler().setJwtDecoderFactory(null))
.withMessage("jwtDecoderFactory cannot be null");
}
@Test
void setAuthoritiesMapperWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler().setAuthoritiesMapper(null))
.withMessage("authoritiesMapper cannot be null");
}
@Test
void setUserServiceWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler().setUserService(null))
.withMessage("userService cannot be null");
}
@Test
void setClockSkewWhenNullThenException() {
assertThatException()
.isThrownBy(() -> new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler().setClockSkew(null))
.withMessage("clockSkew cannot be null");
}
private static OAuth2AccessToken createAccessToken() {
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60));
return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt);
}
private static OAuth2AccessTokenResponse createAccessTokenResponse() {
return OAuth2AccessTokenResponse.withToken("access-token-1234")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, "id-token-1234"))
.build();
}
}

10
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProviderTests.java

@ -49,6 +49,7 @@ import static org.mockito.Mockito.verify; @@ -49,6 +49,7 @@ import static org.mockito.Mockito.verify;
* Tests for {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}.
*
* @author Joe Grandja
* @author Evgeniy Cheban
*/
public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
@ -84,6 +85,15 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests { @@ -84,6 +85,15 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
.withMessage("accessTokenResponseClient cannot be null");
}
@Test
public void setRefreshTokenSuccessHandlerWhenHandlerIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setRefreshTokenSuccessHandler(null))
.withMessage("refreshTokenSuccessHandler cannot be null");
// @formatter:on
}
@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
// @formatter:off

19
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java

@ -18,9 +18,13 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; @@ -18,9 +18,13 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import reactor.core.publisher.Mono;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFunction;
@ -29,14 +33,21 @@ import static org.mockito.Mockito.mock; @@ -29,14 +33,21 @@ import static org.mockito.Mockito.mock;
/**
* @author Rob Winch
* @author Evgeniy Cheban
* @since 5.1
*/
public class MockExchangeFunction implements ExchangeFunction {
private final AtomicReference<Authentication> authenticationCaptor = new AtomicReference<>();
private List<ClientRequest> requests = new ArrayList<>();
private ClientResponse response = mock(ClientResponse.class);
public Authentication getCapturedAuthentication() {
return this.authenticationCaptor.get();
}
public ClientRequest getRequest() {
return this.requests.get(this.requests.size() - 1);
}
@ -53,8 +64,14 @@ public class MockExchangeFunction implements ExchangeFunction { @@ -53,8 +64,14 @@ public class MockExchangeFunction implements ExchangeFunction {
public Mono<ClientResponse> exchange(ClientRequest request) {
return Mono.defer(() -> {
this.requests.add(request);
return Mono.just(this.response);
return captureAuthentication().then(Mono.just(this.response));
});
}
private Mono<Authentication> captureAuthentication() {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.doOnNext(this.authenticationCaptor::set);
}
}

28
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@ -59,6 +59,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; @@ -59,6 +59,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
@ -66,6 +67,7 @@ import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2Authoriz @@ -66,6 +67,7 @@ import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2Authoriz
import org.springframework.security.oauth2.client.JwtBearerReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationSuccessHandler;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
@ -89,8 +91,10 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; @@ -89,8 +91,10 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.TestOAuth2Users;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
@ -113,6 +117,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -113,6 +117,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
/**
* @author Rob Winch
* @author Evgeniy Cheban
* @since 5.1
*/
@ExtendWith(MockitoExtension.class)
@ -136,6 +141,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -136,6 +141,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;
@Mock
private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler;
@Captor
private ArgumentCaptor<OAuth2AuthorizationException> authorizationExceptionCaptor;
@ -170,7 +178,8 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -170,7 +178,8 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.builder()
.authorizationCode()
.refreshToken(
(configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient))
(configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)
.refreshTokenSuccessHandler(this.refreshTokenSuccessHandler))
.clientCredentials(
(configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient))
.provider(jwtBearerAuthorizedClientProvider)
@ -201,6 +210,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -201,6 +210,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(null));
}
@Test
public void setServerSecurityContextRepositoryWhenHandlerIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
.setServerSecurityContextRepository(null));
}
@Test
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
@ -326,14 +342,23 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -326,14 +342,23 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
// @formatter:on
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
// @formatter:off
WebSessionServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
DefaultOAuth2User refreshedUser = TestOAuth2Users.create();
OAuth2AuthenticationToken refreshedAuthentication = new OAuth2AuthenticationToken(refreshedUser, refreshedUser.getAuthorities(), this.registration.getRegistrationId());
SecurityContextImpl securityContext = new SecurityContextImpl(refreshedAuthentication);
given(this.refreshTokenSuccessHandler.onAuthorizationSuccess(any(), eq(authentication), any()))
.willReturn(securityContextRepository.save(this.serverWebExchange, securityContext));
this.function.filter(request, this.exchange)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication))
.contextWrite(serverWebExchange())
.block();
Authentication currentAuthentication = this.exchange.getCapturedAuthentication();
assertThat(currentAuthentication).isSameAs(refreshedAuthentication);
// @formatter:on
verify(this.refreshTokenTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(),
eq(authentication), any());
verify(this.refreshTokenSuccessHandler).onAuthorizationSuccess(any(), eq(authentication), any());
OAuth2AuthorizedClient newAuthorizedClient = this.authorizedClientCaptor.getValue();
assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken());
assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken());
@ -355,6 +380,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -355,6 +380,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.refreshToken("refresh-1")
.build();
given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response));
given(this.refreshTokenSuccessHandler.onAuthorizationSuccess(any(), any(), any())).willReturn(Mono.empty());
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(),

Loading…
Cancel
Save