diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java index 7a49b9c2813..bdaea4dff5a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java @@ -27,6 +27,7 @@ import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; import reactor.core.publisher.Mono; +import org.springframework.lang.Nullable; import org.springframework.util.MimeType; /** @@ -39,6 +40,9 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { private List> factoryConfigurers = new ArrayList<>(); + @Nullable + private RSocketStrategies strategies; + private List> strategiesConfigurers = new ArrayList<>(); @@ -48,6 +52,12 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { return this; } + @Override + public RSocketRequester.Builder rsocketStrategies(@Nullable RSocketStrategies strategies) { + this.strategies = strategies; + return this; + } + @Override public RSocketRequester.Builder rsocketStrategies(Consumer configurer) { this.strategiesConfigurers.add(configurer); @@ -55,28 +65,54 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { } @Override - public Mono connectTcp(String host, int port, MimeType dataMimeType) { - return connect(TcpClientTransport.create(host, port), dataMimeType); + public Mono connectTcp(String host, int port) { + return connect(TcpClientTransport.create(host, port)); } @Override - public Mono connectWebSocket(URI uri, MimeType dataMimeType) { - return connect(WebsocketClientTransport.create(uri), dataMimeType); + public Mono connectWebSocket(URI uri) { + return connect(WebsocketClientTransport.create(uri)); } @Override - public Mono connect(ClientTransport transport, MimeType dataMimeType) { + public Mono connect(ClientTransport transport) { return Mono.defer(() -> { - String mimeType = dataMimeType.toString(); - RSocketFactory.ClientRSocketFactory factory = RSocketFactory.connect().dataMimeType(mimeType); - this.factoryConfigurers.forEach(configurer -> configurer.accept(factory)); + RSocketStrategies strategies = getRSocketStrategies(); + MimeType dataMimeType = getDefaultDataMimeType(strategies); - RSocketStrategies.Builder builder = RSocketStrategies.builder(); - this.strategiesConfigurers.forEach(configurer -> configurer.accept(builder)); + RSocketFactory.ClientRSocketFactory factory = RSocketFactory.connect(); + if (dataMimeType != null) { + factory.dataMimeType(dataMimeType.toString()); + } + this.factoryConfigurers.forEach(configurer -> configurer.accept(factory)); return factory.transport(transport).start() - .map(rsocket -> RSocketRequester.create(rsocket, dataMimeType, builder.build())); + .map(rsocket -> RSocketRequester.create(rsocket, dataMimeType, strategies)); }); } + private RSocketStrategies getRSocketStrategies() { + if (this.strategiesConfigurers.isEmpty()) { + return this.strategies != null ? this.strategies : RSocketStrategies.builder().build(); + } + RSocketStrategies.Builder strategiesBuilder = this.strategies != null ? + this.strategies.mutate() : RSocketStrategies.builder(); + this.strategiesConfigurers.forEach(configurer -> configurer.accept(strategiesBuilder)); + return strategiesBuilder.build(); + } + + @Nullable + private MimeType getDefaultDataMimeType(RSocketStrategies strategies) { + return strategies.encoders().stream() + .flatMap(encoder -> encoder.getEncodableMimeTypes().stream()) + .filter(MimeType::isConcrete) + .findFirst() + .orElseGet(() -> + strategies.decoders().stream() + .flatMap(encoder -> encoder.getDecodableMimeTypes().stream()) + .filter(MimeType::isConcrete) + .findFirst() + .orElse(null)); + } + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java index 8c23f499d7f..5b18f14cb5b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java @@ -91,15 +91,25 @@ public interface RSocketRequester { interface Builder { /** - * Configure the {@code ClientRSocketFactory} to customize protocol - * options, register RSocket plugins (interceptors), and more. + * Configure the {@code ClientRSocketFactory}. + *

Note there is typically no need to set a data MimeType explicitly. + * By default a data MimeType is picked by taking the first concrete + * MimeType supported by the configured encoders and decoders. * @param configurer the configurer to apply */ RSocketRequester.Builder rsocketFactory(Consumer configurer); /** - * Configure the builder for {@link RSocketStrategies}. - *

The builder starts with an empty {@code RSocketStrategies}. + * Set the {@link RSocketStrategies} instance. + * @param strategies the strategies to use + */ + RSocketRequester.Builder rsocketStrategies(@Nullable RSocketStrategies strategies); + + /** + * Customize the {@link RSocketStrategies}. + *

By default this starts out with an empty builder, i.e. + * {@link RSocketStrategies#builder()}, but the strategies can also be + * set via {@link #rsocketStrategies(RSocketStrategies)}. * @param configurer the configurer to apply */ RSocketRequester.Builder rsocketStrategies(Consumer configurer); @@ -108,25 +118,23 @@ public interface RSocketRequester { * Connect to the RSocket server over TCP. * @param host the server host * @param port the server port - * @param dataMimeType the data MimeType for the connection * @return an {@code RSocketRequester} for the connection */ - Mono connectTcp(String host, int port, MimeType dataMimeType); + Mono connectTcp(String host, int port); /** * Connect to the RSocket server over WebSocket. * @param uri the RSocket server endpoint URI - * @param dataMimeType the data MimeType * @return an {@code RSocketRequester} for the connection */ - Mono connectWebSocket(URI uri, MimeType dataMimeType); + Mono connectWebSocket(URI uri); /** * Connect to the RSocket server with the given {@code ClientTransport}. * @param transport the client transport to use * @return an {@code RSocketRequester} for the connection */ - Mono connect(ClientTransport transport, MimeType dataMimeType); + Mono connect(ClientTransport transport); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java index c282accdd86..fa2393c13de 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java @@ -28,14 +28,8 @@ import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.util.MimeTypeUtils; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; /** * Unit tests for {@link DefaultRSocketRequesterBuilder}. @@ -46,6 +40,7 @@ public class DefaultRSocketRequesterBuilderTests { private ClientTransport transport; + @Before public void setup() { this.transport = mock(ClientTransport.class); @@ -57,10 +52,10 @@ public class DefaultRSocketRequesterBuilderTests { public void shouldApplyCustomizationsAtSubscription() { Consumer factoryConfigurer = mock(Consumer.class); Consumer strategiesConfigurer = mock(Consumer.class); - Mono requester = RSocketRequester.builder() + RSocketRequester.builder() .rsocketFactory(factoryConfigurer) .rsocketStrategies(strategiesConfigurer) - .connect(this.transport, MimeTypeUtils.APPLICATION_JSON); + .connect(this.transport); verifyZeroInteractions(this.transport, factoryConfigurer, strategiesConfigurer); } @@ -69,10 +64,10 @@ public class DefaultRSocketRequesterBuilderTests { public void shouldApplyCustomizations() { Consumer factoryConfigurer = mock(Consumer.class); Consumer strategiesConfigurer = mock(Consumer.class); - RSocketRequester requester = RSocketRequester.builder() + RSocketRequester.builder() .rsocketFactory(factoryConfigurer) .rsocketStrategies(strategiesConfigurer) - .connect(this.transport, MimeTypeUtils.APPLICATION_JSON) + .connect(this.transport) .block(); verify(this.transport).connect(anyInt()); verify(factoryConfigurer).accept(any(RSocketFactory.ClientRSocketFactory.class)); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java index 70fd9d29a7b..cd800d67428 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java @@ -32,7 +32,6 @@ import io.rsocket.RSocket; import io.rsocket.RSocketFactory; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.plugins.RSocketInterceptor; -import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import org.junit.After; @@ -60,7 +59,6 @@ import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.Payload; import org.springframework.stereotype.Controller; -import org.springframework.util.MimeTypeUtils; import org.springframework.util.ObjectUtils; import static org.junit.Assert.*; @@ -78,8 +76,6 @@ public class RSocketBufferLeakTests { private static CloseableChannel server; - private static RSocket client; - private static RSocketRequester requester; @@ -96,21 +92,19 @@ public class RSocketBufferLeakTests { .start() .block(); - client = RSocketFactory.connect() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .addClientPlugin(payloadInterceptor) // intercept outgoing requests - .dataMimeType(MimeTypeUtils.TEXT_PLAIN_VALUE) - .transport(TcpClientTransport.create("localhost", 7000)) - .start() + requester = RSocketRequester.builder() + .rsocketFactory(factory -> { + factory.frameDecoder(PayloadDecoder.ZERO_COPY); + factory.addClientPlugin(payloadInterceptor); // intercept outgoing requests + }) + .rsocketStrategies(context.getBean(RSocketStrategies.class)) + .connectTcp("localhost", 7000) .block(); - - requester = RSocketRequester.create( - client, MimeTypeUtils.TEXT_PLAIN, context.getBean(RSocketStrategies.class)); } @AfterClass public static void tearDownOnce() { - client.dispose(); + requester.rsocket().dispose(); server.dispose(); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java index 3639622f9dd..dea21477d97 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java @@ -40,7 +40,6 @@ import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.stereotype.Controller; -import org.springframework.util.MimeTypeUtils; import static org.junit.Assert.*; @@ -75,11 +74,8 @@ public class RSocketClientToServerIntegrationTests { requester = RSocketRequester.builder() .rsocketFactory(factory -> factory.frameDecoder(PayloadDecoder.ZERO_COPY)) - .rsocketStrategies(strategies -> strategies - .decoder(StringDecoder.allMimeTypes()) - .encoder(CharSequenceEncoder.allMimeTypes()) - .dataBufferFactory(new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT))) - .connectTcp("localhost", 7000, MimeTypeUtils.TEXT_PLAIN) + .rsocketStrategies(context.getBean(RSocketStrategies.class)) + .connectTcp("localhost", 7000) .block(); }