diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/TextMessage.java b/spring-websocket/src/main/java/org/springframework/web/socket/TextMessage.java index 370f433953d..5c078988ac0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/TextMessage.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/TextMessage.java @@ -16,8 +16,6 @@ package org.springframework.web.socket; -import java.io.Reader; -import java.io.StringReader; /** * A {@link WebSocketMessage} that contains a textual {@link String} payload. @@ -46,13 +44,6 @@ public final class TextMessage extends WebSocketMessage { super(payload.toString(), isLast); } - /** - * Returns access to the message payload as a {@link Reader}. - */ - public Reader getReader() { - return new StringReader(getPayload()); - } - @Override protected int getPayloadSize() { return getPayload().length(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java index ab5dc2d60b4..599e94f91cf 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java @@ -24,7 +24,6 @@ import org.springframework.http.HttpHeaders; import org.springframework.util.CollectionUtils; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; -import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; /** @@ -61,11 +60,9 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { /** * Decorate the WebSocketHandler provided to the class constructor. *

- * By default {@link ExceptionWebSocketHandlerDecorator} and - * {@link LoggingWebSocketHandlerDecorator} are applied are added. + * By default {@link LoggingWebSocketHandlerDecorator} is added. */ protected WebSocketHandler decorateWebSocketHandler(WebSocketHandler handler) { - handler = new ExceptionWebSocketHandlerDecorator(handler); return new LoggingWebSocketHandlerDecorator(handler); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java index d81bcff5c22..8dabcface61 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java @@ -33,6 +33,7 @@ import javax.websocket.WebSocketContainer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.StandardEndpointAdapter; @@ -53,49 +54,57 @@ public class StandardWebSocketClient implements WebSocketClient { private static final Log logger = LogFactory.getLog(StandardWebSocketClient.class); - private WebSocketContainer webSocketContainer; + private final WebSocketContainer webSocketContainer; - public WebSocketContainer getWebSocketContainer() { - if (this.webSocketContainer == null) { - this.webSocketContainer = ContainerProvider.getWebSocketContainer(); - } - return this.webSocketContainer; + public StandardWebSocketClient() { + this.webSocketContainer = ContainerProvider.getWebSocketContainer(); } - public void setWebSocketContainer(WebSocketContainer container) { - this.webSocketContainer = container; + public StandardWebSocketClient(WebSocketContainer webSocketContainer) { + Assert.notNull(webSocketContainer, "webSocketContainer is required"); + this.webSocketContainer = webSocketContainer; } @Override public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVariables) throws WebSocketConnectFailureException { + Assert.notNull(uriTemplate, "uriTemplate is required"); UriComponents uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVariables).encode(); - return doHandshake(webSocketHandler, null, uriComponents); + return doHandshake(webSocketHandler, null, uriComponents.toUri()); } @Override public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders httpHeaders, URI uri) throws WebSocketConnectFailureException { + Assert.notNull(webSocketHandler, "webSocketHandler is required"); + Assert.notNull(uri, "uri is required"); + + httpHeaders = (httpHeaders != null) ? httpHeaders : new HttpHeaders(); + + if (logger.isDebugEnabled()) { + logger.debug("Connecting to " + uri); + } + StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter(); session.setUri(uri); session.setRemoteHostName(uri.getHost()); - Endpoint endpoint = new StandardEndpointAdapter(webSocketHandler, session); ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create(); - if (httpHeaders != null) { - List protocols = httpHeaders.getSecWebSocketProtocol(); - if (!protocols.isEmpty()) { - configBuidler.preferredSubprotocols(protocols); - } - configBuidler.configurator(new StandardWebSocketClientConfigurator(httpHeaders)); + configBuidler.configurator(new StandardWebSocketClientConfigurator(httpHeaders)); + + List protocols = httpHeaders.getSecWebSocketProtocol(); + if (!protocols.isEmpty()) { + configBuidler.preferredSubprotocols(protocols); } try { // TODO: do not block + Endpoint endpoint = new StandardEndpointAdapter(webSocketHandler, session); this.webSocketContainer.connectToServer(endpoint, configBuidler.build(), uri); + return session; } catch (Exception e) { @@ -128,14 +137,14 @@ public class StandardWebSocketClient implements WebSocketClient { headers.put(headerName, value); } } - if (logger.isTraceEnabled()) { - logger.trace("Handshake request headers: " + headers); + if (logger.isDebugEnabled()) { + logger.debug("Handshake request headers: " + headers); } } @Override public void afterResponse(HandshakeResponse handshakeResponse) { - if (logger.isTraceEnabled()) { - logger.trace("Handshake response headers: " + handshakeResponse.getHeaders()); + if (logger.isDebugEnabled()) { + logger.debug("Handshake response headers: " + handshakeResponse.getHeaders()); } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java index d5bb04ddfd3..0203203e8b6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java @@ -22,6 +22,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.context.SmartLifecycle; import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter; @@ -126,13 +127,20 @@ public class JettyWebSocketClient implements WebSocketClient, SmartLifecycle { throws WebSocketConnectFailureException { UriComponents uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVariables).encode(); - return doHandshake(webSocketHandler, null, uriComponents); + return doHandshake(webSocketHandler, null, uriComponents.toUri()); } @Override public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri) throws WebSocketConnectFailureException { + Assert.notNull(webSocketHandler, "webSocketHandler is required"); + Assert.notNull(uri, "uri is required"); + + if (logger.isDebugEnabled()) { + logger.debug("Connecting to " + uri); + } + // TODO: populate headers JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java index 118ef0d40ab..0d8ed6dcdd6 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java @@ -27,7 +27,6 @@ import org.springframework.http.HttpHeaders; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; -import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; import org.springframework.web.socket.support.WebSocketHandlerDecorator; @@ -69,11 +68,7 @@ public class WebSocketConnectionManagerTests { WebSocketHandlerDecorator loggingHandler = captor.getValue(); assertEquals(LoggingWebSocketHandlerDecorator.class, loggingHandler.getClass()); - WebSocketHandlerDecorator exceptionHandler = (WebSocketHandlerDecorator) loggingHandler.getDelegate(); - assertNotNull(exceptionHandler); - assertEquals(ExceptionWebSocketHandlerDecorator.class, exceptionHandler.getClass()); - - assertSame(handler, exceptionHandler.getDelegate()); + assertSame(handler, loggingHandler.getDelegate()); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java index 54187c39912..5e16a560f7c 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java @@ -34,7 +34,6 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.StandardEndpointAdapter; import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; -import org.springframework.web.socket.client.endpoint.StandardWebSocketClient; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -60,18 +59,15 @@ public class StandardWebSocketClientTests { WebSocketHandler handler = new WebSocketHandlerAdapter(); WebSocketContainer webSocketContainer = mock(WebSocketContainer.class); - StandardWebSocketClient client = new StandardWebSocketClient(); - client.setWebSocketContainer(webSocketContainer); + StandardWebSocketClient client = new StandardWebSocketClient(webSocketContainer); WebSocketSession session = client.doHandshake(handler, headers, uri); ArgumentCaptor endpointArg = ArgumentCaptor.forClass(Endpoint.class); ArgumentCaptor configArg = ArgumentCaptor.forClass(ClientEndpointConfig.class); ArgumentCaptor uriArg = ArgumentCaptor.forClass(URI.class); - verify(webSocketContainer).connectToServer(endpointArg.capture(), configArg.capture(), uriArg.capture()); - assertNotNull(endpointArg.getValue()); assertEquals(StandardEndpointAdapter.class, endpointArg.getValue().getClass()); @@ -86,4 +82,5 @@ public class StandardWebSocketClientTests { assertEquals(uri, session.getUri()); assertEquals("example.com", session.getRemoteHostName()); } + }