From e4ad2b352e35a0cdbc5d4834f9689a13162c8d19 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 21 Jan 2014 11:50:04 -0500 Subject: [PATCH] Add DestinationUserNameProvider interface The interface is to be implemented in addition to java.security.Principal when Principal.getName() is not globally unique enough for use in user destinations. Issue: SPR-11327 --- .../user/DefaultUserDestinationResolver.java | 4 +- .../user/DestinationUserNameProvider.java | 20 +++++ .../DefaultUserDestinationResolverTests.java | 26 ++++--- .../messaging/StompSubProtocolHandler.java | 18 ++++- .../StompSubProtocolHandlerTests.java | 76 ++++++++++++++++--- 5 files changed, 117 insertions(+), 27 deletions(-) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/user/DestinationUserNameProvider.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java index 75f4146bb0e..0505acf2037 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java @@ -165,7 +165,9 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { return true; } - protected String getTargetDestination(String origDestination, String targetDestination, String sessionId, String user) { + protected String getTargetDestination(String origDestination, String targetDestination, + String sessionId, String user) { + return targetDestination + "-user" + sessionId; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DestinationUserNameProvider.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DestinationUserNameProvider.java new file mode 100644 index 00000000000..8c0ac7d191b --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DestinationUserNameProvider.java @@ -0,0 +1,20 @@ +package org.springframework.messaging.simp.user; + +/** + * An interface to be implemented in addition to {@link java.security.Principal} + * when {@link java.security.Principal#getName()} is not globally unique enough + * for use in user destinations. For more on user destination see + * {@link org.springframework.messaging.simp.user.UserDestinationResolver}. + * + * @author Rossen Stoyanchev + * @since 4.0.1 + */ +public interface DestinationUserNameProvider { + + + /** + * Return the (globally unique) user name to use with user destinations. + */ + String getDestinationUserName(); + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java index bf4f4b50003..a72aae58c43 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java @@ -39,19 +39,21 @@ public class DefaultUserDestinationResolverTests { private UserSessionRegistry registry; + private TestPrincipal user; + @Before public void setup() { + this.user = new TestPrincipal("joe"); this.registry = new DefaultUserSessionRegistry(); - this.registry.registerSessionId("joe", SESSION_ID); - + this.registry.registerSessionId(this.user.getName(), SESSION_ID); this.resolver = new DefaultUserDestinationResolver(this.registry); } @Test public void handleSubscribe() { - Message message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo"); + Message message = createMessage(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo"); Set actual = this.resolver.resolveDestination(message); assertEquals(1, actual.size()); @@ -66,7 +68,7 @@ public class DefaultUserDestinationResolverTests { this.registry.registerSessionId("joe", "456"); this.registry.registerSessionId("joe", "789"); - Message message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo"); + Message message = createMessage(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo"); Set actual = this.resolver.resolveDestination(message); assertEquals(1, actual.size()); @@ -75,7 +77,7 @@ public class DefaultUserDestinationResolverTests { @Test public void handleUnsubscribe() { - Message message = createMessage(SimpMessageType.UNSUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo"); + Message message = createMessage(SimpMessageType.UNSUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo"); Set actual = this.resolver.resolveDestination(message); assertEquals(1, actual.size()); @@ -84,7 +86,7 @@ public class DefaultUserDestinationResolverTests { @Test public void handleMessage() { - Message message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/user/joe/queue/foo"); + Message message = createMessage(SimpMessageType.MESSAGE, this.user, SESSION_ID, "/user/joe/queue/foo"); Set actual = this.resolver.resolveDestination(message); assertEquals(1, actual.size()); @@ -96,12 +98,12 @@ public class DefaultUserDestinationResolverTests { public void ignoreMessage() { // no destination - Message message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, null); + Message message = createMessage(SimpMessageType.MESSAGE, this.user, SESSION_ID, null); Set actual = this.resolver.resolveDestination(message); assertEquals(0, actual.size()); // not a user destination - message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/queue/foo"); + message = createMessage(SimpMessageType.MESSAGE, this.user, SESSION_ID, "/queue/foo"); actual = this.resolver.resolveDestination(message); assertEquals(0, actual.size()); @@ -111,24 +113,24 @@ public class DefaultUserDestinationResolverTests { assertEquals(0, actual.size()); // subscribe + not a user destination - message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/queue/foo"); + message = createMessage(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/queue/foo"); actual = this.resolver.resolveDestination(message); assertEquals(0, actual.size()); // no match on message type - message = createMessage(SimpMessageType.CONNECT, "joe", SESSION_ID, "user/joe/queue/foo"); + message = createMessage(SimpMessageType.CONNECT, this.user, SESSION_ID, "user/joe/queue/foo"); actual = this.resolver.resolveDestination(message); assertEquals(0, actual.size()); } - private Message createMessage(SimpMessageType messageType, String user, String sessionId, String destination) { + private Message createMessage(SimpMessageType messageType, TestPrincipal user, String sessionId, String destination) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType); if (destination != null) { headers.setDestination(destination); } if (user != null) { - headers.setUser(new TestPrincipal(user)); + headers.setUser(user); } if (sessionId != null) { headers.setSessionId(sessionId); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 7349d844db0..aab7bebc6f7 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -35,6 +35,7 @@ import org.springframework.messaging.simp.stomp.StompConversionException; import org.springframework.messaging.simp.stomp.StompDecoder; import org.springframework.messaging.simp.stomp.StompEncoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; @@ -240,11 +241,20 @@ public class StompSubProtocolHandler implements SubProtocolHandler { if (principal != null) { headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); if (this.userSessionRegistry != null) { - this.userSessionRegistry.registerSessionId(principal.getName(), session.getId()); + String userName = getNameForUserSessionRegistry(principal); + this.userSessionRegistry.registerSessionId(userName, session.getId()); } } } + private String getNameForUserSessionRegistry(Principal principal) { + String userName = principal.getName(); + if (principal instanceof DestinationUserNameProvider) { + userName = ((DestinationUserNameProvider) principal).getDestinationUserName(); + } + return userName; + } + @Override public String resolveSessionId(Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); @@ -258,8 +268,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler { @Override public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) { - if ((this.userSessionRegistry != null) && (session.getPrincipal() != null)) { - this.userSessionRegistry.unregisterSessionId(session.getPrincipal().getName(), session.getId()); + Principal principal = session.getPrincipal(); + if ((this.userSessionRegistry != null) && (principal != null)) { + String userName = getNameForUserSessionRegistry(principal); + this.userSessionRegistry.unregisterSessionId(userName, session.getId()); } StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 84fe4605b2f..0822e2357bc 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.web.socket.messaging; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import org.junit.Before; @@ -32,6 +33,9 @@ import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompDecoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; +import org.springframework.messaging.simp.user.DestinationUserNameProvider; +import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; @@ -47,7 +51,7 @@ import static org.mockito.Mockito.*; */ public class StompSubProtocolHandlerTests { - private StompSubProtocolHandler stompHandler; + private StompSubProtocolHandler protocolHandler; private TestWebSocketSession session; @@ -58,7 +62,7 @@ public class StompSubProtocolHandlerTests { @Before public void setup() { - this.stompHandler = new StompSubProtocolHandler(); + this.protocolHandler = new StompSubProtocolHandler(); this.channel = Mockito.mock(MessageChannel.class); this.messageCaptor = ArgumentCaptor.forClass(Message.class); @@ -68,18 +72,55 @@ public class StompSubProtocolHandlerTests { } @Test - public void connectedResponseIsSentWhenConnectAckIsToBeSentToClient() { + public void handleMessageToClientConnected() { + + UserSessionRegistry registry = new DefaultUserSessionRegistry(); + this.protocolHandler.setUserSessionRegistry(registry); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); + Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + this.protocolHandler.handleMessageToClient(this.session, message); + + assertEquals(1, this.session.getSentMessages().size()); + WebSocketMessage textMessage = this.session.getSentMessages().get(0); + assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload()); + + assertEquals(Collections.singleton("s1"), registry.getSessionIds("joe")); + } + + @Test + public void handleMessageToClientConnectedUniqueUserName() { + + this.session.setPrincipal(new UniqueUser("joe")); + + UserSessionRegistry registry = new DefaultUserSessionRegistry(); + this.protocolHandler.setUserSessionRegistry(registry); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); + Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + this.protocolHandler.handleMessageToClient(this.session, message); + + assertEquals(1, this.session.getSentMessages().size()); + WebSocketMessage textMessage = this.session.getSentMessages().get(0); + assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload()); + + assertEquals(Collections.emptySet(), registry.getSessionIds("joe")); + assertEquals(Collections.singleton("s1"), registry.getSessionIds("Me myself and I")); + } + + @Test + public void handleMessageToClientConnectAck() { + StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT); connectHeaders.setHeartbeat(10000, 10000); connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1"); - Message connectMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build(); SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage); + Message connectAckMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build(); - Message connectAck = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build(); - this.stompHandler.handleMessageToClient(this.session, connectAck); + this.protocolHandler.handleMessageToClient(this.session, connectAckMessage); verifyNoMoreInteractions(this.channel); @@ -97,12 +138,12 @@ public class StompSubProtocolHandlerTests { } @Test - public void messagesAreAugmentedAndForwarded() { + public void handleMessageFromClient() { TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers( "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build(); - this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel); + this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); verify(this.channel).send(this.messageCaptor.capture()); Message actual = this.messageCaptor.getValue(); @@ -121,11 +162,11 @@ public class StompSubProtocolHandlerTests { } @Test - public void invalidStompCommand() { + public void handleMessageFromClientInvalidStompCommand() { TextMessage textMessage = new TextMessage("FOO"); - this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel); + this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); verifyZeroInteractions(this.channel); assertEquals(1, this.session.getSentMessages().size()); @@ -133,4 +174,17 @@ public class StompSubProtocolHandlerTests { assertTrue(actual.getPayload().startsWith("ERROR")); } + + private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider { + + private UniqueUser(String name) { + super(name); + } + + @Override + public String getDestinationUserName() { + return "Me myself and I"; + } + } + }