From 23b39ad27be8fa1c8bb058d4e99903dd9ae4a39a Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 27 Feb 2019 12:08:51 -0500 Subject: [PATCH] Explicit handling of void return values Do give HandlerMethodReturnValueHandler's a chance to handle return values so the RSocket reply header is always set. See gh-21987 --- ...stractEncoderMethodReturnValueHandler.java | 13 ++++ .../invocation/reactive/InvocableHelper.java | 2 + .../messaging/rsocket/MessagingRSocket.java | 5 +- .../rsocket/RSocketMessageHandler.java | 17 ++++++ .../RSocketPayloadReturnValueHandler.java | 28 ++++++--- .../TestEncoderMethodReturnValueHandler.java | 6 ++ ...RSocketClientToServerIntegrationTests.java | 60 ++++++++++--------- 7 files changed, 93 insertions(+), 38 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java index 5c185ce45fe..28387c292d8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java @@ -104,6 +104,10 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler public Mono handleReturnValue( @Nullable Object returnValue, MethodParameter returnType, Message message) { + if (returnValue == null) { + return handleNoContent(returnType, message); + } + DataBufferFactory bufferFactory = (DataBufferFactory) message.getHeaders() .getOrDefault(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, this.defaultBufferFactory); @@ -202,4 +206,13 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler protected abstract Mono handleEncodedContent( Flux encodedContent, MethodParameter returnType, Message message); + /** + * Invoked for a {@code null} return value, which could mean a void method + * or method returning an async type parameterized by void. + * @param returnType return type of the handler method that produced the data + * @param message the input message handled by the handler method + * @return completion {@code Mono} for the handling + */ + protected abstract Mono handleNoContent(MethodParameter returnType, Message message); + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHelper.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHelper.java index 9bdfe47f12e..84e5e766e48 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHelper.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHelper.java @@ -182,6 +182,7 @@ class InvocableHelper { logger.debug("Invoking " + invocable.getShortLogMessage()); } return invocable.invoke(message) + .switchIfEmpty(Mono.defer(() -> handleReturnValue(null, invocable, message))) .flatMap(returnValue -> handleReturnValue(returnValue, invocable, message)) .onErrorResume(ex -> { InvocableHandlerMethod exHandler = initExceptionHandlerMethod(handlerMethod, ex); @@ -192,6 +193,7 @@ class InvocableHelper { logger.debug("Invoking " + exHandler.getShortLogMessage()); } return exHandler.invoke(message, ex) + .switchIfEmpty(Mono.defer(() -> handleReturnValue(null, exHandler, message))) .flatMap(returnValue -> handleReturnValue(returnValue, exHandler, message)); }); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessagingRSocket.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessagingRSocket.java index 63953937393..2cfde5aa411 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessagingRSocket.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessagingRSocket.java @@ -32,7 +32,6 @@ import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.PooledDataBuffer; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; @@ -123,6 +122,7 @@ class MessagingRSocket extends AbstractRSocket { private Mono handle(Payload payload) { Message message = MessageBuilder.createMessage( Mono.fromCallable(() -> wrapPayloadData(payload)), createHeaders(payload, null)); + return this.handler.apply(message); } @@ -131,10 +131,11 @@ class MessagingRSocket extends AbstractRSocket { Message message = MessageBuilder.createMessage( payloads.map(this::wrapPayloadData).doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release), createHeaders(firstPayload, replyMono)); + return this.handler.apply(message) .thenMany(Flux.defer(() -> replyMono.isTerminated() ? replyMono.flatMapMany(Function.identity()) : - Mono.error(new MessageDeliveryException("RSocket request not handled")))); + Mono.error(new IllegalStateException("Something went wrong: reply Mono not set")))); } private MessageHeaders createHeaders(Payload payload, @Nullable MonoProcessor replyMono) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketMessageHandler.java index cac3ee2f4a8..3e038598aca 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketMessageHandler.java @@ -21,9 +21,12 @@ import java.util.List; import org.springframework.core.codec.Decoder; import org.springframework.core.codec.Encoder; import org.springframework.lang.Nullable; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.handler.annotation.support.reactive.MessageMappingMessageHandler; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * RSocket-specific extension of {@link MessageMappingMessageHandler}. @@ -105,4 +108,18 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { return handlers; } + @Override + protected void handleNoMatch(@Nullable String destination, Message message) { + + // MessagingRSocket will raise an error anyway if reply Mono is expected + // Here we raise a more helpful message a destination is present + + // It is OK if some messages (ConnectionSetupPayload, metadataPush) are not handled + // We need a better way to avoid raising errors for those + + if (StringUtils.hasText(destination)) { + throw new MessageDeliveryException("No handler for destination '" + destination + "'"); + } + } + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketPayloadReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketPayloadReturnValueHandler.java index c841736a7c8..4c671c4e13a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketPayloadReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketPayloadReturnValueHandler.java @@ -26,6 +26,7 @@ import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.codec.Encoder; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.messaging.handler.invocation.reactive.AbstractEncoderMethodReturnValueHandler; import org.springframework.util.Assert; @@ -58,15 +59,28 @@ public class RSocketPayloadReturnValueHandler extends AbstractEncoderMethodRetur protected Mono handleEncodedContent( Flux encodedContent, MethodParameter returnType, Message message) { - Object headerValue = message.getHeaders().get(RESPONSE_HEADER); - Assert.notNull(headerValue, "Missing '" + RESPONSE_HEADER + "'"); - Assert.isInstanceOf(MonoProcessor.class, headerValue, "Expected MonoProcessor"); - - MonoProcessor> monoProcessor = (MonoProcessor>) headerValue; - monoProcessor.onNext(encodedContent.map(PayloadUtils::createPayload)); - monoProcessor.onComplete(); + MonoProcessor> replyMono = getReplyMono(message); + Assert.notNull(replyMono, "Missing '" + RESPONSE_HEADER + "'"); + replyMono.onNext(encodedContent.map(PayloadUtils::createPayload)); + replyMono.onComplete(); + return Mono.empty(); + } + @Override + protected Mono handleNoContent(MethodParameter returnType, Message message) { + MonoProcessor> replyMono = getReplyMono(message); + if (replyMono != null) { + replyMono.onComplete(); + } return Mono.empty(); } + @Nullable + @SuppressWarnings("unchecked") + private MonoProcessor> getReplyMono(Message message) { + Object headerValue = message.getHeaders().get(RESPONSE_HEADER); + Assert.state(headerValue == null || headerValue instanceof MonoProcessor, "Expected MonoProcessor"); + return (MonoProcessor>) headerValue; + } + } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/TestEncoderMethodReturnValueHandler.java b/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/TestEncoderMethodReturnValueHandler.java index 3a47d53af9b..2d07a4ad86d 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/TestEncoderMethodReturnValueHandler.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/TestEncoderMethodReturnValueHandler.java @@ -60,4 +60,10 @@ public class TestEncoderMethodReturnValueHandler extends AbstractEncoderMethodRe this.encodedContent = encodedContent.cache(); return this.encodedContent.then(); } + + @Override + protected Mono handleNoContent(MethodParameter returnType, Message message) { + this.encodedContent = Flux.empty(); + return Mono.empty(); + } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java index a37ce6b2766..6e189e241ed 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java @@ -111,83 +111,73 @@ public class RSocketClientToServerIntegrationTests { @Test public void echo() { - Flux result = Flux.range(1, 3).concatMap(i -> requester.route("echo").data("Hello " + i).retrieveMono(String.class)); StepVerifier.create(result) - .expectNext("Hello 1") - .expectNext("Hello 2") - .expectNext("Hello 3") + .expectNext("Hello 1").expectNext("Hello 2").expectNext("Hello 3") .verifyComplete(); } @Test public void echoAsync() { - Flux result = Flux.range(1, 3).concatMap(i -> requester.route("echo-async").data("Hello " + i).retrieveMono(String.class)); StepVerifier.create(result) - .expectNext("Hello 1 async") - .expectNext("Hello 2 async") - .expectNext("Hello 3 async") + .expectNext("Hello 1 async").expectNext("Hello 2 async").expectNext("Hello 3 async") .verifyComplete(); } @Test public void echoStream() { - Flux result = requester.route("echo-stream").data("Hello").retrieveFlux(String.class); StepVerifier.create(result) - .expectNext("Hello 0") - .expectNextCount(5) - .expectNext("Hello 6") - .expectNext("Hello 7") + .expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7") .thenCancel() .verify(); } @Test public void echoChannel() { - Flux result = requester.route("echo-channel") .data(Flux.range(1, 10).map(i -> "Hello " + i), String.class) .retrieveFlux(String.class); StepVerifier.create(result) - .expectNext("Hello 1 async") - .expectNextCount(7) - .expectNext("Hello 9 async") - .expectNext("Hello 10 async") + .expectNext("Hello 1 async").expectNextCount(8).expectNext("Hello 10 async") .verifyComplete(); } @Test - public void handleWithThrownException() { + public void voidReturnValue() { + Flux result = requester.route("void-return-value").data("Hello").retrieveFlux(String.class); + StepVerifier.create(result).verifyComplete(); + } - Mono result = requester.route("thrown-exception").data("a").retrieveMono(String.class); + @Test + public void voidReturnValueFromExceptionHandler() { + Flux result = requester.route("void-return-value").data("bad").retrieveFlux(String.class); + StepVerifier.create(result).verifyComplete(); + } - StepVerifier.create(result) - .expectNext("Invalid input error handled") - .verifyComplete(); + @Test + public void handleWithThrownException() { + Mono result = requester.route("thrown-exception").data("a").retrieveMono(String.class); + StepVerifier.create(result).expectNext("Invalid input error handled").verifyComplete(); } @Test public void handleWithErrorSignal() { - Mono result = requester.route("error-signal").data("a").retrieveMono(String.class); - - StepVerifier.create(result) - .expectNext("Invalid input error handled") - .verifyComplete(); + StepVerifier.create(result).expectNext("Invalid input error handled").verifyComplete(); } @Test public void noMatchingRoute() { Mono result = requester.route("invalid").data("anything").retrieveMono(String.class); - StepVerifier.create(result).verifyErrorMessage("RSocket request not handled"); + StepVerifier.create(result).verifyErrorMessage("No handler for destination 'invalid'"); } @@ -232,10 +222,22 @@ public class RSocketClientToServerIntegrationTests { return Mono.error(new IllegalArgumentException("Invalid input error")); } + @MessageMapping("void-return-value") + Mono voidReturnValue(String payload) { + return !payload.equals("bad") ? + Mono.delay(Duration.ofMillis(10)).then(Mono.empty()) : + Mono.error(new IllegalStateException("bad")); + } + @MessageExceptionHandler Mono handleException(IllegalArgumentException ex) { return Mono.delay(Duration.ofMillis(10)).map(aLong -> ex.getMessage() + " handled"); } + + @MessageExceptionHandler + Mono handleExceptionWithVoidReturnValue(IllegalStateException ex) { + return Mono.delay(Duration.ofMillis(10)).then(Mono.empty()); + } }