From f5b4f7d9e89ba21e2fd842f71d8a9af1dcfa0535 Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Thu, 28 Dec 2023 23:42:04 +0100 Subject: [PATCH] Support for SSLContext configuration on StandardWebSocketClient Closes gh-30680 --- .../client/StandardWebSocketClient.java | 32 +++++++++++---- .../standard/StandardWebSocketClient.java | 41 +++++++++++++++---- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java index a57ccc063e3..c9b417b3c25 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/client/StandardWebSocketClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -114,7 +114,7 @@ public class StandardWebSocketClient implements WebSocketClient { return Mono.error(ex); } }) - .subscribeOn(Schedulers.boundedElastic()); // connectToServer is blocking + .subscribeOn(Schedulers.boundedElastic()); // connectToServer is blocking } private StandardWebSocketHandlerAdapter createEndpoint(URI url, WebSocketHandler handler, @@ -130,24 +130,38 @@ public class StandardWebSocketClient implements WebSocketClient { return new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol); } + /** + * Create the {@link StandardWebSocketSession} for the given Jakarta WebSocket Session. + * @see #bufferFactory() + */ protected StandardWebSocketSession createWebSocketSession( Session session, HandshakeInfo info, Sinks.Empty completionSink) { - return new StandardWebSocketSession( - session, info, DefaultDataBufferFactory.sharedInstance, completionSink); + return new StandardWebSocketSession(session, info, bufferFactory(), completionSink); + } + + /** + * Return the {@link DataBufferFactory} to use. + * @see #createWebSocketSession + */ + protected DataBufferFactory bufferFactory() { + return DefaultDataBufferFactory.sharedInstance; } - private ClientEndpointConfig createEndpointConfig(Configurator configurator, List subProtocols) { + /** + * Create the {@link ClientEndpointConfig} for the given configurator. + * Can be overridden to add extensions or an SSL context. + * @param configurator the configurator to apply + * @param subProtocols the preferred sub-protocols + * @since 6.1.3 + */ + protected ClientEndpointConfig createEndpointConfig(Configurator configurator, List subProtocols) { return ClientEndpointConfig.Builder.create() .configurator(configurator) .preferredSubprotocols(subProtocols) .build(); } - protected DataBufferFactory bufferFactory() { - return DefaultDataBufferFactory.sharedInstance; - } - private static final class DefaultConfigurator extends Configurator { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java index 801a2894399..6106a7631fe 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,8 @@ import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; +import javax.net.ssl.SSLContext; + import jakarta.websocket.ClientEndpointConfig; import jakarta.websocket.ClientEndpointConfig.Configurator; import jakarta.websocket.ContainerProvider; @@ -54,6 +56,7 @@ import org.springframework.web.socket.client.AbstractWebSocketClient; * A WebSocketClient based on the standard Jakarta WebSocket API. * * @author Rossen Stoyanchev + * @author Juergen Hoeller * @since 4.0 */ public class StandardWebSocketClient extends AbstractWebSocketClient { @@ -62,6 +65,9 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { private final Map userProperties = new HashMap<>(); + @Nullable + private SSLContext sslContext; + @Nullable private AsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); @@ -100,12 +106,29 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { } /** - * The configured user properties. + * Return the configured user properties. */ public Map getUserProperties() { return this.userProperties; } + /** + * Set the {@link SSLContext} to use for {@link ClientEndpointConfig#getSSLContext()}. + * @since 6.1.3 + */ + public void setSslContext(@Nullable SSLContext sslContext) { + this.sslContext = sslContext; + } + + /** + * Return the {@link SSLContext} to use. + * @since 6.1.3 + */ + @Nullable + public SSLContext getSslContext() { + return this.sslContext; + } + /** * Set an {@link AsyncTaskExecutor} to use when opening connections. *

If this property is set to {@code null}, calls to any of the @@ -134,17 +157,19 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port); InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port); - final StandardWebSocketSession session = new StandardWebSocketSession(headers, + StandardWebSocketSession session = new StandardWebSocketSession(headers, attributes, localAddress, remoteAddress); - final ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create() + ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create() .configurator(new StandardWebSocketClientConfigurator(headers)) .preferredSubprotocols(protocols) - .extensions(adaptExtensions(extensions)).build(); + .extensions(adaptExtensions(extensions)) + .sslContext(getSslContext()) + .build(); endpointConfig.getUserProperties().putAll(getUserProperties()); - final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); + Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); Callable connectTask = () -> { this.webSocketContainer.connectToServer(endpoint, endpointConfig, uri); @@ -167,7 +192,7 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { return result; } - private InetAddress getLocalHost() { + private static InetAddress getLocalHost() { try { return InetAddress.getLocalHost(); } @@ -176,7 +201,7 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { } } - private int getPort(URI uri) { + private static int getPort(URI uri) { if (uri.getPort() == -1) { String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH); return ("wss".equals(scheme) ? 443 : 80);