Browse Source

Polish Max Sessions on WebFlux

This commit changes the PreventLoginServerMaximumSessionsExceededHandler to invalidate the WebSession in addition to throwing the error, this is needed otherwise the session would still be saved with the security context. It also changes the SessionRegistryWebSession to first perform the operation on the delegate and then invoke the needed method on the ReactiveSessionRegistry

Issue gh-6192
pull/14659/head
Marcus Hert Da Coregio 2 years ago
parent
commit
a5ce8ae87f
  1. 20
      config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java
  2. 10
      config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java
  3. 4
      web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java
  4. 12
      web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java
  5. 8
      web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java
  6. 6
      web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java
  7. 24
      web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java

20
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@ -2143,19 +2143,23 @@ public class ServerHttpSecurity { @@ -2143,19 +2143,23 @@ public class ServerHttpSecurity {
@Override
public Mono<Void> changeSessionId() {
String currentId = this.session.getId();
return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId)
.flatMap((information) -> this.session.changeSessionId().thenReturn(information))
.flatMap((information) -> {
information = information.withSessionId(this.session.getId());
return SessionRegistryWebFilter.this.sessionRegistry.saveSessionInformation(information);
});
return this.session.changeSessionId()
.then(Mono.defer(
() -> SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId)
.flatMap((information) -> {
information = information.withSessionId(this.session.getId());
return SessionRegistryWebFilter.this.sessionRegistry
.saveSessionInformation(information);
})));
}
@Override
public Mono<Void> invalidate() {
String currentId = this.session.getId();
return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId)
.flatMap((information) -> this.session.invalidate());
return this.session.invalidate()
.then(Mono.defer(() -> SessionRegistryWebFilter.this.sessionRegistry
.removeSessionInformation(currentId)))
.then();
}
@Override

10
config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java

@ -67,6 +67,7 @@ import org.springframework.web.reactive.function.BodyInserters; @@ -67,6 +67,7 @@ import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import org.springframework.web.server.session.DefaultWebSessionManager;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
@ -95,14 +96,19 @@ public class SessionManagementSpecTests { @@ -95,14 +96,19 @@ public class SessionManagementSpecTests {
ResponseCookie firstLoginSessionCookie = loginReturningCookie(data);
// second login should fail
this.client.mutateWith(csrf())
ResponseCookie secondLoginSessionCookie = this.client.mutateWith(csrf())
.post()
.uri("/login")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(BodyInserters.fromFormData(data))
.exchange()
.expectHeader()
.location("/login?error");
.location("/login?error")
.returnResult(Void.class)
.getResponseCookies()
.getFirst("SESSION");
assertThat(secondLoginSessionCookie).isNull();
// first login should still be valid
this.client.mutateWith(csrf())

4
web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java

@ -81,8 +81,8 @@ public final class ConcurrentSessionControlServerAuthenticationSuccessHandler @@ -81,8 +81,8 @@ public final class ConcurrentSessionControlServerAuthenticationSuccessHandler
}
}
}
return this.maximumSessionsExceededHandler
.handle(new MaximumSessionsContext(authentication, registeredSessions, maximumSessions));
return this.maximumSessionsExceededHandler.handle(new MaximumSessionsContext(authentication,
registeredSessions, maximumSessions, currentSession));
});
}

12
web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import java.util.List; @@ -20,6 +20,7 @@ import java.util.List;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.session.ReactiveSessionInformation;
import org.springframework.web.server.WebSession;
public final class MaximumSessionsContext {
@ -29,11 +30,14 @@ public final class MaximumSessionsContext { @@ -29,11 +30,14 @@ public final class MaximumSessionsContext {
private final int maximumSessionsAllowed;
private final WebSession currentSession;
public MaximumSessionsContext(Authentication authentication, List<ReactiveSessionInformation> sessions,
int maximumSessionsAllowed) {
int maximumSessionsAllowed, WebSession currentSession) {
this.authentication = authentication;
this.sessions = sessions;
this.maximumSessionsAllowed = maximumSessionsAllowed;
this.currentSession = currentSession;
}
public Authentication getAuthentication() {
@ -48,4 +52,8 @@ public final class MaximumSessionsContext { @@ -48,4 +52,8 @@ public final class MaximumSessionsContext {
return this.maximumSessionsAllowed;
}
public WebSession getCurrentSession() {
return this.currentSession;
}
}

8
web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,9 +31,9 @@ public final class PreventLoginServerMaximumSessionsExceededHandler implements S @@ -31,9 +31,9 @@ public final class PreventLoginServerMaximumSessionsExceededHandler implements S
@Override
public Mono<Void> handle(MaximumSessionsContext context) {
return Mono
.error(new SessionAuthenticationException("Maximum sessions of " + context.getMaximumSessionsAllowed()
+ " for authentication '" + context.getAuthentication().getName() + "' exceeded"));
return context.getCurrentSession()
.invalidate()
.then(Mono.defer(() -> Mono.error(new SessionAuthenticationException("Maximum sessions exceeded"))));
}
}

6
web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -50,7 +50,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests { @@ -50,7 +50,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests {
given(session2.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760000L));
given(session2.invalidate()).willReturn(Mono.empty());
MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class),
List.of(session1, session2), 2);
List.of(session1, session2), 2, null);
this.handler.handle(context).block();
@ -72,7 +72,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests { @@ -72,7 +72,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests {
given(session1.invalidate()).willReturn(Mono.empty());
given(session2.invalidate()).willReturn(Mono.empty());
MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class),
List.of(session1, session2, session3), 2);
List.of(session1, session2, session3), 2, null);
this.handler.handle(context).block();

24
web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,13 +19,19 @@ package org.springframework.security.web.server.authentication.session; @@ -19,13 +19,19 @@ package org.springframework.security.web.server.authentication.session;
import java.util.Collections;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.web.authentication.session.SessionAuthenticationException;
import org.springframework.security.web.server.authentication.MaximumSessionsContext;
import org.springframework.security.web.server.authentication.PreventLoginServerMaximumSessionsExceededHandler;
import org.springframework.web.server.WebSession;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link PreventLoginServerMaximumSessionsExceededHandler}.
@ -35,13 +41,17 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -35,13 +41,17 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
class PreventLoginServerMaximumSessionsExceededHandlerTests {
@Test
void handleWhenInvokedThenThrowsSessionAuthenticationException() {
void handleWhenInvokedThenInvalidateWebSessionAndThrowsSessionAuthenticationException() {
PreventLoginServerMaximumSessionsExceededHandler handler = new PreventLoginServerMaximumSessionsExceededHandler();
WebSession webSession = mock();
given(webSession.invalidate()).willReturn(Mono.empty());
MaximumSessionsContext context = new MaximumSessionsContext(TestAuthentication.authenticatedUser(),
Collections.emptyList(), 1);
assertThatExceptionOfType(SessionAuthenticationException.class)
.isThrownBy(() -> handler.handle(context).block())
.withMessage("Maximum sessions of 1 for authentication 'user' exceeded");
Collections.emptyList(), 1, webSession);
StepVerifier.create(handler.handle(context)).expectErrorSatisfies((ex) -> {
assertThat(ex).isInstanceOf(SessionAuthenticationException.class);
assertThat(ex.getMessage()).isEqualTo("Maximum sessions exceeded");
}).verify();
verify(webSession).invalidate();
}
}

Loading…
Cancel
Save