From 7af74b24753a2692cbfdde24999b94e13e7fcb1f Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Sun, 23 Mar 2014 02:12:57 -0400 Subject: [PATCH] Disable SockJS heartbeat if STOMP heartbeat is on --- .../WebMvcStompEndpointRegistry.java | 3 +- .../handler/WebSocketHandlerDecorator.java | 9 +++++ .../handler/WebSocketSessionDecorator.java | 33 ++++++++++++------- .../messaging/StompSubProtocolHandler.java | 10 ++++++ .../support/DefaultHandshakeHandler.java | 7 +--- .../sockjs/transport/SockJsSession.java | 9 ++++- .../session/AbstractSockJsSession.java | 13 ++++++++ .../StompSubProtocolHandlerTests.java | 15 +++++++++ .../transport/session/SockJsSessionTests.java | 18 +++++----- 9 files changed, 87 insertions(+), 30 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java index 8703563164e..af1f9b693b6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java @@ -70,8 +70,7 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry { } private static SubProtocolWebSocketHandler unwrapSubProtocolWebSocketHandler(WebSocketHandler webSocketHandler) { - WebSocketHandler actual = (webSocketHandler instanceof WebSocketHandlerDecorator) ? - ((WebSocketHandlerDecorator) webSocketHandler).getLastHandler() : webSocketHandler; + WebSocketHandler actual = WebSocketHandlerDecorator.unwrap(webSocketHandler); Assert.isInstanceOf(SubProtocolWebSocketHandler.class, actual, "No SubProtocolWebSocketHandler found: " + webSocketHandler); return (SubProtocolWebSocketHandler) actual; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java index ec5eda452ba..538c3b61db1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java @@ -56,6 +56,15 @@ public class WebSocketHandlerDecorator implements WebSocketHandler { return result; } + public static WebSocketHandler unwrap(WebSocketHandler handler) { + if (handler instanceof WebSocketHandlerDecorator) { + return ((WebSocketHandlerDecorator) handler).getLastHandler(); + } + else { + return handler; + } + } + @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { this.delegate.afterConnectionEstablished(session); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java index cea06b18f9e..87985d1065c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java @@ -52,6 +52,27 @@ public class WebSocketSessionDecorator implements WebSocketSession { } + public WebSocketSession getDelegate() { + return this.delegate; + } + + public WebSocketSession getLastSession() { + WebSocketSession result = this.delegate; + while (result instanceof WebSocketSessionDecorator) { + result = ((WebSocketSessionDecorator) result).getDelegate(); + } + return result; + } + + public static WebSocketSession unwrap(WebSocketSession session) { + if (session instanceof WebSocketSessionDecorator) { + return ((WebSocketSessionDecorator) session).getLastSession(); + } + else { + return session; + } + } + @Override public String getId() { return this.delegate.getId(); @@ -117,18 +138,6 @@ public class WebSocketSessionDecorator implements WebSocketSession { this.delegate.close(status); } - public WebSocketSession getDelegate() { - return this.delegate; - } - - public WebSocketSession getLastSession() { - WebSocketSession result = this.delegate; - while (result instanceof WebSocketSessionDecorator) { - result = ((WebSocketSessionDecorator) result).getDelegate(); - } - return result; - } - @Override public String toString() { return getClass().getSimpleName() + " [delegate=" + this.delegate + "]"; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index e9de1bbd34c..67f8963cfbe 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -44,6 +44,8 @@ import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.SessionLimitExceededException; +import org.springframework.web.socket.handler.WebSocketSessionDecorator; +import org.springframework.web.socket.sockjs.transport.SockJsSession; /** * A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2 @@ -253,6 +255,14 @@ public class StompSubProtocolHandler implements SubProtocolHandler { this.userSessionRegistry.registerSessionId(userName, session.getId()); } } + long[] heartbeat = headers.getHeartbeat(); + if (heartbeat[1] > 0) { + session = WebSocketSessionDecorator.unwrap(session); + if (session instanceof SockJsSession) { + logger.debug("STOMP heartbeats negotiated, disabling SockJS heartbeats."); + ((SockJsSession) session).disableHeartbeat(); + } + } } private String resolveNameForUserSessionRegistry(Principal principal) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java index 4f785939f91..7fcca127829 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java @@ -301,16 +301,11 @@ public class DefaultHandshakeHandler implements HandshakeHandler { * @return a list of supported protocols or an empty list */ protected final List determineHandlerSupportedProtocols(WebSocketHandler handler) { + handler = WebSocketHandlerDecorator.unwrap(handler); List subProtocols = null; if (handler instanceof SubProtocolCapable) { subProtocols = ((SubProtocolCapable) handler).getSubProtocols(); } - else if (handler instanceof WebSocketHandlerDecorator) { - WebSocketHandler lastHandler = ((WebSocketHandlerDecorator) handler).getLastHandler(); - if (lastHandler instanceof SubProtocolCapable) { - subProtocols = ((SubProtocolCapable) lastHandler).getSubProtocols();; - } - } return (subProtocols != null) ? subProtocols : Collections.emptyList(); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/SockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/SockJsSession.java index 216216c0085..c600795cf47 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/SockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/SockJsSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,4 +33,11 @@ public interface SockJsSession extends WebSocketSession { */ long getTimeSinceLastActive(); + /** + * Disable SockJS heartbeat, presumably because a higher level protocol has + * heartbeats enabled for the session. It is not recommended to disable this + * otherwise as it helps proxies to know the connection is not hanging. + */ + void disableHeartbeat(); + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index afd228d6f37..245279387db 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -99,12 +99,16 @@ public abstract class AbstractSockJsSession implements SockJsSession { private volatile State state = State.NEW; + private final long timeCreated = System.currentTimeMillis(); private volatile long timeLastActive = this.timeCreated; + private volatile ScheduledFuture heartbeatTask; + private volatile boolean heartbeatDisabled; + /** * Create a new instance. @@ -182,6 +186,12 @@ public abstract class AbstractSockJsSession implements SockJsSession { this.timeLastActive = System.currentTimeMillis(); } + @Override + public void disableHeartbeat() { + this.heartbeatDisabled = true; + cancelHeartbeat(); + } + public void delegateConnectionEstablished() throws Exception { this.state = State.OPEN; this.handler.afterConnectionEstablished(this); @@ -366,6 +376,9 @@ public abstract class AbstractSockJsSession implements SockJsSession { } protected void scheduleHeartbeat() { + if (this.heartbeatDisabled) { + return; + } Assert.state(this.config.getTaskScheduler() != null, "No TaskScheduler configured for heartbeat"); cancelHeartbeat(); if (!isActive()) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index bec416f9dc5..cfc769d0336 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -41,6 +41,8 @@ import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.handler.TestWebSocketSession; +import org.springframework.web.socket.sockjs.transport.SockJsSession; +import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -110,6 +112,19 @@ public class StompSubProtocolHandlerTests { assertEquals(Collections.singleton("s1"), registry.getSessionIds("Me myself and I")); } + @Test + public void handleMessageToClientConnectedWithHeartbeats() { + + SockJsSession sockJsSession = Mockito.mock(SockJsSession.class); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); + headers.setHeartbeat(0,10); + Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + this.protocolHandler.handleMessageToClient(sockJsSession, message); + + verify(sockJsSession).disableHeartbeat(); + } + @Test public void handleMessageToClientConnectAck() { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java index b4f5aec93fd..94ae5c64697 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -247,14 +247,6 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests