diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java new file mode 100644 index 0000000000..7f3fcea96c --- /dev/null +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java @@ -0,0 +1,156 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; + +/** + * An {@link OAuth2AuthenticationContext} that holds an + * {@link OAuth2DeviceVerificationAuthenticationToken} and additional information and is + * used when determining if authorization consent is required. + * + * @author Dinesh Gupta + * @since 7.0 + * @see OAuth2AuthenticationContext + * @see OAuth2DeviceVerificationAuthenticationToken + * @see OAuth2DeviceVerificationAuthenticationProvider#setAuthorizationConsentRequired(java.util.function.Predicate) + */ +public final class OAuth2DeviceVerificationAuthenticationContext implements OAuth2AuthenticationContext { + + private final Map context; + + private OAuth2DeviceVerificationAuthenticationContext(Map context) { + this.context = Collections.unmodifiableMap(new HashMap<>(context)); + } + + @SuppressWarnings("unchecked") + @Nullable + @Override + public V get(Object key) { + return hasKey(key) ? (V) this.context.get(key) : null; + } + + @Override + public boolean hasKey(Object key) { + Assert.notNull(key, "key cannot be null"); + return this.context.containsKey(key); + } + + /** + * Returns the {@link RegisteredClient registered client}. + * @return the {@link RegisteredClient} + */ + public RegisteredClient getRegisteredClient() { + return get(RegisteredClient.class); + } + + /** + * Returns the {@link OAuth2Authorization authorization}. + * @return the {@link OAuth2Authorization} + */ + public OAuth2Authorization getAuthorization() { + return get(OAuth2Authorization.class); + } + + /** + * Returns the {@link OAuth2AuthorizationConsent authorization consent}. + * @return the {@link OAuth2AuthorizationConsent}, or {@code null} if not available + */ + @Nullable + public OAuth2AuthorizationConsent getAuthorizationConsent() { + return get(OAuth2AuthorizationConsent.class); + } + + /** + * Returns the requested scopes. + * @return the requested scopes + */ + public Set getRequestedScopes() { + Set requestedScopes = getAuthorization().getAttribute(OAuth2ParameterNames.SCOPE); + return (requestedScopes != null) ? requestedScopes : Collections.emptySet(); + } + + /** + * Constructs a new {@link Builder} with the provided + * {@link OAuth2DeviceVerificationAuthenticationToken}. + * @param authentication the {@link OAuth2DeviceVerificationAuthenticationToken} + * @return the {@link Builder} + */ + public static Builder with(OAuth2DeviceVerificationAuthenticationToken authentication) { + return new Builder(authentication); + } + + /** + * A builder for {@link OAuth2DeviceVerificationAuthenticationContext}. + */ + public static final class Builder extends AbstractBuilder { + + private Builder(OAuth2DeviceVerificationAuthenticationToken authentication) { + super(authentication); + } + + /** + * Sets the {@link RegisteredClient registered client}. + * @param registeredClient the {@link RegisteredClient} + * @return the {@link Builder} for further configuration + */ + public Builder registeredClient(RegisteredClient registeredClient) { + return put(RegisteredClient.class, registeredClient); + } + + /** + * Sets the {@link OAuth2Authorization authorization}. + * @param authorization the {@link OAuth2Authorization} + * @return the {@link Builder} for further configuration + */ + public Builder authorization(OAuth2Authorization authorization) { + return put(OAuth2Authorization.class, authorization); + } + + /** + * Sets the {@link OAuth2AuthorizationConsent authorization consent}. + * @param authorizationConsent the {@link OAuth2AuthorizationConsent} + * @return the {@link Builder} for further configuration + */ + public Builder authorizationConsent(OAuth2AuthorizationConsent authorizationConsent) { + return put(OAuth2AuthorizationConsent.class, authorizationConsent); + } + + /** + * Builds a new {@link OAuth2DeviceVerificationAuthenticationContext}. + * @return the {@link OAuth2DeviceVerificationAuthenticationContext} + */ + @Override + public OAuth2DeviceVerificationAuthenticationContext build() { + Assert.notNull(get(RegisteredClient.class), "registeredClient cannot be null"); + Assert.notNull(get(OAuth2Authorization.class), "authorization cannot be null"); + return new OAuth2DeviceVerificationAuthenticationContext(getContext()); + } + + } + +} diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java index e8ab5d1a95..140e096a2e 100644 --- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java +++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java @@ -19,6 +19,7 @@ package org.springframework.security.oauth2.server.authorization.authentication; import java.security.Principal; import java.util.Base64; import java.util.Set; +import java.util.function.Predicate; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -79,6 +80,8 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut private final OAuth2AuthorizationConsentService authorizationConsentService; + private Predicate authorizationConsentRequired = OAuth2DeviceVerificationAuthenticationProvider::isAuthorizationConsentRequired; + /** * Constructs an {@code OAuth2DeviceVerificationAuthenticationProvider} using the * provided parameters. @@ -143,10 +146,18 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut Set requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE); + OAuth2DeviceVerificationAuthenticationContext.Builder authenticationContextBuilder = OAuth2DeviceVerificationAuthenticationContext + .with(deviceVerificationAuthentication) + .registeredClient(registeredClient) + .authorization(authorization); + OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService .findById(registeredClient.getId(), principal.getName()); + if (currentAuthorizationConsent != null) { + authenticationContextBuilder.authorizationConsent(currentAuthorizationConsent); + } - if (requiresAuthorizationConsent(requestedScopes, currentAuthorizationConsent)) { + if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) { String state = DEFAULT_STATE_GENERATOR.generateKey(); authorization = OAuth2Authorization.from(authorization) .principalName(principal.getName()) @@ -204,10 +215,37 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut return OAuth2DeviceVerificationAuthenticationToken.class.isAssignableFrom(authentication); } - private static boolean requiresAuthorizationConsent(Set requestedScopes, - OAuth2AuthorizationConsent authorizationConsent) { + /** + * Sets the {@code Predicate} used to determine if authorization consent is required. + * + *

+ * The {@link OAuth2DeviceVerificationAuthenticationContext} gives the predicate + * access to the {@link OAuth2DeviceVerificationAuthenticationToken}, as well as, the + * following context attributes: + *

    + *
  • The {@link RegisteredClient} associated with the device authorization + * request.
  • + *
  • The {@link OAuth2Authorization} containing the device authorization request + * parameters.
  • + *
  • The {@link OAuth2AuthorizationConsent} previously granted to the + * {@link RegisteredClient}, or {@code null} if not available.
  • + *
+ *

+ * @param authorizationConsentRequired the {@code Predicate} used to determine if + * authorization consent is required + */ + public void setAuthorizationConsentRequired( + Predicate authorizationConsentRequired) { + Assert.notNull(authorizationConsentRequired, "authorizationConsentRequired cannot be null"); + this.authorizationConsentRequired = authorizationConsentRequired; + } + + private static boolean isAuthorizationConsentRequired( + OAuth2DeviceVerificationAuthenticationContext authenticationContext) { - if (authorizationConsent != null && authorizationConsent.getScopes().containsAll(requestedScopes)) { + if (authenticationContext.getAuthorizationConsent() != null && authenticationContext.getAuthorizationConsent() + .getScopes() + .containsAll(authenticationContext.getRequestedScopes())) { return false; } diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java index 77478e2dee..23d775db03 100644 --- a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java +++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Predicate; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -125,6 +126,13 @@ public class OAuth2DeviceVerificationAuthenticationProviderTests { // @formatter:on } + @Test + public void setAuthorizationConsentRequiredWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setAuthorizationConsentRequired(null)) + .withMessage("authorizationConsentRequired cannot be null"); + } + @Test public void supportsWhenTypeOAuth2DeviceVerificationAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OAuth2DeviceVerificationAuthenticationToken.class)).isTrue(); @@ -382,6 +390,31 @@ public class OAuth2DeviceVerificationAuthenticationProviderTests { .isEqualTo(authenticationResult.getState()); } + @Test + public void authenticateWhenCustomAuthorizationConsentRequiredThenUsed() { + @SuppressWarnings("unchecked") + Predicate authorizationConsentRequired = mock(Predicate.class); + this.authenticationProvider.setAuthorizationConsentRequired(authorizationConsentRequired); + + // @formatter:off + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE) + .token(createDeviceCode()) + .token(createUserCode()) + .attributes(Map::clear) + .attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes()) + .build(); + // @formatter:on + Authentication authentication = createAuthentication(); + given(this.registeredClientRepository.findById(anyString())).willReturn(registeredClient); + given(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).willReturn(authorization); + + this.authenticationProvider.authenticate(authentication); + + verify(authorizationConsentRequired).test(any()); + } + private static void mockAuthorizationServerContext() { AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build(); TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext(