From 4bc0df5ef81a690abb0b7ea3fcbdf3c4b4b6ce23 Mon Sep 17 00:00:00 2001 From: Greg Li Date: Tue, 5 Dec 2023 16:40:25 +0800 Subject: [PATCH] Fix to ensure endpoints distinguish between form and query parameters Closes gh-1451 --- .../sample/AuthorizationCodeGrantFlow.java | 2 +- ...ientSecretPostAuthenticationConverter.java | 24 ++++++------ ...lientAssertionAuthenticationConverter.java | 8 ++-- ...horizationCodeAuthenticationConverter.java | 5 +-- ...ionCodeRequestAuthenticationConverter.java | 4 +- ...izationConsentAuthenticationConverter.java | 6 +-- ...entCredentialsAuthenticationConverter.java | 6 +-- .../authentication/OAuth2EndpointUtils.java | 35 +++++++++++++---- ...h2RefreshTokenAuthenticationConverter.java | 6 +-- ...nIntrospectionAuthenticationConverter.java | 2 +- ...okenRevocationAuthenticationConverter.java | 2 +- .../PublicClientAuthenticationConverter.java | 2 +- .../OAuth2AuthorizationCodeGrantTests.java | 29 +++++++------- .../annotation/web/configurers/OidcTests.java | 2 +- ...Auth2AuthorizationEndpointFilterTests.java | 39 ++++++++++++++----- 15 files changed, 105 insertions(+), 67 deletions(-) diff --git a/docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java b/docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java index 7339946e..d4935c75 100644 --- a/docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java +++ b/docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java @@ -94,7 +94,7 @@ public class AuthorizationCodeGrantFlow { parameters.set(OAuth2ParameterNames.STATE, "state"); MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorize") - .params(parameters) + .queryParams(parameters) .with(user(this.username).roles("USER"))) .andExpect(status().isOk()) .andExpect(header().string("content-type", containsString(MediaType.TEXT_HTML_VALUE))) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java index b5ef6691..a3153a83 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java @@ -48,7 +48,18 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica @Nullable @Override public Authentication convert(HttpServletRequest request) { - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + String queryString = request.getQueryString(); + if (StringUtils.hasText(queryString) && + (queryString.contains(OAuth2ParameterNames.CLIENT_ID) || + queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_REQUEST, + "Client credentials MUST NOT be included in the request URI.", + null); + throw new OAuth2AuthenticationException(error); + } + + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); // client_id (REQUIRED) String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); @@ -70,17 +81,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); } - String queryString = request.getQueryString(); - if (StringUtils.hasText(queryString) && - (queryString.contains(OAuth2ParameterNames.CLIENT_ID) || - queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) { - OAuth2Error error = new OAuth2Error( - OAuth2ErrorCodes.INVALID_REQUEST, - "Client credentials MUST NOT be included in the request URI.", - null); - throw new OAuth2AuthenticationException(error); - } - Map additionalParameters = OAuth2EndpointUtils.getParametersIfMatchesAuthorizationCodeGrantRequest(request, OAuth2ParameterNames.CLIENT_ID, OAuth2ParameterNames.CLIENT_SECRET); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java index 59165ff2..ed8fb779 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java @@ -48,13 +48,13 @@ public final class JwtClientAssertionAuthenticationConverter implements Authenti @Nullable @Override public Authentication convert(HttpServletRequest request) { - if (request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null || - request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION) == null) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + + if (parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null || + parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION) == null) { return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // client_assertion_type (REQUIRED) String clientAssertionType = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE); if (parameters.get(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE).size() != 1) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java index 76c99baf..23b73a8c 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java @@ -47,16 +47,15 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut @Nullable @Override public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); // grant_type (REQUIRED) - String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); + String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) { return null; } Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // code (REQUIRED) String code = parameters.getFirst(OAuth2ParameterNames.CODE); if (!StringUtils.hasText(code) || diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java index cd38789a..f2cd43ec 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java @@ -66,10 +66,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getQueryParameters(request); // response_type (REQUIRED) - String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); + String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE); if (!StringUtils.hasText(responseType) || parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java index d6b0f477..ccb1b736 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java @@ -54,13 +54,13 @@ public final class OAuth2AuthorizationConsentAuthenticationConverter implements @Override public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + if (!"POST".equals(request.getMethod()) || - request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE) != null) { + parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE) != null) { return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - String authorizationUri = request.getRequestURL().toString(); // client_id (REQUIRED) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java index a9578125..2b27cd69 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java @@ -50,16 +50,16 @@ public final class OAuth2ClientCredentialsAuthenticationConverter implements Aut @Nullable @Override public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + // grant_type (REQUIRED) - String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); + String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(grantType)) { return null; } Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // scope (OPTIONAL) String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE); if (StringUtils.hasText(scope) && diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java index ef08954a..e1f7d555 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java @@ -28,24 +28,41 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; /** * Utility methods for the OAuth 2.0 Protocol Endpoints. * * @author Joe Grandja + * @author Greg Li * @since 0.1.2 */ final class OAuth2EndpointUtils { static final String ACCESS_TOKEN_REQUEST_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"; - private OAuth2EndpointUtils() { } - static MultiValueMap getParameters(HttpServletRequest request) { + static MultiValueMap getFormParameters(HttpServletRequest request) { + Map parameterMap = request.getParameterMap(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameterMap.forEach((key, values) -> { + // If not query parameter then it's a form parameter + if ((!StringUtils.hasText(request.getQueryString()) && values.length > 0) + || (!request.getQueryString().contains(key) && values.length > 0)) { + for (String value : values) { + parameters.add(key, value); + } + } + }); + return parameters; + } + + static MultiValueMap getQueryParameters(HttpServletRequest request) { Map parameterMap = request.getParameterMap(); - MultiValueMap parameters = new LinkedMultiValueMap<>(parameterMap.size()); + MultiValueMap parameters = new LinkedMultiValueMap<>(); parameterMap.forEach((key, values) -> { - if (values.length > 0) { + if (StringUtils.hasText(request.getQueryString()) + && request.getQueryString().contains(key) && values.length > 0) { for (String value : values) { parameters.add(key, value); } @@ -58,7 +75,7 @@ final class OAuth2EndpointUtils { if (!matchesAuthorizationCodeGrantRequest(request)) { return Collections.emptyMap(); } - MultiValueMap multiValueParameters = getParameters(request); + MultiValueMap multiValueParameters = getFormParameters(request); for (String exclusion : exclusions) { multiValueParameters.remove(exclusion); } @@ -71,14 +88,16 @@ final class OAuth2EndpointUtils { } static boolean matchesAuthorizationCodeGrantRequest(HttpServletRequest request) { + MultiValueMap parameters = getFormParameters(request); return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals( - request.getParameter(OAuth2ParameterNames.GRANT_TYPE)) && - request.getParameter(OAuth2ParameterNames.CODE) != null; + parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) && + parameters.getFirst(OAuth2ParameterNames.CODE) != null; } static boolean matchesPkceTokenRequest(HttpServletRequest request) { + MultiValueMap parameters = getFormParameters(request); return matchesAuthorizationCodeGrantRequest(request) && - request.getParameter(PkceParameterNames.CODE_VERIFIER) != null; + parameters.getFirst(PkceParameterNames.CODE_VERIFIER) != null; } static void throwError(String errorCode, String parameterName, String errorUri) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java index 49ef5ade..7e08644b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java @@ -50,16 +50,16 @@ public final class OAuth2RefreshTokenAuthenticationConverter implements Authenti @Nullable @Override public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + // grant_type (REQUIRED) - String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); + String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) { return null; } Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // refresh_token (REQUIRED) String refreshToken = parameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN); if (!StringUtils.hasText(refreshToken) || diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java index 3829ee42..90726ac5 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java @@ -49,7 +49,7 @@ public final class OAuth2TokenIntrospectionAuthenticationConverter implements Au public Authentication convert(HttpServletRequest request) { Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); // token (REQUIRED) String token = parameters.getFirst(OAuth2ParameterNames.TOKEN); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java index 4a8ffec6..b8a79b73 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java @@ -46,7 +46,7 @@ public final class OAuth2TokenRevocationAuthenticationConverter implements Authe public Authentication convert(HttpServletRequest request) { Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); // token (REQUIRED) String token = parameters.getFirst(OAuth2ParameterNames.TOKEN); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java index 157bdb09..77369f22 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java @@ -53,7 +53,7 @@ public final class PublicClientAuthenticationConverter implements Authentication return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); // client_id (REQUIRED for public clients) String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java index 95d71a7d..fb455439 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java @@ -153,6 +153,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Daniel Garnier-Moiroux * @author Dmitriy Dubson * @author Steve Riesenberg + * @author Greg Li */ @ExtendWith(SpringTestContextExtension.class) public class OAuth2AuthorizationCodeGrantTests { @@ -255,7 +256,7 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient))) + .queryParams(getAuthorizationRequestParameters(registeredClient))) .andExpect(status().isUnauthorized()) .andReturn(); } @@ -297,7 +298,7 @@ public class OAuth2AuthorizationCodeGrantTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(authorizationEndpointUri) - .params(authorizationRequestParameters) + .queryParams(authorizationRequestParameters) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -389,9 +390,9 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) - .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) - .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") + .queryParams(getAuthorizationRequestParameters(registeredClient)) + .queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) + .queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -434,9 +435,9 @@ public class OAuth2AuthorizationCodeGrantTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) - .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) - .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") + .queryParams(authorizationRequestParameters) + .queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) + .queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -473,7 +474,7 @@ public class OAuth2AuthorizationCodeGrantTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) + .queryParams(authorizationRequestParameters) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -519,7 +520,7 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); String consentPage = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) + .queryParams(getAuthorizationRequestParameters(registeredClient)) .with(user("user"))) .andExpect(status().is2xxSuccessful()) .andReturn() @@ -602,7 +603,7 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) + .queryParams(getAuthorizationRequestParameters(registeredClient)) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -737,9 +738,9 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) - .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) - .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") + .queryParams(getAuthorizationRequestParameters(registeredClient)) + .queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) + .queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java index d8335308..6f98406a 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java @@ -184,7 +184,7 @@ public class OidcTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) + .queryParams(authorizationRequestParameters) .with(user("user").roles("A", "B"))) .andExpect(status().is3xxRedirection()) .andReturn(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 163c130f..c170734b 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -37,6 +37,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockServletContext; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -58,6 +59,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetails; +import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -78,6 +80,7 @@ import static org.mockito.Mockito.when; * @author Daniel Garnier-Moiroux * @author Anoop Garlapati * @author Dmitriy Dubson + * @author Greg Li * @since 0.0.1 */ public class OAuth2AuthorizationEndpointFilterTests { @@ -263,6 +266,13 @@ public class OAuth2AuthorizationEndpointFilterTests { request -> { request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge"); request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge"); + String originalQueryString = request.getQueryString(); + if (StringUtils.hasText(originalQueryString)) { + String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE) + .concat("=code-challenge").concat("&") + .concat(PkceParameterNames.CODE_CHALLENGE).concat("=another-code-challenge"); + request.setQueryString(newQueryString); + } }); } @@ -275,6 +285,13 @@ public class OAuth2AuthorizationEndpointFilterTests { request -> { request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + String originalQueryString = request.getQueryString(); + if (StringUtils.hasText(originalQueryString)) { + String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE_METHOD) + .concat("=S256").concat("&") + .concat(PkceParameterNames.CODE_CHALLENGE_METHOD).concat("=S256"); + request.setQueryString(newQueryString); + } }); } @@ -557,6 +574,10 @@ public class OAuth2AuthorizationEndpointFilterTests { MockHttpServletRequest request = createAuthorizationRequest(registeredClient); request.addParameter("custom-param", "custom-value-1", "custom-value-2"); + String newQueryString = request.getQueryString().concat("custom-param") + .concat("=custom-value-1").concat("&") + .concat("custom-param").concat("=custom-value-2"); + request.setQueryString(newQueryString); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -646,17 +667,15 @@ public class OAuth2AuthorizationEndpointFilterTests { private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) { String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI; - MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); - request.setServletPath(requestUri); + MockHttpServletRequest request = MockMvcRequestBuilders.get(requestUri) + .queryParam(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()) + .queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) + .queryParam(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()) + .queryParam(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")) + .queryParam(OAuth2ParameterNames.STATE, "state") + .buildRequest(new MockServletContext()); request.setRemoteAddr(REMOTE_ADDRESS); - - request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); - request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); - request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); - request.addParameter(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); - request.addParameter(OAuth2ParameterNames.STATE, "state"); - return request; }