Browse Source

Merge branch '1.1.x'

Closes gh-1477
pull/1490/head
Joe Grandja 2 years ago
parent
commit
8e8dd22d64
  1. 2
      docs/src/test/java/sample/AuthorizationCodeGrantFlow.java
  2. 2
      docs/src/test/java/sample/DeviceAuthorizationGrantFlow.java
  3. 67
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OAuth2EndpointUtils.java
  4. 9
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java
  5. 25
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java
  6. 14
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java
  7. 10
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java
  8. 6
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java
  9. 7
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java
  10. 6
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java
  11. 6
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java
  12. 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverter.java
  13. 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverter.java
  14. 6
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverter.java
  15. 5
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverter.java
  16. 29
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java
  17. 6
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java
  18. 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java
  19. 4
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java
  20. 5
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java
  21. 29
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java
  22. 33
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java
  23. 4
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2DeviceCodeGrantTests.java
  24. 12
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java
  25. 22
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java
  26. 59
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java
  27. 22
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java
  28. 25
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java
  29. 21
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java

2
docs/src/test/java/sample/AuthorizationCodeGrantFlow.java

@ -110,7 +110,7 @@ public class AuthorizationCodeGrantFlow { @@ -110,7 +110,7 @@ public class AuthorizationCodeGrantFlow {
// @formatter:off
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorize")
.params(parameters)
.queryParams(parameters)
.with(user(this.username).roles("USER")))
.andExpect(status().isOk())
.andExpect(header().string("content-type", containsString(MediaType.TEXT_HTML_VALUE)))

2
docs/src/test/java/sample/DeviceAuthorizationGrantFlow.java

@ -117,7 +117,7 @@ public class DeviceAuthorizationGrantFlow { @@ -117,7 +117,7 @@ public class DeviceAuthorizationGrantFlow {
parameters.set(OAuth2ParameterNames.USER_CODE, userCode);
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/device_verification")
.params(parameters)
.queryParams(parameters)
.with(user(this.username).roles("USER")))
.andExpect(status().isOk())
.andExpect(header().string(HttpHeaders.CONTENT_TYPE, containsString(MediaType.TEXT_HTML_VALUE)))

67
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OAuth2EndpointUtils.java

@ -0,0 +1,67 @@ @@ -0,0 +1,67 @@
/*
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.oidc.web.authentication;
import java.util.Map;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* Utility methods for the OAuth 2.0 Protocol Endpoints.
*
* @author Joe Grandja
* @author Greg Li
* @since 1.1.4
*/
final class OAuth2EndpointUtils {
private OAuth2EndpointUtils() {
}
static MultiValueMap<String, String> getFormParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameterMap.forEach((key, values) -> {
String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
// If not query parameter then it's a form parameter
if (!queryString.contains(key) && values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
}
});
return parameters;
}
static MultiValueMap<String, String> getQueryParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameterMap.forEach((key, values) -> {
String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
if (queryString.contains(key) && values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
}
});
return parameters;
}
}

9
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,7 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat @@ -30,6 +30,7 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat
import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
@ -65,10 +66,12 @@ public final class OidcClientRegistrationAuthenticationConverter implements Auth @@ -65,10 +66,12 @@ public final class OidcClientRegistrationAuthenticationConverter implements Auth
return new OidcClientRegistrationAuthenticationToken(principal, clientRegistration);
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getQueryParameters(request);
// client_id (REQUIRED)
String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
if (!StringUtils.hasText(clientId) ||
request.getParameterValues(OAuth2ParameterNames.CLIENT_ID).length != 1) {
parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
}

25
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java

