diff --git a/docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java b/docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java index d4935c75..568572bf 100644 --- a/docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java +++ b/docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java index 0bcebec0..12619883 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package org.springframework.security.oauth2.server.authorization.oidc.web.authentication; +import java.util.Map; + import javax.servlet.http.HttpServletRequest; import org.springframework.http.converter.HttpMessageConverter; @@ -30,6 +32,8 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter; import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; /** @@ -65,14 +69,30 @@ public final class OidcClientRegistrationAuthenticationConverter implements Auth return new OidcClientRegistrationAuthenticationToken(principal, clientRegistration); } + MultiValueMap parameters = getQueryParameters(request); + // client_id (REQUIRED) - String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID); + String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); if (!StringUtils.hasText(clientId) || - request.getParameterValues(OAuth2ParameterNames.CLIENT_ID).length != 1) { + parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); } return new OidcClientRegistrationAuthenticationToken(principal, clientId); } + private static MultiValueMap getQueryParameters(HttpServletRequest request) { + Map parameterMap = request.getParameterMap(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameterMap.forEach((key, values) -> { + String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : ""; + if (queryString.contains(key) && values.length > 0) { + for (String value : values) { + parameters.add(key, value); + } + } + }); + return parameters; + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java index a3153a83..7e9fc0d8 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java @@ -23,7 +23,6 @@ import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; @@ -48,17 +47,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica @Nullable @Override public Authentication convert(HttpServletRequest request) { - String queryString = request.getQueryString(); - if (StringUtils.hasText(queryString) && - (queryString.contains(OAuth2ParameterNames.CLIENT_ID) || - queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) { - OAuth2Error error = new OAuth2Error( - OAuth2ErrorCodes.INVALID_REQUEST, - "Client credentials MUST NOT be included in the request URI.", - null); - throw new OAuth2AuthenticationException(error); - } - MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); // client_id (REQUIRED) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java index ed8fb779..bb426cc2 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java index 23b73a8c..032584c7 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java @@ -48,6 +48,7 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut @Override public Authentication convert(HttpServletRequest request) { MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + // grant_type (REQUIRED) String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE); if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java index f2cd43ec..b408eccc 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java @@ -66,7 +66,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getQueryParameters(request); + MultiValueMap parameters = + "GET".equals(request.getMethod()) ? + OAuth2EndpointUtils.getQueryParameters(request) : + OAuth2EndpointUtils.getFormParameters(request); // response_type (REQUIRED) String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java index e1f7d555..49fd55e9 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java @@ -39,6 +39,7 @@ import org.springframework.util.StringUtils; */ final class OAuth2EndpointUtils { static final String ACCESS_TOKEN_REQUEST_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"; + private OAuth2EndpointUtils() { } @@ -46,9 +47,9 @@ final class OAuth2EndpointUtils { Map parameterMap = request.getParameterMap(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameterMap.forEach((key, values) -> { + String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : ""; // If not query parameter then it's a form parameter - if ((!StringUtils.hasText(request.getQueryString()) && values.length > 0) - || (!request.getQueryString().contains(key) && values.length > 0)) { + if (!queryString.contains(key) && values.length > 0) { for (String value : values) { parameters.add(key, value); } @@ -61,8 +62,8 @@ final class OAuth2EndpointUtils { Map parameterMap = request.getParameterMap(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameterMap.forEach((key, values) -> { - if (StringUtils.hasText(request.getQueryString()) - && request.getQueryString().contains(key) && values.length > 0) { + String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : ""; + if (queryString.contains(key) && values.length > 0) { for (String value : values) { parameters.add(key, value); } @@ -75,7 +76,10 @@ final class OAuth2EndpointUtils { if (!matchesAuthorizationCodeGrantRequest(request)) { return Collections.emptyMap(); } - MultiValueMap multiValueParameters = getFormParameters(request); + MultiValueMap multiValueParameters = + "GET".equals(request.getMethod()) ? + getQueryParameters(request) : + getFormParameters(request); for (String exclusion : exclusions) { multiValueParameters.remove(exclusion); } @@ -88,16 +92,14 @@ final class OAuth2EndpointUtils { } static boolean matchesAuthorizationCodeGrantRequest(HttpServletRequest request) { - MultiValueMap parameters = getFormParameters(request); return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals( - parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) && - parameters.getFirst(OAuth2ParameterNames.CODE) != null; + request.getParameter(OAuth2ParameterNames.GRANT_TYPE)) && + request.getParameter(OAuth2ParameterNames.CODE) != null; } static boolean matchesPkceTokenRequest(HttpServletRequest request) { - MultiValueMap parameters = getFormParameters(request); return matchesAuthorizationCodeGrantRequest(request) && - parameters.getFirst(PkceParameterNames.CODE_VERIFIER) != null; + request.getParameter(PkceParameterNames.CODE_VERIFIER) != null; } static void throwError(String errorCode, String parameterName, String errorUri) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java index b8a79b73..4cc0b11e 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java index 77369f22..6a714075 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java @@ -53,7 +53,10 @@ public final class PublicClientAuthenticationConverter implements Authentication return null; } - MultiValueMap parameters = OAuth2EndpointUtils.getFormParameters(request); + MultiValueMap parameters = + "GET".equals(request.getMethod()) ? + OAuth2EndpointUtils.getQueryParameters(request) : + OAuth2EndpointUtils.getFormParameters(request); // client_id (REQUIRED for public clients) String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java index a7adcce6..4e65102c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java @@ -59,7 +59,6 @@ import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService; @@ -98,7 +97,6 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; -import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -232,37 +230,6 @@ public class OAuth2ClientCredentialsGrantTests { verify(jwtCustomizer).customize(any()); } - // gh-1378 - @Test - public void requestWhenTokenRequestWithClientCredentialsInQueryParamThenInvalidRequest() throws Exception { - this.spring.register(AuthorizationServerConfiguration.class).autowire(); - - RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); - this.registeredClientRepository.save(registeredClient); - - String tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI) - .queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) - .toUriString(); - - this.mvc.perform(post(tokenEndpointUri) - .param(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret()) - .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .param(OAuth2ParameterNames.SCOPE, "scope1 scope2")) - .andExpect(status().isBadRequest()) - .andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST)); - - tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI) - .queryParam(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret()) - .toUriString(); - - this.mvc.perform(post(tokenEndpointUri) - .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) - .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .param(OAuth2ParameterNames.SCOPE, "scope1 scope2")) - .andExpect(status().isBadRequest()) - .andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST)); - } - @Test public void requestWhenTokenEndpointCustomizedThenUsed() throws Exception { this.spring.register(AuthorizationServerConfigurationCustomTokenEndpoint.class).autowire(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java index da5564b3..0e88df75 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -327,6 +328,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CLIENT_ID, ""); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -342,6 +344,7 @@ public class OidcClientRegistrationEndpointFilterTests { request.setServletPath(requestUri); request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id"); request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id2"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -388,6 +391,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -421,6 +425,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId()); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -463,6 +468,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client-id"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -492,6 +498,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId()); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -513,6 +520,7 @@ public class OidcClientRegistrationEndpointFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1"); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -522,6 +530,18 @@ public class OidcClientRegistrationEndpointFilterTests { any(OAuth2AuthenticationException.class)); } + private static void updateQueryString(MockHttpServletRequest request) { + UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI()); + request.getParameterMap().forEach((key, values) -> { + if (values.length > 0) { + for (String value : values) { + uriBuilder.queryParam(key, value); + } + } + }); + request.setQueryString(uriBuilder.build().getQuery()); + } + private OAuth2Error readError(MockHttpServletResponse response) throws Exception { MockClientHttpResponse httpResponse = new MockClientHttpResponse( response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index c170734b..b83771dc 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -37,7 +37,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.mock.web.MockServletContext; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -59,8 +58,8 @@ import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetails; -import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -173,7 +172,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE)); + request -> { + request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE); + updateQueryString(request); + }); } @Test @@ -182,7 +184,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token")); + request -> { + request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); + updateQueryString(request); + }); } @Test @@ -191,7 +196,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.RESPONSE_TYPE, OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, - request -> request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token")); + request -> { + request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"); + updateQueryString(request); + }); } @Test @@ -200,7 +208,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID)); + request -> { + request.removeParameter(OAuth2ParameterNames.CLIENT_ID); + updateQueryString(request); + }); } @Test @@ -209,7 +220,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2")); + request -> { + request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"); + updateQueryString(request); + }); } @Test @@ -218,7 +232,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com")); + request -> { + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"); + updateQueryString(request); + }); } @Test @@ -227,7 +244,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.SCOPE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.SCOPE, "scope2")); + request -> { + request.addParameter(OAuth2ParameterNames.SCOPE, "scope2"); + updateQueryString(request); + }); } @Test @@ -236,7 +256,10 @@ public class OAuth2AuthorizationEndpointFilterTests { TestRegisteredClients.registeredClient().build(), OAuth2ParameterNames.STATE, OAuth2ErrorCodes.INVALID_REQUEST, - request -> request.addParameter(OAuth2ParameterNames.STATE, "state2")); + request -> { + request.addParameter(OAuth2ParameterNames.STATE, "state2"); + updateQueryString(request); + }); } @Test @@ -266,13 +289,7 @@ public class OAuth2AuthorizationEndpointFilterTests { request -> { request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge"); request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge"); - String originalQueryString = request.getQueryString(); - if (StringUtils.hasText(originalQueryString)) { - String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE) - .concat("=code-challenge").concat("&") - .concat(PkceParameterNames.CODE_CHALLENGE).concat("=another-code-challenge"); - request.setQueryString(newQueryString); - } + updateQueryString(request); }); } @@ -285,13 +302,7 @@ public class OAuth2AuthorizationEndpointFilterTests { request -> { request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); - String originalQueryString = request.getQueryString(); - if (StringUtils.hasText(originalQueryString)) { - String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE_METHOD) - .concat("=S256").concat("&") - .concat(PkceParameterNames.CODE_CHALLENGE_METHOD).concat("=S256"); - request.setQueryString(newQueryString); - } + updateQueryString(request); }); } @@ -574,10 +585,7 @@ public class OAuth2AuthorizationEndpointFilterTests { MockHttpServletRequest request = createAuthorizationRequest(registeredClient); request.addParameter("custom-param", "custom-value-1", "custom-value-2"); - String newQueryString = request.getQueryString().concat("custom-param") - .concat("=custom-value-1").concat("&") - .concat("custom-param").concat("=custom-value-2"); - request.setQueryString(newQueryString); + updateQueryString(request); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -623,6 +631,7 @@ public class OAuth2AuthorizationEndpointFilterTests { MockHttpServletRequest request = createAuthorizationRequest(registeredClient); request.setMethod("POST"); // OpenID Connect supports POST method + request.setQueryString(null); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); @@ -667,15 +676,18 @@ public class OAuth2AuthorizationEndpointFilterTests { private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) { String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI; - MockHttpServletRequest request = MockMvcRequestBuilders.get(requestUri) - .queryParam(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()) - .queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) - .queryParam(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()) - .queryParam(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")) - .queryParam(OAuth2ParameterNames.STATE, "state") - .buildRequest(new MockServletContext()); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); request.setRemoteAddr(REMOTE_ADDRESS); + + request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); + request.addParameter(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + updateQueryString(request); + return request; } @@ -692,6 +704,18 @@ public class OAuth2AuthorizationEndpointFilterTests { return request; } + private static void updateQueryString(MockHttpServletRequest request) { + UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI()); + request.getParameterMap().forEach((key, values) -> { + if (values.length > 0) { + for (String value : values) { + uriBuilder.queryParam(key, value); + } + } + }); + request.setQueryString(uriBuilder.build().getQuery()); + } + private static String scopeCheckbox(String scope) { return MessageFormat.format( "", diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java index 6464ab00..894fb409 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java @@ -79,31 +79,6 @@ public class ClientSecretPostAuthenticationConverterTests { .isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); } - // gh-1378 - @Test - public void convertWhenClientCredentialsInQueryParamThenInvalidRequestError() { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1"); - request.addParameter(OAuth2ParameterNames.CLIENT_SECRET, "client-secret"); - request.setQueryString("client_id=client-1"); - assertThatThrownBy(() -> this.converter.convert(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); - assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI."); - }); - - request.setQueryString("client_secret=client-secret"); - assertThatThrownBy(() -> this.converter.convert(request)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .satisfies(error -> { - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); - assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI."); - }); - } - @Test public void convertWhenPostWithValidCredentialsThenReturnClientAuthenticationToken() { MockHttpServletRequest request = new MockHttpServletRequest();