Browse Source

Provide more flexibility on when to display consent page

Closes gh-1541
pull/1609/head
MrJovanovic13 2 years ago committed by Joe Grandja
parent
commit
2b7da9fc5a
  1. 45
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationContext.java
  2. 80
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java
  3. 85
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

45
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationContext.java

@ -21,6 +21,8 @@ import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -63,6 +65,27 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationContext implement
return get(RegisteredClient.class); return get(RegisteredClient.class);
} }
/**
* Returns the {@link OAuth2AuthorizationRequest oauth2 authorization request}.
*
* @return the {@link OAuth2AuthorizationRequest}
*/
@Nullable
public OAuth2AuthorizationRequest getOAuth2AuthorizationRequest() {
return get(OAuth2AuthorizationRequest.class);
}
/**
* Returns the {@link OAuth2AuthorizationConsent oauth2 authorization consent}.
*
* @return the {@link OAuth2AuthorizationConsent}
*/
@Nullable
public OAuth2AuthorizationConsent getOAuth2AuthorizationConsent() {
return get(OAuth2AuthorizationConsent.class);
}
/** /**
* Constructs a new {@link Builder} with the provided {@link OAuth2AuthorizationCodeRequestAuthenticationToken}. * Constructs a new {@link Builder} with the provided {@link OAuth2AuthorizationCodeRequestAuthenticationToken}.
* *
@ -92,6 +115,28 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationContext implement
return put(RegisteredClient.class, registeredClient); return put(RegisteredClient.class, registeredClient);
} }
/**
* Sets the {@link OAuth2AuthorizationRequest oauth2 authorization request}.
*
* @param authorizationRequest the {@link OAuth2AuthorizationRequest}
* @return the {@link Builder} for further configuration
* @since 1.3.0
*/
public Builder authorizationRequest(OAuth2AuthorizationRequest authorizationRequest) {
return put(OAuth2AuthorizationRequest.class, authorizationRequest);
}
/**
* Sets the {@link OAuth2AuthorizationConsent oauth2 authorization consent}.
*
* @param authorizationConsent the {@link OAuth2AuthorizationConsent}
* @return the {@link Builder} for further configuration
* @since 1.3.0
*/
public Builder authorizationConsent(OAuth2AuthorizationConsent authorizationConsent) {
return put(OAuth2AuthorizationConsent.class, authorizationConsent);
}
/** /**
* Builds a new {@link OAuth2AuthorizationCodeRequestAuthenticationContext}. * Builds a new {@link OAuth2AuthorizationCodeRequestAuthenticationContext}.
* *

80
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

@ -19,6 +19,7 @@ import java.security.Principal;
import java.util.Base64; import java.util.Base64;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Predicate;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -80,6 +81,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
private OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator = new OAuth2AuthorizationCodeGenerator(); private OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator = new OAuth2AuthorizationCodeGenerator();
private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator = private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator =
new OAuth2AuthorizationCodeRequestAuthenticationValidator(); new OAuth2AuthorizationCodeRequestAuthenticationValidator();
private Predicate<OAuth2AuthorizationCodeRequestAuthenticationContext> requiresAuthorizationConsent;
/** /**
* Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationProvider} using the provided parameters. * Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationProvider} using the provided parameters.
@ -96,6 +98,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
this.registeredClientRepository = registeredClientRepository; this.registeredClientRepository = registeredClientRepository;
this.authorizationService = authorizationService; this.authorizationService = authorizationService;
this.authorizationConsentService = authorizationConsentService; this.authorizationConsentService = authorizationConsentService;
this.requiresAuthorizationConsent = this::requireAuthorizationConsent;
} }
@Override @Override
@ -171,7 +174,19 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService.findById( OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService.findById(
registeredClient.getId(), principal.getName()); registeredClient.getId(), principal.getName());
if (requireAuthorizationConsent(registeredClient, authorizationRequest, currentAuthorizationConsent)) { OAuth2AuthorizationCodeRequestAuthenticationContext.Builder authenticationContextBuilder =
OAuth2AuthorizationCodeRequestAuthenticationContext.with(authorizationCodeRequestAuthentication)
.registeredClient(registeredClient)
.authorizationRequest(authorizationRequest);
if (currentAuthorizationConsent != null) {
authenticationContextBuilder.authorizationConsent(currentAuthorizationConsent);
}
OAuth2AuthorizationCodeRequestAuthenticationContext contextWithAuthorizationRequestAndAuthorizationConsent =
authenticationContextBuilder.build();
if (requiresAuthorizationConsent.test(contextWithAuthorizationRequestAndAuthorizationConsent)) {
String state = DEFAULT_STATE_GENERATOR.generateKey(); String state = DEFAULT_STATE_GENERATOR.generateKey();
OAuth2Authorization authorization = authorizationBuilder(registeredClient, principal, authorizationRequest) OAuth2Authorization authorization = authorizationBuilder(registeredClient, principal, authorizationRequest)
.attribute(OAuth2ParameterNames.STATE, state) .attribute(OAuth2ParameterNames.STATE, state)
@ -264,7 +279,48 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
this.authenticationValidator = authenticationValidator; this.authenticationValidator = authenticationValidator;
} }
private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient, Authentication principal, /**
* Sets the {@link Predicate} used to determine if authorization consent is required.
*
* <p>
* The {@link OAuth2AuthorizationCodeRequestAuthenticationContext} gives the predicate access to the {@link OAuth2AuthorizationCodeRequestAuthenticationToken},
* as well as, the following context attributes:
* {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getRegisteredClient()} containing {@link RegisteredClient} used to make the request.
* {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getOAuth2AuthorizationRequest()} containing {@link OAuth2AuthorizationRequest}.
* {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getOAuth2AuthorizationConsent()} containing {@link OAuth2AuthorizationConsent} granted in the request.
*
* @param requiresAuthorizationConsent the {@link Predicate} that determines if authorization consent is required.
* @since 1.3.0
*/
public void setRequiresAuthorizationConsent(Predicate<OAuth2AuthorizationCodeRequestAuthenticationContext> requiresAuthorizationConsent) {
Assert.notNull(requiresAuthorizationConsent, "requiresAuthorizationConsent cannot be null");
this.requiresAuthorizationConsent = requiresAuthorizationConsent;
}
private boolean requireAuthorizationConsent(OAuth2AuthorizationCodeRequestAuthenticationContext context) {
RegisteredClient registeredClient = context.getRegisteredClient();
if (!registeredClient.getClientSettings().isRequireAuthorizationConsent()) {
return false;
}
OAuth2AuthorizationRequest authorizationRequest = context.getOAuth2AuthorizationRequest();
// 'openid' scope does not require consent
if (authorizationRequest.getScopes().contains(OidcScopes.OPENID) &&
authorizationRequest.getScopes().size() == 1) {
return false;
}
OAuth2AuthorizationConsent authorizationConsent = context.getOAuth2AuthorizationConsent();
if (authorizationConsent != null &&
authorizationConsent.getScopes().containsAll(authorizationRequest.getScopes())) {
return false;
}
return true;
}
private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient,
Authentication principal,
OAuth2AuthorizationRequest authorizationRequest) { OAuth2AuthorizationRequest authorizationRequest) {
return OAuth2Authorization.withRegisteredClient(registeredClient) return OAuth2Authorization.withRegisteredClient(registeredClient)
.principalName(principal.getName()) .principalName(principal.getName())
@ -295,26 +351,6 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
return tokenContextBuilder.build(); return tokenContextBuilder.build();
} }
private static boolean requireAuthorizationConsent(RegisteredClient registeredClient,
OAuth2AuthorizationRequest authorizationRequest, OAuth2AuthorizationConsent authorizationConsent) {
if (!registeredClient.getClientSettings().isRequireAuthorizationConsent()) {
return false;
}
// 'openid' scope does not require consent
if (authorizationRequest.getScopes().contains(OidcScopes.OPENID) &&
authorizationRequest.getScopes().size() == 1) {
return false;
}
if (authorizationConsent != null &&
authorizationConsent.getScopes().containsAll(authorizationRequest.getScopes())) {
return false;
}
return true;
}
private static boolean isPrincipalAuthenticated(Authentication principal) { private static boolean isPrincipalAuthenticated(Authentication principal) {
return principal != null && return principal != null &&
!AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) && !AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) &&