@ -15,8 +15,6 @@ @@ -15,8 +15,6 @@
*/
package org.springframework.security.oauth2.server.authorization.oidc.web.authentication;
import java.util.Map;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpSession;
@ -31,7 +29,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -31,7 +29,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.oidc.web.OidcLogoutEndpointFilter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
@ -51,12 +48,15 @@ public final class OidcLogoutAuthenticationConverter implements AuthenticationCo @@ -51,12 +48,15 @@ public final class OidcLogoutAuthenticationConverter implements AuthenticationCo
@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = getParameters(request);
MultiValueMap<String, String> parameters =
"GET".equals(request.getMethod()) ?
OAuth2EndpointUtils.getQueryParameters(request) :
OAuth2EndpointUtils.getFormParameters(request);
// id_token_hint (REQUIRED) // RECOMMENDED as per spec
String idTokenHint = request.getParameter("id_token_hint");
String idTokenHint = parameters.getFirst("id_token_hint");
if (!StringUtils.hasText(idTokenHint) ||
request.getParameterValues("id_token_hint").length != 1) {
parameters.get("id_token_hint").size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "id_token_hint");
}
@ -96,19 +96,6 @@ public final class OidcLogoutAuthenticationConverter implements AuthenticationCo @@ -96,19 +96,6 @@ public final class OidcLogoutAuthenticationConverter implements AuthenticationCo
sessionId, clientId, postLogoutRedirectUri, state);
}
private static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
parameterMap.forEach((key, values) -> {
if (values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
}
});
return parameters;
}
private static void throwError(String errorCode, String parameterName) {
OAuth2Error error = new OAuth2Error(
errorCode,

14
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java

@ -23,7 +23,6 @@ import org.springframework.lang.Nullable; @@ -23,7 +23,6 @@ import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
@ -48,7 +47,7 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica @@ -48,7 +47,7 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// client_id (REQUIRED)
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
@ -70,17 +69,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica @@ -70,17 +69,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
}
String queryString = request.getQueryString();
if (StringUtils.hasText(queryString) &&
(queryString.contains(OAuth2ParameterNames.CLIENT_ID) ||
queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) {
OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_REQUEST,
"Client credentials MUST NOT be included in the request URI.",
null);
throw new OAuth2AuthenticationException(error);
}
Map<String, Object> additionalParameters = OAuth2EndpointUtils.getParametersIfMatchesAuthorizationCodeGrantRequest(request,
OAuth2ParameterNames.CLIENT_ID,
OAuth2ParameterNames.CLIENT_SECRET);

10
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,13 +48,13 @@ public final class JwtClientAssertionAuthenticationConverter implements Authenti @@ -48,13 +48,13 @@ public final class JwtClientAssertionAuthenticationConverter implements Authenti
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
if (request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null ||
request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION) == null) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
if (parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null ||
parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION) == null) {
return null;
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// client_assertion_type (REQUIRED)
String clientAssertionType = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE);
if (parameters.get(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE).size() != 1) {

6
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java

@ -47,16 +47,16 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut @@ -47,16 +47,16 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) {
return null;
}
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// code (REQUIRED)
String code = parameters.getFirst(OAuth2ParameterNames.CODE);
if (!StringUtils.hasText(code) ||

7
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java

@ -66,10 +66,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme @@ -66,10 +66,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme
return null;
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
MultiValueMap<String, String> parameters =
"GET".equals(request.getMethod()) ?
OAuth2EndpointUtils.getQueryParameters(request) :
OAuth2EndpointUtils.getFormParameters(request);
// response_type (REQUIRED)
String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE);
String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE);
if (!StringUtils.hasText(responseType) ||
parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE);

6
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java

@ -54,13 +54,13 @@ public final class OAuth2AuthorizationConsentAuthenticationConverter implements @@ -54,13 +54,13 @@ public final class OAuth2AuthorizationConsentAuthenticationConverter implements
@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
if (!"POST".equals(request.getMethod()) ||
request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE) != null) {
parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE) != null) {
return null;
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
String authorizationUri = request.getRequestURL().toString();
// client_id (REQUIRED)

6
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java

@ -50,16 +50,16 @@ public final class OAuth2ClientCredentialsAuthenticationConverter implements Aut @@ -50,16 +50,16 @@ public final class OAuth2ClientCredentialsAuthenticationConverter implements Aut
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(grantType)) {
return null;
}
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// scope (OPTIONAL)
String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope) &&

2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationConsentAuthenticationConverter.java

@ -59,7 +59,7 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationConverter imple @@ -59,7 +59,7 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationConverter imple
return null;
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
String authorizationUri = request.getRequestURL().toString();

2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceAuthorizationRequestAuthenticationConverter.java

