diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java index ef27e95f40c..6816b796d88 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java @@ -16,10 +16,8 @@ package org.springframework.messaging.rsocket; -import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; import io.netty.buffer.ByteBuf; @@ -59,8 +57,6 @@ final class DefaultRSocketRequester implements RSocketRequester { static final MimeType ROUTING = new MimeType("message", "x.rsocket.routing.v0"); - static final List METADATA_MIME_TYPES = Arrays.asList(COMPOSITE_METADATA, ROUTING); - private static final Map EMPTY_HINTS = Collections.emptyMap(); @@ -85,9 +81,6 @@ final class DefaultRSocketRequester implements RSocketRequester { Assert.notNull(metadataMimeType, "'metadataMimeType' is required"); Assert.notNull(strategies, "RSocketStrategies is required"); - Assert.isTrue(METADATA_MIME_TYPES.contains(metadataMimeType), - () -> "Unexpected metadatata mime type: '" + metadataMimeType + "'"); - this.rsocket = rsocket; this.dataMimeType = dataMimeType; this.metadataMimeType = metadataMimeType; @@ -113,7 +106,13 @@ final class DefaultRSocketRequester implements RSocketRequester { @Override public RequestSpec route(String route) { - return new DefaultRequestSpec(route); + Assert.notNull(route, "'route' is required"); + return new DefaultRequestSpec(route, metadataMimeType().equals(COMPOSITE_METADATA) ? ROUTING : null); + } + + @Override + public RequestSpec metadata(Object metadata, @Nullable MimeType mimeType) { + return new DefaultRequestSpec(metadata, mimeType); } @@ -131,16 +130,22 @@ final class DefaultRSocketRequester implements RSocketRequester { private final Map metadata = new LinkedHashMap<>(4); - public DefaultRequestSpec(String route) { - Assert.notNull(route, "'route' is required"); - metadata(route, ROUTING); + public DefaultRequestSpec(Object metadata, @Nullable MimeType mimeType) { + mimeType = (mimeType == null && !isCompositeMetadata() ? metadataMimeType() : mimeType); + Assert.notNull(mimeType, "MimeType is required for composite metadata"); + metadata(metadata, mimeType); } + private boolean isCompositeMetadata() { + return metadataMimeType().equals(COMPOSITE_METADATA); + } @Override public RequestSpec metadata(Object metadata, MimeType mimeType) { - Assert.isTrue(this.metadata.isEmpty() || metadataMimeType().equals(COMPOSITE_METADATA), - "Additional metadata entries supported only with composite metadata"); + Assert.notNull(metadata, "Metadata content is required"); + Assert.notNull(mimeType, "MimeType is required"); + Assert.isTrue(this.metadata.isEmpty() || isCompositeMetadata(), + "Composite metadata required for multiple metadata entries."); this.metadata.put(metadata, mimeType); return this; } @@ -250,22 +255,27 @@ final class DefaultRSocketRequester implements RSocketRequester { } private DataBuffer getMetadata() { - if (metadataMimeType().equals(COMPOSITE_METADATA)) { + if (isCompositeMetadata()) { CompositeByteBuf metadata = getAllocator().compositeBuffer(); - this.metadata.forEach((key, value) -> { - DataBuffer dataBuffer = encodeMetadata(key, value); - CompositeMetadataFlyweight.encodeAndAddMetadata(metadata, getAllocator(), value.toString(), + this.metadata.forEach((value, mimeType) -> { + DataBuffer dataBuffer = encodeMetadata(value, mimeType); + CompositeMetadataFlyweight.encodeAndAddMetadata(metadata, getAllocator(), mimeType.toString(), dataBuffer instanceof NettyDataBuffer ? ((NettyDataBuffer) dataBuffer).getNativeBuffer() : Unpooled.wrappedBuffer(dataBuffer.asByteBuffer())); }); return asDataBuffer(metadata); } - Assert.isTrue(this.metadata.size() < 2, "Composite metadata required for multiple entries"); - Map.Entry entry = this.metadata.entrySet().iterator().next(); - Assert.isTrue(metadataMimeType().equals(entry.getValue()), - () -> "Expected metadata MimeType '" + metadataMimeType() + "', actual " + this.metadata); - return encodeMetadata(entry.getKey(), entry.getValue()); + else { + Assert.isTrue(this.metadata.size() == 1, "Composite metadata required for multiple entries"); + Map.Entry entry = this.metadata.entrySet().iterator().next(); + if (!metadataMimeType().equals(entry.getValue())) { + throw new IllegalArgumentException( + "Connection configured for metadata mime type " + + "'" + metadataMimeType() + "', but actual is `" + this.metadata + "`"); + } + return encodeMetadata(entry.getKey(), entry.getValue()); + } } @SuppressWarnings("unchecked") diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java index ab9f22b0f68..5481dd6bf83 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java @@ -67,14 +67,25 @@ public interface RSocketRequester { /** - * Begin to specify a new request with the given route to a handler on the - * remote side. The route will be encoded in the metadata of the first - * payload. + * Begin to specify a new request with the given route to a remote handler. + *

If the connection is set to use composite metadata, the route is + * encoded as {@code "message/x.rsocket.routing.v0"}. Otherwise the route + * is encoded according to the mime type for the connection. * @param route the route to a handler * @return a spec for further defining and executing the request */ RequestSpec route(String route); + /** + * Begin to specify a new request with the given metadata. + *

If using composite metadata then the mime type argument is required. + * Otherwise the mime type should be {@code null}, or it must match the + * mime type for the connection. + * @param metadata the metadata value to encode + * @param mimeType the mime type that describes the metadata; + */ + RequestSpec metadata(Object metadata, @Nullable MimeType mimeType); + /** * Obtain a builder for an {@link RSocketRequester} by connecting to an @@ -110,24 +121,24 @@ public interface RSocketRequester { interface Builder { /** - * Configure the MimeType to use for payload data. This is then - * specified on the {@code SETUP} frame for the whole connection. - *

By default this is set to the first concrete MimeType supported + * Configure the MimeType for payload data which is then specified + * on the {@code SETUP} frame and applies to the whole connection. + *

By default this is set to the first concrete mime type supported * by the configured encoders and decoders. * @param mimeType the data MimeType to use */ RSocketRequester.Builder dataMimeType(@Nullable MimeType mimeType); /** - * Configure the MimeType to use for payload metadata. This is then - * specified on the {@code SETUP} frame for the whole connection. - *

At present the metadata MimeType must be - * {@code "message/x.rsocket.routing.v0"} to allow the request - * {@link RSocketRequester#route(String) route} to be encoded, or it - * could also be {@code "message/x.rsocket.composite-metadata.v0"} in - * which case the route can be encoded along with other metadata entries. + * Configure the MimeType for payload metadata which is then specified + * on the {@code SETUP} frame and applies to the whole connection. *

By default this is set to - * {@code "message/x.rsocket.composite-metadata.v0"}. + * {@code "message/x.rsocket.composite-metadata.v0"} in which case the + * route, if provided, is encoded as a + * {@code "message/x.rsocket.routing.v0"} metadata entry, potentially + * with other metadata entries added too. If this is set to any other + * mime type, and a route is provided, it is assumed the mime type is + * for the route. * @param mimeType the data MimeType to use */ RSocketRequester.Builder metadataMimeType(MimeType mimeType); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractor.java new file mode 100644 index 00000000000..bc008862d13 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractor.java @@ -0,0 +1,209 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.rsocket.annotation.support; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiConsumer; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.CompositeMetadata; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Decoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.lang.Nullable; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +/** + * Default {@link MetadataExtractor} implementation that relies on {@link Decoder}s + * to deserialize the content of metadata entries. + * + *

By default only {@code "message/x.rsocket.routing.v0""} is extracted and + * saved under {@link MetadataExtractor#ROUTE_KEY}. Use the + * {@code metadataToExtract} methods to specify other metadata mime types of + * interest to extract. + * + * @author Rossen Stoyanchev + * @since 5.2 + */ +public class DefaultMetadataExtractor implements MetadataExtractor { + + private final RSocketStrategies rsocketStrategies; + + private final Map> entryProcessors = new HashMap<>(); + + + /** + * Default constructor with {@link RSocketStrategies}. + */ + public DefaultMetadataExtractor(RSocketStrategies strategies) { + Assert.notNull(strategies, "RSocketStrategies is required"); + this.rsocketStrategies = strategies; + // TODO: remove when rsocket-core API available + metadataToExtract(MessagingRSocket.ROUTING, String.class, ROUTE_KEY); + } + + + /** + * Decode metadata entries with the given {@link MimeType} to the specified + * target class, and store the decoded value in the output map under the + * given name. + * @param mimeType the mime type of metadata entries to extract + * @param targetType the target value type to decode to + * @param name assign a name for the decoded value; if not provided, then + * the mime type is used as the key + */ + public void metadataToExtract( + MimeType mimeType, Class targetType, @Nullable String name) { + + String key = name != null ? name : mimeType.toString(); + metadataToExtract(mimeType, targetType, (value, map) -> map.put(key, value)); + } + + /** + * Variant of {@link #metadataToExtract(MimeType, Class, String)} that accepts + * {@link ParameterizedTypeReference} instead of {@link Class} for + * specifying a target type with generic parameters. + */ + public void metadataToExtract( + MimeType mimeType, ParameterizedTypeReference targetType, @Nullable String name) { + + String key = name != null ? name : mimeType.toString(); + metadataToExtract(mimeType, targetType, (value, map) -> map.put(key, value)); + } + + /** + * Variant of {@link #metadataToExtract(MimeType, Class, String)} that allows + * custom logic to be used to map the decoded value to any number of values + * in the output map. + * @param mimeType the mime type of metadata entries to extract + * @param targetType the target value type to decode to + * @param mapper custom logic to add the decoded value to the output map + * @param the target value type + */ + public void metadataToExtract( + MimeType mimeType, Class targetType, + BiConsumer> mapper) { + + EntryProcessor spec = new EntryProcessor<>(mimeType, targetType, mapper); + this.entryProcessors.put(mimeType.toString(), spec); + } + + /** + * Variant of {@link #metadataToExtract(MimeType, Class, BiConsumer)} that + * accepts {@link ParameterizedTypeReference} instead of {@link Class} for + * specifying a target type with generic parameters. + * @param mimeType the mime type of metadata entries to extract + * @param targetType the target value type to decode to + * @param mapper custom logic to add the decoded value to the output map + * @param the target value type + */ + public void metadataToExtract( + MimeType mimeType, ParameterizedTypeReference targetType, + BiConsumer> mapper) { + + EntryProcessor spec = new EntryProcessor<>(mimeType, targetType, mapper); + this.entryProcessors.put(mimeType.toString(), spec); + } + + + @Override + public Map extract(Payload payload, MimeType metadataMimeType) { + Map result = new HashMap<>(); + if (metadataMimeType.equals(MessagingRSocket.COMPOSITE_METADATA)) { + for (CompositeMetadata.Entry entry : new CompositeMetadata(payload.metadata(), false)) { + processEntry(entry.getContent(), entry.getMimeType(), result); + } + } + else { + processEntry(payload.metadata(), metadataMimeType.toString(), result); + } + return result; + } + + private void processEntry(ByteBuf content, @Nullable String mimeType, Map result) { + EntryProcessor entryProcessor = this.entryProcessors.get(mimeType); + if (entryProcessor != null) { + content.retain(); + entryProcessor.process(content, result); + return; + } + if (MessagingRSocket.ROUTING.toString().equals(mimeType)) { + // TODO: use rsocket-core API when available + } + } + + + /** + * Helps to decode a metadata entry and add the resulting value to the + * output map. + */ + private class EntryProcessor { + + private final MimeType mimeType; + + private final ResolvableType targetType; + + private final BiConsumer> accumulator; + + private final Decoder decoder; + + + public EntryProcessor( + MimeType mimeType, Class targetType, + BiConsumer> accumulator) { + + this(mimeType, ResolvableType.forClass(targetType), accumulator); + } + + public EntryProcessor( + MimeType mimeType, ParameterizedTypeReference targetType, + BiConsumer> accumulator) { + + this(mimeType, ResolvableType.forType(targetType), accumulator); + } + + private EntryProcessor( + MimeType mimeType, ResolvableType targetType, + BiConsumer> accumulator) { + + this.mimeType = mimeType; + this.targetType = targetType; + this.accumulator = accumulator; + this.decoder = rsocketStrategies.decoder(targetType, mimeType); + } + + + public void process(ByteBuf byteBuf, Map result) { + DataBufferFactory factory = rsocketStrategies.dataBufferFactory(); + DataBuffer buffer = factory instanceof NettyDataBufferFactory ? + ((NettyDataBufferFactory) factory).wrap(byteBuf) : + factory.wrap(byteBuf.nioBuffer()); + + T value = this.decoder.decode(buffer, this.targetType, this.mimeType, Collections.emptyMap()); + this.accumulator.accept(value, result); + } + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java index 70eabdf59d7..7585993634b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java @@ -16,9 +16,7 @@ package org.springframework.messaging.rsocket.annotation.support; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; @@ -26,7 +24,6 @@ import io.rsocket.AbstractRSocket; import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.metadata.CompositeMetadata; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -39,6 +36,7 @@ import org.springframework.core.io.buffer.NettyDataBuffer; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.ReactiveMessageHandler; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; import org.springframework.messaging.rsocket.PayloadUtils; @@ -63,42 +61,42 @@ class MessagingRSocket extends AbstractRSocket { static final MimeType COMPOSITE_METADATA = new MimeType("message", "x.rsocket.composite-metadata.v0"); - private static final MimeType ROUTING = new MimeType("message", "x.rsocket.routing.v0"); + static final MimeType ROUTING = new MimeType("message", "x.rsocket.routing.v0"); - private static final List METADATA_MIME_TYPES = Arrays.asList(COMPOSITE_METADATA, ROUTING); + private final MimeType dataMimeType; - private final RSocketMessageHandler messageHandler; + private final MimeType metadataMimeType; - private final RouteMatcher routeMatcher; + private final MetadataExtractor metadataExtractor; - private final RSocketRequester requester; + private final ReactiveMessageHandler messageHandler; - private final MimeType dataMimeType; + private final RouteMatcher routeMatcher; - private final MimeType metadataMimeType; + private final RSocketRequester requester; private final DataBufferFactory bufferFactory; - MessagingRSocket(RSocketMessageHandler messageHandler, RouteMatcher routeMatcher, - RSocketRequester requester, MimeType dataMimeType, MimeType metadataMimeType, - DataBufferFactory bufferFactory) { + MessagingRSocket(MimeType dataMimeType, MimeType metadataMimeType, MetadataExtractor metadataExtractor, + RSocketRequester requester, ReactiveMessageHandler messageHandler, + RouteMatcher routeMatcher, DataBufferFactory bufferFactory) { - Assert.notNull(messageHandler, "'messageHandler' is required"); - Assert.notNull(routeMatcher, "'routeMatcher' is required"); - Assert.notNull(requester, "'requester' is required"); Assert.notNull(dataMimeType, "'dataMimeType' is required"); Assert.notNull(metadataMimeType, "'metadataMimeType' is required"); + Assert.notNull(metadataExtractor, "'metadataExtractor' is required"); + Assert.notNull(requester, "'requester' is required"); + Assert.notNull(messageHandler, "'messageHandler' is required"); + Assert.notNull(routeMatcher, "'routeMatcher' is required"); + Assert.notNull(bufferFactory, "'bufferFactory' is required"); - Assert.isTrue(METADATA_MIME_TYPES.contains(metadataMimeType), - () -> "Unexpected metadatata mime type: '" + metadataMimeType + "'"); - - this.messageHandler = messageHandler; - this.routeMatcher = routeMatcher; - this.requester = requester; this.dataMimeType = dataMimeType; this.metadataMimeType = metadataMimeType; + this.metadataExtractor = metadataExtractor; + this.requester = requester; + this.messageHandler = messageHandler; + this.routeMatcher = routeMatcher; this.bufferFactory = bufferFactory; } @@ -149,8 +147,7 @@ class MessagingRSocket extends AbstractRSocket { private Mono handle(Payload payload) { - String destination = getDestination(payload); - MessageHeaders headers = createHeaders(destination, null); + MessageHeaders headers = createHeaders(payload, null); DataBuffer dataBuffer = retainDataAndReleasePayload(payload); int refCount = refCount(dataBuffer); Message message = MessageBuilder.createMessage(dataBuffer, headers); @@ -169,8 +166,7 @@ class MessagingRSocket extends AbstractRSocket { private Flux handleAndReply(Payload firstPayload, Flux payloads) { MonoProcessor> replyMono = MonoProcessor.create(); - String destination = getDestination(firstPayload); - MessageHeaders headers = createHeaders(destination, replyMono); + MessageHeaders headers = createHeaders(firstPayload, replyMono); AtomicBoolean read = new AtomicBoolean(); Flux buffers = payloads.map(this::retainDataAndReleasePayload).doOnSubscribe(s -> read.set(true)); @@ -188,39 +184,33 @@ class MessagingRSocket extends AbstractRSocket { Mono.error(new IllegalStateException("Something went wrong: reply Mono not set")))); } - private String getDestination(Payload payload) { - if (this.metadataMimeType.equals(COMPOSITE_METADATA)) { - CompositeMetadata metadata = new CompositeMetadata(payload.metadata(), false); - for (CompositeMetadata.Entry entry : metadata) { - String mimeType = entry.getMimeType(); - if (ROUTING.toString().equals(mimeType)) { - return entry.getContent().toString(StandardCharsets.UTF_8); - } - } - return ""; - } - else if (this.metadataMimeType.equals(ROUTING)) { - return payload.getMetadataUtf8(); - } - // Should not happen (given constructor assertions) - throw new IllegalArgumentException("Unexpected metadataMimeType"); - } - private DataBuffer retainDataAndReleasePayload(Payload payload) { return PayloadUtils.retainDataAndReleasePayload(payload, this.bufferFactory); } - private MessageHeaders createHeaders(String destination, @Nullable MonoProcessor replyMono) { + private MessageHeaders createHeaders(Payload payload, @Nullable MonoProcessor replyMono) { MessageHeaderAccessor headers = new MessageHeaderAccessor(); headers.setLeaveMutable(true); - RouteMatcher.Route route = this.routeMatcher.parseRoute(destination); - headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, route); + + Map metadataValues = this.metadataExtractor.extract(payload, this.metadataMimeType); + metadataValues.putIfAbsent(MetadataExtractor.ROUTE_KEY, ""); + for (Map.Entry entry : metadataValues.entrySet()) { + if (entry.getKey().equals(MetadataExtractor.ROUTE_KEY)) { + RouteMatcher.Route route = this.routeMatcher.parseRoute((String) entry.getValue()); + headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, route); + } + else { + headers.setHeader(entry.getKey(), entry.getValue()); + } + } + headers.setContentType(this.dataMimeType); headers.setHeader(RSocketRequesterMethodArgumentResolver.RSOCKET_REQUESTER_HEADER, this.requester); if (replyMono != null) { headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono); } headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, this.bufferFactory); + return headers.getMessageHeaders(); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MetadataExtractor.java new file mode 100644 index 00000000000..ad48da526da --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MetadataExtractor.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.rsocket.annotation.support; + +import java.util.Map; + +import io.rsocket.Payload; + +import org.springframework.util.MimeType; + +/** + * Strategy to extract a map of values from the metadata of a {@link Payload}. + * This includes decoding metadata entries based on their mime type and + * assigning a name to the decoded value. The resulting name-value pairs can + * be added to the headers of a + * {@link org.springframework.messaging.Message Message}. + * + * @author Rossen Stoyanchev + * @since 5.2 + */ +public interface MetadataExtractor { + + /** + * The key to assign to the extracted "route" of the payload. + */ + String ROUTE_KEY = "route"; + + + /** + * Extract a map of values from the given {@link Payload} metadata. + *

Metadata may be composite and consist of multiple entries + * Implementations are free to extract any number of name-value pairs per + * metadata entry. The Payload "route" should be saved under the + * {@link #ROUTE_KEY}. + * @param payload the payload whose metadata should be read + * @param metadataMimeType the mime type of the metadata; this is what was + * specified by the client at the start of the RSocket connection. + * @return a map of 0 or more decoded metadata values with assigned names + */ + Map extract(Payload payload, MimeType metadataMimeType); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java index c869c2542f3..817cc7a33a9 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java @@ -60,6 +60,9 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { @Nullable private RSocketStrategies rsocketStrategies; + @Nullable + private MetadataExtractor metadataExtractor; + @Nullable private MimeType defaultDataMimeType; @@ -88,31 +91,33 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { * {@link RSocketStrategies} encapsulates required configuration for re-use. * @param rsocketStrategies the strategies to use */ - public void setRSocketStrategies(RSocketStrategies rsocketStrategies) { - Assert.notNull(rsocketStrategies, "RSocketStrategies must not be null"); + public void setRSocketStrategies(@Nullable RSocketStrategies rsocketStrategies) { this.rsocketStrategies = rsocketStrategies; - setDecoders(rsocketStrategies.decoders()); - setEncoders(rsocketStrategies.encoders()); - setReactiveAdapterRegistry(rsocketStrategies.reactiveAdapterRegistry()); + if (rsocketStrategies != null) { + setDecoders(rsocketStrategies.decoders()); + setEncoders(rsocketStrategies.encoders()); + setReactiveAdapterRegistry(rsocketStrategies.reactiveAdapterRegistry()); + } } - /** - * Return the {@code RSocketStrategies} instance provided via - * {@link #setRSocketStrategies rsocketStrategies}, or - * otherwise initialize it with the configured {@link #setEncoders(List) - * encoders}, {@link #setDecoders(List) decoders}, and others. - */ + @Nullable public RSocketStrategies getRSocketStrategies() { - if (this.rsocketStrategies == null) { - this.rsocketStrategies = RSocketStrategies.builder() - .decoder(getDecoders().toArray(new Decoder[0])) - .encoder(getEncoders().toArray(new Encoder[0])) - .reactiveAdapterStrategy(getReactiveAdapterRegistry()) - .build(); - } return this.rsocketStrategies; } + /** + * Configure a {@link MetadataExtractor} to extract the route and possibly + * other metadata from the first payload of incoming requests. + *

By default this is a {@link DefaultMetadataExtractor} with the + * configured {@link RSocketStrategies} (and decoders), extracting a route + * from {@code "message/x.rsocket.routing.v0"} or {@code "text/plain"} + * metadata entries. + * @param extractor the extractor to use + */ + public void setMetadataExtractor(MetadataExtractor extractor) { + this.metadataExtractor = extractor; + } + /** * Configure the default content type to use for data payloads if the * {@code SETUP} frame did not specify one. @@ -137,6 +142,18 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { @Override public void afterPropertiesSet() { + if (this.rsocketStrategies == null) { + this.rsocketStrategies = RSocketStrategies.builder() + .decoder(getDecoders().toArray(new Decoder[0])) + .encoder(getEncoders().toArray(new Encoder[0])) + .reactiveAdapterStrategy(getReactiveAdapterRegistry()) + .build(); + } + if (this.metadataExtractor == null) { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(this.rsocketStrategies); + extractor.metadataToExtract(MimeTypeUtils.TEXT_PLAIN, String.class, MetadataExtractor.ROUTE_KEY); + this.metadataExtractor = extractor; + } getArgumentResolverConfigurer().addCustomResolver(new RSocketRequesterMethodArgumentResolver()); super.afterPropertiesSet(); } @@ -201,11 +218,14 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { MimeType metaMimeType = StringUtils.hasText(s) ? MimeTypeUtils.parseMimeType(s) : this.defaultMetadataMimeType; Assert.notNull(dataMimeType, "No `metadataMimeType` in ConnectionSetupPayload and no default value"); - RSocketRequester requester = RSocketRequester.wrap( - rsocket, dataMimeType, metaMimeType, getRSocketStrategies()); + RSocketStrategies strategies = this.rsocketStrategies; + Assert.notNull(strategies, "No RSocketStrategies. Was afterPropertiesSet not called?"); + RSocketRequester requester = RSocketRequester.wrap(rsocket, dataMimeType, metaMimeType, strategies); + + Assert.notNull(this.metadataExtractor, () -> "No MetadataExtractor. Was afterPropertiesSet not called?"); - return new MessagingRSocket(this, getRouteMatcher(), requester, - dataMimeType, metaMimeType, getRSocketStrategies().dataBufferFactory()); + return new MessagingRSocket(dataMimeType, metaMimeType, this.metadataExtractor, requester, + this, getRouteMatcher(), strategies.dataBufferFactory()); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java index 915499aba64..546844ffeaf 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java @@ -43,11 +43,14 @@ import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.lang.Nullable; import org.springframework.messaging.rsocket.RSocketRequester.RequestSpec; import org.springframework.messaging.rsocket.RSocketRequester.ResponseSpec; -import org.springframework.util.MimeTypeUtils; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.springframework.messaging.rsocket.DefaultRSocketRequester.COMPOSITE_METADATA; +import static org.springframework.messaging.rsocket.DefaultRSocketRequester.ROUTING; +import static org.springframework.util.MimeTypeUtils.TEXT_PLAIN; /** * Unit tests for {@link DefaultRSocketRequester}. @@ -75,9 +78,7 @@ public class DefaultRSocketRequesterTests { .encoder(CharSequenceEncoder.allMimeTypes()) .build(); this.rsocket = new TestRSocket(); - this.requester = RSocketRequester.wrap(this.rsocket, - MimeTypeUtils.TEXT_PLAIN, DefaultRSocketRequester.ROUTING, - this.strategies); + this.requester = RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, TEXT_PLAIN, this.strategies); } @@ -143,13 +144,32 @@ public class DefaultRSocketRequesterTests { } @Test - public void sendCompositeMetadata() { - RSocketRequester requester = RSocketRequester.wrap(this.rsocket, - MimeTypeUtils.TEXT_PLAIN, DefaultRSocketRequester.COMPOSITE_METADATA, - this.strategies); + public void metadataCompositeWithRoute() { + + RSocketRequester requester = RSocketRequester.wrap( + this.rsocket, TEXT_PLAIN, COMPOSITE_METADATA, this.strategies); + + requester.route("toA").data("bodyA").send().block(Duration.ofSeconds(5)); + + CompositeMetadata entries = new CompositeMetadata(this.rsocket.getSavedPayload().metadata(), false); + Iterator iterator = entries.iterator(); + + assertThat(iterator.hasNext()).isTrue(); + CompositeMetadata.Entry entry = iterator.next(); + assertThat(entry.getMimeType()).isEqualTo(ROUTING.toString()); + assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); + + assertThat(iterator.hasNext()).isFalse(); + } + + @Test + public void metadataCompositeWithRouteAndTextEntry() { + + RSocketRequester requester = RSocketRequester.wrap( + this.rsocket, TEXT_PLAIN, COMPOSITE_METADATA, this.strategies); requester.route("toA") - .metadata("My metadata", MimeTypeUtils.TEXT_PLAIN).data("bodyA") + .metadata("My metadata", TEXT_PLAIN).data("bodyA") .send() .block(Duration.ofSeconds(5)); @@ -158,27 +178,46 @@ public class DefaultRSocketRequesterTests { assertThat(iterator.hasNext()).isTrue(); CompositeMetadata.Entry entry = iterator.next(); - assertThat(entry.getMimeType()).isEqualTo(DefaultRSocketRequester.ROUTING.toString()); + assertThat(entry.getMimeType()).isEqualTo(ROUTING.toString()); assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); assertThat(iterator.hasNext()).isTrue(); entry = iterator.next(); - assertThat(entry.getMimeType()).isEqualTo(MimeTypeUtils.TEXT_PLAIN.toString()); + assertThat(entry.getMimeType()).isEqualTo(TEXT_PLAIN.toString()); assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("My metadata"); assertThat(iterator.hasNext()).isFalse(); } + @Test + public void metadataRouteAsText() { + RSocketRequester requester = RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, TEXT_PLAIN, this.strategies); + requester.route("toA").data("bodyA").send().block(Duration.ofSeconds(5)); + assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("toA"); + } + + @Test + public void metadataAsText() { + RSocketRequester requester = RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, TEXT_PLAIN, this.strategies); + requester.metadata("toA", null).data("bodyA").send().block(Duration.ofSeconds(5)); + assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("toA"); + } + + @Test + public void metadataMimeTypeMismatch() { + RSocketRequester requester = RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, TEXT_PLAIN, this.strategies); + assertThatThrownBy(() -> requester.metadata("toA", ROUTING).data("bodyA").send().block()) + .hasMessageStartingWith("Connection configured for metadata mime type"); + } + @Test public void supportedMetadataMimeTypes() { - RSocketRequester.wrap(this.rsocket, MimeTypeUtils.TEXT_PLAIN, - DefaultRSocketRequester.COMPOSITE_METADATA, this.strategies); - RSocketRequester.wrap(this.rsocket, MimeTypeUtils.TEXT_PLAIN, - DefaultRSocketRequester.ROUTING, this.strategies); + RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, + COMPOSITE_METADATA, this.strategies); - assertThatIllegalArgumentException().isThrownBy(() -> RSocketRequester.wrap( - this.rsocket, MimeTypeUtils.TEXT_PLAIN, MimeTypeUtils.TEXT_PLAIN, this.strategies)); + RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, + ROUTING, this.strategies); } @Test diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/LeakAwareNettyDataBufferFactory.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/LeakAwareNettyDataBufferFactory.java new file mode 100644 index 00000000000..321ff4e3845 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/LeakAwareNettyDataBufferFactory.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.rsocket; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.NettyDataBuffer; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.core.io.buffer.PooledDataBuffer; +import org.springframework.util.ObjectUtils; + +/** + * Unlike {@link org.springframework.core.io.buffer.LeakAwareDataBufferFactory} + * this one is an instance of {@link NettyDataBufferFactory} which is necessary + * since {@link PayloadUtils} does instanceof checks, and that also allows + * intercepting {@link NettyDataBufferFactory#wrap(ByteBuf)}. + */ +public class LeakAwareNettyDataBufferFactory extends NettyDataBufferFactory { + + private final List created = new ArrayList<>(); + + + public LeakAwareNettyDataBufferFactory(ByteBufAllocator byteBufAllocator) { + super(byteBufAllocator); + } + + + public void checkForLeaks(Duration duration) throws InterruptedException { + Instant start = Instant.now(); + while (true) { + try { + this.created.forEach(info -> { + if (((PooledDataBuffer) info.getDataBuffer()).isAllocated()) { + throw info.getError(); + } + }); + break; + } + catch (AssertionError ex) { + if (Instant.now().isAfter(start.plus(duration))) { + throw ex; + } + } + Thread.sleep(50); + } + } + + public void reset() { + this.created.clear(); + } + + + @Override + public NettyDataBuffer allocateBuffer() { + return (NettyDataBuffer) recordHint(super.allocateBuffer()); + } + + @Override + public NettyDataBuffer allocateBuffer(int initialCapacity) { + return (NettyDataBuffer) recordHint(super.allocateBuffer(initialCapacity)); + } + + @Override + public NettyDataBuffer wrap(ByteBuf byteBuf) { + NettyDataBuffer dataBuffer = super.wrap(byteBuf); + if (byteBuf != Unpooled.EMPTY_BUFFER) { + recordHint(dataBuffer); + } + return dataBuffer; + } + + @Override + public DataBuffer join(List dataBuffers) { + return recordHint(super.join(dataBuffers)); + } + + private DataBuffer recordHint(DataBuffer buffer) { + AssertionError error = new AssertionError(String.format( + "DataBuffer leak: {%s} {%s} not released.%nStacktrace at buffer creation: ", buffer, + ObjectUtils.getIdentityHexString(((NettyDataBuffer) buffer).getNativeBuffer()))); + this.created.add(new DataBufferLeakInfo(buffer, error)); + return buffer; + } + + + private static class DataBufferLeakInfo { + + private final DataBuffer dataBuffer; + + private final AssertionError error; + + DataBufferLeakInfo(DataBuffer dataBuffer, AssertionError error) { + this.dataBuffer = dataBuffer; + this.error = error; + } + + DataBuffer getDataBuffer() { + return this.dataBuffer; + } + + AssertionError getError() { + return this.error; + } + } +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java index 215e0f31a8e..93e34ad77e2 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java @@ -18,14 +18,10 @@ package org.springframework.messaging.rsocket; import java.time.Duration; import java.time.Instant; -import java.util.ArrayList; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.PooledByteBufAllocator; -import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCounted; import io.rsocket.AbstractRSocket; import io.rsocket.RSocket; @@ -51,16 +47,11 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.Resource; -import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.NettyDataBuffer; -import org.springframework.core.io.buffer.NettyDataBufferFactory; -import org.springframework.core.io.buffer.PooledDataBuffer; import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.Payload; import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; import org.springframework.stereotype.Controller; -import org.springframework.util.ObjectUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -232,100 +223,6 @@ public class RSocketBufferLeakTests { } - /** - * Unlike {@link org.springframework.core.io.buffer.LeakAwareDataBufferFactory} - * this one is an instance of {@link NettyDataBufferFactory} which is necessary - * since {@link PayloadUtils} does instanceof checks, and that also allows - * intercepting {@link NettyDataBufferFactory#wrap(ByteBuf)}. - */ - private static class LeakAwareNettyDataBufferFactory extends NettyDataBufferFactory { - - private final List created = new ArrayList<>(); - - LeakAwareNettyDataBufferFactory(ByteBufAllocator byteBufAllocator) { - super(byteBufAllocator); - } - - void checkForLeaks(Duration duration) throws InterruptedException { - Instant start = Instant.now(); - while (true) { - try { - this.created.forEach(info -> { - if (((PooledDataBuffer) info.getDataBuffer()).isAllocated()) { - throw info.getError(); - } - }); - break; - } - catch (AssertionError ex) { - if (Instant.now().isAfter(start.plus(duration))) { - throw ex; - } - } - Thread.sleep(50); - } - } - - void reset() { - this.created.clear(); - } - - - @Override - public NettyDataBuffer allocateBuffer() { - return (NettyDataBuffer) recordHint(super.allocateBuffer()); - } - - @Override - public NettyDataBuffer allocateBuffer(int initialCapacity) { - return (NettyDataBuffer) recordHint(super.allocateBuffer(initialCapacity)); - } - - @Override - public NettyDataBuffer wrap(ByteBuf byteBuf) { - NettyDataBuffer dataBuffer = super.wrap(byteBuf); - if (byteBuf != Unpooled.EMPTY_BUFFER) { - recordHint(dataBuffer); - } - return dataBuffer; - } - - @Override - public DataBuffer join(List dataBuffers) { - return recordHint(super.join(dataBuffers)); - } - - private DataBuffer recordHint(DataBuffer buffer) { - AssertionError error = new AssertionError(String.format( - "DataBuffer leak: {%s} {%s} not released.%nStacktrace at buffer creation: ", buffer, - ObjectUtils.getIdentityHexString(((NettyDataBuffer) buffer).getNativeBuffer()))); - this.created.add(new DataBufferLeakInfo(buffer, error)); - return buffer; - } - } - - - private static class DataBufferLeakInfo { - - private final DataBuffer dataBuffer; - - private final AssertionError error; - - DataBufferLeakInfo(DataBuffer dataBuffer, AssertionError error) { - this.dataBuffer = dataBuffer; - this.error = error; - } - - DataBuffer getDataBuffer() { - return this.dataBuffer; - } - - AssertionError getError() { - return this.error; - } - } - - /** * Store all intercepted incoming and outgoing payloads and then use * {@link #checkForLeaks()} at the end to check reference counts. diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractorTests.java new file mode 100644 index 00000000000..44e13c69c64 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/DefaultMetadataExtractorTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.rsocket.annotation.support; + +import java.time.Duration; +import java.util.Map; + +import io.netty.buffer.PooledByteBufAllocator; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.BDDMockito; +import reactor.core.publisher.Mono; + +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.messaging.rsocket.LeakAwareNettyDataBufferFactory; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.messaging.rsocket.annotation.support.MessagingRSocket.COMPOSITE_METADATA; +import static org.springframework.messaging.rsocket.annotation.support.MessagingRSocket.ROUTING; +import static org.springframework.messaging.rsocket.annotation.support.MetadataExtractor.ROUTE_KEY; +import static org.springframework.util.MimeTypeUtils.TEXT_HTML; +import static org.springframework.util.MimeTypeUtils.TEXT_PLAIN; +import static org.springframework.util.MimeTypeUtils.TEXT_XML; + + +/** + * Unit tests for {@link DefaultMetadataExtractor}. + * @author Rossen Stoyanchev + */ +public class DefaultMetadataExtractorTests { + + private RSocketStrategies strategies; + + private ArgumentCaptor captor; + + private RSocket rsocket; + + private DefaultMetadataExtractor extractor; + + + @Before + public void setUp() { + this.strategies = RSocketStrategies.builder() + .decoder(StringDecoder.allMimeTypes()) + .encoder(CharSequenceEncoder.allMimeTypes()) + .dataBufferFactory(new LeakAwareNettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)) + .build(); + + this.rsocket = BDDMockito.mock(RSocket.class); + this.captor = ArgumentCaptor.forClass(Payload.class); + BDDMockito.when(this.rsocket.fireAndForget(captor.capture())).thenReturn(Mono.empty()); + + this.extractor = new DefaultMetadataExtractor(this.strategies); + } + + @After + public void tearDown() throws InterruptedException { + DataBufferFactory bufferFactory = this.strategies.dataBufferFactory(); + ((LeakAwareNettyDataBufferFactory) bufferFactory).checkForLeaks(Duration.ofSeconds(5)); + } + + + @Test + public void compositeMetadataWithDefaultSettings() { + + requester(COMPOSITE_METADATA).route("toA") + .metadata("text data", TEXT_PLAIN) + .metadata("html data", TEXT_HTML) + .metadata("xml data", TEXT_XML) + .data("data") + .send().block(); + + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, COMPOSITE_METADATA); + payload.release(); + + assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); + } + + @Test + public void compositeMetadataWithMimeTypeRegistrations() { + + this.extractor.metadataToExtract(TEXT_PLAIN, String.class, "text-entry"); + this.extractor.metadataToExtract(TEXT_HTML, String.class, "html-entry"); + this.extractor.metadataToExtract(TEXT_XML, String.class, "xml-entry"); + + requester(COMPOSITE_METADATA).route("toA") + .metadata("text data", TEXT_PLAIN) + .metadata("html data", TEXT_HTML) + .metadata("xml data", TEXT_XML) + .data("data") + .send() + .block(); + + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, COMPOSITE_METADATA); + payload.release(); + + assertThat(result).hasSize(4) + .containsEntry(ROUTE_KEY, "toA") + .containsEntry("text-entry", "text data") + .containsEntry("html-entry", "html data") + .containsEntry("xml-entry", "xml data"); + } + + @Test + public void route() { + + requester(ROUTING).route("toA").data("data").send().block(); + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, ROUTING); + payload.release(); + + assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); + } + + @Test + public void routeAsText() { + + this.extractor.metadataToExtract(TEXT_PLAIN, String.class, ROUTE_KEY); + + requester(TEXT_PLAIN).route("toA").data("data").send().block(); + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, TEXT_PLAIN); + payload.release(); + + assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); + } + + @Test + public void routeWithCustomFormatting() { + + this.extractor.metadataToExtract(TEXT_PLAIN, String.class, (text, result) -> { + String[] items = text.split(":"); + Assert.isTrue(items.length == 2, "Expected two items"); + result.put(ROUTE_KEY, items[0]); + result.put("entry1", items[1]); + }); + + requester(TEXT_PLAIN).metadata("toA:text data", null).data("data").send().block(); + Payload payload = this.captor.getValue(); + Map result = this.extractor.extract(payload, TEXT_PLAIN); + payload.release(); + + assertThat(result).hasSize(2) + .containsEntry(ROUTE_KEY, "toA") + .containsEntry("entry1", "text data"); + } + + + private RSocketRequester requester(MimeType metadataMimeType) { + return RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, metadataMimeType, this.strategies); + } + +}