diff --git a/docs/src/test/java/sample/AuthorizationCodeGrantFlow.java b/docs/src/test/java/sample/AuthorizationCodeGrantFlow.java index 9c35a4db..534e2b7e 100644 --- a/docs/src/test/java/sample/AuthorizationCodeGrantFlow.java +++ b/docs/src/test/java/sample/AuthorizationCodeGrantFlow.java @@ -110,7 +110,7 @@ public class AuthorizationCodeGrantFlow { // @formatter:off MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorize") - .params(parameters) + .queryParams(parameters) .with(user(this.username).roles("USER"))) .andExpect(status().isOk()) .andExpect(header().string("content-type", containsString(MediaType.TEXT_HTML_VALUE))) diff --git a/docs/src/test/java/sample/DeviceAuthorizationGrantFlow.java b/docs/src/test/java/sample/DeviceAuthorizationGrantFlow.java index 0adcdfc2..3abe7dce 100644 --- a/docs/src/test/java/sample/DeviceAuthorizationGrantFlow.java +++ b/docs/src/test/java/sample/DeviceAuthorizationGrantFlow.java @@ -117,7 +117,7 @@ public class DeviceAuthorizationGrantFlow { parameters.set(OAuth2ParameterNames.USER_CODE, userCode); MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/device_verification") - .params(parameters) + .queryParams(parameters) .with(user(this.username).roles("USER"))) .andExpect(status().isOk()) .andExpect(header().string(HttpHeaders.CONTENT_TYPE, containsString(MediaType.TEXT_HTML_VALUE))) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OAuth2EndpointUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OAuth2EndpointUtils.java new file mode 100644 index 00000000..0eb6d711 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OAuth2EndpointUtils.java @@ -0,0 +1,67 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.web.authentication; + +import java.util.Map; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * Utility methods for the OAuth 2.0 Protocol Endpoints. + * + * @author Joe Grandja + * @author Greg Li + * @since 1.1.4 + */ +final class OAuth2EndpointUtils { + + private OAuth2EndpointUtils() { + } + + static MultiValueMap getFormParameters(HttpServletRequest request) { + Map parameterMap = request.getParameterMap(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameterMap.forEach((key, values) -> { + String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : ""; + // If not query parameter then it's a form parameter + if (!queryString.contains(key) && values.length > 0) { + for (String value : values) { + parameters.add(key, value); + } + } + }); + return parameters; + } + + static MultiValueMap getQueryParameters(HttpServletRequest request) { + Map parameterMap = request.getParameterMap(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameterMap.forEach((key, values) -> { + String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : ""; + if (queryString.contains(key) && values.length > 0) { + for (String value : values) { + parameters.add(key, value); + } + } + }); + return parameters; + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java index 54c6c1ba..c4933398 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter; import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; /** @@ -65,10 +66,12 @@ public final class OidcClientRegistrationAuthenticationConverter implements Auth return new OidcClientRegistrationAuthenticationToken(principal, clientRegistration); } + MultiValueMap parameters = OAuth2EndpointUtils.getQueryParameters(request); + // client_id (REQUIRED) - String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID); + String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); if (!StringUtils.hasText(clientId) || - request.getParameterValues(OAuth2ParameterNames.CLIENT_ID).length != 1) { + parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java index 530c2f06..dfa1f2cd 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java @@ -15,8 +15,6 @@ */ package org.springframework.security.oauth2.server.authorization.oidc.web.authentication; -import java.util.Map; - import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpSession; @@ -31,7 +29,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcLogoutEndpointFilter; import org.springframework.security.web.authentication.AuthenticationConverter; -import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -51,12 +48,15 @@ public final class OidcLogoutAuthenticationConverter implements AuthenticationCo @Override public Authentication convert(HttpServletRequest request) { - MultiValueMap parameters = getParameters(request); + MultiValueMap parameters = + "GET".equals(request.getMethod()) ? + OAuth2EndpointUtils.getQueryParameters(request) : + OAuth2EndpointUtils.getFormParameters(request); // id_token_hint (REQUIRED) // RECOMMENDED as per spec - String idTokenHint = request.getParameter("id_token_hint"); + String idTokenHint = parameters.getFirst("id_token_hint"); if (!StringUtils.hasText(idTokenHint) || - request.getParameterValues("id_token_hint").length != 1) { + parameters.get("id_token_hint").size() != 1) { throwError(OAuth2ErrorCodes.INVALID_REQUEST, "id_token_hint"); } @@ -96,19 +96,6 @@ public final class OidcLogoutAuthenticationConverter implements AuthenticationCo sessionId, clientId, postLogoutRedirectUri, state); } - private static MultiValueMap getParameters(HttpServletRequest request) { - Map parameterMap = request.getParameterMap(); - MultiValueMap parameters = new LinkedMultiValueMap<>(parameterMap.size()); - parameterMap.forEach((key, values) -> { - if (values.length > 0) { - for (String value : values) { - parameters.add(key, value); - } - } - }); - return parameters; - } - private static void throwError(String errorCode, String parameterName) { OAuth2Error error = new OAuth2Error( errorCode, diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java index 4f225873..04207fb5 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java @@ -23,7 +23,6 @@ import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; @@ -48,7 +47,7 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica @Nullable @Override public Authentication convert(HttpServletRequest request) { - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); // client_id (REQUIRED) String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); @@ -70,17 +69,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); } - String queryString = request.getQueryString(); - if (StringUtils.hasText(queryString) && - (queryString.contains(OAuth2ParameterNames.CLIENT_ID) || - queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) { - OAuth2Error error = new OAuth2Error( - OAuth2ErrorCodes.INVALID_REQUEST, - "Client credentials MUST NOT be included in the request URI.", - null); - throw new OAuth2AuthenticationException(error); - } - Map additionalParameters = OAuth2EndpointUtils.getParametersIfMatchesAuthorizationCodeGrantRequest(request, OAuth2ParameterNames.CLIENT_ID, OAuth2ParameterNames.CLIENT_SECRET); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java index ce6118d5..a29c663f 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,13 +48,13 @@ public final class JwtClientAssertionAuthenticationConverter implements Authenti @Nullable @Override public Authentication convert(HttpServletRequest request) { - if (request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null || - request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION) == null) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + + if (parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null || + parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION) == null) { return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // client_assertion_type (REQUIRED) String clientAssertionType = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE); if (parameters.get(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE).size() != 1) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java index 27ace5c7..2dda4c0e 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java @@ -47,16 +47,16 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut @Nullable @Override public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + // grant_type (REQUIRED) - String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); + String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) { return null; } Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // code (REQUIRED) String code = parameters.getFirst(OAuth2ParameterNames.CODE); if (!StringUtils.hasText(code) || diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java index b116c6a7..3d5ea80d 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java @@ -66,10 +66,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = + "GET".equals(request.getMethod()) ? + OAuth2EndpointUtils.getQueryParameters(request) : + OAuth2EndpointUtils.getFormParameters(request); // response_type (REQUIRED) - String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE); + String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE); if (!StringUtils.hasText(responseType) || parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) { throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java index 9e347b41..7da9e9f1 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java @@ -54,13 +54,13 @@ public final class OAuth2AuthorizationConsentAuthenticationConverter implements @Override public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + if (!"POST".equals(request.getMethod()) || - request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE) != null) { + parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE) != null) { return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - String authorizationUri = request.getRequestURL().toString(); // client_id (REQUIRED) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java index 9eac61c8..7c34754a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java @@ -50,16 +50,16 @@ public final class OAuth2ClientCredentialsAuthenticationConverter implements Aut @Nullable @Override public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + // grant_type (REQUIRED) - String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); + String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(grantType)) { return null; } Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // scope (OPTIONAL) String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE); if (StringUtils.hasText(scope) && diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverter.java index 17625281..3c4ea0ed 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverter.java @@ -59,7 +59,7 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationConverter imple return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); String authorizationUri = request.getRequestURL().toString(); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverter.java index 5e2c8cb2..925dfed6 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverter.java @@ -53,7 +53,7 @@ public final class OAuth2DeviceAuthorizationRequestAuthenticationConverter imple public Authentication convert(HttpServletRequest request) { Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); String authorizationUri = request.getRequestURL().toString(); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverter.java index c61941b0..1405b76f 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverter.java @@ -49,16 +49,16 @@ public final class OAuth2DeviceCodeAuthenticationConverter implements Authentica @Nullable @Override public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + // grant_type (REQUIRED) - String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); + String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.DEVICE_CODE.getValue().equals(grantType)) { return null; } Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // device_code (REQUIRED) String deviceCode = parameters.getFirst(OAuth2ParameterNames.DEVICE_CODE); if (!StringUtils.hasText(deviceCode) || diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverter.java index ee90d815..ad879ba8 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverter.java @@ -59,7 +59,10 @@ public final class OAuth2DeviceVerificationAuthenticationConverter implements Au return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = + "GET".equals(request.getMethod()) ? + OAuth2EndpointUtils.getQueryParameters(request) : + OAuth2EndpointUtils.getFormParameters(request); // user_code (REQUIRED) String userCode = parameters.getFirst(OAuth2ParameterNames.USER_CODE); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java index 9469e23f..972deaa0 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java @@ -29,11 +29,13 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; /** * Utility methods for the OAuth 2.0 Protocol Endpoints. * * @author Joe Grandja + * @author Greg Li * @since 0.1.2 */ final class OAuth2EndpointUtils { @@ -42,11 +44,27 @@ final class OAuth2EndpointUtils { private OAuth2EndpointUtils() { } - static MultiValueMap getParameters(HttpServletRequest request) { + static MultiValueMap getFormParameters(HttpServletRequest request) { Map parameterMap = request.getParameterMap(); - MultiValueMap parameters = new LinkedMultiValueMap<>(parameterMap.size()); + MultiValueMap parameters = new LinkedMultiValueMap<>(); parameterMap.forEach((key, values) -> { - if (values.length > 0) { + String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : ""; + // If not query parameter then it's a form parameter + if (!queryString.contains(key) && values.length > 0) { + for (String value : values) { + parameters.add(key, value); + } + } + }); + return parameters; + } + + static MultiValueMap getQueryParameters(HttpServletRequest request) { + Map parameterMap = request.getParameterMap(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameterMap.forEach((key, values) -> { + String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : ""; + if (queryString.contains(key) && values.length > 0) { for (String value : values) { parameters.add(key, value); } @@ -59,7 +77,10 @@ final class OAuth2EndpointUtils { if (!matchesAuthorizationCodeGrantRequest(request)) { return Collections.emptyMap(); } - MultiValueMap multiValueParameters = getParameters(request); + MultiValueMap multiValueParameters = + "GET".equals(request.getMethod()) ? + getQueryParameters(request) : + getFormParameters(request); for (String exclusion : exclusions) { multiValueParameters.remove(exclusion); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java index 437f8f43..1b34eef8 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java @@ -50,16 +50,16 @@ public final class OAuth2RefreshTokenAuthenticationConverter implements Authenti @Nullable @Override public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + // grant_type (REQUIRED) - String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE); + String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) { return null; } Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); - // refresh_token (REQUIRED) String refreshToken = parameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN); if (!StringUtils.hasText(refreshToken) || diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java index f2b48d85..28dbcb0c 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java @@ -49,7 +49,7 @@ public final class OAuth2TokenIntrospectionAuthenticationConverter implements Au public Authentication convert(HttpServletRequest request) { Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); // token (REQUIRED) String token = parameters.getFirst(OAuth2ParameterNames.TOKEN); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java index fea3fab2..a8765ed2 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ public final class OAuth2TokenRevocationAuthenticationConverter implements Authe public Authentication convert(HttpServletRequest request) { Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication(); - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); // token (REQUIRED) String token = parameters.getFirst(OAuth2ParameterNames.TOKEN); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java index 2353a18d..2fda1393 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java @@ -53,7 +53,10 @@ public final class PublicClientAuthenticationConverter implements Authentication return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getParameters(request); + MultiValueMap parameters = + "GET".equals(request.getMethod()) ? + OAuth2EndpointUtils.getQueryParameters(request) : + OAuth2EndpointUtils.getFormParameters(request); // client_id (REQUIRED for public clients) String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java index ed47b879..06f7ccb8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java @@ -158,6 +158,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. * @author Daniel Garnier-Moiroux * @author Dmitriy Dubson * @author Steve Riesenberg + * @author Greg Li */ @ExtendWith(SpringTestContextExtension.class) public class OAuth2AuthorizationCodeGrantTests { @@ -260,7 +261,7 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient))) + .queryParams(getAuthorizationRequestParameters(registeredClient))) .andExpect(status().isUnauthorized()) .andReturn(); } @@ -302,7 +303,7 @@ public class OAuth2AuthorizationCodeGrantTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(authorizationEndpointUri) - .params(authorizationRequestParameters) + .queryParams(authorizationRequestParameters) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -394,9 +395,9 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) - .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) - .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") + .queryParams(getAuthorizationRequestParameters(registeredClient)) + .queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) + .queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -487,9 +488,9 @@ public class OAuth2AuthorizationCodeGrantTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) - .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) - .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") + .queryParams(authorizationRequestParameters) + .queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) + .queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -526,7 +527,7 @@ public class OAuth2AuthorizationCodeGrantTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) + .queryParams(authorizationRequestParameters) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -572,7 +573,7 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); String consentPage = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) + .queryParams(getAuthorizationRequestParameters(registeredClient)) .with(user("user"))) .andExpect(status().is2xxSuccessful()) .andReturn() @@ -655,7 +656,7 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) + .queryParams(getAuthorizationRequestParameters(registeredClient)) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -790,9 +791,9 @@ public class OAuth2AuthorizationCodeGrantTests { this.registeredClientRepository.save(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(getAuthorizationRequestParameters(registeredClient)) - .param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) - .param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") + .queryParams(getAuthorizationRequestParameters(registeredClient)) + .queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE) + .queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256") .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java index 1e559b00..9b097db8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java @@ -61,7 +61,6 @@ import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService; @@ -102,7 +101,6 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; -import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -236,37 +234,6 @@ public class OAuth2ClientCredentialsGrantTests { verify(jwtCustomizer).customize(any()); } - // gh-1378 - @Test - public void requestWhenTokenRequestWithClientCredentialsInQueryParamThenInvalidRequest() throws Exception { - this.spring.register(AuthorizationServerConfiguration.class).autowire(); - - RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); - this.registeredClientRepository.save(registeredClient); - - String tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI) - .queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) - .toUriString(); - - this.mvc.perform(post(tokenEndpointUri) - .param(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret()) - .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .param(OAuth2ParameterNames.SCOPE, "scope1 scope2")) - .andExpect(status().isBadRequest()) - .andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST)); - - tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI) - .queryParam(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret()) - .toUriString(); - - this.mvc.perform(post(tokenEndpointUri) - .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) - .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .param(OAuth2ParameterNames.SCOPE, "scope1 scope2")) - .andExpect(status().isBadRequest()) - .andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST)); - } - @Test public void requestWhenTokenRequestPostsClientCredentialsAndRequiresUpgradingThenClientSecretUpgraded() throws Exception { this.spring.register(AuthorizationServerConfigurationCustomPasswordEncoder.class).autowire(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2DeviceCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2DeviceCodeGrantTests.java index e3ac6ece..f8b03983 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2DeviceCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2DeviceCodeGrantTests.java @@ -279,7 +279,7 @@ public class OAuth2DeviceCodeGrantTests { // @formatter:off this.mvc.perform(get(DEFAULT_DEVICE_VERIFICATION_ENDPOINT_URI) - .params(parameters)) + .queryParams(parameters)) .andExpect(status().isUnauthorized()); // @formatter:on } @@ -313,7 +313,7 @@ public class OAuth2DeviceCodeGrantTests { // @formatter:off MvcResult mvcResult = this.mvc.perform(get(DEFAULT_DEVICE_VERIFICATION_ENDPOINT_URI) - .params(parameters) + .queryParams(parameters) .with(user("user"))) .andExpect(status().isOk()) .andExpect(content().contentTypeCompatibleWith(MediaType.TEXT_HTML)) diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java index 72ecd6bb..78758702 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java @@ -193,8 +193,8 @@ public class OidcTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) - .with(user("user").roles("A", "B"))) + .queryParams(authorizationRequestParameters) + .with(user("user").roles("A", "B"))) .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); @@ -249,7 +249,7 @@ public class OidcTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) + .queryParams(authorizationRequestParameters) .with(user("user").roles("A", "B"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -306,7 +306,7 @@ public class OidcTests { // Login MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) + .queryParams(authorizationRequestParameters) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -355,7 +355,7 @@ public class OidcTests { MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient1); MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) + .queryParams(authorizationRequestParameters) .with(user("user1"))) .andExpect(status().is3xxRedirection()) .andReturn(); @@ -387,7 +387,7 @@ public class OidcTests { authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient2); mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) - .params(authorizationRequestParameters) + .queryParams(authorizationRequestParameters) .with(user("user2"))) .andExpect(status().is3xxRedirection()) .andReturn(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java index 20fd0725..51208b45 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -327,6 +328,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CLIENT_ID, ""); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -342,6 +344,7 @@ public class OidcClientRegistrationEndpointFilterTests { request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id"); request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id2"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -388,6 +391,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -421,6 +425,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId()); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -463,6 +468,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client-id"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -492,6 +498,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId()); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -513,6 +520,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -522,6 +530,18 @@ public class OidcClientRegistrationEndpointFilterTests { any(OAuth2AuthenticationException.class)); } + private static void updateQueryString(MockHttpServletRequest request) { + UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI()); + request.getParameterMap().forEach((key, values) -> { + if (values.length > 0) { + for (String value : values) { + uriBuilder.queryParam(key, value); + } + } + }); + request.setQueryString(uriBuilder.build().getQuery()); + } + private OAuth2Error readError(MockHttpServletResponse response) throws Exception { MockClientHttpResponse httpResponse = new MockClientHttpResponse( response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index eb244829..9419de80 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -60,6 +60,7 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand import org.springframework.security.web.authentication.WebAuthenticationDetails; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -79,6 +80,7 @@ import static org.mockito.Mockito.when; * @author Daniel Garnier-Moiroux * @author Anoop Garlapati * @author Dmitriy Dubson + * @author Greg Li * @since 0.0.1 */ public class OAuth2AuthorizationEndpointFilterTests { @@ -178,7 +180,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE)); + request -> { + request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE); + updateQueryString(request); + }); } @Test @@ -187,7 +192,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token")); + request -> { + request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); + updateQueryString(request); + }); } @Test @@ -196,7 +204,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, - request -> request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token")); + request -> { + request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); + updateQueryString(request); + }); } @Test @@ -205,7 +216,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID)); + request -> { + request.removeParameter(OAuth2ParameterNames.CLIENT_ID); + updateQueryString(request); + }); } @Test @@ -214,7 +228,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2")); + request -> { + request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"); + updateQueryString(request); + }); } @Test @@ -223,7 +240,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com")); + request -> { + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); + updateQueryString(request); + }); } @Test @@ -232,7 +252,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.SCOPE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.SCOPE, "scope2")); + request -> { + request.addParameter(OAuth2ParameterNames.SCOPE, "scope2"); + updateQueryString(request); + }); } @Test @@ -241,7 +264,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.STATE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.STATE, "state2")); + request -> { + request.addParameter(OAuth2ParameterNames.STATE, "state2"); + updateQueryString(request); + }); } @Test @@ -271,6 +297,7 @@ public class OAuth2AuthorizationEndpointFilterTests { request -> { request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge"); request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge"); + updateQueryString(request); }); } @@ -283,6 +310,7 @@ public class OAuth2AuthorizationEndpointFilterTests { request -> { request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); + updateQueryString(request); }); } @@ -590,6 +618,7 @@ public class OAuth2AuthorizationEndpointFilterTests { MockHttpServletRequest request = createAuthorizationRequest(registeredClient); request.addParameter("custom-param", "custom-value-1", "custom-value-2"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -635,6 +664,7 @@ public class OAuth2AuthorizationEndpointFilterTests { MockHttpServletRequest request = createAuthorizationRequest(registeredClient); request.setMethod("POST"); // OpenID Connect supports POST method + request.setQueryString(null); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -689,6 +719,7 @@ public class OAuth2AuthorizationEndpointFilterTests { request.addParameter(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); request.addParameter(OAuth2ParameterNames.STATE, "state"); + updateQueryString(request); return request; } @@ -706,6 +737,18 @@ public class OAuth2AuthorizationEndpointFilterTests { return request; } + private static void updateQueryString(MockHttpServletRequest request) { + UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI()); + request.getParameterMap().forEach((key, values) -> { + if (values.length > 0) { + for (String value : values) { + uriBuilder.queryParam(key, value); + } + } + }); + request.setQueryString(uriBuilder.build().getQuery()); + } + private static String scopeCheckbox(String scope) { return MessageFormat.format( "", diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java index e30069fb..524112b5 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java @@ -23,6 +23,7 @@ import java.util.Set; import jakarta.servlet.FilterChain; import jakarta.servlet.http.HttpServletRequest; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -164,6 +165,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -223,6 +225,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); request.addParameter("custom-param-1", "custom-value-1"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -248,6 +251,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -268,6 +272,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -291,6 +296,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { request.setServerPort(443); request.setServerName("provider.com"); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.setConsentPage("/consent"); @@ -322,6 +328,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -340,6 +347,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -367,6 +375,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -388,6 +397,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); this.filter.doFilter(request, response, filterChain); @@ -445,6 +455,18 @@ public class OAuth2DeviceVerificationEndpointFilterTests { return request; } + private static void updateQueryString(MockHttpServletRequest request) { + UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI()); + request.getParameterMap().forEach((key, values) -> { + if (values.length > 0) { + for (String value : values) { + uriBuilder.queryParam(key, value); + } + } + }); + request.setQueryString(uriBuilder.build().getQuery()); + } + private static String scopeCheckbox(String scope) { return MessageFormat.format( "", diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java index 6464ab00..894fb409 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java @@ -79,31 +79,6 @@ public class ClientSecretPostAuthenticationConverterTests { .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); } - // gh-1378 - @Test - public void convertWhenClientCredentialsInQueryParamThenInvalidRequestError() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1"); - request.addParameter(OAuth2ParameterNames.CLIENT_SECRET, "client-secret"); - request.setQueryString("client_id=client-1"); - assertThatThrownBy(() -> this.converter.convert(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); - assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI."); - }); - - request.setQueryString("client_secret=client-secret"); - assertThatThrownBy(() -> this.converter.convert(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); - assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI."); - }); - } - @Test public void convertWhenPostWithValidCredentialsThenReturnClientAuthenticationToken() { MockHttpServletRequest request = new MockHttpServletRequest(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java index d7d17980..0b8ba1ef 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java @@ -31,6 +31,7 @@ import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceVerificationAuthenticationToken; +import org.springframework.web.util.UriComponentsBuilder; import static java.util.Map.entry; import static org.assertj.core.api.Assertions.assertThat; @@ -69,6 +70,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { public void convertWhenStateThenReturnNull() { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.STATE, "abc123"); + updateQueryString(request); Authentication authentication = this.converter.convert(request); assertThat(authentication).isNull(); } @@ -84,6 +86,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { public void convertWhenEmptyUserCodeParameterThenInvalidRequestError() { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, ""); + updateQueryString(request); // @formatter:off assertThatExceptionOfType(OAuth2AuthenticationException.class) .isThrownBy(() -> this.converter.convert(request)) @@ -98,6 +101,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { public void convertWhenInvalidUserCodeParameterThenInvalidRequestError() { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, "LONG-USER-CODE"); + updateQueryString(request); // @formatter:off assertThatExceptionOfType(OAuth2AuthenticationException.class) .isThrownBy(() -> this.converter.convert(request)) @@ -113,6 +117,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); request.addParameter(OAuth2ParameterNames.USER_CODE, "another"); + updateQueryString(request); // @formatter:off assertThatExceptionOfType(OAuth2AuthenticationException.class) .isThrownBy(() -> this.converter.convert(request)) @@ -127,6 +132,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { public void convertWhenMissingPrincipalThenReturnDeviceVerificationAuthentication() { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE.toLowerCase().replace("-", " . ")); + updateQueryString(request); OAuth2DeviceVerificationAuthenticationToken authentication = (OAuth2DeviceVerificationAuthenticationToken) this.converter.convert(request); @@ -140,6 +146,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { public void convertWhenNonNormalizedUserCodeThenReturnDeviceVerificationAuthentication() { MockHttpServletRequest request = createRequest(); request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE.toLowerCase().replace("-", " . ")); + updateQueryString(request); SecurityContextImpl securityContext = new SecurityContextImpl(); securityContext.setAuthentication(new TestingAuthenticationToken("user", null)); @@ -159,6 +166,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE); request.addParameter("param-1", "value-1"); request.addParameter("param-2", "value-1", "value-2"); + updateQueryString(request); SecurityContextImpl securityContext = new SecurityContextImpl(); securityContext.setAuthentication(new TestingAuthenticationToken("user", null)); @@ -180,4 +188,17 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { request.setRequestURI(VERIFICATION_URI); return request; } + + private static void updateQueryString(MockHttpServletRequest request) { + UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI()); + request.getParameterMap().forEach((key, values) -> { + if (values.length > 0) { + for (String value : values) { + uriBuilder.queryParam(key, value); + } + } + }); + request.setQueryString(uriBuilder.build().getQuery()); + } + }