diff --git a/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java b/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java index 3845e293b54..9ee5cffd614 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/HeaderWebSessionIdResolver.java @@ -26,18 +26,23 @@ import org.springframework.web.server.ServerWebExchange; * Request and response header-based {@link WebSessionIdResolver}. * * @author Greg Turnquist + * @author Rob Winch * @since 5.0 */ public class HeaderWebSessionIdResolver implements WebSessionIdResolver { - private String headerName = "SESSION"; + /** + * The default header name + */ + public static final String DEFAULT_HEADER_NAME = "SESSION"; + private String headerName = DEFAULT_HEADER_NAME; /** * Set the name of the session header to use for the session id. * The name is used to extract the session id from the request headers as * well to set the session id on the response headers. - *

By default set to {@literal "SESSION"}. + *

By default set to {@code DEFAULT_HEADER_NAME} * @param headerName the header name */ public void setHeaderName(String headerName) { @@ -45,14 +50,6 @@ public class HeaderWebSessionIdResolver implements WebSessionIdResolver { this.headerName = headerName; } - /** - * Return the configured header name. - */ - public String getHeaderName() { - return this.headerName; - } - - @Override public List resolveSessionIds(ServerWebExchange exchange) { HttpHeaders headers = exchange.getRequest().getHeaders(); @@ -70,4 +67,11 @@ public class HeaderWebSessionIdResolver implements WebSessionIdResolver { this.setSessionId(exchange, ""); } + /** + * Return the configured header name. + * @return the configured header name + */ + private String getHeaderName() { + return this.headerName; + } } diff --git a/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java b/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java index d0d4a7117a0..ad31462a56e 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/HeaderWebSessionIdResolverTests.java @@ -15,170 +15,131 @@ */ package org.springframework.web.server.session; -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.time.ZoneId; -import java.util.UUID; - import org.junit.Before; import org.junit.Test; -import reactor.core.publisher.Mono; - -import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; -import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.web.server.ServerWebExchange; -import org.springframework.web.server.WebSession; -import org.springframework.web.server.adapter.DefaultServerWebExchange; -import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; -import static org.hamcrest.collection.IsCollectionWithSize.hasSize; -import static org.hamcrest.core.Is.is; -import static org.hamcrest.core.IsCollectionContaining.hasItem; +import java.util.Arrays; +import java.util.List; + import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; /** * Tests using {@link HeaderWebSessionIdResolver}. * * @author Greg Turnquist + * @author Rob Winch */ public class HeaderWebSessionIdResolverTests { - - private static final Clock CLOCK = Clock.system(ZoneId.of("GMT")); - - private HeaderWebSessionIdResolver idResolver; - private DefaultWebSessionManager manager; - private ServerWebExchange exchange; - @Before public void setUp() { this.idResolver = new HeaderWebSessionIdResolver(); - this.manager = new DefaultWebSessionManager(); - this.manager.setSessionIdResolver(this.idResolver); + this.exchange = MockServerHttpRequest.get("/path").toExchange(); + } - MockServerHttpRequest request = MockServerHttpRequest.get("/path").build(); - MockServerHttpResponse response = new MockServerHttpResponse(); + @Test + public void expireWhenValidThenSetsEmptyHeader() { + this.idResolver.expireSession(this.exchange); - this.exchange = new DefaultServerWebExchange(request, response, this.manager, - ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); + assertEquals(Arrays.asList(""), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void getSessionWithoutStarting() throws Exception { - WebSession session = this.manager.getSession(this.exchange).block(); - session.save(); + public void expireWhenMultipleInvocationThenSetsSingleEmptyHeader() { + this.idResolver.expireSession(this.exchange); + + this.idResolver.expireSession(this.exchange); - assertFalse(session.isStarted()); - assertFalse(session.isExpired()); - assertNull(this.manager.getSessionStore().retrieveSession(session.getId()).block()); + assertEquals(Arrays.asList(""), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void startSessionExplicitly() throws Exception { - WebSession session = this.manager.getSession(this.exchange).block(); - session.start(); - session.save().block(); - - assertThat(this.exchange.getResponse().getHeaders().containsKey("SESSION"), is(true)); - assertThat(this.exchange.getResponse().getHeaders().get("SESSION"), hasSize(1)); - assertThat(this.exchange.getResponse().getHeaders().get("SESSION"), hasItem(session.getId())); + public void expireWhenAfterSetSessionIdThenSetsEmptyHeader() { + this.idResolver.setSessionId(this.exchange, "123"); + + this.idResolver.expireSession(this.exchange); + + assertEquals(Arrays.asList(""), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void startSessionImplicitly() throws Exception { - WebSession session = this.manager.getSession(this.exchange).block(); - session.getAttributes().put("foo", "bar"); - session.save(); + public void setSessionIdWhenValidThenSetsHeader() { + String id = "123"; + + this.idResolver.setSessionId(this.exchange, id); - assertNotNull(this.exchange.getResponse().getHeaders().get("SESSION")); + assertEquals(Arrays.asList(id), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void existingSession() throws Exception { - UUID sessionId = UUID.randomUUID(); - DefaultWebSession existing = createDefaultWebSession(sessionId); - this.manager.getSessionStore().storeSession(existing); - - this.exchange = this.exchange.mutate() - .request(this.exchange.getRequest().mutate() - .header("SESSION", sessionId.toString()) - .build()) - .build(); - - WebSession actual = this.manager.getSession(this.exchange).block(); - assertNotNull(actual); - assertEquals(existing.getId(), actual.getId()); + public void setSessionIdWhenMultipleThenSetsSingleHeader() { + String id = "123"; + this.idResolver.setSessionId(this.exchange, "overriddenByNextInvocation"); + + this.idResolver.setSessionId(this.exchange, id); + + assertEquals(Arrays.asList(id), + this.exchange.getResponse().getHeaders().get(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME)); } @Test - public void existingSessionIsExpired() throws Exception { - UUID sessionId = UUID.randomUUID(); - DefaultWebSession existing = createDefaultWebSession(sessionId); - existing.start(); - Instant lastAccessTime = Instant.now(CLOCK).minus(Duration.ofMinutes(31)); - existing = new DefaultWebSession(existing, lastAccessTime, s -> Mono.empty()); - this.manager.getSessionStore().storeSession(existing); - - this.exchange = this.exchange.mutate() - .request(this.exchange.getRequest().mutate() - .header("SESSION", sessionId.toString()) - .build()) - .build(); - - WebSession actual = this.manager.getSession(this.exchange).block(); - assertNotSame(existing, actual); + public void setSessionIdWhenCustomHeaderNameThenSetsHeader() { + String headerName = "x-auth"; + String id = "123"; + this.idResolver.setHeaderName(headerName); + + this.idResolver.setSessionId(this.exchange, id); + + assertEquals(Arrays.asList(id), + this.exchange.getResponse().getHeaders().get(headerName)); } - @Test - public void multipleSessionIds() throws Exception { - UUID sessionId = UUID.randomUUID(); - DefaultWebSession existing = createDefaultWebSession(sessionId); - this.manager.getSessionStore().storeSession(existing); - this.manager.getSessionStore().storeSession(createDefaultWebSession(UUID.randomUUID())); - this.manager.getSessionStore().storeSession(createDefaultWebSession(UUID.randomUUID())); - - this.exchange = this.exchange.mutate() - .request(this.exchange.getRequest().mutate() - .header("SESSION", sessionId.toString()) - .build()) - .build(); - - WebSession actual = this.manager.getSession(this.exchange).block(); - assertNotNull(actual); - assertEquals(existing.getId(), actual.getId()); + @Test(expected = IllegalArgumentException.class) + public void setSessionIdWhenNullIdThenIllegalArgumentException() { + String id = null; + + this.idResolver.setSessionId(this.exchange, id); } @Test - public void alternateHeaderName() throws Exception { - this.idResolver.setHeaderName("alternateHeaderName"); - - UUID sessionId = UUID.randomUUID(); - DefaultWebSession existing = createDefaultWebSession(sessionId); - this.manager.getSessionStore().storeSession(existing); - - this.exchange = this.exchange.mutate() - .request(this.exchange.getRequest().mutate() - .header("alternateHeaderName", sessionId.toString()) - .build()) - .build(); - - WebSession actual = this.manager.getSession(this.exchange).block(); - assertNotNull(actual); - assertEquals(existing.getId(), actual.getId()); + public void resolveSessionIdsWhenNoIdsThenEmpty() { + List ids = this.idResolver.resolveSessionIds(this.exchange); + + assertTrue(ids.isEmpty()); } - private DefaultWebSession createDefaultWebSession(UUID sessionId) { - return new DefaultWebSession(() -> sessionId, CLOCK, (s, session) -> Mono.empty(), s -> Mono.empty()); + @Test + public void resolveSessionIdsWhenIdThenIdFound() { + String id = "123"; + this.exchange = MockServerHttpRequest.get("/path") + .header(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME, id) + .toExchange(); + + List ids = this.idResolver.resolveSessionIds(this.exchange); + + assertEquals(Arrays.asList(id), ids); } + @Test + public void resolveSessionIdsWhenMultipleIdsThenIdsFound() { + String id1 = "123"; + String id2 = "abc"; + this.exchange = MockServerHttpRequest.get("/path") + .header(HeaderWebSessionIdResolver.DEFAULT_HEADER_NAME, id1, id2) + .toExchange(); + + List ids = this.idResolver.resolveSessionIds(this.exchange); + + assertEquals(Arrays.asList(id1, id2), ids); + } }