@ -53,7 +53,7 @@ public final class OAuth2DeviceAuthorizationRequestAuthenticationConverter imple @@ -53,7 +53,7 @@ public final class OAuth2DeviceAuthorizationRequestAuthenticationConverter imple
public Authentication convert(HttpServletRequest request) {
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
String authorizationUri = request.getRequestURL().toString();

6
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceCodeAuthenticationConverter.java

@ -49,16 +49,16 @@ public final class OAuth2DeviceCodeAuthenticationConverter implements Authentica @@ -49,16 +49,16 @@ public final class OAuth2DeviceCodeAuthenticationConverter implements Authentica
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.DEVICE_CODE.getValue().equals(grantType)) {
return null;
}
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// device_code (REQUIRED)
String deviceCode = parameters.getFirst(OAuth2ParameterNames.DEVICE_CODE);
if (!StringUtils.hasText(deviceCode) ||

5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverter.java

@ -59,7 +59,10 @@ public final class OAuth2DeviceVerificationAuthenticationConverter implements Au @@ -59,7 +59,10 @@ public final class OAuth2DeviceVerificationAuthenticationConverter implements Au
return null;
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
MultiValueMap<String, String> parameters =
"GET".equals(request.getMethod()) ?
OAuth2EndpointUtils.getQueryParameters(request) :
OAuth2EndpointUtils.getFormParameters(request);
// user_code (REQUIRED)
String userCode = parameters.getFirst(OAuth2ParameterNames.USER_CODE);

29
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java

@ -29,11 +29,13 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; @@ -29,11 +29,13 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* Utility methods for the OAuth 2.0 Protocol Endpoints.
*
* @author Joe Grandja
* @author Greg Li
* @since 0.1.2
*/
final class OAuth2EndpointUtils {
@ -42,11 +44,27 @@ final class OAuth2EndpointUtils { @@ -42,11 +44,27 @@ final class OAuth2EndpointUtils {
private OAuth2EndpointUtils() {
}
static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
static MultiValueMap<String, String> getFormParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameterMap.forEach((key, values) -> {
if (values.length > 0) {
String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
// If not query parameter then it's a form parameter
if (!queryString.contains(key) && values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
}
});
return parameters;
}
static MultiValueMap<String, String> getQueryParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameterMap.forEach((key, values) -> {
String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
if (queryString.contains(key) && values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
@ -59,7 +77,10 @@ final class OAuth2EndpointUtils { @@ -59,7 +77,10 @@ final class OAuth2EndpointUtils {
if (!matchesAuthorizationCodeGrantRequest(request)) {
return Collections.emptyMap();
}
MultiValueMap<String, String> multiValueParameters = getParameters(request);
MultiValueMap<String, String> multiValueParameters =
"GET".equals(request.getMethod()) ?
getQueryParameters(request) :
getFormParameters(request);
for (String exclusion : exclusions) {
multiValueParameters.remove(exclusion);
}

6
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java

@ -50,16 +50,16 @@ public final class OAuth2RefreshTokenAuthenticationConverter implements Authenti @@ -50,16 +50,16 @@ public final class OAuth2RefreshTokenAuthenticationConverter implements Authenti
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) {
return null;
}
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// refresh_token (REQUIRED)
String refreshToken = parameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN);
if (!StringUtils.hasText(refreshToken) ||

2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java

@ -49,7 +49,7 @@ public final class OAuth2TokenIntrospectionAuthenticationConverter implements Au @@ -49,7 +49,7 @@ public final class OAuth2TokenIntrospectionAuthenticationConverter implements Au
public Authentication convert(HttpServletRequest request) {
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// token (REQUIRED)
String token = parameters.getFirst(OAuth2ParameterNames.TOKEN);

4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -46,7 +46,7 @@ public final class OAuth2TokenRevocationAuthenticationConverter implements Authe @@ -46,7 +46,7 @@ public final class OAuth2TokenRevocationAuthenticationConverter implements Authe
public Authentication convert(HttpServletRequest request) {
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
// token (REQUIRED)
String token = parameters.getFirst(OAuth2ParameterNames.TOKEN);

5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java

@ -53,7 +53,10 @@ public final class PublicClientAuthenticationConverter implements Authentication @@ -53,7 +53,10 @@ public final class PublicClientAuthenticationConverter implements Authentication
return null;
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
MultiValueMap<String, String> parameters =
"GET".equals(request.getMethod()) ?
OAuth2EndpointUtils.getQueryParameters(request) :
OAuth2EndpointUtils.getFormParameters(request);
// client_id (REQUIRED for public clients)
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);

29
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java

@ -158,6 +158,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @@ -158,6 +158,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
* @author Daniel Garnier-Moiroux
* @author Dmitriy Dubson
* @author Steve Riesenberg
* @author Greg Li
*/
@ExtendWith(SpringTestContextExtension.class)
public class OAuth2AuthorizationCodeGrantTests {
@ -260,7 +261,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -260,7 +261,7 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient);
this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient)))
.queryParams(getAuthorizationRequestParameters(registeredClient)))
.andExpect(status().isUnauthorized())
.andReturn();
}
@ -302,7 +303,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -302,7 +303,7 @@ public class OAuth2AuthorizationCodeGrantTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(authorizationEndpointUri)
.params(authorizationRequestParameters)
.queryParams(authorizationRequestParameters)
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();
@ -394,9 +395,9 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -394,9 +395,9 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient))
.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.queryParams(getAuthorizationRequestParameters(registeredClient))
.queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();
@ -487,9 +488,9 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -487,9 +488,9 @@ public class OAuth2AuthorizationCodeGrantTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters)
.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.queryParams(authorizationRequestParameters)
.queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();
@ -526,7 +527,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -526,7 +527,7 @@ public class OAuth2AuthorizationCodeGrantTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters)
.queryParams(authorizationRequestParameters)
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();
@ -572,7 +573,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -572,7 +573,7 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient);
String consentPage = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient))
.queryParams(getAuthorizationRequestParameters(registeredClient))
.with(user("user")))
.andExpect(status().is2xxSuccessful())
.andReturn()
@ -655,7 +656,7 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -655,7 +656,7 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient))
.queryParams(getAuthorizationRequestParameters(registeredClient))
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();
@ -790,9 +791,9 @@ public class OAuth2AuthorizationCodeGrantTests { @@ -790,9 +791,9 @@ public class OAuth2AuthorizationCodeGrantTests {
this.registeredClientRepository.save(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient))
.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.queryParams(getAuthorizationRequestParameters(registeredClient))
.queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();

33
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java

@ -61,7 +61,6 @@ import org.springframework.security.crypto.password.PasswordEncoder; @@ -61,7 +61,6 @@ import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService;
@ -102,7 +101,6 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand @@ -102,7 +101,6 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.web.util.UriComponentsBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
@ -236,37 +234,6 @@ public class OAuth2ClientCredentialsGrantTests { @@ -236,37 +234,6 @@ public class OAuth2ClientCredentialsGrantTests {
verify(jwtCustomizer).customize(any());
}
// gh-1378
@Test
public void requestWhenTokenRequestWithClientCredentialsInQueryParamThenInvalidRequest() throws Exception {
this.spring.register(AuthorizationServerConfiguration.class).autowire();
RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
this.registeredClientRepository.save(registeredClient);
String tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI)
.queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.toUriString();
this.mvc.perform(post(tokenEndpointUri)
.param(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret())
.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
.param(OAuth2ParameterNames.SCOPE, "scope1 scope2"))
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST));
tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI)
.queryParam(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret())
.toUriString();
this.mvc.perform(post(tokenEndpointUri)
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
.param(OAuth2ParameterNames.SCOPE, "scope1 scope2"))
.andExpect(status().isBadRequest())
.andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST));
}
@Test
public void requestWhenTokenRequestPostsClientCredentialsAndRequiresUpgradingThenClientSecretUpgraded() throws Exception {
this.spring.register(AuthorizationServerConfigurationCustomPasswordEncoder.class).autowire();

4
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2DeviceCodeGrantTests.java

@ -279,7 +279,7 @@ public class OAuth2DeviceCodeGrantTests { @@ -279,7 +279,7 @@ public class OAuth2DeviceCodeGrantTests {
// @formatter:off
this.mvc.perform(get(DEFAULT_DEVICE_VERIFICATION_ENDPOINT_URI)
.params(parameters))
.queryParams(parameters))
.andExpect(status().isUnauthorized());
// @formatter:on
}
@ -313,7 +313,7 @@ public class OAuth2DeviceCodeGrantTests { @@ -313,7 +313,7 @@ public class OAuth2DeviceCodeGrantTests {
// @formatter:off
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_DEVICE_VERIFICATION_ENDPOINT_URI)
.params(parameters)
.queryParams(parameters)
.with(user("user")))
.andExpect(status().isOk())
.andExpect(content().contentTypeCompatibleWith(MediaType.TEXT_HTML))

12
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java

@ -193,8 +193,8 @@ public class OidcTests { @@ -193,8 +193,8 @@ public class OidcTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters)
.with(user("user").roles("A", "B")))
.queryParams(authorizationRequestParameters)
.with(user("user").roles("A", "B")))
.andExpect(status().is3xxRedirection())
.andReturn();
String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
@ -249,7 +249,7 @@ public class OidcTests { @@ -249,7 +249,7 @@ public class OidcTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters)
.queryParams(authorizationRequestParameters)
.with(user("user").roles("A", "B")))
.andExpect(status().is3xxRedirection())
.andReturn();
@ -306,7 +306,7 @@ public class OidcTests { @@ -306,7 +306,7 @@ public class OidcTests {
// Login
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters)
.queryParams(authorizationRequestParameters)
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();
@ -355,7 +355,7 @@ public class OidcTests { @@ -355,7 +355,7 @@ public class OidcTests {
MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient1);
MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters)
.queryParams(authorizationRequestParameters)
.with(user("user1")))
.andExpect(status().is3xxRedirection())
.andReturn();
@ -387,7 +387,7 @@ public class OidcTests { @@ -387,7 +387,7 @@ public class OidcTests {
authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient2);
mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(authorizationRequestParameters)
.queryParams(authorizationRequestParameters)
.with(user("user2")))
.andExpect(status().is3xxRedirection())
.andReturn();

