Browse Source

Polish gh-1468

pull/1550/head
Joe Grandja 2 years ago
parent
commit
f0a6a4c0bf
  1. 2
      docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java
  2. 26
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java
  3. 12
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java
  4. 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java
  5. 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java
  6. 5
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java
  7. 22
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java
  8. 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java
  9. 5
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java
  10. 33
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java
  11. 22
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java
  12. 96
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java
  13. 25
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java

2
docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

26
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,6 +15,8 @@ @@ -15,6 +15,8 @@
*/
package org.springframework.security.oauth2.server.authorization.oidc.web.authentication;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import org.springframework.http.converter.HttpMessageConverter;
@ -30,6 +32,8 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat @@ -30,6 +32,8 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat
import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
@ -65,14 +69,30 @@ public final class OidcClientRegistrationAuthenticationConverter implements Auth @@ -65,14 +69,30 @@ public final class OidcClientRegistrationAuthenticationConverter implements Auth
return new OidcClientRegistrationAuthenticationToken(principal, clientRegistration);
}
MultiValueMap<String, String> parameters = getQueryParameters(request);
// client_id (REQUIRED)
String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
if (!StringUtils.hasText(clientId) ||
request.getParameterValues(OAuth2ParameterNames.CLIENT_ID).length != 1) {
parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
}
return new OidcClientRegistrationAuthenticationToken(principal, clientId);
}
private static MultiValueMap<String, String> getQueryParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameterMap.forEach((key, values) -> {
String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
if (queryString.contains(key) && values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
}
});
return parameters;
}
}

12
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java

@ -23,7 +23,6 @@ import org.springframework.lang.Nullable; @@ -23,7 +23,6 @@ import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
@ -48,17 +47,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica @@ -48,17 +47,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
String queryString = request.getQueryString();
if (StringUtils.hasText(queryString) &&
(queryString.contains(OAuth2ParameterNames.CLIENT_ID) ||
queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) {
OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_REQUEST,
"Client credentials MUST NOT be included in the request URI.",
null);
throw new OAuth2AuthenticationException(error);
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// client_id (REQUIRED)

2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2021 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java

@ -48,6 +48,7 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut @@ -48,6 +48,7 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut
@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// grant_type (REQUIRED)
String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) {

5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java

@ -66,7 +66,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme @@ -66,7 +66,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme
return null;
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getQueryParameters(request);
MultiValueMap<String, String> parameters =
"GET".equals(request.getMethod()) ?
OAuth2EndpointUtils.getQueryParameters(request) :
OAuth2EndpointUtils.getFormParameters(request);
// response_type (REQUIRED)
String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE);

22
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java

