Browse Source

Polish WebSocketSession

Update methods available on WebSocketSession interface.
Introduce DelegatingWebSocketSession interface.
pull/333/merge
Rossen Stoyanchev 13 years ago
parent
commit
01feae0ad5
  1. 3
      spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java
  2. 13
      spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java
  3. 9
      spring-web/src/main/java/org/springframework/http/server/ServerHttpRequest.java
  4. 9
      spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java
  5. 19
      spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java
  6. 33
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java
  7. 27
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java
  8. 13
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapter.java
  9. 121
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java
  10. 144
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java
  11. 12
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapter.java
  12. 123
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java
  13. 145
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java
  14. 3
      spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java
  15. 39
      spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java
  16. 51
      spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java
  17. 15
      spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java
  18. 6
      spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java
  19. 22
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java
  20. 17
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java
  21. 43
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java
  22. 4
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java
  23. 3
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java
  24. 6
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java
  25. 5
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java
  26. 2
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java
  27. 57
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java
  28. 43
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java
  29. 71
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java
  30. 12
      spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java
  31. 12
      spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java
  32. 98
      spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java
  33. 8
      spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java
  34. 77
      spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java
  35. 11
      spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java
  36. 45
      spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java

3
spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandler.java

