From ca9beeac9d517a7ee0356f6c449a2513abe38497 Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Tue, 7 Jul 2015 02:08:39 +0200 Subject: [PATCH] DefaultSubscriptionRegistry uses deep LinkedMultiValueMap copies between accessCache and updateCache Also backported CopyOnWriteArraySet use from 4.2, for defensive iteration over registered subscriptions. Issue: SPR-13185 --- .../broker/DefaultSubscriptionRegistry.java | 26 +++-- .../DefaultSubscriptionRegistryTests.java | 108 +++++++++--------- 2 files changed, 73 insertions(+), 61 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java index 613b2c64564..ebc1b783738 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java @@ -19,11 +19,13 @@ package org.springframework.messaging.simp.broker; import java.util.Collection; import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CopyOnWriteArraySet; import org.springframework.messaging.Message; import org.springframework.util.AntPathMatcher; @@ -32,7 +34,6 @@ import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.PathMatcher; - /** * A default, simple in-memory implementation of {@link SubscriptionRegistry}. * @@ -166,7 +167,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { public void addSubscriptions(String destination, MultiValueMap subscriptions) { synchronized (this.updateCache) { - this.updateCache.put(destination, new LinkedMultiValueMap(subscriptions)); + this.updateCache.put(destination, deepCopy(subscriptions)); this.accessCache.put(destination, subscriptions); } } @@ -178,7 +179,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { if (getPathMatcher().match(destination, cachedDestination)) { MultiValueMap subs = entry.getValue(); subs.add(sessionId, subsId); - this.accessCache.put(cachedDestination, new LinkedMultiValueMap(subs)); + this.accessCache.put(cachedDestination, deepCopy(subs)); } } } @@ -200,7 +201,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { destinationsToRemove.add(destination); } else { - this.accessCache.put(destination, new LinkedMultiValueMap(sessionMap)); + this.accessCache.put(destination, deepCopy(sessionMap)); } } } @@ -222,7 +223,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { destinationsToRemove.add(destination); } else { - this.accessCache.put(destination, new LinkedMultiValueMap(sessionMap)); + this.accessCache.put(destination, deepCopy(sessionMap)); } } } @@ -233,12 +234,21 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { } } + private LinkedMultiValueMap deepCopy(Map> map) { + LinkedMultiValueMap copy = new LinkedMultiValueMap(); + for (Map.Entry> entry : map.entrySet()) { + copy.put(entry.getKey(), new LinkedList(entry.getValue())); + } + return copy; + } + @Override public String toString() { return "cache[" + this.accessCache.size() + " destination(s)]"; } } + /** * Provide access to session subscriptions by sessionId. */ @@ -276,10 +286,11 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @Override public String toString() { - return "registry[" + sessions.size() + " sessions]"; + return "registry[" + this.sessions.size() + " sessions]"; } } + /** * Hold subscriptions for a session. */ @@ -292,7 +303,6 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { private final Object monitor = new Object(); - public SessionSubscriptionInfo(String sessionId) { Assert.notNull(sessionId, "sessionId must not be null"); this.sessionId = sessionId; @@ -316,7 +326,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { synchronized (this.monitor) { subs = this.subscriptions.get(destination); if (subs == null) { - subs = new HashSet(4); + subs = new CopyOnWriteArraySet(); this.subscriptions.put(destination, subs); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java index d14a1588460..8cc1067aa36 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java @@ -34,25 +34,19 @@ import static org.junit.Assert.*; /** - * Test fixture for {@link org.springframework.messaging.simp.broker.DefaultSubscriptionRegistry}. + * Test fixture for + * {@link org.springframework.messaging.simp.broker.DefaultSubscriptionRegistry}. * * @author Rossen Stoyanchev * @author Sebastien Deleuze */ public class DefaultSubscriptionRegistryTests { - private DefaultSubscriptionRegistry registry; - - - @Before - public void setup() { - this.registry = new DefaultSubscriptionRegistry(); - } + private final DefaultSubscriptionRegistry registry = new DefaultSubscriptionRegistry(); @Test public void registerSubscriptionInvalidInput() { - String sessId = "sess01"; String subsId = "subs01"; String dest = "/foo"; @@ -69,7 +63,6 @@ public class DefaultSubscriptionRegistryTests { @Test public void registerSubscription() { - String sessId = "sess01"; String subsId = "subs01"; String dest = "/foo"; @@ -83,7 +76,6 @@ public class DefaultSubscriptionRegistryTests { @Test public void registerSubscriptionOneSession() { - String sessId = "sess01"; List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); String dest = "/foo"; @@ -93,14 +85,12 @@ public class DefaultSubscriptionRegistryTests { } MultiValueMap actual = this.registry.findSubscriptions(message(dest)); - assertEquals("Expected one element " + actual, 1, actual.size()); assertEquals(subscriptionIds, sort(actual.get(sessId))); } @Test public void registerSubscriptionMultipleSessions() { - List sessIds = Arrays.asList("sess01", "sess02", "sess03"); List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); String dest = "/foo"; @@ -112,7 +102,6 @@ public class DefaultSubscriptionRegistryTests { } MultiValueMap actual = this.registry.findSubscriptions(message(dest)); - assertEquals("Expected three elements " + actual, 3, actual.size()); assertEquals(subscriptionIds, sort(actual.get(sessIds.get(0)))); assertEquals(subscriptionIds, sort(actual.get(sessIds.get(1)))); @@ -121,22 +110,18 @@ public class DefaultSubscriptionRegistryTests { @Test public void registerSubscriptionWithDestinationPattern() { - String sessId = "sess01"; String subsId = "subs01"; String destPattern = "/topic/PRICE.STOCK.*.IBM"; String dest = "/topic/PRICE.STOCK.NASDAQ.IBM"; - this.registry.registerSubscription(subscribeMessage(sessId, subsId, destPattern)); - MultiValueMap actual = this.registry.findSubscriptions(message(dest)); + MultiValueMap actual = this.registry.findSubscriptions(message(dest)); assertEquals("Expected one element " + actual, 1, actual.size()); assertEquals(Arrays.asList(subsId), actual.get(sessId)); } - // SPR-11657 - - @Test + @Test // SPR-11657 public void registerSubscriptionsWithSimpleAndPatternDestinations() { String sess1 = "sess01"; @@ -148,6 +133,7 @@ public class DefaultSubscriptionRegistryTests { this.registry.registerSubscription(subscribeMessage(sess1, subs2, "/topic/PRICE.STOCK.NASDAQ.IBM")); this.registry.registerSubscription(subscribeMessage(sess1, subs1, "/topic/PRICE.STOCK.*.IBM")); + MultiValueMap actual = this.registry.findSubscriptions(message("/topic/PRICE.STOCK.NASDAQ.IBM")); assertEquals(1, actual.size()); assertEquals(Arrays.asList(subs2, subs1), actual.get(sess1)); @@ -155,44 +141,47 @@ public class DefaultSubscriptionRegistryTests { this.registry.registerSubscription(subscribeMessage(sess2, subs1, "/topic/PRICE.STOCK.NASDAQ.IBM")); this.registry.registerSubscription(subscribeMessage(sess2, subs2, "/topic/PRICE.STOCK.NYSE.IBM")); this.registry.registerSubscription(subscribeMessage(sess2, subs3, "/topic/PRICE.STOCK.NASDAQ.GOOG")); + actual = this.registry.findSubscriptions(message("/topic/PRICE.STOCK.NASDAQ.IBM")); assertEquals(2, actual.size()); assertEquals(Arrays.asList(subs2, subs1), actual.get(sess1)); assertEquals(Arrays.asList(subs1), actual.get(sess2)); this.registry.unregisterAllSubscriptions(sess1); + actual = this.registry.findSubscriptions(message("/topic/PRICE.STOCK.NASDAQ.IBM")); assertEquals(1, actual.size()); assertEquals(Arrays.asList(subs1), actual.get(sess2)); this.registry.registerSubscription(subscribeMessage(sess1, subs1, "/topic/PRICE.STOCK.*.IBM")); this.registry.registerSubscription(subscribeMessage(sess1, subs2, "/topic/PRICE.STOCK.NASDAQ.IBM")); + actual = this.registry.findSubscriptions(message("/topic/PRICE.STOCK.NASDAQ.IBM")); assertEquals(2, actual.size()); assertEquals(Arrays.asList(subs1, subs2), actual.get(sess1)); assertEquals(Arrays.asList(subs1), actual.get(sess2)); this.registry.unregisterSubscription(unsubscribeMessage(sess1, subs2)); + actual = this.registry.findSubscriptions(message("/topic/PRICE.STOCK.NASDAQ.IBM")); assertEquals(2, actual.size()); assertEquals(Arrays.asList(subs1), actual.get(sess1)); assertEquals(Arrays.asList(subs1), actual.get(sess2)); this.registry.unregisterSubscription(unsubscribeMessage(sess1, subs1)); + actual = this.registry.findSubscriptions(message("/topic/PRICE.STOCK.NASDAQ.IBM")); assertEquals(1, actual.size()); assertEquals(Arrays.asList(subs1), actual.get(sess2)); this.registry.unregisterSubscription(unsubscribeMessage(sess2, subs1)); + actual = this.registry.findSubscriptions(message("/topic/PRICE.STOCK.NASDAQ.IBM")); assertEquals(0, actual.size()); } - // SPR-11755 - - @Test + @Test // SPR-11755 public void registerAndUnregisterMultipleDestinations() { - String sess1 = "sess01"; String sess2 = "sess02"; @@ -219,13 +208,13 @@ public class DefaultSubscriptionRegistryTests { this.registry.registerSubscription(subscribeMessage(sess1, subs3, "/topic/PRICE.STOCK.NASDAQ.GOOG")); this.registry.registerSubscription(subscribeMessage(sess1, subs4, "/topic/PRICE.STOCK.NYSE.IBM")); this.registry.registerSubscription(subscribeMessage(sess2, subs5, "/topic/PRICE.STOCK.NASDAQ.GOOG")); + this.registry.unregisterAllSubscriptions(sess1); this.registry.unregisterAllSubscriptions(sess2); } @Test public void registerSubscriptionWithDestinationPatternRegex() { - String sessId = "sess01"; String subsId = "subs01"; String destPattern = "/topic/PRICE.STOCK.*.{ticker:(IBM|MSFT)}"; @@ -249,9 +238,29 @@ public class DefaultSubscriptionRegistryTests { assertEquals("Expected no elements " + actual, 0, actual.size()); } + @Test // SPR-11931 + public void registerTwiceAndUnregisterSubscriptions() { + this.registry.registerSubscription(subscribeMessage("sess01", "subs01", "/foo")); + this.registry.registerSubscription(subscribeMessage("sess01", "subs02", "/foo")); + + MultiValueMap actual = this.registry.findSubscriptions(message("/foo")); + assertEquals("Expected 1 element", 1, actual.size()); + assertEquals(Arrays.asList("subs01", "subs02"), actual.get("sess01")); + + this.registry.unregisterSubscription(unsubscribeMessage("sess01", "subs01")); + + actual = this.registry.findSubscriptions(message("/foo")); + assertEquals("Expected 1 element", 1, actual.size()); + assertEquals(Arrays.asList("subs02"), actual.get("sess01")); + + this.registry.unregisterSubscription(unsubscribeMessage("sess01", "subs02")); + + actual = this.registry.findSubscriptions(message("/foo")); + assertEquals("Expected no element", 0, actual.size()); + } + @Test public void unregisterSubscription() { - List sessIds = Arrays.asList("sess01", "sess02", "sess03"); List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); String dest = "/foo"; @@ -267,36 +276,13 @@ public class DefaultSubscriptionRegistryTests { this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(2))); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); - assertEquals("Expected two elements: " + actual, 2, actual.size()); assertEquals(subscriptionIds, sort(actual.get(sessIds.get(1)))); assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2)))); } - // SPR-11931 - - @Test - public void registerTwiceAndUnregisterSubscriptions() { - - this.registry.registerSubscription(subscribeMessage("sess01", "subs01", "/foo")); - this.registry.registerSubscription(subscribeMessage("sess01", "subs02", "/foo")); - MultiValueMap actual = this.registry.findSubscriptions(message("/foo")); - assertEquals("Expected 1 element", 1, actual.size()); - assertEquals(Arrays.asList("subs01", "subs02"), actual.get("sess01")); - - this.registry.unregisterSubscription(unsubscribeMessage("sess01", "subs01")); - actual = this.registry.findSubscriptions(message("/foo")); - assertEquals("Expected 1 element", 1, actual.size()); - assertEquals(Arrays.asList("subs02"), actual.get("sess01")); - - this.registry.unregisterSubscription(unsubscribeMessage("sess01", "subs02")); - actual = this.registry.findSubscriptions(message("/foo")); - assertEquals("Expected no element", 0, actual.size()); - } - @Test public void unregisterAllSubscriptions() { - List sessIds = Arrays.asList("sess01", "sess02", "sess03"); List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); String dest = "/foo"; @@ -311,7 +297,6 @@ public class DefaultSubscriptionRegistryTests { this.registry.unregisterAllSubscriptions(sessIds.get(1)); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); - assertEquals("Expected one element: " + actual, 1, actual.size()); assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2)))); } @@ -328,9 +313,7 @@ public class DefaultSubscriptionRegistryTests { assertEquals("Expected no elements " + actual, 0, actual.size()); } - // SPR-12665 - - @Test + @Test // SPR-12665 public void findSubscriptionsReturnsMapSafeToIterate() throws Exception { this.registry.registerSubscription(subscribeMessage("sess1", "1", "/foo")); this.registry.registerSubscription(subscribeMessage("sess2", "1", "/foo")); @@ -346,6 +329,25 @@ public class DefaultSubscriptionRegistryTests { // no ConcurrentModificationException } + @Test // SPR-13185 + public void findSubscriptionsReturnsMapSafeToIterateIncludingValues() throws Exception { + this.registry.registerSubscription(subscribeMessage("sess1", "1", "/foo")); + this.registry.registerSubscription(subscribeMessage("sess1", "2", "/foo")); + + MultiValueMap allSubscriptions = this.registry.findSubscriptions(message("/foo")); + assertNotNull(allSubscriptions); + assertEquals(1, allSubscriptions.size()); + + Iterator iteratorValues = allSubscriptions.get("sess1").iterator(); + iteratorValues.next(); + + this.registry.unregisterSubscription(unsubscribeMessage("sess1", "2")); + + iteratorValues.next(); + // no ConcurrentModificationException + } + + private Message subscribeMessage(String sessionId, String subscriptionId, String destination) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); headers.setSessionId(sessionId);