diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java
new file mode 100644
index 00000000000..ce70fad270a
--- /dev/null
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright 2002-2014 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.socket.handler;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.web.socket.WebSocketMessage;
+import org.springframework.web.socket.WebSocketSession;
+
+import java.io.IOException;
+import java.util.Queue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+
+
+/**
+ * Wraps a {@link org.springframework.web.socket.WebSocketSession} and guarantees
+ * only one thread can send messages at a time.
+ *
+ *
If a send is slow, subsequent attempts to send more messages from a different
+ * thread will fail to acquire the lock and the messages will be buffered instead --
+ * at that time the specified buffer size limit and send time limit will be checked
+ * and the session closed if the limits are exceeded.
+ *
+ * @author Rossen Stoyanchev
+ * @since 4.0.3
+ */
+public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorator {
+
+ private static Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class);
+
+
+ private final int sendTimeLimit;
+
+ private final int bufferSizeLimit;
+
+ private final Queue> buffer = new LinkedBlockingQueue>();
+
+ private final AtomicInteger bufferSize = new AtomicInteger();
+
+ private volatile long sendStartTime;
+
+ private final Lock lock = new ReentrantLock();
+
+
+ public ConcurrentWebSocketSessionDecorator(
+ WebSocketSession delegateSession, int sendTimeLimit, int bufferSizeLimit) {
+
+ super(delegateSession);
+ this.sendTimeLimit = sendTimeLimit;
+ this.bufferSizeLimit = bufferSizeLimit;
+ }
+
+
+ public int getBufferSize() {
+ return this.bufferSize.get();
+ }
+
+ public long getInProgressSendTime() {
+ long start = this.sendStartTime;
+ return (start > 0 ? (System.currentTimeMillis() - start) : 0);
+ }
+
+
+ public void sendMessage(WebSocketMessage> message) throws IOException {
+
+ this.buffer.add(message);
+ this.bufferSize.addAndGet(message.getPayloadLength());
+
+ do {
+ if (!tryFlushMessageBuffer()) {
+ checkSessionLimits();
+ break;
+ }
+ }
+ while (!this.buffer.isEmpty());
+ }
+
+ private boolean tryFlushMessageBuffer() throws IOException {
+
+ if (this.lock.tryLock()) {
+ try {
+ while (true) {
+ WebSocketMessage> messageToSend = this.buffer.poll();
+ if (messageToSend == null) {
+ break;
+ }
+ this.bufferSize.addAndGet(messageToSend.getPayloadLength() * -1);
+ this.sendStartTime = System.currentTimeMillis();
+ getDelegate().sendMessage(messageToSend);
+ this.sendStartTime = 0;
+ }
+ }
+ finally {
+ this.sendStartTime = 0;
+ lock.unlock();
+ }
+ return true;
+ }
+
+ return false;
+ }
+
+ private void checkSessionLimits() throws IOException {
+ if (getInProgressSendTime() > this.sendTimeLimit) {
+ logError("A message could not be sent due to a timeout");
+ getDelegate().close();
+ }
+ else if (this.bufferSize.get() > this.bufferSizeLimit) {
+ logError("The total send buffer byte count '" + this.bufferSize.get() +
+ "' for session '" + getId() + "' exceeds the allowed limit '" + this.bufferSizeLimit + "'");
+ getDelegate().close();
+ }
+ }
+
+ private void logError(String reason) {
+ logger.error(reason + ", number of buffered messages is '" + this.buffer.size() +
+ "', time since the last send started is '" + getInProgressSendTime() + "' (ms)");
+ }
+
+}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java
index bf13a8dbf4a..ec5eda452ba 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java
@@ -23,6 +23,13 @@ import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
/**
+ * Wraps another {@link org.springframework.web.socket.WebSocketHandler}
+ * instance and delegates to it.
+ *
+ * Also provides a {@link #getDelegate()} method to return the decorated
+ * handler as well as a {@link #getLastHandler()} method to go through all nested
+ * delegates and return the "last" handler.
+ *
* @author Rossen Stoyanchev
* @since 4.0
*/
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java
new file mode 100644
index 00000000000..cea06b18f9e
--- /dev/null
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright 2002-2014 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.socket.handler;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.util.Assert;
+import org.springframework.web.socket.CloseStatus;
+import org.springframework.web.socket.WebSocketExtension;
+import org.springframework.web.socket.WebSocketMessage;
+import org.springframework.web.socket.WebSocketSession;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.URI;
+import java.security.Principal;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Wraps another {@link org.springframework.web.socket.WebSocketSession} instance
+ * and delegates to it.
+ *
+ *
Also provides a {@link #getDelegate()} method to return the decorated session
+ * as well as a {@link #getLastSession()} method to go through all nested delegates
+ * and return the "last" session.
+ *
+ * @author Rossen Stoyanchev
+ * @since 4.0.3
+ */
+public class WebSocketSessionDecorator implements WebSocketSession {
+
+ private final WebSocketSession delegate;
+
+
+ public WebSocketSessionDecorator(WebSocketSession session) {
+ Assert.notNull(session, "Delegate WebSocketSessionSession is required");
+ this.delegate = session;
+ }
+
+
+ @Override
+ public String getId() {
+ return this.delegate.getId();
+ }
+
+ @Override
+ public URI getUri() {
+ return this.delegate.getUri();
+ }
+
+ @Override
+ public HttpHeaders getHandshakeHeaders() {
+ return this.delegate.getHandshakeHeaders();
+ }
+
+ @Override
+ public Map getAttributes() {
+ return this.delegate.getAttributes();
+ }
+
+ @Override
+ public Principal getPrincipal() {
+ return this.delegate.getPrincipal();
+ }
+
+ @Override
+ public InetSocketAddress getLocalAddress() {
+ return this.delegate.getLocalAddress();
+ }
+
+ @Override
+ public InetSocketAddress getRemoteAddress() {
+ return this.delegate.getRemoteAddress();
+ }
+
+ @Override
+ public String getAcceptedProtocol() {
+ return this.delegate.getAcceptedProtocol();
+ }
+
+ @Override
+ public List getExtensions() {
+ return this.delegate.getExtensions();
+ }
+
+ @Override
+ public boolean isOpen() {
+ return this.delegate.isOpen();
+ }
+
+ @Override
+ public void sendMessage(WebSocketMessage> message) throws IOException {
+
+ }
+
+ @Override
+ public void close() throws IOException {
+ this.delegate.close();
+ }
+
+ @Override
+ public void close(CloseStatus status) throws IOException {
+ this.delegate.close(status);
+ }
+
+ public WebSocketSession getDelegate() {
+ return this.delegate;
+ }
+
+ public WebSocketSession getLastSession() {
+ WebSocketSession result = this.delegate;
+ while (result instanceof WebSocketSessionDecorator) {
+ result = ((WebSocketSessionDecorator) result).getDelegate();
+ }
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ return getClass().getSimpleName() + " [delegate=" + this.delegate + "]";
+ }
+
+}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
index fee754ced66..670544c580e 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
@@ -42,6 +42,7 @@ import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
+import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
/**
* An implementation of {@link WebSocketHandler} that delegates incoming WebSocket
@@ -74,6 +75,10 @@ public class SubProtocolWebSocketHandler
private final Map sessions = new ConcurrentHashMap();
+ private int sendTimeLimit = 20 * 1000;
+
+ private int sendBufferSizeLimit = 1024 * 1024;
+
private Object lifecycleMonitor = new Object();
private volatile boolean running = false;
@@ -155,6 +160,24 @@ public class SubProtocolWebSocketHandler
return new ArrayList(this.protocolHandlers.keySet());
}
+
+ public void setSendTimeLimit(int sendTimeLimit) {
+ this.sendTimeLimit = sendTimeLimit;
+ }
+
+ public int getSendTimeLimit() {
+ return this.sendTimeLimit;
+ }
+
+ public void setSendBufferSizeLimit(int sendBufferSizeLimit) {
+ this.sendBufferSizeLimit = sendBufferSizeLimit;
+ }
+
+ public int getSendBufferSizeLimit() {
+ return sendBufferSizeLimit;
+ }
+
+
@Override
public boolean isAutoStartup() {
return true;
@@ -198,11 +221,15 @@ public class SubProtocolWebSocketHandler
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
+
+ session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit());
+
this.sessions.put(session.getId(), session);
if (logger.isDebugEnabled()) {
logger.debug("Started WebSocket session=" + session.getId() +
", number of sessions=" + this.sessions.size());
}
+
findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel);
}
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java
new file mode 100644
index 00000000000..21413c32c67
--- /dev/null
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java
@@ -0,0 +1,220 @@
+/*
+ * Copyright 2002-2014 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.socket.handler;
+
+import org.junit.Test;
+import org.springframework.web.socket.TextMessage;
+import org.springframework.web.socket.WebSocketMessage;
+
+import java.io.IOException;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Unit tests for
+ * {@link org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator}.
+ *
+ * @author Rossen Stoyanchev
+ */
+public class ConcurrentWebSocketSessionDecoratorTests {
+
+
+ @Test
+ public void send() throws IOException {
+
+ TestWebSocketSession session = new TestWebSocketSession();
+ session.setOpen(true);
+
+ ConcurrentWebSocketSessionDecorator concurrentSession =
+ new ConcurrentWebSocketSessionDecorator(session, 1000, 1024);
+
+ TextMessage textMessage = new TextMessage("payload");
+ concurrentSession.sendMessage(textMessage);
+
+ assertEquals(1, session.getSentMessages().size());
+ assertEquals(textMessage, session.getSentMessages().get(0));
+
+ assertEquals(0, concurrentSession.getBufferSize());
+ assertEquals(0, concurrentSession.getInProgressSendTime());
+ assertTrue(session.isOpen());
+ }
+
+ @Test
+ public void sendAfterBlockedSend() throws IOException, InterruptedException {
+
+ BlockingSession blockingSession = new BlockingSession();
+ blockingSession.setOpen(true);
+ CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch();
+
+ final ConcurrentWebSocketSessionDecorator concurrentSession =
+ new ConcurrentWebSocketSessionDecorator(blockingSession, 10 * 1000, 1024);
+
+ Executors.newSingleThreadExecutor().submit(new Runnable() {
+ @Override
+ public void run() {
+ TextMessage textMessage = new TextMessage("slow message");
+ try {
+ concurrentSession.sendMessage(textMessage);
+ }
+ catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ });
+
+ assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS));
+
+ // ensure some send time elapses
+ Thread.sleep(100);
+ assertTrue(concurrentSession.getInProgressSendTime() > 0);
+
+ TextMessage payload = new TextMessage("payload");
+ for (int i=0; i < 5; i++) {
+ concurrentSession.sendMessage(payload);
+ }
+
+ assertTrue(concurrentSession.getInProgressSendTime() > 0);
+ assertEquals(5 * payload.getPayloadLength(), concurrentSession.getBufferSize());
+ assertTrue(blockingSession.isOpen());
+ }
+
+ @Test
+ public void sendTimeLimitExceeded() throws IOException, InterruptedException {
+
+ BlockingSession blockingSession = new BlockingSession();
+ blockingSession.setOpen(true);
+ CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch();
+
+ int sendTimeLimit = 100;
+ int bufferSizeLimit = 1024;
+
+ final ConcurrentWebSocketSessionDecorator concurrentSession =
+ new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit);
+
+ Executors.newSingleThreadExecutor().submit(new Runnable() {
+ @Override
+ public void run() {
+ TextMessage textMessage = new TextMessage("slow message");
+ try {
+ concurrentSession.sendMessage(textMessage);
+ }
+ catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ });
+
+ assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS));
+
+ // ensure some send time elapses
+ Thread.sleep(sendTimeLimit + 100);
+
+ TextMessage payload = new TextMessage("payload");
+ concurrentSession.sendMessage(payload);
+
+ assertFalse(blockingSession.isOpen());
+ }
+
+ @Test
+ public void sendBufferSizeExceeded() throws IOException, InterruptedException {
+
+ BlockingSession blockingSession = new BlockingSession();
+ blockingSession.setOpen(true);
+ CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch();
+
+ int sendTimeLimit = 10 * 1000;
+ int bufferSizeLimit = 1024;
+
+ final ConcurrentWebSocketSessionDecorator concurrentSession =
+ new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit);
+
+ Executors.newSingleThreadExecutor().submit(new Runnable() {
+ @Override
+ public void run() {
+ TextMessage textMessage = new TextMessage("slow message");
+ try {
+ concurrentSession.sendMessage(textMessage);
+ }
+ catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ });
+
+ assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS));
+
+ StringBuilder sb = new StringBuilder();
+ for (int i=0 ; i < 1023; i++) {
+ sb.append("a");
+ }
+
+ TextMessage message = new TextMessage(sb.toString());
+ concurrentSession.sendMessage(message);
+
+ assertEquals(1023, concurrentSession.getBufferSize());
+ assertTrue(blockingSession.isOpen());
+
+ concurrentSession.sendMessage(message);
+ assertFalse(blockingSession.isOpen());
+ }
+
+
+ private static class BlockingSession extends TestWebSocketSession {
+
+ private AtomicReference nextMessageLatch = new AtomicReference<>();
+
+ private AtomicReference releaseLatch = new AtomicReference<>();
+
+
+ public CountDownLatch getSentMessageLatch() {
+ this.nextMessageLatch.set(new CountDownLatch(1));
+ return this.nextMessageLatch.get();
+ }
+
+ @Override
+ public void sendMessage(WebSocketMessage> message) throws IOException {
+ super.sendMessage(message);
+ if (this.nextMessageLatch != null) {
+ this.nextMessageLatch.get().countDown();
+ }
+ block();
+ }
+
+ private void block() {
+ try {
+ this.releaseLatch.set(new CountDownLatch(1));
+ this.releaseLatch.get().await();
+ }
+ catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+
+ public void release() {
+ if (this.releaseLatch.get() != null) {
+ this.releaseLatch.get().countDown();
+ }
+ }
+ }
+
+}
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java
index 399a7d1a36e..b021199e189 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2013 the original author or authors.
+ * Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -21,9 +21,12 @@ import java.util.Arrays;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
+import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
+import org.springframework.web.socket.WebSocketSession;
+import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.TestWebSocketSession;
import static org.mockito.Mockito.*;
@@ -71,7 +74,8 @@ public class SubProtocolWebSocketHandlerTests {
this.session.setAcceptedProtocol("v12.sToMp");
this.webSocketHandler.afterConnectionEstablished(session);
- verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel);
+ verify(this.stompHandler).afterSessionStarted(
+ isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
}
@@ -81,7 +85,8 @@ public class SubProtocolWebSocketHandlerTests {
this.session.setAcceptedProtocol("v12.sToMp");
this.webSocketHandler.afterConnectionEstablished(session);
- verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel);
+ verify(this.stompHandler).afterSessionStarted(
+ isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
}
@Test(expected=IllegalStateException.class)
@@ -98,7 +103,8 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.afterConnectionEstablished(session);
- verify(this.defaultHandler).afterSessionStarted(session, this.inClientChannel);
+ verify(this.defaultHandler).afterSessionStarted(
+ isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
verify(this.stompHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
}
@@ -109,7 +115,8 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.afterConnectionEstablished(session);
- verify(this.defaultHandler).afterSessionStarted(session, this.inClientChannel);
+ verify(this.defaultHandler).afterSessionStarted(
+ isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
verify(this.stompHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel);
}
@@ -119,7 +126,8 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler));
this.webSocketHandler.afterConnectionEstablished(session);
- verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel);
+ verify(this.stompHandler).afterSessionStarted(
+ isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel));
}
@Test(expected=IllegalStateException.class)