85
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

@ -21,6 +21,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Predicate;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -72,6 +73,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
private OAuth2AuthorizationConsentService authorizationConsentService; private OAuth2AuthorizationConsentService authorizationConsentService;
private OAuth2AuthorizationCodeRequestAuthenticationProvider authenticationProvider; private OAuth2AuthorizationCodeRequestAuthenticationProvider authenticationProvider;
private TestingAuthenticationToken principal; private TestingAuthenticationToken principal;
private Predicate<OAuth2AuthorizationCodeRequestAuthenticationContext> requiresAuthorizationConsent;
@BeforeEach @BeforeEach
public void setUp() { public void setUp() {
@ -129,6 +131,13 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
.hasMessage("authenticationValidator cannot be null"); .hasMessage("authenticationValidator cannot be null");
} }
@Test
public void setRequiresAuthorizationConsentWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authenticationProvider.setRequiresAuthorizationConsent(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("requiresAuthorizationConsent cannot be null");
}
@Test @Test
public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() { public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@ -443,6 +452,82 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
assertThat(authenticationResult.isAuthenticated()).isTrue(); assertThat(authenticationResult.isAuthenticated()).isTrue();
} }
@Test
public void authenticateWhenRequireAuthorizationConsentAndRequiresAuthorizationConsentPredicateTrueThenReturnAuthorizationConsent() {
this.authenticationProvider.setRequiresAuthorizationConsent((authenticationContext) -> true);
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
.clientSettings(ClientSettings.builder().requireAuthorizationConsent(true).build())
.build();
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);
String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[0];
OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
new OAuth2AuthorizationCodeRequestAuthenticationToken(
AUTHORIZATION_URI, registeredClient.getClientId(), principal,
redirectUri, STATE, registeredClient.getScopes(), null);
OAuth2AuthorizationConsentAuthenticationToken authenticationResult =
(OAuth2AuthorizationConsentAuthenticationToken) this.authenticationProvider.authenticate(authentication);
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization authorization = authorizationCaptor.getValue();
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(authentication.getAuthorizationUri());
assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId());
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(authentication.getRedirectUri());
assertThat(authorizationRequest.getScopes()).isEqualTo(authentication.getScopes());
assertThat(authorizationRequest.getState()).isEqualTo(authentication.getState());
assertThat(authorizationRequest.getAdditionalParameters()).isEqualTo(authentication.getAdditionalParameters());
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
assertThat(authorization.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(authorization.<Authentication>getAttribute(Principal.class.getName())).isEqualTo(this.principal);
String state = authorization.getAttribute(OAuth2ParameterNames.STATE);
assertThat(state).isNotNull();
assertThat(state).isNotEqualTo(authentication.getState());
assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId());
assertThat(authenticationResult.getPrincipal()).isEqualTo(this.principal);
assertThat(authenticationResult.getAuthorizationUri()).isEqualTo(authorizationRequest.getAuthorizationUri());
assertThat(authenticationResult.getScopes()).isEmpty();
assertThat(authenticationResult.getState()).isEqualTo(state);
assertThat(authenticationResult.isAuthenticated()).isTrue();
}
@Test
public void authenticateWhenRequireAuthorizationConsentAndRequiresAuthorizationConsentPredicateFalseThenAuthorizationConsentNotRequired() {
this.authenticationProvider.setRequiresAuthorizationConsent((authenticationContext) -> false);
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
.clientSettings(ClientSettings.builder().requireAuthorizationConsent(true).build())
.scopes(scopes -> {
scopes.clear();
scopes.add(OidcScopes.OPENID);
scopes.add(OidcScopes.EMAIL);
})
.build();
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);
String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[1];
OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
new OAuth2AuthorizationCodeRequestAuthenticationToken(
AUTHORIZATION_URI, registeredClient.getClientId(), principal,
redirectUri, STATE, registeredClient.getScopes(), null);
OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult =
(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication);
assertAuthorizationCodeRequestWithAuthorizationCodeResult(registeredClient, authentication, authenticationResult);
}
@Test @Test
public void authenticateWhenRequireAuthorizationConsentAndOnlyOpenidScopeRequestedThenAuthorizationConsentNotRequired() { public void authenticateWhenRequireAuthorizationConsentAndOnlyOpenidScopeRequestedThenAuthorizationConsentNotRequired() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient() RegisteredClient registeredClient = TestRegisteredClients.registeredClient()

Loading…
Cancel
Save