diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java
new file mode 100644
index 0000000000..7f3fcea96c
--- /dev/null
+++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java
@@ -0,0 +1,156 @@
+/*
+ * Copyright 2004-present the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.util.Assert;
+
+/**
+ * An {@link OAuth2AuthenticationContext} that holds an
+ * {@link OAuth2DeviceVerificationAuthenticationToken} and additional information and is
+ * used when determining if authorization consent is required.
+ *
+ * @author Dinesh Gupta
+ * @since 7.0
+ * @see OAuth2AuthenticationContext
+ * @see OAuth2DeviceVerificationAuthenticationToken
+ * @see OAuth2DeviceVerificationAuthenticationProvider#setAuthorizationConsentRequired(java.util.function.Predicate)
+ */
+public final class OAuth2DeviceVerificationAuthenticationContext implements OAuth2AuthenticationContext {
+
+ private final Map context;
+
+ private OAuth2DeviceVerificationAuthenticationContext(Map context) {
+ this.context = Collections.unmodifiableMap(new HashMap<>(context));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Nullable
+ @Override
+ public V get(Object key) {
+ return hasKey(key) ? (V) this.context.get(key) : null;
+ }
+
+ @Override
+ public boolean hasKey(Object key) {
+ Assert.notNull(key, "key cannot be null");
+ return this.context.containsKey(key);
+ }
+
+ /**
+ * Returns the {@link RegisteredClient registered client}.
+ * @return the {@link RegisteredClient}
+ */
+ public RegisteredClient getRegisteredClient() {
+ return get(RegisteredClient.class);
+ }
+
+ /**
+ * Returns the {@link OAuth2Authorization authorization}.
+ * @return the {@link OAuth2Authorization}
+ */
+ public OAuth2Authorization getAuthorization() {
+ return get(OAuth2Authorization.class);
+ }
+
+ /**
+ * Returns the {@link OAuth2AuthorizationConsent authorization consent}.
+ * @return the {@link OAuth2AuthorizationConsent}, or {@code null} if not available
+ */
+ @Nullable
+ public OAuth2AuthorizationConsent getAuthorizationConsent() {
+ return get(OAuth2AuthorizationConsent.class);
+ }
+
+ /**
+ * Returns the requested scopes.
+ * @return the requested scopes
+ */
+ public Set getRequestedScopes() {
+ Set requestedScopes = getAuthorization().getAttribute(OAuth2ParameterNames.SCOPE);
+ return (requestedScopes != null) ? requestedScopes : Collections.emptySet();
+ }
+
+ /**
+ * Constructs a new {@link Builder} with the provided
+ * {@link OAuth2DeviceVerificationAuthenticationToken}.
+ * @param authentication the {@link OAuth2DeviceVerificationAuthenticationToken}
+ * @return the {@link Builder}
+ */
+ public static Builder with(OAuth2DeviceVerificationAuthenticationToken authentication) {
+ return new Builder(authentication);
+ }
+
+ /**
+ * A builder for {@link OAuth2DeviceVerificationAuthenticationContext}.
+ */
+ public static final class Builder extends AbstractBuilder {
+
+ private Builder(OAuth2DeviceVerificationAuthenticationToken authentication) {
+ super(authentication);
+ }
+
+ /**
+ * Sets the {@link RegisteredClient registered client}.
+ * @param registeredClient the {@link RegisteredClient}
+ * @return the {@link Builder} for further configuration
+ */
+ public Builder registeredClient(RegisteredClient registeredClient) {
+ return put(RegisteredClient.class, registeredClient);
+ }
+
+ /**
+ * Sets the {@link OAuth2Authorization authorization}.
+ * @param authorization the {@link OAuth2Authorization}
+ * @return the {@link Builder} for further configuration
+ */
+ public Builder authorization(OAuth2Authorization authorization) {
+ return put(OAuth2Authorization.class, authorization);
+ }
+
+ /**
+ * Sets the {@link OAuth2AuthorizationConsent authorization consent}.
+ * @param authorizationConsent the {@link OAuth2AuthorizationConsent}
+ * @return the {@link Builder} for further configuration
+ */
+ public Builder authorizationConsent(OAuth2AuthorizationConsent authorizationConsent) {
+ return put(OAuth2AuthorizationConsent.class, authorizationConsent);
+ }
+
+ /**
+ * Builds a new {@link OAuth2DeviceVerificationAuthenticationContext}.
+ * @return the {@link OAuth2DeviceVerificationAuthenticationContext}
+ */
+ @Override
+ public OAuth2DeviceVerificationAuthenticationContext build() {
+ Assert.notNull(get(RegisteredClient.class), "registeredClient cannot be null");
+ Assert.notNull(get(OAuth2Authorization.class), "authorization cannot be null");
+ return new OAuth2DeviceVerificationAuthenticationContext(getContext());
+ }
+
+ }
+
+}
diff --git a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java
index e8ab5d1a95..140e096a2e 100644
--- a/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java
+++ b/oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java
@@ -19,6 +19,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.security.Principal;
import java.util.Base64;
import java.util.Set;
+import java.util.function.Predicate;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -79,6 +80,8 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
private final OAuth2AuthorizationConsentService authorizationConsentService;
+ private Predicate authorizationConsentRequired = OAuth2DeviceVerificationAuthenticationProvider::isAuthorizationConsentRequired;
+
/**
* Constructs an {@code OAuth2DeviceVerificationAuthenticationProvider} using the
* provided parameters.
@@ -143,10 +146,18 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
Set requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE);
+ OAuth2DeviceVerificationAuthenticationContext.Builder authenticationContextBuilder = OAuth2DeviceVerificationAuthenticationContext
+ .with(deviceVerificationAuthentication)
+ .registeredClient(registeredClient)
+ .authorization(authorization);
+
OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService
.findById(registeredClient.getId(), principal.getName());
+ if (currentAuthorizationConsent != null) {
+ authenticationContextBuilder.authorizationConsent(currentAuthorizationConsent);
+ }
- if (requiresAuthorizationConsent(requestedScopes, currentAuthorizationConsent)) {
+ if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) {
String state = DEFAULT_STATE_GENERATOR.generateKey();
authorization = OAuth2Authorization.from(authorization)
.principalName(principal.getName())
@@ -204,10 +215,37 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
return OAuth2DeviceVerificationAuthenticationToken.class.isAssignableFrom(authentication);
}
- private static boolean requiresAuthorizationConsent(Set requestedScopes,
- OAuth2AuthorizationConsent authorizationConsent) {
+ /**
+ * Sets the {@code Predicate} used to determine if authorization consent is required.
+ *
+ *
+ * The {@link OAuth2DeviceVerificationAuthenticationContext} gives the predicate
+ * access to the {@link OAuth2DeviceVerificationAuthenticationToken}, as well as, the
+ * following context attributes:
+ *
+ * The {@link RegisteredClient} associated with the device authorization
+ * request.
+ * The {@link OAuth2Authorization} containing the device authorization request
+ * parameters.
+ * The {@link OAuth2AuthorizationConsent} previously granted to the
+ * {@link RegisteredClient}, or {@code null} if not available.
+ *
+ *
+ * @param authorizationConsentRequired the {@code Predicate} used to determine if
+ * authorization consent is required
+ */
+ public void setAuthorizationConsentRequired(
+ Predicate authorizationConsentRequired) {
+ Assert.notNull(authorizationConsentRequired, "authorizationConsentRequired cannot be null");
+ this.authorizationConsentRequired = authorizationConsentRequired;
+ }
+
+ private static boolean isAuthorizationConsentRequired(
+ OAuth2DeviceVerificationAuthenticationContext authenticationContext) {
- if (authorizationConsent != null && authorizationConsent.getScopes().containsAll(requestedScopes)) {
+ if (authenticationContext.getAuthorizationConsent() != null && authenticationContext.getAuthorizationConsent()
+ .getScopes()
+ .containsAll(authenticationContext.getRequestedScopes())) {
return false;
}
diff --git a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java
index 77478e2dee..23d775db03 100644
--- a/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java
+++ b/oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java
@@ -23,6 +23,7 @@ import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
+import java.util.function.Predicate;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -125,6 +126,13 @@ public class OAuth2DeviceVerificationAuthenticationProviderTests {
// @formatter:on
}
+ @Test
+ public void setAuthorizationConsentRequiredWhenNullThenThrowIllegalArgumentException() {
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> this.authenticationProvider.setAuthorizationConsentRequired(null))
+ .withMessage("authorizationConsentRequired cannot be null");
+ }
+
@Test
public void supportsWhenTypeOAuth2DeviceVerificationAuthenticationTokenThenReturnTrue() {
assertThat(this.authenticationProvider.supports(OAuth2DeviceVerificationAuthenticationToken.class)).isTrue();
@@ -382,6 +390,31 @@ public class OAuth2DeviceVerificationAuthenticationProviderTests {
.isEqualTo(authenticationResult.getState());
}
+ @Test
+ public void authenticateWhenCustomAuthorizationConsentRequiredThenUsed() {
+ @SuppressWarnings("unchecked")
+ Predicate authorizationConsentRequired = mock(Predicate.class);
+ this.authenticationProvider.setAuthorizationConsentRequired(authorizationConsentRequired);
+
+ // @formatter:off
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+ OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
+ .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE)
+ .token(createDeviceCode())
+ .token(createUserCode())
+ .attributes(Map::clear)
+ .attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes())
+ .build();
+ // @formatter:on
+ Authentication authentication = createAuthentication();
+ given(this.registeredClientRepository.findById(anyString())).willReturn(registeredClient);
+ given(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).willReturn(authorization);
+
+ this.authenticationProvider.authenticate(authentication);
+
+ verify(authorizationConsentRequired).test(any());
+ }
+
private static void mockAuthorizationServerContext() {
AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build();
TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext(