|
|
|
|
@ -26,6 +26,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
@@ -26,6 +26,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
|
|
|
import org.springframework.security.core.context.SecurityContext; |
|
|
|
|
import org.springframework.security.core.context.SecurityContextHolder; |
|
|
|
|
import org.springframework.security.oauth2.core.AuthorizationGrantType; |
|
|
|
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; |
|
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; |
|
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; |
|
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; |
|
|
|
|
@ -41,6 +42,7 @@ import javax.servlet.FilterChain;
@@ -41,6 +42,7 @@ import javax.servlet.FilterChain;
|
|
|
|
|
import javax.servlet.http.HttpServletRequest; |
|
|
|
|
import javax.servlet.http.HttpServletResponse; |
|
|
|
|
import java.util.Set; |
|
|
|
|
import java.util.function.Consumer; |
|
|
|
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat; |
|
|
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy; |
|
|
|
|
@ -130,53 +132,29 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -130,53 +132,29 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception { |
|
|
|
|
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
|
|
|
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
|
|
|
|
request.removeParameter(OAuth2ParameterNames.CLIENT_ID); |
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse(); |
|
|
|
|
FilterChain filterChain = mock(FilterChain.class); |
|
|
|
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain); |
|
|
|
|
|
|
|
|
|
verifyNoInteractions(filterChain); |
|
|
|
|
|
|
|
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
|
|
|
|
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); |
|
|
|
|
doFilterWhenAuthorizationRequestInvalidParameterThenError( |
|
|
|
|
TestRegisteredClients.registeredClient().build(), |
|
|
|
|
OAuth2ParameterNames.CLIENT_ID, |
|
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST, |
|
|
|
|
request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception { |
|
|
|
|
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
|
|
|
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
|
|
|
|
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); |
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse(); |
|
|
|
|
FilterChain filterChain = mock(FilterChain.class); |
|
|
|
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain); |
|
|
|
|
|
|
|
|
|
verifyNoInteractions(filterChain); |
|
|
|
|
|
|
|
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
|
|
|
|
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); |
|
|
|
|
doFilterWhenAuthorizationRequestInvalidParameterThenError( |
|
|
|
|
TestRegisteredClients.registeredClient().build(), |
|
|
|
|
OAuth2ParameterNames.CLIENT_ID, |
|
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST, |
|
|
|
|
request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2")); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception { |
|
|
|
|
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); |
|
|
|
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
|
|
|
|
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid"); |
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse(); |
|
|
|
|
FilterChain filterChain = mock(FilterChain.class); |
|
|
|
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain); |
|
|
|
|
|
|
|
|
|
verifyNoInteractions(filterChain); |
|
|
|
|
|
|
|
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
|
|
|
|
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id"); |
|
|
|
|
doFilterWhenAuthorizationRequestInvalidParameterThenError( |
|
|
|
|
TestRegisteredClients.registeredClient().build(), |
|
|
|
|
OAuth2ParameterNames.CLIENT_ID, |
|
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST, |
|
|
|
|
request -> request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid")); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
@ -188,16 +166,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -188,16 +166,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
|
|
|
|
.thenReturn(registeredClient); |
|
|
|
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse(); |
|
|
|
|
FilterChain filterChain = mock(FilterChain.class); |
|
|
|
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain); |
|
|
|
|
|
|
|
|
|
verifyNoInteractions(filterChain); |
|
|
|
|
|
|
|
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
|
|
|
|
assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id"); |
|
|
|
|
doFilterWhenAuthorizationRequestInvalidParameterThenError( |
|
|
|
|
registeredClient, |
|
|
|
|
OAuth2ParameterNames.CLIENT_ID, |
|
|
|
|
OAuth2ErrorCodes.UNAUTHORIZED_CLIENT); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
@ -206,17 +178,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -206,17 +178,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
|
|
|
|
.thenReturn(registeredClient); |
|
|
|
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
|
|
|
|
request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com"); |
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse(); |
|
|
|
|
FilterChain filterChain = mock(FilterChain.class); |
|
|
|
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain); |
|
|
|
|
|
|
|
|
|
verifyNoInteractions(filterChain); |
|
|
|
|
|
|
|
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
|
|
|
|
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); |
|
|
|
|
doFilterWhenAuthorizationRequestInvalidParameterThenError( |
|
|
|
|
registeredClient, |
|
|
|
|
OAuth2ParameterNames.REDIRECT_URI, |
|
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST, |
|
|
|
|
request -> request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com")); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
@ -225,17 +191,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -225,17 +191,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
|
|
|
|
.thenReturn(registeredClient); |
|
|
|
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
|
|
|
|
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); |
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse(); |
|
|
|
|
FilterChain filterChain = mock(FilterChain.class); |
|
|
|
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain); |
|
|
|
|
|
|
|
|
|
verifyNoInteractions(filterChain); |
|
|
|
|
|
|
|
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
|
|
|
|
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); |
|
|
|
|
doFilterWhenAuthorizationRequestInvalidParameterThenError( |
|
|
|
|
registeredClient, |
|
|
|
|
OAuth2ParameterNames.REDIRECT_URI, |
|
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST, |
|
|
|
|
request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com")); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
@ -244,17 +204,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -244,17 +204,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId())))) |
|
|
|
|
.thenReturn(registeredClient); |
|
|
|
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
|
|
|
|
request.removeParameter(OAuth2ParameterNames.REDIRECT_URI); |
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse(); |
|
|
|
|
FilterChain filterChain = mock(FilterChain.class); |
|
|
|
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain); |
|
|
|
|
|
|
|
|
|
verifyNoInteractions(filterChain); |
|
|
|
|
|
|
|
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
|
|
|
|
assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri"); |
|
|
|
|
doFilterWhenAuthorizationRequestInvalidParameterThenError( |
|
|
|
|
registeredClient, |
|
|
|
|
OAuth2ParameterNames.REDIRECT_URI, |
|
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST, |
|
|
|
|
request -> request.removeParameter(OAuth2ParameterNames.REDIRECT_URI)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
@ -383,6 +337,27 @@ public class OAuth2AuthorizationEndpointFilterTests {
@@ -383,6 +337,27 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
assertThat(authorizationRequest.getAdditionalParameters()).isEmpty(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient, |
|
|
|
|
String parameterName, String errorCode) throws Exception { |
|
|
|
|
doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, parameterName, errorCode, request -> {}); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient, |
|
|
|
|
String parameterName, String errorCode, Consumer<MockHttpServletRequest> requestConsumer) throws Exception { |
|
|
|
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient); |
|
|
|
|
requestConsumer.accept(request); |
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse(); |
|
|
|
|
FilterChain filterChain = mock(FilterChain.class); |
|
|
|
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain); |
|
|
|
|
|
|
|
|
|
verifyNoInteractions(filterChain); |
|
|
|
|
|
|
|
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); |
|
|
|
|
assertThat(response.getErrorMessage()).isEqualTo("[" + errorCode + "] OAuth 2.0 Parameter: " + parameterName); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) { |
|
|
|
|
String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]); |
|
|
|
|
|
|
|
|
|
|