9 changed files with 615 additions and 626 deletions
@ -1,55 +0,0 @@
@@ -1,55 +0,0 @@
|
||||
/* |
||||
* Copyright 2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.oauth2.server.authorization.web; |
||||
|
||||
import java.util.Arrays; |
||||
import java.util.Collections; |
||||
import java.util.LinkedHashSet; |
||||
import java.util.Set; |
||||
|
||||
import javax.servlet.http.HttpServletRequest; |
||||
|
||||
import org.springframework.core.convert.converter.Converter; |
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; |
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; |
||||
import org.springframework.util.StringUtils; |
||||
|
||||
/** |
||||
* @author Paurav Munshi |
||||
* @since 0.0.1 |
||||
* @see Converter |
||||
*/ |
||||
public class OAuth2AuthorizationRequestConverter implements Converter<HttpServletRequest, OAuth2AuthorizationRequest> { |
||||
|
||||
@Override |
||||
public OAuth2AuthorizationRequest convert(HttpServletRequest request) { |
||||
String scope = request.getParameter(OAuth2ParameterNames.SCOPE); |
||||
Set<String> scopes = !StringUtils.isEmpty(scope) |
||||
? new LinkedHashSet<String>(Arrays.asList(scope.split(" "))) |
||||
: Collections.emptySet(); |
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() |
||||
.clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID)) |
||||
.redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI)) |
||||
.scopes(scopes) |
||||
.state(request.getParameter(OAuth2ParameterNames.STATE)) |
||||
.authorizationUri(request.getServletPath()) |
||||
.build(); |
||||
|
||||
return authorizationRequest; |
||||
} |
||||
|
||||
} |
||||
@ -1,371 +0,0 @@
@@ -1,371 +0,0 @@
|
||||
/* |
||||
* Copyright 2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.oauth2.server.authorization.web; |
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat; |
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy; |
||||
import static org.mockito.ArgumentMatchers.any; |
||||
import static org.mockito.ArgumentMatchers.anyString; |
||||
import static org.mockito.Mockito.mock; |
||||
import static org.mockito.Mockito.spy; |
||||
import static org.mockito.Mockito.times; |
||||
import static org.mockito.Mockito.verify; |
||||
import static org.mockito.Mockito.when; |
||||
|
||||
import javax.servlet.FilterChain; |
||||
import javax.servlet.http.HttpServletRequest; |
||||
import javax.servlet.http.HttpServletResponse; |
||||
|
||||
import org.junit.Before; |
||||
import org.junit.Test; |
||||
import org.springframework.http.HttpStatus; |
||||
import org.springframework.mock.web.MockHttpServletRequest; |
||||
import org.springframework.mock.web.MockHttpServletResponse; |
||||
import org.springframework.security.core.Authentication; |
||||
import org.springframework.security.core.context.SecurityContextHolder; |
||||
import org.springframework.security.crypto.keygen.StringKeyGenerator; |
||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; |
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; |
||||
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; |
||||
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; |
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; |
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; |
||||
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; |
||||
|
||||
|
||||
/** |
||||
* Tests for {@link OAuth2AuthorizationEndpointFilter}. |
||||
* |
||||
* @author Paurav Munshi |
||||
* @since 0.0.1 |
||||
*/ |
||||
|
||||
public class OAuth2AuthorizationEndpointFilterTest { |
||||
|
||||
private static final String VALID_CLIENT = "valid_client"; |
||||
private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri"; |
||||
private static final String VALID_CC_CLIENT = "valid_cc_client"; |
||||
|
||||
private OAuth2AuthorizationEndpointFilter filter; |
||||
|
||||
private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class); |
||||
private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class); |
||||
private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class); |
||||
private Authentication authentication = mock(Authentication.class); |
||||
|
||||
@Before |
||||
public void setUp() { |
||||
this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService); |
||||
this.filter.setCodeGenerator(this.codeGenerator); |
||||
|
||||
SecurityContextHolder.getContext().setAuthentication(this.authentication); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception { |
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService)) |
||||
.isInstanceOf(IllegalArgumentException.class); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception { |
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null)) |
||||
.isInstanceOf(IllegalArgumentException.class); |
||||
} |
||||
|
||||
@Test |
||||
public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { |
||||
assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null)) |
||||
.isInstanceOf(IllegalArgumentException.class); |
||||
} |
||||
|
||||
@Test |
||||
public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { |
||||
assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null)) |
||||
.isInstanceOf(IllegalArgumentException.class); |
||||
} |
||||
|
||||
@Test |
||||
public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { |
||||
assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null)) |
||||
.isInstanceOf(IllegalArgumentException.class); |
||||
} |
||||
|
||||
@Test |
||||
public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception { |
||||
assertThatThrownBy(() ->this.filter.setCodeGenerator(null)) |
||||
.isInstanceOf(IllegalArgumentException.class); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); |
||||
when(this.codeGenerator.generateKey()).thenReturn("sample_code"); |
||||
when(this.authentication.getPrincipal()).thenReturn("test-user"); |
||||
when(this.authentication.isAuthenticated()).thenReturn(true); |
||||
|
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(this.authentication).isAuthenticated(); |
||||
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT); |
||||
verify(this.authorizationService).save(any(OAuth2Authorization.class)); |
||||
verify(this.codeGenerator).generateKey(); |
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); |
||||
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); |
||||
|
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); |
||||
when(this.codeGenerator.generateKey()).thenReturn("sample_code"); |
||||
when(this.authentication.getPrincipal()).thenReturn("test-user"); |
||||
when(this.authentication.isAuthenticated()).thenReturn(true); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(this.authentication).isAuthenticated(); |
||||
verify(this.registeredClientRepository).findByClientId(VALID_CLIENT); |
||||
verify(this.authorizationService).save(any(OAuth2Authorization.class)); |
||||
verify(this.codeGenerator).generateKey(); |
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); |
||||
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate"); |
||||
|
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI); |
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, ""); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build(); |
||||
when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient); |
||||
when(this.authentication.isAuthenticated()).thenReturn(true); |
||||
|
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(this.authentication, times(1)).isAuthenticated(); |
||||
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI); |
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); |
||||
verify(this.codeGenerator, times(0)).generateKey(); |
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); |
||||
|
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback"); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient); |
||||
when(this.authentication.isAuthenticated()).thenReturn(true); |
||||
|
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(this.authentication, times(1)).isAuthenticated(); |
||||
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT); |
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); |
||||
verify(this.codeGenerator, times(0)).generateKey(); |
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); |
||||
|
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient); |
||||
when(this.authentication.isAuthenticated()).thenReturn(true); |
||||
|
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(this.authentication, times(1)).isAuthenticated(); |
||||
verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT); |
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); |
||||
verify(this.codeGenerator, times(0)).generateKey(); |
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); |
||||
|
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, ""); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
when(this.authentication.isAuthenticated()).thenReturn(true); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(this.authentication).isAuthenticated(); |
||||
verify(this.registeredClientRepository, times(0)).findByClientId(anyString()); |
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); |
||||
verify(this.codeGenerator, times(0)).generateKey(); |
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value()); |
||||
assertThat(response.getContentAsString()).isEmpty(); |
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); |
||||
|
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client"); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null); |
||||
when(this.codeGenerator.generateKey()).thenReturn("sample_code"); |
||||
when(this.authentication.isAuthenticated()).thenReturn(true); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(this.authentication).isAuthenticated(); |
||||
verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client"); |
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); |
||||
verify(this.codeGenerator, times(0)).generateKey(); |
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); |
||||
assertThat(response.getContentAsString()).isEmpty(); |
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); |
||||
|
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
when(authentication.isAuthenticated()).thenReturn(false); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(this.authentication).isAuthenticated(); |
||||
verify(this.registeredClientRepository, times(0)).findByClientId(anyString()); |
||||
verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class)); |
||||
verify(this.codeGenerator, times(0)).generateKey(); |
||||
verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value()); |
||||
assertThat(response.getContentAsString()).isEmpty(); |
||||
assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED); |
||||
|
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
request.setServletPath("/custom/authorize"); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); |
||||
spyFilter.doFilter(request, response, filterChain); |
||||
|
||||
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); |
||||
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, ""); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); |
||||
spyFilter.doFilter(request, response, filterChain); |
||||
|
||||
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); |
||||
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); |
||||
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception { |
||||
MockHttpServletRequest request = getValidMockHttpServletRequest(); |
||||
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token"); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter); |
||||
spyFilter.doFilter(request, response, filterChain); |
||||
|
||||
verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class)); |
||||
verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); |
||||
verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
} |
||||
|
||||
private MockHttpServletRequest getValidMockHttpServletRequest() { |
||||
|
||||
MockHttpServletRequest request = new MockHttpServletRequest(); |
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT); |
||||
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code"); |
||||
request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email"); |
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback"); |
||||
request.setParameter(OAuth2ParameterNames.STATE, "teststate"); |
||||
request.setServletPath("/oauth2/authorize"); |
||||
|
||||
return request; |
||||
|
||||
|
||||
} |
||||
|
||||
} |
||||
@ -0,0 +1,399 @@
@@ -0,0 +1,399 @@
|
||||
/* |
||||
* Copyright 2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.oauth2.server.authorization.web; |
||||
|
||||
import org.junit.After; |
||||
import org.junit.Before; |
||||
import org.junit.Test; |
||||
import org.mockito.ArgumentCaptor; |
||||
import org.springframework.http.HttpStatus; |
||||
import org.springframework.mock.web.MockHttpServletRequest; |
||||
import org.springframework.mock.web.MockHttpServletResponse; |
||||
import org.springframework.security.authentication.TestingAuthenticationToken; |
||||
import org.springframework.security.core.context.SecurityContext; |
||||
import org.springframework.security.core.context.SecurityContextHolder; |
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType; |
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; |
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; |
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; |
||||
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; |
||||
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; |
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; |
||||
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; |
||||
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; |
||||
import org.springframework.util.StringUtils; |
||||
|
||||
import javax.servlet.FilterChain; |
||||
import javax.servlet.http.HttpServletRequest; |
||||
import javax.servlet.http.HttpServletResponse; |
||||
import java.util.Set; |
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat; |
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy; |
||||
import static org.mockito.ArgumentMatchers.any; |
||||
import static org.mockito.ArgumentMatchers.eq; |
||||
import static org.mockito.Mockito.mock; |
||||
import static org.mockito.Mockito.verify; |
||||
import static org.mockito.Mockito.verifyNoInteractions; |
||||
import static org.mockito.Mockito.when; |
||||
|
||||
/** |
||||
* Tests for {@link OAuth2AuthorizationEndpointFilter}. |
||||
* |
||||
* @author Paurav Munshi |
||||
* @author Joe Grandja |
||||
* @since 0.0.1 |
||||
*/ |
||||
public class OAuth2AuthorizationEndpointFilterTests { |
||||
private RegisteredClientRepository registeredClientRepository; |
||||
private OAuth2AuthorizationService authorizationService; |
||||
private OAuth2AuthorizationEndpointFilter filter; |
||||
private TestingAuthenticationToken authentication; |
||||
|
||||
@Before |
||||
public void setUp() { |
||||
this.registeredClientRepository = mock(RegisteredClientRepository.class); |
||||
this.authorizationService = mock(OAuth2AuthorizationService.class); |
||||
this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService); |
||||
this.authentication = new TestingAuthenticationToken("principalName", "password"); |
||||
this.authentication.setAuthenticated(true); |
||||
SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); |
||||
securityContext.setAuthentication(this.authentication); |
||||
SecurityContextHolder.setContext(securityContext); |
||||
} |
||||
|
||||
@After |
||||
public void cleanup() { |
||||
SecurityContextHolder.clearContext(); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("registeredClientRepository cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("authorizationService cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenAuthorizationEndpointUriNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService, null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("authorizationEndpointUri cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenNotAuthorizationRequestThenNotProcessed() throws Exception { |
||||
String requestUri = "/path"; |
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); |
||||
request.setServletPath(requestUri); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestPostThenNotProcessed() throws Exception { |
||||
String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI; |
||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); |
||||
request.setServletPath(requestUri); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestNotAuthenticatedThenNotProcessed() throws Exception { |
||||
String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI; |
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); |
||||
request.setServletPath(requestUri); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.authentication.setAuthenticated(false); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
request.removeParameter(OAuth2ParameterNames.CLIENT_ID); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid"); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestAndClientNotAuthorizedToRequestCodeThenUnauthorizedClientError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient() |
||||
.authorizationGrantTypes(Set::clear) |
||||
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) |
||||
.build(); |
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
||||
.thenReturn(registeredClient); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestInvalidRedirectUriThenInvalidRequestError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
||||
.thenReturn(registeredClient); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com"); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestMultipleRedirectUriThenInvalidRequestError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
||||
.thenReturn(registeredClient); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestExcludesRedirectUriAndMultipleRegisteredThenInvalidRequestError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().redirectUri("https://example2.com").build(); |
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
||||
.thenReturn(registeredClient); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
request.removeParameter(OAuth2ParameterNames.REDIRECT_URI); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
||||
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestMissingResponseTypeThenInvalidRequestError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
||||
.thenReturn(registeredClient); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); |
||||
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + |
||||
"error=invalid_request&" + |
||||
"error_description=OAuth%202.0%20Parameter:%20response_type&" + |
||||
"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" + |
||||
"state=state"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestMultipleResponseTypeThenInvalidRequestError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
||||
.thenReturn(registeredClient); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); |
||||
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + |
||||
"error=invalid_request&" + |
||||
"error_description=OAuth%202.0%20Parameter:%20response_type&" + |
||||
"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" + |
||||
"state=state"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestInvalidResponseTypeThenUnsupportedResponseTypeError() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
||||
.thenReturn(registeredClient); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); |
||||
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" + |
||||
"error=unsupported_response_type&" + |
||||
"error_description=OAuth%202.0%20Parameter:%20response_type&" + |
||||
"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" + |
||||
"state=state"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAuthorizationRequestValidThenAuthorizationResponse() throws Exception { |
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
||||
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
||||
.thenReturn(registeredClient); |
||||
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); |
||||
assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state"); |
||||
|
||||
ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); |
||||
|
||||
verify(this.authorizationService).save(authorizationCaptor.capture()); |
||||
|
||||
OAuth2Authorization authorization = authorizationCaptor.getValue(); |
||||
assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); |
||||
assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString()); |
||||
|
||||
String code = authorization.getAttribute(OAuth2ParameterNames.class.getName().concat(".CODE")); |
||||
assertThat(code).isNotNull(); |
||||
|
||||
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName()); |
||||
assertThat(authorizationRequest).isNotNull(); |
||||
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo("http://localhost/oauth2/authorize"); |
||||
assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); |
||||
assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); |
||||
assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId()); |
||||
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(registeredClient.getRedirectUris().iterator().next()); |
||||
assertThat(authorizationRequest.getScopes()).containsExactlyInAnyOrderElementsOf(registeredClient.getScopes()); |
||||
assertThat(authorizationRequest.getState()).isEqualTo("state"); |
||||
assertThat(authorizationRequest.getAdditionalParameters()).isEmpty(); |
||||
} |
||||
|
||||
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) { |
||||
String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]); |
||||
|
||||
String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI; |
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); |
||||
request.setServletPath(requestUri); |
||||
|
||||
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); |
||||
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); |
||||
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]); |
||||
request.addParameter(OAuth2ParameterNames.SCOPE, |
||||
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); |
||||
request.addParameter(OAuth2ParameterNames.STATE, "state"); |
||||
|
||||
return request; |
||||
} |
||||
} |
||||
Loading…
Reference in new issue