@ -20,6 +20,7 @@ import java.time.Clock;
import java.time.Duration ;
import java.time.Duration ;
import java.time.Instant ;
import java.time.Instant ;
import java.time.ZoneId ;
import java.time.ZoneId ;
import java.time.temporal.ChronoUnit ;
import java.util.Iterator ;
import java.util.Iterator ;
import java.util.Map ;
import java.util.Map ;
import java.util.concurrent.ConcurrentHashMap ;
import java.util.concurrent.ConcurrentHashMap ;
@ -43,9 +44,6 @@ import org.springframework.web.server.WebSession;
* /
* /
public class InMemoryWebSessionStore implements WebSessionStore {
public class InMemoryWebSessionStore implements WebSessionStore {
/** Minimum period between expiration checks. */
private static final Duration EXPIRATION_CHECK_PERIOD = Duration . ofSeconds ( 60 ) ;
private static final IdGenerator idGenerator = new JdkIdGenerator ( ) ;
private static final IdGenerator idGenerator = new JdkIdGenerator ( ) ;
@ -53,9 +51,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
private final ConcurrentMap < String , InMemoryWebSession > sessions = new ConcurrentHashMap < > ( ) ;
private final ConcurrentMap < String , InMemoryWebSession > sessions = new ConcurrentHashMap < > ( ) ;
private volatile Instant nextExpirationCheckTime = Instant . now ( this . clock ) . plus ( EXPIRATION_CHECK_PERIOD ) ;
private final ExpiredSessionChecker expiredSessionChecker = new ExpiredSessionChecker ( ) ;
private final ReentrantLock expirationCheckLock = new ReentrantLock ( ) ;
/ * *
/ * *
@ -70,8 +66,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
public void setClock ( Clock clock ) {
public void setClock ( Clock clock ) {
Assert . notNull ( clock , "Clock is required" ) ;
Assert . notNull ( clock , "Clock is required" ) ;
this . clock = clock ;
this . clock = clock ;
// Force a check when clock changes..
this . expiredSessionChecker . removeExpiredSessions ( clock . instant ( ) ) ;
this . nextExpirationCheckTime = Instant . now ( this . clock ) ;
}
}
/ * *
/ * *
@ -84,49 +79,29 @@ public class InMemoryWebSessionStore implements WebSessionStore {
@Override
@Override
public Mono < WebSession > createWebSession ( ) {
public Mono < WebSession > createWebSession ( ) {
return Mono . fromSupplier ( InMemoryWebSession : : new ) ;
Instant now = this . clock . instant ( ) ;
this . expiredSessionChecker . checkIfNecessary ( now ) ;
return Mono . fromSupplier ( ( ) - > new InMemoryWebSession ( now ) ) ;
}
}
@Override
@Override
public Mono < WebSession > retrieveSession ( String id ) {
public Mono < WebSession > retrieveSession ( String id ) {
Instant currentTime = Instant . now ( this . clock ) ;
Instant now = this . clock . instant ( ) ;
if ( ! this . sessions . isEmpty ( ) & & ! currentTime . isBefore ( this . nextExpirationCheckTime ) ) {
this . expiredSessionChecker . checkIfNecessary ( now ) ;
checkExpiredSessions ( currentTime ) ;
}
InMemoryWebSession session = this . sessions . get ( id ) ;
InMemoryWebSession session = this . sessions . get ( id ) ;
if ( session = = null ) {
if ( session = = null ) {
return Mono . empty ( ) ;
return Mono . empty ( ) ;
}
}
else if ( session . isExpired ( currentTime ) ) {
else if ( session . isExpired ( now ) ) {
this . sessions . remove ( id ) ;
this . sessions . remove ( id ) ;
return Mono . empty ( ) ;
return Mono . empty ( ) ;
}
}
else {
else {
session . updateLastAccessTime ( currentTime ) ;
session . updateLastAccessTime ( now ) ;
return Mono . just ( session ) ;
return Mono . just ( session ) ;
}
}
}
}
private void checkExpiredSessions ( Instant currentTime ) {
if ( this . expirationCheckLock . tryLock ( ) ) {
try {
Iterator < InMemoryWebSession > iterator = this . sessions . values ( ) . iterator ( ) ;
while ( iterator . hasNext ( ) ) {
InMemoryWebSession session = iterator . next ( ) ;
if ( session . isExpired ( currentTime ) ) {
iterator . remove ( ) ;
session . invalidate ( ) ;
}
}
}
finally {
this . nextExpirationCheckTime = currentTime . plus ( EXPIRATION_CHECK_PERIOD ) ;
this . expirationCheckLock . unlock ( ) ;
}
}
}
@Override
@Override
public Mono < Void > removeSession ( String id ) {
public Mono < Void > removeSession ( String id ) {
this . sessions . remove ( id ) ;
this . sessions . remove ( id ) ;
@ -137,7 +112,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
return Mono . fromSupplier ( ( ) - > {
return Mono . fromSupplier ( ( ) - > {
Assert . isInstanceOf ( InMemoryWebSession . class , webSession ) ;
Assert . isInstanceOf ( InMemoryWebSession . class , webSession ) ;
InMemoryWebSession session = ( InMemoryWebSession ) webSession ;
InMemoryWebSession session = ( InMemoryWebSession ) webSession ;
session . updateLastAccessTime ( Instant . now ( getClock ( ) ) ) ;
session . updateLastAccessTime ( getClock ( ) . instant ( ) ) ;
return session ;
return session ;
} ) ;
} ) ;
}
}
@ -157,8 +132,9 @@ public class InMemoryWebSessionStore implements WebSessionStore {
private final AtomicReference < State > state = new AtomicReference < > ( State . NEW ) ;
private final AtomicReference < State > state = new AtomicReference < > ( State . NEW ) ;
public InMemoryWebSession ( ) {
this . creationTime = Instant . now ( getClock ( ) ) ;
public InMemoryWebSession ( Instant creationTime ) {
this . creationTime = creationTime ;
this . lastAccessTime = this . creationTime ;
this . lastAccessTime = this . creationTime ;
}
}
@ -256,6 +232,57 @@ public class InMemoryWebSessionStore implements WebSessionStore {
}
}
private class ExpiredSessionChecker {
/** Max time before next expiration checks. */
private static final int CHECK_PERIOD = 60 ;
/** Max sessions that can be created before next expiration checks. */
private static final int SESSION_COUNT_THRESHOLD = 500 ;
private final ReentrantLock lock = new ReentrantLock ( ) ;
private Instant nextCheckTime = Instant . now ( clock ) . plus ( CHECK_PERIOD , ChronoUnit . SECONDS ) ;
private long lastSessionCount ;
public void checkIfNecessary ( Instant now ) {
if ( howManyCreated ( ) > SESSION_COUNT_THRESHOLD | | this . nextCheckTime . isBefore ( now ) ) {
removeExpiredSessions ( Instant . now ( clock ) ) ;
}
}
private long howManyCreated ( ) {
return sessions . size ( ) - this . lastSessionCount ;
}
public void removeExpiredSessions ( Instant now ) {
if ( sessions . isEmpty ( ) ) {
return ;
}
if ( this . lock . tryLock ( ) ) {
try {
Iterator < InMemoryWebSession > iterator = sessions . values ( ) . iterator ( ) ;
while ( iterator . hasNext ( ) ) {
InMemoryWebSession session = iterator . next ( ) ;
if ( session . isExpired ( now ) ) {
iterator . remove ( ) ;
session . invalidate ( ) ;
}
}
}
finally {
this . nextCheckTime = clock . instant ( ) . plus ( CHECK_PERIOD , ChronoUnit . SECONDS ) ;
this . lastSessionCount = sessions . size ( ) ;
this . lock . unlock ( ) ;
}
}
}
}
private enum State { NEW , STARTED , EXPIRED }
private enum State { NEW , STARTED , EXPIRED }
}
}