diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 42eaa1bd0c1..d9a06db5927 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -271,53 +271,59 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem private SockJsSession createSockJsSession(String sessionId, SockJsSessionFactory sessionFactory, WebSocketHandler handler, Map attributes) { - synchronized (this.sessions) { - SockJsSession session = this.sessions.get(sessionId); - if (session != null) { - return session; - } - if (this.sessionCleanupTask == null) { - scheduleSessionTask(); - } - - if (logger.isDebugEnabled()) { - logger.debug("Creating new session with session id \"" + sessionId + "\""); - } - session = sessionFactory.createSession(sessionId, handler, attributes); - this.sessions.put(sessionId, session); + SockJsSession session = this.sessions.get(sessionId); + if (session != null) { return session; } + + if (this.sessionCleanupTask == null) { + scheduleSessionTask(); + } + + if (logger.isDebugEnabled()) { + logger.debug("Creating new session with session id \"" + sessionId + "\""); + } + session = sessionFactory.createSession(sessionId, handler, attributes); + this.sessions.put(sessionId, session); + + return session; } private void scheduleSessionTask() { - this.sessionCleanupTask = getTaskScheduler().scheduleAtFixedRate(new Runnable() { - @Override - public void run() { - try { - int count = sessions.size(); - if (logger.isTraceEnabled() && (count != 0)) { - logger.trace("Checking " + count + " session(s) for timeouts [" + getName() + "]"); - } - for (SockJsSession session : sessions.values()) { - if (session.getTimeSinceLastActive() > getDisconnectDelay()) { - if (logger.isTraceEnabled()) { - logger.trace("Removing " + session + " for [" + getName() + "]"); + + synchronized (this.sessions) { + if (this.sessionCleanupTask != null) { + return; + } + this.sessionCleanupTask = getTaskScheduler().scheduleAtFixedRate(new Runnable() { + @Override + public void run() { + try { + int count = sessions.size(); + if (logger.isTraceEnabled() && (count != 0)) { + logger.trace("Checking " + count + " session(s) for timeouts [" + getName() + "]"); + } + for (SockJsSession session : sessions.values()) { + if (session.getTimeSinceLastActive() > getDisconnectDelay()) { + if (logger.isTraceEnabled()) { + logger.trace("Removing " + session + " for [" + getName() + "]"); + } + session.close(); + sessions.remove(session.getId()); } - session.close(); - sessions.remove(session.getId()); + } + if (logger.isTraceEnabled() && count > 0) { + logger.trace(sessions.size() + " remaining session(s) [" + getName() + "]"); } } - if (logger.isTraceEnabled() && count > 0) { - logger.trace(sessions.size() + " remaining session(s) [" + getName() + "]"); - } - } - catch (Throwable ex) { - if (logger.isErrorEnabled()) { - logger.error("Failed to complete session timeout checks for [" + getName() + "]", ex); + catch (Throwable ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to complete session timeout checks for [" + getName() + "]", ex); + } } } - } - }, getDisconnectDelay()); + }, getDisconnectDelay()); + } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java index 46c2a430268..ded74d49883 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/EventSourceTransportHandler.java @@ -21,6 +21,7 @@ import java.util.Map; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.frame.DefaultSockJsFrameFormat; import org.springframework.web.socket.sockjs.frame.SockJsFrameFormat; @@ -69,10 +70,10 @@ public class EventSourceTransportHandler extends AbstractHttpSendingTransportHan } @Override - protected void writePrelude() throws IOException { - getResponse().getBody().write('\r'); - getResponse().getBody().write('\n'); - getResponse().flush(); + protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { + response.getBody().write('\r'); + response.getBody().write('\n'); + response.flush(); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java index 276ea2bda69..763906f41a3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java @@ -133,15 +133,15 @@ public class HtmlFileTransportHandler extends AbstractHttpSendingTransportHandle } @Override - protected void writePrelude() { + protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) { // we already validated the parameter above.. - String callback = getCallbackParam(getRequest()); + String callback = getCallbackParam(request); String html = String.format(PARTIAL_HTML_CONTENT, callback); try { - getResponse().getBody().write(html.getBytes("UTF-8")); - getResponse().flush(); + response.getBody().write(html.getBytes("UTF-8")); + response.flush(); } catch (IOException e) { tryCloseWithSockJsTransportError(e, CloseStatus.SERVER_ERROR); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java index 65f4f657c4a..a27f369778d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/XhrStreamingTransportHandler.java @@ -21,6 +21,7 @@ import java.util.Map; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.frame.DefaultSockJsFrameFormat; import org.springframework.web.socket.sockjs.frame.SockJsFrameFormat; @@ -69,12 +70,12 @@ public class XhrStreamingTransportHandler extends AbstractHttpSendingTransportHa } @Override - protected void writePrelude() throws IOException { + protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { for (int i=0; i < 2048; i++) { - getResponse().getBody().write('h'); + response.getBody().write('h'); } - getResponse().getBody().write('\n'); - getResponse().flush(); + response.getBody().write('\n'); + response.flush(); } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index 5e49a510cf1..cdd15fa51d9 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -23,8 +23,8 @@ import java.security.Principal; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingQueue; import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpAsyncRequestControl; @@ -48,34 +48,36 @@ import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; */ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { - private final BlockingQueue messageCache; + private final Queue messageCache; - private ServerHttpRequest request; - private ServerHttpResponse response; + private volatile ServerHttpResponse response; - private ServerHttpAsyncRequestControl asyncRequestControl; + private volatile ServerHttpAsyncRequestControl asyncRequestControl; - private SockJsFrameFormat frameFormat; + private volatile SockJsFrameFormat frameFormat; - private URI uri; + private volatile boolean requestInitialized; - private HttpHeaders handshakeHeaders; - private Principal principal; + private volatile URI uri; - private InetSocketAddress localAddress; + private volatile HttpHeaders handshakeHeaders; - private InetSocketAddress remoteAddress; + private volatile Principal principal; - private String acceptedProtocol; + private volatile InetSocketAddress localAddress; + + private volatile InetSocketAddress remoteAddress; + + private volatile String acceptedProtocol; public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler, Map attributes) { super(id, config, wsHandler, attributes); - this.messageCache = new ArrayBlockingQueue(config.getHttpMessageCacheSize()); + this.messageCache = new LinkedBlockingQueue(config.getHttpMessageCacheSize()); } @@ -150,7 +152,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { * * @see #handleSuccessiveRequest(org.springframework.http.server.ServerHttpRequest, org.springframework.http.server.ServerHttpResponse, org.springframework.web.socket.sockjs.frame.SockJsFrameFormat) */ - public synchronized void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, + public void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, SockJsFrameFormat frameFormat) throws SockJsException { initRequest(request, response, frameFormat); @@ -162,7 +164,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.remoteAddress = request.getRemoteAddress(); try { - writePrelude(); + writePrelude(request, response); writeFrame(SockJsFrame.openFrame()); } catch (Throwable ex) { @@ -171,6 +173,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { } try { + this.requestInitialized = true; delegateConnectionEstablished(); } catch (Throwable ex) { @@ -185,13 +188,12 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { Assert.notNull(response, "Response must not be null"); Assert.notNull(frameFormat, "SockJsFrameFormat must not be null"); - this.request = request; this.response = response; - this.asyncRequestControl = request.getAsyncRequestControl(response); this.frameFormat = frameFormat; + this.asyncRequestControl = request.getAsyncRequestControl(response); } - protected void writePrelude() throws IOException { + protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { } /** @@ -217,12 +219,12 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { * * @see #handleInitialRequest(org.springframework.http.server.ServerHttpRequest, org.springframework.http.server.ServerHttpResponse, org.springframework.web.socket.sockjs.frame.SockJsFrameFormat) */ - public synchronized void handleSuccessiveRequest(ServerHttpRequest request, + public void handleSuccessiveRequest(ServerHttpRequest request, ServerHttpResponse response, SockJsFrameFormat frameFormat) throws SockJsException { initRequest(request, response, frameFormat); try { - writePrelude(); + writePrelude(request, response); } catch (Throwable ex) { tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); @@ -234,6 +236,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { protected void startAsyncRequest() throws SockJsException { try { this.asyncRequestControl.start(-1); + this.requestInitialized = true; scheduleHeartbeat(); tryFlushCache(); } @@ -244,24 +247,21 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { } @Override - public synchronized boolean isActive() { - return (this.asyncRequestControl != null && !this.asyncRequestControl.isCompleted()); + public boolean isActive() { + ServerHttpAsyncRequestControl control = this.asyncRequestControl; + return (control != null && !control.isCompleted()); } - protected BlockingQueue getMessageCache() { + protected Queue getMessageCache() { return this.messageCache; } - protected ServerHttpRequest getRequest() { - return this.request; - } - protected ServerHttpResponse getResponse() { return this.response; } @Override - protected final synchronized void sendMessageInternal(String message) throws SockJsTransportFailureException { + protected final void sendMessageInternal(String message) throws SockJsTransportFailureException { this.messageCache.add(message); tryFlushCache(); } @@ -274,7 +274,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { if (logger.isTraceEnabled()) { logger.trace(this.messageCache.size() + " message(s) to flush"); } - if (isActive()) { + if (isActive() && this.requestInitialized) { logger.trace("Flushing messages"); flushCache(); } @@ -295,30 +295,36 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { resetRequest(); } - protected synchronized void resetRequest() { + protected void resetRequest() { + + this.requestInitialized = false; updateLastActiveTime(); - if (isActive() && this.asyncRequestControl.isStarted()) { - try { - logger.debug("Completing asynchronous request"); - this.asyncRequestControl.complete(); - } - catch (Throwable ex) { - logger.error("Failed to complete request: " + ex.getMessage()); + + if (isActive()) { + ServerHttpAsyncRequestControl control = this.asyncRequestControl; + if (control.isStarted()) { + try { + logger.debug("Completing asynchronous request"); + control.complete(); + } + catch (Throwable ex) { + logger.error("Failed to complete request: " + ex.getMessage()); + } } } - this.request = null; + this.response = null; this.asyncRequestControl = null; } @Override - protected synchronized void writeFrameInternal(SockJsFrame frame) throws IOException { + protected void writeFrameInternal(SockJsFrame frame) throws IOException { if (isActive()) { frame = this.frameFormat.format(frame); if (logger.isTraceEnabled()) { logger.trace("Writing " + frame); } - this.response.getBody().write(frame.getContentBytes()); + getResponse().getBody().write(frame.getContentBytes()); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index b9a11c882de..cbc6b1d467d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -79,10 +79,11 @@ public abstract class AbstractSockJsSession implements SockJsSession { private static final Set disconnectedClientExceptions; static { + Set set = new HashSet(2); set.add("ClientAbortException"); // Tomcat set.add("EofException"); // Jetty - // IOException("Broken pipe") on WildFly and Glassfish + // java.io.IOException "Broken pipe" on WildFly, Glassfish (already covered) disconnectedClientExceptions = Collections.unmodifiableSet(set); } @@ -95,13 +96,13 @@ public abstract class AbstractSockJsSession implements SockJsSession { private final Map attributes; - private State state = State.NEW; + private volatile State state = State.NEW; private final long timeCreated = System.currentTimeMillis(); - private long timeLastActive = this.timeCreated; + private volatile long timeLastActive = this.timeCreated; - private ScheduledFuture heartbeatTask; + private volatile ScheduledFuture heartbeatTask; /** @@ -229,7 +230,7 @@ public abstract class AbstractSockJsSession implements SockJsSession { this.handler.handleTransportError(this, ex); } - public final synchronized void sendMessage(WebSocketMessage message) throws IOException { + public final void sendMessage(WebSocketMessage message) throws IOException { Assert.isTrue(!isClosed(), "Cannot send a message when session is closed"); Assert.isInstanceOf(TextMessage.class, message, "Expected text message: " + message); sendMessageInternal(((TextMessage) message).getPayload()); @@ -352,7 +353,7 @@ public abstract class AbstractSockJsSession implements SockJsSession { protected abstract void writeFrameInternal(SockJsFrame frame) throws IOException; - public synchronized void sendHeartbeat() throws SockJsTransportFailureException { + public void sendHeartbeat() throws SockJsTransportFailureException { if (isActive()) { writeFrame(SockJsFrame.heartbeatFrame()); scheduleHeartbeat(); @@ -382,13 +383,16 @@ public abstract class AbstractSockJsSession implements SockJsSession { } protected void cancelHeartbeat() { - if ((this.heartbeatTask != null) && !this.heartbeatTask.isDone()) { + + ScheduledFuture task = this.heartbeatTask; + this.heartbeatTask = null; + + if ((task != null) && !task.isDone()) { if (logger.isTraceEnabled()) { logger.trace("Cancelling heartbeat"); } - this.heartbeatTask.cancel(false); + task.cancel(false); } - this.heartbeatTask = null; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java index 32e9a67abfe..b16b66d9830 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.sockjs.transport.session; import java.util.Map; +import java.util.Queue; import java.util.concurrent.BlockingQueue; import org.springframework.web.socket.WebSocketHandler; @@ -43,7 +44,7 @@ public class PollingSockJsSession extends AbstractHttpSockJsSession { @Override protected void flushCache() throws SockJsTransportFailureException { cancelHeartbeat(); - BlockingQueue messageCache = getMessageCache(); + Queue messageCache = getMessageCache(); String[] messages = messageCache.toArray(new String[messageCache.size()]); messageCache.clear(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java index 7af49fb7695..70facc8776e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java @@ -48,7 +48,7 @@ public class StreamingSockJsSession extends AbstractHttpSockJsSession { @Override - public synchronized void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, + public void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, SockJsFrameFormat frameFormat) throws SockJsException { super.handleInitialRequest(request, response, frameFormat); @@ -87,13 +87,13 @@ public class StreamingSockJsSession extends AbstractHttpSockJsSession { } @Override - protected synchronized void resetRequest() { + protected void resetRequest() { super.resetRequest(); this.byteCount = 0; } @Override - protected synchronized void writeFrameInternal(SockJsFrame frame) throws IOException { + protected void writeFrameInternal(SockJsFrame frame) throws IOException { if (isActive()) { super.writeFrameInternal(frame); getResponse().flush(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java index e5bbe43aee9..d0c96bdf4ca 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java @@ -81,9 +81,6 @@ public class HttpSockJsSessionTests extends AbstractSockJsSessionTests