@ -30,6 +30,7 @@ import java.util.Map;
* { @link OAuth2AuthorizationRequest } in the { @code HttpSession } .
* { @link OAuth2AuthorizationRequest } in the { @code HttpSession } .
*
*
* @author Joe Grandja
* @author Joe Grandja
* @author Rob Winch
* @since 5 . 0
* @since 5 . 0
* @see AuthorizationRequestRepository
* @see AuthorizationRequestRepository
* @see OAuth2AuthorizationRequest
* @see OAuth2AuthorizationRequest
@ -37,17 +38,16 @@ import java.util.Map;
public final class HttpSessionOAuth2AuthorizationRequestRepository implements AuthorizationRequestRepository < OAuth2AuthorizationRequest > {
public final class HttpSessionOAuth2AuthorizationRequestRepository implements AuthorizationRequestRepository < OAuth2AuthorizationRequest > {
private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME =
private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME =
HttpSessionOAuth2AuthorizationRequestRepository . class . getName ( ) + ".AUTHORIZATION_REQUEST" ;
HttpSessionOAuth2AuthorizationRequestRepository . class . getName ( ) + ".AUTHORIZATION_REQUEST" ;
private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME ;
private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME ;
@Override
@Override
public OAuth2AuthorizationRequest loadAuthorizationRequest ( HttpServletRequest request ) {
public OAuth2AuthorizationRequest loadAuthorizationRequest ( HttpServletRequest request ) {
Assert . notNull ( request , "request cannot be null" ) ;
Assert . notNull ( request , "request cannot be null" ) ;
Assert . hasText ( request . getParameter ( OAuth2ParameterNames . STATE ) , "state parameter cannot be empty" ) ;
String stateParameter = getStateParameter ( request ) ;
Assert . hasText ( stateParameter , "state parameter cannot be empty" ) ;
Map < String , OAuth2AuthorizationRequest > authorizationRequests = this . getAuthorizationRequests ( request ) ;
Map < String , OAuth2AuthorizationRequest > authorizationRequests = this . getAuthorizationRequests ( request ) ;
if ( authorizationRequests ! = null ) {
return authorizationRequests . get ( stateParameter ) ;
return authorizationRequests . get ( request . getParameter ( OAuth2ParameterNames . STATE ) ) ;
}
return null ;
}
}
@Override
@Override
@ -59,35 +59,46 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
this . removeAuthorizationRequest ( request ) ;
this . removeAuthorizationRequest ( request ) ;
return ;
return ;
}
}
Assert . hasText ( authorizationRequest . getState ( ) , "authorizationRequest.state cannot be empty" ) ;
String state = authorizationRequest . getState ( ) ;
Map < String , OAuth2AuthorizationRequest > authorizationRequests = this . getAuthorizationRequests ( request , true ) ;
Assert . hasText ( state , "authorizationRequest.state cannot be empty" ) ;
authorizationRequests . put ( authorizationRequest . getState ( ) , authorizationRequest ) ;
Map < String , OAuth2AuthorizationRequest > authorizationRequests = this . getAuthorizationRequests ( request ) ;
authorizationRequests . put ( state , authorizationRequest ) ;
request . getSession ( ) . setAttribute ( this . sessionAttributeName , authorizationRequests ) ;
}
}
@Override
@Override
public OAuth2AuthorizationRequest removeAuthorizationRequest ( HttpServletRequest request ) {
public OAuth2AuthorizationRequest removeAuthorizationRequest ( HttpServletRequest request ) {
Assert . notNull ( request , "request cannot be null" ) ;
Assert . notNull ( request , "request cannot be null" ) ;
OAuth2AuthorizationRequest authorizationRequest = this . loadAuthorizationRequest ( request ) ;
String stateParameter = getStateParameter ( request ) ;
if ( authorizationRequest ! = null ) {
if ( stateParameter = = null ) {
Map < String , OAuth2AuthorizationRequest > authorizationRequests = this . getAuthorizationRequests ( request ) ;
return null ;
authorizationRequests . remove ( authorizationRequest . getState ( ) ) ;
}
}
return authorizationRequest ;
Map < String , OAuth2AuthorizationRequest > authorizationRequests = this . getAuthorizationRequests ( request ) ;
OAuth2AuthorizationRequest originalRequest = authorizationRequests . remove ( stateParameter ) ;
request . getSession ( ) . setAttribute ( this . sessionAttributeName , authorizationRequests ) ;
return originalRequest ;
}
}
private Map < String , OAuth2AuthorizationRequest > getAuthorizationRequests ( HttpServletRequest request ) {
/ * *
return this . getAuthorizationRequests ( request , false ) ;
* Gets the state parameter from the { @link HttpServletRequest }
* @param request the request to use
* @return the state parameter or null if not found
* /
private String getStateParameter ( HttpServletRequest request ) {
return request . getParameter ( OAuth2ParameterNames . STATE ) ;
}
}
private Map < String , OAuth2AuthorizationRequest > getAuthorizationRequests ( HttpServletRequest request , boolean createSession ) {
/ * *
Map < String , OAuth2AuthorizationRequest > authorizationRequests = null ;
* Gets a non - null and mutable map of { @link OAuth2AuthorizationRequest # getState ( ) } to an { @link OAuth2AuthorizationRequest }
HttpSession session = request . getSession ( createSession ) ;
* @param request
if ( session ! = null ) {
* @return a non - null and mutable map of { @link OAuth2AuthorizationRequest # getState ( ) } to an { @link OAuth2AuthorizationRequest } .
authorizationRequests = ( Map < String , OAuth2AuthorizationRequest > ) session . getAttribute ( this . sessionAttributeName ) ;
* /
if ( authorizationRequests = = null ) {
private Map < String , OAuth2AuthorizationRequest > getAuthorizationRequests ( HttpServletRequest request ) {
authorizationRequests = new HashMap < > ( ) ;
HttpSession session = request . getSession ( false ) ;
session . setAttribute ( this . sessionAttributeName , authorizationRequests ) ;
Map < String , OAuth2AuthorizationRequest > authorizationRequests = session = = null ? null :
}
( Map < String , OAuth2AuthorizationRequest > ) session . getAttribute ( this . sessionAttributeName ) ;
if ( authorizationRequests = = null ) {
return new HashMap < > ( ) ;
}
}
return authorizationRequests ;
return authorizationRequests ;
}
}