From 5bf070a857e80333f59c7f071b5fcf7b9e4fedad Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 28 Jun 2019 21:53:30 +0100 Subject: [PATCH] More flexible RSocket metadata support The responding side now relies on a new MetadataExtractor which decodes metadata entries of interest, and adds them to an output map whose values are then added as Message headers, and are hence accessible to controller methods. Decoded metadata entry values can be added to the output map one for one, or translated to any number of values (e.g. JSON properties), as long as one of the resulting pairs has a key called "route". On the requesting side, now any metadata can be sent, and a String route for example is not required to be provided explicitly. Instead an application could create any metadata (e.g. JSON properties) as long as the server can work out the route from it. The commit contains further refinements on the requesting side so that any mime type can be used, not only composite or routing metadata, e.g. a route in an "text/plain" entry. Closes gh-23157 --- .../rsocket/DefaultRSocketRequester.java | 54 +++-- .../messaging/rsocket/RSocketRequester.java | 39 ++-- .../support/DefaultMetadataExtractor.java | 209 ++++++++++++++++++ .../annotation/support/MessagingRSocket.java | 86 ++++--- .../annotation/support/MetadataExtractor.java | 55 +++++ .../support/RSocketMessageHandler.java | 64 ++++-- .../rsocket/DefaultRSocketRequesterTests.java | 73 ++++-- .../LeakAwareNettyDataBufferFactory.java | 126 +++++++++++ .../rsocket/RSocketBufferLeakTests.java | 103 --------- .../DefaultMetadataExtractorTests.java | 178 +++++++++++++++ 10 files changed, 761 insertions(+), 226 deletions(-) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractor.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MetadataExtractor.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/rsocket/LeakAwareNettyDataBufferFactory.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractorTests.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java index ef27e95f40c..6816b796d88 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java @@ -16,10 +16,8 @@ package org.springframework.messaging.rsocket; -import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; import io.netty.buffer.ByteBuf; @@ -59,8 +57,6 @@ final class DefaultRSocketRequester implements RSocketRequester { static final MimeType ROUTING = new MimeType("message", "x.rsocket.routing.v0"); - static final List METADATA_MIME_TYPES = Arrays.asList(COMPOSITE_METADATA, ROUTING); - private static final Map EMPTY_HINTS = Collections.emptyMap(); @@ -85,9 +81,6 @@ final class DefaultRSocketRequester implements RSocketRequester { Assert.notNull(metadataMimeType, "'metadataMimeType' is required"); Assert.notNull(strategies, "RSocketStrategies is required"); - Assert.isTrue(METADATA_MIME_TYPES.contains(metadataMimeType), - () -> "Unexpected metadatata mime type: '" + metadataMimeType + "'"); - this.rsocket = rsocket; this.dataMimeType = dataMimeType; this.metadataMimeType = metadataMimeType; @@ -113,7 +106,13 @@ final class DefaultRSocketRequester implements RSocketRequester { @Override public RequestSpec route(String route) { - return new DefaultRequestSpec(route); + Assert.notNull(route, "'route' is required"); + return new DefaultRequestSpec(route, metadataMimeType().equals(COMPOSITE_METADATA) ? ROUTING : null); + } + + @Override + public RequestSpec metadata(Object metadata, @Nullable MimeType mimeType) { + return new DefaultRequestSpec(metadata, mimeType); } @@ -131,16 +130,22 @@ final class DefaultRSocketRequester implements RSocketRequester { private final Map metadata = new LinkedHashMap<>(4); - public DefaultRequestSpec(String route) { - Assert.notNull(route, "'route' is required"); - metadata(route, ROUTING); + public DefaultRequestSpec(Object metadata, @Nullable MimeType mimeType) { + mimeType = (mimeType == null && !isCompositeMetadata() ? metadataMimeType() : mimeType); + Assert.notNull(mimeType, "MimeType is required for composite metadata"); + metadata(metadata, mimeType); } + private boolean isCompositeMetadata() { + return metadataMimeType().equals(COMPOSITE_METADATA); + } @Override public RequestSpec metadata(Object metadata, MimeType mimeType) { - Assert.isTrue(this.metadata.isEmpty() || metadataMimeType().equals(COMPOSITE_METADATA), - "Additional metadata entries supported only with composite metadata"); + Assert.notNull(metadata, "Metadata content is required"); + Assert.notNull(mimeType, "MimeType is required"); + Assert.isTrue(this.metadata.isEmpty() || isCompositeMetadata(), + "Composite metadata required for multiple metadata entries."); this.metadata.put(metadata, mimeType); return this; } @@ -250,22 +255,27 @@ final class DefaultRSocketRequester implements RSocketRequester { } private DataBuffer getMetadata() { - if (metadataMimeType().equals(COMPOSITE_METADATA)) { + if (isCompositeMetadata()) { CompositeByteBuf metadata = getAllocator().compositeBuffer(); - this.metadata.forEach((key, value) -> { - DataBuffer dataBuffer = encodeMetadata(key, value); - CompositeMetadataFlyweight.encodeAndAddMetadata(metadata, getAllocator(), value.toString(), + this.metadata.forEach((value, mimeType) -> { + DataBuffer dataBuffer = encodeMetadata(value, mimeType); + CompositeMetadataFlyweight.encodeAndAddMetadata(metadata, getAllocator(), mimeType.toString(), dataBuffer instanceof NettyDataBuffer ? ((NettyDataBuffer) dataBuffer).getNativeBuffer() : Unpooled.wrappedBuffer(dataBuffer.asByteBuffer())); }); return asDataBuffer(metadata); } - Assert.isTrue(this.metadata.size() < 2, "Composite metadata required for multiple entries"); - Map.Entry entry = this.metadata.entrySet().iterator().next(); - Assert.isTrue(metadataMimeType().equals(entry.getValue()), - () -> "Expected metadata MimeType '" + metadataMimeType() + "', actual " + this.metadata); - return encodeMetadata(entry.getKey(), entry.getValue()); + else { + Assert.isTrue(this.metadata.size() == 1, "Composite metadata required for multiple entries"); + Map.Entry entry = this.metadata.entrySet().iterator().next(); + if (!metadataMimeType().equals(entry.getValue())) { + throw new IllegalArgumentException( + "Connection configured for metadata mime type " + + "'" + metadataMimeType() + "', but actual is `" + this.metadata + "`"); + } + return encodeMetadata(entry.getKey(), entry.getValue()); + } } @SuppressWarnings("unchecked") diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java index ab9f22b0f68..5481dd6bf83 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java @@ -67,14 +67,25 @@ public interface RSocketRequester { /** - * Begin to specify a new request with the given route to a handler on the - * remote side. The route will be encoded in the metadata of the first - * payload. + * Begin to specify a new request with the given route to a remote handler. + *

If the connection is set to use composite metadata, the route is + * encoded as {@code "message/x.rsocket.routing.v0"}. Otherwise the route + * is encoded according to the mime type for the connection. * @param route the route to a handler * @return a spec for further defining and executing the request */ RequestSpec route(String route); + /** + * Begin to specify a new request with the given metadata. + *

If using composite metadata then the mime type argument is required. + * Otherwise the mime type should be {@code null}, or it must match the + * mime type for the connection. + * @param metadata the metadata value to encode + * @param mimeType the mime type that describes the metadata; + */ + RequestSpec metadata(Object metadata, @Nullable MimeType mimeType); + /** * Obtain a builder for an {@link RSocketRequester} by connecting to an @@ -110,24 +121,24 @@ public interface RSocketRequester { interface Builder { /** - * Configure the MimeType to use for payload data. This is then - * specified on the {@code SETUP} frame for the whole connection. - *

By default this is set to the first concrete MimeType supported + * Configure the MimeType for payload data which is then specified + * on the {@code SETUP} frame and applies to the whole connection. + *

By default this is set to the first concrete mime type supported * by the configured encoders and decoders. * @param mimeType the data MimeType to use */ RSocketRequester.Builder dataMimeType(@Nullable MimeType mimeType); /** - * Configure the MimeType to use for payload metadata. This is then - * specified on the {@code SETUP} frame for the whole connection. - *

At present the metadata MimeType must be - * {@code "message/x.rsocket.routing.v0"} to allow the request - * {@link RSocketRequester#route(String) route} to be encoded, or it - * could also be {@code "message/x.rsocket.composite-metadata.v0"} in - * which case the route can be encoded along with other metadata entries. + * Configure the MimeType for payload metadata which is then specified + * on the {@code SETUP} frame and applies to the whole connection. *

By default this is set to - * {@code "message/x.rsocket.composite-metadata.v0"}. + * {@code "message/x.rsocket.composite-metadata.v0"} in which case the + * route, if provided, is encoded as a + * {@code "message/x.rsocket.routing.v0"} metadata entry, potentially + * with other metadata entries added too. If this is set to any other + * mime type, and a route is provided, it is assumed the mime type is + * for the route. * @param mimeType the data MimeType to use */ RSocketRequester.Builder metadataMimeType(MimeType mimeType); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractor.java new file mode 100644 index 00000000000..bc008862d13 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractor.java @@ -0,0 +1,209 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.rsocket.annotation.support; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiConsumer; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.CompositeMetadata; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Decoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.lang.Nullable; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +/** + * Default {@link MetadataExtractor} implementation that relies on {@link Decoder}s + * to deserialize the content of metadata entries. + * + *

By default only {@code "message/x.rsocket.routing.v0""} is extracted and + * saved under {@link MetadataExtractor#ROUTE_KEY}. Use the + * {@code metadataToExtract} methods to specify other metadata mime types of + * interest to extract. + * + * @author Rossen Stoyanchev + * @since 5.2 + */ +public class DefaultMetadataExtractor implements MetadataExtractor { + + private final RSocketStrategies rsocketStrategies; + + private final Map> entryProcessors = new HashMap<>(); + + + /** + * Default constructor with {@link RSocketStrategies}. + */ + public DefaultMetadataExtractor(RSocketStrategies strategies) { + Assert.notNull(strategies, "RSocketStrategies is required"); + this.rsocketStrategies = strategies; + // TODO: remove when rsocket-core API available + metadataToExtract(MessagingRSocket.ROUTING, String.class, ROUTE_KEY); + } + + + /** + * Decode metadata entries with the given {@link MimeType} to the specified + * target class, and store the decoded value in the output map under the + * given name. + * @param mimeType the mime type of metadata entries to extract + * @param targetType the target value type to decode to + * @param name assign a name for the decoded value; if not provided, then + * the mime type is used as the key + */ + public void metadataToExtract( + MimeType mimeType, Class targetType, @Nullable String name) { + + String key = name != null ? name : mimeType.toString(); + metadataToExtract(mimeType, targetType, (value, map) -> map.put(key, value)); + } + + /** + * Variant of {@link #metadataToExtract(MimeType, Class, String)} that accepts + * {@link ParameterizedTypeReference} instead of {@link Class} for + * specifying a target type with generic parameters. + */ + public void metadataToExtract( + MimeType mimeType, ParameterizedTypeReference targetType, @Nullable String name) { + + String key = name != null ? name : mimeType.toString(); + metadataToExtract(mimeType, targetType, (value, map) -> map.put(key, value)); + } + + /** + * Variant of {@link #metadataToExtract(MimeType, Class, String)} that allows + * custom logic to be used to map the decoded value to any number of values + * in the output map. + * @param mimeType the mime type of metadata entries to extract + * @param targetType the target value type to decode to + * @param mapper custom logic to add the decoded value to the output map + * @param the target value type + */ + public void metadataToExtract( + MimeType mimeType, Class targetType, + BiConsumer> mapper) { + + EntryProcessor spec = new EntryProcessor<>(mimeType, targetType, mapper); + this.entryProcessors.put(mimeType.toString(), spec); + } + + /** + * Variant of {@link #metadataToExtract(MimeType, Class, BiConsumer)} that + * accepts {@link ParameterizedTypeReference} instead of {@link Class} for + * specifying a target type with generic parameters. + * @param mimeType the mime type of metadata entries to extract + * @param targetType the target value type to decode to + * @param mapper custom logic to add the decoded value to the output map + * @param the target value type + */ + public void metadataToExtract( + MimeType mimeType, ParameterizedTypeReference targetType, + BiConsumer> mapper) { + + EntryProcessor spec = new EntryProcessor<>(mimeType, targetType, mapper); + this.entryProcessors.put(mimeType.toString(), spec); + } + + + @Override + public Map extract(Payload payload, MimeType metadataMimeType) { + Map result = new HashMap<>(); + if (metadataMimeType.equals(MessagingRSocket.COMPOSITE_METADATA)) { + for (CompositeMetadata.Entry entry : new CompositeMetadata(payload.metadata(), false)) { + processEntry(entry.getContent(), entry.getMimeType(), result); + } + } + else { + processEntry(payload.metadata(), metadataMimeType.toString(), result); + } + return result; + } + + private void processEntry(ByteBuf content, @Nullable String mimeType, Map result) { + EntryProcessor entryProcessor = this.entryProcessors.get(mimeType); + if (entryProcessor != null) { + content.retain(); + entryProcessor.process(content, result); + return; + } + if (MessagingRSocket.ROUTING.toString().equals(mimeType)) { + // TODO: use rsocket-core API when available + } + } + + + /** + * Helps to decode a metadata entry and add the resulting value to the + * output map. + */ + private class EntryProcessor { + + private final MimeType mimeType; + + private final ResolvableType targetType; + + private final BiConsumer> accumulator; + + private final Decoder decoder; + + + public EntryProcessor( + MimeType mimeType, Class targetType, + BiConsumer> accumulator) { + + this(mimeType, ResolvableType.forClass(targetType), accumulator); + } + + public EntryProcessor( + MimeType mimeType, ParameterizedTypeReference targetType, + BiConsumer> accumulator) { + + this(mimeType, ResolvableType.forType(targetType), accumulator); + } + + private EntryProcessor( + MimeType mimeType, ResolvableType targetType, + BiConsumer> accumulator) { + + this.mimeType = mimeType; + this.targetType = targetType; + this.accumulator = accumulator; + this.decoder = rsocketStrategies.decoder(targetType, mimeType); + } + + + public void process(ByteBuf byteBuf, Map result) { + DataBufferFactory factory = rsocketStrategies.dataBufferFactory(); + DataBuffer buffer = factory instanceof NettyDataBufferFactory ? + ((NettyDataBufferFactory) factory).wrap(byteBuf) : + factory.wrap(byteBuf.nioBuffer()); + + T value = this.decoder.decode(buffer, this.targetType, this.mimeType, Collections.emptyMap()); + this.accumulator.accept(value, result); + } + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java index 70eabdf59d7..7585993634b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java @@ -16,9 +16,7 @@ package org.springframework.messaging.rsocket.annotation.support; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; @@ -26,7 +24,6 @@ import io.rsocket.AbstractRSocket; import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.metadata.CompositeMetadata; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -39,6 +36,7 @@ import org.springframework.core.io.buffer.NettyDataBuffer; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.ReactiveMessageHandler; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; import org.springframework.messaging.rsocket.PayloadUtils; @@ -63,42 +61,42 @@ class MessagingRSocket extends AbstractRSocket { static final MimeType COMPOSITE_METADATA = new MimeType("message", "x.rsocket.composite-metadata.v0"); - private static final MimeType ROUTING = new MimeType("message", "x.rsocket.routing.v0"); + static final MimeType ROUTING = new MimeType("message", "x.rsocket.routing.v0"); - private static final List METADATA_MIME_TYPES = Arrays.asList(COMPOSITE_METADATA, ROUTING); + private final MimeType dataMimeType; - private final RSocketMessageHandler messageHandler; + private final MimeType metadataMimeType; - private final RouteMatcher routeMatcher; + private final MetadataExtractor metadataExtractor; - private final RSocketRequester requester; + private final ReactiveMessageHandler messageHandler; - private final MimeType dataMimeType; + private final RouteMatcher routeMatcher; - private final MimeType metadataMimeType; + private final RSocketRequester requester; private final DataBufferFactory bufferFactory; - MessagingRSocket(RSocketMessageHandler messageHandler, RouteMatcher routeMatcher, - RSocketRequester requester, MimeType dataMimeType, MimeType metadataMimeType, - DataBufferFactory bufferFactory) { + MessagingRSocket(MimeType dataMimeType, MimeType metadataMimeType, MetadataExtractor metadataExtractor, + RSocketRequester requester, ReactiveMessageHandler messageHandler, + RouteMatcher routeMatcher, DataBufferFactory bufferFactory) { - Assert.notNull(messageHandler, "'messageHandler' is required"); - Assert.notNull(routeMatcher, "'routeMatcher' is required"); - Assert.notNull(requester, "'requester' is required"); Assert.notNull(dataMimeType, "'dataMimeType' is required"); Assert.notNull(metadataMimeType, "'metadataMimeType' is required"); + Assert.notNull(metadataExtractor, "'metadataExtractor' is required"); + Assert.notNull(requester, "'requester' is required"); + Assert.notNull(messageHandler, "'messageHandler' is required"); + Assert.notNull(routeMatcher, "'routeMatcher' is required"); + Assert.notNull(bufferFactory, "'bufferFactory' is required"); - Assert.isTrue(METADATA_MIME_TYPES.contains(metadataMimeType), - () -> "Unexpected metadatata mime type: '" + metadataMimeType + "'"); - - this.messageHandler = messageHandler; - this.routeMatcher = routeMatcher; - this.requester = requester; this.dataMimeType = dataMimeType; this.metadataMimeType = metadataMimeType; + this.metadataExtractor = metadataExtractor; + this.requester = requester; + this.messageHandler = messageHandler; + this.routeMatcher = routeMatcher; this.bufferFactory = bufferFactory; } @@ -149,8 +147,7 @@ class MessagingRSocket extends AbstractRSocket { private Mono handle(Payload payload) { - String destination = getDestination(payload); - MessageHeaders headers = createHeaders(destination, null); + MessageHeaders headers = createHeaders(payload, null); DataBuffer dataBuffer = retainDataAndReleasePayload(payload); int refCount = refCount(dataBuffer); Message message = MessageBuilder.createMessage(dataBuffer, headers); @@ -169,8 +166,7 @@ class MessagingRSocket extends AbstractRSocket { private Flux handleAndReply(Payload firstPayload, Flux payloads) { MonoProcessor> replyMono = MonoProcessor.create(); - String destination = getDestination(firstPayload); - MessageHeaders headers = createHeaders(destination, replyMono); + MessageHeaders headers = createHeaders(firstPayload, replyMono); AtomicBoolean read = new AtomicBoolean(); Flux buffers = payloads.map(this::retainDataAndReleasePayload).doOnSubscribe(s -> read.set(true)); @@ -188,39 +184,33 @@ class MessagingRSocket extends AbstractRSocket { Mono.error(new IllegalStateException("Something went wrong: reply Mono not set")))); } - private String getDestination(Payload payload) { - if (this.metadataMimeType.equals(COMPOSITE_METADATA)) { - CompositeMetadata metadata = new CompositeMetadata(payload.metadata(), false); - for (CompositeMetadata.Entry entry : metadata) { - String mimeType = entry.getMimeType(); - if (ROUTING.toString().equals(mimeType)) { - return entry.getContent().toString(StandardCharsets.UTF_8); - } - } - return ""; - } - else if (this.metadataMimeType.equals(ROUTING)) { - return payload.getMetadataUtf8(); - } - // Should not happen (given constructor assertions) - throw new IllegalArgumentException("Unexpected metadataMimeType"); - } - private DataBuffer retainDataAndReleasePayload(Payload payload) { return PayloadUtils.retainDataAndReleasePayload(payload, this.bufferFactory); } - private MessageHeaders createHeaders(String destination, @Nullable MonoProcessor replyMono) { + private MessageHeaders createHeaders(Payload payload, @Nullable MonoProcessor replyMono) { MessageHeaderAccessor headers = new MessageHeaderAccessor(); headers.setLeaveMutable(true); - RouteMatcher.Route route = this.routeMatcher.parseRoute(destination); - headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, route); + + Map metadataValues = this.metadataExtractor.extract(payload, this.metadataMimeType); + metadataValues.putIfAbsent(MetadataExtractor.ROUTE_KEY, ""); + for (Map.Entry entry : metadataValues.entrySet()) { + if (entry.getKey().equals(MetadataExtractor.ROUTE_KEY)) { + RouteMatcher.Route route = this.routeMatcher.parseRoute((String) entry.getValue()); + headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, route); + } + else { + headers.setHeader(entry.getKey(), entry.getValue()); + } + } + headers.setContentType(this.dataMimeType); headers.setHeader(RSocketRequesterMethodArgumentResolver.RSOCKET_REQUESTER_HEADER, this.requester); if (replyMono != null) { headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono); } headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, this.bufferFactory); + return headers.getMessageHeaders(); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MetadataExtractor.java new file mode 100644 index 00000000000..ad48da526da --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MetadataExtractor.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.rsocket.annotation.support; + +import java.util.Map; + +import io.rsocket.Payload; + +import org.springframework.util.MimeType; + +/** + * Strategy to extract a map of values from the metadata of a {@link Payload}. + * This includes decoding metadata entries based on their mime type and + * assigning a name to the decoded value. The resulting name-value pairs can + * be added to the headers of a + * {@link org.springframework.messaging.Message Message}. + * + * @author Rossen Stoyanchev + * @since 5.2 + */ +public interface MetadataExtractor { + + /** + * The key to assign to the extracted "route" of the payload. + */ + String ROUTE_KEY = "route"; + + + /** + * Extract a map of values from the given {@link Payload} metadata. + *

Metadata may be composite and consist of multiple entries + * Implementations are free to extract any number of name-value pairs per + * metadata entry. The Payload "route" should be saved under the + * {@link #ROUTE_KEY}. + * @param payload the payload whose metadata should be read + * @param metadataMimeType the mime type of the metadata; this is what was + * specified by the client at the start of the RSocket connection. + * @return a map of 0 or more decoded metadata values with assigned names + */ + Map extract(Payload payload, MimeType metadataMimeType); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java index c869c2542f3..817cc7a33a9 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java @@ -60,6 +60,9 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { @Nullable private RSocketStrategies rsocketStrategies; + @Nullable + private MetadataExtractor metadataExtractor; + @Nullable private MimeType defaultDataMimeType; @@ -88,31 +91,33 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { * {@link RSocketStrategies} encapsulates required configuration for re-use. * @param rsocketStrategies the strategies to use */ - public void setRSocketStrategies(RSocketStrategies rsocketStrategies) { - Assert.notNull(rsocketStrategies, "RSocketStrategies must not be null"); + public void setRSocketStrategies(@Nullable RSocketStrategies rsocketStrategies) { this.rsocketStrategies = rsocketStrategies; - setDecoders(rsocketStrategies.decoders()); - setEncoders(rsocketStrategies.encoders()); - setReactiveAdapterRegistry(rsocketStrategies.reactiveAdapterRegistry()); + if (rsocketStrategies != null) { + setDecoders(rsocketStrategies.decoders()); + setEncoders(rsocketStrategies.encoders()); + setReactiveAdapterRegistry(rsocketStrategies.reactiveAdapterRegistry()); + } } - /** - * Return the {@code RSocketStrategies} instance provided via - * {@link #setRSocketStrategies rsocketStrategies}, or - * otherwise initialize it with the configured {@link #setEncoders(List) - * encoders}, {@link #setDecoders(List) decoders}, and others. - */ + @Nullable public RSocketStrategies getRSocketStrategies() { - if (this.rsocketStrategies == null) { - this.rsocketStrategies = RSocketStrategies.builder() - .decoder(getDecoders().toArray(new Decoder[0])) - .encoder(getEncoders().toArray(new Encoder[0])) - .reactiveAdapterStrategy(getReactiveAdapterRegistry()) - .build(); - } return this.rsocketStrategies; } + /** + * Configure a {@link MetadataExtractor} to extract the route and possibly + * other metadata from the first payload of incoming requests. + *

By default this is a {@link DefaultMetadataExtractor} with the + * configured {@link RSocketStrategies} (and decoders), extracting a route + * from {@code "message/x.rsocket.routing.v0"} or {@code "text/plain"} + * metadata entries. + * @param extractor the extractor to use + */ + public void setMetadataExtractor(MetadataExtractor extractor) { + this.metadataExtractor = extractor; + } + /** * Configure the default content type to use for data payloads if the * {@code SETUP} frame did not specify one. @@ -137,6 +142,18 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { @Override public void afterPropertiesSet() { + if (this.rsocketStrategies == null) { + this.rsocketStrategies = RSocketStrategies.builder() + .decoder(getDecoders().toArray(new Decoder[0])) + .encoder(getEncoders().toArray(new Encoder[0])) + .reactiveAdapterStrategy(getReactiveAdapterRegistry()) + .build(); + } + if (this.metadataExtractor == null) { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(this.rsocketStrategies); + extractor.metadataToExtract(MimeTypeUtils.TEXT_PLAIN, String.class, MetadataExtractor.ROUTE_KEY); + this.metadataExtractor = extractor; + } getArgumentResolverConfigurer().addCustomResolver(new RSocketRequesterMethodArgumentResolver()); super.afterPropertiesSet(); } @@ -201,11 +218,14 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { MimeType metaMimeType = StringUtils.hasText(s) ? MimeTypeUtils.parseMimeType(s) : this.defaultMetadataMimeType; Assert.notNull(dataMimeType, "No `metadataMimeType` in ConnectionSetupPayload and no default value"); - RSocketRequester requester = RSocketRequester.wrap( - rsocket, dataMimeType, metaMimeType, getRSocketStrategies()); + RSocketStrategies strategies = this.rsocketStrategies; + Assert.notNull(strategies, "No RSocketStrategies. Was afterPropertiesSet not called?"); + RSocketRequester requester = RSocketRequester.wrap(rsocket, dataMimeType, metaMimeType, strategies); + + Assert.notNull(this.metadataExtractor, () -> "No MetadataExtractor. Was afterPropertiesSet not called?"); - return new MessagingRSocket(this, getRouteMatcher(), requester, - dataMimeType, metaMimeType, getRSocketStrategies().dataBufferFactory()); + return new MessagingRSocket(dataMimeType, metaMimeType, this.metadataExtractor, requester, + this, getRouteMatcher(), strategies.dataBufferFactory()); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java index 915499aba64..546844ffeaf 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java @@ -43,11 +43,14 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.lang.Nullable; import org.springframework.messaging.rsocket.RSocketRequester.RequestSpec; import org.springframework.messaging.rsocket.RSocketRequester.ResponseSpec; -import org.springframework.util.MimeTypeUtils; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.springframework.messaging.rsocket.DefaultRSocketRequester.COMPOSITE_METADATA; +import static org.springframework.messaging.rsocket.DefaultRSocketRequester.ROUTING; +import static org.springframework.util.MimeTypeUtils.TEXT_PLAIN; /** * Unit tests for {@link DefaultRSocketRequester}. @@ -75,9 +78,7 @@ public class DefaultRSocketRequesterTests { .encoder(CharSequenceEncoder.allMimeTypes()) .build(); this.rsocket = new TestRSocket(); - this.requester = RSocketRequester.wrap(this.rsocket, - MimeTypeUtils.TEXT_PLAIN, DefaultRSocketRequester.ROUTING, - this.strategies); + this.requester = RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, TEXT_PLAIN, this.strategies); } @@ -143,13 +144,32 @@ public class DefaultRSocketRequesterTests { } @Test - public void sendCompositeMetadata() { - RSocketRequester requester = RSocketRequester.wrap(this.rsocket, - MimeTypeUtils.TEXT_PLAIN, DefaultRSocketRequester.COMPOSITE_METADATA, - this.strategies); + public void metadataCompositeWithRoute() { + + RSocketRequester requester = RSocketRequester.wrap( + this.rsocket, TEXT_PLAIN, COMPOSITE_METADATA, this.strategies); + + requester.route("toA").data("bodyA").send().block(Duration.ofSeconds(5)); + + CompositeMetadata entries = new CompositeMetadata(this.rsocket.getSavedPayload().metadata(), false); + Iterator iterator = entries.iterator(); + + assertThat(iterator.hasNext()).isTrue(); + CompositeMetadata.Entry entry = iterator.next(); + assertThat(entry.getMimeType()).isEqualTo(ROUTING.toString()); + assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); + + assertThat(iterator.hasNext()).isFalse(); + } + + @Test + public void metadataCompositeWithRouteAndTextEntry() { + + RSocketRequester requester = RSocketRequester.wrap( + this.rsocket, TEXT_PLAIN, COMPOSITE_METADATA, this.strategies); requester.route("toA") - .metadata("My metadata", MimeTypeUtils.TEXT_PLAIN).data("bodyA") + .metadata("My metadata", TEXT_PLAIN).data("bodyA") .send() .block(Duration.ofSeconds(5)); @@ -158,27 +178,46 @@ public class DefaultRSocketRequesterTests { assertThat(iterator.hasNext()).isTrue(); CompositeMetadata.Entry entry = iterator.next(); - assertThat(entry.getMimeType()).isEqualTo(DefaultRSocketRequester.ROUTING.toString()); + assertThat(entry.getMimeType()).isEqualTo(ROUTING.toString()); assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); assertThat(iterator.hasNext()).isTrue(); entry = iterator.next(); - assertThat(entry.getMimeType()).isEqualTo(MimeTypeUtils.TEXT_PLAIN.toString()); + assertThat(entry.getMimeType()).isEqualTo(TEXT_PLAIN.toString()); assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("My metadata"); assertThat(iterator.hasNext()).isFalse(); } + @Test + public void metadataRouteAsText() { + RSocketRequester requester = RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, TEXT_PLAIN, this.strategies); + requester.route("toA").data("bodyA").send().block(Duration.ofSeconds(5)); + assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("toA"); + } + + @Test + public void metadataAsText() { + RSocketRequester requester = RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, TEXT_PLAIN, this.strategies); + requester.metadata("toA", null).data("bodyA").send().block(Duration.ofSeconds(5)); + assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("toA"); + } + + @Test + public void metadataMimeTypeMismatch() { + RSocketRequester requester = RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, TEXT_PLAIN, this.strategies); + assertThatThrownBy(() -> requester.metadata("toA", ROUTING).data("bodyA").send().block()) + .hasMessageStartingWith("Connection configured for metadata mime type"); + } + @Test public void supportedMetadataMimeTypes() { - RSocketRequester.wrap(this.rsocket, MimeTypeUtils.TEXT_PLAIN, - DefaultRSocketRequester.COMPOSITE_METADATA, this.strategies); - RSocketRequester.wrap(this.rsocket, MimeTypeUtils.TEXT_PLAIN, - DefaultRSocketRequester.ROUTING, this.strategies); + RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, + COMPOSITE_METADATA, this.strategies); - assertThatIllegalArgumentException().isThrownBy(() -> RSocketRequester.wrap( - this.rsocket, MimeTypeUtils.TEXT_PLAIN, MimeTypeUtils.TEXT_PLAIN, this.strategies)); + RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, + ROUTING, this.strategies); } @Test diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/LeakAwareNettyDataBufferFactory.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/LeakAwareNettyDataBufferFactory.java new file mode 100644 index 00000000000..321ff4e3845 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/LeakAwareNettyDataBufferFactory.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.rsocket; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.NettyDataBuffer; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.core.io.buffer.PooledDataBuffer; +import org.springframework.util.ObjectUtils; + +/** + * Unlike {@link org.springframework.core.io.buffer.LeakAwareDataBufferFactory} + * this one is an instance of {@link NettyDataBufferFactory} which is necessary + * since {@link PayloadUtils} does instanceof checks, and that also allows + * intercepting {@link NettyDataBufferFactory#wrap(ByteBuf)}. + */ +public class LeakAwareNettyDataBufferFactory extends NettyDataBufferFactory { + + private final List created = new ArrayList<>(); + + + public LeakAwareNettyDataBufferFactory(ByteBufAllocator byteBufAllocator) { + super(byteBufAllocator); + } + + + public void checkForLeaks(Duration duration) throws InterruptedException { + Instant start = Instant.now(); + while (true) { + try { + this.created.forEach(info -> { + if (((PooledDataBuffer) info.getDataBuffer()).isAllocated()) { + throw info.getError(); + } + }); + break; + } + catch (AssertionError ex) { + if (Instant.now().isAfter(start.plus(duration))) { + throw ex; + } + } + Thread.sleep(50); + } + } + + public void reset() { + this.created.clear(); + } + + + @Override + public NettyDataBuffer allocateBuffer() { + return (NettyDataBuffer) recordHint(super.allocateBuffer()); + } + + @Override + public NettyDataBuffer allocateBuffer(int initialCapacity) { + return (NettyDataBuffer) recordHint(super.allocateBuffer(initialCapacity)); + } + + @Override + public NettyDataBuffer wrap(ByteBuf byteBuf) { + NettyDataBuffer dataBuffer = super.wrap(byteBuf); + if (byteBuf != Unpooled.EMPTY_BUFFER) { + recordHint(dataBuffer); + } + return dataBuffer; + } + + @Override + public DataBuffer join(List dataBuffers) { + return recordHint(super.join(dataBuffers)); + } + + private DataBuffer recordHint(DataBuffer buffer) { + AssertionError error = new AssertionError(String.format( + "DataBuffer leak: {%s} {%s} not released.%nStacktrace at buffer creation: ", buffer, + ObjectUtils.getIdentityHexString(((NettyDataBuffer) buffer).getNativeBuffer()))); + this.created.add(new DataBufferLeakInfo(buffer, error)); + return buffer; + } + + + private static class DataBufferLeakInfo { + + private final DataBuffer dataBuffer; + + private final AssertionError error; + + DataBufferLeakInfo(DataBuffer dataBuffer, AssertionError error) { + this.dataBuffer = dataBuffer; + this.error = error; + } + + DataBuffer getDataBuffer() { + return this.dataBuffer; + } + + AssertionError getError() { + return this.error; + } + } +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java index 215e0f31a8e..93e34ad77e2 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java @@ -18,14 +18,10 @@ package org.springframework.messaging.rsocket; import java.time.Duration; import java.time.Instant; -import java.util.ArrayList; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.PooledByteBufAllocator; -import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCounted; import io.rsocket.AbstractRSocket; import io.rsocket.RSocket; @@ -51,16 +47,11 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.Resource; -import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.NettyDataBuffer; -import org.springframework.core.io.buffer.NettyDataBufferFactory; -import org.springframework.core.io.buffer.PooledDataBuffer; import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.Payload; import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; import org.springframework.stereotype.Controller; -import org.springframework.util.ObjectUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -232,100 +223,6 @@ public class RSocketBufferLeakTests { } - /** - * Unlike {@link org.springframework.core.io.buffer.LeakAwareDataBufferFactory} - * this one is an instance of {@link NettyDataBufferFactory} which is necessary - * since {@link PayloadUtils} does instanceof checks, and that also allows - * intercepting {@link NettyDataBufferFactory#wrap(ByteBuf)}. - */ - private static class LeakAwareNettyDataBufferFactory extends NettyDataBufferFactory { - - private final List created = new ArrayList<>(); - - LeakAwareNettyDataBufferFactory(ByteBufAllocator byteBufAllocator) { - super(byteBufAllocator); - } - - void checkForLeaks(Duration duration) throws InterruptedException { - Instant start = Instant.now(); - while (true) { - try { - this.created.forEach(info -> { - if (((PooledDataBuffer) info.getDataBuffer()).isAllocated()) { - throw info.getError(); - } - }); - break; - } - catch (AssertionError ex) { - if (Instant.now().isAfter(start.plus(duration))) { - throw ex; - } - } - Thread.sleep(50); - } - } - - void reset() { - this.created.clear(); - } - - - @Override - public NettyDataBuffer allocateBuffer() { - return (NettyDataBuffer) recordHint(super.allocateBuffer()); - } - - @Override - public NettyDataBuffer allocateBuffer(int initialCapacity) { - return (NettyDataBuffer) recordHint(super.allocateBuffer(initialCapacity)); - } - - @Override - public NettyDataBuffer wrap(ByteBuf byteBuf) { - NettyDataBuffer dataBuffer = super.wrap(byteBuf); - if (byteBuf != Unpooled.EMPTY_BUFFER) { - recordHint(dataBuffer); - } - return dataBuffer; - } - - @Override - public DataBuffer join(List dataBuffers) { - return recordHint(super.join(dataBuffers)); - } - - private DataBuffer recordHint(DataBuffer buffer) { - AssertionError error = new AssertionError(String.format( - "DataBuffer leak: {%s} {%s} not released.%nStacktrace at buffer creation: ", buffer, - ObjectUtils.getIdentityHexString(((NettyDataBuffer) buffer).getNativeBuffer()))); - this.created.add(new DataBufferLeakInfo(buffer, error)); - return buffer; - } - } - - - private static class DataBufferLeakInfo { - - private final DataBuffer dataBuffer; - - private final AssertionError error; - - DataBufferLeakInfo(DataBuffer dataBuffer, AssertionError error) { - this.dataBuffer = dataBuffer; - this.error = error; - } - - DataBuffer getDataBuffer() { - return this.dataBuffer; - } - - AssertionError getError() { - return this.error; - } - } - - /** * Store all intercepted incoming and outgoing payloads and then use * {@link #checkForLeaks()} at the end to check reference counts. diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractorTests.java new file mode 100644 index 00000000000..44e13c69c64 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractorTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.rsocket.annotation.support; + +import java.time.Duration; +import java.util.Map; + +import io.netty.buffer.PooledByteBufAllocator; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.BDDMockito; +import reactor.core.publisher.Mono; + +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.messaging.rsocket.LeakAwareNettyDataBufferFactory; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.messaging.rsocket.annotation.support.MessagingRSocket.COMPOSITE_METADATA; +import static org.springframework.messaging.rsocket.annotation.support.MessagingRSocket.ROUTING; +import static org.springframework.messaging.rsocket.annotation.support.MetadataExtractor.ROUTE_KEY; +import static org.springframework.util.MimeTypeUtils.TEXT_HTML; +import static org.springframework.util.MimeTypeUtils.TEXT_PLAIN; +import static org.springframework.util.MimeTypeUtils.TEXT_XML; + + +/** + * Unit tests for {@link DefaultMetadataExtractor}. + * @author Rossen Stoyanchev + */ +public class DefaultMetadataExtractorTests { + + private RSocketStrategies strategies; + + private ArgumentCaptor captor; + + private RSocket rsocket; + + private DefaultMetadataExtractor extractor; + + + @Before + public void setUp() { + this.strategies = RSocketStrategies.builder() + .decoder(StringDecoder.allMimeTypes()) + .encoder(CharSequenceEncoder.allMimeTypes()) + .dataBufferFactory(new LeakAwareNettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)) + .build(); + + this.rsocket = BDDMockito.mock(RSocket.class); + this.captor = ArgumentCaptor.forClass(Payload.class); + BDDMockito.when(this.rsocket.fireAndForget(captor.capture())).thenReturn(Mono.empty()); + + this.extractor = new DefaultMetadataExtractor(this.strategies); + } + + @After + public void tearDown() throws InterruptedException { + DataBufferFactory bufferFactory = this.strategies.dataBufferFactory(); + ((LeakAwareNettyDataBufferFactory) bufferFactory).checkForLeaks(Duration.ofSeconds(5)); + } + + + @Test + public void compositeMetadataWithDefaultSettings() { + + requester(COMPOSITE_METADATA).route("toA") + .metadata("text data", TEXT_PLAIN) + .metadata("html data", TEXT_HTML) + .metadata("xml data", TEXT_XML) + .data("data") + .send().block(); + + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, COMPOSITE_METADATA); + payload.release(); + + assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); + } + + @Test + public void compositeMetadataWithMimeTypeRegistrations() { + + this.extractor.metadataToExtract(TEXT_PLAIN, String.class, "text-entry"); + this.extractor.metadataToExtract(TEXT_HTML, String.class, "html-entry"); + this.extractor.metadataToExtract(TEXT_XML, String.class, "xml-entry"); + + requester(COMPOSITE_METADATA).route("toA") + .metadata("text data", TEXT_PLAIN) + .metadata("html data", TEXT_HTML) + .metadata("xml data", TEXT_XML) + .data("data") + .send() + .block(); + + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, COMPOSITE_METADATA); + payload.release(); + + assertThat(result).hasSize(4) + .containsEntry(ROUTE_KEY, "toA") + .containsEntry("text-entry", "text data") + .containsEntry("html-entry", "html data") + .containsEntry("xml-entry", "xml data"); + } + + @Test + public void route() { + + requester(ROUTING).route("toA").data("data").send().block(); + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, ROUTING); + payload.release(); + + assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); + } + + @Test + public void routeAsText() { + + this.extractor.metadataToExtract(TEXT_PLAIN, String.class, ROUTE_KEY); + + requester(TEXT_PLAIN).route("toA").data("data").send().block(); + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, TEXT_PLAIN); + payload.release(); + + assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); + } + + @Test + public void routeWithCustomFormatting() { + + this.extractor.metadataToExtract(TEXT_PLAIN, String.class, (text, result) -> { + String[] items = text.split(":"); + Assert.isTrue(items.length == 2, "Expected two items"); + result.put(ROUTE_KEY, items[0]); + result.put("entry1", items[1]); + }); + + requester(TEXT_PLAIN).metadata("toA:text data", null).data("data").send().block(); + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, TEXT_PLAIN); + payload.release(); + + assertThat(result).hasSize(2) + .containsEntry(ROUTE_KEY, "toA") + .containsEntry("entry1", "text data"); + } + + + private RSocketRequester requester(MimeType metadataMimeType) { + return RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, metadataMimeType, this.strategies); + } + +}