diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java index 7e24fafd4d3..f155ec70f1d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java @@ -135,6 +135,16 @@ public class MultiServerUserRegistry implements SimpUserRegistry, SmartApplicati return result; } + @Override + public int getUserCount() { + int userCount = 0; + for (UserRegistrySnapshot registry : this.remoteRegistries.values()) { + userCount += registry.getUserMap().size(); + } + userCount += this.localRegistry.getUserCount(); + return userCount; + } + @Override public Set findSubscriptions(SimpSubscriptionMatcher matcher) { Set result = new HashSet(); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java index a989f628fa8..ea5f9ab3422 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,21 +29,28 @@ public interface SimpUserRegistry { /** * Get the user for the given name. * @param userName the name of the user to look up - * @return the user or {@code null} if not connected + * @return the user, or {@code null} if not connected */ SimpUser getUser(String userName); /** - * Return a snapshot of all connected users. The returned set is a copy and - * will never be modified. - * @return the connected users or an empty set. + * Return a snapshot of all connected users. + *

The returned set is a copy and will not reflect further changes. + * @return the connected users, or an empty set if none */ Set getUsers(); + /** + * Return the count of all connected users. + * @return the number of connected users + * @since 4.3.5 + */ + int getUserCount(); + /** * Find subscriptions with the given matcher. * @param matcher the matcher to use - * @return a set of matching subscriptions or an empty set. + * @return a set of matching subscriptions, or an empty set if none */ Set findSubscriptions(SimpSubscriptionMatcher matcher); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistryAdapter.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistryAdapter.java index 9be049898ec..c61598ab969 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistryAdapter.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistryAdapter.java @@ -59,6 +59,11 @@ public class UserSessionRegistryAdapter implements SimpUserRegistry { throw new UnsupportedOperationException("UserSessionRegistry does not expose a listing of users"); } + @Override + public int getUserCount() { + throw new UnsupportedOperationException("UserSessionRegistry does not expose a user count"); + } + @Override public Set findSubscriptions(SimpSubscriptionMatcher matcher) { throw new UnsupportedOperationException("UserSessionRegistry does not support operations across users"); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java index 5ce5dd708a8..2e93ce20834 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java @@ -61,9 +61,10 @@ public class MultiServerUserRegistryTests { SimpUser user = Mockito.mock(SimpUser.class); Set users = Collections.singleton(user); when(this.localRegistry.getUsers()).thenReturn(users); + when(this.localRegistry.getUserCount()).thenReturn(1); when(this.localRegistry.getUser("joe")).thenReturn(user); - assertEquals(1, this.registry.getUsers().size()); + assertEquals(1, this.registry.getUserCount()); assertSame(user, this.registry.getUser("joe")); } @@ -84,7 +85,7 @@ public class MultiServerUserRegistryTests { this.registry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(1, this.registry.getUsers().size()); + assertEquals(1, this.registry.getUserCount()); SimpUser user = this.registry.getUser("joe"); assertNotNull(user); assertTrue(user.hasSessions()); @@ -125,7 +126,7 @@ public class MultiServerUserRegistryTests { this.registry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(3, this.registry.getUsers().size()); + assertEquals(3, this.registry.getUserCount()); Set matches = this.registry.findSubscriptions(s -> s.getDestination().equals("/match")); assertEquals(2, matches.size()); Iterator iterator = matches.iterator(); @@ -157,7 +158,7 @@ public class MultiServerUserRegistryTests { this.registry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(1, this.registry.getUsers().size()); + assertEquals(1, this.registry.getUserCount()); SimpUser user = this.registry.getUsers().iterator().next(); assertTrue(user.hasSessions()); assertEquals(2, user.getSessions().size()); @@ -187,9 +188,9 @@ public class MultiServerUserRegistryTests { this.registry.addRemoteRegistryDto(message, this.converter, -1); - assertEquals(1, this.registry.getUsers().size()); + assertEquals(1, this.registry.getUserCount()); this.registry.purgeExpiredRegistries(); - assertEquals(0, this.registry.getUsers().size()); + assertEquals(0, this.registry.getUserCount()); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java index 4f441c9e505..790181be25b 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java @@ -126,7 +126,7 @@ public class UserRegistryMessageHandlerTests { MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(mock(SimpUserRegistry.class)); remoteRegistry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(2, remoteRegistry.getUsers().size()); + assertEquals(2, remoteRegistry.getUserCount()); assertNotNull(remoteRegistry.getUser("joe")); assertNotNull(remoteRegistry.getUser("jane")); } @@ -142,6 +142,7 @@ public class UserRegistryMessageHandlerTests { HashSet simpUsers = new HashSet<>(Arrays.asList(simpUser1, simpUser2)); SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class); + when(remoteUserRegistry.getUserCount()).thenReturn(2); when(remoteUserRegistry.getUsers()).thenReturn(simpUsers); MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry); @@ -149,7 +150,7 @@ public class UserRegistryMessageHandlerTests { this.handler.handleMessage(message); - assertEquals(2, remoteRegistry.getUsers().size()); + assertEquals(2, remoteRegistry.getUserCount()); assertNotNull(this.multiServerRegistry.getUser("joe")); assertNotNull(this.multiServerRegistry.getUser("jane")); } @@ -159,13 +160,14 @@ public class UserRegistryMessageHandlerTests { TestSimpUser simpUser = new TestSimpUser("joe"); simpUser.addSessions(new TestSimpSession("123")); + when(this.localRegistry.getUserCount()).thenReturn(1); when(this.localRegistry.getUsers()).thenReturn(Collections.singleton(simpUser)); - assertEquals(1, this.multiServerRegistry.getUsers().size()); + assertEquals(1, this.multiServerRegistry.getUserCount()); Message message = this.converter.toMessage(this.multiServerRegistry.getLocalRegistryDto(), null); this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(1, this.multiServerRegistry.getUsers().size()); + assertEquals(1, this.multiServerRegistry.getUserCount()); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java index 9ea63ea25fc..90a1aac47d9 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java @@ -141,6 +141,11 @@ public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicati return new HashSet(this.users.values()); } + @Override + public int getUserCount() { + return this.users.size(); + } + public Set findSubscriptions(SimpSubscriptionMatcher matcher) { Set result = new HashSet(); for (LocalSimpSession session : this.sessions.values()) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java index cbccc5a567a..bd43f95e13d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java @@ -57,6 +57,7 @@ public class DefaultSimpUserRegistryTests { SimpUser simpUser = registry.getUser("joe"); assertNotNull(simpUser); + assertEquals(1, registry.getUserCount()); assertEquals(1, simpUser.getSessions().size()); assertNotNull(simpUser.getSession("123")); } @@ -82,6 +83,7 @@ public class DefaultSimpUserRegistryTests { SimpUser simpUser = registry.getUser("joe"); assertNotNull(simpUser); + assertEquals(1, registry.getUserCount()); assertEquals(3, simpUser.getSessions().size()); assertNotNull(simpUser.getSession("123")); assertNotNull(simpUser.getSession("456"));