Browse Source

Add check for unused WebSocket sessions

Sessions connected to a STOMP endpoint are expected to receive some
client messages. Having received none after successfully connecting
could be an indication of proxy or network issue. This change adds
periodic checks to see if we have not received any messages on a
session which is an indication the session isn't going anywhere
most likely due to a proxy issue (or unreliable network) and close
those sessions.

This is a backport for commit:
a3fa9c9797

Issue: SPR-11884
pull/579/head
Rossen Stoyanchev 12 years ago
parent
commit
618771d59d
  1. 121
      spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
  2. 36
      spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java

121
spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java

@ -25,6 +25,7 @@ import java.util.Map; @@ -25,6 +25,7 @@ import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -66,6 +67,15 @@ import org.springframework.web.socket.handler.SessionLimitExceededException; @@ -66,6 +67,15 @@ import org.springframework.web.socket.handler.SessionLimitExceededException;
public class SubProtocolWebSocketHandler implements WebSocketHandler,
SubProtocolCapable, MessageHandler, SmartLifecycle, ApplicationEventPublisherAware {
/**
* Sessions connected to this handler use a sub-protocol. Hence we expect to
* receive some client messages. If we don't receive any within a minute, the
* connection isn't doing well (proxy issue, slow network?) and can be closed.
* @see #checkSessions()
*/
private static final int TIME_TO_FIRST_MESSAGE = 60 * 1000;
private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
private final MessageChannel clientInboundChannel;
@ -77,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -77,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
private SubProtocolHandler defaultProtocolHandler;
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap<String, WebSocketSessionHolder>();
private int sendTimeLimit = 10 * 1000;
private int sendBufferSizeLimit = 512 * 1024;
private volatile long lastSessionCheckTime = System.currentTimeMillis();
private final ReentrantLock sessionCheckLock = new ReentrantLock();
private Object lifecycleMonitor = new Object();
private volatile boolean running = false;
@ -227,12 +241,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -227,12 +241,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
this.clientOutboundChannel.unsubscribe(this);
// Notify sessions to stop flushing messages
for (WebSocketSession session : this.sessions.values()) {
for (WebSocketSessionHolder holder : this.sessions.values()) {
try {
session.close(CloseStatus.GOING_AWAY);
holder.getSession().close(CloseStatus.GOING_AWAY);
}
catch (Throwable t) {
logger.error("Failed to close session id '" + session.getId() + "': " + t.getMessage());
logger.error("Failed to close session id '" + holder.getSession().getId() + "': " + t.getMessage());
}
}
}
@ -251,7 +265,7 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -251,7 +265,7 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit());
this.sessions.put(session.getId(), session);
this.sessions.put(session.getId(), new WebSocketSessionHolder(session));
if (logger.isDebugEnabled()) {
logger.debug("Started WebSocket session=" + session.getId() +
", number of sessions=" + this.sessions.size());
@ -296,7 +310,13 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -296,7 +310,13 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel);
SubProtocolHandler protocolHandler = findProtocolHandler(session);
protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel);
WebSocketSessionHolder holder = this.sessions.get(session.getId());
if (holder != null) {
holder.setHasHandledMessages();
}
checkSessions();
}
@Override
@ -307,13 +327,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -307,13 +327,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
logger.error("sessionId not found in message " + message);
return;
}
WebSocketSession session = this.sessions.get(sessionId);
if (session == null) {
logger.error("Session not found for session with id '" + sessionId + "', ignoring message " + message);
WebSocketSessionHolder holder = this.sessions.get(sessionId);
if (holder == null) {
logger.error("No session for " + message);
return;
}
WebSocketSession session = holder.getSession();
try {
findProtocolHandler(session).handleMessageToClient(session, message);
}
@ -350,6 +369,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -350,6 +369,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
return null;
}
/**
* When a session is connected through a higher-level protocol it has a chance
* to use heartbeat management to shut down sessions that are too slow to send
* or receive messages. However, after a WebSocketSession is established and
* before the higher level protocol is fully connected there is a possibility
* for sessions to hang. This method checks and closes any sessions that have
* been connected for more than 60 seconds without having received a single
* message.
*/
private void checkSessions() throws IOException {
long currentTime = System.currentTimeMillis();
if (!isRunning() || (currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE)) {
return;
}
if (this.sessionCheckLock.tryLock()) {
try {
for (WebSocketSessionHolder holder : this.sessions.values()) {
if (holder.hasHandledMessages()) {
continue;
}
long timeSinceCreated = currentTime - holder.getCreateTime();
if (timeSinceCreated < TIME_TO_FIRST_MESSAGE) {
continue;
}
WebSocketSession session = holder.getSession();
if (logger.isErrorEnabled()) {
logger.error("No messages received after " + timeSinceCreated + " ms. " +
"Closing " + holder.getSession() + ".");
}
try {
session.close(CloseStatus.SESSION_NOT_RELIABLE);
}
catch (Throwable t) {
logger.error("Failure while closing " + session, t);
}
}
}
finally {
this.sessionCheckLock.unlock();
}
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
}
@ -369,4 +431,41 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -369,4 +431,41 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
return false;
}
private static class WebSocketSessionHolder {
private final WebSocketSession session;
private final long createTime = System.currentTimeMillis();
private volatile boolean handledMessages;
private WebSocketSessionHolder(WebSocketSession session) {
this.session = session;
}
public WebSocketSession getSession() {
return this.session;
}
public long getCreateTime() {
return this.createTime;
}
public void setHasHandledMessages() {
this.handledMessages = true;
}
public boolean hasHandledMessages() {
return this.handledMessages;
}
@Override
public String toString() {
return "WebSocketSessionHolder[=session=" + this.session + ", createTime=" +
this.createTime + ", hasHandledMessages=" + this.handledMessages + "]";
}
}
}

