diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java index ada40d92eae..ee296fd6e85 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java @@ -31,7 +31,6 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; import org.springframework.util.MimeType; /** @@ -48,17 +47,13 @@ import org.springframework.util.MimeType; */ public class DefaultMetadataExtractor implements MetadataExtractor { - private final RSocketStrategies rsocketStrategies; - private final Map> entryProcessors = new HashMap<>(); /** * Default constructor with {@link RSocketStrategies}. */ - public DefaultMetadataExtractor(RSocketStrategies strategies) { - Assert.notNull(strategies, "RSocketStrategies is required"); - this.rsocketStrategies = strategies; + public DefaultMetadataExtractor() { // TODO: remove when rsocket-core API available metadataToExtract(MetadataExtractor.ROUTING, String.class, ROUTE_KEY); } @@ -128,24 +123,26 @@ public class DefaultMetadataExtractor implements MetadataExtractor { @Override - public Map extract(Payload payload, MimeType metadataMimeType) { + public Map extract(Payload payload, MimeType metadataMimeType, RSocketStrategies strategies) { Map result = new HashMap<>(); if (metadataMimeType.equals(COMPOSITE_METADATA)) { for (CompositeMetadata.Entry entry : new CompositeMetadata(payload.metadata(), false)) { - processEntry(entry.getContent(), entry.getMimeType(), result); + processEntry(entry.getContent(), entry.getMimeType(), result, strategies); } } else { - processEntry(payload.metadata(), metadataMimeType.toString(), result); + processEntry(payload.metadata(), metadataMimeType.toString(), result, strategies); } return result; } - private void processEntry(ByteBuf content, @Nullable String mimeType, Map result) { + private void processEntry(ByteBuf content, + @Nullable String mimeType, Map result, RSocketStrategies strategies) { + EntryProcessor entryProcessor = this.entryProcessors.get(mimeType); if (entryProcessor != null) { content.retain(); - entryProcessor.process(content, result); + entryProcessor.process(content, result, strategies); return; } if (MetadataExtractor.ROUTING.toString().equals(mimeType)) { @@ -166,8 +163,6 @@ public class DefaultMetadataExtractor implements MetadataExtractor { private final BiConsumer> accumulator; - private final Decoder decoder; - public EntryProcessor( MimeType mimeType, Class targetType, @@ -190,17 +185,17 @@ public class DefaultMetadataExtractor implements MetadataExtractor { this.mimeType = mimeType; this.targetType = targetType; this.accumulator = accumulator; - this.decoder = rsocketStrategies.decoder(targetType, mimeType); } - public void process(ByteBuf byteBuf, Map result) { - DataBufferFactory factory = rsocketStrategies.dataBufferFactory(); + public void process(ByteBuf byteBuf, Map result, RSocketStrategies strategies) { + DataBufferFactory factory = strategies.dataBufferFactory(); DataBuffer buffer = factory instanceof NettyDataBufferFactory ? ((NettyDataBufferFactory) factory).wrap(byteBuf) : factory.wrap(byteBuf.nioBuffer()); - T value = this.decoder.decode(buffer, this.targetType, this.mimeType, Collections.emptyMap()); + Decoder decoder = strategies.decoder(this.targetType, this.mimeType); + T value = decoder.decode(buffer, this.targetType, this.mimeType, Collections.emptyMap()); this.accumulator.accept(value, result); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java index 580b702da57..86e86a9bc36 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java @@ -58,8 +58,9 @@ public interface MetadataExtractor { * @param payload the payload whose metadata should be read * @param metadataMimeType the mime type of the metadata; this is what was * specified by the client at the start of the RSocket connection. + * @param strategies for access to codecs and a DataBufferFactory * @return a map of 0 or more decoded metadata values with assigned names */ - Map extract(Payload payload, MimeType metadataMimeType); + Map extract(Payload payload, MimeType metadataMimeType, RSocketStrategies strategies); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java index 1f5f7bd1367..d205e597f79 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java @@ -31,7 +31,6 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.NettyDataBuffer; import org.springframework.lang.Nullable; @@ -43,6 +42,7 @@ import org.springframework.messaging.handler.invocation.reactive.HandlerMethodRe import org.springframework.messaging.rsocket.MetadataExtractor; import org.springframework.messaging.rsocket.PayloadUtils; import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; @@ -73,12 +73,12 @@ class MessagingRSocket extends AbstractRSocket { private final RSocketRequester requester; - private final DataBufferFactory bufferFactory; + private final RSocketStrategies strategies; MessagingRSocket(MimeType dataMimeType, MimeType metadataMimeType, MetadataExtractor metadataExtractor, RSocketRequester requester, ReactiveMessageHandler messageHandler, - RouteMatcher routeMatcher, DataBufferFactory bufferFactory) { + RouteMatcher routeMatcher, RSocketStrategies strategies) { Assert.notNull(dataMimeType, "'dataMimeType' is required"); Assert.notNull(metadataMimeType, "'metadataMimeType' is required"); @@ -86,7 +86,7 @@ class MessagingRSocket extends AbstractRSocket { Assert.notNull(requester, "'requester' is required"); Assert.notNull(messageHandler, "'messageHandler' is required"); Assert.notNull(routeMatcher, "'routeMatcher' is required"); - Assert.notNull(bufferFactory, "'bufferFactory' is required"); + Assert.notNull(strategies, "RSocketStrategies is required"); this.dataMimeType = dataMimeType; this.metadataMimeType = metadataMimeType; @@ -94,7 +94,7 @@ class MessagingRSocket extends AbstractRSocket { this.requester = requester; this.messageHandler = messageHandler; this.routeMatcher = routeMatcher; - this.bufferFactory = bufferFactory; + this.strategies = strategies; } @@ -183,7 +183,7 @@ class MessagingRSocket extends AbstractRSocket { } private DataBuffer retainDataAndReleasePayload(Payload payload) { - return PayloadUtils.retainDataAndReleasePayload(payload, this.bufferFactory); + return PayloadUtils.retainDataAndReleasePayload(payload, this.strategies.dataBufferFactory()); } private MessageHeaders createHeaders(Payload payload, FrameType frameType, @@ -192,7 +192,9 @@ class MessagingRSocket extends AbstractRSocket { MessageHeaderAccessor headers = new MessageHeaderAccessor(); headers.setLeaveMutable(true); - Map metadataValues = this.metadataExtractor.extract(payload, this.metadataMimeType); + Map metadataValues = + this.metadataExtractor.extract(payload, this.metadataMimeType, this.strategies); + metadataValues.putIfAbsent(MetadataExtractor.ROUTE_KEY, ""); for (Map.Entry entry : metadataValues.entrySet()) { if (entry.getKey().equals(MetadataExtractor.ROUTE_KEY)) { @@ -210,7 +212,8 @@ class MessagingRSocket extends AbstractRSocket { if (replyMono != null) { headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono); } - headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, this.bufferFactory); + headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, + this.strategies.dataBufferFactory()); return headers.getMessageHeaders(); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java index dd9678272bf..7f7ba9d9a1e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java @@ -208,7 +208,7 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { .build(); } if (this.metadataExtractor == null) { - DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(this.rsocketStrategies); + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); extractor.metadataToExtract(MimeTypeUtils.TEXT_PLAIN, String.class, MetadataExtractor.ROUTE_KEY); this.metadataExtractor = extractor; } @@ -318,7 +318,7 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { Assert.notNull(this.metadataExtractor, () -> "No MetadataExtractor. Was afterPropertiesSet not called?"); return new MessagingRSocket(dataMimeType, metadataMimeType, this.metadataExtractor, requester, - this, getRouteMatcher(), strategies.dataBufferFactory()); + this, getRouteMatcher(), strategies); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java index 8e489721340..a8af5523406 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java @@ -70,7 +70,7 @@ public class DefaultMetadataExtractorTests { this.captor = ArgumentCaptor.forClass(Payload.class); BDDMockito.when(this.rsocket.fireAndForget(captor.capture())).thenReturn(Mono.empty()); - this.extractor = new DefaultMetadataExtractor(this.strategies); + this.extractor = new DefaultMetadataExtractor(); } @After @@ -91,7 +91,7 @@ public class DefaultMetadataExtractorTests { .send().block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, COMPOSITE_METADATA); + Map result = this.extractor.extract(payload, COMPOSITE_METADATA, this.strategies); payload.release(); assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); @@ -113,7 +113,7 @@ public class DefaultMetadataExtractorTests { .block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, COMPOSITE_METADATA); + Map result = this.extractor.extract(payload, COMPOSITE_METADATA, this.strategies); payload.release(); assertThat(result).hasSize(4) @@ -128,7 +128,7 @@ public class DefaultMetadataExtractorTests { requester(ROUTING).route("toA").data("data").send().block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, ROUTING); + Map result = this.extractor.extract(payload, ROUTING, this.strategies); payload.release(); assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); @@ -141,7 +141,7 @@ public class DefaultMetadataExtractorTests { requester(TEXT_PLAIN).route("toA").data("data").send().block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, TEXT_PLAIN); + Map result = this.extractor.extract(payload, TEXT_PLAIN, this.strategies); payload.release(); assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); @@ -159,7 +159,7 @@ public class DefaultMetadataExtractorTests { requester(TEXT_PLAIN).metadata("toA:text data", null).data("data").send().block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, TEXT_PLAIN); + Map result = this.extractor.extract(payload, TEXT_PLAIN, this.strategies); payload.release(); assertThat(result).hasSize(2)