Browse Source

DefaultSubscriptionRegistry: Reduced thread contention

* DestinationCache is now synchronized on multiple 'destination' locks
 (previously a single shared lock)
* DestinationCache keeps destinations without any subscriptions
 (previously such destinations were recomputed over and over)
* SessionSubscriptionRegistry is now a
 'sessionId -> subscriptionId -> (destination,selector)' map
 for faster lookups
 (previously 'sessionId -> destination -> set of (subscriptionId,selector)')

closes gh-24395
pull/25401/head
Tomas Drabek 6 years ago committed by Rossen Stoyanchev
parent
commit
524ca1a676
  1. 1
      spring-messaging/spring-messaging.gradle
  2. 192
      spring-messaging/src/jmh/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryBenchmark.java
  3. 415
      spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java
  4. 21
      spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java

1
spring-messaging/spring-messaging.gradle

@ -34,4 +34,5 @@ dependencies { @@ -34,4 +34,5 @@ dependencies {
testRuntime("com.sun.xml.bind:jaxb-core")
testRuntime("com.sun.xml.bind:jaxb-impl")
testRuntime("com.sun.activation:javax.activation")
testRuntime(project(":spring-context"))
}

192
spring-messaging/src/jmh/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryBenchmark.java

@ -0,0 +1,192 @@ @@ -0,0 +1,192 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.broker;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.infra.Blackhole;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.MultiValueMap;
@BenchmarkMode(Mode.Throughput)
public class DefaultSubscriptionRegistryBenchmark {
@State(Scope.Benchmark)
public static class ServerState {
@Param("1000")
public int sessions;
@Param("10")
public int destinations;
@Param({"0", "1024"})
int cacheSizeLimit;
@Param({"none", "patternSubscriptions", "selectorHeaders"})
String specialization;
public DefaultSubscriptionRegistry registry;
public String[] destinationIds;
public String[] sessionIds;
public AtomicInteger uniqueIdGenerator;
public Message<?> findMessage;
@Setup(Level.Trial)
public void doSetup() {
this.findMessage = MessageBuilder.createMessage("", SimpMessageHeaderAccessor.create().getMessageHeaders());
this.uniqueIdGenerator = new AtomicInteger();
this.registry = new DefaultSubscriptionRegistry();
this.registry.setCacheLimit(this.cacheSizeLimit);
this.registry.setSelectorHeaderName("selectorHeaders".equals(this.specialization) ? "someSelector" : null);
this.destinationIds = IntStream.range(0, this.destinations)
.mapToObj(i -> "/some/destination/" + i)
.toArray(String[]::new);
this.sessionIds = IntStream.range(0, this.sessions)
.mapToObj(i -> "sessionId_" + i)
.toArray(String[]::new);
for (String sessionId : this.sessionIds) {
for (String destinationId : this.destinationIds) {
registerSubscriptions(sessionId, destinationId);
}
}
}
public void registerSubscriptions(String sessionId, String destination) {
if ("patternSubscriptions".equals(this.specialization)) {
destination = "/**/" + destination;
}
String subscriptionId = "subscription_" + this.uniqueIdGenerator.incrementAndGet();
this.registry.registerSubscription(subscribeMessage(sessionId, subscriptionId, destination));
}
}
@State(Scope.Thread)
public static class Requests {
@Param({"none", "sameDestination", "sameSession"})
String contention;
public String session;
public Message<?> subscribe;
public String findDestination;
public Message<?> unsubscribe;
@Setup(Level.Trial)
public void doSetup(ServerState serverState) {
int uniqueNumber = serverState.uniqueIdGenerator.incrementAndGet();
if ("sameDestination".equals(this.contention)) {
this.findDestination = serverState.destinationIds[0];
}
else {
this.findDestination = serverState.destinationIds[uniqueNumber % serverState.destinationIds.length];
}
if ("sameSession".equals(this.contention)) {
this.session = serverState.sessionIds[0];
}
else {
this.session = serverState.sessionIds[uniqueNumber % serverState.sessionIds.length];
}
String subscription = String.valueOf(uniqueNumber);
String subscribeDestination = "patternSubscriptions".equals(serverState.specialization) ?
"/**/" + this.findDestination : this.findDestination;
this.subscribe = subscribeMessage(this.session, subscription, subscribeDestination);
this.unsubscribe = unsubscribeMessage(this.session, subscription);
}
}
@State(Scope.Thread)
public static class FindRequest {
@Param({"none", "noSubscribers", "sameDestination"})
String contention;
public String destination;
@Setup(Level.Trial)
public void doSetup(ServerState serverState) {
switch (this.contention) {
case "noSubscribers":
this.destination = "someDestination_withNoSubscribers_" + serverState.uniqueIdGenerator.incrementAndGet();
break;
case "sameDestination":
this.destination = serverState.destinationIds[0];
break;
case "none":
int uniqueNumber = serverState.uniqueIdGenerator.getAndIncrement();
this.destination = serverState.destinationIds[uniqueNumber % serverState.destinationIds.length];
break;
default:
throw new IllegalStateException();
}
}
}
@Benchmark
public void registerUnregister(ServerState serverState, Requests request, Blackhole blackhole) {
serverState.registry.registerSubscription(request.subscribe);
blackhole.consume(serverState.registry.findSubscriptionsInternal(request.findDestination, serverState.findMessage));
serverState.registry.unregisterSubscription(request.unsubscribe);
blackhole.consume(serverState.registry.findSubscriptionsInternal(request.findDestination, serverState.findMessage));
}
@Benchmark
public MultiValueMap<String, String> find(ServerState serverState, FindRequest request) {
return serverState.registry.findSubscriptionsInternal(request.destination, serverState.findMessage);
}
public static Message<?> subscribeMessage(String sessionId, String subscriptionId, String dest) {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE);
accessor.setSessionId(sessionId);
accessor.setSubscriptionId(subscriptionId);
accessor.setDestination(dest);
accessor.setNativeHeader("someSelector", "true");
return MessageBuilder.createMessage("", accessor.getMessageHeaders());
}
public static Message<?> unsubscribeMessage(String sessionId, String subscriptionId) {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE);
accessor.setSessionId(sessionId);
accessor.setSubscriptionId(subscriptionId);
return MessageBuilder.createMessage("", accessor.getMessageHeaders());
}
}

