@ -19,6 +19,9 @@ package org.springframework.messaging.rsocket.annotation.support;
import java.lang.reflect.AnnotatedElement ;
import java.lang.reflect.AnnotatedElement ;
import java.util.ArrayList ;
import java.util.ArrayList ;
import java.util.List ;
import java.util.List ;
import java.util.Set ;
import java.util.function.Predicate ;
import java.util.stream.Collectors ;
import io.rsocket.ConnectionSetupPayload ;
import io.rsocket.ConnectionSetupPayload ;
import io.rsocket.RSocket ;
import io.rsocket.RSocket ;
@ -28,6 +31,8 @@ import io.rsocket.metadata.WellKnownMimeType;
import reactor.core.publisher.Mono ;
import reactor.core.publisher.Mono ;
import org.springframework.beans.BeanUtils ;
import org.springframework.beans.BeanUtils ;
import org.springframework.core.MethodParameter ;
import org.springframework.core.ReactiveAdapter ;
import org.springframework.core.ReactiveAdapterRegistry ;
import org.springframework.core.ReactiveAdapterRegistry ;
import org.springframework.core.annotation.AnnotatedElementUtils ;
import org.springframework.core.annotation.AnnotatedElementUtils ;
import org.springframework.core.codec.Decoder ;
import org.springframework.core.codec.Decoder ;
@ -37,8 +42,11 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageDeliveryException ;
import org.springframework.messaging.MessageDeliveryException ;
import org.springframework.messaging.handler.CompositeMessageCondition ;
import org.springframework.messaging.handler.CompositeMessageCondition ;
import org.springframework.messaging.handler.DestinationPatternsMessageCondition ;
import org.springframework.messaging.handler.DestinationPatternsMessageCondition ;
import org.springframework.messaging.handler.HandlerMethod ;
import org.springframework.messaging.handler.MessageCondition ;
import org.springframework.messaging.handler.annotation.MessageMapping ;
import org.springframework.messaging.handler.annotation.MessageMapping ;
import org.springframework.messaging.handler.annotation.reactive.MessageMappingMessageHandler ;
import org.springframework.messaging.handler.annotation.reactive.MessageMappingMessageHandler ;
import org.springframework.messaging.handler.annotation.reactive.PayloadMethodArgumentResolver ;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler ;
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler ;
import org.springframework.messaging.rsocket.ClientRSocketFactoryConfigurer ;
import org.springframework.messaging.rsocket.ClientRSocketFactoryConfigurer ;
import org.springframework.messaging.rsocket.MetadataExtractor ;
import org.springframework.messaging.rsocket.MetadataExtractor ;
@ -55,12 +63,27 @@ import org.springframework.util.StringUtils;
* Extension of { @link MessageMappingMessageHandler } for handling RSocket
* Extension of { @link MessageMappingMessageHandler } for handling RSocket
* requests with { @link ConnectMapping @ConnectMapping } and
* requests with { @link ConnectMapping @ConnectMapping } and
* { @link MessageMapping @MessageMapping } methods .
* { @link MessageMapping @MessageMapping } methods .
* < p > Use { @link # responder ( ) } to obtain a { @link SocketAcceptor } adapter to
*
* plug in as responder into an { @link io . rsocket . RSocketFactory } .
* < p > For server scenarios this class can be declared as a bean in Spring
* < p > Use { @link # clientResponder ( RSocketStrategies , Object . . . ) } to obtain a
* configuration and that would detect { @code @MessageMapping } methods in
* client responder configurer
* { @code @Controller } beans . What beans are checked can be changed through a
* { @link # setHandlerPredicate ( Predicate ) handlerPredicate } . Given an instance
* of this class , you can then use { @link # responder ( ) } to obtain a
* { @link SocketAcceptor } adapter to register with the
* { @link io . rsocket . RSocketFactory } .
*
* < p > For client scenarios , possibly in the same process as a server , consider
* consider using the static factory method
* { @link # clientResponder ( RSocketStrategies , Object . . . ) } to obtain a client
* responder to be registered with an
* { @link org . springframework . messaging . rsocket . RSocketRequester . Builder # rsocketFactory
* { @link org . springframework . messaging . rsocket . RSocketRequester . Builder # rsocketFactory
* RSocketRequester } .
* RSocketRequester . Builder } .
*
* < p > For { @code @MessageMapping } methods , this class automatically determines
* the RSocket interaction type based on the input and output cardinality of the
* method . See the
* < a href = "https://docs.spring.io/spring/docs/current/spring-framework-reference/web-reactive.html#rsocket-annot-responders" >
* "Annotated Responders" < / a > section of the Spring Framework reference for more details .
*
*
* @author Rossen Stoyanchev
* @author Rossen Stoyanchev
* @since 5 . 2
* @since 5 . 2
@ -263,6 +286,17 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler {
getArgumentResolverConfigurer ( ) . addCustomResolver ( new RSocketRequesterMethodArgumentResolver ( ) ) ;
getArgumentResolverConfigurer ( ) . addCustomResolver ( new RSocketRequesterMethodArgumentResolver ( ) ) ;
super . afterPropertiesSet ( ) ;
super . afterPropertiesSet ( ) ;
getHandlerMethods ( ) . forEach ( ( composite , handler ) - > {
if ( composite . getMessageConditions ( ) . contains ( RSocketFrameTypeMessageCondition . CONNECT_CONDITION ) ) {
MethodParameter returnType = handler . getReturnType ( ) ;
if ( getCardinality ( returnType ) > 0 ) {
throw new IllegalStateException (
"Invalid @ConnectMapping method. " +
"Return type must be void or a void async type: " + handler ) ;
}
}
} ) ;
}
}
@Override
@Override
@ -279,10 +313,9 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler {
protected CompositeMessageCondition getCondition ( AnnotatedElement element ) {
protected CompositeMessageCondition getCondition ( AnnotatedElement element ) {
MessageMapping ann1 = AnnotatedElementUtils . findMergedAnnotation ( element , MessageMapping . class ) ;
MessageMapping ann1 = AnnotatedElementUtils . findMergedAnnotation ( element , MessageMapping . class ) ;
if ( ann1 ! = null & & ann1 . value ( ) . length > 0 ) {
if ( ann1 ! = null & & ann1 . value ( ) . length > 0 ) {
String [ ] patterns = processDestinations ( ann1 . value ( ) ) ;
return new CompositeMessageCondition (
return new CompositeMessageCondition (
RSocketFrameTypeMessageCondition . REQUEST _CONDITION,
RSocketFrameTypeMessageCondition . EMPTY _CONDITION,
new DestinationPatternsMessageCondition ( patterns , obtainRouteMatcher ( ) ) ) ;
new DestinationPatternsMessageCondition ( processDestinations ( ann1 . value ( ) ) , obtainRouteMatcher ( ) ) ) ;
}
}
ConnectMapping ann2 = AnnotatedElementUtils . findMergedAnnotation ( element , ConnectMapping . class ) ;
ConnectMapping ann2 = AnnotatedElementUtils . findMergedAnnotation ( element , ConnectMapping . class ) ;
if ( ann2 ! = null ) {
if ( ann2 ! = null ) {
@ -294,6 +327,45 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler {
return null ;
return null ;
}
}
@Override
protected CompositeMessageCondition extendMapping ( CompositeMessageCondition composite , HandlerMethod handler ) {
List < MessageCondition < ? > > conditions = composite . getMessageConditions ( ) ;
Assert . isTrue ( conditions . size ( ) = = 2 & &
conditions . get ( 0 ) instanceof RSocketFrameTypeMessageCondition & &
conditions . get ( 1 ) instanceof DestinationPatternsMessageCondition ,
"Unexpected message condition types" ) ;
if ( conditions . get ( 0 ) ! = RSocketFrameTypeMessageCondition . EMPTY_CONDITION ) {
return composite ;
}
int responseCardinality = getCardinality ( handler . getReturnType ( ) ) ;
int requestCardinality = 0 ;
for ( MethodParameter parameter : handler . getMethodParameters ( ) ) {
if ( getArgumentResolvers ( ) . getArgumentResolver ( parameter ) instanceof PayloadMethodArgumentResolver ) {
requestCardinality = getCardinality ( parameter ) ;
}
}
return new CompositeMessageCondition (
RSocketFrameTypeMessageCondition . getCondition ( requestCardinality , responseCardinality ) ,
conditions . get ( 1 ) ) ;
}
private int getCardinality ( MethodParameter parameter ) {
Class < ? > clazz = parameter . getParameterType ( ) ;
ReactiveAdapter adapter = getReactiveAdapterRegistry ( ) . getAdapter ( clazz ) ;
if ( adapter = = null ) {
return clazz . equals ( void . class ) ? 0 : 1 ;
}
else if ( parameter . nested ( ) . getNestedParameterType ( ) . equals ( Void . class ) ) {
return 0 ;
}
else {
return adapter . isMultiValue ( ) ? 2 : 1 ;
}
}
@Override
@Override
protected void handleNoMatch ( @Nullable RouteMatcher . Route destination , Message < ? > message ) {
protected void handleNoMatch ( @Nullable RouteMatcher . Route destination , Message < ? > message ) {
@ -306,7 +378,18 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler {
logger . warn ( "No handler for fireAndForget to '" + destination + "'" ) ;
logger . warn ( "No handler for fireAndForget to '" + destination + "'" ) ;
return ;
return ;
}
}
throw new MessageDeliveryException ( "No handler for destination '" + destination + "'" ) ;
Set < FrameType > frameTypes = getHandlerMethods ( ) . keySet ( ) . stream ( )
. map ( CompositeMessageCondition : : getMessageConditions )
. filter ( conditions - > conditions . get ( 1 ) . getMatchingCondition ( message ) ! = null )
. map ( conditions - > ( RSocketFrameTypeMessageCondition ) conditions . get ( 0 ) )
. flatMap ( condition - > condition . getFrameTypes ( ) . stream ( ) )
. collect ( Collectors . toSet ( ) ) ;
throw new MessageDeliveryException ( frameTypes . isEmpty ( ) ?
"No handler for destination '" + destination + "'" :
"Destination '" + destination + "' does not support " + frameType + ". " +
"Supported interaction(s): " + frameTypes ) ;
}
}
/ * *
/ * *