@ -16,21 +16,12 @@
@@ -16,21 +16,12 @@
package org.springframework.web.messaging.stomp.support ;
import java.io.BufferedInputStream ;
import java.io.BufferedOutputStream ;
import java.io.ByteArrayOutputStream ;
import java.io.IOException ;
import java.io.InputStream ;
import java.io.OutputStream ;
import java.net.Socket ;
import java.nio.charset.Charset ;
import java.util.Collection ;
import java.util.List ;
import java.util.Map ;
import java.util.concurrent.ConcurrentHashMap ;
import javax.net.SocketFactory ;
import org.springframework.core.task.TaskExecutor ;
import org.springframework.http.MediaType ;
import org.springframework.messaging.GenericMessage ;
import org.springframework.messaging.Message ;
@ -45,6 +36,15 @@ import org.springframework.web.messaging.service.AbstractPubSubMessageHandler;
@@ -45,6 +36,15 @@ import org.springframework.web.messaging.service.AbstractPubSubMessageHandler;
import org.springframework.web.messaging.stomp.StompCommand ;
import org.springframework.web.messaging.stomp.StompHeaders ;
import reactor.core.Environment ;
import reactor.core.Promise ;
import reactor.fn.Consumer ;
import reactor.tcp.TcpClient ;
import reactor.tcp.TcpConnection ;
import reactor.tcp.encoding.DelimitedCodec ;
import reactor.tcp.encoding.StandardCodecs ;
import reactor.tcp.netty.NettyTcpClient ;
/ * *
* @author Rossen Stoyanchev
@ -57,19 +57,22 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
@@ -57,19 +57,22 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
private MessageConverter payloadConverter ;
private final TaskExecutor taskExecutor ;
private final TcpClient < String , String > tcpClient ;
private Map < String , RelaySession > relaySessions = new ConcurrentHashMap < String , RelaySession > ( ) ;
private final Map < String , TcpConnection < String , String > > connections =
new ConcurrentHashMap < String , TcpConnection < String , String > > ( ) ;
/ * *
* @param executor
* /
public StompRelayPubSubMessageHandler ( SubscribableChannel publishChannel , MessageChannel clientChannel ,
TaskExecutor executor ) {
public StompRelayPubSubMessageHandler ( SubscribableChannel publishChannel , MessageChannel clientChannel ) {
super ( publishChannel , clientChannel ) ;
this . taskExecutor = executor ; // For now, a naive way to manage socket reading
this . tcpClient = new TcpClient . Spec < String , String > ( NettyTcpClient . class )
. using ( new Environment ( ) )
. codec ( new DelimitedCodec < String , String > ( ( byte ) 0 , StandardCodecs . STRING_CODEC ) )
. connect ( "127.0.0.1" , 61613 )
. get ( ) ;
this . payloadConverter = new CompositeMessageConverter ( null ) ;
}
@ -84,34 +87,52 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
@@ -84,34 +87,52 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
}
@Override
public void handleConnect ( Message < ? > message ) {
public void handleConnect ( final Message < ? > message ) {
String sessionId = ( String ) message . getHeaders ( ) . get ( PubSubHeaders . SESSION_ID ) ;
final String sessionId = ( String ) message . getHeaders ( ) . get ( PubSubHeaders . SESSION_ID ) ;
RelaySession session = new RelaySession ( ) ;
this . relaySessions . put ( sessionId , session ) ;
Promise < TcpConnection < String , String > > promise = this . tcpClient . open ( ) ;
try {
Socket socket = SocketFactory . getDefault ( ) . createSocket ( "127.0.0.1" , 61613 ) ;
session . setSocket ( socket ) ;
promise . onSuccess ( new Consumer < TcpConnection < String , String > > ( ) {
@Override
public void accept ( TcpConnection < String , String > connection ) {
connections . put ( sessionId , connection ) ;
forwardMessage ( message , StompCommand . CONNECT ) ;
}
} ) ;
promise . consume ( new Consumer < TcpConnection < String , String > > ( ) {
@Override
public void accept ( TcpConnection < String , String > connection ) {
connection . in ( ) . consume ( new Consumer < String > ( ) {
@Override
public void accept ( String stompFrame ) {
if ( stompFrame . isEmpty ( ) ) {
// TODO: why are we getting empty frames?
return ;
}
Message < byte [ ] > message = stompMessageConverter . toMessage ( stompFrame , sessionId ) ;
getClientChannel ( ) . send ( message ) ;
}
} ) ;
}
} ) ;
forwardMessage ( message , StompCommand . CONNECT ) ;
// TODO: ATM no way to detect closed socket
// StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
// stompHeaders.setMessage("Socket closed, STOMP session=" + sessionId);
// stompHeaders.setSessionId(sessionId);
// Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.toMessageHeaders());
// getClientChannel().send(errorMessage);
RelayReadTask readTask = new RelayReadTask ( sessionId , session ) ;
this . taskExecutor . execute ( readTask ) ;
}
catch ( Throwable t ) {
t . printStackTrace ( ) ;
clearRelaySession ( sessionId ) ;
}
}
private void forwardMessage ( Message < ? > message , StompCommand command ) {
StompHeaders stompHeaders = StompHeaders . fromMessageHeaders ( message . getHeaders ( ) ) ;
String sessionId = stompHeaders . getSessionId ( ) ;
RelaySession session = StompRelayPubSubMessageHandler . this . relaySessions . get ( sessionId ) ;
Assert . notNull ( session , "RelaySession not found" ) ;
byte [ ] bytesToWrite ;
try {
stompHeaders . setStompCommandIfNotSet ( StompCommand . SEND ) ;
@ -119,32 +140,46 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
@@ -119,32 +140,46 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
MediaType contentType = stompHeaders . getContentType ( ) ;
byte [ ] payload = this . payloadConverter . convertToPayload ( message . getPayload ( ) , contentType ) ;
Message < byte [ ] > byteMessage = new GenericMessage < byte [ ] > ( payload , stompHeaders . toMessageHeaders ( ) ) ;
bytesToWrite = this . stompMessageConverter . fromMessage ( byteMessage ) ;
}
catch ( Throwable ex ) {
logger . error ( "Failed to forward message " + message , ex ) ;
return ;
}
TcpConnection < String , String > connection = getConnection ( sessionId ) ;
Assert . notNull ( connection , "TCP connection to message broker not found, sessionId=" + sessionId ) ;
try {
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "Forwarding STOMP " + stompHeaders . getStompCommand ( ) + " message" ) ;
}
byte [ ] bytes = this . stompMessageConverter . fromMessage ( byteMessage ) ;
session . getOutputStream ( ) . write ( bytes ) ;
session . getOutputStream ( ) . flush ( ) ;
connection . out ( ) . accept ( new String ( bytesToWrite , Charset . forName ( "UTF-8" ) ) ) ;
}
catch ( Exception ex ) {
logger . error ( "Couldn't forward message " + message , ex ) ;
clearRelaySession ( sessionId ) ;
catch ( Throwable ex ) {
logger . error ( "Could not get TCP connection " + sessionId , ex ) ;
try {
if ( connection ! = null ) {
connection . close ( ) ;
}
}
catch ( Throwable t ) {
// ignore
}
}
}
private void clearRelaySession ( String stompSessionId ) {
RelaySession relaySession = this . relaySessions . remove ( stompSessionId ) ;
if ( relaySession ! = null ) {
// TODO: raise failure event so client session can be closed
private TcpConnection < String , String > getConnection ( String sessionId ) {
TcpConnection < String , String > connection = this . connections . get ( sessionId ) ;
if ( connection = = null ) {
try {
relaySession . getSocket ( ) . close ( ) ;
Thread . sleep ( 1000 ) ;
}
catch ( IO Exception e ) {
// ignore
catch ( Interrupted Exception e ) {
return null ;
}
}
connection = this . connections . get ( sessionId ) ;
return connection ;
}
@Override
@ -174,6 +209,8 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
@@ -174,6 +209,8 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
forwardMessage ( message , command ) ;
}
// TODO:
/ * @Override
public void handleClientConnectionClosed ( String sessionId ) {
if ( logger . isDebugEnabled ( ) ) {
@ -183,81 +220,4 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
@@ -183,81 +220,4 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
}
* /
private final static class RelaySession {
private Socket socket ;
private InputStream inputStream ;
private OutputStream outputStream ;
public void setSocket ( Socket socket ) throws IOException {
this . socket = socket ;
this . inputStream = new BufferedInputStream ( socket . getInputStream ( ) ) ;
this . outputStream = new BufferedOutputStream ( socket . getOutputStream ( ) ) ;
}
public Socket getSocket ( ) {
return this . socket ;
}
public InputStream getInputStream ( ) {
return this . inputStream ;
}
public OutputStream getOutputStream ( ) {
return this . outputStream ;
}
}
private final class RelayReadTask implements Runnable {
private final String sessionId ;
private final RelaySession session ;
private RelayReadTask ( String sessionId , RelaySession session ) {
this . sessionId = sessionId ;
this . session = session ;
}
@Override
public void run ( ) {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream ( ) ;
while ( ! session . getSocket ( ) . isClosed ( ) ) {
int b = session . getInputStream ( ) . read ( ) ;
if ( b = = - 1 ) {
break ;
}
else if ( b = = 0x00 ) {
byte [ ] bytes = out . toByteArray ( ) ;
Message < byte [ ] > message = stompMessageConverter . toMessage ( bytes , sessionId ) ;
getClientChannel ( ) . send ( message ) ;
out . reset ( ) ;
}
else {
out . write ( b ) ;
}
}
logger . debug ( "Socket closed, STOMP session=" + sessionId ) ;
sendErrorMessage ( "Lost connection" ) ;
}
catch ( IOException e ) {
logger . error ( "Socket error: " + e . getMessage ( ) ) ;
clearRelaySession ( sessionId ) ;
sendErrorMessage ( "Lost connection" ) ;
}
}
private void sendErrorMessage ( String message ) {
StompHeaders stompHeaders = StompHeaders . create ( StompCommand . ERROR ) ;
stompHeaders . setMessage ( message ) ;
stompHeaders . setSessionId ( this . sessionId ) ;
Message < byte [ ] > errorMessage = new GenericMessage < byte [ ] > ( new byte [ 0 ] , stompHeaders . toMessageHeaders ( ) ) ;
getClientChannel ( ) . send ( errorMessage ) ;
}
}
}