Browse Source

Add header support to WebSocketConnectionManager

Issue: SPR-10796
pull/331/head
Rossen Stoyanchev 13 years ago
parent
commit
172a0b9f5d
  1. 111
      spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java
  2. 35
      spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java
  3. 61
      spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java
  4. 27
      spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java
  5. 26
      spring-websocket/src/test/java/org/springframework/web/socket/client/AbstractWebSocketClientTests.java
  6. 9
      spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java
  7. 133
      spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java

111
spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java

@ -0,0 +1,111 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.client;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.util.UriComponentsBuilder;
/**
* Abstract base class for {@link WebSocketClient} implementations.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractWebSocketClient implements WebSocketClient {
protected final Log logger = LogFactory.getLog(getClass());
private static final Set<String> disallowedHeaders = new HashSet<String>();
static {
disallowedHeaders.add("cache-control");
disallowedHeaders.add("cookie");
disallowedHeaders.add("connection");
disallowedHeaders.add("host");
disallowedHeaders.add("sec-websocket-extensions");
disallowedHeaders.add("sec-websocket-key");
disallowedHeaders.add("sec-websocket-protocol");
disallowedHeaders.add("sec-websocket-version");
disallowedHeaders.add("pragma");
disallowedHeaders.add("upgrade");
}
@Override
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate,
Object... uriVars) throws WebSocketConnectFailureException {
Assert.notNull(uriTemplate, "uriTemplate must not be null");
URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVars).encode().toUri();
return doHandshake(webSocketHandler, null, uri);
}
@Override
public final WebSocketSession doHandshake(WebSocketHandler webSocketHandler,
HttpHeaders headers, URI uri) throws WebSocketConnectFailureException {
Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
Assert.notNull(uri, "uri must not be null");
if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + uri);
}
HttpHeaders headersToUse = new HttpHeaders();
if (headers != null) {
for (String header : headers.keySet()) {
if (!disallowedHeaders.contains(header.toLowerCase())) {
headersToUse.put(header, headers.get(header));
}
}
}
List<String> subProtocols = new ArrayList<String>();
if ((headers != null) && (headers.getSecWebSocketProtocol() != null)) {
subProtocols.addAll(headers.getSecWebSocketProtocol());
}
return doHandshakeInternal(webSocketHandler, headersToUse, uri, subProtocols);
}
/**
*
*
* @param webSocketHandler the client-side handler for WebSocket messages
* @param headers HTTP headers to use for the handshake, with unwanted (forbidden)
* headers filtered out, never {@code null}
* @param uri the target URI for the handshake, never {@code null}
* @param subProtocols requested sub-protocols, or an empty list
* @return the established WebSocket session
* @throws WebSocketConnectFailureException
*/
protected abstract WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler,
HttpHeaders headers, URI uri, List<String> subProtocols) throws WebSocketConnectFailureException;
}

35
spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java

