@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2018 the original author or authors .
* Copyright 2002 - 2020 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -16,8 +16,13 @@
@@ -16,8 +16,13 @@
package org.springframework.messaging.simp.broker ;
import java.util.Collection ;
import java.util.HashSet ;
import java.util.Queue ;
import java.util.Set ;
import java.util.concurrent.ConcurrentHashMap ;
import java.util.concurrent.ConcurrentLinkedQueue ;
import java.util.concurrent.ConcurrentMap ;
import java.util.concurrent.atomic.AtomicBoolean ;
import org.apache.commons.logging.Log ;
@ -33,8 +38,14 @@ import org.springframework.messaging.support.MessageHeaderAccessor;
@@ -33,8 +38,14 @@ import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert ;
/ * *
* Submit messages to an { @link ExecutorSubscribableChannel } , one at a time .
* The channel must have been configured with { @link # configureOutboundChannel } .
* { @code MessageChannel } decorator that ensures messages from the same session
* are sent and processed in the same order . This would not normally be the case
* with an { @code Executor } backed { @code MessageChannel } since the executor
* is free to submit tasks in any order .
*
* < p > To provide ordering , inbound messages are placed in a queue and sent one
* one at a time per session . Once a message is processed , a callback is used to
* notify that the next message from the same session can be sent through .
*
* @author Rossen Stoyanchev
* @since 5 . 1
@ -48,9 +59,7 @@ class OrderedMessageSender implements MessageChannel {
@@ -48,9 +59,7 @@ class OrderedMessageSender implements MessageChannel {
private final Log logger ;
private final Queue < Message < ? > > messages = new ConcurrentLinkedQueue < > ( ) ;
private final AtomicBoolean sendInProgress = new AtomicBoolean ( false ) ;
private final Control control = new Control ( ) ;
public OrderedMessageSender ( MessageChannel channel , Log logger ) {
@ -66,30 +75,40 @@ class OrderedMessageSender implements MessageChannel {
@@ -66,30 +75,40 @@ class OrderedMessageSender implements MessageChannel {
@Override
public boolean send ( Message < ? > message , long timeout ) {
this . messages . add ( message ) ;
this . control . addMessage ( message ) ;
trySend ( ) ;
return true ;
}
private void trySend ( ) {
// Take sendInProgress flag only if queue is not empty
if ( this . messages . isEmpty ( ) ) {
return ;
}
if ( this . sendInProgress . compareAndSet ( false , true ) ) {
sendNextMessage ( ) ;
if ( this . control . acquireSendLock ( ) ) {
sendMessages ( ) ;
}
}
private void sendNextMessage ( ) {
for ( ; ; ) {
Message < ? > message = this . messages . poll ( ) ;
if ( message ! = null ) {
private void sendMessages ( ) {
for ( ; ; ) {
Set < String > skipSet = new HashSet < > ( ) ;
for ( Message < ? > message : this . control . getMessagesToSend ( ) ) {
String sessionId = SimpMessageHeaderAccessor . getSessionId ( message . getHeaders ( ) ) ;
Assert . notNull ( sessionId , ( ) - > "No session id in " + message . getHeaders ( ) ) ;
if ( skipSet . contains ( sessionId ) ) {
continue ;
}
if ( ! this . control . acquireSessionLock ( sessionId ) ) {
skipSet . add ( sessionId ) ;
continue ;
}
this . control . removeMessage ( message ) ;
try {
addCompletionCallback ( message ) ;
getMutableAccessor ( message ) . setHeader ( COMPLETION_TASK_HEADER , ( Runnable ) ( ) - > {
this . control . releaseSessionLock ( sessionId ) ;
if ( this . control . hasRemainingWork ( ) ) {
trySend ( ) ;
}
} ) ;
if ( this . channel . send ( message ) ) {
return ;
continue ;
}
}
catch ( Throwable ex ) {
@ -97,20 +116,24 @@ class OrderedMessageSender implements MessageChannel {
@@ -97,20 +116,24 @@ class OrderedMessageSender implements MessageChannel {
logger . error ( "Failed to send " + message , ex ) ;
}
}
// We didn't send
this . control . releaseSessionLock ( sessionId ) ;
}
else {
// We ran out of messages..
this . sendInProgress . set ( false ) ;
trySend ( ) ;
break ;
if ( this . control . shouldYield ( ) ) {
this . control . releaseSendLock ( ) ;
if ( ! this . control . shouldYield ( ) ) {
trySend ( ) ;
}
return ;
}
}
}
private void addCompletionCallback ( Message < ? > msg ) {
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor . getAccessor ( msg , SimpMessageHeaderAccessor . class ) ;
private SimpMessageHeaderAccessor getMutableAccessor ( Message < ? > mes sa ge ) {
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor . getAccessor ( mes sa ge , SimpMessageHeaderAccessor . class ) ;
Assert . isTrue ( accessor ! = null & & accessor . isMutable ( ) , "Expected mutable SimpMessageHeaderAccessor" ) ;
accessor . setHeader ( COMPLETION_TASK_HEADER , ( Runnable ) this : : sendNextMessage ) ;
return accessor ;
}
@ -126,13 +149,13 @@ class OrderedMessageSender implements MessageChannel {
@@ -126,13 +149,13 @@ class OrderedMessageSender implements MessageChannel {
Assert . isInstanceOf ( ExecutorSubscribableChannel . class , channel ,
"An ExecutorSubscribableChannel is required for `preservePublishOrder`" ) ;
ExecutorSubscribableChannel execChannel = ( ExecutorSubscribableChannel ) channel ;
if ( execChannel . getInterceptors ( ) . stream ( ) . noneMatch ( i - > i instanceof Callbac kInterceptor ) ) {
execChannel . addInterceptor ( 0 , new Callbac kInterceptor ( ) ) ;
if ( execChannel . getInterceptors ( ) . stream ( ) . noneMatch ( i - > i instanceof CompletionTas kInterceptor ) ) {
execChannel . addInterceptor ( 0 , new CompletionTas kInterceptor ( ) ) ;
}
}
else if ( channel instanceof ExecutorSubscribableChannel ) {
ExecutorSubscribableChannel execChannel = ( ExecutorSubscribableChannel ) channel ;
execChannel . getInterceptors ( ) . stream ( ) . filter ( i - > i instanceof Callbac kInterceptor )
execChannel . getInterceptors ( ) . stream ( ) . filter ( i - > i instanceof CompletionTas kInterceptor )
. findFirst ( )
. map ( execChannel : : removeInterceptor ) ;
@ -140,13 +163,71 @@ class OrderedMessageSender implements MessageChannel {
@@ -140,13 +163,71 @@ class OrderedMessageSender implements MessageChannel {
}
private static class CallbackInterceptor implements ExecutorChannelInterceptor {
/ * *
* Provides locks required for ordered message sending and execution within
* a session as well as storage for messages waiting to be sent .
* /
private static class Control {
private final Queue < Message < ? > > messages = new ConcurrentLinkedQueue < > ( ) ;
private final ConcurrentMap < String , Boolean > sessionsInProgress = new ConcurrentHashMap < > ( ) ;
private final AtomicBoolean workInProgress = new AtomicBoolean ( false ) ;
public void addMessage ( Message < ? > message ) {
this . messages . add ( message ) ;
}
public void removeMessage ( Message < ? > message ) {
if ( ! this . messages . remove ( message ) ) {
throw new IllegalStateException (
"Message " + message . getHeaders ( ) + " was expected in the queue." ) ;
}
}
public Collection < Message < ? > > getMessagesToSend ( ) {
return this . messages ;
}
public boolean acquireSendLock ( ) {
return this . workInProgress . compareAndSet ( false , true ) ;
}
public void releaseSendLock ( ) {
this . workInProgress . set ( false ) ;
}
public boolean acquireSessionLock ( String sessionId ) {
if ( this . sessionsInProgress . put ( sessionId , Boolean . TRUE ) ! = null ) {
return false ;
}
return true ;
}
public void releaseSessionLock ( String sessionId ) {
this . sessionsInProgress . remove ( sessionId ) ;
}
public boolean hasRemainingWork ( ) {
return ! this . messages . isEmpty ( ) ;
}
public boolean shouldYield ( ) {
// No remaining work, or others can pick it up
return ( ! hasRemainingWork ( ) | | this . sessionsInProgress . size ( ) > 0 ) ;
}
}
private static class CompletionTaskInterceptor implements ExecutorChannelInterceptor {
@Override
public void afterMessageHandled (
Message < ? > msg , MessageChannel ch , MessageHandler handler , @Nullable Exception ex ) {
Message < ? > mes sa ge , MessageChannel ch , MessageHandler handler , @Nullable Exception ex ) {
Runnable task = ( Runnable ) msg . getHeaders ( ) . get ( OrderedMessageSender . COMPLETION_TASK_HEADER ) ;
Runnable task = ( Runnable ) mes sa ge . getHeaders ( ) . get ( OrderedMessageSender . COMPLETION_TASK_HEADER ) ;
if ( task ! = null ) {
task . run ( ) ;
}