Browse Source

Fix issue in DefaultUserDestinationResolver

DefaultUserDestinationResolver now uses the session id of
SUBSCRIBE/UNSUBSCRIBE messages rather than looking up all session id's
associated with a user.

Issue: SPR-11325
pull/443/merge
Rossen Stoyanchev 12 years ago
parent
commit
b4e48d6749
  1. 45
      spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java
  2. 48
      spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java
  3. 30
      spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java

45
spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2014 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -96,17 +96,13 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return Collections.emptySet(); return Collections.emptySet();
} }
Set<String> set = new HashSet<String>(); Set<String> result = new HashSet<String>();
for (String sessionId : this.userSessionRegistry.getSessionIds(info.getUser())) { for (String sessionId : info.getSessionIds()) {
set.add(getTargetDestination(headers.getDestination(), info.getDestination(), sessionId, info.getUser())); result.add(getTargetDestination(
headers.getDestination(), info.getDestination(), sessionId, info.getUser()));
} }
return set;
}
protected String getTargetDestination(String originalDestination, String targetDestination,
String sessionId, String user) {
return targetDestination + "-user" + sessionId; return result;
} }
private UserDestinationInfo getUserDestinationInfo(SimpMessageHeaderAccessor headers) { private UserDestinationInfo getUserDestinationInfo(SimpMessageHeaderAccessor headers) {
@ -115,6 +111,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
String targetUser; String targetUser;
String targetDestination; String targetDestination;
Set<String> sessionIds;
Principal user = headers.getUser(); Principal user = headers.getUser();
SimpMessageType messageType = headers.getMessageType(); SimpMessageType messageType = headers.getMessageType();
@ -124,11 +121,16 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return null; return null;
} }
if (user == null) { if (user == null) {
logger.warn("Ignoring message, no user information"); logger.error("Ignoring message, no user info available");
return null;
}
if (headers.getSessionId() == null) {
logger.error("Ignoring message, no session id available");
return null; return null;
} }
targetUser = user.getName(); targetUser = user.getName();
targetDestination = destination.substring(this.destinationPrefix.length()-1); targetDestination = destination.substring(this.destinationPrefix.length()-1);
sessionIds = Collections.singleton(headers.getSessionId());
} }
else if (SimpMessageType.MESSAGE.equals(messageType)) { else if (SimpMessageType.MESSAGE.equals(messageType)) {
if (!checkDestination(destination, this.destinationPrefix)) { if (!checkDestination(destination, this.destinationPrefix)) {
@ -139,7 +141,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
Assert.isTrue(endIndex > 0, "Expected destination pattern \"/user/{userId}/**\""); Assert.isTrue(endIndex > 0, "Expected destination pattern \"/user/{userId}/**\"");
targetUser = destination.substring(startIndex, endIndex); targetUser = destination.substring(startIndex, endIndex);
targetDestination = destination.substring(endIndex); targetDestination = destination.substring(endIndex);
sessionIds = this.userSessionRegistry.getSessionIds(targetUser);
} }
else { else {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
@ -148,7 +150,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return null; return null;
} }
return new UserDestinationInfo(targetUser, targetDestination); return new UserDestinationInfo(targetUser, targetDestination, sessionIds);
} }
protected boolean checkDestination(String destination, String requiredPrefix) { protected boolean checkDestination(String destination, String requiredPrefix) {
@ -165,6 +167,10 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return true; return true;
} }
protected String getTargetDestination(String origDestination, String targetDestination, String sessionId, String user) {
return targetDestination + "-user" + sessionId;
}
private static class UserDestinationInfo { private static class UserDestinationInfo {
@ -172,18 +178,25 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
private final String destination; private final String destination;
private UserDestinationInfo(String user, String destination) { private final Set<String> sessionIds;
private UserDestinationInfo(String user, String destination, Set<String> sessionIds) {
this.user = user; this.user = user;
this.destination = destination; this.destination = destination;
this.sessionIds = sessionIds;
} }
private String getUser() { public String getUser() {
return this.user; return this.user;
} }
private String getDestination() { public String getDestination() {
return this.destination; return this.destination;
} }
public Set<String> getSessionIds() {
return this.sessionIds;
}
} }
} }

48
spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2014 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,9 +22,6 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.simp.user.DefaultUserDestinationResolver;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import java.util.Set; import java.util.Set;
@ -36,6 +33,8 @@ import static org.junit.Assert.assertEquals;
*/ */
public class DefaultUserDestinationResolverTests { public class DefaultUserDestinationResolverTests {
public static final String SESSION_ID = "123";
private DefaultUserDestinationResolver resolver; private DefaultUserDestinationResolver resolver;
private UserSessionRegistry registry; private UserSessionRegistry registry;
@ -44,14 +43,30 @@ public class DefaultUserDestinationResolverTests {
@Before @Before
public void setup() { public void setup() {
this.registry = new DefaultUserSessionRegistry(); this.registry = new DefaultUserSessionRegistry();
this.registry.registerSessionId("joe", SESSION_ID);
this.resolver = new DefaultUserDestinationResolver(this.registry); this.resolver = new DefaultUserDestinationResolver(this.registry);
} }
@Test @Test
public void handleSubscribe() { public void handleSubscribe() {
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", "/user/queue/foo"); Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
this.registry.registerSessionId("joe", "123"); Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size());
assertEquals("/queue/foo-user123", actual.iterator().next());
}
// SPR-11325
@Test
public void handleSubscribeOneUserMultipleSessions() {
this.registry.registerSessionId("joe", "456");
this.registry.registerSessionId("joe", "789");
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
Set<String> actual = this.resolver.resolveDestination(message); Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size()); assertEquals(1, actual.size());
@ -60,8 +75,7 @@ public class DefaultUserDestinationResolverTests {
@Test @Test
public void handleUnsubscribe() { public void handleUnsubscribe() {
Message<?> message = createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "/user/queue/foo"); Message<?> message = createMessage(SimpMessageType.UNSUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
this.registry.registerSessionId("joe", "123");
Set<String> actual = this.resolver.resolveDestination(message); Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size()); assertEquals(1, actual.size());
@ -70,8 +84,7 @@ public class DefaultUserDestinationResolverTests {
@Test @Test
public void handleMessage() { public void handleMessage() {
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", "/user/joe/queue/foo"); Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/user/joe/queue/foo");
this.registry.registerSessionId("joe", "123");
Set<String> actual = this.resolver.resolveDestination(message); Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size()); assertEquals(1, actual.size());
@ -83,33 +96,33 @@ public class DefaultUserDestinationResolverTests {
public void ignoreMessage() { public void ignoreMessage() {
// no destination // no destination
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", null); Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, null);
Set<String> actual = this.resolver.resolveDestination(message); Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size()); assertEquals(0, actual.size());
// not a user destination // not a user destination
message = createMessage(SimpMessageType.MESSAGE, "joe", "/queue/foo"); message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/queue/foo");
actual = this.resolver.resolveDestination(message); actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size()); assertEquals(0, actual.size());
// subscribe + no user // subscribe + no user
message = createMessage(SimpMessageType.SUBSCRIBE, null, "/user/queue/foo"); message = createMessage(SimpMessageType.SUBSCRIBE, null, SESSION_ID, "/user/queue/foo");
actual = this.resolver.resolveDestination(message); actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size()); assertEquals(0, actual.size());
// subscribe + not a user destination // subscribe + not a user destination
message = createMessage(SimpMessageType.SUBSCRIBE, "joe", "/queue/foo"); message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/queue/foo");
actual = this.resolver.resolveDestination(message); actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size()); assertEquals(0, actual.size());
// no match on message type // no match on message type
message = createMessage(SimpMessageType.CONNECT, "joe", "user/joe/queue/foo"); message = createMessage(SimpMessageType.CONNECT, "joe", SESSION_ID, "user/joe/queue/foo");
actual = this.resolver.resolveDestination(message); actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size()); assertEquals(0, actual.size());
} }
private Message<?> createMessage(SimpMessageType messageType, String user, String destination) { private Message<?> createMessage(SimpMessageType messageType, String user, String sessionId, String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType);
if (destination != null) { if (destination != null) {
headers.setDestination(destination); headers.setDestination(destination);
@ -117,6 +130,9 @@ public class DefaultUserDestinationResolverTests {
if (user != null) { if (user != null) {
headers.setUser(new TestPrincipal(user)); headers.setUser(new TestPrincipal(user));
} }
if (sessionId != null) {
headers.setSessionId(sessionId);
}
return MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); return MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
} }

30
spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2014 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,10 +28,6 @@ import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.simp.user.DefaultUserDestinationResolver;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@ -42,6 +38,7 @@ import static org.mockito.Mockito.*;
*/ */
public class UserDestinationMessageHandlerTests { public class UserDestinationMessageHandlerTests {
public static final String SESSION_ID = "123";
private UserDestinationMessageHandler messageHandler; private UserDestinationMessageHandler messageHandler;
@ -63,9 +60,8 @@ public class UserDestinationMessageHandlerTests {
@Test @Test
public void handleSubscribe() { public void handleSubscribe() {
this.registry.registerSessionId("joe", "123");
when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true); when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true);
this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "/user/queue/foo")); this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo"));
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture()); Mockito.verify(this.brokerChannel).send(captor.capture());
@ -76,9 +72,8 @@ public class UserDestinationMessageHandlerTests {
@Test @Test
public void handleUnsubscribe() { public void handleUnsubscribe() {
this.registry.registerSessionId("joe", "123");
when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true); when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true);
this.messageHandler.handleMessage(createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "/user/queue/foo")); this.messageHandler.handleMessage(createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "123", "/user/queue/foo"));
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture()); Mockito.verify(this.brokerChannel).send(captor.capture());
@ -91,7 +86,7 @@ public class UserDestinationMessageHandlerTests {
public void handleMessage() { public void handleMessage() {
this.registry.registerSessionId("joe", "123"); this.registry.registerSessionId("joe", "123");
when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true); when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true);
this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "/user/joe/queue/foo")); this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "123", "/user/joe/queue/foo"));
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture()); Mockito.verify(this.brokerChannel).send(captor.capture());
@ -105,28 +100,28 @@ public class UserDestinationMessageHandlerTests {
public void ignoreMessage() { public void ignoreMessage() {
// no destination // no destination
this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", null)); this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "123", null));
Mockito.verifyZeroInteractions(this.brokerChannel); Mockito.verifyZeroInteractions(this.brokerChannel);
// not a user destination // not a user destination
this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "/queue/foo")); this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "123", "/queue/foo"));
Mockito.verifyZeroInteractions(this.brokerChannel); Mockito.verifyZeroInteractions(this.brokerChannel);
// subscribe + no user // subscribe + no user
this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, null, "/user/queue/foo")); this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, null, "123", "/user/queue/foo"));
Mockito.verifyZeroInteractions(this.brokerChannel); Mockito.verifyZeroInteractions(this.brokerChannel);
// subscribe + not a user destination // subscribe + not a user destination
this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "/queue/foo")); this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "123", "/queue/foo"));
Mockito.verifyZeroInteractions(this.brokerChannel); Mockito.verifyZeroInteractions(this.brokerChannel);
// no match on message type // no match on message type
this.messageHandler.handleMessage(createMessage(SimpMessageType.CONNECT, "joe", "user/joe/queue/foo")); this.messageHandler.handleMessage(createMessage(SimpMessageType.CONNECT, "joe", "123", "user/joe/queue/foo"));
Mockito.verifyZeroInteractions(this.brokerChannel); Mockito.verifyZeroInteractions(this.brokerChannel);
} }
private Message<?> createMessage(SimpMessageType messageType, String user, String destination) { private Message<?> createMessage(SimpMessageType messageType, String user, String sessionId, String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType);
if (destination != null) { if (destination != null) {
headers.setDestination(destination); headers.setDestination(destination);
@ -134,6 +129,9 @@ public class UserDestinationMessageHandlerTests {
if (user != null) { if (user != null) {
headers.setUser(new TestPrincipal(user)); headers.setUser(new TestPrincipal(user));
} }
if (sessionId != null) {
headers.setSessionId(sessionId);
}
return MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); return MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
} }

Loading…
Cancel
Save