diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java new file mode 100644 index 00000000000..ce70fad270a --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.handler; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; + +import java.io.IOException; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + + +/** + * Wraps a {@link org.springframework.web.socket.WebSocketSession} and guarantees + * only one thread can send messages at a time. + * + *

If a send is slow, subsequent attempts to send more messages from a different + * thread will fail to acquire the lock and the messages will be buffered instead -- + * at that time the specified buffer size limit and send time limit will be checked + * and the session closed if the limits are exceeded. + * + * @author Rossen Stoyanchev + * @since 4.0.3 + */ +public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorator { + + private static Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class); + + + private final int sendTimeLimit; + + private final int bufferSizeLimit; + + private final Queue> buffer = new LinkedBlockingQueue>(); + + private final AtomicInteger bufferSize = new AtomicInteger(); + + private volatile long sendStartTime; + + private final Lock lock = new ReentrantLock(); + + + public ConcurrentWebSocketSessionDecorator( + WebSocketSession delegateSession, int sendTimeLimit, int bufferSizeLimit) { + + super(delegateSession); + this.sendTimeLimit = sendTimeLimit; + this.bufferSizeLimit = bufferSizeLimit; + } + + + public int getBufferSize() { + return this.bufferSize.get(); + } + + public long getInProgressSendTime() { + long start = this.sendStartTime; + return (start > 0 ? (System.currentTimeMillis() - start) : 0); + } + + + public void sendMessage(WebSocketMessage message) throws IOException { + + this.buffer.add(message); + this.bufferSize.addAndGet(message.getPayloadLength()); + + do { + if (!tryFlushMessageBuffer()) { + checkSessionLimits(); + break; + } + } + while (!this.buffer.isEmpty()); + } + + private boolean tryFlushMessageBuffer() throws IOException { + + if (this.lock.tryLock()) { + try { + while (true) { + WebSocketMessage messageToSend = this.buffer.poll(); + if (messageToSend == null) { + break; + } + this.bufferSize.addAndGet(messageToSend.getPayloadLength() * -1); + this.sendStartTime = System.currentTimeMillis(); + getDelegate().sendMessage(messageToSend); + this.sendStartTime = 0; + } + } + finally { + this.sendStartTime = 0; + lock.unlock(); + } + return true; + } + + return false; + } + + private void checkSessionLimits() throws IOException { + if (getInProgressSendTime() > this.sendTimeLimit) { + logError("A message could not be sent due to a timeout"); + getDelegate().close(); + } + else if (this.bufferSize.get() > this.bufferSizeLimit) { + logError("The total send buffer byte count '" + this.bufferSize.get() + + "' for session '" + getId() + "' exceeds the allowed limit '" + this.bufferSizeLimit + "'"); + getDelegate().close(); + } + } + + private void logError(String reason) { + logger.error(reason + ", number of buffered messages is '" + this.buffer.size() + + "', time since the last send started is '" + getInProgressSendTime() + "' (ms)"); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java index bf13a8dbf4a..ec5eda452ba 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketHandlerDecorator.java @@ -23,6 +23,13 @@ import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; /** + * Wraps another {@link org.springframework.web.socket.WebSocketHandler} + * instance and delegates to it. + * + *

Also provides a {@link #getDelegate()} method to return the decorated + * handler as well as a {@link #getLastHandler()} method to go through all nested + * delegates and return the "last" handler. + * * @author Rossen Stoyanchev * @since 4.0 */ diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java new file mode 100644 index 00000000000..cea06b18f9e --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/WebSocketSessionDecorator.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.handler; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.security.Principal; +import java.util.List; +import java.util.Map; + +/** + * Wraps another {@link org.springframework.web.socket.WebSocketSession} instance + * and delegates to it. + * + *

Also provides a {@link #getDelegate()} method to return the decorated session + * as well as a {@link #getLastSession()} method to go through all nested delegates + * and return the "last" session. + * + * @author Rossen Stoyanchev + * @since 4.0.3 + */ +public class WebSocketSessionDecorator implements WebSocketSession { + + private final WebSocketSession delegate; + + + public WebSocketSessionDecorator(WebSocketSession session) { + Assert.notNull(session, "Delegate WebSocketSessionSession is required"); + this.delegate = session; + } + + + @Override + public String getId() { + return this.delegate.getId(); + } + + @Override + public URI getUri() { + return this.delegate.getUri(); + } + + @Override + public HttpHeaders getHandshakeHeaders() { + return this.delegate.getHandshakeHeaders(); + } + + @Override + public Map getAttributes() { + return this.delegate.getAttributes(); + } + + @Override + public Principal getPrincipal() { + return this.delegate.getPrincipal(); + } + + @Override + public InetSocketAddress getLocalAddress() { + return this.delegate.getLocalAddress(); + } + + @Override + public InetSocketAddress getRemoteAddress() { + return this.delegate.getRemoteAddress(); + } + + @Override + public String getAcceptedProtocol() { + return this.delegate.getAcceptedProtocol(); + } + + @Override + public List getExtensions() { + return this.delegate.getExtensions(); + } + + @Override + public boolean isOpen() { + return this.delegate.isOpen(); + } + + @Override + public void sendMessage(WebSocketMessage message) throws IOException { + + } + + @Override + public void close() throws IOException { + this.delegate.close(); + } + + @Override + public void close(CloseStatus status) throws IOException { + this.delegate.close(status); + } + + public WebSocketSession getDelegate() { + return this.delegate; + } + + public WebSocketSession getLastSession() { + WebSocketSession result = this.delegate; + while (result instanceof WebSocketSessionDecorator) { + result = ((WebSocketSessionDecorator) result).getDelegate(); + } + return result; + } + + @Override + public String toString() { + return getClass().getSimpleName() + " [delegate=" + this.delegate + "]"; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index fee754ced66..670544c580e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -42,6 +42,7 @@ import org.springframework.web.socket.SubProtocolCapable; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; /** * An implementation of {@link WebSocketHandler} that delegates incoming WebSocket @@ -74,6 +75,10 @@ public class SubProtocolWebSocketHandler private final Map sessions = new ConcurrentHashMap(); + private int sendTimeLimit = 20 * 1000; + + private int sendBufferSizeLimit = 1024 * 1024; + private Object lifecycleMonitor = new Object(); private volatile boolean running = false; @@ -155,6 +160,24 @@ public class SubProtocolWebSocketHandler return new ArrayList(this.protocolHandlers.keySet()); } + + public void setSendTimeLimit(int sendTimeLimit) { + this.sendTimeLimit = sendTimeLimit; + } + + public int getSendTimeLimit() { + return this.sendTimeLimit; + } + + public void setSendBufferSizeLimit(int sendBufferSizeLimit) { + this.sendBufferSizeLimit = sendBufferSizeLimit; + } + + public int getSendBufferSizeLimit() { + return sendBufferSizeLimit; + } + + @Override public boolean isAutoStartup() { return true; @@ -198,11 +221,15 @@ public class SubProtocolWebSocketHandler @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { + + session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit()); + this.sessions.put(session.getId(), session); if (logger.isDebugEnabled()) { logger.debug("Started WebSocket session=" + session.getId() + ", number of sessions=" + this.sessions.size()); } + findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java new file mode 100644 index 00000000000..21413c32c67 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java @@ -0,0 +1,220 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.handler; + +import org.junit.Test; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketMessage; + +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for + * {@link org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator}. + * + * @author Rossen Stoyanchev + */ +public class ConcurrentWebSocketSessionDecoratorTests { + + + @Test + public void send() throws IOException { + + TestWebSocketSession session = new TestWebSocketSession(); + session.setOpen(true); + + ConcurrentWebSocketSessionDecorator concurrentSession = + new ConcurrentWebSocketSessionDecorator(session, 1000, 1024); + + TextMessage textMessage = new TextMessage("payload"); + concurrentSession.sendMessage(textMessage); + + assertEquals(1, session.getSentMessages().size()); + assertEquals(textMessage, session.getSentMessages().get(0)); + + assertEquals(0, concurrentSession.getBufferSize()); + assertEquals(0, concurrentSession.getInProgressSendTime()); + assertTrue(session.isOpen()); + } + + @Test + public void sendAfterBlockedSend() throws IOException, InterruptedException { + + BlockingSession blockingSession = new BlockingSession(); + blockingSession.setOpen(true); + CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); + + final ConcurrentWebSocketSessionDecorator concurrentSession = + new ConcurrentWebSocketSessionDecorator(blockingSession, 10 * 1000, 1024); + + Executors.newSingleThreadExecutor().submit(new Runnable() { + @Override + public void run() { + TextMessage textMessage = new TextMessage("slow message"); + try { + concurrentSession.sendMessage(textMessage); + } + catch (IOException e) { + e.printStackTrace(); + } + } + }); + + assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); + + // ensure some send time elapses + Thread.sleep(100); + assertTrue(concurrentSession.getInProgressSendTime() > 0); + + TextMessage payload = new TextMessage("payload"); + for (int i=0; i < 5; i++) { + concurrentSession.sendMessage(payload); + } + + assertTrue(concurrentSession.getInProgressSendTime() > 0); + assertEquals(5 * payload.getPayloadLength(), concurrentSession.getBufferSize()); + assertTrue(blockingSession.isOpen()); + } + + @Test + public void sendTimeLimitExceeded() throws IOException, InterruptedException { + + BlockingSession blockingSession = new BlockingSession(); + blockingSession.setOpen(true); + CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); + + int sendTimeLimit = 100; + int bufferSizeLimit = 1024; + + final ConcurrentWebSocketSessionDecorator concurrentSession = + new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); + + Executors.newSingleThreadExecutor().submit(new Runnable() { + @Override + public void run() { + TextMessage textMessage = new TextMessage("slow message"); + try { + concurrentSession.sendMessage(textMessage); + } + catch (IOException e) { + e.printStackTrace(); + } + } + }); + + assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); + + // ensure some send time elapses + Thread.sleep(sendTimeLimit + 100); + + TextMessage payload = new TextMessage("payload"); + concurrentSession.sendMessage(payload); + + assertFalse(blockingSession.isOpen()); + } + + @Test + public void sendBufferSizeExceeded() throws IOException, InterruptedException { + + BlockingSession blockingSession = new BlockingSession(); + blockingSession.setOpen(true); + CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); + + int sendTimeLimit = 10 * 1000; + int bufferSizeLimit = 1024; + + final ConcurrentWebSocketSessionDecorator concurrentSession = + new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); + + Executors.newSingleThreadExecutor().submit(new Runnable() { + @Override + public void run() { + TextMessage textMessage = new TextMessage("slow message"); + try { + concurrentSession.sendMessage(textMessage); + } + catch (IOException e) { + e.printStackTrace(); + } + } + }); + + assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); + + StringBuilder sb = new StringBuilder(); + for (int i=0 ; i < 1023; i++) { + sb.append("a"); + } + + TextMessage message = new TextMessage(sb.toString()); + concurrentSession.sendMessage(message); + + assertEquals(1023, concurrentSession.getBufferSize()); + assertTrue(blockingSession.isOpen()); + + concurrentSession.sendMessage(message); + assertFalse(blockingSession.isOpen()); + } + + + private static class BlockingSession extends TestWebSocketSession { + + private AtomicReference nextMessageLatch = new AtomicReference<>(); + + private AtomicReference releaseLatch = new AtomicReference<>(); + + + public CountDownLatch getSentMessageLatch() { + this.nextMessageLatch.set(new CountDownLatch(1)); + return this.nextMessageLatch.get(); + } + + @Override + public void sendMessage(WebSocketMessage message) throws IOException { + super.sendMessage(message); + if (this.nextMessageLatch != null) { + this.nextMessageLatch.get().countDown(); + } + block(); + } + + private void block() { + try { + this.releaseLatch.set(new CountDownLatch(1)); + this.releaseLatch.get().await(); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + } + + public void release() { + if (this.releaseLatch.get() != null) { + this.releaseLatch.get().countDown(); + } + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java index 399a7d1a36e..b021199e189 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,9 +21,12 @@ import java.util.Arrays; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; import org.springframework.web.socket.handler.TestWebSocketSession; import static org.mockito.Mockito.*; @@ -71,7 +74,8 @@ public class SubProtocolWebSocketHandlerTests { this.session.setAcceptedProtocol("v12.sToMp"); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel); + verify(this.stompHandler).afterSessionStarted( + isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel)); verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel); } @@ -81,7 +85,8 @@ public class SubProtocolWebSocketHandlerTests { this.session.setAcceptedProtocol("v12.sToMp"); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel); + verify(this.stompHandler).afterSessionStarted( + isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel)); } @Test(expected=IllegalStateException.class) @@ -98,7 +103,8 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.setDefaultProtocolHandler(defaultHandler); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.defaultHandler).afterSessionStarted(session, this.inClientChannel); + verify(this.defaultHandler).afterSessionStarted( + isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel)); verify(this.stompHandler, times(0)).afterSessionStarted(session, this.inClientChannel); verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel); } @@ -109,7 +115,8 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.setDefaultProtocolHandler(defaultHandler); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.defaultHandler).afterSessionStarted(session, this.inClientChannel); + verify(this.defaultHandler).afterSessionStarted( + isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel)); verify(this.stompHandler, times(0)).afterSessionStarted(session, this.inClientChannel); verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.inClientChannel); } @@ -119,7 +126,8 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.setProtocolHandlers(Arrays.asList(stompHandler)); this.webSocketHandler.afterConnectionEstablished(session); - verify(this.stompHandler).afterSessionStarted(session, this.inClientChannel); + verify(this.stompHandler).afterSessionStarted( + isA(ConcurrentWebSocketSessionDecorator.class), eq(this.inClientChannel)); } @Test(expected=IllegalStateException.class)