diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistry.java b/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistry.java index 03ee40b5e30..852b745e8ed 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistry.java @@ -33,4 +33,6 @@ public interface SessionSubscriptionRegistry { Set getSessionSubscriptions(String sessionId, String destination); + Set getRegistrationsByDestination(String destination); + } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandler.java index 505bb0b7fb2..651ed4b8708 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandler.java @@ -40,7 +40,7 @@ public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler { private final MessageChannel clientChannel; - private CachingSessionSubscriptionRegistry subscriptionRegistry= + private SessionSubscriptionRegistry subscriptionRegistry= new CachingSessionSubscriptionRegistry(new DefaultSessionSubscriptionRegistry()); @@ -56,7 +56,7 @@ public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler { public void setSubscriptionRegistry(SessionSubscriptionRegistry subscriptionRegistry) { Assert.notNull(subscriptionRegistry, "subscriptionRegistry is required"); - this.subscriptionRegistry = new CachingSessionSubscriptionRegistry(subscriptionRegistry); + this.subscriptionRegistry = subscriptionRegistry; } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java index 66326986928..5b71138e921 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java @@ -28,8 +28,7 @@ import org.springframework.web.messaging.SessionSubscriptionRegistry; /** * A decorator for a {@link SessionSubscriptionRegistry} that intercepts subscriptions - * being added and removed, and maintains a cache that tracks registrations for a - * given destination. + * being added and removed and maintains a lookup cache of registrations by destination. * * @author Rossen Stoyanchev * @since 4.0 @@ -49,7 +48,8 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe @Override public SessionSubscriptionRegistration getRegistration(String sessionId) { - return new CachingSessionSubscriptionRegistration(this.delegate.getRegistration(sessionId)); + SessionSubscriptionRegistration reg = this.delegate.getRegistration(sessionId); + return (reg != null) ? new CachingSessionSubscriptionRegistration(reg) : null; } @Override @@ -71,6 +71,7 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe return this.delegate.getSessionSubscriptions(sessionId, destination); } + @Override public Set getRegistrationsByDestination(String destination) { return this.destinationCache.getRegistrations(destination); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistry.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistry.java index 53b3f38fa4e..3f5bf75a426 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistry.java @@ -16,6 +16,7 @@ package org.springframework.web.messaging.support; +import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -63,4 +64,20 @@ public class DefaultSessionSubscriptionRegistry implements SessionSubscriptionRe return (registration != null) ? registration.getSubscriptionsByDestination(destination) : null; } + /** + * The default implementation performs a lookup by destination on each registration. + * For a more efficient algorithm consider decorating an instance of this class with + * {@link CachingSessionSubscriptionRegistry}. + */ + @Override + public Set getRegistrationsByDestination(String destination) { + Set result = new HashSet(); + for (SessionSubscriptionRegistration r : this.registrations.values()) { + if (r.getSubscriptionsByDestination(destination) != null) { + result.add(r); + } + } + return result.isEmpty() ? null : result; + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/support/AbstractSessionSubscriptionRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/support/AbstractSessionSubscriptionRegistryTests.java new file mode 100644 index 00000000000..55c0274803d --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/support/AbstractSessionSubscriptionRegistryTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.web.messaging.SessionSubscriptionRegistration; +import org.springframework.web.messaging.SessionSubscriptionRegistry; + +import static org.junit.Assert.*; + + +/** + * A test fixture for {@link AbstractSessionSubscriptionRegistry}. + * + * @author Rossen Stoyanchev + */ +public abstract class AbstractSessionSubscriptionRegistryTests { + + protected SessionSubscriptionRegistry registry; + + + @Before + public void setup() { + this.registry = createSessionSubscriptionRegistry(); + } + + protected abstract SessionSubscriptionRegistry createSessionSubscriptionRegistry(); + + + @Test + public void getRegistration() { + String sessionId = "sess1"; + assertNull(this.registry.getRegistration(sessionId)); + + this.registry.getOrCreateRegistration(sessionId); + assertNotNull(this.registry.getRegistration(sessionId)); + assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId()); + } + + @Test + public void getOrCreateRegistration() { + String sessionId = "sess1"; + assertNull(this.registry.getRegistration(sessionId)); + + SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId); + assertEquals(registration, this.registry.getOrCreateRegistration(sessionId)); + } + + @Test + public void removeRegistration() { + String sessionId = "sess1"; + this.registry.getOrCreateRegistration(sessionId); + assertNotNull(this.registry.getRegistration(sessionId)); + assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId()); + + this.registry.removeRegistration(sessionId); + assertNull(this.registry.getRegistration(sessionId)); + } + + @Test + public void getSessionSubscriptions() { + String sessionId = "sess1"; + SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId); + registration.addSubscription("/foo", "sub1"); + registration.addSubscription("/foo", "sub2"); + + Set subscriptions = this.registry.getSessionSubscriptions(sessionId, "/foo"); + assertEquals("Wrong number of subscriptions " + subscriptions, 2, subscriptions.size()); + assertTrue(subscriptions.contains("sub1")); + assertTrue(subscriptions.contains("sub2")); + } + + @Test + public void getRegistrationsByDestination() { + + SessionSubscriptionRegistration reg1 = this.registry.getOrCreateRegistration("sess1"); + reg1.addSubscription("/foo", "sub1"); + + SessionSubscriptionRegistration reg2 = this.registry.getOrCreateRegistration("sess2"); + reg2.addSubscription("/foo", "sub1"); + + Set actual = this.registry.getRegistrationsByDestination("/foo"); + assertEquals(2, actual.size()); + assertTrue(actual.contains(reg1)); + assertTrue(actual.contains(reg2)); + + reg1.removeSubscription("sub1"); + + actual = this.registry.getRegistrationsByDestination("/foo"); + assertEquals("Invalid set of registrations " + actual, 1, actual.size()); + assertTrue(actual.contains(reg2)); + + reg2.removeSubscription("sub1"); + + actual = this.registry.getRegistrationsByDestination("/foo"); + assertNull("Unexpected registrations " + actual, actual); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java index aed8431b6af..b2ecf87c91d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java @@ -16,56 +16,20 @@ package org.springframework.web.messaging.support; -import java.util.Set; - -import org.junit.Before; -import org.junit.Test; -import org.springframework.web.messaging.SessionSubscriptionRegistration; import org.springframework.web.messaging.SessionSubscriptionRegistry; -import static org.junit.Assert.*; - /** * Test fixture for {@link CachingSessionSubscriptionRegistry}. * * @author Rossen Stoyanchev */ -public class CachingSessionSubscriptionRegistryTests { - - private CachingSessionSubscriptionRegistry registry; - - - @Before - public void setup() { - SessionSubscriptionRegistry delegate = new DefaultSessionSubscriptionRegistry(); - this.registry = new CachingSessionSubscriptionRegistry(delegate); - } - - @Test - public void getRegistrationsByDestination() { - - SessionSubscriptionRegistration reg1 = this.registry.getOrCreateRegistration("sess1"); - reg1.addSubscription("/foo", "sub1"); - - SessionSubscriptionRegistration reg2 = this.registry.getOrCreateRegistration("sess2"); - reg2.addSubscription("/foo", "sub1"); - - Set actual = this.registry.getRegistrationsByDestination("/foo"); - assertEquals(2, actual.size()); - assertTrue(actual.contains(reg1)); - assertTrue(actual.contains(reg2)); - - reg1.removeSubscription("sub1"); - - actual = this.registry.getRegistrationsByDestination("/foo"); - assertEquals("Invalid set of registrations " + actual, 1, actual.size()); - assertTrue(actual.contains(reg2)); +public class CachingSessionSubscriptionRegistryTests extends AbstractSessionSubscriptionRegistryTests { - reg2.removeSubscription("sub1"); - actual = this.registry.getRegistrationsByDestination("/foo"); - assertNull("Unexpected registrations " + actual, actual); + @Override + protected SessionSubscriptionRegistry createSessionSubscriptionRegistry() { + return new CachingSessionSubscriptionRegistry(new DefaultSessionSubscriptionRegistry()); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistryTests.java index 46128d631a6..c7021f3c35c 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistryTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistryTests.java @@ -16,13 +16,7 @@ package org.springframework.web.messaging.support; -import java.util.Set; - -import org.junit.Before; -import org.junit.Test; -import org.springframework.web.messaging.SessionSubscriptionRegistration; - -import static org.junit.Assert.*; +import org.springframework.web.messaging.SessionSubscriptionRegistry; /** @@ -30,57 +24,12 @@ import static org.junit.Assert.*; * * @author Rossen Stoyanchev */ -public class DefaultSessionSubscriptionRegistryTests { - - private DefaultSessionSubscriptionRegistry registry; - - - @Before - public void setup() { - this.registry = new DefaultSessionSubscriptionRegistry(); - } - - @Test - public void getRegistration() { - String sessionId = "sess1"; - assertNull(this.registry.getRegistration(sessionId)); - - this.registry.getOrCreateRegistration(sessionId); - assertNotNull(this.registry.getRegistration(sessionId)); - assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId()); - } - - @Test - public void getOrCreateRegistration() { - String sessionId = "sess1"; - assertNull(this.registry.getRegistration(sessionId)); - - SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId); - assertSame(registration, this.registry.getOrCreateRegistration(sessionId)); - } - - @Test - public void removeRegistration() { - String sessionId = "sess1"; - this.registry.getOrCreateRegistration(sessionId); - assertNotNull(this.registry.getRegistration(sessionId)); - assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId()); - - this.registry.removeRegistration(sessionId); - assertNull(this.registry.getRegistration(sessionId)); - } +public class DefaultSessionSubscriptionRegistryTests extends AbstractSessionSubscriptionRegistryTests { - @Test - public void getSessionSubscriptions() { - String sessionId = "sess1"; - SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId); - registration.addSubscription("/foo", "sub1"); - registration.addSubscription("/foo", "sub2"); - Set subscriptions = this.registry.getSessionSubscriptions(sessionId, "/foo"); - assertEquals("Wrong number of subscriptions " + subscriptions, 2, subscriptions.size()); - assertTrue(subscriptions.contains("sub1")); - assertTrue(subscriptions.contains("sub2")); + @Override + protected SessionSubscriptionRegistry createSessionSubscriptionRegistry() { + return new DefaultSessionSubscriptionRegistry(); } }