From 618771d59d050c2f36d9e203501abd833fa39b93 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Sun, 29 Jun 2014 16:52:20 -0400 Subject: [PATCH] Add check for unused WebSocket sessions Sessions connected to a STOMP endpoint are expected to receive some client messages. Having received none after successfully connecting could be an indication of proxy or network issue. This change adds periodic checks to see if we have not received any messages on a session which is an indication the session isn't going anywhere most likely due to a proxy issue (or unreliable network) and close those sessions. This is a backport for commit: https://github.com/spring-projects/spring-framework/commit/a3fa9c979777e554efad0df429041767f05dfdb8 Issue: SPR-11884 --- .../SubProtocolWebSocketHandler.java | 121 ++++++++++++++++-- .../SubProtocolWebSocketHandlerTests.java | 36 ++++++ 2 files changed, 146 insertions(+), 11 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index e85f0eadfbc..19b4ec6530f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Set; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -66,6 +67,15 @@ import org.springframework.web.socket.handler.SessionLimitExceededException; public class SubProtocolWebSocketHandler implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle, ApplicationEventPublisherAware { + /** + * Sessions connected to this handler use a sub-protocol. Hence we expect to + * receive some client messages. If we don't receive any within a minute, the + * connection isn't doing well (proxy issue, slow network?) and can be closed. + * @see #checkSessions() + */ + private static final int TIME_TO_FIRST_MESSAGE = 60 * 1000; + + private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); private final MessageChannel clientInboundChannel; @@ -77,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, private SubProtocolHandler defaultProtocolHandler; - private final Map sessions = new ConcurrentHashMap(); + private final Map sessions = new ConcurrentHashMap(); private int sendTimeLimit = 10 * 1000; private int sendBufferSizeLimit = 512 * 1024; + private volatile long lastSessionCheckTime = System.currentTimeMillis(); + + private final ReentrantLock sessionCheckLock = new ReentrantLock(); + private Object lifecycleMonitor = new Object(); private volatile boolean running = false; @@ -227,12 +241,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, this.clientOutboundChannel.unsubscribe(this); // Notify sessions to stop flushing messages - for (WebSocketSession session : this.sessions.values()) { + for (WebSocketSessionHolder holder : this.sessions.values()) { try { - session.close(CloseStatus.GOING_AWAY); + holder.getSession().close(CloseStatus.GOING_AWAY); } catch (Throwable t) { - logger.error("Failed to close session id '" + session.getId() + "': " + t.getMessage()); + logger.error("Failed to close session id '" + holder.getSession().getId() + "': " + t.getMessage()); } } } @@ -251,7 +265,7 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit()); - this.sessions.put(session.getId(), session); + this.sessions.put(session.getId(), new WebSocketSessionHolder(session)); if (logger.isDebugEnabled()) { logger.debug("Started WebSocket session=" + session.getId() + ", number of sessions=" + this.sessions.size()); @@ -296,7 +310,13 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @Override public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { - findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel); + SubProtocolHandler protocolHandler = findProtocolHandler(session); + protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel); + WebSocketSessionHolder holder = this.sessions.get(session.getId()); + if (holder != null) { + holder.setHasHandledMessages(); + } + checkSessions(); } @Override @@ -307,13 +327,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, logger.error("sessionId not found in message " + message); return; } - - WebSocketSession session = this.sessions.get(sessionId); - if (session == null) { - logger.error("Session not found for session with id '" + sessionId + "', ignoring message " + message); + WebSocketSessionHolder holder = this.sessions.get(sessionId); + if (holder == null) { + logger.error("No session for " + message); return; } - + WebSocketSession session = holder.getSession(); try { findProtocolHandler(session).handleMessageToClient(session, message); } @@ -350,6 +369,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, return null; } + /** + * When a session is connected through a higher-level protocol it has a chance + * to use heartbeat management to shut down sessions that are too slow to send + * or receive messages. However, after a WebSocketSession is established and + * before the higher level protocol is fully connected there is a possibility + * for sessions to hang. This method checks and closes any sessions that have + * been connected for more than 60 seconds without having received a single + * message. + */ + private void checkSessions() throws IOException { + long currentTime = System.currentTimeMillis(); + if (!isRunning() || (currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE)) { + return; + } + if (this.sessionCheckLock.tryLock()) { + try { + for (WebSocketSessionHolder holder : this.sessions.values()) { + if (holder.hasHandledMessages()) { + continue; + } + long timeSinceCreated = currentTime - holder.getCreateTime(); + if (timeSinceCreated < TIME_TO_FIRST_MESSAGE) { + continue; + } + WebSocketSession session = holder.getSession(); + if (logger.isErrorEnabled()) { + logger.error("No messages received after " + timeSinceCreated + " ms. " + + "Closing " + holder.getSession() + "."); + } + try { + session.close(CloseStatus.SESSION_NOT_RELIABLE); + } + catch (Throwable t) { + logger.error("Failure while closing " + session, t); + } + } + } + finally { + this.sessionCheckLock.unlock(); + } + } + } + @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { } @@ -369,4 +431,41 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, return false; } + + private static class WebSocketSessionHolder { + + private final WebSocketSession session; + + private final long createTime = System.currentTimeMillis(); + + private volatile boolean handledMessages; + + + private WebSocketSessionHolder(WebSocketSession session) { + this.session = session; + } + + public WebSocketSession getSession() { + return this.session; + } + + public long getCreateTime() { + return this.createTime; + } + + public void setHasHandledMessages() { + this.handledMessages = true; + } + + public boolean hasHandledMessages() { + return this.handledMessages; + } + + @Override + public String toString() { + return "WebSocketSessionHolder[=session=" + this.session + ", createTime=" + + this.createTime + ", hasHandledMessages=" + this.handledMessages + "]"; + } + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java index 31cf82ac766..20bf97e8733 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java @@ -17,16 +17,24 @@ package org.springframework.web.socket.messaging; import java.util.Arrays; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.springframework.beans.DirectFieldAccessor; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; import org.springframework.web.socket.handler.TestWebSocketSession; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.*; /** @@ -140,4 +148,32 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.afterConnectionEstablished(session); } + @Test + public void checkSession() throws Exception { + TestWebSocketSession session1 = new TestWebSocketSession("id1"); + TestWebSocketSession session2 = new TestWebSocketSession("id2"); + session1.setAcceptedProtocol("v12.stomp"); + session2.setAcceptedProtocol("v12.stomp"); + + this.webSocketHandler.setProtocolHandlers(Arrays.asList(this.stompHandler)); + this.webSocketHandler.afterConnectionEstablished(session1); + this.webSocketHandler.afterConnectionEstablished(session2); + session1.setOpen(true); + session2.setOpen(true); + + long sixtyOneSecondsAgo = System.currentTimeMillis() - 61 * 1000; + new DirectFieldAccessor(this.webSocketHandler).setPropertyValue("lastSessionCheckTime", sixtyOneSecondsAgo); + Map sessions = (Map) new DirectFieldAccessor(this.webSocketHandler).getPropertyValue("sessions"); + new DirectFieldAccessor(sessions.get("id1")).setPropertyValue("createTime", sixtyOneSecondsAgo); + new DirectFieldAccessor(sessions.get("id2")).setPropertyValue("createTime", sixtyOneSecondsAgo); + + this.webSocketHandler.start(); + this.webSocketHandler.handleMessage(session1, new TextMessage("foo")); + + assertTrue(session1.isOpen()); + assertFalse(session2.isOpen()); + assertNull(session1.getCloseStatus()); + assertEquals(CloseStatus.SESSION_NOT_RELIABLE, session2.getCloseStatus()); + } + }