@ -30,6 +30,7 @@ import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException; import org.springframework.messaging.MessagingException;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketMessage;
@ -133,7 +134,7 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHan
protected final SubProtocolHandler getProtocolHandler(WebSocketSession session) { protected final SubProtocolHandler getProtocolHandler(WebSocketSession session) {
SubProtocolHandler handler; SubProtocolHandler handler;
String protocol = session.getAcceptedProtocol(); String protocol = session.getAcceptedProtocol();
if (protocol != null) { if (!StringUtils.isEmpty(protocol)) {
handler = this.protocolHandlers.get(protocol); handler = this.protocolHandlers.get(protocol);
Assert.state(handler != null, Assert.state(handler != null,
"No handler for sub-protocol '" + protocol + "', handlers=" + this.protocolHandlers); "No handler for sub-protocol '" + protocol + "', handlers=" + this.protocolHandlers);

13
spring-messaging/src/test/java/org/springframework/messaging/handler/websocket/SubProtocolWebSocketHandlerTests.java

@ -91,7 +91,18 @@ public class SubProtocolWebSocketHandlerTests {
} }
@Test @Test
public void noSubProtocol() throws Exception { public void nullSubProtocol() throws Exception {
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.afterConnectionEstablished(session);
verify(this.defaultHandler).afterSessionStarted(session, this.channel);
verify(this.stompHandler, times(0)).afterSessionStarted(session, this.channel);
verify(this.mqttHandler, times(0)).afterSessionStarted(session, this.channel);
}
@Test
public void emptySubProtocol() throws Exception {
this.session.setAcceptedProtocol("");
this.webSocketHandler.setDefaultProtocolHandler(defaultHandler); this.webSocketHandler.setDefaultProtocolHandler(defaultHandler);
this.webSocketHandler.afterConnectionEstablished(session); this.webSocketHandler.afterConnectionEstablished(session);

9
spring-web/src/main/java/org/springframework/http/server/ServerHttpRequest.java

@ -16,6 +16,7 @@
package org.springframework.http.server; package org.springframework.http.server;
import java.net.InetSocketAddress;
import java.security.Principal; import java.security.Principal;
import java.util.Map; import java.util.Map;
@ -51,14 +52,14 @@ public interface ServerHttpRequest extends HttpRequest, HttpInputMessage {
Principal getPrincipal(); Principal getPrincipal();
/** /**
* Return the host name of the endpoint on the other end. * Return the address on which the request was received.
*/ */
String getRemoteHostName(); InetSocketAddress getLocalAddress();
/** /**
* Return the IP address of the endpoint on the other end. * Return the address of the remote client.
*/ */
String getRemoteAddress(); InetSocketAddress getRemoteAddress();
/** /**
* Return a control that allows putting the request in asynchronous mode so the * Return a control that allows putting the request in asynchronous mode so the

9
spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java

@ -22,6 +22,7 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStreamWriter; import java.io.OutputStreamWriter;
import java.io.Writer; import java.io.Writer;
import java.net.InetSocketAddress;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.URLEncoder; import java.net.URLEncoder;
@ -147,13 +148,13 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
} }
@Override @Override
public String getRemoteHostName() { public InetSocketAddress getLocalAddress() {
return this.servletRequest.getRemoteHost(); return new InetSocketAddress(this.servletRequest.getLocalName(), this.servletRequest.getLocalPort());
} }
@Override @Override
public String getRemoteAddress() { public InetSocketAddress getRemoteAddress() {
return this.servletRequest.getRemoteAddr(); return new InetSocketAddress(this.servletRequest.getRemoteHost(), this.servletRequest.getRemotePort());
} }
@Override @Override

19
spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java

@ -17,9 +17,12 @@
package org.springframework.web.socket; package org.springframework.web.socket;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI; import java.net.URI;
import java.security.Principal; import java.security.Principal;
import org.springframework.http.HttpHeaders;
/** /**
* A WebSocket session abstraction. Allows sending messages over a WebSocket connection * A WebSocket session abstraction. Allows sending messages over a WebSocket connection
* and closing it. * and closing it.
@ -29,6 +32,7 @@ import java.security.Principal;
*/ */
public interface WebSocketSession { public interface WebSocketSession {
/** /**
* Return a unique session identifier. * Return a unique session identifier.
*/ */
@ -40,9 +44,9 @@ public interface WebSocketSession {
URI getUri(); URI getUri();
/** /**
* Return whether the underlying socket is using a secure transport. * Return the headers used in the handshake request.
*/ */
boolean isSecure(); HttpHeaders getHandshakeHeaders();
/** /**
* Return a {@link java.security.Principal} instance containing the name of the * Return a {@link java.security.Principal} instance containing the name of the
@ -52,17 +56,18 @@ public interface WebSocketSession {
Principal getPrincipal(); Principal getPrincipal();
/** /**
* Return the host name of the endpoint on the other end. * Return the address on which the request was received.
*/ */
String getRemoteHostName(); InetSocketAddress getLocalAddress();
/** /**
* Return the IP address of the endpoint on the other end. * Return the address of the remote client.
*/ */
String getRemoteAddress(); InetSocketAddress getRemoteAddress();
/** /**
* Return the negotiated sub-protocol or {@code null} if none was specified. * Return the negotiated sub-protocol or {@code null} if none was specified or
* negotiated successfully.
*/ */
String getAcceptedProtocol(); String getAcceptedProtocol();

33
spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssionAdapter.java → spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSesssion.java

@ -32,19 +32,41 @@ import org.springframework.web.socket.WebSocketSession;
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public abstract class AbstractWebSocketSesssionAdapter<T> implements ConfigurableWebSocketSession { public abstract class AbstractWebSocketSesssion<T> implements DelegatingWebSocketSession<T> {
protected final Log logger = LogFactory.getLog(getClass()); protected final Log logger = LogFactory.getLog(getClass());
private T delegateSession;
public abstract void initSession(T session);
/**
* @return the WebSocket session to delegate to
*/
public T getDelegateSession() {
return this.delegateSession;
}
@Override
public void afterSessionInitialized(T session) {
Assert.notNull(session, "session must not be null");
this.delegateSession = session;
}
protected final void checkDelegateSessionInitialized() {
Assert.state(this.delegateSession != null, "WebSocket session is not yet initialized");
}
@Override @Override
public final void sendMessage(WebSocketMessage message) throws IOException { public final void sendMessage(WebSocketMessage message) throws IOException {
checkDelegateSessionInitialized();
Assert.isTrue(isOpen(), "Cannot send message after connection closed.");
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Sending " + message + ", " + this); logger.trace("Sending " + message + ", " + this);
} }
Assert.isTrue(isOpen(), "Cannot send message after connection closed.");
if (message instanceof TextMessage) { if (message instanceof TextMessage) {
sendTextMessage((TextMessage) message); sendTextMessage((TextMessage) message);
} }
@ -60,13 +82,15 @@ public abstract class AbstractWebSocketSesssionAdapter<T> implements Configurabl
protected abstract void sendBinaryMessage(BinaryMessage message) throws IOException ; protected abstract void sendBinaryMessage(BinaryMessage message) throws IOException ;
@Override @Override
public void close() throws IOException { public final void close() throws IOException {
close(CloseStatus.NORMAL); close(CloseStatus.NORMAL);
} }
@Override @Override
public final void close(CloseStatus status) throws IOException { public final void close(CloseStatus status) throws IOException {
checkDelegateSessionInitialized();
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Closing " + this); logger.debug("Closing " + this);
} }
@ -75,6 +99,7 @@ public abstract class AbstractWebSocketSesssionAdapter<T> implements Configurabl
protected abstract void closeInternal(CloseStatus status) throws IOException; protected abstract void closeInternal(CloseStatus status) throws IOException;
@Override @Override
public String toString() { public String toString() {
return "WebSocket session id=" + getId(); return "WebSocket session id=" + getId();

27
spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java → spring-websocket/src/main/java/org/springframework/web/socket/adapter/DelegatingWebSocketSession.java

@ -16,34 +16,25 @@
package org.springframework.web.socket.adapter; package org.springframework.web.socket.adapter;
import java.net.URI;
import java.security.Principal;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
/** /**
* A WebSocketSession with configurable properties. * A contract for {@link WebSocketSession} implementations that delegate to another
* WebSocket session (e.g. a native session).
*
* @param T the type of the delegate WebSocket session
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public interface ConfigurableWebSocketSession extends WebSocketSession { public interface DelegatingWebSocketSession<T> extends WebSocketSession {
void setUri(URI uri);
void setRemoteHostName(String name);
void setRemoteAddress(String address);
void setPrincipal(Principal principal);
/** /**
* Set the protocol accepted as part of the WebSocket handshake. This property can be * Invoked when the delegate WebSocket session has been initialized.
* used when the WebSocket handshake is performed through
* {@link DefaultHandshakeHandler} rather than the underlying WebSocket runtime, or
* when there is no WebSocket handshake (e.g. SockJS HTTP fallback options)
*/ */
void setAcceptedProtocol(String protocol); void afterSessionInitialized(T session);
} }

13
spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketListenerAdapter.java → spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapter.java

@ -28,21 +28,22 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator; import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator;
/** /**
* Adapts {@link WebSocketHandler} to the Jetty 9 {@link WebSocketListener}. * Adapts {@link WebSocketHandler} to the Jetty 9 WebSocket API.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public class JettyWebSocketListenerAdapter implements WebSocketListener { public class JettyWebSocketHandlerAdapter implements WebSocketListener {
private static final Log logger = LogFactory.getLog(JettyWebSocketListenerAdapter.class); private static final Log logger = LogFactory.getLog(JettyWebSocketHandlerAdapter.class);
private final WebSocketHandler webSocketHandler; private final WebSocketHandler webSocketHandler;
private final JettyWebSocketSessionAdapter wsSession; private final JettyWebSocketSession wsSession;
public JettyWebSocketListenerAdapter(WebSocketHandler webSocketHandler, JettyWebSocketSessionAdapter wsSession) { public JettyWebSocketHandlerAdapter(WebSocketHandler webSocketHandler, JettyWebSocketSession wsSession) {
Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
Assert.notNull(wsSession, "wsSession must not be null"); Assert.notNull(wsSession, "wsSession must not be null");
this.webSocketHandler = webSocketHandler; this.webSocketHandler = webSocketHandler;
@ -52,8 +53,8 @@ public class JettyWebSocketListenerAdapter implements WebSocketListener {
@Override @Override
public void onWebSocketConnect(Session session) { public void onWebSocketConnect(Session session) {
this.wsSession.initSession(session);
try { try {
this.wsSession.afterSessionInitialized(session);
this.webSocketHandler.afterConnectionEstablished(this.wsSession); this.webSocketHandler.afterConnectionEstablished(this.wsSession);
} }
catch (Throwable t) { catch (Throwable t) {

121
spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java

@ -0,0 +1,121 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import org.springframework.http.HttpHeaders;
import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* A {@link WebSocketSession} for use with the Jetty 9 WebSocket API.
*
* @author Phillip Webb
* @author Rossen Stoyanchev
* @since 4.0
*/
public class JettyWebSocketSession extends AbstractWebSocketSesssion<org.eclipse.jetty.websocket.api.Session> {
private HttpHeaders headers;
private final Principal principal;
/**
* Class constructor.
*
* @param principal the user associated with the session, or {@code null}
*/
public JettyWebSocketSession(Principal principal) {
this.principal = principal;
}
@Override
public String getId() {
checkDelegateSessionInitialized();
return ObjectUtils.getIdentityHexString(getDelegateSession());
}
@Override
public URI getUri() {
checkDelegateSessionInitialized();
return getDelegateSession().getUpgradeRequest().getRequestURI();
}
@Override
public HttpHeaders getHandshakeHeaders() {
checkDelegateSessionInitialized();
if (this.headers == null) {
this.headers = new HttpHeaders();
this.headers.putAll(getDelegateSession().getUpgradeRequest().getHeaders());
this.headers = HttpHeaders.readOnlyHttpHeaders(headers);
}
return this.headers;
}
@Override
public Principal getPrincipal() {
return this.principal;
}
@Override
public InetSocketAddress getLocalAddress() {
checkDelegateSessionInitialized();
return getDelegateSession().getLocalAddress();
}
@Override
public InetSocketAddress getRemoteAddress() {
checkDelegateSessionInitialized();
return getDelegateSession().getRemoteAddress();
}
@Override
public String getAcceptedProtocol() {
checkDelegateSessionInitialized();
return getDelegateSession().getUpgradeResponse().getAcceptedSubProtocol();
}
@Override
public boolean isOpen() {
return ((getDelegateSession() != null) && getDelegateSession().isOpen());
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
getDelegateSession().getRemote().sendString(message.getPayload());
}
@Override
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
getDelegateSession().getRemote().sendBytes(message.getPayload());
}
@Override
protected void closeInternal(CloseStatus status) throws IOException {
getDelegateSession().close(status.getCode(), status.getReason());
}
}

144
spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java

@ -1,144 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* Adapts a Jetty {@link org.eclipse.jetty.websocket.api.Session} to
* {@link WebSocketSession}.
*
* @author Phillip Webb
* @author Rossen Stoyanchev
* @since 4.0
*/
public class JettyWebSocketSessionAdapter
extends AbstractWebSocketSesssionAdapter<org.eclipse.jetty.websocket.api.Session> {
private Session session;
private Principal principal;
private String protocol;
@Override
public void initSession(Session session) {
Assert.notNull(session, "session must not be null");
this.session = session;
if (this.protocol == null) {
UpgradeResponse response = session.getUpgradeResponse();
if ((response != null) && response.getAcceptedSubProtocol() != null) {
this.protocol = response.getAcceptedSubProtocol();
}
}
}
@Override
public String getId() {
return ObjectUtils.getIdentityHexString(this.session);
}
@Override
public boolean isSecure() {
return this.session.isSecure();
}
@Override
public URI getUri() {
return this.session.getUpgradeRequest().getRequestURI();
}
@Override
public void setUri(URI uri) {
}
@Override
public Principal getPrincipal() {
return this.principal;
}
@Override
public void setPrincipal(Principal principal) {
this.principal = principal;
}
@Override
public String getRemoteHostName() {
return this.session.getRemoteAddress().getHostName();
}
@Override
public void setRemoteHostName(String address) {
// ignore
}
@Override
public String getRemoteAddress() {
InetSocketAddress address = this.session.getRemoteAddress();
return address.isUnresolved() ? null : address.getAddress().getHostAddress();
}
@Override
public void setRemoteAddress(String address) {
// ignore
}
@Override
public String getAcceptedProtocol() {
return this.protocol;
}
@Override
public void setAcceptedProtocol(String protocol) {
this.protocol = protocol;
}
@Override
public boolean isOpen() {
return this.session.isOpen();
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
this.session.getRemote().sendString(message.getPayload());
}
@Override
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
this.session.getRemote().sendBytes(message.getPayload());
}
@Override
protected void closeInternal(CloseStatus status) throws IOException {
this.session.close(status.getCode(), status.getReason());
}
}

12
spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardEndpointAdapter.java → spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapter.java

@ -33,21 +33,21 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator; import org.springframework.web.socket.support.ExceptionWebSocketHandlerDecorator;
/** /**
* Adapts a {@link WebSocketHandler} to a standard {@link Endpoint}. * Adapts a {@link WebSocketHandler} to the standard WebSocket for Java API.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public class StandardEndpointAdapter extends Endpoint { public class StandardWebSocketHandlerAdapter extends Endpoint {
private static final Log logger = LogFactory.getLog(StandardEndpointAdapter.class); private static final Log logger = LogFactory.getLog(StandardWebSocketHandlerAdapter.class);
private final WebSocketHandler handler; private final WebSocketHandler handler;
private final StandardWebSocketSessionAdapter wsSession; private final StandardWebSocketSession wsSession;
public StandardEndpointAdapter(WebSocketHandler handler, StandardWebSocketSessionAdapter wsSession) { public StandardWebSocketHandlerAdapter(WebSocketHandler handler, StandardWebSocketSession wsSession) {
Assert.notNull(handler, "handler must not be null"); Assert.notNull(handler, "handler must not be null");
Assert.notNull(wsSession, "wsSession must not be null"); Assert.notNull(wsSession, "wsSession must not be null");
this.handler = handler; this.handler = handler;
@ -58,7 +58,7 @@ public class StandardEndpointAdapter extends Endpoint {
@Override @Override
public void onOpen(final javax.websocket.Session session, EndpointConfig config) { public void onOpen(final javax.websocket.Session session, EndpointConfig config) {
this.wsSession.initSession(session); this.wsSession.afterSessionInitialized(session);
if (this.handler.supportsPartialMessages()) { if (this.handler.supportsPartialMessages()) {
session.addMessageHandler(new MessageHandler.Partial<String>() { session.addMessageHandler(new MessageHandler.Partial<String>() {

123
spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java

@ -0,0 +1,123 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import org.springframework.http.HttpHeaders;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* A {@link WebSocketSession} for use with the standard WebSocket for Java API.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StandardWebSocketSession extends AbstractWebSocketSesssion<javax.websocket.Session> {
private final HttpHeaders headers;
private final InetSocketAddress localAddress;
private final InetSocketAddress remoteAddress;
/**
* Class constructor.
*
* @param handshakeHeaders the headers of the handshake request
*/
public StandardWebSocketSession(HttpHeaders handshakeHeaders, InetSocketAddress localAddress,
InetSocketAddress remoteAddress) {
handshakeHeaders = (handshakeHeaders != null) ? handshakeHeaders : new HttpHeaders();
this.headers = HttpHeaders.readOnlyHttpHeaders(handshakeHeaders);
this.localAddress = localAddress;
this.remoteAddress = remoteAddress;
}
@Override
public String getId() {
checkDelegateSessionInitialized();
return getDelegateSession().getId();
}
@Override
public URI getUri() {
checkDelegateSessionInitialized();
return getDelegateSession().getRequestURI();
}
@Override
public HttpHeaders getHandshakeHeaders() {
return this.headers;
}
@Override
public Principal getPrincipal() {
checkDelegateSessionInitialized();
return getDelegateSession().getUserPrincipal();
}
@Override
public InetSocketAddress getLocalAddress() {
return this.localAddress;
}
@Override
public InetSocketAddress getRemoteAddress() {
return this.remoteAddress;
}
@Override
public String getAcceptedProtocol() {
checkDelegateSessionInitialized();
String protocol = getDelegateSession().getNegotiatedSubprotocol();
return StringUtils.isEmpty(protocol)? null : protocol;
}
@Override
public boolean isOpen() {
return ((getDelegateSession() != null) && getDelegateSession().isOpen());
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
getDelegateSession().getBasicRemote().sendText(message.getPayload(), message.isLast());
}
@Override
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
getDelegateSession().getBasicRemote().sendBinary(message.getPayload(), message.isLast());
}
@Override
protected void closeInternal(CloseStatus status) throws IOException {
getDelegateSession().close(new CloseReason(CloseCodes.getCloseCode(status.getCode()), status.getReason()));
}
}

145
spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java

@ -1,145 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.adapter;
import java.io.IOException;
import java.net.URI;
import java.security.Principal;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* Adapts a standard {@link javax.websocket.Session} to {@link WebSocketSession}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAdapter<javax.websocket.Session> {
private javax.websocket.Session session;
private URI uri;
private String remoteHostName;
private String remoteAddress;
private String protocol;
@Override
public void initSession(javax.websocket.Session session) {
Assert.notNull(session, "session must not be null");
this.session = session;
if (this.protocol == null) {
if (StringUtils.hasText(session.getNegotiatedSubprotocol())) {
this.protocol = session.getNegotiatedSubprotocol();
}
}
}
@Override
public String getId() {
return this.session.getId();
}
@Override
public URI getUri() {
return this.uri;
}
@Override
public void setUri(URI uri) {
this.uri = uri;
}
@Override
public boolean isSecure() {
return this.session.isSecure();
}
@Override
public Principal getPrincipal() {
return this.session.getUserPrincipal();
}
@Override
public void setPrincipal(Principal principal) {
// ignore
}
@Override
public String getRemoteHostName() {
return this.remoteHostName;
}
@Override
public void setRemoteHostName(String name) {
this.remoteHostName = name;
}
@Override
public String getRemoteAddress() {
return this.remoteAddress;
}
@Override
public void setRemoteAddress(String address) {
this.remoteAddress = address;
}
@Override
public String getAcceptedProtocol() {
return this.protocol;
}
@Override
public void setAcceptedProtocol(String protocol) {
this.protocol = protocol;
}
@Override
public boolean isOpen() {
return this.session.isOpen();
}
@Override
protected void sendTextMessage(TextMessage message) throws IOException {
this.session.getBasicRemote().sendText(message.getPayload(), message.isLast());
}
@Override
protected void sendBinaryMessage(BinaryMessage message) throws IOException {
this.session.getBasicRemote().sendBinary(message.getPayload(), message.isLast());
}
@Override
protected void closeInternal(CloseStatus status) throws IOException {
this.session.close(new CloseReason(CloseCodes.getCloseCode(status.getCode()), status.getReason()));
}
}

3
spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java

@ -73,6 +73,9 @@ public abstract class AbstractWebSocketClient implements WebSocketClient {
Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); Assert.notNull(webSocketHandler, "webSocketHandler must not be null");
Assert.notNull(uri, "uri must not be null"); Assert.notNull(uri, "uri must not be null");
String scheme = uri.getScheme();
Assert.isTrue(((scheme != null) && ("ws".equals(scheme) || "wss".equals(scheme))), "Invalid scheme: " + scheme);
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Connecting to " + uri); logger.debug("Connecting to " + uri);
} }

39
spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java

@ -16,12 +16,10 @@
package org.springframework.web.socket.client; package org.springframework.web.socket.client;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.springframework.context.SmartLifecycle; import org.springframework.context.SmartLifecycle;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator; import org.springframework.web.socket.support.LoggingWebSocketHandlerDecorator;
@ -43,9 +41,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
private WebSocketSession webSocketSession; private WebSocketSession webSocketSession;
private final List<String> protocols = new ArrayList<String>(); private HttpHeaders headers = new HttpHeaders();
private HttpHeaders headers;
private final boolean syncClientLifecycle; private final boolean syncClientLifecycle;
@ -76,24 +72,36 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
* any. * any.
*/ */
public void setSubProtocols(List<String> protocols) { public void setSubProtocols(List<String> protocols) {
this.protocols.clear(); this.headers.setSecWebSocketProtocol(protocols);
if (!CollectionUtils.isEmpty(protocols)) {
this.protocols.addAll(protocols);
}
} }
/** /**
* Return the configured sub-protocols to use. * Return the configured sub-protocols to use.
*/ */
public List<String> getSubProtocols() { public List<String> getSubProtocols() {
return this.protocols; return this.headers.getSecWebSocketProtocol();
}
/**
* Set the origin to use.
*/
public void setOrigin(String origin) {
this.headers.setOrigin(origin);
}
/**
* @return the configured origin.
*/
public String getOrigin() {
return this.headers.getOrigin();
} }
/** /**
* Provide default headers to add to the WebSocket handshake request. * Provide default headers to add to the WebSocket handshake request.
*/ */
public void setHeaders(HttpHeaders headers) { public void setHeaders(HttpHeaders headers) {
this.headers = headers; this.headers.clear();
this.headers.putAll(headers);
} }
/** /**
@ -122,14 +130,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
@Override @Override
protected void openConnection() throws Exception { protected void openConnection() throws Exception {
this.webSocketSession = this.client.doHandshake(this.webSocketHandler, this.headers, getUri());
HttpHeaders headers = new HttpHeaders();
if (this.headers != null) {
headers.putAll(this.headers);
}
headers.setSecWebSocketProtocol(this.protocols);
this.webSocketSession = this.client.doHandshake(this.webSocketHandler, headers, getUri());
} }
@Override @Override

51
spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClient.java

@ -16,8 +16,12 @@
package org.springframework.web.socket.client.endpoint; package org.springframework.web.socket.client.endpoint;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI; import java.net.URI;
import java.net.UnknownHostException;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
import javax.websocket.ClientEndpointConfig; import javax.websocket.ClientEndpointConfig;
@ -31,8 +35,8 @@ import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.StandardEndpointAdapter; import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketSessionAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSession;
import org.springframework.web.socket.client.AbstractWebSocketClient; import org.springframework.web.socket.client.AbstractWebSocketClient;
import org.springframework.web.socket.client.WebSocketConnectFailureException; import org.springframework.web.socket.client.WebSocketConnectFailureException;
@ -60,19 +64,21 @@ public class StandardWebSocketClient extends AbstractWebSocketClient {
@Override @Override
protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, protected WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler,
HttpHeaders httpHeaders, URI uri, List<String> protocols) throws WebSocketConnectFailureException { HttpHeaders headers, URI uri, List<String> protocols) throws WebSocketConnectFailureException {
StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter(); int port = getPort(uri);
session.setUri(uri); InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
session.setRemoteHostName(uri.getHost()); InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);
StandardWebSocketSession session = new StandardWebSocketSession(headers, localAddress, remoteAddress);
ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create(); ClientEndpointConfig.Builder configBuidler = ClientEndpointConfig.Builder.create();
configBuidler.configurator(new StandardWebSocketClientConfigurator(httpHeaders)); configBuidler.configurator(new StandardWebSocketClientConfigurator(headers));
configBuidler.preferredSubprotocols(protocols); configBuidler.preferredSubprotocols(protocols);
try { try {
// TODO: do not block // TODO: do not block
Endpoint endpoint = new StandardEndpointAdapter(webSocketHandler, session); Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
this.webSocketContainer.connectToServer(endpoint, configBuidler.build(), uri); this.webSocketContainer.connectToServer(endpoint, configBuidler.build(), uri);
return session; return session;
@ -82,21 +88,38 @@ public class StandardWebSocketClient extends AbstractWebSocketClient {
} }
} }
private InetAddress getLocalHost() {
try {
return InetAddress.getLocalHost();
}
catch (UnknownHostException e) {
return InetAddress.getLoopbackAddress();
}
}
private int getPort(URI uri) {
if (uri.getPort() == -1) {
String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH);
return "wss".equals(scheme) ? 443 : 80;
}
return uri.getPort();
}
private class StandardWebSocketClientConfigurator extends Configurator { private class StandardWebSocketClientConfigurator extends Configurator {
private final HttpHeaders httpHeaders; private final HttpHeaders headers;
public StandardWebSocketClientConfigurator(HttpHeaders httpHeaders) { public StandardWebSocketClientConfigurator(HttpHeaders headers) {
this.httpHeaders = httpHeaders; this.headers = headers;
} }
@Override @Override
public void beforeRequest(Map<String, List<String>> headers) { public void beforeRequest(Map<String, List<String>> requestHeaders) {
headers.putAll(this.httpHeaders); requestHeaders.putAll(this.headers);
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Handshake request headers: " + headers); logger.debug("Handshake request headers: " + requestHeaders);
} }
} }
@Override @Override

15
spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java

@ -24,8 +24,8 @@ import org.springframework.context.SmartLifecycle;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession;
import org.springframework.web.socket.client.AbstractWebSocketClient; import org.springframework.web.socket.client.AbstractWebSocketClient;
import org.springframework.web.socket.client.WebSocketConnectFailureException; import org.springframework.web.socket.client.WebSocketConnectFailureException;
import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponents;
@ -130,7 +130,7 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
} }
@Override @Override
public WebSocketSession doHandshakeInternal(WebSocketHandler webSocketHandler, HttpHeaders headers, public WebSocketSession doHandshakeInternal(WebSocketHandler wsHandler, HttpHeaders headers,
URI uri, List<String> protocols) throws WebSocketConnectFailureException { URI uri, List<String> protocols) throws WebSocketConnectFailureException {
ClientUpgradeRequest request = new ClientUpgradeRequest(); ClientUpgradeRequest request = new ClientUpgradeRequest();
@ -140,16 +140,13 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma
request.setHeader(header, headers.get(header)); request.setHeader(header, headers.get(header));
} }
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter(); JettyWebSocketSession wsSession = new JettyWebSocketSession(null);
session.setUri(uri); JettyWebSocketHandlerAdapter listener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession);
session.setRemoteHostName(uri.getHost());
JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler, session);
try { try {
// TODO: do not block // TODO: do not block
this.client.connect(listener, uri, request).get(); this.client.connect(listener, uri, request).get();
return session; return wsSession;
} }
catch (Exception e) { catch (Exception e) {
throw new WebSocketConnectFailureException("Failed to connect to " + uri, e); throw new WebSocketConnectFailureException("Failed to connect to " + uri, e);

6
spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java

@ -201,10 +201,14 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
protected String selectProtocol(List<String> requestedProtocols) { protected String selectProtocol(List<String> requestedProtocols) {
if (requestedProtocols != null) { if (requestedProtocols != null) {
if (logger.isDebugEnabled()) {
logger.debug("Requested sub-protocol(s): " + requestedProtocols
+ ", supported sub-protocol(s): " + this.supportedProtocols);
}
for (String protocol : requestedProtocols) { for (String protocol : requestedProtocols) {
if (this.supportedProtocols.contains(protocol.toLowerCase())) { if (this.supportedProtocols.contains(protocol.toLowerCase())) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Selected sub-protocol '" + protocol + "'"); logger.debug("Selected sub-protocol: '" + protocol + "'");
} }
return protocol; return protocol;
} }

22
spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java

@ -17,16 +17,18 @@
package org.springframework.web.socket.server.support; package org.springframework.web.socket.server.support;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import javax.websocket.Endpoint; import javax.websocket.Endpoint;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.StandardEndpointAdapter; import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketSessionAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSession;
import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.RequestUpgradeStrategy; import org.springframework.web.socket.server.RequestUpgradeStrategy;
@ -40,17 +42,19 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS
protected final Log logger = LogFactory.getLog(getClass()); protected final Log logger = LogFactory.getLog(getClass());
private final ServerWebSocketSessionInitializer wsSessionInitializer = new ServerWebSocketSessionInitializer();
@Override @Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response, public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
String protocol, WebSocketHandler handler) throws IOException, HandshakeFailureException { String acceptedProtocol, WebSocketHandler wsHandler) throws IOException, HandshakeFailureException {
HttpHeaders headers = request.getHeaders();
InetSocketAddress localAddress = request.getLocalAddress();
InetSocketAddress remoteAddress = request.getRemoteAddress();
StandardWebSocketSession wsSession = new StandardWebSocketSession(headers, localAddress, remoteAddress);
StandardWebSocketHandlerAdapter endpoint = new StandardWebSocketHandlerAdapter(wsHandler, wsSession);
StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter(); upgradeInternal(request, response, acceptedProtocol, endpoint);
this.wsSessionInitializer.initialize(request, response, protocol, session);
StandardEndpointAdapter endpoint = new StandardEndpointAdapter(handler, session);
upgradeInternal(request, response, protocol, endpoint);
} }
protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,

17
spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java

@ -33,8 +33,8 @@ import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession;
import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.RequestUpgradeStrategy; import org.springframework.web.socket.server.RequestUpgradeStrategy;
@ -59,8 +59,6 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
private WebSocketServerFactory factory; private WebSocketServerFactory factory;
private final ServerWebSocketSessionInitializer wsSessionInitializer = new ServerWebSocketSessionInitializer();
public JettyRequestUpgradeStrategy() { public JettyRequestUpgradeStrategy() {
this.factory = new WebSocketServerFactory(); this.factory = new WebSocketServerFactory();
@ -87,7 +85,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
@Override @Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response, public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
String protocol, WebSocketHandler webSocketHandler) throws IOException { String protocol, WebSocketHandler wsHandler) throws IOException {
Assert.isInstanceOf(ServletServerHttpRequest.class, request); Assert.isInstanceOf(ServletServerHttpRequest.class, request);
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
@ -100,14 +98,13 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
throw new HandshakeFailureException("Not a WebSocket request"); throw new HandshakeFailureException("Not a WebSocket request");
} }
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter(); JettyWebSocketSession wsSession = new JettyWebSocketSession(request.getPrincipal());
this.wsSessionInitializer.initialize(request, response, protocol, session); JettyWebSocketHandlerAdapter wsListener = new JettyWebSocketHandlerAdapter(wsHandler, wsSession);
JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler, session);
servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, listener); servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, wsListener);
if (!this.factory.acceptWebSocket(servletRequest, servletResponse)) { if (!this.factory.acceptWebSocket(servletRequest, servletResponse)) {
// should never happen // should not happen
throw new HandshakeFailureException("WebSocket request not accepted by Jetty"); throw new HandshakeFailureException("WebSocket request not accepted by Jetty");
} }
} }

43
spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java

@ -1,43 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.support;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.ConfigurableWebSocketSession;
/**
* Copies information from the handshake HTTP request and response to a given
* {@link WebSocketSession}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ServerWebSocketSessionInitializer {
public void initialize(ServerHttpRequest request, ServerHttpResponse response,
String protocol, ConfigurableWebSocketSession session) {
session.setUri(request.getURI());
session.setRemoteHostName(request.getRemoteHostName());
session.setRemoteAddress(request.getRemoteAddress());
session.setPrincipal(request.getPrincipal());
session.setAcceptedProtocol(protocol);
}
}

4
spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java

@ -52,7 +52,7 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg
@Override @Override
public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
String selectedProtocol, Endpoint endpoint) throws IOException { String acceptedProtocol, Endpoint endpoint) throws IOException {
Assert.isTrue(request instanceof ServletServerHttpRequest); Assert.isTrue(request instanceof ServletServerHttpRequest);
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
@ -82,7 +82,7 @@ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrateg
ServerEndpointConfig endpointConfig = new ServerEndpointRegistration("/shouldntmatter", endpoint); ServerEndpointConfig endpointConfig = new ServerEndpointRegistration("/shouldntmatter", endpoint);
upgradeHandler.preInit(endpoint, endpointConfig, serverContainer, webSocketRequest, upgradeHandler.preInit(endpoint, endpointConfig, serverContainer, webSocketRequest,
selectedProtocol, Collections.<String, String> emptyMap(), servletRequest.isSecure()); acceptedProtocol, Collections.<String, String> emptyMap(), servletRequest.isSecure());
} }
} }

3
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpReceivingTransportHandler.java

@ -47,9 +47,6 @@ public abstract class AbstractHttpReceivingTransportHandler
public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, WebSocketSession wsSession) throws SockJsException { WebSocketHandler wsHandler, WebSocketSession wsSession) throws SockJsException {
// TODO: check "Sec-WebSocket-Protocol" header
// https://github.com/sockjs/sockjs-client/issues/130
Assert.notNull(wsSession, "No session"); Assert.notNull(wsSession, "No session");
AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession; AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession;

6
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java

@ -43,10 +43,14 @@ public abstract class AbstractHttpSendingTransportHandler extends TransportHandl
public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, WebSocketSession wsSession) throws SockJsException { WebSocketHandler wsHandler, WebSocketSession wsSession) throws SockJsException {
AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession;
String protocol = null; // TODO: https://github.com/sockjs/sockjs-client/issues/130
sockJsSession.setAcceptedProtocol(protocol);
// Set content type before writing // Set content type before writing
response.getHeaders().setContentType(getContentType()); response.getHeaders().setContentType(getContentType());
AbstractHttpSockJsSession sockJsSession = (AbstractHttpSockJsSession) wsSession;
handleRequestInternal(request, response, sockJsSession); handleRequestInternal(request, response, sockJsSession);
} }

5
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsService.java

@ -42,7 +42,6 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.server.DefaultHandshakeHandler; import org.springframework.web.socket.server.DefaultHandshakeHandler;
import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.support.ServerWebSocketSessionInitializer;
import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsException;
import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.support.AbstractSockJsService; import org.springframework.web.socket.sockjs.support.AbstractSockJsService;
@ -77,8 +76,6 @@ public class DefaultSockJsService extends AbstractSockJsService {
private final Map<String, AbstractSockJsSession> sessions = new ConcurrentHashMap<String, AbstractSockJsSession>(); private final Map<String, AbstractSockJsSession> sessions = new ConcurrentHashMap<String, AbstractSockJsSession>();
private final ServerWebSocketSessionInitializer sessionInitializer = new ServerWebSocketSessionInitializer();
private ScheduledFuture sessionCleanupTask; private ScheduledFuture sessionCleanupTask;
@ -279,8 +276,6 @@ public class DefaultSockJsService extends AbstractSockJsService {
} }
logger.debug("Creating new session with session id \"" + sessionId + "\""); logger.debug("Creating new session with session id \"" + sessionId + "\"");
session = sessionFactory.createSession(sessionId, handler); session = sessionFactory.createSession(sessionId, handler);
String protocol = null; // TODO: https://github.com/sockjs/sockjs-client/issues/130
this.sessionInitializer.initialize(request, response, protocol, session);
this.sessions.put(sessionId, session); this.sessions.put(sessionId, session);
return session; return session;
} }

2
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/SockJsWebSocketHandler.java

@ -69,7 +69,7 @@ public class SockJsWebSocketHandler extends TextWebSocketHandlerAdapter {
@Override @Override
public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception { public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception {
Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection"); Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection");
this.sockJsSession.initWebSocketSession(wsSession); this.sockJsSession.afterSessionInitialized(wsSession);
} }
@Override @Override

57
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java

@ -17,9 +17,12 @@
package org.springframework.web.socket.sockjs.transport.session; package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpAsyncRequestControl; import org.springframework.http.server.ServerHttpAsyncRequestControl;
import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServerHttpResponse;
@ -51,11 +54,55 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession {
private String protocol; private String protocol;
private HttpHeaders handshakeHeaders;
public AbstractHttpSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { private Principal principal;
super(sessionId, config, handler);
private InetSocketAddress localAddress;
private InetSocketAddress remoteAddress;
public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) {
super(id, config, wsHandler);
}
@Override
public HttpHeaders getHandshakeHeaders() {
return this.handshakeHeaders;
}
protected void setHandshakeHeaders(HttpHeaders handshakeHeaders) {
this.handshakeHeaders = handshakeHeaders;
}
@Override
public Principal getPrincipal() {
return this.principal;
} }
protected void setPrincipal(Principal principal) {
this.principal = principal;
}
@Override
public InetSocketAddress getLocalAddress() {
return this.localAddress;
}
protected void setLocalAddress(InetSocketAddress localAddress) {
this.localAddress = localAddress;
}
@Override
public InetSocketAddress getRemoteAddress() {
return this.remoteAddress;
}
protected void setRemoteAddress(InetSocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
/** /**
* Unlike WebSocket where sub-protocol negotiation is part of the * Unlike WebSocket where sub-protocol negotiation is part of the
@ -87,6 +134,12 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession {
tryCloseWithSockJsTransportError(t, CloseStatus.SERVER_ERROR); tryCloseWithSockJsTransportError(t, CloseStatus.SERVER_ERROR);
throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), t); throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), t);
} }
this.handshakeHeaders = request.getHeaders();
this.principal = request.getPrincipal();
this.localAddress = request.getLocalAddress();
this.remoteAddress = request.getRemoteAddress();
try { try {
delegateConnectionEstablished(); delegateConnectionEstablished();
} }

43
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java

@ -35,7 +35,6 @@ import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.ConfigurableWebSocketSession;
import org.springframework.web.socket.sockjs.SockJsMessageDeliveryException; import org.springframework.web.socket.sockjs.SockJsMessageDeliveryException;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
@ -46,7 +45,7 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public abstract class AbstractSockJsSession implements ConfigurableWebSocketSession { public abstract class AbstractSockJsSession implements WebSocketSession {
protected final Log logger = LogFactory.getLog(getClass()); protected final Log logger = LogFactory.getLog(getClass());
@ -97,46 +96,6 @@ public abstract class AbstractSockJsSession implements ConfigurableWebSocketSess
return this.uri; return this.uri;
} }
@Override
public void setUri(URI uri) {
this.uri = uri;
}
@Override
public boolean isSecure() {
return "wss".equals(this.uri.getSchemeSpecificPart());
}
@Override
public String getRemoteHostName() {
return this.remoteHostName;
}
@Override
public void setRemoteHostName(String remoteHostName) {
this.remoteHostName = remoteHostName;
}
@Override
public String getRemoteAddress() {
return this.remoteAddress;
}
@Override
public void setRemoteAddress(String remoteAddress) {
this.remoteAddress = remoteAddress;
}
@Override
public Principal getPrincipal() {
return this.principal;
}
@Override
public void setPrincipal(Principal principal) {
this.principal = principal;
}
public SockJsServiceConfig getSockJsServiceConfig() { public SockJsServiceConfig getSockJsServiceConfig() {
return this.sockJsServiceConfig; return this.sockJsServiceConfig;
} }

71
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java

@ -17,12 +17,17 @@
package org.springframework.web.socket.sockjs.transport.session; package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.Principal;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.DelegatingWebSocketSession;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
import org.springframework.web.socket.sockjs.support.frame.SockJsMessageCodec; import org.springframework.web.socket.sockjs.support.frame.SockJsMessageCodec;
@ -33,47 +38,69 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsMessageCodec;
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.0 * @since 4.0
*/ */
public class WebSocketServerSockJsSession extends AbstractSockJsSession { public class WebSocketServerSockJsSession extends AbstractSockJsSession
implements DelegatingWebSocketSession<WebSocketSession> {
private WebSocketSession webSocketSession; private WebSocketSession wsSession;
public WebSocketServerSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler handler) { public WebSocketServerSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler) {
super(sessionId, config, handler); super(id, config, wsHandler);
} }
@Override
public HttpHeaders getHandshakeHeaders() {
checkDelegateSessionInitialized();
return this.wsSession.getHandshakeHeaders();
}
@Override @Override
public String getAcceptedProtocol() { public Principal getPrincipal() {
if (this.webSocketSession == null) { checkDelegateSessionInitialized();
logger.warn("getAcceptedProtocol() invoked before WebSocketSession has been initialized."); return this.wsSession.getPrincipal();
return null;
}
return this.webSocketSession.getAcceptedProtocol();
} }
@Override @Override
public void setAcceptedProtocol(String protocol) { public InetSocketAddress getLocalAddress() {
// ignore, webSocketSession should have it checkDelegateSessionInitialized();
return this.wsSession.getLocalAddress();
} }
public void initWebSocketSession(WebSocketSession session) throws Exception { @Override
this.webSocketSession = session; public InetSocketAddress getRemoteAddress() {
checkDelegateSessionInitialized();
return this.wsSession.getRemoteAddress();
}
@Override
public String getAcceptedProtocol() {
checkDelegateSessionInitialized();
return this.wsSession.getAcceptedProtocol();
}
private void checkDelegateSessionInitialized() {
Assert.state(this.wsSession != null, "WebSocketSession not yet initialized");
}
@Override
public void afterSessionInitialized(WebSocketSession session) {
this.wsSession = session;
try { try {
TextMessage message = new TextMessage(SockJsFrame.openFrame().getContent()); TextMessage message = new TextMessage(SockJsFrame.openFrame().getContent());
this.webSocketSession.sendMessage(message); this.wsSession.sendMessage(message);
scheduleHeartbeat();
delegateConnectionEstablished();
} }
catch (IOException ex) { catch (Exception ex) {
tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR);
return; return;
} }
scheduleHeartbeat();
delegateConnectionEstablished();
} }
@Override @Override
public boolean isActive() { public boolean isActive() {
return ((this.webSocketSession != null) && this.webSocketSession.isOpen()); return ((this.wsSession != null) && this.wsSession.isOpen());
} }
public void handleMessage(TextMessage message, WebSocketSession wsSession) throws Exception { public void handleMessage(TextMessage message, WebSocketSession wsSession) throws Exception {
@ -109,13 +136,13 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession {
logger.trace("Write " + frame); logger.trace("Write " + frame);
} }
TextMessage message = new TextMessage(frame.getContent()); TextMessage message = new TextMessage(frame.getContent());
this.webSocketSession.sendMessage(message); this.wsSession.sendMessage(message);
} }
@Override @Override
protected void disconnect(CloseStatus status) throws IOException { protected void disconnect(CloseStatus status) throws IOException {
if (this.webSocketSession != null) { if (this.wsSession != null) {
this.webSocketSession.close(status); this.wsSession.close(status);
} }
} }

12
spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketListenerAdapterTests.java → spring-websocket/src/test/java/org/springframework/web/socket/adapter/JettyWebSocketHandlerAdapterTests.java

@ -25,17 +25,17 @@ import org.springframework.web.socket.WebSocketHandler;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
/** /**
* Test fixture for {@link JettyWebSocketListenerAdapter}. * Test fixture for {@link JettyWebSocketHandlerAdapter}.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class JettyWebSocketListenerAdapterTests { public class JettyWebSocketHandlerAdapterTests {
private JettyWebSocketListenerAdapter adapter; private JettyWebSocketHandlerAdapter adapter;
private WebSocketHandler webSocketHandler; private WebSocketHandler webSocketHandler;
private JettyWebSocketSessionAdapter webSocketSession; private JettyWebSocketSession webSocketSession;
private Session session; private Session session;
@ -44,8 +44,8 @@ public class JettyWebSocketListenerAdapterTests {
public void setup() { public void setup() {
this.session = mock(Session.class); this.session = mock(Session.class);
this.webSocketHandler = mock(WebSocketHandler.class); this.webSocketHandler = mock(WebSocketHandler.class);
this.webSocketSession = new JettyWebSocketSessionAdapter(); this.webSocketSession = new JettyWebSocketSession(null);
this.adapter = new JettyWebSocketListenerAdapter(this.webSocketHandler, this.webSocketSession); this.adapter = new JettyWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession);
} }
@Test @Test

12
spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardEndpointAdapterTests.java → spring-websocket/src/test/java/org/springframework/web/socket/adapter/StandardWebSocketHandlerAdapterTests.java

@ -31,17 +31,17 @@ import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
/** /**
* Test fixture for {@link StandardEndpointAdapter}. * Test fixture for {@link StandardWebSocketHandlerAdapter}.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
*/ */
public class StandardEndpointAdapterTests { public class StandardWebSocketHandlerAdapterTests {
private StandardEndpointAdapter adapter; private StandardWebSocketHandlerAdapter adapter;
private WebSocketHandler webSocketHandler; private WebSocketHandler webSocketHandler;
private StandardWebSocketSessionAdapter webSocketSession; private StandardWebSocketSession webSocketSession;
private Session session; private Session session;
@ -50,8 +50,8 @@ public class StandardEndpointAdapterTests {
public void setup() { public void setup() {
this.session = mock(Session.class); this.session = mock(Session.class);
this.webSocketHandler = mock(WebSocketHandler.class); this.webSocketHandler = mock(WebSocketHandler.class);
this.webSocketSession = new StandardWebSocketSessionAdapter(); this.webSocketSession = new StandardWebSocketSession(null, null, null);
this.adapter = new StandardEndpointAdapter(this.webSocketHandler, this.webSocketSession); this.adapter = new StandardWebSocketHandlerAdapter(this.webSocketHandler, this.webSocketSession);
} }
@Test @Test

98
spring-websocket/src/test/java/org/springframework/web/socket/client/endpoint/StandardWebSocketClientTests.java

@ -27,12 +27,12 @@ import javax.websocket.ClientEndpointConfig;
import javax.websocket.Endpoint; import javax.websocket.Endpoint;
import javax.websocket.WebSocketContainer; import javax.websocket.WebSocketContainer;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.StandardEndpointAdapter;
import org.springframework.web.socket.adapter.WebSocketHandlerAdapter; import org.springframework.web.socket.adapter.WebSocketHandlerAdapter;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -45,40 +45,92 @@ import static org.mockito.Mockito.*;
*/ */
public class StandardWebSocketClientTests { public class StandardWebSocketClientTests {
private StandardWebSocketClient wsClient;
private WebSocketContainer wsContainer;
private WebSocketHandler wsHandler;
private HttpHeaders headers;
@Before
public void setup() {
this.headers = new HttpHeaders();
this.wsHandler = new WebSocketHandlerAdapter();
this.wsContainer = mock(WebSocketContainer.class);
this.wsClient = new StandardWebSocketClient(this.wsContainer);
}
@Test
public void localAddress() throws Exception {
URI uri = new URI("ws://example.com/abc");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
assertNotNull(session.getLocalAddress());
assertEquals(80, session.getLocalAddress().getPort());
}
@Test
public void localAddressWss() throws Exception {
URI uri = new URI("wss://example.com/abc");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
assertNotNull(session.getLocalAddress());
assertEquals(443, session.getLocalAddress().getPort());
}
@Test(expected=IllegalArgumentException.class)
public void localAddressNoScheme() throws Exception {
URI uri = new URI("example.com/abc");
this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
}
@Test
public void remoteAddress() throws Exception {
URI uri = new URI("wss://example.com/abc");
WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
assertNotNull(session.getRemoteAddress());
assertEquals("example.com", session.getRemoteAddress().getHostName());
assertEquals(443, session.getLocalAddress().getPort());
}
@Test @Test
public void doHandshake() throws Exception { public void headersWebSocketSession() throws Exception {
URI uri = new URI("ws://example.com/abc"); URI uri = new URI("ws://example.com/abc");
List<String> subprotocols = Arrays.asList("abc"); List<String> protocols = Arrays.asList("abc");
this.headers.setSecWebSocketProtocol(protocols);
this.headers.add("foo", "bar");
HttpHeaders headers = new HttpHeaders(); WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
headers.setSecWebSocketProtocol(subprotocols);
headers.add("foo", "bar");
WebSocketHandler handler = new WebSocketHandlerAdapter(); assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), session.getHandshakeHeaders());
WebSocketContainer webSocketContainer = mock(WebSocketContainer.class); }
StandardWebSocketClient client = new StandardWebSocketClient(webSocketContainer);
WebSocketSession session = client.doHandshake(handler, headers, uri);
ArgumentCaptor<Endpoint> endpointArg = ArgumentCaptor.forClass(Endpoint.class); @Test
ArgumentCaptor<ClientEndpointConfig> configArg = ArgumentCaptor.forClass(ClientEndpointConfig.class); public void headersClientEndpointConfigurator() throws Exception {
ArgumentCaptor<URI> uriArg = ArgumentCaptor.forClass(URI.class);
verify(webSocketContainer).connectToServer(endpointArg.capture(), configArg.capture(), uriArg.capture()); URI uri = new URI("ws://example.com/abc");
List<String> protocols = Arrays.asList("abc");
this.headers.setSecWebSocketProtocol(protocols);
this.headers.add("foo", "bar");
this.wsClient.doHandshake(this.wsHandler, this.headers, uri);
assertNotNull(endpointArg.getValue()); ArgumentCaptor<Endpoint> arg1 = ArgumentCaptor.forClass(Endpoint.class);
assertEquals(StandardEndpointAdapter.class, endpointArg.getValue().getClass()); ArgumentCaptor<ClientEndpointConfig> arg2 = ArgumentCaptor.forClass(ClientEndpointConfig.class);
ArgumentCaptor<URI> arg3 = ArgumentCaptor.forClass(URI.class);
verify(this.wsContainer).connectToServer(arg1.capture(), arg2.capture(), arg3.capture());
ClientEndpointConfig config = configArg.getValue(); ClientEndpointConfig endpointConfig = arg2.getValue();
assertEquals(subprotocols, config.getPreferredSubprotocols()); assertEquals(protocols, endpointConfig.getPreferredSubprotocols());
Map<String, List<String>> map = new HashMap<>(); Map<String, List<String>> map = new HashMap<>();
config.getConfigurator().beforeRequest(map); endpointConfig.getConfigurator().beforeRequest(map);
assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), map); assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), map);
assertEquals(uri, uriArg.getValue());
assertEquals(uri, session.getUri());
assertEquals("example.com", session.getRemoteHostName());
} }
} }

