From cc33bfaf612fde9d74cff6e92433e3edc8ce9c17 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 29 Aug 2016 18:30:17 -0400 Subject: [PATCH] Support receipt on DISCONNECT with simple broker Issue: SPR-14568 --- .../simp/SimpMessageHeaderAccessor.java | 2 ++ .../broker/SimpleBrokerMessageHandler.java | 9 +++-- .../SimpleBrokerMessageHandlerTests.java | 23 +++++++++--- .../messaging/StompSubProtocolHandler.java | 22 ++++++++++-- .../StompSubProtocolHandlerTests.java | 36 ++++++++++++++++++- 5 files changed, 81 insertions(+), 11 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java index 7e3d7f12b69..5bec0672bb1 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java @@ -65,6 +65,8 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { public static final String CONNECT_MESSAGE_HEADER = "simpConnectMessage"; + public static final String DISCONNECT_MESSAGE_HEADER = "simpDisconnectMessage"; + public static final String HEART_BEAT_HEADER = "simpHeartbeat"; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java index 838d6705e77..8555338bdac 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java @@ -250,7 +250,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { } else if (SimpMessageType.DISCONNECT.equals(messageType)) { logMessage(message); - handleDisconnect(sessionId, user); + handleDisconnect(sessionId, user, message); } else if (SimpMessageType.SUBSCRIBE.equals(messageType)) { logMessage(message); @@ -285,12 +285,15 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { } } - private void handleDisconnect(String sessionId, Principal user) { + private void handleDisconnect(String sessionId, Principal user, Message origMessage) { this.sessions.remove(sessionId); this.subscriptionRegistry.unregisterAllSubscriptions(sessionId); SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); accessor.setSessionId(sessionId); accessor.setUser(user); + if (origMessage != null) { + accessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, origMessage); + } initHeaders(accessor); Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); getClientOutboundChannel().send(message); @@ -407,7 +410,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { long now = System.currentTimeMillis(); for (SessionInfo info : sessions.values()) { if (info.getReadInterval() > 0 && (now - info.getLastReadTime()) > info.getReadInterval()) { - handleDisconnect(info.getSessiondId(), info.getUser()); + handleDisconnect(info.getSessiondId(), info.getUser(), null); } if (info.getWriteInterval() > 0 && (now - info.getLastWriteTime()) > info.getWriteInterval()) { SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java index cc969cc5c7c..86d8b4f0fe7 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.messaging.simp.broker; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - import java.security.Principal; import java.util.Collections; import java.util.List; @@ -41,6 +38,21 @@ import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.support.MessageBuilder; import org.springframework.scheduling.TaskScheduler; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + /** * Unit tests for SimpleBrokerMessageHandler. * @@ -72,7 +84,7 @@ public class SimpleBrokerMessageHandlerTests { public void setup() { MockitoAnnotations.initMocks(this); this.messageHandler = new SimpleBrokerMessageHandler(this.clientInboundChannel, - this.clientOutboundChannel, this.brokerChannel, Collections.emptyList()); + this.clientOutboundChannel, this.brokerChannel, Collections.emptyList()); } @@ -130,6 +142,7 @@ public class SimpleBrokerMessageHandlerTests { Message captured = this.messageCaptor.getAllValues().get(0); assertEquals(SimpMessageType.DISCONNECT_ACK, SimpMessageHeaderAccessor.getMessageType(captured.getHeaders())); + assertSame(message, captured.getHeaders().get(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER)); assertEquals(sess1, SimpMessageHeaderAccessor.getSessionId(captured.getHeaders())); assertEquals("joe", SimpMessageHeaderAccessor.getUser(captured.getHeaders()).getName()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index a56a4e06281..98a3d65aead 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -34,6 +34,7 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.simp.SimpAttributes; import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -479,8 +480,15 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE stompAccessor = convertConnectAcktoStompConnected(stompAccessor); } else if (SimpMessageType.DISCONNECT_ACK.equals(messageType)) { - stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR); - stompAccessor.setMessage("Session closed."); + String receipt = getDisconnectReceipt(stompAccessor); + if (receipt != null) { + stompAccessor = StompHeaderAccessor.create(StompCommand.RECEIPT); + stompAccessor.setReceiptId(receipt); + } + else { + stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR); + stompAccessor.setMessage("Session closed."); + } } else if (SimpMessageType.HEARTBEAT.equals(messageType)) { stompAccessor = StompHeaderAccessor.createForHeartbeat(); @@ -533,6 +541,16 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE return connectedHeaders; } + private String getDisconnectReceipt(SimpMessageHeaderAccessor simpHeaders) { + String name = StompHeaderAccessor.DISCONNECT_MESSAGE_HEADER; + Message message = (Message) simpHeaders.getHeader(name); + if (message != null) { + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + return accessor.getReceipt(); + } + return null; + } + protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message message) { return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message)); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 22214eafab3..1c388c08d0e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -169,6 +169,40 @@ public class StompSubProtocolHandlerTests { "user-name:joe\n" + "\n" + "\u0000", actual.getPayload()); } + @Test + public void handleMessageToClientWithSimpDisconnectAck() { + + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT); + Message connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); + + SimpMessageHeaderAccessor ackAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); + ackAccessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, connectMessage); + Message ackMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, ackAccessor.getMessageHeaders()); + this.protocolHandler.handleMessageToClient(this.session, ackMessage); + + assertEquals(1, this.session.getSentMessages().size()); + TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); + assertEquals("ERROR\n" + "message:Session closed.\n" + "content-length:0\n" + + "\n\u0000", actual.getPayload()); + } + + @Test + public void handleMessageToClientWithSimpDisconnectAckAndReceipt() { + + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT); + accessor.setReceipt("message-123"); + Message connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); + + SimpMessageHeaderAccessor ackAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); + ackAccessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, connectMessage); + Message ackMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, ackAccessor.getMessageHeaders()); + this.protocolHandler.handleMessageToClient(this.session, ackMessage); + + assertEquals(1, this.session.getSentMessages().size()); + TextMessage actual = (TextMessage) this.session.getSentMessages().get(0); + assertEquals("RECEIPT\n" + "receipt-id:message-123\n" + "\n\u0000", actual.getPayload()); + } + @Test public void handleMessageToClientWithSimpHeartbeat() {