Browse Source

Fix concurrency issue in DefaultSubscriptionRegistry

pull/286/merge
Rossen Stoyanchev 13 years ago
parent
commit
329fbf31bc
  1. 4
      spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java
  2. 42
      spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationMethodMessageHandler.java
  3. 68
      spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java
  4. 13
      spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java

4
spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java

@ -43,7 +43,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String @@ -43,7 +43,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
public SimpMessagingTemplate(MessageChannel messageChannel) {
Assert.notNull(messageChannel, "outputChannel is required");
Assert.notNull(messageChannel, "messageChannel is required");
this.messageChannel = messageChannel;
}
@ -117,6 +117,8 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String @@ -117,6 +117,8 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
}
}
@Override
public <T> void convertAndSendToUser(String user, String destination, T message) throws MessagingException {
convertAndSendToUser(user, destination, message, null);

42
spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationMethodMessageHandler.java

@ -34,8 +34,10 @@ import org.springframework.context.ApplicationContextAware; @@ -34,8 +34,10 @@ import org.springframework.context.ApplicationContextAware;
import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.core.AbstractMessageSendingTemplate;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.ReplyTo;
import org.springframework.messaging.handler.annotation.support.ExceptionHandlerMethodResolver;
@ -48,6 +50,7 @@ import org.springframework.messaging.handler.method.InvocableHandlerMethod; @@ -48,6 +50,7 @@ import org.springframework.messaging.handler.method.InvocableHandlerMethod;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageSendingOperations;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.annotation.SubscribeEvent;
import org.springframework.messaging.simp.annotation.UnsubscribeEvent;
import org.springframework.messaging.simp.annotation.support.PrincipalMethodArgumentResolver;
@ -68,9 +71,9 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati @@ -68,9 +71,9 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati
private static final Log logger = LogFactory.getLog(AnnotationMethodMessageHandler.class);
private final SimpMessageSendingOperations inboundMessagingTemplate;
private final SimpMessageSendingOperations dispatchMessagingTemplate;
private final SimpMessageSendingOperations outboundMessagingTemplate;
private final SimpMessageSendingOperations webSocketSessionMessagingTemplate;
private MessageConverter<?> messageConverter;
@ -91,24 +94,20 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati @@ -91,24 +94,20 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati
/**
* @param inboundMessagingTemplate a template for sending messages on the channel
* where incoming messages from clients are sent; essentially messages sent
* through this template will be re-processed by the application. One example
* is the use of {@link ReplyTo} annotation on a method to send a broadcast
* message.
* @param outboundMessagingTemplate a template for sending messages on the client used
* to send messages back out to connected clients; such messages must have all
* necessary information to reach the client such as session and subscription
* id's. One example is returning a value from an {@link SubscribeEvent}
* method.
* @param dispatchMessagingTemplate a messaging template to dispatch messages to for
* further processing, e.g. the use of an {@link ReplyTo} annotation on a
* message handling method, causes a new (broadcast) message to be sent.
* @param webSocketSessionChannel the channel to send messages to WebSocket sessions
* on this application server. This is used primarily for processing the return
* values from {@link SubscribeEvent}-annotated methods.
*/
public AnnotationMethodMessageHandler(SimpMessageSendingOperations inboundMessagingTemplate,
SimpMessageSendingOperations outboundMessagingTemplate) {
public AnnotationMethodMessageHandler(SimpMessageSendingOperations dispatchMessagingTemplate,
MessageChannel webSocketSessionChannel) {
Assert.notNull(inboundMessagingTemplate, "inboundMessagingTemplate is required");
Assert.notNull(outboundMessagingTemplate, "outboundMessagingTemplate is required");
this.inboundMessagingTemplate = inboundMessagingTemplate;
this.outboundMessagingTemplate = outboundMessagingTemplate;
Assert.notNull(dispatchMessagingTemplate, "dispatchMessagingTemplate is required");
Assert.notNull(webSocketSessionChannel, "webSocketSessionChannel is required");
this.dispatchMessagingTemplate = dispatchMessagingTemplate;
this.webSocketSessionMessagingTemplate = new SimpMessagingTemplate(webSocketSessionChannel);
}
/**
@ -116,6 +115,9 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati @@ -116,6 +115,9 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati
*/
public void setMessageConverter(MessageConverter<?> converter) {
this.messageConverter = converter;
if (converter != null) {
((AbstractMessageSendingTemplate<?>) this.webSocketSessionMessagingTemplate).setMessageConverter(converter);
}
}
@Override
@ -131,8 +133,8 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati @@ -131,8 +133,8 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati
this.argumentResolvers.addResolver(new PrincipalMethodArgumentResolver());
this.argumentResolvers.addResolver(new MessageBodyMethodArgumentResolver(this.messageConverter));
this.returnValueHandlers.addHandler(new ReplyToMethodReturnValueHandler(this.inboundMessagingTemplate));
this.returnValueHandlers.addHandler(new SubscriptionMethodReturnValueHandler(this.outboundMessagingTemplate));
this.returnValueHandlers.addHandler(new ReplyToMethodReturnValueHandler(this.dispatchMessagingTemplate));
this.returnValueHandlers.addHandler(new SubscriptionMethodReturnValueHandler(this.webSocketSessionMessagingTemplate));
}
protected void initHandlerMethods() {

68
spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java

@ -17,7 +17,6 @@ @@ -17,7 +17,6 @@
package org.springframework.messaging.simp.handler;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
@ -29,6 +28,8 @@ import org.springframework.util.AntPathMatcher; @@ -29,6 +28,8 @@ import org.springframework.util.AntPathMatcher;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import reactor.util.Assert;
/**
* @author Rossen Stoyanchev
@ -102,6 +103,14 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -102,6 +103,14 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
return result;
}
@Override
public String toString() {
return "[destinationCache=" + this.destinationCache + ", subscriptionRegistry="
+ this.subscriptionRegistry + "]";
}
/**
* Provide direct lookup of session subscriptions by destination (for non-pattern destinations).
@ -116,7 +125,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -116,7 +125,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
public void mapToDestination(String destination, SessionSubscriptionInfo info) {
synchronized (monitor) {
synchronized(this.monitor) {
Set<SessionSubscriptionInfo> registrations = this.subscriptionsByDestination.get(destination);
if (registrations == null) {
registrations = new CopyOnWriteArraySet<SessionSubscriptionInfo>();
@ -127,7 +136,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -127,7 +136,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
}
public void unmapFromDestination(String destination, SessionSubscriptionInfo info) {
synchronized (monitor) {
synchronized(this.monitor) {
Set<SessionSubscriptionInfo> infos = this.subscriptionsByDestination.get(destination);
if (infos != null) {
infos.remove(info);
@ -159,6 +168,11 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -159,6 +168,11 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
}
return result;
}
@Override
public String toString() {
return "[subscriptionsByDestination=" + this.subscriptionsByDestination + "]";
}
}
/**
@ -169,6 +183,8 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -169,6 +183,8 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
private final Map<String, SessionSubscriptionInfo> sessions =
new ConcurrentHashMap<String, SessionSubscriptionInfo>();
private final Object monitor = new Object();
public SessionSubscriptionInfo getSubscriptions(String sessionId) {
return this.sessions.get(sessionId);
@ -181,16 +197,26 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -181,16 +197,26 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, String destination) {
SessionSubscriptionInfo info = this.sessions.get(sessionId);
if (info == null) {
info = new SessionSubscriptionInfo(sessionId);
this.sessions.put(sessionId, info);
synchronized(this.monitor) {
info = this.sessions.get(sessionId);
if (info == null) {
info = new SessionSubscriptionInfo(sessionId);
this.sessions.put(sessionId, info);
}
}
}
info.addSubscription(subscriptionId, destination);
info.addSubscription(destination, subscriptionId);
return info;
}
public SessionSubscriptionInfo removeSubscriptions(String sessionId) {
return this.sessions.remove(sessionId);
}
@Override
public String toString() {
return "[sessions=" + sessions + "]";
}
}
/**
@ -200,10 +226,13 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -200,10 +226,13 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
private final String sessionId;
private final Map<String, Set<String>> subscriptions = new HashMap<String, Set<String>>(4);
private final Map<String, Set<String>> subscriptions = new ConcurrentHashMap<String, Set<String>>(4);
private final Object monitor = new Object();
public SessionSubscriptionInfo(String sessionId) {
Assert.notNull(sessionId, "sessionId is required");
this.sessionId = sessionId;
}
@ -219,27 +248,36 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -219,27 +248,36 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
return this.subscriptions.get(destination);
}
public void addSubscription(String subscriptionId, String destination) {
Set<String> subs = this.subscriptions.get(destination);
if (subs == null) {
subs = new HashSet<String>(4);
this.subscriptions.put(destination, subs);
public void addSubscription(String destination, String subscriptionId) {
synchronized(this.monitor) {
Set<String> subs = this.subscriptions.get(destination);
if (subs == null) {
subs = new HashSet<String>(4);
this.subscriptions.put(destination, subs);
}
subs.add(subscriptionId);
}
subs.add(subscriptionId);
}
public String removeSubscription(String subscriptionId) {
for (String destination : this.subscriptions.keySet()) {
Set<String> subscriptionIds = this.subscriptions.get(destination);
if (subscriptionIds.remove(subscriptionId)) {
if (subscriptionIds.isEmpty()) {
this.subscriptions.remove(destination);
synchronized(this.monitor) {
if (subscriptionIds.isEmpty()) {
this.subscriptions.remove(destination);
}
}
return destination;
}
}
return null;
}
@Override
public String toString() {
return "[sessionId=" + this.sessionId + ", subscriptions=" + this.subscriptions + "]";
}
}
}

13
spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java

@ -68,23 +68,30 @@ public class SimpleBrokerMessageHandler implements MessageHandler { @@ -68,23 +68,30 @@ public class SimpleBrokerMessageHandler implements MessageHandler {
SimpMessageType messageType = headers.getMessageType();
if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
// TODO: need a way to communicate back if subscription was successfully created or
// not in which case an ERROR should be sent back and close the connection
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
preProcessMessage(message);
this.subscriptionRegistry.registerSubscription(message);
}
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
preProcessMessage(message);
this.subscriptionRegistry.unregisterSubscription(message);
}
else if (SimpMessageType.MESSAGE.equals(messageType)) {
preProcessMessage(message);
sendMessageToSubscribers(headers.getDestination(), message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
preProcessMessage(message);
String sessionId = SimpMessageHeaderAccessor.wrap(message).getSessionId();
this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
}
}
private void preProcessMessage(Message<?> message) {
if (logger.isTraceEnabled()) {
logger.trace("Processing " + message);
}
}
protected void sendMessageToSubscribers(String destination, Message<?> message) {
MultiValueMap<String,String> subscriptions = this.subscriptionRegistry.findSubscriptions(message);
for (String sessionId : subscriptions.keySet()) {

Loading…
Cancel
Save