22
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,6 +61,7 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut @@ -61,6 +61,7 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.web.util.UriComponentsBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@ -327,6 +328,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -327,6 +328,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -342,6 +344,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -342,6 +344,7 @@ public class OidcClientRegistrationEndpointFilterTests {
request.setServletPath(requestUri);
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id");
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id2");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -388,6 +391,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -388,6 +391,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -421,6 +425,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -421,6 +425,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId());
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -463,6 +468,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -463,6 +468,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client-id");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -492,6 +498,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -492,6 +498,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId());
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -513,6 +520,7 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -513,6 +520,7 @@ public class OidcClientRegistrationEndpointFilterTests {
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -522,6 +530,18 @@ public class OidcClientRegistrationEndpointFilterTests { @@ -522,6 +530,18 @@ public class OidcClientRegistrationEndpointFilterTests {
any(OAuth2AuthenticationException.class));
}
private static void updateQueryString(MockHttpServletRequest request) {
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI());
request.getParameterMap().forEach((key, values) -> {
if (values.length > 0) {
for (String value : values) {
uriBuilder.queryParam(key, value);
}
}
});
request.setQueryString(uriBuilder.build().getQuery());
}
private OAuth2Error readError(MockHttpServletResponse response) throws Exception {
MockClientHttpResponse httpResponse = new MockClientHttpResponse(
response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));

59
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@ -60,6 +60,7 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand @@ -60,6 +60,7 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand
import org.springframework.security.web.authentication.WebAuthenticationDetails;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -79,6 +80,7 @@ import static org.mockito.Mockito.when; @@ -79,6 +80,7 @@ import static org.mockito.Mockito.when;
* @author Daniel Garnier-Moiroux
* @author Anoop Garlapati
* @author Dmitriy Dubson
* @author Greg Li
* @since 0.0.1
*/
public class OAuth2AuthorizationEndpointFilterTests {
@ -178,7 +180,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -178,7 +180,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.RESPONSE_TYPE,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE));
request -> {
request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE);
updateQueryString(request);
});
}
@Test
@ -187,7 +192,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -187,7 +192,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.RESPONSE_TYPE,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"));
request -> {
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
updateQueryString(request);
});
}
@Test
@ -196,7 +204,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -196,7 +204,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.RESPONSE_TYPE,
OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE,
request -> request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"));
request -> {
request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
updateQueryString(request);
});
}
@Test
@ -205,7 +216,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -205,7 +216,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.CLIENT_ID,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID));
request -> {
request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
updateQueryString(request);
});
}
@Test
@ -214,7 +228,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -214,7 +228,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.CLIENT_ID,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"));
request -> {
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2");
updateQueryString(request);
});
}
@Test
@ -223,7 +240,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -223,7 +240,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.REDIRECT_URI,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
request -> {
request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
updateQueryString(request);
});
}
@Test
@ -232,7 +252,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -232,7 +252,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.SCOPE,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.SCOPE, "scope2"));
request -> {
request.addParameter(OAuth2ParameterNames.SCOPE, "scope2");
updateQueryString(request);
});
}
@Test
@ -241,7 +264,10 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -241,7 +264,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
TestRegisteredClients.registeredClient().build(),
OAuth2ParameterNames.STATE,
OAuth2ErrorCodes.INVALID_REQUEST,
request -> request.addParameter(OAuth2ParameterNames.STATE, "state2"));
request -> {
request.addParameter(OAuth2ParameterNames.STATE, "state2");
updateQueryString(request);
});
}
@Test
@ -271,6 +297,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -271,6 +297,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
request -> {
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge");
updateQueryString(request);
});
}
@ -283,6 +310,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -283,6 +310,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
request -> {
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
updateQueryString(request);
});
}
@ -590,6 +618,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -590,6 +618,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.addParameter("custom-param", "custom-value-1", "custom-value-2");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -635,6 +664,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -635,6 +664,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
request.setMethod("POST"); // OpenID Connect supports POST method
request.setQueryString(null);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -689,6 +719,7 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -689,6 +719,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
request.addParameter(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
request.addParameter(OAuth2ParameterNames.STATE, "state");
updateQueryString(request);
return request;
}
@ -706,6 +737,18 @@ public class OAuth2AuthorizationEndpointFilterTests { @@ -706,6 +737,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
return request;
}
private static void updateQueryString(MockHttpServletRequest request) {
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI());
request.getParameterMap().forEach((key, values) -> {
if (values.length > 0) {
for (String value : values) {
uriBuilder.queryParam(key, value);
}
}
});
request.setQueryString(uriBuilder.build().getQuery());
}
private static String scopeCheckbox(String scope) {
return MessageFormat.format(
"<input class=\"form-check-input\" type=\"checkbox\" name=\"scope\" value=\"{0}\" id=\"{0}\">",

22
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceVerificationEndpointFilterTests.java

@ -23,6 +23,7 @@ import java.util.Set; @@ -23,6 +23,7 @@ import java.util.Set;
import jakarta.servlet.FilterChain;
import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -164,6 +165,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -164,6 +165,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
@ -223,6 +225,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -223,6 +225,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
request.addParameter("custom-param-1", "custom-value-1");
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
@ -248,6 +251,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -248,6 +251,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
@ -268,6 +272,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -268,6 +272,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
@ -291,6 +296,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -291,6 +296,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
request.setServerPort(443);
request.setServerName("provider.com");
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.setConsentPage("/consent");
@ -322,6 +328,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -322,6 +328,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
@ -340,6 +347,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -340,6 +347,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
@ -367,6 +375,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -367,6 +375,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
@ -388,6 +397,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -388,6 +397,7 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
updateQueryString(request);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
@ -445,6 +455,18 @@ public class OAuth2DeviceVerificationEndpointFilterTests { @@ -445,6 +455,18 @@ public class OAuth2DeviceVerificationEndpointFilterTests {
return request;
}
private static void updateQueryString(MockHttpServletRequest request) {
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI());
request.getParameterMap().forEach((key, values) -> {
if (values.length > 0) {
for (String value : values) {
uriBuilder.queryParam(key, value);
}
}
});
request.setQueryString(uriBuilder.build().getQuery());
}
private static String scopeCheckbox(String scope) {
return MessageFormat.format(
"<input class=\"form-check-input\" type=\"checkbox\" name=\"scope\" value=\"{0}\" id=\"{0}\">",

25
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java

@ -79,31 +79,6 @@ public class ClientSecretPostAuthenticationConverterTests { @@ -79,31 +79,6 @@ public class ClientSecretPostAuthenticationConverterTests {
.isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
// gh-1378
@Test
public void convertWhenClientCredentialsInQueryParamThenInvalidRequestError() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1");
request.addParameter(OAuth2ParameterNames.CLIENT_SECRET, "client-secret");
request.setQueryString("client_id=client-1");
assertThatThrownBy(() -> this.converter.convert(request))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI.");
});
request.setQueryString("client_secret=client-secret");
assertThatThrownBy(() -> this.converter.convert(request))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI.");
});
}
@Test
public void convertWhenPostWithValidCredentialsThenReturnClientAuthenticationToken() {
MockHttpServletRequest request = new MockHttpServletRequest();

21
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2DeviceVerificationAuthenticationConverterTests.java

@ -31,6 +31,7 @@ import org.springframework.security.oauth2.core.OAuth2Error; @@ -31,6 +31,7 @@ import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceVerificationAuthenticationToken;
import org.springframework.web.util.UriComponentsBuilder;
import static java.util.Map.entry;
import static org.assertj.core.api.Assertions.assertThat;
@ -69,6 +70,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { @@ -69,6 +70,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests {
public void convertWhenStateThenReturnNull() {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.STATE, "abc123");
updateQueryString(request);
Authentication authentication = this.converter.convert(request);
assertThat(authentication).isNull();
}
@ -84,6 +86,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { @@ -84,6 +86,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests {
public void convertWhenEmptyUserCodeParameterThenInvalidRequestError() {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, "");
updateQueryString(request);
// @formatter:off
assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.converter.convert(request))
@ -98,6 +101,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { @@ -98,6 +101,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests {
public void convertWhenInvalidUserCodeParameterThenInvalidRequestError() {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, "LONG-USER-CODE");
updateQueryString(request);
// @formatter:off
assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.converter.convert(request))
@ -113,6 +117,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { @@ -113,6 +117,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
request.addParameter(OAuth2ParameterNames.USER_CODE, "another");
updateQueryString(request);
// @formatter:off
assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.converter.convert(request))
@ -127,6 +132,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { @@ -127,6 +132,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests {
public void convertWhenMissingPrincipalThenReturnDeviceVerificationAuthentication() {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE.toLowerCase().replace("-", " . "));
updateQueryString(request);
OAuth2DeviceVerificationAuthenticationToken authentication =
(OAuth2DeviceVerificationAuthenticationToken) this.converter.convert(request);
@ -140,6 +146,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { @@ -140,6 +146,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests {
public void convertWhenNonNormalizedUserCodeThenReturnDeviceVerificationAuthentication() {
MockHttpServletRequest request = createRequest();
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE.toLowerCase().replace("-", " . "));
updateQueryString(request);
SecurityContextImpl securityContext = new SecurityContextImpl();
securityContext.setAuthentication(new TestingAuthenticationToken("user", null));
@ -159,6 +166,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { @@ -159,6 +166,7 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests {
request.addParameter(OAuth2ParameterNames.USER_CODE, USER_CODE);
request.addParameter("param-1", "value-1");
request.addParameter("param-2", "value-1", "value-2");
updateQueryString(request);
SecurityContextImpl securityContext = new SecurityContextImpl();
securityContext.setAuthentication(new TestingAuthenticationToken("user", null));
@ -180,4 +188,17 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests { @@ -180,4 +188,17 @@ public class OAuth2DeviceVerificationAuthenticationConverterTests {
request.setRequestURI(VERIFICATION_URI);
return request;
}
private static void updateQueryString(MockHttpServletRequest request) {
UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI());
request.getParameterMap().forEach((key, values) -> {
if (values.length > 0) {
for (String value : values) {
uriBuilder.queryParam(key, value);
}
}
});
request.setQueryString(uriBuilder.build().getQuery());
}
}

Loading…
Cancel
Save