6 changed files with 542 additions and 6 deletions
@ -0,0 +1,137 @@
@@ -0,0 +1,137 @@
|
||||
/* |
||||
* Copyright 2002-2014 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package org.springframework.web.socket.handler; |
||||
|
||||
import org.apache.commons.logging.Log; |
||||
import org.apache.commons.logging.LogFactory; |
||||
import org.springframework.web.socket.WebSocketMessage; |
||||
import org.springframework.web.socket.WebSocketSession; |
||||
|
||||
import java.io.IOException; |
||||
import java.util.Queue; |
||||
import java.util.concurrent.LinkedBlockingQueue; |
||||
import java.util.concurrent.atomic.AtomicInteger; |
||||
import java.util.concurrent.locks.Lock; |
||||
import java.util.concurrent.locks.ReentrantLock; |
||||
|
||||
|
||||
/** |
||||
* Wraps a {@link org.springframework.web.socket.WebSocketSession} and guarantees |
||||
* only one thread can send messages at a time. |
||||
* |
||||
* <p>If a send is slow, subsequent attempts to send more messages from a different |
||||
* thread will fail to acquire the lock and the messages will be buffered instead -- |
||||
* at that time the specified buffer size limit and send time limit will be checked |
||||
* and the session closed if the limits are exceeded. |
||||
* |
||||
* @author Rossen Stoyanchev |
||||
* @since 4.0.3 |
||||
*/ |
||||
public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorator { |
||||
|
||||
private static Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class); |
||||
|
||||
|
||||
private final int sendTimeLimit; |
||||
|
||||
private final int bufferSizeLimit; |
||||
|
||||
private final Queue<WebSocketMessage<?>> buffer = new LinkedBlockingQueue<WebSocketMessage<?>>(); |
||||
|
||||
private final AtomicInteger bufferSize = new AtomicInteger(); |
||||
|
||||
private volatile long sendStartTime; |
||||
|
||||
private final Lock lock = new ReentrantLock(); |
||||
|
||||
|
||||
public ConcurrentWebSocketSessionDecorator( |
||||
WebSocketSession delegateSession, int sendTimeLimit, int bufferSizeLimit) { |
||||
|
||||
super(delegateSession); |
||||
this.sendTimeLimit = sendTimeLimit; |
||||
this.bufferSizeLimit = bufferSizeLimit; |
||||
} |
||||
|
||||
|
||||
public int getBufferSize() { |
||||
return this.bufferSize.get(); |
||||
} |
||||
|
||||
public long getInProgressSendTime() { |
||||
long start = this.sendStartTime; |
||||
return (start > 0 ? (System.currentTimeMillis() - start) : 0); |
||||
} |
||||
|
||||
|
||||
public void sendMessage(WebSocketMessage<?> message) throws IOException { |
||||
|
||||
this.buffer.add(message); |
||||
this.bufferSize.addAndGet(message.getPayloadLength()); |
||||
|
||||
do { |
||||
if (!tryFlushMessageBuffer()) { |
||||
checkSessionLimits(); |
||||
break; |
||||
} |
||||
} |
||||
while (!this.buffer.isEmpty()); |
||||
} |
||||
|
||||
private boolean tryFlushMessageBuffer() throws IOException { |
||||
|
||||
if (this.lock.tryLock()) { |
||||
try { |
||||
while (true) { |
||||
WebSocketMessage<?> messageToSend = this.buffer.poll(); |
||||
if (messageToSend == null) { |
||||
break; |
||||
} |
||||
this.bufferSize.addAndGet(messageToSend.getPayloadLength() * -1); |
||||
this.sendStartTime = System.currentTimeMillis(); |
||||
getDelegate().sendMessage(messageToSend); |
||||
this.sendStartTime = 0; |
||||
} |
||||
} |
||||
finally { |
||||
this.sendStartTime = 0; |
||||
lock.unlock(); |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
return false; |
||||
} |
||||
|
||||
private void checkSessionLimits() throws IOException { |
||||
if (getInProgressSendTime() > this.sendTimeLimit) { |
||||
logError("A message could not be sent due to a timeout"); |
||||
getDelegate().close(); |
||||
} |
||||
else if (this.bufferSize.get() > this.bufferSizeLimit) { |
||||
logError("The total send buffer byte count '" + this.bufferSize.get() + |
||||
"' for session '" + getId() + "' exceeds the allowed limit '" + this.bufferSizeLimit + "'"); |
||||
getDelegate().close(); |
||||
} |
||||
} |
||||
|
||||
private void logError(String reason) { |
||||
logger.error(reason + ", number of buffered messages is '" + this.buffer.size() + |
||||
"', time since the last send started is '" + getInProgressSendTime() + "' (ms)"); |
||||
} |
||||
|
||||
} |
||||
@ -0,0 +1,137 @@
@@ -0,0 +1,137 @@
|
||||
/* |
||||
* Copyright 2002-2014 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package org.springframework.web.socket.handler; |
||||
|
||||
import org.springframework.http.HttpHeaders; |
||||
import org.springframework.util.Assert; |
||||
import org.springframework.web.socket.CloseStatus; |
||||
import org.springframework.web.socket.WebSocketExtension; |
||||
import org.springframework.web.socket.WebSocketMessage; |
||||
import org.springframework.web.socket.WebSocketSession; |
||||
|
||||
import java.io.IOException; |
||||
import java.net.InetSocketAddress; |
||||
import java.net.URI; |
||||
import java.security.Principal; |
||||
import java.util.List; |
||||
import java.util.Map; |
||||
|
||||
/** |
||||
* Wraps another {@link org.springframework.web.socket.WebSocketSession} instance |
||||
* and delegates to it. |
||||
* |
||||
* <p>Also provides a {@link #getDelegate()} method to return the decorated session |
||||
* as well as a {@link #getLastSession()} method to go through all nested delegates |
||||
* and return the "last" session. |
||||
* |
||||
* @author Rossen Stoyanchev |
||||
* @since 4.0.3 |
||||
*/ |
||||
public class WebSocketSessionDecorator implements WebSocketSession { |
||||
|
||||
private final WebSocketSession delegate; |
||||
|
||||
|
||||
public WebSocketSessionDecorator(WebSocketSession session) { |
||||
Assert.notNull(session, "Delegate WebSocketSessionSession is required"); |
||||
this.delegate = session; |
||||
} |
||||
|
||||
|
||||
@Override |
||||
public String getId() { |
||||
return this.delegate.getId(); |
||||
} |
||||
|
||||
@Override |
||||
public URI getUri() { |
||||
return this.delegate.getUri(); |
||||
} |
||||
|
||||
@Override |
||||
public HttpHeaders getHandshakeHeaders() { |
||||
return this.delegate.getHandshakeHeaders(); |
||||
} |
||||
|
||||
@Override |
||||
public Map<String, Object> getAttributes() { |
||||
return this.delegate.getAttributes(); |
||||
} |
||||
|
||||
@Override |
||||
public Principal getPrincipal() { |
||||
return this.delegate.getPrincipal(); |
||||
} |
||||
|
||||
@Override |
||||
public InetSocketAddress getLocalAddress() { |
||||
return this.delegate.getLocalAddress(); |
||||
} |
||||
|
||||
@Override |
||||
public InetSocketAddress getRemoteAddress() { |
||||
return this.delegate.getRemoteAddress(); |
||||
} |
||||
|
||||
@Override |
||||
public String getAcceptedProtocol() { |
||||
return this.delegate.getAcceptedProtocol(); |
||||
} |
||||
|
||||
@Override |
||||
public List<WebSocketExtension> getExtensions() { |
||||
return this.delegate.getExtensions(); |
||||
} |
||||
|
||||
@Override |
||||
public boolean isOpen() { |
||||
return this.delegate.isOpen(); |
||||
} |
||||
|
||||
@Override |
||||
public void sendMessage(WebSocketMessage<?> message) throws IOException { |
||||
|
||||
} |
||||
|
||||
@Override |
||||
public void close() throws IOException { |
||||
this.delegate.close(); |
||||
} |
||||
|
||||
@Override |
||||
public void close(CloseStatus status) throws IOException { |
||||
this.delegate.close(status); |
||||
} |
||||
|
||||
public WebSocketSession getDelegate() { |
||||
return this.delegate; |
||||
} |
||||
|
||||
public WebSocketSession getLastSession() { |
||||
WebSocketSession result = this.delegate; |
||||
while (result instanceof WebSocketSessionDecorator) { |
||||
result = ((WebSocketSessionDecorator) result).getDelegate(); |
||||
} |
||||
return result; |
||||
} |
||||
|
||||
@Override |
||||
public String toString() { |
||||
return getClass().getSimpleName() + " [delegate=" + this.delegate + "]"; |
||||
} |
||||
|
||||
} |
||||
@ -0,0 +1,220 @@
@@ -0,0 +1,220 @@
|
||||
/* |
||||
* Copyright 2002-2014 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package org.springframework.web.socket.handler; |
||||
|
||||
import org.junit.Test; |
||||
import org.springframework.web.socket.TextMessage; |
||||
import org.springframework.web.socket.WebSocketMessage; |
||||
|
||||
import java.io.IOException; |
||||
import java.util.concurrent.CountDownLatch; |
||||
import java.util.concurrent.Executors; |
||||
import java.util.concurrent.TimeUnit; |
||||
import java.util.concurrent.atomic.AtomicReference; |
||||
|
||||
import static org.junit.Assert.assertEquals; |
||||
import static org.junit.Assert.assertFalse; |
||||
import static org.junit.Assert.assertTrue; |
||||
|
||||
/** |
||||
* Unit tests for |
||||
* {@link org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator}. |
||||
* |
||||
* @author Rossen Stoyanchev |
||||
*/ |
||||
public class ConcurrentWebSocketSessionDecoratorTests { |
||||
|
||||
|
||||
@Test |
||||
public void send() throws IOException { |
||||
|
||||
TestWebSocketSession session = new TestWebSocketSession(); |
||||
session.setOpen(true); |
||||
|
||||
ConcurrentWebSocketSessionDecorator concurrentSession = |
||||
new ConcurrentWebSocketSessionDecorator(session, 1000, 1024); |
||||
|
||||
TextMessage textMessage = new TextMessage("payload"); |
||||
concurrentSession.sendMessage(textMessage); |
||||
|
||||
assertEquals(1, session.getSentMessages().size()); |
||||
assertEquals(textMessage, session.getSentMessages().get(0)); |
||||
|
||||
assertEquals(0, concurrentSession.getBufferSize()); |
||||
assertEquals(0, concurrentSession.getInProgressSendTime()); |
||||
assertTrue(session.isOpen()); |
||||
} |
||||
|
||||
@Test |
||||
public void sendAfterBlockedSend() throws IOException, InterruptedException { |
||||
|
||||
BlockingSession blockingSession = new BlockingSession(); |
||||
blockingSession.setOpen(true); |
||||
CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); |
||||
|
||||
final ConcurrentWebSocketSessionDecorator concurrentSession = |
||||
new ConcurrentWebSocketSessionDecorator(blockingSession, 10 * 1000, 1024); |
||||
|
||||
Executors.newSingleThreadExecutor().submit(new Runnable() { |
||||
@Override |
||||
public void run() { |
||||
TextMessage textMessage = new TextMessage("slow message"); |
||||
try { |
||||
concurrentSession.sendMessage(textMessage); |
||||
} |
||||
catch (IOException e) { |
||||
e.printStackTrace(); |
||||
} |
||||
} |
||||
}); |
||||
|
||||
assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); |
||||
|
||||
// ensure some send time elapses
|
||||
Thread.sleep(100); |
||||
assertTrue(concurrentSession.getInProgressSendTime() > 0); |
||||
|
||||
TextMessage payload = new TextMessage("payload"); |
||||
for (int i=0; i < 5; i++) { |
||||
concurrentSession.sendMessage(payload); |
||||
} |
||||
|
||||
assertTrue(concurrentSession.getInProgressSendTime() > 0); |
||||
assertEquals(5 * payload.getPayloadLength(), concurrentSession.getBufferSize()); |
||||
assertTrue(blockingSession.isOpen()); |
||||
} |
||||
|
||||
@Test |
||||
public void sendTimeLimitExceeded() throws IOException, InterruptedException { |
||||
|
||||
BlockingSession blockingSession = new BlockingSession(); |
||||
blockingSession.setOpen(true); |
||||
CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); |
||||
|
||||
int sendTimeLimit = 100; |
||||
int bufferSizeLimit = 1024; |
||||
|
||||
final ConcurrentWebSocketSessionDecorator concurrentSession = |
||||
new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); |
||||
|
||||
Executors.newSingleThreadExecutor().submit(new Runnable() { |
||||
@Override |
||||
public void run() { |
||||
TextMessage textMessage = new TextMessage("slow message"); |
||||
try { |
||||
concurrentSession.sendMessage(textMessage); |
||||
} |
||||
catch (IOException e) { |
||||
e.printStackTrace(); |
||||
} |
||||
} |
||||
}); |
||||
|
||||
assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); |
||||
|
||||
// ensure some send time elapses
|
||||
Thread.sleep(sendTimeLimit + 100); |
||||
|
||||
TextMessage payload = new TextMessage("payload"); |
||||
concurrentSession.sendMessage(payload); |
||||
|
||||
assertFalse(blockingSession.isOpen()); |
||||
} |
||||
|
||||
@Test |
||||
public void sendBufferSizeExceeded() throws IOException, InterruptedException { |
||||
|
||||
BlockingSession blockingSession = new BlockingSession(); |
||||
blockingSession.setOpen(true); |
||||
CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); |
||||
|
||||
int sendTimeLimit = 10 * 1000; |
||||
int bufferSizeLimit = 1024; |
||||
|
||||
final ConcurrentWebSocketSessionDecorator concurrentSession = |
||||
new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); |
||||
|
||||
Executors.newSingleThreadExecutor().submit(new Runnable() { |
||||
@Override |
||||
public void run() { |
||||
TextMessage textMessage = new TextMessage("slow message"); |
||||
try { |
||||
concurrentSession.sendMessage(textMessage); |
||||
} |
||||
catch (IOException e) { |
||||
e.printStackTrace(); |
||||
} |
||||
} |
||||
}); |
||||
|
||||
assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); |
||||
|
||||
StringBuilder sb = new StringBuilder(); |
||||
for (int i=0 ; i < 1023; i++) { |
||||
sb.append("a"); |
||||
} |
||||
|
||||
TextMessage message = new TextMessage(sb.toString()); |
||||
concurrentSession.sendMessage(message); |
||||
|
||||
assertEquals(1023, concurrentSession.getBufferSize()); |
||||
assertTrue(blockingSession.isOpen()); |
||||
|
||||
concurrentSession.sendMessage(message); |
||||
assertFalse(blockingSession.isOpen()); |
||||
} |
||||
|
||||
|
||||
private static class BlockingSession extends TestWebSocketSession { |
||||
|
||||
private AtomicReference<CountDownLatch> nextMessageLatch = new AtomicReference<>(); |
||||
|
||||
private AtomicReference<CountDownLatch> releaseLatch = new AtomicReference<>(); |
||||
|
||||
|
||||
public CountDownLatch getSentMessageLatch() { |
||||
this.nextMessageLatch.set(new CountDownLatch(1)); |
||||
return this.nextMessageLatch.get(); |
||||
} |
||||
|
||||
@Override |
||||
public void sendMessage(WebSocketMessage<?> message) throws IOException { |
||||
super.sendMessage(message); |
||||
if (this.nextMessageLatch != null) { |
||||
this.nextMessageLatch.get().countDown(); |
||||
} |
||||
block(); |
||||
} |
||||
|
||||
private void block() { |
||||
try { |
||||
this.releaseLatch.set(new CountDownLatch(1)); |
||||
this.releaseLatch.get().await(); |
||||
} |
||||
catch (InterruptedException e) { |
||||
e.printStackTrace(); |
||||
} |
||||
} |
||||
|
||||
public void release() { |
||||
if (this.releaseLatch.get() != null) { |
||||
this.releaseLatch.get().countDown(); |
||||
} |
||||
} |
||||
} |
||||
|
||||
} |
||||
Loading…
Reference in new issue