@ -45,6 +45,8 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
private final List<String> protocols = new ArrayList<String>(); private final List<String> protocols = new ArrayList<String>();
private HttpHeaders headers;
private final boolean syncClientLifecycle; private final boolean syncClientLifecycle;
@ -67,17 +69,41 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
return new LoggingWebSocketHandlerDecorator(handler); return new LoggingWebSocketHandlerDecorator(handler);
} }
public void setSupportedProtocols(List<String> protocols) { /**
* Set the sub-protocols to use. If configured, specified sub-protocols will be
* requested in the handshake through the {@code Sec-WebSocket-Protocol} header. The
* resulting WebSocket session will contain the protocol accepted by the server, if
* any.
*/
public void setSubProtocols(List<String> protocols) {
this.protocols.clear(); this.protocols.clear();
if (!CollectionUtils.isEmpty(protocols)) { if (!CollectionUtils.isEmpty(protocols)) {
this.protocols.addAll(protocols); this.protocols.addAll(protocols);
} }
} }
public List<String> getSupportedProtocols() { /**
* Return the configured sub-protocols to use.
*/
public List<String> getSubProtocols() {
return this.protocols; return this.protocols;
} }
/**
* Provide default headers to add to the WebSocket handshake request.
*/
public void setHeaders(HttpHeaders headers) {
this.headers = headers;
}
/**
* Return the default headers for the WebSocket handshake request.
*/
public HttpHeaders getHeaders() {
return this.headers;
}
@Override @Override
public void startInternal() { public void startInternal() {
if (this.syncClientLifecycle) { if (this.syncClientLifecycle) {
@ -96,8 +122,13 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
@Override @Override
protected void openConnection() throws Exception { protected void openConnection() throws Exception {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
if (this.headers != null) {
headers.putAll(this.headers);
}
headers.setSecWebSocketProtocol(this.protocols); headers.setSecWebSocketProtocol(this.protocols);
this.webSocketSession = this.client.doHandshake(this.webSocketHandler, headers, getUri()); this.webSocketSession = this.client.doHandshake(this.webSocketHandler, headers, getUri());
} }

61
spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java

@ -17,11 +17,8 @@
package org.springframework.web.socket.client.endpoint; package org.springframework.web.socket.client.endpoint;
import java.net.URI; import java.net.URI;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import javax.websocket.ClientEndpointConfig; import javax.websocket.ClientEndpointConfig;
import javax.websocket.ClientEndpointConfig.Configurator; import javax.websocket.ClientEndpointConfig.Configurator;
@ -30,18 +27,14 @@ import javax.websocket.Endpoint;
import javax.websocket.HandshakeResponse; import javax.websocket.HandshakeResponse;
import javax.websocket.WebSocketContainer; import javax.websocket.WebSocketContainer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.StandardEndpointAdapter; import org.springframework.web.socket.adapter.StandardEndpointAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketSessionAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSessionAdapter;
import org.springframework.web.socket.client.WebSocketClient; import org.springframework.web.socket.client.AbstractWebSocketClient;
import org.springframework.web.socket.client.WebSocketConnectFailureException; import org.springframework.web.socket.client.WebSocketConnectFailureException;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
/** /**
* Initiates WebSocket requests to a WebSocket server programatically through the standard * Initiates WebSocket requests to a WebSocket server programatically through the standard
@ -50,9 +43,7 @@ import org.springframework.web.util.UriComponentsBuilder;
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public class StandardWebSocketClient implements WebSocketClient { public class StandardWebSocketClient extends AbstractWebSocketClient {
private static final Log logger = LogFactory.getLog(StandardWebSocketClient.class);
private final WebSocketContainer webSocketContainer; private final WebSocketContainer webSocketContainer;
@ -68,26 +59,8 @@ public class StandardWebSocketClient implements WebSocketClient {
@Override @Override
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVariables) protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler,
throws WebSocketConnectFailureException { HttpHeaders httpHeaders, URI uri, List<String> protocols) throws WebSocketConnectFailureException {
Assert.notNull(uriTemplate, "uriTemplate must not be null");
UriComponents uriComponents = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVariables).encode();
return doHandshake(webSocketHandler, null, uriComponents.toUri());
}
@Override
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders httpHeaders, URI uri)
throws WebSocketConnectFailureException {
Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
Assert.notNull(uri, "uri must not be null");
httpHeaders = (httpHeaders != null) ? httpHeaders : new HttpHeaders();
if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + uri);
}
StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter(); StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter();
session.setUri(uri); session.setUri(uri);
@ -95,11 +68,7 @@ public class StandardWebSocketClient implements WebSocketClient {
ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create(); ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create();
configBuidler.configurator(new StandardWebSocketClientConfigurator(httpHeaders)); configBuidler.configurator(new StandardWebSocketClientConfigurator(httpHeaders));
configBuidler.preferredSubprotocols(protocols);
List<String> protocols = httpHeaders.getSecWebSocketProtocol();
if (!protocols.isEmpty()) {
configBuidler.preferredSubprotocols(protocols);
}
try { try {
// TODO: do not block // TODO: do not block
@ -114,11 +83,7 @@ public class StandardWebSocketClient implements WebSocketClient {
} }
private static class StandardWebSocketClientConfigurator extends Configurator { private class StandardWebSocketClientConfigurator extends Configurator {
private static final Set<String> EXCLUDED_HEADERS = new HashSet<String>(
Arrays.asList("Sec-WebSocket-Accept", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key",
"Sec-WebSocket-Protocol", "Sec-WebSocket-Version"));
private final HttpHeaders httpHeaders; private final HttpHeaders httpHeaders;
@ -129,23 +94,15 @@ public class StandardWebSocketClient implements WebSocketClient {
@Override @Override
public void beforeRequest(Map<String, List<String>> headers) { public void beforeRequest(Map<String, List<String>> headers) {
for (String headerName : this.httpHeaders.keySet()) { headers.putAll(this.httpHeaders);
if (!EXCLUDED_HEADERS.contains(headerName)) {
List<String> value = this.httpHeaders.get(headerName);
if (logger.isTraceEnabled()) {
logger.trace("Adding header [" + headerName + "=" + value + "]");
}
headers.put(headerName, value);
}
}
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Handshake request headers: " + headers); logger.debug("Handshake request headers: " + headers);
} }
} }
@Override @Override
public void afterResponse(HandshakeResponse handshakeResponse) { public void afterResponse(HandshakeResponse response) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Handshake response headers: " + handshakeResponse.getHeaders()); logger.debug("Handshake response headers: " + response.getHeaders());
} }
} }
} }

27
spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java

@ -17,17 +17,16 @@
package org.springframework.web.socket.client.jetty; package org.springframework.web.socket.client.jetty;
import java.net.URI; import java.net.URI;
import java.util.List;
import org.apache.commons.logging.Log; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.SmartLifecycle; import org.springframework.context.SmartLifecycle;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter;
import org.springframework.web.socket.client.WebSocketClient; import org.springframework.web.socket.client.AbstractWebSocketClient;
import org.springframework.web.socket.client.WebSocketConnectFailureException; import org.springframework.web.socket.client.WebSocketConnectFailureException;
import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
@ -39,9 +38,7 @@ import org.springframework.web.util.UriComponentsBuilder;
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public class JettyWebSocketClient implements WebSocketClient, SmartLifecycle { public class JettyWebSocketClient extends AbstractWebSocketClient implements SmartLifecycle {
private static final Log logger = LogFactory.getLog(JettyWebSocketClient.class);
private final org.eclipse.jetty.websocket.client.WebSocketClient client; private final org.eclipse.jetty.websocket.client.WebSocketClient client;
@ -133,18 +130,16 @@ public class JettyWebSocketClient implements WebSocketClient, SmartLifecycle {
} }
@Override @Override
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri) public WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, HttpHeaders headers,
throws WebSocketConnectFailureException { URI uri, List<String> protocols) throws WebSocketConnectFailureException {
Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); ClientUpgradeRequest request = new ClientUpgradeRequest();
Assert.notNull(uri, "uri must not be null"); request.setSubProtocols(protocols);
if (logger.isDebugEnabled()) { for (String header : headers.keySet()) {
logger.debug("Connecting to " + uri); request.setHeader(header, headers.get(header));
} }
// TODO: populate headers
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter(); JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter();
session.setUri(uri); session.setUri(uri);
session.setRemoteHostName(uri.getHost()); session.setRemoteHostName(uri.getHost());
@ -153,7 +148,7 @@ public class JettyWebSocketClient implements WebSocketClient, SmartLifecycle {
try { try {
// TODO: do not block // TODO: do not block
this.client.connect(listener, uri).get(); this.client.connect(listener, uri, request).get();
return session; return session;
} }
catch (Exception e) { catch (Exception e) {

26
spring-websocket/src/test/java/org/springframework/web/socket/client/AbstractWebSocketClientTests.java

@ -0,0 +1,26 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.client;
/**
* Test fixture for {@link AbstractWebSocketClient}.
* @author Rossen Stoyanchev
*/
public class AbstractWebSocketClientTests {
}

9
spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java

@ -44,14 +44,12 @@ public class WebSocketConnectionManagerTests {
public void openConnection() throws Exception { public void openConnection() throws Exception {
List<String> subprotocols = Arrays.asList("abc"); List<String> subprotocols = Arrays.asList("abc");
HttpHeaders headers = new HttpHeaders();
headers.setSecWebSocketProtocol(subprotocols);
WebSocketClient client = mock(WebSocketClient.class); WebSocketClient client = mock(WebSocketClient.class);
WebSocketHandler handler = new WebSocketHandlerAdapter(); WebSocketHandler handler = new WebSocketHandlerAdapter();
WebSocketConnectionManager manager = new WebSocketConnectionManager(client, handler , "/path/{id}", "123"); WebSocketConnectionManager manager = new WebSocketConnectionManager(client, handler , "/path/{id}", "123");
manager.setSupportedProtocols(subprotocols); manager.setSubProtocols(subprotocols);
manager.openConnection(); manager.openConnection();
ArgumentCaptor<WebSocketHandlerDecorator> captor = ArgumentCaptor.forClass(WebSocketHandlerDecorator.class); ArgumentCaptor<WebSocketHandlerDecorator> captor = ArgumentCaptor.forClass(WebSocketHandlerDecorator.class);
@ -60,7 +58,10 @@ public class WebSocketConnectionManagerTests {
verify(client).doHandshake(captor.capture(), headersCaptor.capture(), uriCaptor.capture()); verify(client).doHandshake(captor.capture(), headersCaptor.capture(), uriCaptor.capture());
assertEquals(headers, headersCaptor.getValue()); HttpHeaders expectedHeaders = new HttpHeaders();
expectedHeaders.setSecWebSocketProtocol(subprotocols);
assertEquals(expectedHeaders, headersCaptor.getValue());
assertEquals(new URI("/path/123"), uriCaptor.getValue()); assertEquals(new URI("/path/123"), uriCaptor.getValue());
WebSocketHandlerDecorator loggingHandler = captor.getValue(); WebSocketHandlerDecorator loggingHandler = captor.getValue();

133
spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java

@ -0,0 +1,133 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.client.jetty;
import java.net.URI;
import java.util.Arrays;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.util.SocketUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter;
import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
import static org.junit.Assert.*;
/**
* Tests for {@link JettyWebSocketClient}.
* @author Rossen Stoyanchev
*/
public class JettyWebSocketClientTests {
private JettyWebSocketClient client;
private TestJettyWebSocketServer server;
private String wsUrl;
private WebSocketSession wsSession;
@Before
public void setup() throws Exception {
int port = SocketUtils.findAvailableTcpPort();
this.server = new TestJettyWebSocketServer(port, new TextWebSocketHandlerAdapter());
this.server.start();
this.client = new JettyWebSocketClient();
this.client.start();
this.wsUrl = "ws://localhost:" + port + "/test";
}
@After
public void teardown() throws Exception {
this.wsSession.close();
this.client.stop();
this.server.stop();
}
@Test
public void doHandshake() throws Exception {
HttpHeaders headers = new HttpHeaders();
headers.setSecWebSocketProtocol(Arrays.asList("echo"));
this.wsSession = this.client.doHandshake(new TextWebSocketHandlerAdapter(), headers, new URI(this.wsUrl));
assertEquals(this.wsUrl, this.wsSession.getUri().toString());
assertEquals("echo", this.wsSession.getAcceptedProtocol());
}
private static class TestJettyWebSocketServer {
private final Server server;
public TestJettyWebSocketServer(int port, final WebSocketHandler webSocketHandler) {
this.server = new Server();
ServerConnector connector = new ServerConnector(this.server);
connector.setPort(port);
this.server.addConnector(connector);
this.server.setHandler(new org.eclipse.jetty.websocket.server.WebSocketHandler() {
@Override
public void configure(WebSocketServletFactory factory) {
factory.setCreator(new WebSocketCreator() {
@Override
public Object createWebSocket(UpgradeRequest req, UpgradeResponse resp) {
if (!CollectionUtils.isEmpty(req.getSubProtocols())) {
resp.setAcceptedSubProtocol(req.getSubProtocols().get(0));
}
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter();
return new JettyWebSocketListenerAdapter(webSocketHandler, session);
}
});
}
});
}
public void start() throws Exception {
this.server.start();
}
public void stop() throws Exception {
this.server.stop();
}
}
}
Loading…
Cancel
Save