diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java index 8485afa4058..a715933a0d9 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java @@ -15,11 +15,11 @@ */ package org.springframework.messaging.rsocket; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; @@ -28,6 +28,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.Payload; import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.RoutingMetadata; import io.rsocket.metadata.WellKnownMimeType; import org.springframework.core.ParameterizedTypeReference; @@ -36,6 +37,7 @@ import org.springframework.core.codec.Decoder; import org.springframework.core.io.buffer.NettyDataBuffer; import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.MimeType; /** @@ -166,14 +168,19 @@ public class DefaultMetadataExtractor implements MetadataExtractor { } private void extractEntry(ByteBuf content, @Nullable String mimeType, Map result) { + if (content.readableBytes() == 0) { + return; + } EntryExtractor extractor = this.registrations.get(mimeType); if (extractor != null) { extractor.extract(content, result); return; } if (mimeType != null && mimeType.equals(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString())) { - // TODO: use rsocket-core API when available - result.put(MetadataExtractor.ROUTE_KEY, content.toString(StandardCharsets.UTF_8)); + Iterator iterator = new RoutingMetadata(content).iterator(); + if (iterator.hasNext()) { + result.put(MetadataExtractor.ROUTE_KEY, iterator.next()); + } } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java index f5f2e289f44..df5f3407d84 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java @@ -15,16 +15,17 @@ */ package org.springframework.messaging.rsocket; -import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; import io.rsocket.metadata.CompositeMetadataFlyweight; +import io.rsocket.metadata.TaggingMetadataFlyweight; import io.rsocket.metadata.WellKnownMimeType; import org.springframework.core.ResolvableType; @@ -161,25 +162,15 @@ final class MetadataEncoder { public DataBuffer encode() { if (this.isComposite) { CompositeByteBuf composite = this.allocator.compositeBuffer(); - if (this.route != null) { - CompositeMetadataFlyweight.encodeAndAddMetadata(composite, this.allocator, - WellKnownMimeType.MESSAGE_RSOCKET_ROUTING, - PayloadUtils.asByteBuf(bufferFactory().wrap(this.route.getBytes(StandardCharsets.UTF_8)))); - } try { - this.metadata.forEach((value, mimeType) -> { - DataBuffer buffer = encodeEntry(value, mimeType); - CompositeMetadataFlyweight.encodeAndAddMetadata( - composite, this.allocator, mimeType.toString(), PayloadUtils.asByteBuf(buffer)); - }); - if (bufferFactory() instanceof NettyDataBufferFactory) { - return ((NettyDataBufferFactory) bufferFactory()).wrap(composite); - } - else { - DataBuffer buffer = bufferFactory().wrap(composite.nioBuffer()); - composite.release(); - return buffer; + if (this.route != null) { + CompositeMetadataFlyweight.encodeAndAddMetadata(composite, this.allocator, + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING, encodeRoute()); } + this.metadata.forEach((value, mimeType) -> + CompositeMetadataFlyweight.encodeAndAddMetadata(composite, this.allocator, + mimeType.toString(), PayloadUtils.asByteBuf(encodeEntry(value, mimeType)))); + return asDataBuffer(composite); } catch (Throwable ex) { composite.release(); @@ -189,7 +180,7 @@ final class MetadataEncoder { else if (this.route != null) { Assert.isTrue(this.metadata.isEmpty(), "Composite metadata required for route and other entries"); return this.metadataMimeType.toString().equals(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()) ? - bufferFactory().wrap(this.route.getBytes(StandardCharsets.UTF_8)) : + asDataBuffer(encodeRoute()) : encodeEntry(this.route, this.metadataMimeType); } else { @@ -204,6 +195,11 @@ final class MetadataEncoder { } } + private ByteBuf encodeRoute() { + return TaggingMetadataFlyweight.createRoutingMetadata( + this.allocator, Collections.singletonList(this.route)).getContent(); + } + @SuppressWarnings("unchecked") private DataBuffer encodeEntry(Object metadata, MimeType mimeType) { if (metadata instanceof DataBuffer) { @@ -215,4 +211,14 @@ final class MetadataEncoder { return encoder.encodeValue((T) metadata, bufferFactory(), type, mimeType, Collections.emptyMap()); } + private DataBuffer asDataBuffer(ByteBuf byteBuf) { + if (bufferFactory() instanceof NettyDataBufferFactory) { + return ((NettyDataBufferFactory) bufferFactory()).wrap(byteBuf); + } + else { + DataBuffer buffer = bufferFactory().wrap(byteBuf.nioBuffer()); + byteBuf.release(); + return buffer; + } + } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java index 41a9204e38d..23da144de34 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java @@ -19,6 +19,7 @@ import java.util.Map; import io.rsocket.Payload; +import org.springframework.core.codec.DecodingException; import org.springframework.util.MimeType; /** @@ -45,6 +46,8 @@ public interface MetadataExtractor { * @param payload the payload whose metadata should be read * @param metadataMimeType the metadata MimeType for the connection. * @return name values pairs extracted from the metadata + * @throws DecodingException if the metadata cannot be decoded + * @throws IllegalArgumentException if routing metadata cannot be decoded */ Map extract(Payload payload, MimeType metadataMimeType); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java index a664e685fd3..00c474a5aaa 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java @@ -23,6 +23,7 @@ import java.util.Map; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.RoutingMetadata; import io.rsocket.metadata.WellKnownMimeType; import org.junit.jupiter.api.Test; @@ -63,8 +64,7 @@ public class MetadataEncoderTests { assertThat(iterator.hasNext()).isTrue(); CompositeMetadata.Entry entry = iterator.next(); assertThat(entry.getMimeType()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); - assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); - + assertRoute("toA", entry.getContent()); assertThat(iterator.hasNext()).isFalse(); } @@ -82,7 +82,7 @@ public class MetadataEncoderTests { assertThat(iterator.hasNext()).isTrue(); CompositeMetadata.Entry entry = iterator.next(); assertThat(entry.getMimeType()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); - assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); + assertRoute("toA", entry.getContent()); assertThat(iterator.hasNext()).isTrue(); entry = iterator.next(); @@ -102,7 +102,7 @@ public class MetadataEncoderTests { .route("toA") .encode(); - assertThat(dumpString(buffer)).isEqualTo("toA"); + assertRoute("toA", ((NettyDataBuffer) buffer).getNativeBuffer()); } @Test @@ -196,12 +196,19 @@ public class MetadataEncoderTests { assertThat(iterator.hasNext()).isTrue(); CompositeMetadata.Entry entry = iterator.next(); assertThat(entry.getMimeType()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); - assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); + assertRoute("toA", entry.getContent()); assertThat(iterator.hasNext()).isFalse(); } + private void assertRoute(String route, ByteBuf metadata) { + Iterator tags = new RoutingMetadata(metadata).iterator(); + assertThat(tags.hasNext()).isTrue(); + assertThat(tags.next()).isEqualTo(route); + assertThat(tags.hasNext()).isFalse(); + } + private String dumpString(DataBuffer buffer) { return DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java index 93bb125ab9d..3b823757d4e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java @@ -21,6 +21,7 @@ import java.time.Duration; import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.metadata.WellKnownMimeType; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import org.junit.jupiter.api.AfterAll; @@ -38,6 +39,8 @@ import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler; import org.springframework.stereotype.Controller; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -61,6 +64,9 @@ public class RSocketClientToServerIntegrationTests { @SuppressWarnings("ConstantConditions") public static void setupOnce() { + MimeType metadataMimeType = MimeTypeUtils.parseMimeType( + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + context = new AnnotationConfigApplicationContext(ServerConfig.class); RSocketMessageHandler messageHandler = context.getBean(RSocketMessageHandler.class); SocketAcceptor responder = messageHandler.responder(); @@ -74,6 +80,7 @@ public class RSocketClientToServerIntegrationTests { .block(); requester = RSocketRequester.builder() + .metadataMimeType(metadataMimeType) .rsocketStrategies(context.getBean(RSocketStrategies.class)) .connectTcp("localhost", 7000) .block();