diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index e85f0eadfbc..19b4ec6530f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Set; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -66,6 +67,15 @@ import org.springframework.web.socket.handler.SessionLimitExceededException; public class SubProtocolWebSocketHandler implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle, ApplicationEventPublisherAware { + /** + * Sessions connected to this handler use a sub-protocol. Hence we expect to + * receive some client messages. If we don't receive any within a minute, the + * connection isn't doing well (proxy issue, slow network?) and can be closed. + * @see #checkSessions() + */ + private static final int TIME_TO_FIRST_MESSAGE = 60 * 1000; + + private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); private final MessageChannel clientInboundChannel; @@ -77,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, private SubProtocolHandler defaultProtocolHandler; - private final Map sessions = new ConcurrentHashMap(); + private final Map sessions = new ConcurrentHashMap(); private int sendTimeLimit = 10 * 1000; private int sendBufferSizeLimit = 512 * 1024; + private volatile long lastSessionCheckTime = System.currentTimeMillis(); + + private final ReentrantLock sessionCheckLock = new ReentrantLock(); + private Object lifecycleMonitor = new Object(); private volatile boolean running = false; @@ -227,12 +241,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, this.clientOutboundChannel.unsubscribe(this); // Notify sessions to stop flushing messages - for (WebSocketSession session : this.sessions.values()) { + for (WebSocketSessionHolder holder : this.sessions.values()) { try { - session.close(CloseStatus.GOING_AWAY); + holder.getSession().close(CloseStatus.GOING_AWAY); } catch (Throwable t) { - logger.error("Failed to close session id '" + session.getId() + "': " + t.getMessage()); + logger.error("Failed to close session id '" + holder.getSession().getId() + "': " + t.getMessage()); } } } @@ -251,7 +265,7 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit()); - this.sessions.put(session.getId(), session); + this.sessions.put(session.getId(), new WebSocketSessionHolder(session)); if (logger.isDebugEnabled()) { logger.debug("Started WebSocket session=" + session.getId() + ", number of sessions=" + this.sessions.size()); @@ -296,7 +310,13 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @Override public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { - findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel); + SubProtocolHandler protocolHandler = findProtocolHandler(session); + protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel); + WebSocketSessionHolder holder = this.sessions.get(session.getId()); + if (holder != null) { + holder.setHasHandledMessages(); + } + checkSessions(); } @Override @@ -307,13 +327,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, logger.error("sessionId not found in message " + message); return; } - - WebSocketSession session = this.sessions.get(sessionId); - if (session == null) { - logger.error("Session not found for session with id '" + sessionId + "', ignoring message " + message); + WebSocketSessionHolder holder = this.sessions.get(sessionId); + if (holder == null) { + logger.error("No session for " + message); return; } - + WebSocketSession session = holder.getSession(); try { findProtocolHandler(session).handleMessageToClient(session, message); } @@ -350,6 +369,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, return null; } + /** + * When a session is connected through a higher-level protocol it has a chance + * to use heartbeat management to shut down sessions that are too slow to send + * or receive messages. However, after a WebSocketSession is established and + * before the higher level protocol is fully connected there is a possibility + * for sessions to hang. This method checks and closes any sessions that have + * been connected for more than 60 seconds without having received a single + * message. + */ + private void checkSessions() throws IOException { + long currentTime = System.currentTimeMillis(); + if (!isRunning() || (currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE)) { + return; + } + if (this.sessionCheckLock.tryLock()) { + try { + for (WebSocketSessionHolder holder : this.sessions.values()) { + if (holder.hasHandledMessages()) { + continue; + } + long timeSinceCreated = currentTime - holder.getCreateTime(); + if (timeSinceCreated < TIME_TO_FIRST_MESSAGE) { + continue; + } + WebSocketSession session = holder.getSession(); + if (logger.isErrorEnabled()) { + logger.error("No messages received after " + timeSinceCreated + " ms. " + + "Closing " + holder.getSession() + "."); + } + try { + session.close(CloseStatus.SESSION_NOT_RELIABLE); + } + catch (Throwable t) { + logger.error("Failure while closing " + session, t); + } + } + } + finally { + this.sessionCheckLock.unlock(); + } + } + } + @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { } @@ -369,4 +431,41 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, return false; } + + private static class WebSocketSessionHolder { + + private final WebSocketSession session; + + private final long createTime = System.currentTimeMillis(); + + private volatile boolean handledMessages; + + + private WebSocketSessionHolder(WebSocketSession session) { + this.session = session; + } + + public WebSocketSession getSession() { + return this.session; + } + + public long getCreateTime() { + return this.createTime; + } + + public void setHasHandledMessages() { + this.handledMessages = true; + } + + public boolean hasHandledMessages() { + return this.handledMessages; + } + + @Override + public String toString() { + return "WebSocketSessionHolder[=session=" + this.session + ", createTime=" + + this.createTime + ", hasHandledMessages=" + this.handledMessages + "]"; + } + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java index 31cf82ac766..20bf97e8733 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java @@ -17,16 +17,24 @@ package org.springframework.web.socket.messaging; import java.util.Arrays; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.springframework.beans.DirectFieldAccessor; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; import org.springframework.web.socket.handler.TestWebSocketSession; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.*; /** @@ -140,4 +148,32 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.afterConnectionEstablished(session); } + @Test + public void checkSession() throws Exception { + TestWebSocketSession session1 = new TestWebSocketSession("id1"); + TestWebSocketSession session2 = new TestWebSocketSession("id2"); + session1.setAcceptedProtocol("v12.stomp"); + session2.setAcceptedProtocol("v12.stomp"); + + this.webSocketHandler.setProtocolHandlers(Arrays.asList(this.stompHandler)); + this.webSocketHandler.afterConnectionEstablished(session1); + this.webSocketHandler.afterConnectionEstablished(session2); + session1.setOpen(true); + session2.setOpen(true); + + long sixtyOneSecondsAgo = System.currentTimeMillis() - 61 * 1000; + new DirectFieldAccessor(this.webSocketHandler).setPropertyValue("lastSessionCheckTime", sixtyOneSecondsAgo); + Map sessions = (Map) new DirectFieldAccessor(this.webSocketHandler).getPropertyValue("sessions"); + new DirectFieldAccessor(sessions.get("id1")).setPropertyValue("createTime", sixtyOneSecondsAgo); + new DirectFieldAccessor(sessions.get("id2")).setPropertyValue("createTime", sixtyOneSecondsAgo); + + this.webSocketHandler.start(); + this.webSocketHandler.handleMessage(session1, new TextMessage("foo")); + + assertTrue(session1.isOpen()); + assertFalse(session2.isOpen()); + assertNull(session1.getCloseStatus()); + assertEquals(CloseStatus.SESSION_NOT_RELIABLE, session2.getCloseStatus()); + } + }