@ -33,12 +33,11 @@ import org.springframework.util.Assert;
@@ -33,12 +33,11 @@ import org.springframework.util.Assert;
import org.springframework.util.StringUtils ;
import org.springframework.web.messaging.MessageType ;
import org.springframework.web.messaging.PubSubChannelRegistry ;
import org.springframework.web.messaging.PubSubHeaders ;
import org.springframework.web.messaging.converter.CompositeMessageConverter ;
import org.springframework.web.messaging.converter.MessageConverter ;
import org.springframework.web.messaging.service.AbstractPubSubMessageHandler ;
import org.springframework.web.messaging.stomp.StompCommand ;
import org.springframework.web.messaging.stomp.StompHeaders ;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor ;
import reactor.core.Environment ;
import reactor.core.Promise ;
@ -97,7 +96,7 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@@ -97,7 +96,7 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@Override
public void handleConnect ( M message ) {
StompHeaders stompHeaders = StompHeaders . fromMessageHeaders ( message . getHeaders ( ) ) ;
StompHeaderAcce ssor stompHeaders = StompHeaderAccessor . wrap ( message ) ;
String sessionId = stompHeaders . getSessionId ( ) ;
if ( sessionId = = null ) {
logger . error ( "No sessionId in message " + message ) ;
@ -124,7 +123,7 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@@ -124,7 +123,7 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@Override
public void handleDisconnect ( M message ) {
StompHeaders stompHeaders = StompHeaders . fromMessageHeaders ( message . getHeaders ( ) ) ;
StompHeaderAcce ssor stompHeaders = StompHeaderAccessor . wrap ( message ) ;
if ( stompHeaders . getStompCommand ( ) ! = null ) {
forwardMessage ( message , StompCommand . DISCONNECT ) ;
}
@ -137,14 +136,14 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@@ -137,14 +136,14 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@Override
public void handleOther ( M message ) {
StompCommand command = ( StompCommand ) message . getHeaders ( ) . get ( PubSubHeaders . PROTOCOL_MESSAGE_TYPE ) ;
StompCommand command = ( StompCommand ) message . getHeaders ( ) . get ( PubSubHeaderAcce sssor . PROTOCOL_MESSAGE_TYPE ) ;
Assert . notNull ( command , "Expected STOMP command: " + message . getHeaders ( ) ) ;
forwardMessage ( message , command ) ;
}
private void forwardMessage ( M message , StompCommand command ) {
StompHeaders headers = StompHeaders . fromMessageHeaders ( message . getHeaders ( ) ) ;
StompHeaderAcce ssor headers = StompHeaderAccessor . wrap ( message ) ;
headers . setStompCommandIfNotSet ( command ) ;
String sessionId = headers . getSessionId ( ) ;
@ -174,9 +173,10 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@@ -174,9 +173,10 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
private final Object monitor = new Object ( ) ;
private boolean isConnected = false ;
private volatile boolean isConnected = false ;
public RelaySession ( final M message , final StompHeaders stompHeaders ) {
public RelaySession ( final M message , final StompHeaderAccessor stompHeaders ) {
Assert . notNull ( message , "message is required" ) ;
Assert . notNull ( stompHeaders , "stompHeaders is required" ) ;
@ -222,7 +222,7 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@@ -222,7 +222,7 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
logger . trace ( "Reading message " + message ) ;
}
StompHeaders headers = StompHeaders . fromMessageHeaders ( message . getHeaders ( ) ) ;
StompHeaderAcce ssor headers = StompHeaderAccessor . wrap ( message ) ;
if ( StompCommand . CONNECTED = = headers . getStompCommand ( ) ) {
synchronized ( this . monitor ) {
this . isConnected = true ;
@ -240,15 +240,15 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@@ -240,15 +240,15 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
}
private void sendError ( String sessionId , String errorText ) {
StompHeaders stompH eaders = StompHeaders . create ( StompCommand . ERROR ) ;
stompH eaders. setSessionId ( sessionId ) ;
stompH eaders. setMessage ( errorText ) ;
StompHeaderAccessor h eaders = StompHeaderAcce ssor . create ( StompCommand . ERROR ) ;
h eaders. setSessionId ( sessionId ) ;
h eaders. setMessage ( errorText ) ;
@SuppressWarnings ( "unchecked" )
M errorMessage = ( M ) MessageBuilder . fromPayloadAndHeaders ( new byte [ 0 ] , stompHeaders . toMessage Headers( ) ) . build ( ) ;
M errorMessage = ( M ) MessageBuilder . withPayload ( new byte [ 0 ] ) . copyHeaders ( headers . to Headers( ) ) . build ( ) ;
clientChannel . send ( errorMessage ) ;
}
public void forward ( M message , StompHeaders headers ) {
public void forward ( M message , StompHeaderAcce ssor headers ) {
if ( ! this . isConnected ) {
synchronized ( this . monitor ) {
@ -277,21 +277,21 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@@ -277,21 +277,21 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
List < M > messages = new ArrayList < M > ( ) ;
this . messageQueue . drainTo ( messages ) ;
for ( Message < ? > message : messages ) {
StompHeaders headers = StompHeaders . fromMessageHeaders ( message . getHeaders ( ) ) ;
StompHeaderAcce ssor headers = StompHeaderAccessor . wrap ( message ) ;
if ( ! forwardInternal ( message , headers , connection ) ) {
return ;
}
}
}
private boolean forwardInternal ( Message < ? > message , StompHeaders headers , TcpConnection < String , String > connection ) {
private boolean forwardInternal ( Message < ? > message , StompHeaderAcce ssor headers , TcpConnection < String , String > connection ) {
try {
headers . setStompCommandIfNotSet ( StompCommand . SEND ) ;
MediaType contentType = headers . getContentType ( ) ;
byte [ ] payload = payloadConverter . convertToPayload ( message . getPayload ( ) , contentType ) ;
@SuppressWarnings ( "unchecked" )
M byteMessage = ( M ) MessageBuilder . fromPayloadAndHeaders ( payload , headers . toMessage Headers ( ) ) . build ( ) ;
M byteMessage = ( M ) MessageBuilder . withPayload ( payload ) . copyHeaders ( headers . toHeaders ( ) ) . build ( ) ;
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "Forwarding message " + byteMessage ) ;