diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 1a8d2e3a73..1939a4f157 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -2143,19 +2143,23 @@ public class ServerHttpSecurity { @Override public Mono changeSessionId() { String currentId = this.session.getId(); - return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId) - .flatMap((information) -> this.session.changeSessionId().thenReturn(information)) - .flatMap((information) -> { - information = information.withSessionId(this.session.getId()); - return SessionRegistryWebFilter.this.sessionRegistry.saveSessionInformation(information); - }); + return this.session.changeSessionId() + .then(Mono.defer( + () -> SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId) + .flatMap((information) -> { + information = information.withSessionId(this.session.getId()); + return SessionRegistryWebFilter.this.sessionRegistry + .saveSessionInformation(information); + }))); } @Override public Mono invalidate() { String currentId = this.session.getId(); - return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId) - .flatMap((information) -> this.session.invalidate()); + return this.session.invalidate() + .then(Mono.defer(() -> SessionRegistryWebFilter.this.sessionRegistry + .removeSessionInformation(currentId))) + .then(); } @Override diff --git a/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java index 089a3916d4..22a5eee49c 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java @@ -67,6 +67,7 @@ import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import org.springframework.web.server.session.DefaultWebSessionManager; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -95,14 +96,19 @@ public class SessionManagementSpecTests { ResponseCookie firstLoginSessionCookie = loginReturningCookie(data); // second login should fail - this.client.mutateWith(csrf()) + ResponseCookie secondLoginSessionCookie = this.client.mutateWith(csrf()) .post() .uri("/login") .contentType(MediaType.MULTIPART_FORM_DATA) .body(BodyInserters.fromFormData(data)) .exchange() .expectHeader() - .location("/login?error"); + .location("/login?error") + .returnResult(Void.class) + .getResponseCookies() + .getFirst("SESSION"); + + assertThat(secondLoginSessionCookie).isNull(); // first login should still be valid this.client.mutateWith(csrf()) diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java b/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java index ca777da888..556bf042df 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java @@ -81,8 +81,8 @@ public final class ConcurrentSessionControlServerAuthenticationSuccessHandler } } } - return this.maximumSessionsExceededHandler - .handle(new MaximumSessionsContext(authentication, registeredSessions, maximumSessions)); + return this.maximumSessionsExceededHandler.handle(new MaximumSessionsContext(authentication, + registeredSessions, maximumSessions, currentSession)); }); } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java b/web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java index 0875051b78..9ba11bc17d 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import java.util.List; import org.springframework.security.core.Authentication; import org.springframework.security.core.session.ReactiveSessionInformation; +import org.springframework.web.server.WebSession; public final class MaximumSessionsContext { @@ -29,11 +30,14 @@ public final class MaximumSessionsContext { private final int maximumSessionsAllowed; + private final WebSession currentSession; + public MaximumSessionsContext(Authentication authentication, List sessions, - int maximumSessionsAllowed) { + int maximumSessionsAllowed, WebSession currentSession) { this.authentication = authentication; this.sessions = sessions; this.maximumSessionsAllowed = maximumSessionsAllowed; + this.currentSession = currentSession; } public Authentication getAuthentication() { @@ -48,4 +52,8 @@ public final class MaximumSessionsContext { return this.maximumSessionsAllowed; } + public WebSession getCurrentSession() { + return this.currentSession; + } + } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java b/web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java index a98f8795e6..1afb5771e3 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,9 +31,9 @@ public final class PreventLoginServerMaximumSessionsExceededHandler implements S @Override public Mono handle(MaximumSessionsContext context) { - return Mono - .error(new SessionAuthenticationException("Maximum sessions of " + context.getMaximumSessionsAllowed() - + " for authentication '" + context.getAuthentication().getName() + "' exceeded")); + return context.getCurrentSession() + .invalidate() + .then(Mono.defer(() -> Mono.error(new SessionAuthenticationException("Maximum sessions exceeded")))); } } diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java index 60b6107418..3c16e6fdd9 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests { given(session2.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760000L)); given(session2.invalidate()).willReturn(Mono.empty()); MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class), - List.of(session1, session2), 2); + List.of(session1, session2), 2, null); this.handler.handle(context).block(); @@ -72,7 +72,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests { given(session1.invalidate()).willReturn(Mono.empty()); given(session2.invalidate()).willReturn(Mono.empty()); MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class), - List.of(session1, session2, session3), 2); + List.of(session1, session2, session3), 2, null); this.handler.handle(context).block(); diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java index 819489ee43..68f1f09650 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,13 +19,19 @@ package org.springframework.security.web.server.authentication.session; import java.util.Collections; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import org.springframework.security.authentication.TestAuthentication; import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.server.authentication.MaximumSessionsContext; import org.springframework.security.web.server.authentication.PreventLoginServerMaximumSessionsExceededHandler; +import org.springframework.web.server.WebSession; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link PreventLoginServerMaximumSessionsExceededHandler}. @@ -35,13 +41,17 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; class PreventLoginServerMaximumSessionsExceededHandlerTests { @Test - void handleWhenInvokedThenThrowsSessionAuthenticationException() { + void handleWhenInvokedThenInvalidateWebSessionAndThrowsSessionAuthenticationException() { PreventLoginServerMaximumSessionsExceededHandler handler = new PreventLoginServerMaximumSessionsExceededHandler(); + WebSession webSession = mock(); + given(webSession.invalidate()).willReturn(Mono.empty()); MaximumSessionsContext context = new MaximumSessionsContext(TestAuthentication.authenticatedUser(), - Collections.emptyList(), 1); - assertThatExceptionOfType(SessionAuthenticationException.class) - .isThrownBy(() -> handler.handle(context).block()) - .withMessage("Maximum sessions of 1 for authentication 'user' exceeded"); + Collections.emptyList(), 1, webSession); + StepVerifier.create(handler.handle(context)).expectErrorSatisfies((ex) -> { + assertThat(ex).isInstanceOf(SessionAuthenticationException.class); + assertThat(ex.getMessage()).isEqualTo("Maximum sessions exceeded"); + }).verify(); + verify(webSession).invalidate(); } }