36
spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java

@ -17,16 +17,24 @@ @@ -17,16 +17,24 @@
package org.springframework.web.socket.messaging;
import java.util.Arrays;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.TestWebSocketSession;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*;
/**
@ -140,4 +148,32 @@ public class SubProtocolWebSocketHandlerTests { @@ -140,4 +148,32 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.afterConnectionEstablished(session);
}
@Test
public void checkSession() throws Exception {
TestWebSocketSession session1 = new TestWebSocketSession("id1");
TestWebSocketSession session2 = new TestWebSocketSession("id2");
session1.setAcceptedProtocol("v12.stomp");
session2.setAcceptedProtocol("v12.stomp");
this.webSocketHandler.setProtocolHandlers(Arrays.asList(this.stompHandler));
this.webSocketHandler.afterConnectionEstablished(session1);
this.webSocketHandler.afterConnectionEstablished(session2);
session1.setOpen(true);
session2.setOpen(true);
long sixtyOneSecondsAgo = System.currentTimeMillis() - 61 * 1000;
new DirectFieldAccessor(this.webSocketHandler).setPropertyValue("lastSessionCheckTime", sixtyOneSecondsAgo);
Map<String, ?> sessions = (Map<String, ?>) new DirectFieldAccessor(this.webSocketHandler).getPropertyValue("sessions");
new DirectFieldAccessor(sessions.get("id1")).setPropertyValue("createTime", sixtyOneSecondsAgo);
new DirectFieldAccessor(sessions.get("id2")).setPropertyValue("createTime", sixtyOneSecondsAgo);
this.webSocketHandler.start();
this.webSocketHandler.handleMessage(session1, new TextMessage("foo"));
assertTrue(session1.isOpen());
assertFalse(session2.isOpen());
assertNull(session1.getCloseStatus());
assertEquals(CloseStatus.SESSION_NOT_RELIABLE, session2.getCloseStatus());
}
}

Loading…
Cancel
Save