diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java index 02d9530267..3274dcf82e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java @@ -31,8 +31,11 @@ import org.springframework.security.oauth2.client.authentication.jwt.ProviderJwt import org.springframework.security.oauth2.client.authentication.nimbus.NimbusAuthorizationCodeTokenExchanger; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.token.InMemoryAccessTokenRepository; +import org.springframework.security.oauth2.client.token.SecurityTokenRepository; import org.springframework.security.oauth2.client.user.OAuth2UserService; import org.springframework.security.oauth2.client.user.nimbus.NimbusOAuth2UserService; +import org.springframework.security.oauth2.core.AccessToken; import org.springframework.security.oauth2.core.http.HttpClientConfig; import org.springframework.security.oauth2.core.provider.DefaultProviderMetadata; import org.springframework.security.oauth2.core.provider.ProviderMetadata; @@ -57,6 +60,7 @@ final class AuthorizationCodeAuthenticationFilterConfigurer authorizationCodeTokenExchanger; + private SecurityTokenRepository accessTokenRepository; private OAuth2UserService userInfoService; private Map> customUserTypes = new HashMap<>(); private Map userNameAttributeNames = new HashMap<>(); @@ -80,6 +84,12 @@ final class AuthorizationCodeAuthenticationFilterConfigurer accessTokenRepository(SecurityTokenRepository accessTokenRepository) { + Assert.notNull(accessTokenRepository, "accessTokenRepository cannot be null"); + this.accessTokenRepository = accessTokenRepository; + return this; + } + AuthorizationCodeAuthenticationFilterConfigurer userInfoService(OAuth2UserService userInfoService) { Assert.notNull(userInfoService, "userInfoService cannot be null"); this.userInfoService = userInfoService; @@ -124,7 +134,8 @@ final class AuthorizationCodeAuthenticationFilterConfigurer getAccessTokenRepository() { + if (this.accessTokenRepository == null) { + this.accessTokenRepository = new InMemoryAccessTokenRepository(); + } + return this.accessTokenRepository; + } + private ProviderJwtDecoderRegistry getProviderJwtDecoderRegistry(H http) { HttpClientConfig httpClientConfig = this.getHttpClientConfig(http); Map jwtDecoders = new HashMap<>(); diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 592acb1e2a..71b61c1182 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -26,7 +26,9 @@ import org.springframework.security.oauth2.client.authentication.AuthorizationRe import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; +import org.springframework.security.oauth2.client.token.SecurityTokenRepository; import org.springframework.security.oauth2.client.user.OAuth2UserService; +import org.springframework.security.oauth2.core.AccessToken; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; @@ -139,6 +141,12 @@ public final class OAuth2LoginConfigurer> exten return this; } + public TokenEndpointConfig accessTokenRepository(SecurityTokenRepository accessTokenRepository) { + Assert.notNull(accessTokenRepository, "accessTokenRepository cannot be null"); + OAuth2LoginConfigurer.this.authorizationCodeAuthenticationFilterConfigurer.accessTokenRepository(accessTokenRepository); + return this; + } + public OAuth2LoginConfigurer and() { return OAuth2LoginConfigurer.this; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProvider.java index 78af55d13a..910018ed71 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProvider.java @@ -26,6 +26,7 @@ import org.springframework.security.jwt.Jwt; import org.springframework.security.jwt.JwtDecoder; import org.springframework.security.oauth2.client.authentication.jwt.ProviderJwtDecoderRegistry; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.token.SecurityTokenRepository; import org.springframework.security.oauth2.client.user.OAuth2UserService; import org.springframework.security.oauth2.core.AccessToken; import org.springframework.security.oauth2.core.endpoint.TokenResponseAttributes; @@ -79,19 +80,23 @@ import java.util.Collection; */ public class AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { private final AuthorizationGrantTokenExchanger authorizationCodeTokenExchanger; + private final SecurityTokenRepository accessTokenRepository; private final ProviderJwtDecoderRegistry providerJwtDecoderRegistry; private final OAuth2UserService userInfoService; private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper(); public AuthorizationCodeAuthenticationProvider( AuthorizationGrantTokenExchanger authorizationCodeTokenExchanger, + SecurityTokenRepository accessTokenRepository, ProviderJwtDecoderRegistry providerJwtDecoderRegistry, OAuth2UserService userInfoService) { Assert.notNull(authorizationCodeTokenExchanger, "authorizationCodeTokenExchanger cannot be null"); + Assert.notNull(accessTokenRepository, "accessTokenRepository cannot be null"); Assert.notNull(providerJwtDecoderRegistry, "providerJwtDecoderRegistry cannot be null"); Assert.notNull(userInfoService, "userInfoService cannot be null"); this.authorizationCodeTokenExchanger = authorizationCodeTokenExchanger; + this.accessTokenRepository = accessTokenRepository; this.providerJwtDecoderRegistry = providerJwtDecoderRegistry; this.userInfoService = userInfoService; } @@ -134,6 +139,8 @@ public class AuthorizationCodeAuthenticationProvider implements AuthenticationPr accessTokenAuthentication.getAccessToken(), accessTokenAuthentication.getIdToken()); authenticationResult.setDetails(accessTokenAuthentication.getDetails()); + this.accessTokenRepository.saveSecurityToken(accessToken, authenticationResult); + return authenticationResult; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/token/InMemoryAccessTokenRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/token/InMemoryAccessTokenRepository.java new file mode 100644 index 0000000000..e4e1cb6e7f --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/token/InMemoryAccessTokenRepository.java @@ -0,0 +1,72 @@ +/* + * Copyright 2012-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.token; + +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.core.AccessToken; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.oidc.core.user.OidcUser; +import org.springframework.util.Assert; + +import java.util.HashMap; +import java.util.Map; + +/** + * A basic implementation of a {@link SecurityTokenRepository} + * that stores {@link AccessToken}(s) in-memory. + * + * @author Joe Grandja + * @since 5.0 + * @see SecurityTokenRepository + * @see AccessToken + */ +public final class InMemoryAccessTokenRepository implements SecurityTokenRepository { + private final Map accessTokens = new HashMap<>(); + + @Override + public AccessToken loadSecurityToken(OAuth2AuthenticationToken authentication) { + Assert.notNull(authentication, "authentication cannot be null"); + return this.accessTokens.get(this.resolveAuthenticationKey(authentication)); + } + + @Override + public void saveSecurityToken(AccessToken accessToken, OAuth2AuthenticationToken authentication) { + Assert.notNull(accessToken, "accessToken cannot be null"); + Assert.notNull(authentication, "authentication cannot be null"); + this.accessTokens.put(this.resolveAuthenticationKey(authentication), accessToken); + } + + @Override + public void removeSecurityToken(OAuth2AuthenticationToken authentication) { + Assert.notNull(authentication, "authentication cannot be null"); + this.accessTokens.remove(this.resolveAuthenticationKey(authentication)); + } + + private String resolveAuthenticationKey(OAuth2AuthenticationToken authentication) { + String authenticationKey; + + OAuth2User oauth2User = (OAuth2User) authentication.getPrincipal(); + if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) { + OidcUser oidcUser = (OidcUser)oauth2User; + authenticationKey = oidcUser.getIssuer().toString() + "-" + oidcUser.getSubject(); + } else { + authenticationKey = authentication.getClientRegistration().getProviderDetails().getUserInfoUri() + + "-" + oauth2User.getName(); + } + + return authenticationKey; + } +}