@ -53,11 +53,91 @@ public final class CsrfAuthenticationStrategy implements
@@ -53,11 +53,91 @@ public final class CsrfAuthenticationStrategy implements
throws SessionAuthenticationException {
boolean containsToken = this . csrfTokenRepository . loadToken ( request ) ! = null ;
if ( containsToken ) {
CsrfToken newToken = this . csrfTokenRepository . generateToken ( request ) ;
this . csrfTokenRepository . saveToken ( null , request , response ) ;
this . csrfTokenRepository . saveToken ( newToken , request , response ) ;
request . setAttribute ( CsrfToken . class . getName ( ) , newToken ) ;
request . setAttribute ( newToken . getParameterName ( ) , newToken ) ;
CsrfToken newToken = this . csrfTokenRepository . generateToken ( request ) ;
CsrfToken tokenForRequest = new SaveOnAccessCsrfToken ( csrfTokenRepository , request , response , newToken ) ;
request . setAttribute ( CsrfToken . class . getName ( ) , tokenForRequest ) ;
request . setAttribute ( newToken . getParameterName ( ) , tokenForRequest ) ;
}
}
private static final class SaveOnAccessCsrfToken implements CsrfToken {
private transient CsrfTokenRepository tokenRepository ;
private transient HttpServletRequest request ;
private transient HttpServletResponse response ;
private final CsrfToken delegate ;
public SaveOnAccessCsrfToken ( CsrfTokenRepository tokenRepository ,
HttpServletRequest request , HttpServletResponse response ,
CsrfToken delegate ) {
super ( ) ;
this . tokenRepository = tokenRepository ;
this . request = request ;
this . response = response ;
this . delegate = delegate ;
}
public String getHeaderName ( ) {
return delegate . getHeaderName ( ) ;
}
public String getParameterName ( ) {
return delegate . getParameterName ( ) ;
}
public String getToken ( ) {
saveTokenIfNecessary ( ) ;
return delegate . getToken ( ) ;
}
@Override
public String toString ( ) {
return "SaveOnAccessCsrfToken [delegate=" + delegate + "]" ;
}
@Override
public int hashCode ( ) {
final int prime = 31 ;
int result = 1 ;
result = prime * result
+ ( ( delegate = = null ) ? 0 : delegate . hashCode ( ) ) ;
return result ;
}
@Override
public boolean equals ( Object obj ) {
if ( this = = obj )
return true ;
if ( obj = = null )
return false ;
if ( getClass ( ) ! = obj . getClass ( ) )
return false ;
SaveOnAccessCsrfToken other = ( SaveOnAccessCsrfToken ) obj ;
if ( delegate = = null ) {
if ( other . delegate ! = null )
return false ;
} else if ( ! delegate . equals ( other . delegate ) )
return false ;
return true ;
}
private void saveTokenIfNecessary ( ) {
if ( this . tokenRepository = = null ) {
return ;
}
synchronized ( this ) {
if ( tokenRepository ! = null ) {
this . tokenRepository . saveToken ( delegate , request , response ) ;
this . tokenRepository = null ;
this . request = null ;
this . response = null ;
}
}
}
}
}