diff --git a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSession.java b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSession.java deleted file mode 100644 index 152c602c0ac..00000000000 --- a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSession.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright 2002-2017 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.web.server.session; - -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.Function; - -import reactor.core.publisher.Mono; - -import org.springframework.util.Assert; -import org.springframework.util.IdGenerator; -import org.springframework.web.server.WebSession; - -/** - * Default implementation of {@link org.springframework.web.server.WebSession}. - * - * @author Rossen Stoyanchev - * @since 5.0 - */ -class DefaultWebSession implements WebSession { - - private final AtomicReference id; - - private final IdGenerator idGenerator; - - private final Map attributes; - - private final Clock clock; - - private final BiFunction> changeIdOperation; - - private final Function> saveOperation; - - private final Instant creationTime; - - private final Instant lastAccessTime; - - private volatile Duration maxIdleTime; - - private volatile State state; - - - /** - * Constructor for creating a brand, new session. - * @param idGenerator the session id generator - * @param clock for access to current time - */ - DefaultWebSession(IdGenerator idGenerator, Clock clock, - BiFunction> changeIdOperation, - Function> saveOperation) { - - Assert.notNull(idGenerator, "'idGenerator' is required."); - Assert.notNull(clock, "'clock' is required."); - Assert.notNull(changeIdOperation, "'changeIdOperation' is required."); - Assert.notNull(saveOperation, "'saveOperation' is required."); - - this.id = new AtomicReference<>(String.valueOf(idGenerator.generateId())); - this.idGenerator = idGenerator; - this.clock = clock; - this.changeIdOperation = changeIdOperation; - this.saveOperation = saveOperation; - this.attributes = new ConcurrentHashMap<>(); - this.creationTime = Instant.now(clock); - this.lastAccessTime = this.creationTime; - this.maxIdleTime = Duration.ofMinutes(30); - this.state = State.NEW; - } - - /** - * Constructor to refresh an existing session for a new request. - * @param existingSession the session to recreate - * @param lastAccessTime the last access time - * @param saveOperation save operation for the current request - */ - DefaultWebSession(DefaultWebSession existingSession, Instant lastAccessTime, - Function> saveOperation) { - - this.id = existingSession.id; - this.idGenerator = existingSession.idGenerator; - this.attributes = existingSession.attributes; - this.clock = existingSession.clock; - this.changeIdOperation = existingSession.changeIdOperation; - this.saveOperation = saveOperation; - this.creationTime = existingSession.creationTime; - this.lastAccessTime = lastAccessTime; - this.maxIdleTime = existingSession.maxIdleTime; - this.state = existingSession.state; - } - - /** - * For testing purposes. - */ - DefaultWebSession(DefaultWebSession existingSession, Instant lastAccessTime) { - this.id = existingSession.id; - this.idGenerator = existingSession.idGenerator; - this.attributes = existingSession.attributes; - this.clock = existingSession.clock; - this.changeIdOperation = existingSession.changeIdOperation; - this.saveOperation = existingSession.saveOperation; - this.creationTime = existingSession.creationTime; - this.lastAccessTime = lastAccessTime; - this.maxIdleTime = existingSession.maxIdleTime; - this.state = existingSession.state; - } - - - @Override - public String getId() { - return this.id.get(); - } - - @Override - public Map getAttributes() { - return this.attributes; - } - - @Override - public Instant getCreationTime() { - return this.creationTime; - } - - @Override - public Instant getLastAccessTime() { - return this.lastAccessTime; - } - - /** - *

By default this is set to 30 minutes. - * @param maxIdleTime the max idle time - */ - @Override - public void setMaxIdleTime(Duration maxIdleTime) { - this.maxIdleTime = maxIdleTime; - } - - @Override - public Duration getMaxIdleTime() { - return this.maxIdleTime; - } - - - @Override - public void start() { - this.state = State.STARTED; - } - - @Override - public boolean isStarted() { - State value = this.state; - return (State.STARTED.equals(value) || (State.NEW.equals(value) && !getAttributes().isEmpty())); - } - - @Override - public Mono changeSessionId() { - String oldId = this.id.get(); - String newId = String.valueOf(this.idGenerator.generateId()); - this.id.set(newId); - return this.changeIdOperation.apply(oldId, this).doOnError(ex -> this.id.set(oldId)); - } - - @Override - public Mono save() { - return this.saveOperation.apply(this); - } - - @Override - public boolean isExpired() { - return (isStarted() && !this.maxIdleTime.isNegative() && - Instant.now(this.clock).minus(this.maxIdleTime).isAfter(this.lastAccessTime)); - } - - - private enum State { NEW, STARTED } - -} diff --git a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java index e9502090701..fa501205541 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java @@ -15,17 +15,12 @@ */ package org.springframework.web.server.session; -import java.time.Clock; -import java.time.Instant; -import java.time.ZoneId; import java.util.List; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.util.Assert; -import org.springframework.util.IdGenerator; -import org.springframework.util.JdkIdGenerator; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; @@ -36,19 +31,15 @@ import org.springframework.web.server.WebSession; * {@link WebSessionStore} * * @author Rossen Stoyanchev + * @author Rob Winch * @since 5.0 */ public class DefaultWebSessionManager implements WebSessionManager { - private static final IdGenerator idGenerator = new JdkIdGenerator(); - - private WebSessionIdResolver sessionIdResolver = new CookieWebSessionIdResolver(); private WebSessionStore sessionStore = new InMemoryWebSessionStore(); - private Clock clock = Clock.system(ZoneId.of("GMT")); - /** * Configure the id resolution strategy. @@ -84,49 +75,24 @@ public class DefaultWebSessionManager implements WebSessionManager { return this.sessionStore; } - /** - * Configure the {@link Clock} to use to set lastAccessTime on every created - * session and to calculate if it is expired. - *

This may be useful to align to different timezone or to set the clock - * back in a test, e.g. {@code Clock.offset(clock, Duration.ofMinutes(-31))} - * in order to simulate session expiration. - *

By default this is {@code Clock.system(ZoneId.of("GMT"))}. - * @param clock the clock to use - */ - public void setClock(Clock clock) { - Assert.notNull(clock, "'clock' is required."); - this.clock = clock; - } - - /** - * Return the configured clock for session lastAccessTime calculations. - */ - public Clock getClock() { - return this.clock; - } - @Override public Mono getSession(ServerWebExchange exchange) { return Mono.defer(() -> retrieveSession(exchange) .flatMap(session -> removeSessionIfExpired(exchange, session)) - .map(session -> { - Instant lastAccessTime = Instant.now(getClock()); - return new DefaultWebSession(session, lastAccessTime, s -> saveSession(exchange, s)); - }) - .switchIfEmpty(createSession(exchange)) - .doOnNext(session -> exchange.getResponse().beforeCommit(session::save))); + .flatMap(this.getSessionStore()::updateLastAccessTime) + .switchIfEmpty(this.sessionStore.createWebSession()) + .doOnNext(session -> exchange.getResponse().beforeCommit(() -> save(exchange, session)))); } - private Mono retrieveSession(ServerWebExchange exchange) { + private Mono retrieveSession(ServerWebExchange exchange) { return Flux.fromIterable(getSessionIdResolver().resolveSessionIds(exchange)) .concatMap(this.sessionStore::retrieveSession) - .cast(DefaultWebSession.class) .next(); } - private Mono removeSessionIfExpired(ServerWebExchange exchange, DefaultWebSession session) { + private Mono removeSessionIfExpired(ServerWebExchange exchange, WebSession session) { if (session.isExpired()) { this.sessionIdResolver.expireSession(exchange); return this.sessionStore.removeSession(session.getId()).then(Mono.empty()); @@ -134,39 +100,28 @@ public class DefaultWebSessionManager implements WebSessionManager { return Mono.just(session); } - private Mono saveSession(ServerWebExchange exchange, WebSession session) { + private Mono save(ServerWebExchange exchange, WebSession session) { if (session.isExpired()) { return Mono.error(new IllegalStateException( "Sessions are checked for expiration and have their " + - "lastAccessTime updated when first accessed during request processing. " + - "However this session is expired meaning that maxIdleTime elapsed " + - "before the call to session.save().")); + "lastAccessTime updated when first accessed during request processing. " + + "However this session is expired meaning that maxIdleTime elapsed " + + "before the call to session.save().")); } if (!session.isStarted()) { return Mono.empty(); } - // Force explicit start - session.start(); - if (hasNewSessionId(exchange, session)) { - this.sessionIdResolver.setSessionId(exchange, session.getId()); + DefaultWebSessionManager.this.sessionIdResolver.setSessionId(exchange, session.getId()); } - return this.sessionStore.storeSession(session); + return session.save(); } private boolean hasNewSessionId(ServerWebExchange exchange, WebSession session) { List ids = getSessionIdResolver().resolveSessionIds(exchange); return ids.isEmpty() || !session.getId().equals(ids.get(0)); } - - private Mono createSession(ServerWebExchange exchange) { - return Mono.fromSupplier(() -> - new DefaultWebSession(idGenerator, getClock(), - (oldId, session) -> this.sessionStore.changeSessionId(oldId, session), - session -> saveSession(exchange, session))); - } - } diff --git a/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java b/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java index 3845e293b54..54dafb43cee 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java @@ -26,18 +26,23 @@ import org.springframework.web.server.ServerWebExchange; * Request and response header-based {@link WebSessionIdResolver}. * * @author Greg Turnquist + * @author Rob Winch * @since 5.0 */ public class HeaderWebSessionIdResolver implements WebSessionIdResolver { - private String headerName = "SESSION"; + /** Default value for {@link #setHeaderName(String)}. */ + public static final String DEFAULT_HEADER_NAME = "SESSION"; + + + private String headerName = DEFAULT_HEADER_NAME; /** * Set the name of the session header to use for the session id. * The name is used to extract the session id from the request headers as * well to set the session id on the response headers. - *

By default set to {@literal "SESSION"}. + *

By default set to {@code DEFAULT_HEADER_NAME} * @param headerName the header name */ public void setHeaderName(String headerName) { @@ -47,6 +52,7 @@ public class HeaderWebSessionIdResolver implements WebSessionIdResolver { /** * Return the configured header name. + * @return the configured header name */ public String getHeaderName() { return this.headerName; diff --git a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java index dd193faf5ef..3b833490f9e 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java @@ -15,28 +15,63 @@ */ package org.springframework.web.server.session; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; import reactor.core.publisher.Mono; +import org.springframework.util.Assert; +import org.springframework.util.IdGenerator; +import org.springframework.util.JdkIdGenerator; import org.springframework.web.server.WebSession; /** * Simple Map-based storage for {@link WebSession} instances. * * @author Rossen Stoyanchev + * @author Rob Winch * @since 5.0 */ public class InMemoryWebSessionStore implements WebSessionStore { + private static final IdGenerator idGenerator = new JdkIdGenerator(); + + + private Clock clock = Clock.system(ZoneId.of("GMT")); + private final Map sessions = new ConcurrentHashMap<>(); + /** + * Configure the {@link Clock} to use to set lastAccessTime on every created + * session and to calculate if it is expired. + *

This may be useful to align to different timezone or to set the clock + * back in a test, e.g. {@code Clock.offset(clock, Duration.ofMinutes(-31))} + * in order to simulate session expiration. + *

By default this is {@code Clock.system(ZoneId.of("GMT"))}. + * @param clock the clock to use + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "'clock' is required."); + this.clock = clock; + } + + /** + * Return the configured clock for session lastAccessTime calculations. + */ + public Clock getClock() { + return this.clock; + } + + @Override - public Mono storeSession(WebSession session) { - this.sessions.put(session.getId(), session); - return Mono.empty(); + public Mono createWebSession() { + return Mono.fromSupplier(InMemoryWebSession::new); } @Override @@ -45,16 +80,125 @@ public class InMemoryWebSessionStore implements WebSessionStore { } @Override - public Mono changeSessionId(String oldId, WebSession session) { + public Mono removeSession(String id) { + this.sessions.remove(id); + return Mono.empty(); + } + + public Mono updateLastAccessTime(WebSession webSession) { + return Mono.fromSupplier(() -> { + InMemoryWebSession session = (InMemoryWebSession) webSession; + Instant lastAccessTime = Instant.now(getClock()); + return new InMemoryWebSession(session, lastAccessTime); + }); + } + + /* Private methods for InMemoryWebSession */ + + private Mono changeSessionId(String oldId, WebSession session) { this.sessions.remove(oldId); this.sessions.put(session.getId(), session); return Mono.empty(); } - @Override - public Mono removeSession(String id) { - this.sessions.remove(id); + private Mono storeSession(WebSession session) { + this.sessions.put(session.getId(), session); return Mono.empty(); } + + private class InMemoryWebSession implements WebSession { + + private final AtomicReference id; + + private final Map attributes; + + private final Instant creationTime; + + private final Instant lastAccessTime; + + private volatile Duration maxIdleTime; + + private volatile boolean started; + + + InMemoryWebSession() { + this.id = new AtomicReference<>(String.valueOf(idGenerator.generateId())); + this.attributes = new ConcurrentHashMap<>(); + this.creationTime = Instant.now(getClock()); + this.lastAccessTime = this.creationTime; + this.maxIdleTime = Duration.ofMinutes(30); + } + + InMemoryWebSession(InMemoryWebSession existingSession, Instant lastAccessTime) { + this.id = existingSession.id; + this.attributes = existingSession.attributes; + this.creationTime = existingSession.creationTime; + this.lastAccessTime = lastAccessTime; + this.maxIdleTime = existingSession.maxIdleTime; + this.started = existingSession.isStarted(); // Use method (explicit or implicit start) + } + + + @Override + public String getId() { + return this.id.get(); + } + + @Override + public Map getAttributes() { + return this.attributes; + } + + @Override + public Instant getCreationTime() { + return this.creationTime; + } + + @Override + public Instant getLastAccessTime() { + return this.lastAccessTime; + } + + @Override + public void setMaxIdleTime(Duration maxIdleTime) { + this.maxIdleTime = maxIdleTime; + } + + @Override + public Duration getMaxIdleTime() { + return this.maxIdleTime; + } + + + @Override + public void start() { + this.started = true; + } + + @Override + public boolean isStarted() { + return this.started || !getAttributes().isEmpty(); + } + + @Override + public Mono changeSessionId() { + String oldId = this.id.get(); + String newId = String.valueOf(idGenerator.generateId()); + this.id.set(newId); + return InMemoryWebSessionStore.this.changeSessionId(oldId, this).doOnError(ex -> this.id.set(oldId)); + } + + @Override + public Mono save() { + return InMemoryWebSessionStore.this.storeSession(this); + } + + @Override + public boolean isExpired() { + return (isStarted() && !this.maxIdleTime.isNegative() && + Instant.now(getClock()).minus(this.maxIdleTime).isAfter(this.lastAccessTime)); + } + } + } diff --git a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java index b52cafe5469..4ff61176b39 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java @@ -23,16 +23,20 @@ import org.springframework.web.server.WebSession; * Strategy for {@link WebSession} persistence. * * @author Rossen Stoyanchev + * @author Rob Winch * @since 5.0 */ public interface WebSessionStore { /** - * Store the given WebSession. - * @param session the session to store - * @return a completion notification (success or error) + * Create a new WebSession. + *

Note that this does nothing more than create a new instance. + * The session can later be started explicitly via {@link WebSession#start()} + * or implicitly by adding attributes -- and then persisted via + * {@link WebSession#save()}. + * @return the created session instance */ - Mono storeSession(WebSession session); + Mono createWebSession(); /** * Return the WebSession for the given id. @@ -41,18 +45,6 @@ public interface WebSessionStore { */ Mono retrieveSession(String sessionId); - /** - * Update WebSession data storage to reflect a change in session id. - *

Note that the same can be achieved via a combination of - * {@link #removeSession} + {@link #storeSession}. The purpose of this method - * is to allow a more efficient replacement of the session id mapping - * without replacing and storing the session with all of its data. - * @param oldId the previous session id - * @param session the session reflecting the changed session id - * @return completion notification (success or error) - */ - Mono changeSessionId(String oldId, WebSession session); - /** * Remove the WebSession for the specified id. * @param sessionId the id of the session to remove @@ -60,4 +52,10 @@ public interface WebSessionStore { */ Mono removeSession(String sessionId); + /** + * Update the last accessed timestamp to "now". + * @param webSession the session to update + * @return the session with the updated last access time + */ + Mono updateLastAccessTime(WebSession webSession); } diff --git a/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java index d05cd7fd042..3bfe7ccf5a0 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java @@ -15,25 +15,19 @@ */ package org.springframework.web.server.session; -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.time.ZoneId; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.List; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; import reactor.core.publisher.Mono; import org.springframework.http.codec.ServerCodecConfigurer; -import org.springframework.lang.Nullable; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; -import org.springframework.util.IdGenerator; -import org.springframework.util.JdkIdGenerator; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; import org.springframework.web.server.adapter.DefaultServerWebExchange; @@ -42,33 +36,52 @@ import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; /** * Unit tests for {@link DefaultWebSessionManager}. * @author Rossen Stoyanchev + * @author Rob Winch */ +@RunWith(MockitoJUnitRunner.class) public class DefaultWebSessionManagerTests { - private static final Clock CLOCK = Clock.system(ZoneId.of("GMT")); + private DefaultWebSessionManager manager; - private static final IdGenerator idGenerator = new JdkIdGenerator(); + private ServerWebExchange exchange; + @Mock + private WebSessionIdResolver idResolver; - private DefaultWebSessionManager manager; + @Mock + private WebSessionStore store; - private TestWebSessionIdResolver idResolver; + @Mock + private WebSession createSession; - private ServerWebExchange exchange; + @Mock + private WebSession retrieveSession; + @Mock + private WebSession updateSession; @Before public void setUp() throws Exception { + when(this.store.createWebSession()).thenReturn(Mono.just(this.createSession)); + when(this.store.updateLastAccessTime(any())).thenReturn(Mono.just(this.updateSession)); + when(this.store.retrieveSession(any())).thenReturn(Mono.just(this.retrieveSession)); + when(this.createSession.save()).thenReturn(Mono.empty()); + when(this.updateSession.getId()).thenReturn("update-session-id"); + when(this.retrieveSession.getId()).thenReturn("retrieve-session-id"); + this.manager = new DefaultWebSessionManager(); - this.idResolver = new TestWebSessionIdResolver(); this.manager.setSessionIdResolver(this.idResolver); + this.manager.setSessionStore(this.store); MockServerHttpRequest request = MockServerHttpRequest.get("/path").build(); MockServerHttpResponse response = new MockServerHttpResponse(); @@ -76,115 +89,78 @@ public class DefaultWebSessionManagerTests { ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); } - @Test - public void getSessionWithoutStarting() throws Exception { - this.idResolver.setIdsToResolve(Collections.emptyList()); + public void getSessionSaveWhenCreatedAndNotStartedThenNotSaved() throws Exception { + when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.emptyList()); WebSession session = this.manager.getSession(this.exchange).block(); - session.save(); + this.exchange.getResponse().setComplete().block(); assertFalse(session.isStarted()); assertFalse(session.isExpired()); - assertNull(this.idResolver.getSavedId()); - assertNull(this.manager.getSessionStore().retrieveSession(session.getId()).block()); + verifyZeroInteractions(this.retrieveSession, this.updateSession); + verify(this.createSession, never()).save(); + verify(this.idResolver, never()).setSessionId(any(), any()); } @Test - public void startSessionExplicitly() throws Exception { - this.idResolver.setIdsToResolve(Collections.emptyList()); + public void getSessionSaveWhenCreatedAndStartedThenSavesAndSetsId() throws Exception { + when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.emptyList()); WebSession session = this.manager.getSession(this.exchange).block(); - session.start(); - session.save(); + when(this.createSession.isStarted()).thenReturn(true); + this.exchange.getResponse().setComplete().block(); String id = session.getId(); - assertNotNull(this.idResolver.getSavedId()); - assertEquals(id, this.idResolver.getSavedId()); - assertSame(session, this.manager.getSessionStore().retrieveSession(id).block()); + verify(this.store).createWebSession(); + verify(this.createSession).save(); + verify(this.idResolver).setSessionId(any(), eq(id)); } @Test - public void startSessionImplicitly() throws Exception { - this.idResolver.setIdsToResolve(Collections.emptyList()); + public void exchangeWhenResponseSetCompleteThenSavesAndSetsId() throws Exception { + when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.emptyList()); + String id = this.createSession.getId(); WebSession session = this.manager.getSession(this.exchange).block(); - session.getAttributes().put("foo", "bar"); - session.save(); + when(this.createSession.isStarted()).thenReturn(true); + this.exchange.getResponse().setComplete().block(); - assertNotNull(this.idResolver.getSavedId()); + verify(this.idResolver).setSessionId(any(), eq(id)); + verify(this.createSession).save(); } @Test public void existingSession() throws Exception { - DefaultWebSession existing = createDefaultWebSession(); - String id = existing.getId(); - this.manager.getSessionStore().storeSession(existing); - this.idResolver.setIdsToResolve(Collections.singletonList(id)); + String id = this.updateSession.getId(); + when(this.store.retrieveSession(id)).thenReturn(Mono.just(this.updateSession)); + when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.singletonList(id)); WebSession actual = this.manager.getSession(this.exchange).block(); assertNotNull(actual); - assertEquals(existing.getId(), actual.getId()); + assertEquals(id, actual.getId()); } @Test public void existingSessionIsExpired() throws Exception { - DefaultWebSession existing = createDefaultWebSession(); - existing.start(); - Instant lastAccessTime = Instant.now(CLOCK).minus(Duration.ofMinutes(31)); - existing = new DefaultWebSession(existing, lastAccessTime, s -> Mono.empty()); - this.manager.getSessionStore().storeSession(existing); - this.idResolver.setIdsToResolve(Collections.singletonList("1")); + String id = this.retrieveSession.getId(); + when(this.retrieveSession.isExpired()).thenReturn(true); + when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.singletonList(id)); + when(this.store.removeSession(any())).thenReturn(Mono.empty()); WebSession actual = this.manager.getSession(this.exchange).block(); - assertNotSame(existing, actual); + assertEquals(this.createSession.getId(), actual.getId()); + verify(this.store).removeSession(id); + verify(this.idResolver).expireSession(any()); } @Test public void multipleSessionIds() throws Exception { - DefaultWebSession existing = createDefaultWebSession(); + WebSession existing = this.updateSession; String id = existing.getId(); - this.manager.getSessionStore().storeSession(existing); - this.idResolver.setIdsToResolve(Arrays.asList("neither-this", "nor-that", id)); + when(this.store.retrieveSession(any())).thenReturn(Mono.empty()); + when(this.store.retrieveSession(id)).thenReturn(Mono.just(existing)); + when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Arrays.asList("neither-this", "nor-that", id)); WebSession actual = this.manager.getSession(this.exchange).block(); assertNotNull(actual); assertEquals(existing.getId(), actual.getId()); } - - private DefaultWebSession createDefaultWebSession() { - return new DefaultWebSession(idGenerator, CLOCK, (s, session) -> Mono.empty(), s -> Mono.empty()); - } - - - private static class TestWebSessionIdResolver implements WebSessionIdResolver { - - private List idsToResolve = new ArrayList<>(); - - @Nullable - private String id = null; - - - public void setIdsToResolve(List idsToResolve) { - this.idsToResolve = idsToResolve; - } - - @Nullable - public String getSavedId() { - return this.id; - } - - @Override - public List resolveSessionIds(ServerWebExchange exchange) { - return this.idsToResolve; - } - - @Override - public void setSessionId(ServerWebExchange exchange, String sessionId) { - this.id = sessionId; - } - - @Override - public void expireSession(ServerWebExchange exchange) { - this.id = null; - } - } - } diff --git a/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java b/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java index d0d4a7117a0..ad31462a56e 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java @@ -15,170 +15,131 @@ */ package org.springframework.web.server.session; -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.time.ZoneId; -import java.util.UUID; - import org.junit.Before; import org.junit.Test; -import reactor.core.publisher.Mono; - -import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; -import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.web.server.ServerWebExchange; -import org.springframework.web.server.WebSession; -import org.springframework.web.server.adapter.DefaultServerWebExchange; -import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; -import static org.hamcrest.collection.IsCollectionWithSize.hasSize; -import static org.hamcrest.core.Is.is; -import static org.hamcrest.core.IsCollectionContaining.hasItem; +import java.util.Arrays; +import java.util.List; + import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; /** * Tests using {@link HeaderWebSessionIdResolver}. * * @author Greg Turnquist + * @author Rob Winch */ public class HeaderWebSessionIdResolverTests { - - private static final Clock CLOCK = Clock.system(ZoneId.of("GMT")); - - private HeaderWebSessionIdResolver idResolver; - private DefaultWebSessionManager manager; - private ServerWebExchange exchange; - @Before public void setUp() { this.idResolver = new HeaderWebSessionIdResolver(); - this.manager = new DefaultWebSessionManager(); - this.manager.setSessionIdResolver(this.idResolver); + this.exchange = MockServerHttpRequest.get("/path").toExchange(); + } - MockServerHttpRequest request = MockServerHttpRequest.get("/path").build(); - MockServerHttpResponse response = new MockServerHttpResponse(); + @Test + public void expireWhenValidThenSetsEmptyHeader() { + this.idResolver.expireSession(this.exchange); - this.exchange = new DefaultServerWebExchange(request, response, this.manager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + assertEquals(Arrays.asList(""), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void getSessionWithoutStarting() throws Exception { - WebSession session = this.manager.getSession(this.exchange).block(); - session.save(); + public void expireWhenMultipleInvocationThenSetsSingleEmptyHeader() { + this.idResolver.expireSession(this.exchange); + + this.idResolver.expireSession(this.exchange); - assertFalse(session.isStarted()); - assertFalse(session.isExpired()); - assertNull(this.manager.getSessionStore().retrieveSession(session.getId()).block()); + assertEquals(Arrays.asList(""), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void startSessionExplicitly() throws Exception { - WebSession session = this.manager.getSession(this.exchange).block(); - session.start(); - session.save().block(); - - assertThat(this.exchange.getResponse().getHeaders().containsKey("SESSION"), is(true)); - assertThat(this.exchange.getResponse().getHeaders().get("SESSION"), hasSize(1)); - assertThat(this.exchange.getResponse().getHeaders().get("SESSION"), hasItem(session.getId())); + public void expireWhenAfterSetSessionIdThenSetsEmptyHeader() { + this.idResolver.setSessionId(this.exchange, "123"); + + this.idResolver.expireSession(this.exchange); + + assertEquals(Arrays.asList(""), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void startSessionImplicitly() throws Exception { - WebSession session = this.manager.getSession(this.exchange).block(); - session.getAttributes().put("foo", "bar"); - session.save(); + public void setSessionIdWhenValidThenSetsHeader() { + String id = "123"; + + this.idResolver.setSessionId(this.exchange, id); - assertNotNull(this.exchange.getResponse().getHeaders().get("SESSION")); + assertEquals(Arrays.asList(id), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void existingSession() throws Exception { - UUID sessionId = UUID.randomUUID(); - DefaultWebSession existing = createDefaultWebSession(sessionId); - this.manager.getSessionStore().storeSession(existing); - - this.exchange = this.exchange.mutate() - .request(this.exchange.getRequest().mutate() - .header("SESSION", sessionId.toString()) - .build()) - .build(); - - WebSession actual = this.manager.getSession(this.exchange).block(); - assertNotNull(actual); - assertEquals(existing.getId(), actual.getId()); + public void setSessionIdWhenMultipleThenSetsSingleHeader() { + String id = "123"; + this.idResolver.setSessionId(this.exchange, "overriddenByNextInvocation"); + + this.idResolver.setSessionId(this.exchange, id); + + assertEquals(Arrays.asList(id), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void existingSessionIsExpired() throws Exception { - UUID sessionId = UUID.randomUUID(); - DefaultWebSession existing = createDefaultWebSession(sessionId); - existing.start(); - Instant lastAccessTime = Instant.now(CLOCK).minus(Duration.ofMinutes(31)); - existing = new DefaultWebSession(existing, lastAccessTime, s -> Mono.empty()); - this.manager.getSessionStore().storeSession(existing); - - this.exchange = this.exchange.mutate() - .request(this.exchange.getRequest().mutate() - .header("SESSION", sessionId.toString()) - .build()) - .build(); - - WebSession actual = this.manager.getSession(this.exchange).block(); - assertNotSame(existing, actual); + public void setSessionIdWhenCustomHeaderNameThenSetsHeader() { + String headerName = "x-auth"; + String id = "123"; + this.idResolver.setHeaderName(headerName); + + this.idResolver.setSessionId(this.exchange, id); + + assertEquals(Arrays.asList(id), + this.exchange.getResponse().getHeaders().get(headerName)); } - @Test - public void multipleSessionIds() throws Exception { - UUID sessionId = UUID.randomUUID(); - DefaultWebSession existing = createDefaultWebSession(sessionId); - this.manager.getSessionStore().storeSession(existing); - this.manager.getSessionStore().storeSession(createDefaultWebSession(UUID.randomUUID())); - this.manager.getSessionStore().storeSession(createDefaultWebSession(UUID.randomUUID())); - - this.exchange = this.exchange.mutate() - .request(this.exchange.getRequest().mutate() - .header("SESSION", sessionId.toString()) - .build()) - .build(); - - WebSession actual = this.manager.getSession(this.exchange).block(); - assertNotNull(actual); - assertEquals(existing.getId(), actual.getId()); + @Test(expected = IllegalArgumentException.class) + public void setSessionIdWhenNullIdThenIllegalArgumentException() { + String id = null; + + this.idResolver.setSessionId(this.exchange, id); } @Test - public void alternateHeaderName() throws Exception { - this.idResolver.setHeaderName("alternateHeaderName"); - - UUID sessionId = UUID.randomUUID(); - DefaultWebSession existing = createDefaultWebSession(sessionId); - this.manager.getSessionStore().storeSession(existing); - - this.exchange = this.exchange.mutate() - .request(this.exchange.getRequest().mutate() - .header("alternateHeaderName", sessionId.toString()) - .build()) - .build(); - - WebSession actual = this.manager.getSession(this.exchange).block(); - assertNotNull(actual); - assertEquals(existing.getId(), actual.getId()); + public void resolveSessionIdsWhenNoIdsThenEmpty() { + List ids = this.idResolver.resolveSessionIds(this.exchange); + + assertTrue(ids.isEmpty()); } - private DefaultWebSession createDefaultWebSession(UUID sessionId) { - return new DefaultWebSession(() -> sessionId, CLOCK, (s, session) -> Mono.empty(), s -> Mono.empty()); + @Test + public void resolveSessionIdsWhenIdThenIdFound() { + String id = "123"; + this.exchange = MockServerHttpRequest.get("/path") + .header(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME, id) + .toExchange(); + + List ids = this.idResolver.resolveSessionIds(this.exchange); + + assertEquals(Arrays.asList(id), ids); } + @Test + public void resolveSessionIdsWhenMultipleIdsThenIdsFound() { + String id1 = "123"; + String id2 = "abc"; + this.exchange = MockServerHttpRequest.get("/path") + .header(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME, id1, id2) + .toExchange(); + + List ids = this.idResolver.resolveSessionIds(this.exchange); + + assertEquals(Arrays.asList(id1, id2), ids); + } } diff --git a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java new file mode 100644 index 00000000000..efc92a315ae --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.server.session; + +import org.junit.Test; + +import org.springframework.web.server.WebSession; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests. + * @author Rob Winch + */ +public class InMemoryWebSessionStoreTests { + + private InMemoryWebSessionStore sessionStore = new InMemoryWebSessionStore(); + + + @Test + public void constructorWhenImplicitStartCopiedThenCopyIsStarted() { + WebSession original = this.sessionStore.createWebSession().block(); + assertNotNull(original); + original.getAttributes().put("foo", "bar"); + + WebSession copy = this.sessionStore.updateLastAccessTime(original).block(); + assertNotNull(copy); + assertTrue(copy.isStarted()); + } + + @Test + public void constructorWhenExplicitStartCopiedThenCopyIsStarted() { + WebSession original = this.sessionStore.createWebSession().block(); + assertNotNull(original); + original.start(); + + WebSession copy = this.sessionStore.updateLastAccessTime(original).block(); + assertNotNull(copy); + assertTrue(copy.isStarted()); + } + + @Test + public void startsSessionExplicitly() { + WebSession session = this.sessionStore.createWebSession().block(); + assertNotNull(session); + session.start(); + assertTrue(session.isStarted()); + } + + @Test + public void startsSessionImplicitly() { + WebSession session = this.sessionStore.createWebSession().block(); + assertNotNull(session); + session.start(); + session.getAttributes().put("foo", "bar"); + assertTrue(session.isStarted()); + } + +} \ No newline at end of file diff --git a/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java index 705044447f8..48958d75ada 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java @@ -110,12 +110,10 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe assertEquals(2, this.handler.getSessionRequestCount()); // Now set the clock of the session back by 31 minutes - WebSessionStore store = this.sessionManager.getSessionStore(); - DefaultWebSession session = (DefaultWebSession) store.retrieveSession(id).block(); + InMemoryWebSessionStore store = (InMemoryWebSessionStore) this.sessionManager.getSessionStore(); + WebSession session = store.retrieveSession(id).block(); assertNotNull(session); - Instant lastAccessTime = Clock.offset(this.sessionManager.getClock(), Duration.ofMinutes(-31)).instant(); - session = new DefaultWebSession(session, lastAccessTime); - store.storeSession(session); + store.setClock(Clock.offset(store.getClock(), Duration.ofMinutes(31))); // Third request: expired session, new session created request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build();