@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2020 - 2024 the original author or authors .
* Copyright 2020 - 2025 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -34,6 +34,7 @@ import org.springframework.security.core.Authentication;
@@ -34,6 +34,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException ;
import org.springframework.security.core.context.SecurityContext ;
import org.springframework.security.core.context.SecurityContextHolder ;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod ;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException ;
import org.springframework.security.oauth2.core.OAuth2Error ;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes ;
@ -53,6 +54,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
@@ -53,6 +54,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler ;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler ;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource ;
import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint ;
import org.springframework.security.web.util.matcher.RequestMatcher ;
import org.springframework.util.Assert ;
import org.springframework.web.filter.OncePerRequestFilter ;
@ -90,6 +92,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
@@ -90,6 +92,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
private final AuthenticationDetailsSource < HttpServletRequest , ? > authenticationDetailsSource = new WebAuthenticationDetailsSource ( ) ;
private final BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint ( ) ;
private AuthenticationConverter authenticationConverter ;
private AuthenticationSuccessHandler authenticationSuccessHandler = this : : onAuthenticationSuccess ;
@ -110,6 +114,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
@@ -110,6 +114,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
Assert . notNull ( requestMatcher , "requestMatcher cannot be null" ) ;
this . authenticationManager = authenticationManager ;
this . requestMatcher = requestMatcher ;
this . basicAuthenticationEntryPoint . setRealmName ( "default" ) ;
// @formatter:off
this . authenticationConverter = new DelegatingAuthenticationConverter (
Arrays . asList (
@ -130,8 +135,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
@@ -130,8 +135,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
return ;
}
Authentication authenticationRequest = null ;
try {
Authentication authenticationRequest = this . authenticationConverter . convert ( request ) ;
authenticationRequest = this . authenticationConverter . convert ( request ) ;
if ( authenticationRequest instanceof AbstractAuthenticationToken ) {
( ( AbstractAuthenticationToken ) authenticationRequest )
. setDetails ( this . authenticationDetailsSource . buildDetails ( request ) ) ;
@ -148,7 +154,14 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
@@ -148,7 +154,14 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
if ( this . logger . isTraceEnabled ( ) ) {
this . logger . trace ( LogMessage . format ( "Client authentication failed: %s" , ex . getError ( ) ) , ex ) ;
}
this . authenticationFailureHandler . onAuthenticationFailure ( request , response , ex ) ;
if ( authenticationRequest instanceof OAuth2ClientAuthenticationToken clientAuthentication ) {
this . authenticationFailureHandler . onAuthenticationFailure ( request , response ,
new OAuth2ClientAuthenticationException ( ex . getError ( ) , ex , clientAuthentication ) ) ;
}
else {
this . authenticationFailureHandler . onAuthenticationFailure ( request , response , ex ) ;
}
}
}
@ -200,21 +213,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
@@ -200,21 +213,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
}
private void onAuthenticationFailure ( HttpServletRequest request , HttpServletResponse response ,
AuthenticationException exception ) throws IOException {
AuthenticationException auth enticationE xception) throws IOException {
SecurityContextHolder . clearContext ( ) ;
// TODO
// The authorization server MAY return an HTTP 401 (Unauthorized) status code
// to indicate which HTTP authentication schemes are supported.
// If the client attempted to authenticate via the "Authorization" request header
// field,
// the authorization server MUST respond with an HTTP 401 (Unauthorized) status
// code and
// include the "WWW-Authenticate" response header field
// matching the authentication scheme used by the client.
OAuth2Error error = ( ( OAuth2AuthenticationException ) exception ) . getError ( ) ;
if ( authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException ) {
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException
. getClientAuthentication ( ) ;
if ( ClientAuthenticationMethod . CLIENT_SECRET_BASIC
. equals ( clientAuthentication . getClientAuthenticationMethod ( ) ) ) {
this . basicAuthenticationEntryPoint . commence ( request , response , authenticationException ) ;
return ;
}
}
OAuth2Error error = ( ( OAuth2AuthenticationException ) auth enticationE xception) . getError ( ) ;
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse ( response ) ;
if ( OAuth2ErrorCodes . INVALID_CLIENT . equals ( error . getErrorCode ( ) ) ) {
httpResponse . setStatusCode ( HttpStatus . UNAUTHORIZED ) ;
@ -249,4 +262,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
@@ -249,4 +262,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
}
}
private static final class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException {
private final OAuth2ClientAuthenticationToken clientAuthentication ;
private OAuth2ClientAuthenticationException ( OAuth2Error error , Throwable cause ,
OAuth2ClientAuthenticationToken clientAuthentication ) {
super ( error , cause ) ;
Assert . notNull ( clientAuthentication , "clientAuthentication cannot be null" ) ;
this . clientAuthentication = clientAuthentication ;
}
private OAuth2ClientAuthenticationToken getClientAuthentication ( ) {
return this . clientAuthentication ;
}
}
}