@ -39,6 +39,7 @@ import org.springframework.util.StringUtils; @@ -39,6 +39,7 @@ import org.springframework.util.StringUtils;
*/
final class OAuth2EndpointUtils {
static final String ACCESS_TOKEN_REQUEST_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
private OAuth2EndpointUtils() {
}
@ -46,9 +47,9 @@ final class OAuth2EndpointUtils { @@ -46,9 +47,9 @@ final class OAuth2EndpointUtils {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameterMap.forEach((key, values) -> {
String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
// If not query parameter then it's a form parameter
if ((!StringUtils.hasText(request.getQueryString()) && values.length > 0)
|| (!request.getQueryString().contains(key) && values.length > 0)) {
if (!queryString.contains(key) && values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
@ -61,8 +62,8 @@ final class OAuth2EndpointUtils { @@ -61,8 +62,8 @@ final class OAuth2EndpointUtils {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameterMap.forEach((key, values) -> {
if (StringUtils.hasText(request.getQueryString())
&& request.getQueryString().contains(key) && values.length > 0) {
String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
if (queryString.contains(key) && values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
@ -75,7 +76,10 @@ final class OAuth2EndpointUtils { @@ -75,7 +76,10 @@ final class OAuth2EndpointUtils {
if (!matchesAuthorizationCodeGrantRequest(request)) {
return Collections.emptyMap();
}
MultiValueMap<String, String> multiValueParameters = getFormParameters(request);
MultiValueMap<String, String> multiValueParameters =
"GET".equals(request.getMethod()) ?
getQueryParameters(request) :
getFormParameters(request);
for (String exclusion : exclusions) {
multiValueParameters.remove(exclusion);
}
@ -88,16 +92,14 @@ final class OAuth2EndpointUtils { @@ -88,16 +92,14 @@ final class OAuth2EndpointUtils {
}
static boolean matchesAuthorizationCodeGrantRequest(HttpServletRequest request) {
MultiValueMap<String, String> parameters = getFormParameters(request);
return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(
parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) &&
parameters.getFirst(OAuth2ParameterNames.CODE) != null;
request.getParameter(OAuth2ParameterNames.GRANT_TYPE)) &&
request.getParameter(OAuth2ParameterNames.CODE) != null;
}
static boolean matchesPkceTokenRequest(HttpServletRequest request) {
MultiValueMap<String, String> parameters = getFormParameters(request);
return matchesAuthorizationCodeGrantRequest(request) &&
parameters.getFirst(PkceParameterNames.CODE_VERIFIER) != null;
request.getParameter(PkceParameterNames.CODE_VERIFIER) != null;
}
static void throwError(String errorCode, String parameterName, String errorUri) {

2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java

@ -53,7 +53,10 @@ public final class PublicClientAuthenticationConverter implements Authentication @@ -53,7 +53,10 @@ public final class PublicClientAuthenticationConverter implements Authentication
return null;
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
MultiValueMap<String, String> parameters =
"GET".equals(request.getMethod()) ?
OAuth2EndpointUtils.getQueryParameters(request) :
OAuth2EndpointUtils.getFormParameters(request);
// client_id (REQUIRED for public clients)
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);

33
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java

@ -59,7 +59,6 @@ import org.springframework.security.crypto.password.PasswordEncoder; @@ -59,7 +59,6 @@ import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService;
@ -98,7 +97,6 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand @@ -98,7 +97,6 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.web.util.UriComponentsBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
@ -232,37 +230,6 @@ public class OAuth2ClientCredentialsGrantTests { @@ -232,37 +230,6 @@ public class OAuth2ClientCredentialsGrantTests {
verify(jwtCustomizer).customize(any());
}
// gh-1378
@Test
public void requestWhenTokenRequestWithClientCredentialsInQueryParamThenInvalidRequest() throws Exception {
this.spring.register(AuthorizationServerConfiguration.class).autowire();
RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
this.registeredClientRepository.save(registeredClient);
String tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI)
.queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.toUriString();
this.mvc.perform(post(tokenEndpointUri)
.param(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret())
.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
.param(OAuth2ParameterNames.SCOPE, "scope1 scope2"))
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST));
tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI)
.queryParam(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret())
.toUriString();
this.mvc.perform(post(tokenEndpointUri)
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
.param(OAuth2ParameterNames.SCOPE, "scope1 scope2"))
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST));
}
@Test
public void requestWhenTokenEndpointCustomizedThenUsed() throws Exception {
this.spring.register(AuthorizationServerConfigurationCustomTokenEndpoint.class).autowire();

22
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,6 +61,7 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut @@ -61,6 +61,7 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.web.util.UriComponentsBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@ -327,6 +328,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -327,6 +328,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -342,6 +344,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -342,6 +344,7 @@ public class OidcClientRegistrationEndpointFilterTests {
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id");
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id2");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -388,6 +391,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -388,6 +391,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -421,6 +425,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -421,6 +425,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId());
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -463,6 +468,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -463,6 +468,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client-id");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -492,6 +498,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -492,6 +498,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId());
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -513,6 +520,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -513,6 +520,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -522,6 +530,18 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -522,6 +530,18 @@ public class OidcClientRegistrationEndpointFilterTests {
any(OAuth2AuthenticationException.class));
}
private static void updateQueryString(MockHttpServletRequest request) {
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI());
request.getParameterMap().forEach((key, values) -> {
if (values.length > 0) {
for (String value : values) {
uriBuilder.queryParam(key, value);
}
}
});
request.setQueryString(uriBuilder.build().getQuery());
}
private OAuth2Error readError(MockHttpServletResponse response) throws Exception {
MockClientHttpResponse httpResponse = new MockClientHttpResponse(
response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));

96
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@ -37,7 +37,6 @@ import org.springframework.http.HttpStatus; @@ -37,7 +37,6 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
@ -59,8 +58,8 @@ import org.springframework.security.web.authentication.AuthenticationConverter; @@ -59,8 +58,8 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetails;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -173,7 +172,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -173,7 +172,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.RESPONSE_TYPE,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE));
request -> {
request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE);
updateQueryString(request);
});
}
@Test
@ -182,7 +184,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -182,7 +184,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.RESPONSE_TYPE,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"));
request -> {
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
updateQueryString(request);
});
}
@Test
@ -191,7 +196,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -191,7 +196,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.RESPONSE_TYPE,
OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE,
request -> request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"));
request -> {
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
updateQueryString(request);
});
}
@Test
@ -200,7 +208,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -200,7 +208,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.CLIENT_ID,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID));
request -> {
request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
updateQueryString(request);
});
}
@Test
@ -209,7 +220,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -209,7 +220,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.CLIENT_ID,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"));
request -> {
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2");
updateQueryString(request);
});
}
@Test
@ -218,7 +232,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -218,7 +232,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.REDIRECT_URI,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
request -> {
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
updateQueryString(request);
});
}
@Test
@ -227,7 +244,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -227,7 +244,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.SCOPE,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.SCOPE, "scope2"));
request -> {
request.addParameter(OAuth2ParameterNames.SCOPE, "scope2");
updateQueryString(request);
});
}
@Test
@ -236,7 +256,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -236,7 +256,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.STATE,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.STATE, "state2"));
request -> {
request.addParameter(OAuth2ParameterNames.STATE, "state2");
updateQueryString(request);
});
}
@Test
@ -266,13 +289,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -266,13 +289,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
request -> {
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge");
String originalQueryString = request.getQueryString();
if (StringUtils.hasText(originalQueryString)) {
String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE)
.concat("=code-challenge").concat("&")
.concat(PkceParameterNames.CODE_CHALLENGE).concat("=another-code-challenge");
request.setQueryString(newQueryString);
}
updateQueryString(request);
});
}
@ -285,13 +302,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -285,13 +302,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
request -> {
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
String originalQueryString = request.getQueryString();
if (StringUtils.hasText(originalQueryString)) {
String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE_METHOD)
.concat("=S256").concat("&")
.concat(PkceParameterNames.CODE_CHALLENGE_METHOD).concat("=S256");
request.setQueryString(newQueryString);
}
updateQueryString(request);
});
}
@ -574,10 +585,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -574,10 +585,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter("custom-param", "custom-value-1", "custom-value-2");
String newQueryString = request.getQueryString().concat("custom-param")
.concat("=custom-value-1").concat("&")
.concat("custom-param").concat("=custom-value-2");
request.setQueryString(newQueryString);
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -623,6 +631,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -623,6 +631,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.setMethod("POST"); // OpenID Connect supports POST method
request.setQueryString(null);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -667,15 +676,18 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -667,15 +676,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
MockHttpServletRequest request = MockMvcRequestBuilders.get(requestUri)
.queryParam(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue())
.queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.queryParam(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next())
.queryParam(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "))
.queryParam(OAuth2ParameterNames.STATE, "state")
.buildRequest(new MockServletContext());
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setRemoteAddr(REMOTE_ADDRESS);
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
request.addParameter(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
request.addParameter(OAuth2ParameterNames.STATE, "state");
updateQueryString(request);
return request;
}
@ -692,6 +704,18 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -692,6 +704,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
return request;
}
private static void updateQueryString(MockHttpServletRequest request) {
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI());
request.getParameterMap().forEach((key, values) -> {
if (values.length > 0) {
for (String value : values) {
uriBuilder.queryParam(key, value);
}
}
});
request.setQueryString(uriBuilder.build().getQuery());
}
private static String scopeCheckbox(String scope) {
return MessageFormat.format(
"<input class=\"form-check-input\" type=\"checkbox\" name=\"scope\" value=\"{0}\" id=\"{0}\">",

25
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java

@ -79,31 +79,6 @@ public class ClientSecretPostAuthenticationConverterTests { @@ -79,31 +79,6 @@ public class ClientSecretPostAuthenticationConverterTests {
.isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
// gh-1378
@Test
public void convertWhenClientCredentialsInQueryParamThenInvalidRequestError() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1");
request.addParameter(OAuth2ParameterNames.CLIENT_SECRET, "client-secret");
request.setQueryString("client_id=client-1");
assertThatThrownBy(() -> this.converter.convert(request))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI.");
});
request.setQueryString("client_secret=client-secret");
assertThatThrownBy(() -> this.converter.convert(request))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI.");
});
}
@Test
public void convertWhenPostWithValidCredentialsThenReturnClientAuthenticationToken() {
MockHttpServletRequest request = new MockHttpServletRequest();

Loading…
Cancel
Save