diff --git a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java index 0ee5b887807..b372695aa28 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java @@ -21,10 +21,10 @@ import java.time.Duration; import java.time.Instant; import java.time.ZoneId; import java.time.temporal.ChronoUnit; +import java.util.Collections; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReentrantLock; @@ -47,13 +47,35 @@ public class InMemoryWebSessionStore implements WebSessionStore { private static final IdGenerator idGenerator = new JdkIdGenerator(); + private int maxSessions = 10000; + private Clock clock = Clock.system(ZoneId.of("GMT")); - private final ConcurrentMap sessions = new ConcurrentHashMap<>(); + private final Map sessions = new ConcurrentHashMap<>(); private final ExpiredSessionChecker expiredSessionChecker = new ExpiredSessionChecker(); + /** + * Set the maximum number of sessions that can be stored. Once the limit is + * reached, any attempt to store an additional session will result in an + * {@link IllegalStateException}. + *

By default set to 10000. + * @param maxSessions the maximum number of sessions + * @since 5.1 + */ + public void setMaxSessions(int maxSessions) { + this.maxSessions = maxSessions; + } + + /** + * Return the maximum number of sessions that can be stored. + * @since 5.1 + */ + public int getMaxSessions() { + return this.maxSessions; + } + /** * Configure the {@link Clock} to use to set lastAccessTime on every created * session and to calculate if it is expired. @@ -66,7 +88,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { public void setClock(Clock clock) { Assert.notNull(clock, "Clock is required"); this.clock = clock; - this.expiredSessionChecker.removeExpiredSessions(clock.instant()); + removeExpiredSessions(); } /** @@ -76,6 +98,16 @@ public class InMemoryWebSessionStore implements WebSessionStore { return this.clock; } + /** + * Return the map of sessions with an {@link Collections#unmodifiableMap + * unmodifiable} wrapper. This could be used for management purposes, to + * list active sessions, invalidate expired ones, etc. + * @since 5.1 + */ + public Map getSessions() { + return Collections.unmodifiableMap(this.sessions); + } + @Override public Mono createWebSession() { @@ -108,15 +140,25 @@ public class InMemoryWebSessionStore implements WebSessionStore { return Mono.empty(); } - public Mono updateLastAccessTime(WebSession webSession) { + public Mono updateLastAccessTime(WebSession session) { return Mono.fromSupplier(() -> { - Assert.isInstanceOf(InMemoryWebSession.class, webSession); - InMemoryWebSession session = (InMemoryWebSession) webSession; - session.updateLastAccessTime(getClock().instant()); + Assert.isInstanceOf(InMemoryWebSession.class, session); + ((InMemoryWebSession) session).updateLastAccessTime(this.clock.instant()); return session; }); } + /** + * Check for expired sessions and remove them. Typically such checks are + * kicked off lazily during calls to {@link #createWebSession() create} or + * {@link #retrieveSession retrieve}, no less than 60 seconds apart. + * This method can be called to force a check at a specific time. + * @since 5.1 + */ + public void removeExpiredSessions() { + this.expiredSessionChecker.removeExpiredSessions(this.clock.instant()); + } + private class InMemoryWebSession implements WebSession { @@ -198,6 +240,12 @@ public class InMemoryWebSessionStore implements WebSessionStore { @Override public Mono save() { + if (sessions.size() >= maxSessions) { + expiredSessionChecker.removeExpiredSessions(clock.instant()); + if (sessions.size() >= maxSessions) { + return Mono.error(new IllegalStateException("Max sessions limit reached: " + sessions.size())); + } + } if (!getAttributes().isEmpty()) { this.state.compareAndSet(State.NEW, State.STARTED); } @@ -207,14 +255,14 @@ public class InMemoryWebSessionStore implements WebSessionStore { @Override public boolean isExpired() { - return isExpired(Instant.now(getClock())); + return isExpired(clock.instant()); } - private boolean isExpired(Instant currentTime) { + private boolean isExpired(Instant now) { if (this.state.get().equals(State.EXPIRED)) { return true; } - if (checkExpired(currentTime)) { + if (checkExpired(now)) { this.state.set(State.EXPIRED); return true; } @@ -234,30 +282,21 @@ public class InMemoryWebSessionStore implements WebSessionStore { private class ExpiredSessionChecker { - /** Max time before next expiration checks. */ - private static final int CHECK_PERIOD = 60; - - /** Max sessions that can be created before next expiration checks. */ - private static final int SESSION_COUNT_THRESHOLD = 500; + /** Max time between expiration checks. */ + private static final int CHECK_PERIOD = 60 * 1000; private final ReentrantLock lock = new ReentrantLock(); - private Instant nextCheckTime = Instant.now(clock).plus(CHECK_PERIOD, ChronoUnit.SECONDS); - - private long lastSessionCount; + private Instant checkTime = clock.instant().plus(CHECK_PERIOD, ChronoUnit.MILLIS); public void checkIfNecessary(Instant now) { - if (howManyCreated() > SESSION_COUNT_THRESHOLD || this.nextCheckTime.isBefore(now)) { - removeExpiredSessions(Instant.now(clock)); + if (this.checkTime.isBefore(now)) { + removeExpiredSessions(now); } } - private long howManyCreated() { - return sessions.size() - this.lastSessionCount; - } - public void removeExpiredSessions(Instant now) { if (sessions.isEmpty()) { return; @@ -274,8 +313,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { } } finally { - this.nextCheckTime = clock.instant().plus(CHECK_PERIOD, ChronoUnit.SECONDS); - this.lastSessionCount = sessions.size(); + this.checkTime = now.plus(CHECK_PERIOD, ChronoUnit.MILLIS); this.lock.unlock(); } } diff --git a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java index b469fc8ac25..ee242cbbc43 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java @@ -18,7 +18,6 @@ package org.springframework.web.server.session; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.HashMap; import java.util.Map; import java.util.stream.IntStream; @@ -28,7 +27,11 @@ import org.springframework.beans.DirectFieldAccessor; import org.springframework.web.server.WebSession; import static junit.framework.TestCase.assertSame; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * Unit tests for {@link InMemoryWebSessionStore}. @@ -93,49 +96,37 @@ public class InMemoryWebSessionStoreTests { } @Test - public void expirationCheckBasedOnTimeWindow() { + public void expirationCheckPeriod() { DirectFieldAccessor accessor = new DirectFieldAccessor(this.store); Map sessions = (Map) accessor.getPropertyValue("sessions"); + assertNotNull(sessions); // Create 100 sessions IntStream.range(0, 100).forEach(i -> insertSession()); + assertEquals(100, sessions.size()); - // Force a new clock (31 min later) but don't use setter which would clean expired sessions - Clock newClock = Clock.offset(this.store.getClock(), Duration.ofMinutes(31)); - accessor.setPropertyValue("clock", newClock); - + // Force a new clock (31 min later), don't use setter which would clean expired sessions + accessor.setPropertyValue("clock", Clock.offset(this.store.getClock(), Duration.ofMinutes(31))); assertEquals(100, sessions.size()); - // Create 50 more which forces a time-based check (clock moved forward) - IntStream.range(0, 50).forEach(i -> insertSession()); - assertEquals(50, sessions.size()); + // Create 1 more which forces a time-based check (clock moved forward) + insertSession(); + assertEquals(1, sessions.size()); } @Test - @SuppressWarnings("unchecked") - public void expirationCheckBasedOnSessionCount() { - - DirectFieldAccessor accessor = new DirectFieldAccessor(this.store); - Map sessions = (Map) accessor.getPropertyValue("sessions"); + public void maxSessions() { - // Create 100 sessions - IntStream.range(0, 100).forEach(i -> insertSession()); - - // Copy sessions (about to be expired) - Map expiredSessions = new HashMap<>(sessions); - - // Set new clock which expires and removes above sessions - this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31))); - assertEquals(0, sessions.size()); - - // Re-insert expired sessions - sessions.putAll(expiredSessions); - assertEquals(100, sessions.size()); + IntStream.range(0, 10000).forEach(i -> insertSession()); - // Create 600 more to go over the threshold - IntStream.range(0, 600).forEach(i -> insertSession()); - assertEquals(600, sessions.size()); + try { + insertSession(); + fail(); + } + catch (IllegalStateException ex) { + assertEquals("Max sessions limit reached: 10000", ex.getMessage()); + } } private WebSession insertSession() {