@ -17,11 +17,14 @@
@@ -17,11 +17,14 @@
package org.springframework.security.config.annotation.web.configurers.oauth2.server.resource ;
import java.util.Collections ;
import java.util.LinkedHashMap ;
import java.util.List ;
import java.util.Map ;
import java.util.regex.Matcher ;
import java.util.regex.Pattern ;
import jakarta.servlet.http.HttpServletRequest ;
import jakarta.servlet.http.HttpServletResponse ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpStatus ;
@ -29,18 +32,21 @@ import org.springframework.security.authentication.AuthenticationManager;
@@ -29,18 +32,21 @@ import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder ;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.AuthenticationException ;
import org.springframework.security.oauth2.core.OAuth2AccessToken ;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException ;
import org.springframework.security.oauth2.core.OAuth2Error ;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes ;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms ;
import org.springframework.security.oauth2.server.resource.authentication.DPoPAuthenticationProvider ;
import org.springframework.security.oauth2.server.resource.authentication.DPoPAuthenticationToken ;
import org.springframework.security.web.AuthenticationEntryPoint ;
import org.springframework.security.web.authentication.AuthenticationConverter ;
import org.springframework.security.web.authentication.AuthenticationEntryPointFailureHandler ;
import org.springframework.security.web.authentication.AuthenticationFailureHandler ;
import org.springframework.security.web.authentication.AuthenticationFilter ;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler ;
import org.springframework.security.web.authentication.HttpStatusEntryPoint ;
import org.springframework.security.web.context.RequestAttributeSecurityContextRepository ;
import org.springframework.security.web.util.matcher.RequestMatcher ;
import org.springframework.util.CollectionUtils ;
@ -102,7 +108,7 @@ final class DPoPAuthenticationConfigurer<B extends HttpSecurityBuilder<B>>
@@ -102,7 +108,7 @@ final class DPoPAuthenticationConfigurer<B extends HttpSecurityBuilder<B>>
private AuthenticationFailureHandler getAuthenticationFailureHandler ( ) {
if ( this . authenticationFailureHandler = = null ) {
this . authenticationFailureHandler = new AuthenticationEntryPointFailureHandler (
new HttpStatusEntryPoint ( HttpStatus . UNAUTHORIZED ) ) ;
new DPoPAuthenticationEntryPoint ( ) ) ;
}
return this . authenticationFailureHandler ;
}
@ -161,4 +167,47 @@ final class DPoPAuthenticationConfigurer<B extends HttpSecurityBuilder<B>>
@@ -161,4 +167,47 @@ final class DPoPAuthenticationConfigurer<B extends HttpSecurityBuilder<B>>
}
private static final class DPoPAuthenticationEntryPoint implements AuthenticationEntryPoint {
@Override
public void commence ( HttpServletRequest request , HttpServletResponse response ,
AuthenticationException authenticationException ) {
Map < String , String > parameters = new LinkedHashMap < > ( ) ;
if ( authenticationException instanceof OAuth2AuthenticationException oauth2AuthenticationException ) {
OAuth2Error error = oauth2AuthenticationException . getError ( ) ;
parameters . put ( OAuth2ParameterNames . ERROR , error . getErrorCode ( ) ) ;
if ( StringUtils . hasText ( error . getDescription ( ) ) ) {
parameters . put ( OAuth2ParameterNames . ERROR_DESCRIPTION , error . getDescription ( ) ) ;
}
if ( StringUtils . hasText ( error . getUri ( ) ) ) {
parameters . put ( OAuth2ParameterNames . ERROR_URI , error . getUri ( ) ) ;
}
}
parameters . put ( "algs" ,
JwsAlgorithms . RS256 + " " + JwsAlgorithms . RS384 + " " + JwsAlgorithms . RS512 + " "
+ JwsAlgorithms . PS256 + " " + JwsAlgorithms . PS384 + " " + JwsAlgorithms . PS512 + " "
+ JwsAlgorithms . ES256 + " " + JwsAlgorithms . ES384 + " " + JwsAlgorithms . ES512 ) ;
String wwwAuthenticate = toWWWAuthenticateHeader ( parameters ) ;
response . addHeader ( HttpHeaders . WWW_AUTHENTICATE , wwwAuthenticate ) ;
response . setStatus ( HttpStatus . UNAUTHORIZED . value ( ) ) ;
}
private static String toWWWAuthenticateHeader ( Map < String , String > parameters ) {
StringBuilder wwwAuthenticate = new StringBuilder ( ) ;
wwwAuthenticate . append ( OAuth2AccessToken . TokenType . DPOP . getValue ( ) ) ;
if ( ! parameters . isEmpty ( ) ) {
wwwAuthenticate . append ( " " ) ;
int i = 0 ;
for ( Map . Entry < String , String > entry : parameters . entrySet ( ) ) {
wwwAuthenticate . append ( entry . getKey ( ) ) . append ( "=\"" ) . append ( entry . getValue ( ) ) . append ( "\"" ) ;
if ( i + + ! = parameters . size ( ) - 1 ) {
wwwAuthenticate . append ( ", " ) ;
}
}
}
return wwwAuthenticate . toString ( ) ;
}
}
}