diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java index e9ba1f32..d72dc5c0 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2024 the original author or authors. + * Copyright 2020-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; @@ -53,6 +54,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; @@ -90,6 +92,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter private final AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); + private final BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint(); + private AuthenticationConverter authenticationConverter; private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess; @@ -110,6 +114,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter Assert.notNull(requestMatcher, "requestMatcher cannot be null"); this.authenticationManager = authenticationManager; this.requestMatcher = requestMatcher; + this.basicAuthenticationEntryPoint.setRealmName("default"); // @formatter:off this.authenticationConverter = new DelegatingAuthenticationConverter( Arrays.asList( @@ -130,8 +135,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter return; } + Authentication authenticationRequest = null; try { - Authentication authenticationRequest = this.authenticationConverter.convert(request); + authenticationRequest = this.authenticationConverter.convert(request); if (authenticationRequest instanceof AbstractAuthenticationToken) { ((AbstractAuthenticationToken) authenticationRequest) .setDetails(this.authenticationDetailsSource.buildDetails(request)); @@ -148,7 +154,14 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter if (this.logger.isTraceEnabled()) { this.logger.trace(LogMessage.format("Client authentication failed: %s", ex.getError()), ex); } - this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex); + if (authenticationRequest instanceof OAuth2ClientAuthenticationToken clientAuthentication) { + this.authenticationFailureHandler.onAuthenticationFailure(request, response, + new OAuth2ClientAuthenticationException(ex.getError(), ex, clientAuthentication)); + } + else { + this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex); + } + } } @@ -200,21 +213,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter } private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, - AuthenticationException exception) throws IOException { + AuthenticationException authenticationException) throws IOException { SecurityContextHolder.clearContext(); - // TODO - // The authorization server MAY return an HTTP 401 (Unauthorized) status code - // to indicate which HTTP authentication schemes are supported. - // If the client attempted to authenticate via the "Authorization" request header - // field, - // the authorization server MUST respond with an HTTP 401 (Unauthorized) status - // code and - // include the "WWW-Authenticate" response header field - // matching the authentication scheme used by the client. - - OAuth2Error error = ((OAuth2AuthenticationException) exception).getError(); + if (authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException) { + OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException + .getClientAuthentication(); + if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC + .equals(clientAuthentication.getClientAuthenticationMethod())) { + this.basicAuthenticationEntryPoint.commence(request, response, authenticationException); + return; + } + } + + OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) { httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED); @@ -249,4 +262,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter } } + private static final class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException { + + private final OAuth2ClientAuthenticationToken clientAuthentication; + + private OAuth2ClientAuthenticationException(OAuth2Error error, Throwable cause, + OAuth2ClientAuthenticationToken clientAuthentication) { + super(error, cause); + Assert.notNull(clientAuthentication, "clientAuthentication cannot be null"); + this.clientAuthentication = clientAuthentication; + } + + private OAuth2ClientAuthenticationToken getClientAuthentication() { + return this.clientAuthentication; + } + + } + } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java index d794baf5..edcdb84b 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2024 the original author or authors. + * Copyright 2020-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -538,7 +538,7 @@ public class OAuth2AuthorizationCodeGrantTests { } @Test - public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenBadRequest() throws Exception { + public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenUnauthorized() throws Exception { this.spring.register(AuthorizationServerConfiguration.class).autowire(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); @@ -569,7 +569,7 @@ public class OAuth2AuthorizationCodeGrantTests { .params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization)) .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) .header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))) - .andExpect(status().isBadRequest()); + .andExpect(status().isUnauthorized()); } // gh-1011 @@ -601,7 +601,7 @@ public class OAuth2AuthorizationCodeGrantTests { } @Test - public void requestWhenConfidentialClientWithPkceAndMissingCodeChallengeButCodeVerifierProvidedThenBadRequest() + public void requestWhenConfidentialClientWithPkceAndMissingCodeChallengeButCodeVerifierProvidedThenUnauthorized() throws Exception { this.spring.register(AuthorizationServerConfiguration.class).autowire(); @@ -631,7 +631,7 @@ public class OAuth2AuthorizationCodeGrantTests { .params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization)) .param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER) .header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))) - .andExpect(status().isBadRequest()); + .andExpect(status().isUnauthorized()); } @Test diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java index 97dc1750..f2ebe115 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; @@ -175,26 +176,25 @@ public class OAuth2ClientAuthenticationFilterTests { // gh-889 @Test - public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenInvalidRequestError() - throws Exception { + public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenReturnChallenge() throws Exception { // Hex 00 -> null String clientId = new String(Hex.decode("00"), StandardCharsets.UTF_8); - assertWhenInvalidClientIdThenInvalidRequestError(clientId); + assertWhenInvalidClientIdThenReturnChallenge(clientId); // Hex 0a61 -> line feed + a clientId = new String(Hex.decode("0a61"), StandardCharsets.UTF_8); - assertWhenInvalidClientIdThenInvalidRequestError(clientId); + assertWhenInvalidClientIdThenReturnChallenge(clientId); // Hex 1b -> escape clientId = new String(Hex.decode("1b"), StandardCharsets.UTF_8); - assertWhenInvalidClientIdThenInvalidRequestError(clientId); + assertWhenInvalidClientIdThenReturnChallenge(clientId); // Hex 1b61 -> escape + a clientId = new String(Hex.decode("1b61"), StandardCharsets.UTF_8); - assertWhenInvalidClientIdThenInvalidRequestError(clientId); + assertWhenInvalidClientIdThenReturnChallenge(clientId); } - private void assertWhenInvalidClientIdThenInvalidRequestError(String clientId) throws Exception { + private void assertWhenInvalidClientIdThenReturnChallenge(String clientId) throws Exception { given(this.authenticationConverter.convert(any(HttpServletRequest.class))) .willReturn(new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "secret", null)); @@ -210,13 +210,12 @@ public class OAuth2ClientAuthenticationFilterTests { verifyNoInteractions(this.authenticationManager); assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); - assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - OAuth2Error error = readError(response); - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value()); + assertThat(response.getHeader(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Basic realm=\"default\""); } @Test - public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() throws Exception { + public void doFilterWhenRequestMatchesAndBadCredentialsThenReturnChallenge() throws Exception { given(this.authenticationConverter.convert(any(HttpServletRequest.class))) .willReturn(new OAuth2ClientAuthenticationToken("clientId", ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "invalid-secret", null)); @@ -235,8 +234,7 @@ public class OAuth2ClientAuthenticationFilterTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value()); - OAuth2Error error = readError(response); - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT); + assertThat(response.getHeader(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Basic realm=\"default\""); } @Test