415
spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java

@ -16,15 +16,17 @@ @@ -16,15 +16,17 @@
package org.springframework.messaging.simp.broker;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
@ -34,6 +36,7 @@ import org.springframework.expression.TypedValue; @@ -34,6 +36,7 @@ import org.springframework.expression.TypedValue;
import org.springframework.expression.spel.SpelEvaluationException;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.SimpleEvaluationContext;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
@ -72,7 +75,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -72,7 +75,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
private PathMatcher pathMatcher = new AntPathMatcher();
private volatile int cacheLimit = DEFAULT_CACHE_LIMIT;
private int cacheLimit = DEFAULT_CACHE_LIMIT;
@Nullable
private String selectorHeaderName = "selector";
@ -106,6 +109,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -106,6 +109,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
*/
public void setCacheLimit(int cacheLimit) {
this.cacheLimit = cacheLimit;
this.destinationCache.ensureCacheLimit();
}
/**
@ -142,14 +146,17 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -142,14 +146,17 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
return this.selectorHeaderName;
}
@Override
protected void addSubscriptionInternal(
String sessionId, String subsId, String destination, Message<?> message) {
protected void addSubscriptionInternal(@NonNull String sessionId, @NonNull String subscriptionId,
@NonNull String destination, @NonNull Message<?> message) {
Expression expression = getSelectorExpression(message.getHeaders());
this.subscriptionRegistry.addSubscription(sessionId, subsId, destination, expression);
this.destinationCache.updateAfterNewSubscription(destination, sessionId, subsId);
boolean isAntPattern = this.pathMatcher.isPattern(destination);
Subscription subscription = new Subscription(subscriptionId, expression, destination, isAntPattern);
Subscription previousValue = this.subscriptionRegistry.addSubscription(sessionId, subscriptionId, subscription);
if (previousValue == null) {
this.destinationCache.updateAfterNewSubscription(destination, isAntPattern, sessionId, subscriptionId);
}
}
@Nullable
@ -179,9 +186,9 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -179,9 +186,9 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
protected void removeSubscriptionInternal(String sessionId, String subsId, Message<?> message) {
SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId);
if (info != null) {
String destination = info.removeSubscription(subsId);
if (destination != null) {
this.destinationCache.updateAfterRemovedSubscription(sessionId, subsId);
Subscription subscription = info.removeSubscription(subsId);
if (subscription != null) {
this.destinationCache.updateAfterRemovedSubscription(sessionId, subscription);
}
}
}
@ -190,13 +197,13 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -190,13 +197,13 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
public void unregisterAllSubscriptions(String sessionId) {
SessionSubscriptionInfo info = this.subscriptionRegistry.removeSubscriptions(sessionId);
if (info != null) {
this.destinationCache.updateAfterRemovedSession(info);
this.destinationCache.updateAfterRemovedSession(sessionId, info.getSubscriptions());
}
}
@Override
protected MultiValueMap<String, String> findSubscriptionsInternal(String destination, Message<?> message) {
MultiValueMap<String, String> result = this.destinationCache.getSubscriptions(destination, message);
MultiValueMap<String, String> result = this.destinationCache.getSubscriptions(destination);
return filterSubscriptions(result, message);
}
@ -207,168 +214,181 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -207,168 +214,181 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
return allMatches;
}
MultiValueMap<String, String> result = new LinkedMultiValueMap<>(allMatches.size());
allMatches.forEach((sessionId, subIds) -> {
for (String subId : subIds) {
SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId);
if (info == null) {
continue;
}
Subscription sub = info.getSubscription(subId);
if (sub == null) {
continue;
}
Expression expression = sub.getSelectorExpression();
if (expression == null) {
result.add(sessionId, subId);
continue;
}
try {
if (Boolean.TRUE.equals(expression.getValue(messageEvalContext, message, Boolean.class))) {
result.add(sessionId, subId);
}
}
catch (SpelEvaluationException ex) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to evaluate selector: " + ex.getMessage());
allMatches.forEach((sessionId, subscriptionsIds) -> {
SessionSubscriptionInfo subscriptions = this.subscriptionRegistry.getSubscriptions(sessionId);
if (subscriptions != null) {
for (String subscriptionId : subscriptionsIds) {
Subscription subscription = subscriptions.getSubscription(subscriptionId);
if (subscription != null && evaluateExpression(subscription.getSelectorExpression(), message)) {
result.add(sessionId, subscription.getId());
}
}
catch (Throwable ex) {
logger.debug("Failed to evaluate selector", ex);
}
}
});
return result;
}
@Override
public String toString() {
return "DefaultSubscriptionRegistry[" + this.destinationCache + ", " + this.subscriptionRegistry + "]";
private boolean evaluateExpression(@Nullable Expression expression, Message<?> message) {
boolean result = false;
try {
if (expression == null || Boolean.TRUE.equals(expression.getValue(messageEvalContext, message, Boolean.class))) {
result = true;
}
}
catch (SpelEvaluationException ex) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to evaluate selector: " + ex.getMessage());
}
}
catch (Throwable ex) {
logger.debug("Failed to evaluate selector", ex);
}
return result;
}
/**
* A cache for destinations previously resolved via
* {@link DefaultSubscriptionRegistry#findSubscriptionsInternal(String, Message)}.
*/
private class DestinationCache {
private final class DestinationCache {
/** Map from destination to {@code <sessionId, subscriptionId>} for fast look-ups. */
private final Map<String, LinkedMultiValueMap<String, String>> accessCache =
private final Map<String, LinkedMultiValueMap<String, String>> destinationCache =
new ConcurrentHashMap<>(DEFAULT_CACHE_LIMIT);
/** Map from destination to {@code <sessionId, subscriptionId>} with locking. */
@SuppressWarnings("serial")
private final Map<String, LinkedMultiValueMap<String, String>> updateCache =
new LinkedHashMap<String, LinkedMultiValueMap<String, String>>(DEFAULT_CACHE_LIMIT, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, LinkedMultiValueMap<String, String>> eldest) {
if (size() > getCacheLimit()) {
accessCache.remove(eldest.getKey());
return true;
}
else {
return false;
}
}
};
public LinkedMultiValueMap<String, String> getSubscriptions(String destination, Message<?> message) {
LinkedMultiValueMap<String, String> result = this.accessCache.get(destination);
if (result == null) {
synchronized (this.updateCache) {
result = new LinkedMultiValueMap<>();
for (SessionSubscriptionInfo info : subscriptionRegistry.getAllSubscriptions()) {
for (String destinationPattern : info.getDestinations()) {
if (getPathMatcher().match(destinationPattern, destination)) {
for (Subscription sub : info.getSubscriptions(destinationPattern)) {
result.add(info.sessionId, sub.getId());
}
}
}
}
if (!result.isEmpty()) {
this.updateCache.put(destination, result.deepCopy());
this.accessCache.put(destination, result);
private final Queue<String> cacheEvictionPolicy = new ConcurrentLinkedQueue<>();
private final AtomicInteger cacheSize = new AtomicInteger();
public LinkedMultiValueMap<String, String> getSubscriptions(String destination) {
LinkedMultiValueMap<String, String> subscriptions = this.destinationCache.get(destination);
if (subscriptions == null) {
subscriptions = this.destinationCache.computeIfAbsent(destination, dest -> {
LinkedMultiValueMap<String, String> sessionSubscriptions = calculateSubscriptions(destination);
this.cacheEvictionPolicy.add(destination);
this.cacheSize.incrementAndGet();
return sessionSubscriptions;
});
ensureCacheLimit();
}
return subscriptions;
}
@NonNull
private LinkedMultiValueMap<String, String> calculateSubscriptions(String destination) {
LinkedMultiValueMap<String, String> sessionsToSubscriptions = new LinkedMultiValueMap<>();
DefaultSubscriptionRegistry.this.subscriptionRegistry.forEachSubscription((sessionId, subscriptionDetail) -> {
if (subscriptionDetail.isAntPattern()) {
if (pathMatcher.match(subscriptionDetail.getDestination(), destination)) {
sessionsToSubscriptions.compute(sessionId, (s, subscriptions) ->
addToList(subscriptionDetail.getId(), subscriptions));
}
}
else if (destination.equals(subscriptionDetail.getDestination())) {
sessionsToSubscriptions.compute(sessionId, (s, subscriptions) ->
addToList(subscriptionDetail.getId(), subscriptions));
}
});
return sessionsToSubscriptions;
}
@NonNull
private List<String> addToList(String subscriptionId, @Nullable List<String> subscriptions) {
if (subscriptions == null) {
return Collections.singletonList(subscriptionId);
}
else {
List<String> newSubscriptions = new ArrayList<>(subscriptions.size() + 1);
newSubscriptions.addAll(subscriptions);
newSubscriptions.add(subscriptionId);
return newSubscriptions;
}
return result;
}
public void updateAfterNewSubscription(String destination, String sessionId, String subsId) {
synchronized (this.updateCache) {
this.updateCache.forEach((cachedDestination, subscriptions) -> {
if (getPathMatcher().match(destination, cachedDestination)) {
// Subscription id's may also be populated via getSubscriptions()
List<String> subsForSession = subscriptions.get(sessionId);
if (subsForSession == null || !subsForSession.contains(subsId)) {
subscriptions.add(sessionId, subsId);
this.accessCache.put(cachedDestination, subscriptions.deepCopy());
}
}
private void ensureCacheLimit() {
int size = this.cacheSize.get();
if (size > cacheLimit) {
do {
if (this.cacheSize.compareAndSet(size, size - 1)) {
this.destinationCache.remove(this.cacheEvictionPolicy.poll());
}
});
} while ((size = this.cacheSize.get()) > cacheLimit);
}
}
public void updateAfterRemovedSubscription(String sessionId, String subsId) {
synchronized (this.updateCache) {
Set<String> destinationsToRemove = new HashSet<>();
this.updateCache.forEach((destination, sessionMap) -> {
List<String> subscriptions = sessionMap.get(sessionId);
if (subscriptions != null) {
subscriptions.remove(subsId);
if (subscriptions.isEmpty()) {
sessionMap.remove(sessionId);
}
if (sessionMap.isEmpty()) {
destinationsToRemove.add(destination);
}
else {
this.accessCache.put(destination, sessionMap.deepCopy());
}
public void updateAfterNewSubscription(String destination, boolean isPattern, String sessionId, String subscriptionId) {
if (isPattern) {
for (String cachedDestination : this.destinationCache.keySet()) {
if (pathMatcher.match(destination, cachedDestination)) {
addToDestination(cachedDestination, sessionId, subscriptionId);
}
});
for (String destination : destinationsToRemove) {
this.updateCache.remove(destination);
this.accessCache.remove(destination);
}
}
else {
addToDestination(destination, sessionId, subscriptionId);
}
}
public void updateAfterRemovedSession(SessionSubscriptionInfo info) {
synchronized (this.updateCache) {
Set<String> destinationsToRemove = new HashSet<>();
this.updateCache.forEach((destination, sessionMap) -> {
if (sessionMap.remove(info.getSessionId()) != null) {
if (sessionMap.isEmpty()) {
destinationsToRemove.add(destination);
}
else {
this.accessCache.put(destination, sessionMap.deepCopy());
}
private void addToDestination(String destination, String sessionId, String subscriptionId) {
this.destinationCache.computeIfPresent(destination, (dest, sessionsToSubscriptions) -> {
sessionsToSubscriptions = sessionsToSubscriptions.clone();
sessionsToSubscriptions.compute(sessionId, (s, subscriptions) -> addToList(subscriptionId, subscriptions));
return sessionsToSubscriptions;
});
}
public void updateAfterRemovedSubscription(String sessionId, Subscription subscriptionDetail) {
if (subscriptionDetail.isAntPattern()) {
String patternDestination = subscriptionDetail.getDestination();
for (String destination : this.destinationCache.keySet()) {
if (pathMatcher.match(patternDestination, destination)) {
removeInternal(destination, sessionId, subscriptionDetail.getId());
}
});
for (String destination : destinationsToRemove) {
this.updateCache.remove(destination);
this.accessCache.remove(destination);
}
}
else {
removeInternal(subscriptionDetail.getDestination(), sessionId, subscriptionDetail.getId());
}
}
@Override
public String toString() {
return "cache[" + this.accessCache.size() + " destination(s)]";
private void removeInternal(String destination, String sessionId, String subscription) {
this.destinationCache.computeIfPresent(destination, (dest, subscriptions) -> {
subscriptions = subscriptions.clone();
subscriptions.computeIfPresent(sessionId, (session, subs) -> {
/* it is very likely that one session has only one subscription per one destination */
if (subs.size() == 1 && subscription.equals(subs.get(0))) {
return null;
}
else {
subs = new ArrayList<>(subs);
subs.remove(subscription);
return emptyListToNUll(subs);
}
});
return subscriptions;
});
}
@Nullable
private <T> List<T> emptyListToNUll(@NonNull List<T> list) {
return list.isEmpty() ? null : list;
}
}
public void updateAfterRemovedSession(String sessionId, Collection<Subscription> subscriptionDetails) {
for (Subscription subscriptionDetail : subscriptionDetails) {
updateAfterRemovedSubscription(sessionId, subscriptionDetail);
}
}
}
/**
* Provide access to session subscriptions by sessionId.
*/
private static class SessionSubscriptionRegistry {
private static final class SessionSubscriptionRegistry {
// sessionId -> SessionSubscriptionInfo
// 'sessionId' -> 'subscriptionId' -> 'destination, selector expression'
private final ConcurrentMap<String, SessionSubscriptionInfo> sessions = new ConcurrentHashMap<>();
@Nullable
@ -376,119 +396,51 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -376,119 +396,51 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
return this.sessions.get(sessionId);
}
public Collection<SessionSubscriptionInfo> getAllSubscriptions() {
return this.sessions.values();
public void forEachSubscription(BiConsumer<String, Subscription> consumer) {
this.sessions.forEach((sessionId, subscriptions) ->
subscriptions.getSubscriptions().forEach(subscriptionDetail ->
consumer.accept(sessionId, subscriptionDetail)));
}
public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId,
String destination, @Nullable Expression selectorExpression) {
SessionSubscriptionInfo info = this.sessions.get(sessionId);
if (info == null) {
info = new SessionSubscriptionInfo(sessionId);
SessionSubscriptionInfo value = this.sessions.putIfAbsent(sessionId, info);
if (value != null) {
info = value;
}
}
info.addSubscription(destination, subscriptionId, selectorExpression);
return info;
@Nullable
public Subscription addSubscription(String sessionId, String subscriptionId, Subscription subscriptionDetail) {
SessionSubscriptionInfo subscriptions = this.sessions.computeIfAbsent(sessionId, s -> new SessionSubscriptionInfo());
return subscriptions.addSubscription(subscriptionId, subscriptionDetail);
}
@Nullable
public SessionSubscriptionInfo removeSubscriptions(String sessionId) {
return this.sessions.remove(sessionId);
}
@Override
public String toString() {
return "registry[" + this.sessions.size() + " sessions]";
}
}
/**
* Hold subscriptions for a session.
*/
private static class SessionSubscriptionInfo {
private final String sessionId;
// destination -> subscriptions
private final Map<String, Set<Subscription>> destinationLookup = new ConcurrentHashMap<>(4);
public SessionSubscriptionInfo(String sessionId) {
Assert.notNull(sessionId, "'sessionId' must not be null");
this.sessionId = sessionId;
}
private static final class SessionSubscriptionInfo {
public String getSessionId() {
return this.sessionId;
}
public Set<String> getDestinations() {
return this.destinationLookup.keySet();
}
private final Map<String, Subscription> subscriptionLookup = new ConcurrentHashMap<>();
public Set<Subscription> getSubscriptions(String destination) {
return this.destinationLookup.get(destination);
public Collection<Subscription> getSubscriptions() {
return this.subscriptionLookup.values();
}
@Nullable
public Subscription getSubscription(String subscriptionId) {
for (Map.Entry<String, Set<DefaultSubscriptionRegistry.Subscription>> destinationEntry :
this.destinationLookup.entrySet()) {
for (Subscription sub : destinationEntry.getValue()) {
if (sub.getId().equalsIgnoreCase(subscriptionId)) {
return sub;
}
}
}
return null;
}
public void addSubscription(String destination, String subscriptionId, @Nullable Expression selectorExpression) {
Set<Subscription> subs = this.destinationLookup.get(destination);
if (subs == null) {
synchronized (this.destinationLookup) {
subs = this.destinationLookup.get(destination);
if (subs == null) {
subs = new CopyOnWriteArraySet<>();
this.destinationLookup.put(destination, subs);
}
}
}
subs.add(new Subscription(subscriptionId, selectorExpression));
return this.subscriptionLookup.get(subscriptionId);
}
@Nullable
public String removeSubscription(String subscriptionId) {
for (Map.Entry<String, Set<DefaultSubscriptionRegistry.Subscription>> destinationEntry :
this.destinationLookup.entrySet()) {
Set<Subscription> subs = destinationEntry.getValue();
if (subs != null) {
for (Subscription sub : subs) {
if (sub.getId().equals(subscriptionId) && subs.remove(sub)) {
synchronized (this.destinationLookup) {
if (subs.isEmpty()) {
this.destinationLookup.remove(destinationEntry.getKey());
}
}
return destinationEntry.getKey();
}
}
}
}
return null;
public Subscription addSubscription(String subscriptionId, Subscription subscriptionDetail) {
return this.subscriptionLookup.putIfAbsent(subscriptionId, subscriptionDetail);
}
@Override
public String toString() {
return "[sessionId=" + this.sessionId + ", subscriptions=" + this.destinationLookup + "]";
@Nullable
public Subscription removeSubscription(String subscriptionId) {
return this.subscriptionLookup.remove(subscriptionId);
}
}
private static final class Subscription {
private final String id;
@ -496,16 +448,31 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -496,16 +448,31 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
@Nullable
private final Expression selectorExpression;
public Subscription(String id, @Nullable Expression selector) {
private final String destination;
private final boolean isAntPattern;
public Subscription(String id, @Nullable Expression selector, String destination, boolean isAntPattern) {
Assert.notNull(id, "Subscription id must not be null");
Assert.notNull(destination, "Subscription destination must not be null");
this.id = id;
this.selectorExpression = selector;
this.destination = destination;
this.isAntPattern = isAntPattern;
}
public String getId() {
return this.id;
}
public String getDestination() {
return this.destination;
}
public boolean isAntPattern() {
return this.isAntPattern;
}
@Nullable
public Expression getSelectorExpression() {
return this.selectorExpression;

21
spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java

@ -96,6 +96,21 @@ public class DefaultSubscriptionRegistryTests { @@ -96,6 +96,21 @@ public class DefaultSubscriptionRegistryTests {
assertThat(sort(actual.get(sessId))).isEqualTo(subscriptionIds);
}
@Test
public void registerSameSubscriptionTwice() {
String sessId = "sess01";
String subId = "subs01";
String dest = "/foo";
this.registry.registerSubscription(subscribeMessage(sessId, subId, dest));
this.registry.registerSubscription(subscribeMessage(sessId, subId, dest));
MultiValueMap<String, String> actual = this.registry.findSubscriptions(createMessage(dest));
assertThat(actual).isNotNull();
assertThat(actual.size()).isEqualTo(1);
assertThat(actual.get(sessId)).containsExactly(subId);
}
@Test
public void registerSubscriptionMultipleSessions() {
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
@ -148,7 +163,7 @@ public class DefaultSubscriptionRegistryTests { @@ -148,7 +163,7 @@ public class DefaultSubscriptionRegistryTests {
MultiValueMap<String, String> actual = this.registry.findSubscriptions(destNasdaqIbmMessage);
assertThat(actual).isNotNull();
assertThat(actual.size()).isEqualTo(1);
assertThat(actual.get(sess1)).isEqualTo(Arrays.asList(subs2, subs1));
assertThat(actual.get(sess1)).containsExactlyInAnyOrder(subs2, subs1);
this.registry.registerSubscription(subscribeMessage(sess2, subs1, destNasdaqIbm));
this.registry.registerSubscription(subscribeMessage(sess2, subs2, "/topic/PRICE.STOCK.NYSE.IBM"));
@ -157,7 +172,7 @@ public class DefaultSubscriptionRegistryTests { @@ -157,7 +172,7 @@ public class DefaultSubscriptionRegistryTests {
actual = this.registry.findSubscriptions(destNasdaqIbmMessage);
assertThat(actual).isNotNull();
assertThat(actual.size()).isEqualTo(2);
assertThat(actual.get(sess1)).isEqualTo(Arrays.asList(subs2, subs1));
assertThat(actual.get(sess1)).containsExactlyInAnyOrder(subs2, subs1);
assertThat(actual.get(sess2)).isEqualTo(Collections.singletonList(subs1));
this.registry.unregisterAllSubscriptions(sess1);
@ -173,7 +188,7 @@ public class DefaultSubscriptionRegistryTests { @@ -173,7 +188,7 @@ public class DefaultSubscriptionRegistryTests {
actual = this.registry.findSubscriptions(destNasdaqIbmMessage);
assertThat(actual).isNotNull();
assertThat(actual.size()).isEqualTo(2);
assertThat(actual.get(sess1)).isEqualTo(Arrays.asList(subs1, subs2));
assertThat(actual.get(sess1)).containsExactlyInAnyOrder(subs1, subs2);
assertThat(actual.get(sess2)).isEqualTo(Collections.singletonList(subs1));
this.registry.unregisterSubscription(unsubscribeMessage(sess1, subs2));

Loading…
Cancel
Save