8
spring-websocket/src/test/java/org/springframework/web/socket/client/jetty/JettyWebSocketClientTests.java

@ -33,8 +33,8 @@ import org.springframework.util.CollectionUtils;
import org.springframework.util.SocketUtils; import org.springframework.util.SocketUtils;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.JettyWebSocketListenerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSessionAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession;
import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -113,8 +113,8 @@ public class JettyWebSocketClientTests {
resp.setAcceptedSubProtocol(req.getSubProtocols().get(0)); resp.setAcceptedSubProtocol(req.getSubProtocols().get(0));
} }
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter(); JettyWebSocketSession session = new JettyWebSocketSession(null);
return new JettyWebSocketListenerAdapter(webSocketHandler, session); return new JettyWebSocketHandlerAdapter(webSocketHandler, session);
} }
}); });
} }

77
spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java

@ -17,9 +17,12 @@
package org.springframework.web.socket.sockjs.transport.session; package org.springframework.web.socket.sockjs.transport.session;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
@ -29,6 +32,14 @@ import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
*/ */
public class TestSockJsSession extends AbstractSockJsSession { public class TestSockJsSession extends AbstractSockJsSession {
private HttpHeaders headers;
private Principal principal;
private InetSocketAddress localAddress;
private InetSocketAddress remoteAddress;
private boolean active; private boolean active;
private final List<SockJsFrame> sockJsFrames = new ArrayList<>(); private final List<SockJsFrame> sockJsFrames = new ArrayList<>();
@ -48,12 +59,76 @@ public class TestSockJsSession extends AbstractSockJsSession {
super(sessionId, config, handler); super(sessionId, config, handler);
} }
@Override
public HttpHeaders getHandshakeHeaders() {
return this.headers;
}
/**
* @return the headers
*/
public HttpHeaders getHeaders() {
return this.headers;
}
/**
* @param headers the headers to set
*/
public void setHeaders(HttpHeaders headers) {
this.headers = headers;
}
/**
* @return the principal
*/
@Override
public Principal getPrincipal() {
return this.principal;
}
/**
* @param principal the principal to set
*/
public void setPrincipal(Principal principal) {
this.principal = principal;
}
/**
* @return the localAddress
*/
@Override
public InetSocketAddress getLocalAddress() {
return this.localAddress;
}
/**
* @param remoteAddress the remoteAddress to set
*/
public void setLocalAddress(InetSocketAddress localAddress) {
this.localAddress = localAddress;
}
/**
* @return the remoteAddress
*/
@Override
public InetSocketAddress getRemoteAddress() {
return this.remoteAddress;
}
/**
* @param remoteAddress the remoteAddress to set
*/
public void setRemoteAddress(InetSocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
@Override @Override
public String getAcceptedProtocol() { public String getAcceptedProtocol() {
return this.subProtocol; return this.subProtocol;
} }
@Override
public void setAcceptedProtocol(String protocol) { public void setAcceptedProtocol(String protocol) {
this.subProtocol = protocol; this.subProtocol = protocol;
} }

11
spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSessionTests.java

@ -27,7 +27,6 @@ import org.junit.Test;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession;
import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSessionTests.TestWebSocketServerSockJsSession; import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSessionTests.TestWebSocketServerSockJsSession;
import org.springframework.web.socket.support.TestWebSocketSession; import org.springframework.web.socket.support.TestWebSocketSession;
@ -61,7 +60,7 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
public void isActive() throws Exception { public void isActive() throws Exception {
assertFalse(this.session.isActive()); assertFalse(this.session.isActive());
this.session.initWebSocketSession(this.webSocketSession); this.session.afterSessionInitialized(this.webSocketSession);
assertTrue(this.session.isActive()); assertTrue(this.session.isActive());
this.webSocketSession.setOpen(false); this.webSocketSession.setOpen(false);
@ -69,9 +68,9 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
} }
@Test @Test
public void initWebSocketSession() throws Exception { public void afterSessionInitialized() throws Exception {
this.session.initWebSocketSession(this.webSocketSession); this.session.afterSessionInitialized(this.webSocketSession);
assertEquals("Open frame not sent", assertEquals("Open frame not sent",
Collections.singletonList(new TextMessage("o")), this.webSocketSession.getSentMessages()); Collections.singletonList(new TextMessage("o")), this.webSocketSession.getSentMessages());
@ -110,7 +109,7 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
@Test @Test
public void sendMessageInternal() throws Exception { public void sendMessageInternal() throws Exception {
this.session.initWebSocketSession(this.webSocketSession); this.session.afterSessionInitialized(this.webSocketSession);
this.session.sendMessageInternal("x"); this.session.sendMessageInternal("x");
assertEquals(Arrays.asList(new TextMessage("o"), new TextMessage("a[\"x\"]")), assertEquals(Arrays.asList(new TextMessage("o"), new TextMessage("a[\"x\"]")),
@ -122,7 +121,7 @@ public class WebSocketServerSockJsSessionTests extends BaseAbstractSockJsSession
@Test @Test
public void disconnect() throws Exception { public void disconnect() throws Exception {
this.session.initWebSocketSession(this.webSocketSession); this.session.afterSessionInitialized(this.webSocketSession);
this.session.close(CloseStatus.NOT_ACCEPTABLE); this.session.close(CloseStatus.NOT_ACCEPTABLE);
assertEquals(CloseStatus.NOT_ACCEPTABLE, this.webSocketSession.getCloseStatus()); assertEquals(CloseStatus.NOT_ACCEPTABLE, this.webSocketSession.getCloseStatus());

45
spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java

@ -17,11 +17,13 @@
package org.springframework.web.socket.support; package org.springframework.web.socket.support;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI; import java.net.URI;
import java.security.Principal; import java.security.Principal;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
@ -37,13 +39,11 @@ public class TestWebSocketSession implements WebSocketSession {
private URI uri; private URI uri;
private boolean secure;
private Principal principal; private Principal principal;
private String remoteHostName; private InetSocketAddress localAddress;
private String remoteAddress; private InetSocketAddress remoteAddress;
private String protocol; private String protocol;
@ -53,6 +53,8 @@ public class TestWebSocketSession implements WebSocketSession {
private CloseStatus status; private CloseStatus status;
private HttpHeaders headers;
/** /**
* @return the id * @return the id
@ -84,19 +86,24 @@ public class TestWebSocketSession implements WebSocketSession {
this.uri = uri; this.uri = uri;
} }
@Override
public HttpHeaders getHandshakeHeaders() {
return this.headers;
}
/** /**
* @return the secure * @return the headers
*/ */
@Override public HttpHeaders getHeaders() {
public boolean isSecure() { return this.headers;
return this.secure;
} }
/** /**
* @param secure the secure to set * @param headers the headers to set
*/ */
public void setSecure(boolean secure) { public void setHeaders(HttpHeaders headers) {
this.secure = secure; this.headers = headers;
} }
/** /**
@ -115,32 +122,32 @@ public class TestWebSocketSession implements WebSocketSession {
} }
/** /**
* @return the remoteHostName * @return the localAddress
*/ */
@Override @Override
public String getRemoteHostName() { public InetSocketAddress getLocalAddress() {
return this.remoteHostName; return this.localAddress;
} }
/** /**
* @param remoteHostName the remoteHostName to set * @param remoteAddress the remoteAddress to set
*/ */
public void setRemoteHostName(String remoteHostName) { public void setLocalAddress(InetSocketAddress localAddress) {
this.remoteHostName = remoteHostName; this.localAddress = localAddress;
} }
/** /**
* @return the remoteAddress * @return the remoteAddress
*/ */
@Override @Override
public String getRemoteAddress() { public InetSocketAddress getRemoteAddress() {
return this.remoteAddress; return this.remoteAddress;
} }
/** /**
* @param remoteAddress the remoteAddress to set * @param remoteAddress the remoteAddress to set
*/ */
public void setRemoteAddress(String remoteAddress) { public void setRemoteAddress(InetSocketAddress remoteAddress) {
this.remoteAddress = remoteAddress; this.remoteAddress = remoteAddress;
} }

Loading…
Cancel
Save