@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization.web;
@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization.web;
import org.springframework.http.HttpMethod ;
import org.springframework.http.HttpStatus ;
import org.springframework.http.MediaType ;
import org.springframework.security.authentication.AnonymousAuthenticationToken ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.context.SecurityContextHolder ;
@ -32,6 +33,7 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
@@ -32,6 +33,7 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization ;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames ;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService ;
import org.springframework.security.oauth2.server.authorization.TokenType ;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient ;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository ;
import org.springframework.security.web.DefaultRedirectStrategy ;
@ -39,6 +41,7 @@ import org.springframework.security.web.RedirectStrategy;
@@ -39,6 +41,7 @@ import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher ;
import org.springframework.security.web.util.matcher.RequestMatcher ;
import org.springframework.util.Assert ;
import org.springframework.util.CollectionUtils ;
import org.springframework.util.MultiValueMap ;
import org.springframework.util.StringUtils ;
import org.springframework.web.filter.OncePerRequestFilter ;
@ -49,10 +52,12 @@ import javax.servlet.ServletException;
@@ -49,10 +52,12 @@ import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import java.io.IOException ;
import java.nio.charset.StandardCharsets ;
import java.util.Arrays ;
import java.util.Base64 ;
import java.util.Collections ;
import java.util.HashSet ;
import java.util.List ;
import java.util.Set ;
/ * *
@ -68,6 +73,7 @@ import java.util.Set;
@@ -68,6 +73,7 @@ import java.util.Set;
* @see OAuth2Authorization
* @see < a target = "_blank" href = "https://tools.ietf.org/html/rfc6749#section-4.1" > Section 4 . 1 Authorization Code Grant < / a >
* @see < a target = "_blank" href = "https://tools.ietf.org/html/rfc6749#section-4.1.1" > Section 4 . 1 . 1 Authorization Request < / a >
* @see < a target = "_blank" href = "https://tools.ietf.org/html/rfc6749#section-4.1.2" > Section 4 . 1 . 2 Authorization Response < / a >
* /
public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
/ * *
@ -79,8 +85,10 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
@@ -79,8 +85,10 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
private final RegisteredClientRepository registeredClientRepository ;
private final OAuth2AuthorizationService authorizationService ;
private final RequestMatcher authorizationEndpointMatcher ;
private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator ( Base64 . getUrlEncoder ( ) ) ;
private final RequestMatcher authorizationRequestMatcher ;
private final RequestMatcher userConsentMatcher ;
private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator ( Base64 . getUrlEncoder ( ) . withoutPadding ( ) , 96 ) ;
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator ( Base64 . getUrlEncoder ( ) ) ;
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy ( ) ;
/ * *
@ -108,159 +116,292 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
@@ -108,159 +116,292 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
Assert . hasText ( authorizationEndpointUri , "authorizationEndpointUri cannot be empty" ) ;
this . registeredClientRepository = registeredClientRepository ;
this . authorizationService = authorizationService ;
this . authorizationEndpoin tMatcher = new AntPathRequestMatcher (
this . authorizationReques tMatcher = new AntPathRequestMatcher (
authorizationEndpointUri , HttpMethod . GET . name ( ) ) ;
this . userConsentMatcher = new AntPathRequestMatcher (
authorizationEndpointUri , HttpMethod . POST . name ( ) ) ;
}
@Override
protected void doFilterInternal ( HttpServletRequest request , HttpServletResponse response , FilterChain filterChain )
throws ServletException , IOException {
if ( ! this . authorizationEndpointMatcher . matches ( request ) ) {
if ( this . authorizationRequestMatcher . matches ( request ) ) {
processAuthorizationRequest ( request , response , filterChain ) ;
} else if ( this . userConsentMatcher . matches ( request ) ) {
processUserConsent ( request , response ) ;
} else {
filterChain . doFilter ( request , response ) ;
}
}
private void processAuthorizationRequest ( HttpServletRequest request , HttpServletResponse response , FilterChain filterChain )
throws ServletException , IOException {
OAuth2AuthorizationRequestContext authorizationRequestContext =
new OAuth2AuthorizationRequestContext (
request . getRequestURL ( ) . toString ( ) ,
OAuth2EndpointUtils . getParameters ( request ) ) ;
validateAuthorizationRequest ( authorizationRequestContext ) ;
if ( authorizationRequestContext . hasError ( ) ) {
if ( authorizationRequestContext . isRedirectOnError ( ) ) {
sendErrorResponse ( request , response , authorizationRequestContext . resolveRedirectUri ( ) ,
authorizationRequestContext . getError ( ) , authorizationRequestContext . getState ( ) ) ;
} else {
sendErrorResponse ( response , authorizationRequestContext . getError ( ) ) ;
}
return ;
}
// ---------------
// Validate the request to ensure that all required parameters are present and valid
// The request is valid - ensure the resource owner is authenticate d
// ---------------
MultiValueMap < String , String > parameters = OAuth2EndpointUtils . getParameters ( request ) ;
String stateParameter = parameters . getFirst ( OAuth2ParameterNames . STATE ) ;
Authentication principal = SecurityContextHolder . getContext ( ) . getAuthentication ( ) ;
if ( ! isPrincipalAuthenticated ( principal ) ) {
// Pass through the chain with the expectation that the authentication process
// will commence via AuthenticationEntryPoint
filterChain . doFilter ( request , response ) ;
return ;
}
RegisteredClient registeredClient = authorizationRequestContext . getRegisteredClient ( ) ;
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext . buildAuthorizationRequest ( ) ;
OAuth2Authorization . Builder builder = OAuth2Authorization . withRegisteredClient ( registeredClient )
. principalName ( principal . getName ( ) )
. attribute ( OAuth2AuthorizationAttributeNames . AUTHORIZATION_REQUEST , authorizationRequest ) ;
if ( registeredClient . getClientSettings ( ) . requireUserConsent ( ) ) {
String state = this . stateGenerator . generateKey ( ) ;
OAuth2Authorization authorization = builder
. attribute ( OAuth2AuthorizationAttributeNames . STATE , state )
. build ( ) ;
this . authorizationService . save ( authorization ) ;
// TODO Need to remove 'in-flight' authorization if consent step is not completed (e.g. approved or cancelled)
UserConsentPage . displayConsent ( request , response , registeredClient , authorization ) ;
} else {
String code = this . codeGenerator . generateKey ( ) ;
OAuth2Authorization authorization = builder
. attribute ( OAuth2AuthorizationAttributeNames . CODE , code )
. attribute ( OAuth2AuthorizationAttributeNames . AUTHORIZED_SCOPES , authorizationRequest . getScopes ( ) )
. build ( ) ;
this . authorizationService . save ( authorization ) ;
// TODO security checks for code parameter
// The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks.
// A maximum authorization code lifetime of 10 minutes is RECOMMENDED.
// The client MUST NOT use the authorization code more than once.
// If an authorization code is used more than once, the authorization server MUST deny the request
// and SHOULD revoke (when possible) all tokens previously issued based on that authorization code.
// The authorization code is bound to the client identifier and redirection URI.
sendAuthorizationResponse ( request , response ,
authorizationRequestContext . resolveRedirectUri ( ) , code , authorizationRequest . getState ( ) ) ;
}
}
private void processUserConsent ( HttpServletRequest request , HttpServletResponse response )
throws IOException {
UserConsentRequestContext userConsentRequestContext =
new UserConsentRequestContext (
request . getRequestURL ( ) . toString ( ) ,
OAuth2EndpointUtils . getParameters ( request ) ) ;
validateUserConsentRequest ( userConsentRequestContext ) ;
if ( userConsentRequestContext . hasError ( ) ) {
if ( userConsentRequestContext . isRedirectOnError ( ) ) {
sendErrorResponse ( request , response , userConsentRequestContext . resolveRedirectUri ( ) ,
userConsentRequestContext . getError ( ) , userConsentRequestContext . getState ( ) ) ;
} else {
sendErrorResponse ( response , userConsentRequestContext . getError ( ) ) ;
}
return ;
}
if ( ! UserConsentPage . isConsentApproved ( request ) ) {
this . authorizationService . remove ( userConsentRequestContext . getAuthorization ( ) ) ;
OAuth2Error error = createError ( OAuth2ErrorCodes . ACCESS_DENIED , OAuth2ParameterNames . CLIENT_ID ) ;
sendErrorResponse ( request , response , userConsentRequestContext . resolveRedirectUri ( ) ,
error , userConsentRequestContext . getAuthorizationRequest ( ) . getState ( ) ) ;
return ;
}
String code = this . codeGenerator . generateKey ( ) ;
OAuth2Authorization authorization = OAuth2Authorization . from ( userConsentRequestContext . getAuthorization ( ) )
. attributes ( attrs - > {
attrs . remove ( OAuth2AuthorizationAttributeNames . STATE ) ;
attrs . put ( OAuth2AuthorizationAttributeNames . CODE , code ) ;
attrs . put ( OAuth2AuthorizationAttributeNames . AUTHORIZED_SCOPES , userConsentRequestContext . getScopes ( ) ) ;
} )
. build ( ) ;
this . authorizationService . save ( authorization ) ;
sendAuthorizationResponse ( request , response , userConsentRequestContext . resolveRedirectUri ( ) ,
code , userConsentRequestContext . getAuthorizationRequest ( ) . getState ( ) ) ;
}
private void validateAuthorizationRequest ( OAuth2AuthorizationRequestContext authorizationRequestContext ) {
// ---------------
// Validate the request to ensure all required parameters are present and valid
// ---------------
// client_id (REQUIRED)
String clientId = parameters . getFirst ( OAuth2ParameterNames . CLIENT_ID ) ;
if ( ! StringUtils . hasText ( clientId ) | |
parameters . get ( OAuth2ParameterNames . CLIENT_ID ) . size ( ) ! = 1 ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ;
sendErrorResponse ( request , response , error , stateParameter , null ) ; // when redirectUri is null then don't redirect
if ( ! StringUtils . hasText ( authorizationRequestContext . getClientId ( ) ) | |
authorizationRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . CLIENT_ID ) . size ( ) ! = 1 ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ) ;
return ;
}
RegisteredClient registeredClient = this . registeredClientRepository . findByClientId ( clientId ) ;
RegisteredClient registeredClient = this . registeredClientRepository . findByClientId (
authorizationRequestContext . getClientId ( ) ) ;
if ( registeredClient = = null ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ;
sendErrorResponse ( request , response , error , stateParameter , null ) ; // when redirectUri is null then don't redirect
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ) ;
return ;
} else if ( ! registeredClient . getAuthorizationGrantTypes ( ) . contains ( AuthorizationGrantType . AUTHORIZATION_CODE ) ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . UNAUTHORIZED_CLIENT , OAuth2ParameterNames . CLIENT_ID ) ;
sendErrorResponse ( request , response , error , stateParameter , null ) ; // when redirectUri is null then don't redirect
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . UNAUTHORIZED_CLIENT , OAuth2ParameterNames . CLIENT_ID ) ) ;
return ;
}
authorizationRequestContext . setRegisteredClient ( registeredClient ) ;
// redirect_uri (OPTIONAL)
String redirectUriParameter = parameters . getFirst ( OAuth2ParameterNames . REDIRECT_URI ) ;
if ( StringUtils . hasText ( redirectUriParameter ) ) {
if ( ! registeredClient . getRedirectUris ( ) . contains ( redirectUriParameter ) | |
parameters . get ( OAuth2ParameterNames . REDIRECT_URI ) . size ( ) ! = 1 ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . REDIRECT_URI ) ;
sendErrorResponse ( request , response , error , stateParameter , null ) ; // when redirectUri is null then don't redirect
if ( StringUtils . hasText ( authorizationRequestContext . getRedirectUri ( ) ) ) {
if ( ! registeredClient . getRedirectUris ( ) . contains ( authorizationRequestContext . getRedirectUri ( ) ) | |
authorizationRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . REDIRECT_URI ) . size ( ) ! = 1 ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . REDIRECT_URI ) ) ;
return ;
}
} else if ( registeredClient . getRedirectUris ( ) . size ( ) ! = 1 ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . REDIRECT_URI ) ;
sendErrorResponse ( request , response , error , stateParameter , null ) ; // when redirectUri is null then don't redirect
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . REDIRECT_URI ) ) ;
return ;
}
String redirectUri = StringUtils . hasText ( redirectUriParameter ) ?
redirectUriParameter : registeredClient . getRedirectUris ( ) . iterator ( ) . next ( ) ;
authorizationRequestContext . setRedirectOnError ( true ) ;
// response_type (REQUIRED)
String responseType = parameters . getFirst ( OAuth2ParameterNames . RESPONSE_TYPE ) ;
if ( ! StringUtils . hasText ( responseType ) | |
parameters . get ( OAuth2ParameterNames . RESPONSE_TYPE ) . size ( ) ! = 1 ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . RESPONSE_TYPE ) ;
sendErrorResponse ( request , response , error , stateParameter , redirectUri ) ;
if ( ! StringUtils . hasText ( authorizationRequestContext . getResponseType ( ) ) | |
authorizationRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . RESPONSE_TYPE ) . size ( ) ! = 1 ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . RESPONSE_TYPE ) ) ;
return ;
} else if ( ! responseType . equals ( OAuth2AuthorizationResponseType . CODE . getValue ( ) ) ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . UNSUPPORTED_RESPONSE_TYPE , OAuth2ParameterNames . RESPONSE_TYPE ) ;
sendErrorResponse ( request , response , error , stateParameter , redirectUri ) ;
} else if ( ! authorizationRequestContext . getResponseType ( ) . equals ( OAuth2AuthorizationResponseType . CODE . getValue ( ) ) ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . UNSUPPORTED_RESPONSE_TYPE , OAuth2ParameterNames . RESPONSE_TYPE ) ) ;
return ;
}
// scope (OPTIONAL)
Set < String > requestedScopes = authorizationRequestContext . getScopes ( ) ;
Set < String > allowedScopes = registeredClient . getScopes ( ) ;
if ( ! requestedScopes . isEmpty ( ) & & ! allowedScopes . containsAll ( requestedScopes ) ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_SCOPE , OAuth2ParameterNames . SCOPE ) ) ;
return ;
}
// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
String codeChallenge = parameters . getFirst ( PkceParameterNames . CODE_CHALLENGE ) ;
String codeChallenge = authorizationRequestContext . getPa rameters( ) . getFirst ( PkceParameterNames . CODE_CHALLENGE ) ;
if ( StringUtils . hasText ( codeChallenge ) ) {
if ( parameters . get ( PkceParameterNames . CODE_CHALLENGE ) . size ( ) ! = 1 ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE , PKCE_ERROR_URI ) ;
sendErrorResponse ( request , response , error , stateParameter , redirectUri ) ;
return ;
}
String codeChallengeMethod = parameters . getFirst ( PkceParameterNames . CODE_CHALLENGE_METHOD ) ;
if ( StringUtils . hasText ( codeChallengeMethod ) & &
parameters . get ( PkceParameterNames . CODE_CHALLENGE_METHOD ) . size ( ) ! = 1 ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE_METHOD , PKCE_ERROR_URI ) ;
sendErrorResponse ( request , response , error , stateParameter , redirectUri ) ;
if ( authorizationRequestContext . getParameters ( ) . get ( PkceParameterNames . CODE_CHALLENGE ) . size ( ) ! = 1 ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE , PKCE_ERROR_URI ) ) ;
return ;
}
if ( StringUtils . hasText ( codeChallengeMethod ) & &
( ! "S256" . equals ( codeChallengeMethod ) & & ! "plain" . equals ( codeChallengeMethod ) ) ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE_METHOD , PKCE_ERROR_URI ) ;
sendErrorResponse ( request , response , error , stateParameter , redirectUri ) ;
return ;
String codeChallengeMethod = authorizationRequestContext . getParameters ( ) . getFirst ( PkceParameterNames . CODE_CHALLENGE_METHOD ) ;
if ( StringUtils . hasText ( codeChallengeMethod ) ) {
if ( authorizationRequestContext . getParameters ( ) . get ( PkceParameterNames . CODE_CHALLENGE_METHOD ) . size ( ) ! = 1 | |
( ! "S256" . equals ( codeChallengeMethod ) & & ! "plain" . equals ( codeChallengeMethod ) ) ) {
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE_METHOD , PKCE_ERROR_URI ) ) ;
return ;
}
}
} else if ( registeredClient . getClientSettings ( ) . requireProofKey ( ) ) {
OAuth2Error error = createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE , PKCE_ERROR_URI ) ;
sendErrorResponse ( request , response , error , stateParameter , redirectUri ) ;
authorizationRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , PkceParameterNames . CODE_CHALLENGE , PKCE_ERROR_URI ) ) ;
return ;
}
}
private void validateUserConsentRequest ( UserConsentRequestContext userConsentRequestContext ) {
// ---------------
// The request is valid - ensure the resource owner is authenticated
// Validate the request to ensure all required parameters are present and vali d
// ---------------
Authentication principal = SecurityContextHolder . getContext ( ) . getAuthentication ( ) ;
if ( ! isPrincipalAuthenticated ( principal ) ) {
// Pass through the chain with the expectation that the authentication process
// will commence via AuthenticationEntryPoint
filterChain . doFilter ( request , response ) ;
// state (REQUIRED)
if ( ! StringUtils . hasText ( userConsentRequestContext . getState ( ) ) | |
userConsentRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . STATE ) . size ( ) ! = 1 ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . STATE ) ) ;
return ;
}
OAuth2Authorization authorization = this . authorizationService . findByToken (
userConsentRequestContext . getState ( ) , new TokenType ( OAuth2AuthorizationAttributeNames . STATE ) ) ;
if ( authorization = = null ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . STATE ) ) ;
return ;
}
userConsentRequestContext . setAuthorization ( authorization ) ;
String code = this . codeGenerator . generateKey ( ) ;
OAuth2AuthorizationRequest authorizationRequest = convertAuthorizationRequest ( request ) ;
OAuth2Authorization authorization = OAuth2Authorization . withRegisteredClient ( registeredClient )
. principalName ( principal . getName ( ) )
. attribute ( OAuth2AuthorizationAttributeNames . CODE , code )
. attribute ( OAuth2AuthorizationAttributeNames . AUTHORIZATION_REQUEST , authorizationRequest )
. build ( ) ;
this . authorizationService . save ( authorization ) ;
// TODO security checks for code parameter
// The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks.
// A maximum authorization code lifetime of 10 minutes is RECOMMENDED.
// The client MUST NOT use the authorization code more than once.
// If an authorization code is used more than once, the authorization server MUST deny the request
// and SHOULD revoke (when possible) all tokens previously issued based on that authorization code.
// The authorization code is bound to the client identifier and redirection URI.
// The 'in-flight' authorization must be associated to the current principal
Authentication principal = SecurityContextHolder . getContext ( ) . getAuthentication ( ) ;
if ( ! isPrincipalAuthenticated ( principal ) | | ! principal . getName ( ) . equals ( authorization . getPrincipalName ( ) ) ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . STATE ) ) ;
return ;
}
sendAuthorizationResponse ( request , response , authorizationRequest , code , redirectUri ) ;
// client_id (REQUIRED)
if ( ! StringUtils . hasText ( userConsentRequestContext . getClientId ( ) ) | |
userConsentRequestContext . getParameters ( ) . get ( OAuth2ParameterNames . CLIENT_ID ) . size ( ) ! = 1 ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ) ;
return ;
}
RegisteredClient registeredClient = this . registeredClientRepository . findByClientId (
userConsentRequestContext . getClientId ( ) ) ;
if ( registeredClient = = null | | ! registeredClient . getId ( ) . equals ( authorization . getRegisteredClientId ( ) ) ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_REQUEST , OAuth2ParameterNames . CLIENT_ID ) ) ;
return ;
}
userConsentRequestContext . setRegisteredClient ( registeredClient ) ;
userConsentRequestContext . setRedirectOnError ( true ) ;
// scope (OPTIONAL)
Set < String > requestedScopes = userConsentRequestContext . getAuthorizationRequest ( ) . getScopes ( ) ;
Set < String > authorizedScopes = userConsentRequestContext . getScopes ( ) ;
if ( ! authorizedScopes . isEmpty ( ) & & ! requestedScopes . containsAll ( authorizedScopes ) ) {
userConsentRequestContext . setError (
createError ( OAuth2ErrorCodes . INVALID_SCOPE , OAuth2ParameterNames . SCOPE ) ) ;
return ;
}
}
private void sendAuthorizationResponse ( HttpServletRequest request , HttpServletResponse response ,
OAuth2AuthorizationRequest authorizationRequest , String code , String redirectUri ) throws IOException {
String redirectUri , String code , String state ) throws IOException {
UriComponentsBuilder uriBuilder = UriComponentsBuilder
. fromUriString ( redirectUri )
. queryParam ( OAuth2ParameterNames . CODE , code ) ;
if ( StringUtils . hasText ( authorizationRequest . getState ( ) ) ) {
uriBuilder . queryParam ( OAuth2ParameterNames . STATE , authorizationRequest . getState ( ) ) ;
if ( StringUtils . hasText ( state ) ) {
uriBuilder . queryParam ( OAuth2ParameterNames . STATE , state ) ;
}
this . redirectStrategy . sendRedirect ( request , response , uriBuilder . toUriString ( ) ) ;
}
private void sendErrorResponse ( HttpServletRequest request , HttpServletResponse response ,
OAuth2Error error , String state , String redirectUri ) throws IOException {
if ( redirectUri = = null ) {
// TODO Send default html error response
response . sendError ( HttpStatus . BAD_REQUEST . value ( ) , error . toString ( ) ) ;
return ;
}
String redirectUri , OAuth2Error error , String state ) throws IOException {
UriComponentsBuilder uriBuilder = UriComponentsBuilder
. fromUriString ( redirectUri )
@ -277,6 +418,11 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
@@ -277,6 +418,11 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
this . redirectStrategy . sendRedirect ( request , response , uriBuilder . toUriString ( ) ) ;
}
private void sendErrorResponse ( HttpServletResponse response , OAuth2Error error ) throws IOException {
// TODO Send default html error response
response . sendError ( HttpStatus . BAD_REQUEST . value ( ) , error . toString ( ) ) ;
}
private static OAuth2Error createError ( String errorCode , String parameterName ) {
return createError ( errorCode , parameterName , "https://tools.ietf.org/html/rfc6749#section-4.1.2.1" ) ;
}
@ -291,29 +437,254 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
@@ -291,29 +437,254 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
principal . isAuthenticated ( ) ;
}
private static OAuth2AuthorizationRequest convertAuthorizationRequest ( HttpServletRequest request ) {
MultiValueMap < String , String > parameters = OAuth2EndpointUtils . getParameters ( request ) ;
private static class OAuth2AuthorizationRequestContext extends AbstractRequestContext {
private final String responseType ;
private final String redirectUri ;
private OAuth2AuthorizationRequestContext (
String authorizationUri , MultiValueMap < String , String > parameters ) {
super ( authorizationUri , parameters ,
parameters . getFirst ( OAuth2ParameterNames . CLIENT_ID ) ,
parameters . getFirst ( OAuth2ParameterNames . STATE ) ,
extractScopes ( parameters ) ) ;
this . responseType = parameters . getFirst ( OAuth2ParameterNames . RESPONSE_TYPE ) ;
this . redirectUri = parameters . getFirst ( OAuth2ParameterNames . REDIRECT_URI ) ;
}
Set < String > scopes = Collections . emptySet ( ) ;
if ( parameters . containsKey ( OAuth2ParameterNames . SCOPE ) ) {
private static Set < String > extractScopes ( MultiValueMap < String , String > parameters ) {
String scope = parameters . getFirst ( OAuth2ParameterNames . SCOPE ) ;
scopes = new HashSet < > ( Arrays . asList ( StringUtils . delimitedListToStringArray ( scope , " " ) ) ) ;
}
return OAuth2AuthorizationRequest . authorizationCode ( )
. authorizationUri ( request . getRequestURL ( ) . toString ( ) )
. clientId ( parameters . getFirst ( OAuth2ParameterNames . CLIENT_ID ) )
. redirectUri ( parameters . getFirst ( OAuth2ParameterNames . REDIRECT_URI ) )
. scopes ( scopes )
. state ( parameters . getFirst ( OAuth2ParameterNames . STATE ) )
. additionalParameters ( additionalParameters - >
parameters . entrySet ( ) . stream ( )
. filter ( e - > ! e . getKey ( ) . equals ( OAuth2ParameterNames . RESPONSE_TYPE ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . CLIENT_ID ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . REDIRECT_URI ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . SCOPE ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . STATE ) )
. forEach ( e - > additionalParameters . put ( e . getKey ( ) , e . getValue ( ) . get ( 0 ) ) ) )
. build ( ) ;
return StringUtils . hasText ( scope ) ?
new HashSet < > ( Arrays . asList ( StringUtils . delimitedListToStringArray ( scope , " " ) ) ) :
Collections . emptySet ( ) ;
}
private String getResponseType ( ) {
return this . responseType ;
}
private String getRedirectUri ( ) {
return this . redirectUri ;
}
protected String resolveRedirectUri ( ) {
return StringUtils . hasText ( getRedirectUri ( ) ) ?
getRedirectUri ( ) :
getRegisteredClient ( ) . getRedirectUris ( ) . iterator ( ) . next ( ) ;
}
private OAuth2AuthorizationRequest buildAuthorizationRequest ( ) {
return OAuth2AuthorizationRequest . authorizationCode ( )
. authorizationUri ( getAuthorizationUri ( ) )
. clientId ( getClientId ( ) )
. redirectUri ( getRedirectUri ( ) )
. scopes ( getScopes ( ) )
. state ( getState ( ) )
. additionalParameters ( additionalParameters - >
getParameters ( ) . entrySet ( ) . stream ( )
. filter ( e - > ! e . getKey ( ) . equals ( OAuth2ParameterNames . RESPONSE_TYPE ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . CLIENT_ID ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . REDIRECT_URI ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . SCOPE ) & &
! e . getKey ( ) . equals ( OAuth2ParameterNames . STATE ) )
. forEach ( e - > additionalParameters . put ( e . getKey ( ) , e . getValue ( ) . get ( 0 ) ) ) )
. build ( ) ;
}
}
private static class UserConsentRequestContext extends AbstractRequestContext {
private OAuth2Authorization authorization ;
private UserConsentRequestContext (
String authorizationUri , MultiValueMap < String , String > parameters ) {
super ( authorizationUri , parameters ,
parameters . getFirst ( OAuth2ParameterNames . CLIENT_ID ) ,
parameters . getFirst ( OAuth2ParameterNames . STATE ) ,
extractScopes ( parameters ) ) ;
}
private static Set < String > extractScopes ( MultiValueMap < String , String > parameters ) {
List < String > scope = parameters . get ( OAuth2ParameterNames . SCOPE ) ;
return ! CollectionUtils . isEmpty ( scope ) ? new HashSet < > ( scope ) : Collections . emptySet ( ) ;
}
private OAuth2Authorization getAuthorization ( ) {
return this . authorization ;
}
private void setAuthorization ( OAuth2Authorization authorization ) {
this . authorization = authorization ;
}
protected String resolveRedirectUri ( ) {
OAuth2AuthorizationRequest authorizationRequest = getAuthorizationRequest ( ) ;
return StringUtils . hasText ( authorizationRequest . getRedirectUri ( ) ) ?
authorizationRequest . getRedirectUri ( ) :
getRegisteredClient ( ) . getRedirectUris ( ) . iterator ( ) . next ( ) ;
}
private OAuth2AuthorizationRequest getAuthorizationRequest ( ) {
return getAuthorization ( ) . getAttribute ( OAuth2AuthorizationAttributeNames . AUTHORIZATION_REQUEST ) ;
}
}
private abstract static class AbstractRequestContext {
private final String authorizationUri ;
private final MultiValueMap < String , String > parameters ;
private final String clientId ;
private final String state ;
private final Set < String > scopes ;
private RegisteredClient registeredClient ;
private OAuth2Error error ;
private boolean redirectOnError ;
protected AbstractRequestContext ( String authorizationUri , MultiValueMap < String , String > parameters ,
String clientId , String state , Set < String > scopes ) {
this . authorizationUri = authorizationUri ;
this . parameters = parameters ;
this . clientId = clientId ;
this . state = state ;
this . scopes = scopes ;
}
protected String getAuthorizationUri ( ) {
return this . authorizationUri ;
}
protected MultiValueMap < String , String > getParameters ( ) {
return this . parameters ;
}
protected String getClientId ( ) {
return this . clientId ;
}
protected String getState ( ) {
return this . state ;
}
protected Set < String > getScopes ( ) {
return this . scopes ;
}
protected RegisteredClient getRegisteredClient ( ) {
return this . registeredClient ;
}
protected void setRegisteredClient ( RegisteredClient registeredClient ) {
this . registeredClient = registeredClient ;
}
protected OAuth2Error getError ( ) {
return this . error ;
}
protected void setError ( OAuth2Error error ) {
this . error = error ;
}
protected boolean hasError ( ) {
return getError ( ) ! = null ;
}
protected boolean isRedirectOnError ( ) {
return this . redirectOnError ;
}
protected void setRedirectOnError ( boolean redirectOnError ) {
this . redirectOnError = redirectOnError ;
}
protected abstract String resolveRedirectUri ( ) ;
}
private static class UserConsentPage {
private static final MediaType TEXT_HTML_UTF8 = new MediaType ( "text" , "html" , StandardCharsets . UTF_8 ) ;
private static final String CONSENT_ACTION_PARAMETER_NAME = "consent_action" ;
private static final String CONSENT_ACTION_APPROVE = "approve" ;
private static final String CONSENT_ACTION_CANCEL = "cancel" ;
private static void displayConsent ( HttpServletRequest request , HttpServletResponse response ,
RegisteredClient registeredClient , OAuth2Authorization authorization ) throws IOException {
String consentPage = generateConsentPage ( request , registeredClient , authorization ) ;
response . setContentType ( TEXT_HTML_UTF8 . toString ( ) ) ;
response . setContentLength ( consentPage . getBytes ( StandardCharsets . UTF_8 ) . length ) ;
response . getWriter ( ) . write ( consentPage ) ;
}
private static boolean isConsentApproved ( HttpServletRequest request ) {
return CONSENT_ACTION_APPROVE . equalsIgnoreCase ( request . getParameter ( CONSENT_ACTION_PARAMETER_NAME ) ) ;
}
private static boolean isConsentCancelled ( HttpServletRequest request ) {
return CONSENT_ACTION_CANCEL . equalsIgnoreCase ( request . getParameter ( CONSENT_ACTION_PARAMETER_NAME ) ) ;
}
private static String generateConsentPage ( HttpServletRequest request ,
RegisteredClient registeredClient , OAuth2Authorization authorization ) {
OAuth2AuthorizationRequest authorizationRequest = authorization . getAttribute (
OAuth2AuthorizationAttributeNames . AUTHORIZATION_REQUEST ) ;
String state = authorization . getAttribute (
OAuth2AuthorizationAttributeNames . STATE ) ;
StringBuilder builder = new StringBuilder ( ) ;
builder . append ( "<!DOCTYPE html>" ) ;
builder . append ( "<html lang=\"en\">" ) ;
builder . append ( "<head>" ) ;
builder . append ( " <meta charset=\"utf-8\">" ) ;
builder . append ( " <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">" ) ;
builder . append ( " <link rel=\"stylesheet\" href=\"https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css\" integrity=\"sha384-JcKb8q3iqJ61gNV9KGb8thSsNjpSL0n8PARn9HuZOnIxN0hoP+VmmDGMN5t9UJ0Z\" crossorigin=\"anonymous\">" ) ;
builder . append ( " <title>Consent required</title>" ) ;
builder . append ( "</head>" ) ;
builder . append ( "<body>" ) ;
builder . append ( "<div class=\"container\">" ) ;
builder . append ( " <div class=\"py-5\">" ) ;
builder . append ( " <h1 class=\"text-center\">Consent required</h1>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " <div class=\"row\">" ) ;
builder . append ( " <div class=\"col text-center\">" ) ;
builder . append ( " <p><span class=\"font-weight-bold text-primary\">" + registeredClient . getClientId ( ) + "</span> wants to access your account <span class=\"font-weight-bold\">" + authorization . getPrincipalName ( ) + "</span></p>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " <div class=\"row pb-3\">" ) ;
builder . append ( " <div class=\"col text-center\">" ) ;
builder . append ( " <p>The following permissions are requested by the above app.<br/>Please review these and consent if you approve.</p>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " <div class=\"row\">" ) ;
builder . append ( " <div class=\"col text-center\">" ) ;
builder . append ( " <form method=\"post\" action=\"" + request . getRequestURI ( ) + "\">" ) ;
builder . append ( " <input type=\"hidden\" name=\"client_id\" value=\"" + registeredClient . getClientId ( ) + "\">" ) ;
builder . append ( " <input type=\"hidden\" name=\"state\" value=\"" + state + "\">" ) ;
for ( String scope : authorizationRequest . getScopes ( ) ) {
builder . append ( " <div class=\"form-group form-check py-1\">" ) ;
builder . append ( " <input class=\"form-check-input\" type=\"checkbox\" name=\"scope\" value=\"" + scope + "\" id=\"" + scope + "\" checked>" ) ;
builder . append ( " <label class=\"form-check-label\" for=\"" + scope + "\">" + scope + "</label>" ) ;
builder . append ( " </div>" ) ;
}
builder . append ( " <div class=\"form-group pt-3\">" ) ;
builder . append ( " <button class=\"btn btn-primary btn-lg\" type=\"submit\" name=\"consent_action\" value=\"approve\">Submit Consent</button>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " <div class=\"form-group\">" ) ;
builder . append ( " <button class=\"btn btn-link regular\" type=\"submit\" name=\"consent_action\" value=\"cancel\">Cancel</button>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " </form>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " <div class=\"row pt-4\">" ) ;
builder . append ( " <div class=\"col text-center\">" ) ;
builder . append ( " <p><small>Your consent to provide access is required.<br/>If you do not approve, click Cancel, in which case no information will be shared with the app.</small></p>" ) ;
builder . append ( " </div>" ) ;
builder . append ( " </div>" ) ;
builder . append ( "</div>" ) ;
builder . append ( "</body>" ) ;
builder . append ( "</html>" ) ;
return builder . toString ( ) ;
}
}
}