Browse Source

Add DestinationUserNameProvider interface

The interface is to be implemented in addition to
java.security.Principal when Principal.getName() is not globally unique
enough for use in user destinations.

Issue: SPR-11327
pull/449/head
Rossen Stoyanchev 12 years ago
parent
commit
e4ad2b352e
  1. 4
      spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java
  2. 20
      spring-messaging/src/main/java/org/springframework/messaging/simp/user/DestinationUserNameProvider.java
  3. 26
      spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java
  4. 18
      spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
  5. 76
      spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

4
spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java

@ -165,7 +165,9 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { @@ -165,7 +165,9 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return true;
}
protected String getTargetDestination(String origDestination, String targetDestination, String sessionId, String user) {
protected String getTargetDestination(String origDestination, String targetDestination,
String sessionId, String user) {
return targetDestination + "-user" + sessionId;
}

20
spring-messaging/src/main/java/org/springframework/messaging/simp/user/DestinationUserNameProvider.java

@ -0,0 +1,20 @@ @@ -0,0 +1,20 @@
package org.springframework.messaging.simp.user;
/**
* An interface to be implemented in addition to {@link java.security.Principal}
* when {@link java.security.Principal#getName()} is not globally unique enough
* for use in user destinations. For more on user destination see
* {@link org.springframework.messaging.simp.user.UserDestinationResolver}.
*
* @author Rossen Stoyanchev
* @since 4.0.1
*/
public interface DestinationUserNameProvider {
/**
* Return the (globally unique) user name to use with user destinations.
*/
String getDestinationUserName();
}

26
spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java

@ -39,19 +39,21 @@ public class DefaultUserDestinationResolverTests { @@ -39,19 +39,21 @@ public class DefaultUserDestinationResolverTests {
private UserSessionRegistry registry;
private TestPrincipal user;
@Before
public void setup() {
this.user = new TestPrincipal("joe");
this.registry = new DefaultUserSessionRegistry();
this.registry.registerSessionId("joe", SESSION_ID);
this.registry.registerSessionId(this.user.getName(), SESSION_ID);
this.resolver = new DefaultUserDestinationResolver(this.registry);
}
@Test
public void handleSubscribe() {
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo");
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size());
@ -66,7 +68,7 @@ public class DefaultUserDestinationResolverTests { @@ -66,7 +68,7 @@ public class DefaultUserDestinationResolverTests {
this.registry.registerSessionId("joe", "456");
this.registry.registerSessionId("joe", "789");
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo");
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size());
@ -75,7 +77,7 @@ public class DefaultUserDestinationResolverTests { @@ -75,7 +77,7 @@ public class DefaultUserDestinationResolverTests {
@Test
public void handleUnsubscribe() {
Message<?> message = createMessage(SimpMessageType.UNSUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
Message<?> message = createMessage(SimpMessageType.UNSUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo");
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size());
@ -84,7 +86,7 @@ public class DefaultUserDestinationResolverTests { @@ -84,7 +86,7 @@ public class DefaultUserDestinationResolverTests {
@Test
public void handleMessage() {
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/user/joe/queue/foo");
Message<?> message = createMessage(SimpMessageType.MESSAGE, this.user, SESSION_ID, "/user/joe/queue/foo");
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size());
@ -96,12 +98,12 @@ public class DefaultUserDestinationResolverTests { @@ -96,12 +98,12 @@ public class DefaultUserDestinationResolverTests {
public void ignoreMessage() {
// no destination
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, null);
Message<?> message = createMessage(SimpMessageType.MESSAGE, this.user, SESSION_ID, null);
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size());
// not a user destination
message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/queue/foo");
message = createMessage(SimpMessageType.MESSAGE, this.user, SESSION_ID, "/queue/foo");
actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size());
@ -111,24 +113,24 @@ public class DefaultUserDestinationResolverTests { @@ -111,24 +113,24 @@ public class DefaultUserDestinationResolverTests {
assertEquals(0, actual.size());
// subscribe + not a user destination
message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/queue/foo");
message = createMessage(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/queue/foo");
actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size());
// no match on message type
message = createMessage(SimpMessageType.CONNECT, "joe", SESSION_ID, "user/joe/queue/foo");
message = createMessage(SimpMessageType.CONNECT, this.user, SESSION_ID, "user/joe/queue/foo");
actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size());
}
private Message<?> createMessage(SimpMessageType messageType, String user, String sessionId, String destination) {
private Message<?> createMessage(SimpMessageType messageType, TestPrincipal user, String sessionId, String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType);
if (destination != null) {
headers.setDestination(destination);
}
if (user != null) {
headers.setUser(new TestPrincipal(user));
headers.setUser(user);
}
if (sessionId != null) {
headers.setSessionId(sessionId);

18
spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

@ -35,6 +35,7 @@ import org.springframework.messaging.simp.stomp.StompConversionException; @@ -35,6 +35,7 @@ import org.springframework.messaging.simp.stomp.StompConversionException;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
@ -240,11 +241,20 @@ public class StompSubProtocolHandler implements SubProtocolHandler { @@ -240,11 +241,20 @@ public class StompSubProtocolHandler implements SubProtocolHandler {
if (principal != null) {
headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
if (this.userSessionRegistry != null) {
this.userSessionRegistry.registerSessionId(principal.getName(), session.getId());
String userName = getNameForUserSessionRegistry(principal);
this.userSessionRegistry.registerSessionId(userName, session.getId());
}
}
}
private String getNameForUserSessionRegistry(Principal principal) {
String userName = principal.getName();
if (principal instanceof DestinationUserNameProvider) {
userName = ((DestinationUserNameProvider) principal).getDestinationUserName();
}
return userName;
}
@Override
public String resolveSessionId(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
@ -258,8 +268,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler { @@ -258,8 +268,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler {
@Override
public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
if ((this.userSessionRegistry != null) && (session.getPrincipal() != null)) {
this.userSessionRegistry.unregisterSessionId(session.getPrincipal().getName(), session.getId());
Principal principal = session.getPrincipal();
if ((this.userSessionRegistry != null) && (principal != null)) {
String userName = getNameForUserSessionRegistry(principal);
this.userSessionRegistry.unregisterSessionId(userName, session.getId());
}
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);

76
spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@ package org.springframework.web.socket.messaging; @@ -18,6 +18,7 @@ package org.springframework.web.socket.messaging;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import org.junit.Before;
@ -32,6 +33,9 @@ import org.springframework.messaging.simp.TestPrincipal; @@ -32,6 +33,9 @@ import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
@ -47,7 +51,7 @@ import static org.mockito.Mockito.*; @@ -47,7 +51,7 @@ import static org.mockito.Mockito.*;
*/
public class StompSubProtocolHandlerTests {
private StompSubProtocolHandler stompHandler;
private StompSubProtocolHandler protocolHandler;
private TestWebSocketSession session;
@ -58,7 +62,7 @@ public class StompSubProtocolHandlerTests { @@ -58,7 +62,7 @@ public class StompSubProtocolHandlerTests {
@Before
public void setup() {
this.stompHandler = new StompSubProtocolHandler();
this.protocolHandler = new StompSubProtocolHandler();
this.channel = Mockito.mock(MessageChannel.class);
this.messageCaptor = ArgumentCaptor.forClass(Message.class);
@ -68,18 +72,55 @@ public class StompSubProtocolHandlerTests { @@ -68,18 +72,55 @@ public class StompSubProtocolHandlerTests {
}
@Test
public void connectedResponseIsSentWhenConnectAckIsToBeSentToClient() {
public void handleMessageToClientConnected() {
UserSessionRegistry registry = new DefaultUserSessionRegistry();
this.protocolHandler.setUserSessionRegistry(registry);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size());
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());
assertEquals(Collections.singleton("s1"), registry.getSessionIds("joe"));
}
@Test
public void handleMessageToClientConnectedUniqueUserName() {
this.session.setPrincipal(new UniqueUser("joe"));
UserSessionRegistry registry = new DefaultUserSessionRegistry();
this.protocolHandler.setUserSessionRegistry(registry);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size());
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());
assertEquals(Collections.<String>emptySet(), registry.getSessionIds("joe"));
assertEquals(Collections.singleton("s1"), registry.getSessionIds("Me myself and I"));
}
@Test
public void handleMessageToClientConnectAck() {
StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT);
connectHeaders.setHeartbeat(10000, 10000);
connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1");
Message<?> connectMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build();
SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage);
Message<byte[]> connectAckMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build();
Message<byte[]> connectAck = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build();
this.stompHandler.handleMessageToClient(this.session, connectAck);
this.protocolHandler.handleMessageToClient(this.session, connectAckMessage);
verifyNoMoreInteractions(this.channel);
@ -97,12 +138,12 @@ public class StompSubProtocolHandlerTests { @@ -97,12 +138,12 @@ public class StompSubProtocolHandlerTests {
}
@Test
public void messagesAreAugmentedAndForwarded() {
public void handleMessageFromClient() {
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers(
"login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();
this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel);
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
verify(this.channel).send(this.messageCaptor.capture());
Message<?> actual = this.messageCaptor.getValue();
@ -121,11 +162,11 @@ public class StompSubProtocolHandlerTests { @@ -121,11 +162,11 @@ public class StompSubProtocolHandlerTests {
}
@Test
public void invalidStompCommand() {
public void handleMessageFromClientInvalidStompCommand() {
TextMessage textMessage = new TextMessage("FOO");
this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel);
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
verifyZeroInteractions(this.channel);
assertEquals(1, this.session.getSentMessages().size());
@ -133,4 +174,17 @@ public class StompSubProtocolHandlerTests { @@ -133,4 +174,17 @@ public class StompSubProtocolHandlerTests {
assertTrue(actual.getPayload().startsWith("ERROR"));
}
private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider {
private UniqueUser(String name) {
super(name);
}
@Override
public String getDestinationUserName() {
return "Me myself and I";
}
}
}

Loading…
Cancel
Save