From 8e04da773d644bd335323abcbaee2ab89451a541 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Fri, 17 Mar 2023 18:05:15 -0500 Subject: [PATCH] Add tests for OAuth 2.0 Device Authorization Grant This commit adds tests for the following components: * AuthenticationConverters * AuthenticationProviders * Endpoint Filters Issue gh-44 Closes gh-1127 --- ...ionConsentAuthenticationProviderTests.java | 444 +++++++++++++++++ ...ionRequestAuthenticationProviderTests.java | 351 +++++++++++++ ...DeviceCodeAuthenticationProviderTests.java | 431 ++++++++++++++++ ...rificationAuthenticationProviderTests.java | 326 +++++++++++++ ...eviceAuthorizationEndpointFilterTests.java | 423 ++++++++++++++++ ...DeviceVerificationEndpointFilterTests.java | 460 ++++++++++++++++++ ...onConsentAuthenticationConverterTests.java | 295 +++++++++++ ...onRequestAuthenticationConverterTests.java | 120 +++++ ...eviceCodeAuthenticationConverterTests.java | 126 +++++ ...ificationAuthenticationConverterTests.java | 168 +++++++ 10 files changed, 3144 insertions(+) create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationConsentAuthenticationProviderTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationRequestAuthenticationProviderTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceCodeAuthenticationProviderTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceAuthorizationEndpointFilterTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverterTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverterTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverterTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationConsentAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationConsentAuthenticationProviderTests.java new file mode 100644 index 00000000..8e0b5ffe --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationConsentAuthenticationProviderTests.java @@ -0,0 +1,444 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2DeviceCode; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.OAuth2UserCode; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationConsentAuthenticationProvider.STATE_TOKEN_TYPE; + +/** + * Tests for {@link OAuth2DeviceAuthorizationConsentAuthenticationProvider}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceAuthorizationConsentAuthenticationProviderTests { + private static final String AUTHORIZATION_URI = "/oauth2/device_authorization"; + private static final String DEVICE_CODE = "EfYu_0jEL"; + private static final String USER_CODE = "BCDF-GHJK"; + private static final String STATE = "abc123"; + + private RegisteredClientRepository registeredClientRepository; + private OAuth2AuthorizationService authorizationService; + private OAuth2AuthorizationConsentService authorizationConsentService; + private OAuth2DeviceAuthorizationConsentAuthenticationProvider authenticationProvider; + + @BeforeEach + public void setUp() { + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authorizationConsentService = mock(OAuth2AuthorizationConsentService.class); + this.authenticationProvider = new OAuth2DeviceAuthorizationConsentAuthenticationProvider( + this.registeredClientRepository, this.authorizationService, this.authorizationConsentService); + } + + @Test + public void constructorWhenRegisteredClientRepositoryIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceAuthorizationConsentAuthenticationProvider( + null, this.authorizationService, this.authorizationConsentService)) + .withMessage("registeredClientRepository cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenAuthorizationServiceIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceAuthorizationConsentAuthenticationProvider( + this.registeredClientRepository, null, this.authorizationConsentService)) + .withMessage("authorizationService cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenAuthorizationConsentServiceIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceAuthorizationConsentAuthenticationProvider( + this.registeredClientRepository, this.authorizationService, null)) + .withMessage("authorizationConsentService cannot be null"); + // @formatter:on + } + + @Test + public void setAuthorizationConsentCustomizerWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setAuthorizationConsentCustomizer(null)) + .withMessageContaining("authorizationConsentCustomizer cannot be null"); + // @formatter:on + } + + @Test + public void supportsWhenTypeOAuth2DeviceAuthorizationRequestAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2DeviceAuthorizationConsentAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenAuthorizationNotFoundThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining(OAuth2ParameterNames.STATE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verifyNoInteractions(this.registeredClientRepository, this.authorizationConsentService); + } + + @Test + public void authenticateWhenPrincipalIsNotAuthenticatedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = createAuthorization(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + TestingAuthenticationToken principal = new TestingAuthenticationToken(authorization.getPrincipalName(), null); + Authentication authentication = new OAuth2DeviceAuthorizationConsentAuthenticationToken(AUTHORIZATION_URI, + registeredClient.getClientId(), principal, USER_CODE, STATE, null, Collections.emptyMap()); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining(OAuth2ParameterNames.STATE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verifyNoInteractions(this.registeredClientRepository, this.authorizationConsentService); + } + + @Test + public void authenticateWhenPrincipalNameDoesNotMatchThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = createAuthorization(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + TestingAuthenticationToken principal = new TestingAuthenticationToken("invalid", null, Collections.emptyList()); + Authentication authentication = new OAuth2DeviceAuthorizationConsentAuthenticationToken(AUTHORIZATION_URI, + registeredClient.getClientId(), principal, USER_CODE, STATE, null, Collections.emptyMap()); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining(OAuth2ParameterNames.STATE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verifyNoInteractions(this.registeredClientRepository, this.authorizationConsentService); + } + + @Test + public void authenticateWhenRegisteredClientNotFoundThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = createAuthorization(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining(OAuth2ParameterNames.CLIENT_ID) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + + verify(this.registeredClientRepository).findByClientId(registeredClient.getClientId()); + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService); + verifyNoInteractions(this.authorizationConsentService); + } + + @Test + public void authenticateWhenRegisteredClientDoesNotMatchAuthorizationThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + RegisteredClient registeredClient2 = TestRegisteredClients.registeredClient2().build(); + OAuth2Authorization authorization = createAuthorization(registeredClient2); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.registeredClientRepository.findByClientId(anyString())).thenReturn(registeredClient); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining(OAuth2ParameterNames.CLIENT_ID) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + + verify(this.registeredClientRepository).findByClientId(registeredClient.getClientId()); + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService); + verifyNoInteractions(this.authorizationConsentService); + } + + @Test + public void authenticateWhenRequestedScopesNotAuthorizedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + RegisteredClient registeredClient2 = TestRegisteredClients.registeredClient().scopes(Set::clear) + .scope("invalid").build(); + OAuth2Authorization authorization = createAuthorization(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.registeredClientRepository.findByClientId(anyString())).thenReturn(registeredClient); + Authentication authentication = createAuthentication(registeredClient2); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining(OAuth2ParameterNames.SCOPE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_SCOPE); + // @formatter:on + + verify(this.registeredClientRepository).findByClientId(registeredClient.getClientId()); + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService); + verifyNoInteractions(this.authorizationConsentService); + } + + @Test + public void authenticateWhenAuthoritiesIsEmptyThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + RegisteredClient registeredClient2 = TestRegisteredClients.registeredClient().scopes(Set::clear).build(); + OAuth2Authorization authorization = createAuthorization(registeredClient2); + Authentication authentication = createAuthentication(registeredClient2); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.registeredClientRepository.findByClientId(anyString())).thenReturn(registeredClient); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); + // @formatter:on + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verify(this.registeredClientRepository).findByClientId(registeredClient.getClientId()); + verify(this.authorizationConsentService).findById(registeredClient.getId(), authentication.getName()); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService, + this.authorizationConsentService); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + assertThat(updatedAuthorization.getAttribute(OAuth2ParameterNames.STATE)).isNull(); + // @formatter:off + assertThat(updatedAuthorization.getToken(OAuth2DeviceCode.class)) + .extracting(isInvalidated()) + .isEqualTo(true); + assertThat(updatedAuthorization.getToken(OAuth2UserCode.class)) + .extracting(isInvalidated()) + .isEqualTo(true); + // @formatter:on + } + + @Test + public void authenticateWhenAuthoritiesIsNotEmptyThenAuthorizationConsentSaved() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = createAuthorization(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.registeredClientRepository.findByClientId(anyString())).thenReturn(registeredClient); + + Authentication authentication = createAuthentication(registeredClient); + OAuth2DeviceVerificationAuthenticationToken authenticationResult = + (OAuth2DeviceVerificationAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getPrincipal()).isSameAs(authentication.getPrincipal()); + assertThat(authenticationResult.getUserCode()).isEqualTo(USER_CODE); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verify(this.registeredClientRepository).findByClientId(registeredClient.getClientId()); + verify(this.authorizationConsentService).findById(registeredClient.getId(), authentication.getName()); + verify(this.authorizationConsentService).save(any(OAuth2AuthorizationConsent.class)); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService, + this.authorizationConsentService); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(authentication.getName()); + assertThat(updatedAuthorization.getAuthorizedScopes()).hasSameElementsAs(registeredClient.getScopes()); + assertThat(updatedAuthorization.getAttribute(OAuth2ParameterNames.STATE)).isNull(); + assertThat(updatedAuthorization.>getAttribute(OAuth2ParameterNames.SCOPE)).isNull(); + // @formatter:off + assertThat(updatedAuthorization.getToken(OAuth2DeviceCode.class)) + .extracting(isInvalidated()) + .isEqualTo(false); + assertThat(updatedAuthorization.getToken(OAuth2UserCode.class)) + .extracting(isInvalidated()) + .isEqualTo(true); + // @formatter:on + } + + @Test + public void authenticateWhenExistingAuthorizationConsentThenUpdated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope("additional").build(); + RegisteredClient registeredClient2 = TestRegisteredClients.registeredClient().scopes(Set::clear) + .scope("additional").build(); + OAuth2Authorization authorization = createAuthorization(registeredClient2); + Authentication authentication = createAuthentication(registeredClient2); + // @formatter:off + OAuth2AuthorizationConsent authorizationConsent = + OAuth2AuthorizationConsent.withId(registeredClient.getId(), authentication.getName()) + .scope("scope1").build(); + // @formatter:on + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.registeredClientRepository.findByClientId(anyString())).thenReturn(registeredClient); + when(this.authorizationConsentService.findById(anyString(), anyString())).thenReturn(authorizationConsent); + + OAuth2DeviceVerificationAuthenticationToken authenticationResult = + (OAuth2DeviceVerificationAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getPrincipal()).isSameAs(authentication.getPrincipal()); + assertThat(authenticationResult.getUserCode()).isEqualTo(USER_CODE); + + ArgumentCaptor authorizationConsentCaptor = ArgumentCaptor.forClass( + OAuth2AuthorizationConsent.class); + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verify(this.registeredClientRepository).findByClientId(registeredClient.getClientId()); + verify(this.authorizationConsentService).findById(registeredClient.getId(), authentication.getName()); + verify(this.authorizationConsentService).save(authorizationConsentCaptor.capture()); + verify(this.authorizationService).save(any(OAuth2Authorization.class)); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService, + this.authorizationConsentService); + + OAuth2AuthorizationConsent updatedAuthorizationConsent = authorizationConsentCaptor.getValue(); + assertThat(updatedAuthorizationConsent.getRegisteredClientId()).isEqualTo(registeredClient.getId()); + assertThat(updatedAuthorizationConsent.getPrincipalName()).isEqualTo(authentication.getName()); + assertThat(updatedAuthorizationConsent.getScopes()).hasSameElementsAs(registeredClient.getScopes()); + } + + @Test + public void authenticateWhenAuthorizationConsentCustomizerSetThenUsed() { + SimpleGrantedAuthority customAuthority = new SimpleGrantedAuthority("test"); + this.authenticationProvider.setAuthorizationConsentCustomizer((context) -> context.getAuthorizationConsent() + .authority(customAuthority)); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scopes(Set::clear).build(); + OAuth2Authorization authorization = createAuthorization(registeredClient); + Authentication authentication = createAuthentication(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.registeredClientRepository.findByClientId(anyString())).thenReturn(registeredClient); + when(this.authorizationConsentService.findById(anyString(), anyString())).thenReturn(null); + + OAuth2DeviceVerificationAuthenticationToken authenticationResult = + (OAuth2DeviceVerificationAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getPrincipal()).isSameAs(authentication.getPrincipal()); + assertThat(authenticationResult.getUserCode()).isEqualTo(USER_CODE); + + ArgumentCaptor authorizationConsentCaptor = ArgumentCaptor.forClass( + OAuth2AuthorizationConsent.class); + verify(this.authorizationService).findByToken(STATE, STATE_TOKEN_TYPE); + verify(this.registeredClientRepository).findByClientId(registeredClient.getClientId()); + verify(this.authorizationConsentService).findById(registeredClient.getId(), authentication.getName()); + verify(this.authorizationConsentService).save(authorizationConsentCaptor.capture()); + verify(this.authorizationService).save(any(OAuth2Authorization.class)); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService, + this.authorizationConsentService); + + OAuth2AuthorizationConsent updatedAuthorizationConsent = authorizationConsentCaptor.getValue(); + assertThat(updatedAuthorizationConsent.getRegisteredClientId()).isEqualTo(registeredClient.getId()); + assertThat(updatedAuthorizationConsent.getPrincipalName()).isEqualTo(authentication.getName()); + assertThat(updatedAuthorizationConsent.getAuthorities()).containsExactly(customAuthority); + } + + private static OAuth2Authorization createAuthorization(RegisteredClient registeredClient) { + // @formatter:off + return TestOAuth2Authorizations.authorization(registeredClient) + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE) + .token(createDeviceCode()) + .token(createUserCode()) + .attributes(Map::clear) + .attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes()) + .build(); + // @formatter:on + } + + private static OAuth2DeviceAuthorizationConsentAuthenticationToken createAuthentication(RegisteredClient registeredClient) { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", null, Collections.emptyList()); + Set authorizedScopes = registeredClient.getScopes(); + if (authorizedScopes.isEmpty()) { + authorizedScopes = null; + } + Map additionalParameters = null; + return new OAuth2DeviceAuthorizationConsentAuthenticationToken(AUTHORIZATION_URI, + registeredClient.getClientId(), principal, USER_CODE, STATE, authorizedScopes, additionalParameters); + } + + private static OAuth2DeviceCode createDeviceCode() { + Instant issuedAt = Instant.now(); + return new OAuth2DeviceCode(DEVICE_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static OAuth2UserCode createUserCode() { + Instant issuedAt = Instant.now(); + return new OAuth2UserCode(USER_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static Function, Boolean> isInvalidated() { + return (token) -> token.getMetadata(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationRequestAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationRequestAuthenticationProviderTests.java new file mode 100644 index 00000000..d94fe9e6 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationRequestAuthenticationProviderTests.java @@ -0,0 +1,351 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.Set; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2DeviceCode; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2UserCode; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; +import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationRequestAuthenticationProvider.DEVICE_CODE_TOKEN_TYPE; +import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationRequestAuthenticationProvider.USER_CODE_TOKEN_TYPE; + +/** + * Tests for {@link OAuth2DeviceAuthorizationRequestAuthenticationProvider}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceAuthorizationRequestAuthenticationProviderTests { + private static final String AUTHORIZATION_URI = "/oauth2/device_authorization"; + private static final String DEVICE_CODE = "EfYu_0jEL"; + private static final String USER_CODE = "BCDF-GHJK"; + + private OAuth2AuthorizationService authorizationService; + private OAuth2DeviceAuthorizationRequestAuthenticationProvider authenticationProvider; + + @BeforeEach + public void setUp() { + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authenticationProvider = new OAuth2DeviceAuthorizationRequestAuthenticationProvider( + this.authorizationService); + mockAuthorizationServerContext(); + } + + @AfterEach + public void tearDown() { + AuthorizationServerContextHolder.resetContext(); + } + + @Test + public void constructorWhenAuthorizationServiceIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceAuthorizationRequestAuthenticationProvider(null)) + .withMessage("authorizationService cannot be null"); + // @formatter:on + } + + @Test + public void setDeviceCodeGeneratorWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setDeviceCodeGenerator(null)) + .withMessage("deviceCodeGenerator cannot be null"); + // @formatter:on + } + + @Test + public void setUserCodeGeneratorWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setUserCodeGenerator(null)) + .withMessage("userCodeGenerator cannot be null"); + // @formatter:on + } + + @Test + public void supportsWhenTypeOAuth2DeviceAuthorizationRequestAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2DeviceAuthorizationRequestAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenClientNotAuthenticatedThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken clientPrincipal = + new OAuth2ClientAuthenticationToken("client-1", ClientAuthenticationMethod.CLIENT_SECRET_BASIC, null, null); + OAuth2DeviceAuthorizationRequestAuthenticationToken authentication = + new OAuth2DeviceAuthorizationRequestAuthenticationToken(clientPrincipal, AUTHORIZATION_URI, null, null); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + // @formatter:on + } + + @Test + public void authenticateWhenInvalidGrantTypeThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining(OAuth2ParameterNames.CLIENT_ID) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT); + // @formatter:on + } + + @Test + public void authenticateWhenInvalidScopesThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE).build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient, + ClientAuthenticationMethod.CLIENT_SECRET_BASIC, null); + Authentication authentication = new OAuth2DeviceAuthorizationRequestAuthenticationToken(clientPrincipal, + AUTHORIZATION_URI, Collections.singleton("invalid"), null); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining(OAuth2ParameterNames.SCOPE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_SCOPE); + // @formatter:on + } + + @Test + public void authenticateWhenDeviceCodeIsNullThenThrowOAuth2AuthenticationException() { + @SuppressWarnings("unchecked") + OAuth2TokenGenerator deviceCodeGenerator = mock(OAuth2TokenGenerator.class); + when(deviceCodeGenerator.generate(any(OAuth2TokenContext.class))).thenReturn(null); + this.authenticationProvider.setDeviceCodeGenerator(deviceCodeGenerator); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE).build(); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining("The token generator failed to generate the device code.") + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + // @formatter:on + + verify(deviceCodeGenerator).generate(any(OAuth2TokenContext.class)); + verifyNoMoreInteractions(deviceCodeGenerator); + verifyNoInteractions(this.authorizationService); + } + + @Test + public void authenticateWhenUserCodeIsNullThenThrowOAuth2AuthenticationException() { + @SuppressWarnings("unchecked") + OAuth2TokenGenerator userCodeGenerator = mock(OAuth2TokenGenerator.class); + when(userCodeGenerator.generate(any(OAuth2TokenContext.class))).thenReturn(null); + this.authenticationProvider.setUserCodeGenerator(userCodeGenerator); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE).build(); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessageContaining("The token generator failed to generate the user code.") + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + // @formatter:on + + verify(userCodeGenerator).generate(any(OAuth2TokenContext.class)); + verifyNoMoreInteractions(userCodeGenerator); + verifyNoInteractions(this.authorizationService); + } + + @Test + public void authenticateWhenScopesRequestedThenReturnDeviceCodeAndUserCode() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE).build(); + Authentication authentication = createAuthentication(registeredClient); + OAuth2DeviceAuthorizationRequestAuthenticationToken authenticationResult = + (OAuth2DeviceAuthorizationRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(authenticationResult.getScopes()).hasSameElementsAs(registeredClient.getScopes()); + assertThat(authenticationResult.getDeviceCode().getTokenValue()).hasSize(128); + assertThat(authenticationResult.getUserCode().getTokenValue()).hasSize(9); // 8 chars + 1 dash + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verifyNoMoreInteractions(this.authorizationService); + + OAuth2Authorization authorization = authorizationCaptor.getValue(); + assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); + assertThat(authorization.getPrincipalName()).isEqualTo(authentication.getName()); + assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.DEVICE_CODE); + assertThat(authorization.getToken(OAuth2DeviceCode.class)).isNotNull(); + assertThat(authorization.getToken(OAuth2UserCode.class)).isNotNull(); + assertThat(authorization.>getAttribute(OAuth2ParameterNames.SCOPE)) + .hasSameElementsAs(registeredClient.getScopes()); + } + + @Test + public void authenticateWhenNoScopesRequestedThenReturnDeviceCodeAndUserCode() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scopes(Set::clear) + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE).build(); + Authentication authentication = createAuthentication(registeredClient); + OAuth2DeviceAuthorizationRequestAuthenticationToken authenticationResult = + (OAuth2DeviceAuthorizationRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(authenticationResult.getScopes()).hasSameElementsAs(registeredClient.getScopes()); + assertThat(authenticationResult.getDeviceCode().getTokenValue()).hasSize(128); + assertThat(authenticationResult.getUserCode().getTokenValue()).hasSize(9); // 8 chars + 1 dash + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verifyNoMoreInteractions(this.authorizationService); + + OAuth2Authorization authorization = authorizationCaptor.getValue(); + assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); + assertThat(authorization.getPrincipalName()).isEqualTo(authentication.getName()); + assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.DEVICE_CODE); + assertThat(authorization.getToken(OAuth2DeviceCode.class)).isNotNull(); + assertThat(authorization.getToken(OAuth2UserCode.class)).isNotNull(); + assertThat(authorization.>getAttribute(OAuth2ParameterNames.SCOPE)) + .hasSameElementsAs(registeredClient.getScopes()); + } + + @Test + public void authenticateWhenDeviceCodeGeneratorSetThenUsed() { + @SuppressWarnings("unchecked") + OAuth2TokenGenerator deviceCodeGenerator = mock(OAuth2TokenGenerator.class); + when(deviceCodeGenerator.generate(any(OAuth2TokenContext.class))).thenReturn(createDeviceCode()); + this.authenticationProvider.setDeviceCodeGenerator(deviceCodeGenerator); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE).build(); + Authentication authentication = createAuthentication(registeredClient); + OAuth2DeviceAuthorizationRequestAuthenticationToken authenticationResult = + (OAuth2DeviceAuthorizationRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(authenticationResult.getScopes()).hasSameElementsAs(registeredClient.getScopes()); + assertThat(authenticationResult.getDeviceCode().getTokenValue()).isEqualTo(DEVICE_CODE); + assertThat(authenticationResult.getUserCode().getTokenValue()).hasSize(9); // 8 chars + 1 dash + + ArgumentCaptor tokenContextCaptor = ArgumentCaptor.forClass(OAuth2TokenContext.class); + verify(deviceCodeGenerator).generate(tokenContextCaptor.capture()); + verify(this.authorizationService).save(any(OAuth2Authorization.class)); + verifyNoMoreInteractions(this.authorizationService, deviceCodeGenerator); + + OAuth2TokenContext tokenContext = tokenContextCaptor.getValue(); + assertThat(tokenContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(tokenContext.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(tokenContext.getAuthorizationServerContext()).isNotNull(); + assertThat(tokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.DEVICE_CODE); + assertThat(tokenContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(tokenContext.getTokenType()).isEqualTo(DEVICE_CODE_TOKEN_TYPE); + } + + @Test + public void authenticateWhenUserCodeGeneratorSetThenUsed() { + @SuppressWarnings("unchecked") + OAuth2TokenGenerator userCodeGenerator = mock(OAuth2TokenGenerator.class); + when(userCodeGenerator.generate(any(OAuth2TokenContext.class))).thenReturn(createUserCode()); + this.authenticationProvider.setUserCodeGenerator(userCodeGenerator); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE).build(); + Authentication authentication = createAuthentication(registeredClient); + OAuth2DeviceAuthorizationRequestAuthenticationToken authenticationResult = + (OAuth2DeviceAuthorizationRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(authenticationResult.getScopes()).hasSameElementsAs(registeredClient.getScopes()); + assertThat(authenticationResult.getDeviceCode().getTokenValue()).hasSize(128); + assertThat(authenticationResult.getUserCode().getTokenValue()).isEqualTo(USER_CODE); + + ArgumentCaptor tokenContextCaptor = ArgumentCaptor.forClass(OAuth2TokenContext.class); + verify(userCodeGenerator).generate(tokenContextCaptor.capture()); + verify(this.authorizationService).save(any(OAuth2Authorization.class)); + verifyNoMoreInteractions(this.authorizationService, userCodeGenerator); + + OAuth2TokenContext tokenContext = tokenContextCaptor.getValue(); + assertThat(tokenContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(tokenContext.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(tokenContext.getAuthorizationServerContext()).isNotNull(); + assertThat(tokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.DEVICE_CODE); + assertThat(tokenContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(tokenContext.getTokenType()).isEqualTo(USER_CODE_TOKEN_TYPE); + } + + private static void mockAuthorizationServerContext() { + AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build(); + TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext( + authorizationServerSettings, () -> "https://provider.com"); + AuthorizationServerContextHolder.setContext(authorizationServerContext); + } + + private static OAuth2DeviceAuthorizationRequestAuthenticationToken createAuthentication(RegisteredClient registeredClient) { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient, + ClientAuthenticationMethod.CLIENT_SECRET_BASIC, null); + Set requestedScopes = registeredClient.getScopes(); + if (requestedScopes.isEmpty()) { + requestedScopes = null; + } + return new OAuth2DeviceAuthorizationRequestAuthenticationToken(clientPrincipal, AUTHORIZATION_URI, requestedScopes, null); + } + + private static OAuth2DeviceCode createDeviceCode() { + Instant issuedAt = Instant.now(); + return new OAuth2DeviceCode(DEVICE_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static OAuth2UserCode createUserCode() { + Instant issuedAt = Instant.now(); + return new OAuth2UserCode(USER_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceCodeAuthenticationProviderTests.java new file mode 100644 index 00000000..7d2720a8 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceCodeAuthenticationProviderTests.java @@ -0,0 +1,431 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.security.Principal; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2DeviceCode; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.OAuth2UserCode; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; +import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext; +import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceCodeAuthenticationProvider.AUTHORIZATION_PENDING; +import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceCodeAuthenticationProvider.DEVICE_CODE_TOKEN_TYPE; +import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceCodeAuthenticationProvider.EXPIRED_TOKEN; + +/** + * Tests for {@link OAuth2DeviceCodeAuthenticationProvider}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceCodeAuthenticationProviderTests { + private static final String DEVICE_CODE = "EfYu_0jEL"; + private static final String USER_CODE = "BCDF-GHJK"; + private static final String ACCESS_TOKEN = "abc123"; + private static final String REFRESH_TOKEN = "xyz456"; + + private OAuth2AuthorizationService authorizationService; + private OAuth2TokenGenerator tokenGenerator; + private OAuth2DeviceCodeAuthenticationProvider authenticationProvider; + + @BeforeEach + @SuppressWarnings("unchecked") + public void setUp() { + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.tokenGenerator = mock(OAuth2TokenGenerator.class); + this.authenticationProvider = new OAuth2DeviceCodeAuthenticationProvider(this.authorizationService, + this.tokenGenerator); + mockAuthorizationServerContext(); + } + + @AfterEach + public void tearDown() { + AuthorizationServerContextHolder.resetContext(); + } + + @Test + public void constructorWhenAuthorizationServiceIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceCodeAuthenticationProvider(null, this.tokenGenerator)) + .withMessage("authorizationService cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenTokenGeneratorIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceCodeAuthenticationProvider(this.authorizationService, null)) + .withMessage("tokenGenerator cannot be null"); + // @formatter:on + } + + @Test + public void supportsWhenTypeOAuth2DeviceCodeAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2DeviceCodeAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenClientNotAuthenticatedThenThrowOAuth2AuthenticationException() { + OAuth2ClientAuthenticationToken clientPrincipal = + new OAuth2ClientAuthenticationToken("client-1", ClientAuthenticationMethod.CLIENT_SECRET_BASIC, null, null); + Authentication authentication = new OAuth2DeviceCodeAuthenticationToken(DEVICE_CODE, clientPrincipal, null); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + // @formatter:on + } + + @Test + public void authenticateWhenAuthorizationNotFoundThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(null); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + // @formatter:on + + verify(this.authorizationService).findByToken(DEVICE_CODE, DEVICE_CODE_TOKEN_TYPE); + verifyNoMoreInteractions(this.authorizationService); + verifyNoInteractions(this.tokenGenerator); + } + + @Test + public void authenticateWhenRegisteredClientDoesNotMatchClientIdThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + RegisteredClient registeredClient2 = TestRegisteredClients.registeredClient2().build(); + Authentication authentication = createAuthentication(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient2) + .token(createDeviceCode()).build(); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + // @formatter:on + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).findByToken(DEVICE_CODE, DEVICE_CODE_TOKEN_TYPE); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verifyNoMoreInteractions(this.authorizationService); + verifyNoInteractions(this.tokenGenerator); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + // @formatter:off + assertThat(updatedAuthorization.getToken(OAuth2DeviceCode.class)) + .extracting(isInvalidated()) + .isEqualTo(true); + // @formatter:on + } + + @Test + public void authenticateWhenUserCodeIsNotInvalidatedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createUserCode()).build(); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(AUTHORIZATION_PENDING); + // @formatter:on + + verify(this.authorizationService).findByToken(DEVICE_CODE, DEVICE_CODE_TOKEN_TYPE); + verifyNoMoreInteractions(this.authorizationService); + verifyNoInteractions(this.tokenGenerator); + } + + @Test + public void authenticateWhenDeviceCodeIsInvalidatedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createDeviceCode(), withInvalidated()).token(createUserCode(), withInvalidated()).build(); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); + // @formatter:on + + verify(this.authorizationService).findByToken(DEVICE_CODE, DEVICE_CODE_TOKEN_TYPE); + verifyNoMoreInteractions(this.authorizationService); + verifyNoInteractions(this.tokenGenerator); + } + + @Test + public void authenticateWhenDeviceCodeIsExpiredThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createExpiredDeviceCode()).token(createUserCode(), withInvalidated()).build(); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(EXPIRED_TOKEN); + // @formatter:on + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).findByToken(DEVICE_CODE, DEVICE_CODE_TOKEN_TYPE); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verifyNoMoreInteractions(this.authorizationService); + verifyNoInteractions(this.tokenGenerator); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + // @formatter:off + assertThat(updatedAuthorization.getToken(OAuth2DeviceCode.class)) + .extracting(isInvalidated()) + .isEqualTo(true); + // @formatter:on + } + + @Test + public void authenticateWhenAccessTokenIsNullThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createDeviceCode()) + .token(createUserCode(), withInvalidated()) + .attribute(Principal.class.getName(), authentication.getPrincipal()) + .build(); + // @formatter:on + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.tokenGenerator.generate(any(OAuth2TokenContext.class))).thenReturn(null); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessage("The token generator failed to generate the access token.") + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + // @formatter:on + + verify(this.authorizationService).findByToken(DEVICE_CODE, DEVICE_CODE_TOKEN_TYPE); + verify(this.tokenGenerator).generate(any(OAuth2TokenContext.class)); + verifyNoMoreInteractions(this.authorizationService, this.tokenGenerator); + } + + @Test + public void authenticateWhenRefreshTokenIsNullThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createDeviceCode()) + .token(createUserCode(), withInvalidated()) + .attribute(Principal.class.getName(), authentication.getPrincipal()) + .build(); + // @formatter:on + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.tokenGenerator.generate(any(OAuth2TokenContext.class))).thenReturn(createAccessToken(), + (OAuth2RefreshToken) null); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessage("The token generator failed to generate the refresh token.") + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + // @formatter:on + + verify(this.authorizationService).findByToken(DEVICE_CODE, DEVICE_CODE_TOKEN_TYPE); + verify(this.tokenGenerator, times(2)).generate(any(OAuth2TokenContext.class)); + verifyNoMoreInteractions(this.authorizationService, this.tokenGenerator); + } + + @Test + public void authenticateWhenTokenGeneratorReturnsWrongTypeThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createDeviceCode()) + .token(createUserCode(), withInvalidated()) + .attribute(Principal.class.getName(), authentication.getPrincipal()) + .build(); + // @formatter:on + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + OAuth2AccessToken accessToken = createAccessToken(); + when(this.tokenGenerator.generate(any(OAuth2TokenContext.class))).thenReturn(accessToken, accessToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .withMessage("The token generator failed to generate the refresh token.") + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + // @formatter:on + + verify(this.authorizationService).findByToken(DEVICE_CODE, DEVICE_CODE_TOKEN_TYPE); + verify(this.tokenGenerator, times(2)).generate(any(OAuth2TokenContext.class)); + verifyNoMoreInteractions(this.authorizationService, this.tokenGenerator); + } + + @Test + public void authenticateWhenValidDeviceCodeThenReturnAccessTokenAndRefreshToken() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication authentication = createAuthentication(registeredClient); + // @formatter:off + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createDeviceCode()) + .token(createUserCode(), withInvalidated()) + .attribute(Principal.class.getName(), authentication.getPrincipal()) + .build(); + // @formatter:on + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + OAuth2AccessToken accessToken = createAccessToken(); + OAuth2RefreshToken refreshToken = createRefreshToken(); + when(this.tokenGenerator.generate(any(OAuth2TokenContext.class))).thenReturn(accessToken, refreshToken); + OAuth2AccessTokenAuthenticationToken authenticationResult = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(authenticationResult.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(authenticationResult.getAccessToken()).isEqualTo(accessToken); + assertThat(authenticationResult.getRefreshToken()).isEqualTo(refreshToken); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + ArgumentCaptor tokenContextCaptor = ArgumentCaptor.forClass(OAuth2TokenContext.class); + verify(this.authorizationService).findByToken(DEVICE_CODE, DEVICE_CODE_TOKEN_TYPE); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verify(this.tokenGenerator, times(2)).generate(tokenContextCaptor.capture()); + verifyNoMoreInteractions(this.authorizationService, this.tokenGenerator); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + // @formatter:off + assertThat(updatedAuthorization.getToken(OAuth2DeviceCode.class)) + .extracting(isInvalidated()) + .isEqualTo(true); + // @formatter:on + assertThat(updatedAuthorization.getAccessToken().getToken()).isEqualTo(accessToken); + assertThat(updatedAuthorization.getRefreshToken().getToken()).isEqualTo(refreshToken); + + for (OAuth2TokenContext tokenContext : tokenContextCaptor.getAllValues()) { + assertThat(tokenContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(tokenContext.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(tokenContext.getAuthorizationServerContext()).isNotNull(); + assertThat(tokenContext.getAuthorization()).isEqualTo(authorization); + assertThat(tokenContext.getAuthorizedScopes()).isEqualTo(authorization.getAuthorizedScopes()); + assertThat(tokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.DEVICE_CODE); + assertThat(tokenContext.getAuthorizationGrant()).isEqualTo(authentication); + } + assertThat(tokenContextCaptor.getAllValues().get(0).getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN); + assertThat(tokenContextCaptor.getAllValues().get(1).getTokenType()).isEqualTo(OAuth2TokenType.REFRESH_TOKEN); + } + + private static void mockAuthorizationServerContext() { + AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build(); + TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext( + authorizationServerSettings, () -> "https://provider.com"); + AuthorizationServerContextHolder.setContext(authorizationServerContext); + } + + private static OAuth2DeviceCodeAuthenticationToken createAuthentication(RegisteredClient registeredClient) { + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient, + ClientAuthenticationMethod.CLIENT_SECRET_BASIC, null); + return new OAuth2DeviceCodeAuthenticationToken(DEVICE_CODE, clientPrincipal, null); + } + + private static OAuth2DeviceCode createDeviceCode() { + Instant issuedAt = Instant.now(); + return new OAuth2DeviceCode(DEVICE_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static OAuth2DeviceCode createExpiredDeviceCode() { + Instant issuedAt = Instant.now().minus(45, ChronoUnit.MINUTES); + return new OAuth2DeviceCode(DEVICE_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static OAuth2UserCode createUserCode() { + Instant issuedAt = Instant.now(); + return new OAuth2UserCode(USER_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static OAuth2AccessToken createAccessToken() { + Instant issuedAt = Instant.now(); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, ACCESS_TOKEN, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static OAuth2RefreshToken createRefreshToken() { + Instant issuedAt = Instant.now(); + return new OAuth2RefreshToken(REFRESH_TOKEN, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static Consumer> withInvalidated() { + return (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true); + } + + public static Function, Boolean> isInvalidated() { + return (token) -> token.getMetadata(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java new file mode 100644 index 00000000..97bf5904 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java @@ -0,0 +1,326 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import java.security.Principal; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.Map; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2DeviceCode; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.OAuth2UserCode; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; +import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceVerificationAuthenticationProvider.USER_CODE_TOKEN_TYPE; + +/** + * Tests for {@link OAuth2DeviceVerificationAuthenticationProvider}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceVerificationAuthenticationProviderTests { + private static final String AUTHORIZATION_URI = "/oauth2/device_verification"; + private static final String DEVICE_CODE = "EfYu_0jEL"; + private static final String USER_CODE = "BCDF-GHJK"; + + private RegisteredClientRepository registeredClientRepository; + private OAuth2AuthorizationService authorizationService; + private OAuth2AuthorizationConsentService authorizationConsentService; + private OAuth2DeviceVerificationAuthenticationProvider authenticationProvider; + + @BeforeEach + public void setUp() { + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authorizationConsentService = mock(OAuth2AuthorizationConsentService.class); + this.authenticationProvider = new OAuth2DeviceVerificationAuthenticationProvider( + this.registeredClientRepository, this.authorizationService, this.authorizationConsentService); + mockAuthorizationServerContext(); + } + + @Test + public void constructorWhenRegisteredClientRepositoryIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceVerificationAuthenticationProvider( + null, this.authorizationService, this.authorizationConsentService)) + .withMessage("registeredClientRepository cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenAuthorizationServiceIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceVerificationAuthenticationProvider( + this.registeredClientRepository, null, this.authorizationConsentService)) + .withMessage("authorizationService cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenAuthorizationConsentServiceIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceVerificationAuthenticationProvider( + this.registeredClientRepository, this.authorizationService, null)) + .withMessage("authorizationConsentService cannot be null"); + // @formatter:on + } + + @Test + public void supportsWhenTypeOAuth2DeviceAuthorizationRequestAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OAuth2DeviceVerificationAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenAuthorizationNotFoundThenThrowOAuth2AuthenticationException() { + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(null); + Authentication authentication = createAuthentication(); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + // @formatter:on + + verify(this.authorizationService).findByToken(USER_CODE, USER_CODE_TOKEN_TYPE); + verifyNoMoreInteractions(this.authorizationService); + verifyNoInteractions(this.registeredClientRepository, this.authorizationConsentService); + } + + @Test + public void authenticateWhenPrincipalNotAuthenticatedThenReturnUnauthenticated() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", null); + Authentication authentication = new OAuth2DeviceVerificationAuthenticationToken(principal, USER_CODE, Collections.emptyMap()); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + + OAuth2DeviceVerificationAuthenticationToken authenticationResult = + (OAuth2DeviceVerificationAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult).isEqualTo(authentication); + assertThat(authenticationResult.isAuthenticated()).isFalse(); + + verify(this.authorizationService).findByToken(USER_CODE, USER_CODE_TOKEN_TYPE); + verifyNoMoreInteractions(this.authorizationService); + verifyNoInteractions(this.registeredClientRepository, this.authorizationConsentService); + } + + @Test + public void authenticateWhenAuthorizationConsentDoesNotExistThenReturnAuthorizationConsentWithState() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + // @formatter:off + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .token(createDeviceCode()) + .token(createUserCode()) + .attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes()) + .build(); + // @formatter:on + Authentication authentication = createAuthentication(); + when(this.registeredClientRepository.findById(anyString())).thenReturn(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.authorizationConsentService.findById(anyString(), anyString())).thenReturn(null); + + OAuth2DeviceAuthorizationConsentAuthenticationToken authenticationResult = + (OAuth2DeviceAuthorizationConsentAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI); + assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(authenticationResult.getUserCode()).isEqualTo(USER_CODE); + assertThat(authenticationResult.getState()).hasSize(44); + assertThat(authenticationResult.getRequestedScopes()).hasSameElementsAs(registeredClient.getScopes()); + assertThat(authenticationResult.getScopes()).isEmpty(); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).findByToken(USER_CODE, USER_CODE_TOKEN_TYPE); + verify(this.registeredClientRepository).findById(authorization.getRegisteredClientId()); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verify(this.authorizationConsentService).findById(registeredClient.getId(), authentication.getName()); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService, + this.authorizationConsentService); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + assertThat(updatedAuthorization.getAttribute(OAuth2ParameterNames.STATE)) + .isEqualTo(authenticationResult.getState()); + } + + @Test + public void authenticateWhenAuthorizationConsentExistsAndRequestedScopesMatchThenReturnDeviceVerification() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + // @formatter:off + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE) + .token(createDeviceCode()) + .token(createUserCode()) + .attributes(Map::clear) + .attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes()) + .build(); + // @formatter:on + Authentication authentication = createAuthentication(); + // @formatter:off + OAuth2AuthorizationConsent authorizationConsent = + OAuth2AuthorizationConsent.withId(registeredClient.getId(), authentication.getName()) + .scope(registeredClient.getScopes().iterator().next()) + .build(); + // @formatter:on + when(this.registeredClientRepository.findById(anyString())).thenReturn(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.authorizationConsentService.findById(anyString(), anyString())).thenReturn(authorizationConsent); + + OAuth2DeviceVerificationAuthenticationToken authenticationResult = + (OAuth2DeviceVerificationAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(authenticationResult.getUserCode()).isEqualTo(USER_CODE); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).findByToken(USER_CODE, USER_CODE_TOKEN_TYPE); + verify(this.registeredClientRepository).findById(authorization.getRegisteredClientId()); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verify(this.authorizationConsentService).findById(registeredClient.getId(), authentication.getName()); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService, + this.authorizationConsentService); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(authentication.getName()); + assertThat(updatedAuthorization.getAuthorizedScopes()).hasSameElementsAs(registeredClient.getScopes()); + assertThat(updatedAuthorization.getAttribute(Principal.class.getName())) + .isEqualTo(authentication.getPrincipal()); + assertThat(updatedAuthorization.getAttribute(OAuth2ParameterNames.STATE)).isNull(); + // @formatter:off + assertThat(updatedAuthorization.getToken(OAuth2DeviceCode.class)) + .extracting(isInvalidated()) + .isEqualTo(false); + assertThat(updatedAuthorization.getToken(OAuth2UserCode.class)) + .extracting(isInvalidated()) + .isEqualTo(true); + // @formatter:on + } + + @Test + public void authenticateWhenAuthorizationConsentExistsAndRequestedScopesDoNotMatchThenReturnAuthorizationConsentWithState() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + // @formatter:off + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .authorizationGrantType(AuthorizationGrantType.DEVICE_CODE) + .token(createDeviceCode()) + .token(createUserCode()) + .attributes(Map::clear) + .attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes()) + .build(); + // @formatter:on + Authentication authentication = createAuthentication(); + // @formatter:off + OAuth2AuthorizationConsent authorizationConsent = + OAuth2AuthorizationConsent.withId(registeredClient.getId(), authentication.getName()) + .scope("previous") + .build(); + // @formatter:on + when(this.registeredClientRepository.findById(anyString())).thenReturn(registeredClient); + when(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).thenReturn(authorization); + when(this.authorizationConsentService.findById(anyString(), anyString())).thenReturn(authorizationConsent); + + OAuth2DeviceAuthorizationConsentAuthenticationToken authenticationResult = + (OAuth2DeviceAuthorizationConsentAuthenticationToken) this.authenticationProvider.authenticate(authentication); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + assertThat(authenticationResult.getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI); + assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getPrincipal()).isEqualTo(authentication.getPrincipal()); + assertThat(authenticationResult.getUserCode()).isEqualTo(USER_CODE); + assertThat(authenticationResult.getState()).hasSize(44); + assertThat(authenticationResult.getRequestedScopes()).hasSameElementsAs(registeredClient.getScopes()); + assertThat(authenticationResult.getScopes()).containsExactly("previous"); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).findByToken(USER_CODE, USER_CODE_TOKEN_TYPE); + verify(this.registeredClientRepository).findById(authorization.getRegisteredClientId()); + verify(this.authorizationService).save(authorizationCaptor.capture()); + verify(this.authorizationConsentService).findById(registeredClient.getId(), authentication.getName()); + verifyNoMoreInteractions(this.registeredClientRepository, this.authorizationService, + this.authorizationConsentService); + + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + assertThat(updatedAuthorization.getAttribute(OAuth2ParameterNames.STATE)) + .isEqualTo(authenticationResult.getState()); + } + + private static void mockAuthorizationServerContext() { + AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build(); + TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext( + authorizationServerSettings, () -> "https://provider.com"); + AuthorizationServerContextHolder.setContext(authorizationServerContext); + } + + private static OAuth2DeviceVerificationAuthenticationToken createAuthentication() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", null, + AuthorityUtils.createAuthorityList("USER")); + return new OAuth2DeviceVerificationAuthenticationToken(principal, USER_CODE, Collections.emptyMap()); + } + + private static OAuth2DeviceCode createDeviceCode() { + Instant issuedAt = Instant.now(); + return new OAuth2DeviceCode(DEVICE_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static OAuth2UserCode createUserCode() { + Instant issuedAt = Instant.now(); + return new OAuth2UserCode(USER_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static Function, Boolean> isInvalidated() { + return (token) -> token.getMetadata(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceAuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceAuthorizationEndpointFilterTests.java new file mode 100644 index 00000000..8a3193f6 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceAuthorizationEndpointFilterTests.java @@ -0,0 +1,423 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.http.HttpServletRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationDetailsSource; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2DeviceCode; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2UserCode; +import org.springframework.security.oauth2.core.endpoint.OAuth2DeviceAuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2DeviceAuthorizationResponseHttpMessageConverter; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationRequestAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; +import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.WebAuthenticationDetails; + +import static java.util.Map.entry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OAuth2DeviceAuthorizationEndpointFilter}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceAuthorizationEndpointFilterTests { + private static final String ISSUER_URI = "https://provider.com"; + private static final String REMOTE_ADDRESS = "remote-address"; + private static final String AUTHORIZATION_URI = "/oauth2/device_authorization"; + private static final String VERIFICATION_URI = "/oauth2/device_verification"; + private static final String CLIENT_ID = "client-1"; + private static final String DEVICE_CODE = "EfYu_0jEL"; + private static final String USER_CODE = "BCDF-GHJK"; + + private AuthenticationManager authenticationManager; + private OAuth2DeviceAuthorizationEndpointFilter filter; + + private final HttpMessageConverter deviceAuthorizationHttpResponseConverter = + new OAuth2DeviceAuthorizationResponseHttpMessageConverter(); + private final HttpMessageConverter errorHttpResponseConverter = + new OAuth2ErrorHttpMessageConverter(); + + @BeforeEach + public void setUp() { + this.authenticationManager = mock(AuthenticationManager.class); + this.filter = new OAuth2DeviceAuthorizationEndpointFilter(this.authenticationManager); + mockAuthorizationServerContext(); + } + + @AfterEach + public void tearDown() { + SecurityContextHolder.clearContext(); + AuthorizationServerContextHolder.resetContext(); + } + + @Test + public void constructorWhenAuthenticationMangerIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceAuthorizationEndpointFilter(null)) + .withMessage("authenticationManager cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenDeviceAuthorizationEndpointUriIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceAuthorizationEndpointFilter(this.authenticationManager, null)) + .withMessage("deviceAuthorizationEndpointUri cannot be empty"); + // @formatter:on + } + + @Test + public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationConverter(null)) + .withMessage("authenticationConverter cannot be null"); + // @formatter:on + } + + @Test + public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationDetailsSource(null)) + .withMessage("authenticationDetailsSource cannot be null"); + // @formatter:on + } + + @Test + public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null)) + .withMessage("authenticationSuccessHandler cannot be null"); + // @formatter:on + } + + @Test + public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null)) + .withMessage("authenticationFailureHandler cannot be null"); + // @formatter:on + } + + @Test + public void setVerificationUriWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setVerificationUri(null)) + .withMessage("verificationUri cannot be empty"); + // @formatter:on + } + + @Test + public void doFilterWhenNotDeviceAuthorizationRequestThenNotProcessed() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "/path"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + verify(filterChain).doFilter(request, response); + verifyNoInteractions(this.authenticationManager); + } + + @Test + public void doFilterWhenDeviceAuthorizationRequestGetThenNotProcessed() throws Exception { + MockHttpServletRequest request = createRequest(); + request.setMethod(HttpMethod.GET.name()); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + verify(filterChain).doFilter(request, response); + verifyNoInteractions(this.authenticationManager); + } + + @Test + public void doFilterWhenDeviceAuthorizationRequestThenDeviceAuthorizationResponse() throws Exception { + Authentication authenticationResult = createAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + Authentication clientPrincipal = (Authentication) authenticationResult.getPrincipal(); + mockSecurityContext(clientPrincipal); + + MockHttpServletRequest request = createRequest(); + request.addParameter("custom-param-1", "custom-value-1"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + + ArgumentCaptor deviceAuthorizationRequestAuthenticationCaptor = + ArgumentCaptor.forClass(OAuth2DeviceAuthorizationRequestAuthenticationToken.class); + verify(this.authenticationManager).authenticate(deviceAuthorizationRequestAuthenticationCaptor.capture()); + verifyNoInteractions(filterChain); + + OAuth2DeviceAuthorizationRequestAuthenticationToken deviceAuthorizationRequestAuthentication = + deviceAuthorizationRequestAuthenticationCaptor.getValue(); + assertThat(deviceAuthorizationRequestAuthentication.getAuthorizationUri()).endsWith(AUTHORIZATION_URI); + assertThat(deviceAuthorizationRequestAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(deviceAuthorizationRequestAuthentication.getScopes()).isEmpty(); + assertThat(deviceAuthorizationRequestAuthentication.getAdditionalParameters()) + .containsExactly(entry("custom-param-1", "custom-value-1")); + // @formatter:off + assertThat(deviceAuthorizationRequestAuthentication.getDetails()) + .asInstanceOf(type(WebAuthenticationDetails.class)) + .extracting(WebAuthenticationDetails::getRemoteAddress) + .isEqualTo(REMOTE_ADDRESS); + // @formatter:on + + OAuth2DeviceAuthorizationResponse deviceAuthorizationResponse = readDeviceAuthorizationResponse(response); + String verificationUri = ISSUER_URI + VERIFICATION_URI; + assertThat(deviceAuthorizationResponse.getVerificationUri()).isEqualTo(verificationUri); + assertThat(deviceAuthorizationResponse.getVerificationUriComplete()) + .isEqualTo("%s?%s=%s".formatted(verificationUri, OAuth2ParameterNames.USER_CODE, USER_CODE)); + OAuth2DeviceCode deviceCode = deviceAuthorizationResponse.getDeviceCode(); + assertThat(deviceCode.getTokenValue()).isEqualTo(DEVICE_CODE); + assertThat(deviceCode.getExpiresAt()).isAfter(deviceCode.getIssuedAt()); + OAuth2UserCode userCode = deviceAuthorizationResponse.getUserCode(); + assertThat(userCode.getTokenValue()).isEqualTo(USER_CODE); + assertThat(deviceCode.getExpiresAt()).isAfter(deviceCode.getIssuedAt()); + } + + @Test + public void doFilterWhenInvalidRequestErrorThenBadRequest() throws Exception { + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + OAuth2AuthenticationException authenticationException = new OAuth2AuthenticationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, "Invalid request", "error-uri")); + when(authenticationConverter.convert(any(HttpServletRequest.class))).thenThrow(authenticationException); + this.filter.setAuthenticationConverter(authenticationConverter); + + MockHttpServletRequest request = createRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + + verify(authenticationConverter).convert(request); + verifyNoInteractions(filterChain, this.authenticationManager); + + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + assertThat(error.getDescription()).isEqualTo("Invalid request"); + assertThat(error.getUri()).isEqualTo("error-uri"); + } + + @Test + public void doFilterWhenCustomDeviceAuthorizationEndpointUriThenUsed() throws Exception { + Authentication authenticationResult = createAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + Authentication clientPrincipal = (Authentication) authenticationResult.getPrincipal(); + mockSecurityContext(clientPrincipal); + + MockHttpServletRequest request = createRequest(); + request.setRequestURI("/device"); + request.setServletPath("/device"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter = new OAuth2DeviceAuthorizationEndpointFilter(this.authenticationManager, "/device"); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenAuthenticationConverterSetThenUsed() throws Exception { + Authentication authenticationResult = createAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + Authentication clientPrincipal = (Authentication) authenticationResult.getPrincipal(); + mockSecurityContext(clientPrincipal); + + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + OAuth2DeviceAuthorizationRequestAuthenticationToken authenticationRequest = + new OAuth2DeviceAuthorizationRequestAuthenticationToken(clientPrincipal, AUTHORIZATION_URI, null, null); + when(authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(authenticationRequest); + this.filter.setAuthenticationConverter(authenticationConverter); + + MockHttpServletRequest request = createRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + + verify(authenticationConverter).convert(request); + verify(this.authenticationManager).authenticate(authenticationRequest); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenAuthenticationDetailsSourceSetThenUsed() throws Exception { + Authentication authenticationResult = createAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + Authentication clientPrincipal = (Authentication) authenticationResult.getPrincipal(); + mockSecurityContext(clientPrincipal); + + MockHttpServletRequest request = createRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + @SuppressWarnings("unchecked") + AuthenticationDetailsSource authenticationDetailsSource = mock(AuthenticationDetailsSource.class); + when(authenticationDetailsSource.buildDetails(any(HttpServletRequest.class))).thenReturn(new WebAuthenticationDetails(request)); + this.filter.setAuthenticationDetailsSource(authenticationDetailsSource); + + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(authenticationDetailsSource).buildDetails(request); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenAuthenticationSuccessHandlerSetThenUsed() throws Exception { + Authentication authenticationResult = createAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + Authentication clientPrincipal = (Authentication) authenticationResult.getPrincipal(); + mockSecurityContext(clientPrincipal); + + AuthenticationSuccessHandler authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class); + this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler); + + MockHttpServletRequest request = createRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(authenticationSuccessHandler).onAuthenticationSuccess(request, response, authenticationResult); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenAuthenticationFailureHandlerSetThenUsed() throws Exception { + OAuth2AuthenticationException authenticationException = + new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenThrow(authenticationException); + + Authentication clientPrincipal = (Authentication) createAuthentication().getPrincipal(); + mockSecurityContext(clientPrincipal); + + AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class); + this.filter.setAuthenticationFailureHandler(authenticationFailureHandler); + + MockHttpServletRequest request = createRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(authenticationFailureHandler).onAuthenticationFailure(request, response, authenticationException); + verifyNoInteractions(filterChain); + } + + private OAuth2DeviceAuthorizationResponse readDeviceAuthorizationResponse(MockHttpServletResponse response) throws IOException { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return this.deviceAuthorizationHttpResponseConverter.read(OAuth2DeviceAuthorizationResponse.class, httpResponse); + } + + private OAuth2Error readError(MockHttpServletResponse response) throws IOException { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse); + } + + private static void mockAuthorizationServerContext() { + AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build(); + TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext( + authorizationServerSettings, () -> ISSUER_URI); + AuthorizationServerContextHolder.setContext(authorizationServerContext); + } + + private static void mockSecurityContext(Authentication clientPrincipal) { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + } + + private static MockHttpServletRequest createRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod(HttpMethod.POST.name()); + request.setRequestURI(AUTHORIZATION_URI); + request.setServletPath(AUTHORIZATION_URI); + request.setRemoteAddr(REMOTE_ADDRESS); + return request; + } + + private static OAuth2DeviceAuthorizationRequestAuthenticationToken createAuthentication() { + TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(CLIENT_ID, null); + return new OAuth2DeviceAuthorizationRequestAuthenticationToken(clientPrincipal, null, createDeviceCode(), + createUserCode()); + } + + private static OAuth2DeviceCode createDeviceCode() { + Instant issuedAt = Instant.now(); + return new OAuth2DeviceCode(DEVICE_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } + + private static OAuth2UserCode createUserCode() { + Instant issuedAt = Instant.now(); + return new OAuth2UserCode(USER_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES)); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java new file mode 100644 index 00000000..43ee6f77 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java @@ -0,0 +1,460 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web; + +import java.nio.charset.StandardCharsets; +import java.text.MessageFormat; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.http.HttpServletRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationDetailsSource; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationConsentAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceVerificationAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; +import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.WebAuthenticationDetails; +import org.springframework.web.util.UriComponentsBuilder; + +import static java.util.Map.entry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OAuth2DeviceVerificationEndpointFilter}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceVerificationEndpointFilterTests { + private static final String ISSUER_URI = "https://provider.com"; + private static final String REMOTE_ADDRESS = "remote-address"; + private static final String AUTHORIZATION_URI = "/oauth2/device_authorization"; + private static final String VERIFICATION_URI = "/oauth2/device_verification"; + private static final String CLIENT_ID = "client-1"; + private static final String STATE = "12345"; + private static final String DEVICE_CODE = "EfYu_0jEL"; + private static final String USER_CODE = "BCDF-GHJK"; + + private AuthenticationManager authenticationManager; + private OAuth2DeviceVerificationEndpointFilter filter; + + @BeforeEach + public void setUp() { + this.authenticationManager = mock(AuthenticationManager.class); + this.filter = new OAuth2DeviceVerificationEndpointFilter(this.authenticationManager); + mockAuthorizationServerContext(); + } + + @AfterEach + public void tearDown() { + SecurityContextHolder.clearContext(); + AuthorizationServerContextHolder.resetContext(); + } + + @Test + public void constructorWhenAuthenticationMangerIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceVerificationEndpointFilter(null)) + .withMessage("authenticationManager cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenDeviceVerificationEndpointUriIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2DeviceVerificationEndpointFilter(this.authenticationManager, null)) + .withMessage("deviceVerificationEndpointUri cannot be empty"); + // @formatter:on + } + + @Test + public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationConverter(null)) + .withMessage("authenticationConverter cannot be null"); + // @formatter:on + } + + @Test + public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationDetailsSource(null)) + .withMessage("authenticationDetailsSource cannot be null"); + // @formatter:on + } + + @Test + public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null)) + .withMessage("authenticationSuccessHandler cannot be null"); + // @formatter:on + } + + @Test + public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null)) + .withMessage("authenticationFailureHandler cannot be null"); + // @formatter:on + } + + @Test + public void doFilterWhenNotDeviceVerificationRequestThenNotProcessed() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "/path"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + verify(filterChain).doFilter(request, response); + verifyNoInteractions(this.authenticationManager); + } + + @Test + public void doFilterWhenUnauthenticatedThenPassThrough() throws Exception { + TestingAuthenticationToken unauthenticatedResult = new TestingAuthenticationToken("user", null); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(unauthenticatedResult); + + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(filterChain).doFilter(request, response); + } + + @Test + public void doFilterWhenDeviceAuthorizationConsentRequestThenSuccess() throws Exception { + Authentication authenticationResult = createDeviceVerificationAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + Authentication clientPrincipal = (Authentication) authenticationResult.getPrincipal(); + mockSecurityContext(clientPrincipal); + + MockHttpServletRequest request = createRequest(); + request.setMethod(HttpMethod.POST.name()); + request.addParameter(OAuth2ParameterNames.SCOPE, "scope-1"); + request.addParameter(OAuth2ParameterNames.SCOPE, "scope-2"); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + request.addParameter("custom-param-1", "custom-value-1"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getHeader(HttpHeaders.LOCATION)).isEqualTo("/?success"); + + ArgumentCaptor authenticationCaptor = + ArgumentCaptor.forClass(OAuth2DeviceAuthorizationConsentAuthenticationToken.class); + verify(this.authenticationManager).authenticate(authenticationCaptor.capture()); + verifyNoInteractions(filterChain); + + OAuth2DeviceAuthorizationConsentAuthenticationToken deviceAuthorizationConsentAuthentication = + authenticationCaptor.getValue(); + assertThat(deviceAuthorizationConsentAuthentication.getAuthorizationUri()).endsWith(VERIFICATION_URI); + assertThat(deviceAuthorizationConsentAuthentication.getClientId()).isEqualTo(CLIENT_ID); + assertThat(deviceAuthorizationConsentAuthentication.getPrincipal()) + .isInstanceOf(TestingAuthenticationToken.class); + assertThat(deviceAuthorizationConsentAuthentication.getUserCode()).isEqualTo(USER_CODE); + assertThat(deviceAuthorizationConsentAuthentication.getScopes()).containsExactly("scope-1", "scope-2"); + assertThat(deviceAuthorizationConsentAuthentication.getAdditionalParameters()) + .containsExactly(entry("custom-param-1", "custom-value-1")); + } + + @Test + public void doFilterWhenDeviceVerificationRequestAndConsentNotRequiredThenSuccess() throws Exception { + Authentication authenticationResult = createDeviceVerificationAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + Authentication clientPrincipal = (Authentication) authenticationResult.getPrincipal(); + mockSecurityContext(clientPrincipal); + + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + request.addParameter("custom-param-1", "custom-value-1"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getHeader(HttpHeaders.LOCATION)).isEqualTo("/?success"); + + ArgumentCaptor authenticationCaptor = + ArgumentCaptor.forClass(OAuth2DeviceVerificationAuthenticationToken.class); + verify(this.authenticationManager).authenticate(authenticationCaptor.capture()); + verifyNoInteractions(filterChain); + + OAuth2DeviceVerificationAuthenticationToken deviceVerificationAuthentication = authenticationCaptor.getValue(); + assertThat(deviceVerificationAuthentication.getPrincipal()).isInstanceOf(TestingAuthenticationToken.class); + assertThat(deviceVerificationAuthentication.getUserCode()).isEqualTo(USER_CODE); + assertThat(deviceVerificationAuthentication.getAdditionalParameters()) + .containsExactly(entry("custom-param-1", "custom-value-1")); + } + + @Test + public void doFilterWhenDeviceVerificationRequestAndConsentRequiredThenConsentScreen() throws Exception { + Authentication authenticationResult = createDeviceAuthorizationConsentAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + assertThat(response.getContentType()) + .isEqualTo(new MediaType("text", "html", StandardCharsets.UTF_8).toString()); + assertThat(response.getContentAsString()).contains(scopeCheckbox("scope-1")); + assertThat(response.getContentAsString()).contains(scopeCheckbox("scope-2")); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenDeviceVerificationRequestAndConsentRequiredWithPreviouslyApprovedThenConsentScreen() throws Exception { + Authentication authenticationResult = createDeviceAuthorizationConsentAuthenticationWithAuthorizedScopes(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + assertThat(response.getContentType()) + .isEqualTo(new MediaType("text", "html", StandardCharsets.UTF_8).toString()); + assertThat(response.getContentAsString()).contains(disabledScopeCheckbox("scope-1")); + assertThat(response.getContentAsString()).contains(scopeCheckbox("scope-2")); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenDeviceVerificationRequestAndConsentRequiredAndConsentPageSetThenRedirect() throws Exception { + Authentication authentication = createDeviceAuthorizationConsentAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authentication); + + MockHttpServletRequest request = createRequest(); + request.setScheme("https"); + request.setServerPort(443); + request.setServerName("provider.com"); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.setConsentPage("/consent"); + this.filter.doFilter(request, response, filterChain); + String redirectUri = UriComponentsBuilder.fromUriString("https://provider.com/consent") + .queryParam(OAuth2ParameterNames.SCOPE, "scope-1 scope-2") + .queryParam(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID) + .queryParam(OAuth2ParameterNames.STATE, STATE) + .queryParam(OAuth2ParameterNames.USER_CODE, USER_CODE) + .toUriString(); + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getHeader(HttpHeaders.LOCATION)).isEqualTo(redirectUri); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenAuthenticationConverterSetThenUsed() throws Exception { + Authentication authenticationResult = createDeviceVerificationAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + OAuth2DeviceVerificationAuthenticationToken deviceVerificationAuthentication = + new OAuth2DeviceVerificationAuthenticationToken((Authentication) authenticationResult.getPrincipal(), + USER_CODE, Collections.emptyMap()); + when(authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(deviceVerificationAuthentication); + this.filter.setAuthenticationConverter(authenticationConverter); + + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getHeader(HttpHeaders.LOCATION)).isEqualTo("/?success"); + + verify(authenticationConverter).convert(request); + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenAuthenticationDetailsSourceSetThenUsed() throws Exception { + Authentication authenticationResult = createDeviceVerificationAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + @SuppressWarnings("unchecked") + AuthenticationDetailsSource authenticationDetailsSource = mock(AuthenticationDetailsSource.class); + when(authenticationDetailsSource.buildDetails(any(HttpServletRequest.class))).thenReturn(new WebAuthenticationDetails(request)); + this.filter.setAuthenticationDetailsSource(authenticationDetailsSource); + + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getHeader(HttpHeaders.LOCATION)).isEqualTo("/?success"); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(authenticationDetailsSource).buildDetails(request); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenAuthenticationSuccessHandlerSetThenUsed() throws Exception { + Authentication authenticationResult = createDeviceVerificationAuthentication(); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authenticationResult); + + AuthenticationSuccessHandler authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class); + this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler); + + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(authenticationSuccessHandler).onAuthenticationSuccess(request, response, authenticationResult); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenAuthenticationFailureHandlerSetThenUsed() throws Exception { + OAuth2AuthenticationException authenticationException = + new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + when(this.authenticationManager.authenticate(any(Authentication.class))).thenThrow(authenticationException); + + AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class); + this.filter.setAuthenticationFailureHandler(authenticationFailureHandler); + + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + this.filter.doFilter(request, response, filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(authenticationFailureHandler).onAuthenticationFailure(request, response, authenticationException); + verifyNoInteractions(filterChain); + } + + private static void mockAuthorizationServerContext() { + AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build(); + TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext( + authorizationServerSettings, () -> ISSUER_URI); + AuthorizationServerContextHolder.setContext(authorizationServerContext); + } + + private static void mockSecurityContext(Authentication clientPrincipal) { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + } + + private static OAuth2DeviceVerificationAuthenticationToken createDeviceVerificationAuthentication() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", null); + return new OAuth2DeviceVerificationAuthenticationToken(principal, CLIENT_ID, USER_CODE); + } + + private static Authentication createDeviceAuthorizationConsentAuthentication() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", null); + Set requestedScopes = new HashSet<>(); + requestedScopes.add("scope-1"); + requestedScopes.add("scope-2"); + return new OAuth2DeviceAuthorizationConsentAuthenticationToken(AUTHORIZATION_URI, CLIENT_ID, principal, + USER_CODE, STATE, requestedScopes, new HashSet<>()); + } + + private static Authentication createDeviceAuthorizationConsentAuthenticationWithAuthorizedScopes() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", null); + Set requestedScopes = new HashSet<>(); + requestedScopes.add("scope-1"); + requestedScopes.add("scope-2"); + Set authorizedScopes = new HashSet<>(); + authorizedScopes.add("scope-1"); + return new OAuth2DeviceAuthorizationConsentAuthenticationToken(AUTHORIZATION_URI, CLIENT_ID, principal, + USER_CODE, STATE, requestedScopes, authorizedScopes); + } + + private static MockHttpServletRequest createRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod(HttpMethod.GET.name()); + request.setRequestURI(VERIFICATION_URI); + request.setServletPath(VERIFICATION_URI); + request.setRemoteAddr(REMOTE_ADDRESS); + return request; + } + + private static String scopeCheckbox(String scope) { + return MessageFormat.format( + "", + scope + ); + } + + private static String disabledScopeCheckbox(String scope) { + return MessageFormat.format( + "", + scope + ); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverterTests.java new file mode 100644 index 00000000..c73ecd06 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverterTests.java @@ -0,0 +1,295 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web.authentication; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationConsentAuthenticationToken; + +import static java.util.Map.entry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * Tests for {@link OAuth2DeviceAuthorizationConsentAuthenticationConverter}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceAuthorizationConsentAuthenticationConverterTests { + private static final String VERIFICATION_URI = "/oauth2/device_verification"; + private static final String USER_CODE = "BCDF-GHJK"; + private static final String CLIENT_ID = "client-1"; + private static final String STATE = "abc123"; + + private OAuth2DeviceAuthorizationConsentAuthenticationConverter converter; + + @BeforeEach + public void setUp() { + this.converter = new OAuth2DeviceAuthorizationConsentAuthenticationConverter(); + } + + @AfterEach + public void tearDown() { + SecurityContextHolder.clearContext(); + } + + @Test + public void convertWhenGetThenReturnNull() { + MockHttpServletRequest request = createRequest(); + request.setMethod(HttpMethod.GET.name()); + assertThat(this.converter.convert(request)).isNull(); + } + + @Test + public void convertWhenMissingStateThenReturnNull() { + MockHttpServletRequest request = createRequest(); + assertThat(this.converter.convert(request)).isNull(); + } + + @Test + public void convertWhenMissingClientIdThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.CLIENT_ID) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenBlankClientIdThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, ""); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.CLIENT_ID) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenMultipleClientIdParametersThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, "another"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.CLIENT_ID) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenMissingUserCodeThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.USER_CODE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenBlankUserCodeThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.USER_CODE, ""); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.USER_CODE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenMultipleUserCodeParametersThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + request.addParameter(OAuth2ParameterNames.USER_CODE, "another"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.USER_CODE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenBlankStateParameterThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, ""); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.STATE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenMultipleStateParametersThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.STATE, "another"); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.STATE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenMissingPrincipalThenReturnDeviceAuthorizationConsentAuthentication() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + + OAuth2DeviceAuthorizationConsentAuthenticationToken authentication = + (OAuth2DeviceAuthorizationConsentAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getAuthorizationUri()).endsWith(VERIFICATION_URI); + assertThat(authentication.getClientId()).isEqualTo(CLIENT_ID); + assertThat(authentication.getPrincipal()).isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(authentication.getUserCode()).isEqualTo(USER_CODE); + assertThat(authentication.getScopes()).isEmpty(); + assertThat(authentication.getAdditionalParameters()).isEmpty(); + } + + @Test + public void convertWhenMissingScopeThenReturnDeviceAuthorizationConsentAuthentication() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + + SecurityContextImpl securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", null)); + SecurityContextHolder.setContext(securityContext); + + OAuth2DeviceAuthorizationConsentAuthenticationToken authentication = + (OAuth2DeviceAuthorizationConsentAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getAuthorizationUri()).endsWith(VERIFICATION_URI); + assertThat(authentication.getClientId()).isEqualTo(CLIENT_ID); + assertThat(authentication.getPrincipal()).isInstanceOf(TestingAuthenticationToken.class); + assertThat(authentication.getUserCode()).isEqualTo(USER_CODE); + assertThat(authentication.getScopes()).isEmpty(); + assertThat(authentication.getAdditionalParameters()).isEmpty(); + } + + @Test + public void convertWhenAllParametersThenReturnDeviceAuthorizationConsentAuthentication() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + request.addParameter(OAuth2ParameterNames.SCOPE, "message.read"); + request.addParameter(OAuth2ParameterNames.SCOPE, "message.write"); + request.addParameter("param-1", "value-1"); + request.addParameter("param-2", "value-2"); + + SecurityContextImpl securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", null)); + SecurityContextHolder.setContext(securityContext); + + OAuth2DeviceAuthorizationConsentAuthenticationToken authentication = + (OAuth2DeviceAuthorizationConsentAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getAuthorizationUri()).endsWith(VERIFICATION_URI); + assertThat(authentication.getClientId()).isEqualTo(CLIENT_ID); + assertThat(authentication.getPrincipal()).isInstanceOf(TestingAuthenticationToken.class); + assertThat(authentication.getUserCode()).isEqualTo(USER_CODE); + assertThat(authentication.getScopes()).containsExactly("message.read", "message.write"); + assertThat(authentication.getAdditionalParameters()) + .containsExactly(entry("param-1", "value-1"), entry("param-2", "value-2")); + } + + @Test + public void convertWhenNonNormalizedUserCodeThenReturnDeviceAuthorizationConsentAuthentication() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.STATE, STATE); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE.toLowerCase().replace("-", " . ")); + + SecurityContextImpl securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", null)); + SecurityContextHolder.setContext(securityContext); + + OAuth2DeviceAuthorizationConsentAuthenticationToken authentication = + (OAuth2DeviceAuthorizationConsentAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getAuthorizationUri()).endsWith(VERIFICATION_URI); + assertThat(authentication.getClientId()).isEqualTo(CLIENT_ID); + assertThat(authentication.getPrincipal()).isInstanceOf(TestingAuthenticationToken.class); + assertThat(authentication.getUserCode()).isEqualTo(USER_CODE); + assertThat(authentication.getScopes()).isEmpty(); + assertThat(authentication.getAdditionalParameters()).isEmpty(); + } + + private static MockHttpServletRequest createRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod(HttpMethod.POST.name()); + request.setRequestURI(VERIFICATION_URI); + return request; + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverterTests.java new file mode 100644 index 00000000..147f7409 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverterTests.java @@ -0,0 +1,120 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web.authentication; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationRequestAuthenticationToken; + +import static java.util.Map.entry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * Tests for {@link OAuth2DeviceAuthorizationRequestAuthenticationConverter}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceAuthorizationRequestAuthenticationConverterTests { + private static final String AUTHORIZATION_URI = "/oauth2/device_authorization"; + private static final String CLIENT_ID = "client-1"; + + private OAuth2DeviceAuthorizationRequestAuthenticationConverter converter; + + @BeforeEach + public void setUp() { + this.converter = new OAuth2DeviceAuthorizationRequestAuthenticationConverter(); + } + + @AfterEach + public void tearDown() { + SecurityContextHolder.clearContext(); + } + + @Test + public void convertWhenMultipleScopeParametersThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.SCOPE, "message.read"); + request.addParameter(OAuth2ParameterNames.SCOPE, "message.write"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.SCOPE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenMissingScopeThenReturnDeviceAuthorizationRequestAuthenticationToken() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + + SecurityContextImpl securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(new TestingAuthenticationToken(CLIENT_ID, null)); + SecurityContextHolder.setContext(securityContext); + + OAuth2DeviceAuthorizationRequestAuthenticationToken authentication = + (OAuth2DeviceAuthorizationRequestAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getPrincipal()).isInstanceOf(TestingAuthenticationToken.class); + assertThat(authentication.getAuthorizationUri()).endsWith(AUTHORIZATION_URI); + assertThat(authentication.getScopes()).isEmpty(); + assertThat(authentication.getAdditionalParameters()).isEmpty(); + } + + @Test + public void convertWhenAllParametersThenReturnDeviceAuthorizationRequestAuthenticationToken() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.SCOPE, "message.read message.write"); + request.addParameter("param-1", "value-1"); + request.addParameter("param-2", "value-2"); + + SecurityContextImpl securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(new TestingAuthenticationToken(CLIENT_ID, null)); + SecurityContextHolder.setContext(securityContext); + + OAuth2DeviceAuthorizationRequestAuthenticationToken authentication = + (OAuth2DeviceAuthorizationRequestAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getPrincipal()).isInstanceOf(TestingAuthenticationToken.class); + assertThat(authentication.getAuthorizationUri()).endsWith(AUTHORIZATION_URI); + assertThat(authentication.getScopes()).containsExactly("message.read", "message.write"); + assertThat(authentication.getAdditionalParameters()) + .containsExactly(entry("param-1", "value-1"), entry("param-2", "value-2")); + } + + private static MockHttpServletRequest createRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod(HttpMethod.POST.name()); + request.setRequestURI(AUTHORIZATION_URI); + return request; + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverterTests.java new file mode 100644 index 00000000..605dc7b6 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverterTests.java @@ -0,0 +1,126 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web.authentication; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceCodeAuthenticationToken; + +import static java.util.Map.entry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * Tests for {@link OAuth2DeviceCodeAuthenticationConverter}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceCodeAuthenticationConverterTests { + private static final String CLIENT_ID = "client-1"; + private static final String TOKEN_URI = "/oauth2/token"; + private static final String DEVICE_CODE = "EfYu_0jEL"; + + private OAuth2DeviceCodeAuthenticationConverter converter; + + @BeforeEach + public void setUp() { + this.converter = new OAuth2DeviceCodeAuthenticationConverter(); + } + + @AfterEach + public void tearDown() { + SecurityContextHolder.clearContext(); + } + + @Test + public void convertWhenMissingGrantTypeThenReturnNull() { + MockHttpServletRequest request = createRequest(); + Authentication authentication = this.converter.convert(request); + assertThat(authentication).isNull(); + } + + @Test + public void convertWhenMissingDeviceCodeThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.DEVICE_CODE.getValue()); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.DEVICE_CODE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenMultipleDeviceCodeParametersThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.DEVICE_CODE.getValue()); + request.addParameter(OAuth2ParameterNames.DEVICE_CODE, DEVICE_CODE); + request.addParameter(OAuth2ParameterNames.DEVICE_CODE, "another"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.DEVICE_CODE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenAllParametersThenReturnDeviceCodeAuthenticationToken() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, CLIENT_ID); + request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.DEVICE_CODE.getValue()); + request.addParameter(OAuth2ParameterNames.DEVICE_CODE, DEVICE_CODE); + request.addParameter("param-1", "value-1"); + request.addParameter("param-2", "value-2"); + + SecurityContextImpl securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(new TestingAuthenticationToken(CLIENT_ID, null)); + SecurityContextHolder.setContext(securityContext); + + OAuth2DeviceCodeAuthenticationToken authentication = + (OAuth2DeviceCodeAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getDeviceCode()).isEqualTo(DEVICE_CODE); + assertThat(authentication.getPrincipal()).isInstanceOf(TestingAuthenticationToken.class); + assertThat(authentication.getAdditionalParameters()) + .containsExactly(entry("param-1", "value-1"), entry("param-2", "value-2")); + } + + private static MockHttpServletRequest createRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod(HttpMethod.POST.name()); + request.setRequestURI(TOKEN_URI); + return request; + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java new file mode 100644 index 00000000..0d19f1a9 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java @@ -0,0 +1,168 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web.authentication; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceVerificationAuthenticationToken; + +import static java.util.Map.entry; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * Tests for {@link OAuth2DeviceVerificationAuthenticationConverter}. + * + * @author Steve Riesenberg + */ +public class OAuth2DeviceVerificationAuthenticationConverterTests { + private static final String VERIFICATION_URI = "/oauth2/device_verification"; + private static final String USER_CODE = "BCDF-GHJK"; + + private OAuth2DeviceVerificationAuthenticationConverter converter; + + @BeforeEach + public void setUp() { + this.converter = new OAuth2DeviceVerificationAuthenticationConverter(); + } + + @AfterEach + public void tearDown() { + SecurityContextHolder.clearContext(); + } + + @Test + public void convertWhenPutThenReturnNull() { + MockHttpServletRequest request = createRequest(); + request.setMethod(HttpMethod.PUT.name()); + Authentication authentication = this.converter.convert(request); + assertThat(authentication).isNull(); + } + + @Test + public void convertWhenStateThenReturnNull() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.STATE, "abc123"); + Authentication authentication = this.converter.convert(request); + assertThat(authentication).isNull(); + } + + @Test + public void convertWhenMissingUserCodeThenReturnNull() { + MockHttpServletRequest request = createRequest(); + Authentication authentication = this.converter.convert(request); + assertThat(authentication).isNull(); + } + + @Test + public void convertWhenBlankUserCodeParametersThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, ""); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.USER_CODE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenMultipleUserCodeParametersThenInvalidRequestError() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + request.addParameter(OAuth2ParameterNames.USER_CODE, "another"); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.converter.convert(request)) + .withMessageContaining(OAuth2ParameterNames.USER_CODE) + .extracting(OAuth2AuthenticationException::getError) + .extracting(OAuth2Error::getErrorCode) + .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void convertWhenMissingPrincipalThenReturnDeviceVerificationAuthentication() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE.toLowerCase().replace("-", " . ")); + + OAuth2DeviceVerificationAuthenticationToken authentication = + (OAuth2DeviceVerificationAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getPrincipal()).isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(authentication.getUserCode()).isEqualTo(USER_CODE); + assertThat(authentication.getAdditionalParameters()).isEmpty(); + } + + @Test + public void convertWhenNonNormalizedUserCodeThenReturnDeviceVerificationAuthentication() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE.toLowerCase().replace("-", " . ")); + + SecurityContextImpl securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", null)); + SecurityContextHolder.setContext(securityContext); + + OAuth2DeviceVerificationAuthenticationToken authentication = + (OAuth2DeviceVerificationAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getPrincipal()).isInstanceOf(TestingAuthenticationToken.class); + assertThat(authentication.getUserCode()).isEqualTo(USER_CODE); + assertThat(authentication.getAdditionalParameters()).isEmpty(); + } + + @Test + public void convertWhenAllParametersThenReturnDeviceVerificationAuthentication() { + MockHttpServletRequest request = createRequest(); + request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + request.addParameter("param-1", "value-1"); + request.addParameter("param-2", "value-2"); + + SecurityContextImpl securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", null)); + SecurityContextHolder.setContext(securityContext); + + OAuth2DeviceVerificationAuthenticationToken authentication = + (OAuth2DeviceVerificationAuthenticationToken) this.converter.convert(request); + assertThat(authentication).isNotNull(); + assertThat(authentication.getPrincipal()).isInstanceOf(TestingAuthenticationToken.class); + assertThat(authentication.getUserCode()).isEqualTo(USER_CODE); + assertThat(authentication.getAdditionalParameters()) + .containsExactly(entry("param-1", "value-1"), entry("param-2", "value-2")); + } + + private static MockHttpServletRequest createRequest() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod(HttpMethod.GET.name()); + request.setRequestURI(VERIFICATION_URI); + return request; + } +}