diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java index 1f3e82541ef..29791860a8b 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java @@ -22,6 +22,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.util.Assert; import org.springframework.websocket.CloseStatus; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.TextMessage; import org.springframework.websocket.TextMessageHandler; import org.springframework.websocket.WebSocketHandler; @@ -40,6 +41,8 @@ public abstract class AbstractSockJsSession implements WebSocketSession { private final String sessionId; + private final HandlerProvider handlerProvider; + private final TextMessageHandler handler; private State state = State.NEW; @@ -52,14 +55,17 @@ public abstract class AbstractSockJsSession implements WebSocketSession { /** * * @param sessionId - * @param handler the recipient of SockJS messages + * @param handlerProvider the recipient of SockJS messages */ - public AbstractSockJsSession(String sessionId, WebSocketHandler webSocketHandler) { + public AbstractSockJsSession(String sessionId, HandlerProvider handlerProvider) { Assert.notNull(sessionId, "sessionId is required"); - Assert.notNull(webSocketHandler, "webSocketHandler is required"); - Assert.isInstanceOf(TextMessageHandler.class, webSocketHandler, "Expected a TextMessageHandler"); + Assert.notNull(handlerProvider, "handlerProvider is required"); this.sessionId = sessionId; + + WebSocketHandler webSocketHandler = handlerProvider.getHandler(); + Assert.isInstanceOf(TextMessageHandler.class, webSocketHandler, "Expected a TextMessageHandler"); this.handler = (TextMessageHandler) webSocketHandler; + this.handlerProvider = handlerProvider; } public String getId() { @@ -180,7 +186,12 @@ public abstract class AbstractSockJsSession implements WebSocketSession { } finally { this.state = State.CLOSED; - this.handler.afterConnectionClosed(status, this); + try { + this.handler.afterConnectionClosed(status, this); + } + finally { + this.handlerProvider.destroy(this.handler); + } } } } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionFactory.java b/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionFactory.java index c87ff5a7b9c..4f6051c1f88 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionFactory.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSessionFactory.java @@ -16,6 +16,7 @@ package org.springframework.sockjs; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.WebSocketSession; @@ -28,6 +29,6 @@ import org.springframework.websocket.WebSocketSession; */ public interface SockJsSessionFactory{ - S createSession(String sessionId, WebSocketHandler webSocketHandler); + S createSession(String sessionId, HandlerProvider handler); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSockJsSession.java index 51004660a6b..0c179b9dd0c 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSockJsSession.java @@ -25,6 +25,7 @@ import java.util.concurrent.ScheduledFuture; import org.springframework.sockjs.AbstractSockJsSession; import org.springframework.util.Assert; import org.springframework.websocket.CloseStatus; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.TextMessage; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.WebSocketMessage; @@ -45,9 +46,9 @@ public abstract class AbstractServerSockJsSession extends AbstractSockJsSession public AbstractServerSockJsSession(String sessionId, SockJsConfiguration config, - WebSocketHandler webSocketHandler) { + HandlerProvider handler) { - super(sessionId, webSocketHandler); + super(sessionId, handler); this.sockJsConfig = config; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java index 24c3dacb066..0319ff41bc7 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java @@ -39,6 +39,7 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.DigestUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -218,7 +219,7 @@ public abstract class AbstractSockJsService * @throws Exception */ public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - String sockJsPath, WebSocketHandler webSocketHandler) throws Exception { + String sockJsPath, HandlerProvider handler) throws Exception { logger.debug(request.getMethod() + " [" + sockJsPath + "]"); @@ -244,7 +245,7 @@ public abstract class AbstractSockJsService return; } else if (sockJsPath.equals("/websocket")) { - handleRawWebSocketRequest(request, response, webSocketHandler); + handleRawWebSocketRequest(request, response, handler); return; } @@ -264,7 +265,7 @@ public abstract class AbstractSockJsService return; } - handleTransportRequest(request, response, sessionId, TransportType.fromValue(transport), webSocketHandler); + handleTransportRequest(request, response, sessionId, TransportType.fromValue(transport), handler); } finally { response.flush(); @@ -272,10 +273,10 @@ public abstract class AbstractSockJsService } protected abstract void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler) throws Exception; + HandlerProvider handler) throws Exception; protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, - String sessionId, TransportType transportType, WebSocketHandler webSocketHandler) throws Exception; + String sessionId, TransportType transportType, HandlerProvider handler) throws Exception; protected boolean validateRequest(String serverId, String sessionId, String transport) { diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java index 604ddcbca18..eeef0913d19 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsService.java @@ -18,6 +18,7 @@ package org.springframework.sockjs.server; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -30,6 +31,6 @@ public interface SockJsService { void handleRequest(ServerHttpRequest request, ServerHttpResponse response, String sockJsPath, - WebSocketHandler webSocketHandler) throws Exception; + HandlerProvider handler) throws Exception; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java index 8c662127d66..1d0eedb7664 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/TransportHandler.java @@ -18,6 +18,7 @@ package org.springframework.sockjs.server; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.AbstractSockJsSession; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -31,6 +32,6 @@ public interface TransportHandler { TransportType getTransportType(); void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler, AbstractSockJsSession session) throws Exception; + HandlerProvider handler, AbstractSockJsSession session) throws Exception; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java index aa35115fe6d..f54b540eb13 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java @@ -21,7 +21,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import org.springframework.beans.factory.InitializingBean; import org.springframework.http.Cookie; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; @@ -43,6 +42,7 @@ import org.springframework.sockjs.server.transport.XhrPollingTransportHandler; import org.springframework.sockjs.server.transport.XhrStreamingTransportHandler; import org.springframework.sockjs.server.transport.XhrTransportHandler; import org.springframework.util.Assert; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.server.DefaultHandshakeHandler; import org.springframework.websocket.server.HandshakeHandler; @@ -54,7 +54,7 @@ import org.springframework.websocket.server.HandshakeHandler; * @author Rossen Stoyanchev * @since 4.0 */ -public class DefaultSockJsService extends AbstractSockJsService implements InitializingBean { +public class DefaultSockJsService extends AbstractSockJsService { private final Map transportHandlers = new HashMap(); @@ -157,13 +157,13 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi @Override protected void handleRawWebSocketRequest(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler) throws Exception { + HandlerProvider handler) throws Exception { if (isWebSocketEnabled()) { TransportHandler transportHandler = this.transportHandlers.get(TransportType.WEBSOCKET); if (transportHandler != null) { if (transportHandler instanceof HandshakeHandler) { - ((HandshakeHandler) transportHandler).doHandshake(request, response, webSocketHandler); + ((HandshakeHandler) transportHandler).doHandshake(request, response, handler); return; } } @@ -174,7 +174,7 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi @Override protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, - String sessionId, TransportType transportType, WebSocketHandler webSocketHandler) throws Exception { + String sessionId, TransportType transportType, HandlerProvider handler) throws Exception { TransportHandler transportHandler = this.transportHandlers.get(transportType); @@ -201,7 +201,7 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi return; } - AbstractSockJsSession session = getSockJsSession(sessionId, webSocketHandler, transportHandler); + AbstractSockJsSession session = getSockJsSession(sessionId, handler, transportHandler); if (session != null) { if (transportType.setsNoCacheHeader()) { @@ -220,10 +220,10 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi } } - transportHandler.handleRequest(request, response, webSocketHandler, session); + transportHandler.handleRequest(request, response, handler, session); } - public AbstractSockJsSession getSockJsSession(String sessionId, WebSocketHandler webSocketHandler, + public AbstractSockJsSession getSockJsSession(String sessionId, HandlerProvider handler, TransportHandler transportHandler) { AbstractSockJsSession session = this.sessions.get(sessionId); @@ -240,7 +240,7 @@ public class DefaultSockJsService extends AbstractSockJsService implements Initi return session; } logger.debug("Creating new session with session id \"" + sessionId + "\""); - session = (AbstractSockJsSession) sessionFactory.createSession(sessionId, webSocketHandler); + session = (AbstractSockJsSession) sessionFactory.createSession(sessionId, handler); this.sessions.put(sessionId, session); return session; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsHttpRequestHandler.java index b5e2b49cafd..4b16cacb0f1 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsHttpRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/SockJsHttpRequestHandler.java @@ -22,9 +22,6 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.springframework.beans.BeansException; -import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.http.server.AsyncServletServerHttpRequest; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -36,6 +33,7 @@ import org.springframework.web.util.NestedServletException; import org.springframework.web.util.UrlPathHelper; import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; +import org.springframework.websocket.support.SimpleHandlerProvider; /** @@ -43,7 +41,7 @@ import org.springframework.websocket.WebSocketHandler; * @author Rossen Stoyanchev * @since 4.0 */ -public class SockJsHttpRequestHandler implements HttpRequestHandler, BeanFactoryAware { +public class SockJsHttpRequestHandler implements HttpRequestHandler { private final String prefix; @@ -61,15 +59,15 @@ public class SockJsHttpRequestHandler implements HttpRequestHandler, BeanFactory * that begins with the specified prefix will be handled by this service. In a * Servlet container this is the path within the current servlet mapping. */ - public SockJsHttpRequestHandler(String prefix, SockJsService sockJsService, WebSocketHandler webSocketHandler) { + public SockJsHttpRequestHandler(String prefix, SockJsService sockJsService, WebSocketHandler handler) { Assert.hasText(prefix, "prefix is required"); Assert.notNull(sockJsService, "sockJsService is required"); - Assert.notNull(webSocketHandler, "webSocketHandler is required"); + Assert.notNull(handler, "webSocketHandler is required"); this.prefix = prefix; this.sockJsService = sockJsService; - this.handlerProvider = new HandlerProvider(webSocketHandler); + this.handlerProvider = new SimpleHandlerProvider(handler); } /** @@ -80,15 +78,15 @@ public class SockJsHttpRequestHandler implements HttpRequestHandler, BeanFactory * Servlet container this is the path within the current servlet mapping. */ public SockJsHttpRequestHandler(String prefix, SockJsService sockJsService, - Class webSocketHandlerClass) { + HandlerProvider handlerProvider) { Assert.hasText(prefix, "prefix is required"); Assert.notNull(sockJsService, "sockJsService is required"); - Assert.notNull(webSocketHandlerClass, "webSocketHandlerClass is required"); + Assert.notNull(handlerProvider, "handlerProvider is required"); this.prefix = prefix; this.sockJsService = sockJsService; - this.handlerProvider = new HandlerProvider(webSocketHandlerClass); + this.handlerProvider = handlerProvider; } public String getPrefix() { @@ -99,11 +97,6 @@ public class SockJsHttpRequestHandler implements HttpRequestHandler, BeanFactory return this.prefix + "/**"; } - @Override - public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - this.handlerProvider.setBeanFactory(beanFactory); - } - @Override public void handleRequest(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { @@ -119,8 +112,7 @@ public class SockJsHttpRequestHandler implements HttpRequestHandler, BeanFactory ServerHttpResponse httpResponse = new ServletServerHttpResponse(response); try { - WebSocketHandler webSocketHandler = this.handlerProvider.getHandler(); - this.sockJsService.handleRequest(httpRequest, httpResponse, sockJsPath, webSocketHandler); + this.sockJsService.handleRequest(httpRequest, httpResponse, sockJsPath, this.handlerProvider); } catch (Exception ex) { // TODO diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java index 8b7e3604a97..d3928730a48 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpReceivingTransportHandler.java @@ -27,6 +27,7 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.AbstractSockJsSession; import org.springframework.sockjs.server.TransportHandler; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; import com.fasterxml.jackson.databind.JsonMappingException; @@ -53,7 +54,7 @@ public abstract class AbstractHttpReceivingTransportHandler implements Transport @Override public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler, AbstractSockJsSession session) throws Exception { + HandlerProvider webSocketHandler, AbstractSockJsSession session) throws Exception { if (session == null) { response.setStatusCode(HttpStatus.NOT_FOUND); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java index 978663598c3..e2218715b01 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java @@ -28,6 +28,7 @@ import org.springframework.sockjs.server.ConfigurableTransportHandler; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; /** @@ -55,7 +56,7 @@ public abstract class AbstractHttpSendingTransportHandler @Override public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler, AbstractSockJsSession session) throws Exception { + HandlerProvider webSocketHandler, AbstractSockJsSession session) throws Exception { // Set content type before writing response.getHeaders().setContentType(getContentType()); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSockJsSession.java index e87006958b4..72aa97e7c11 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSockJsSession.java @@ -29,6 +29,7 @@ import org.springframework.sockjs.server.SockJsFrame.FrameFormat; import org.springframework.sockjs.server.TransportHandler; import org.springframework.util.Assert; import org.springframework.websocket.CloseStatus; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; /** @@ -49,9 +50,9 @@ public abstract class AbstractHttpServerSockJsSession extends AbstractServerSock public AbstractHttpServerSockJsSession(String sessionId, SockJsConfiguration sockJsConfig, - WebSocketHandler webSocketHandler) { + HandlerProvider handler) { - super(sessionId, sockJsConfig, webSocketHandler); + super(sessionId, sockJsConfig, handler); } public void setFrameFormat(FrameFormat frameFormat) { diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java index 0adfd867b1e..ed626a29615 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractStreamingTransportHandler.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -33,9 +34,9 @@ public abstract class AbstractStreamingTransportHandler extends AbstractHttpSend @Override - public StreamingServerSockJsSession createSession(String sessionId, WebSocketHandler webSocketHandler) { + public StreamingServerSockJsSession createSession(String sessionId, HandlerProvider handler) { Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); - return new StreamingServerSockJsSession(sessionId, getSockJsConfig(), webSocketHandler); + return new StreamingServerSockJsSession(sessionId, getSockJsConfig(), handler); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java index fabd525393e..3f5854d30bb 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/JsonpPollingTransportHandler.java @@ -27,6 +27,7 @@ import org.springframework.sockjs.server.TransportType; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.util.JavaScriptUtils; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -50,9 +51,9 @@ public class JsonpPollingTransportHandler extends AbstractHttpSendingTransportHa } @Override - public PollingServerSockJsSession createSession(String sessionId, WebSocketHandler webSocketHandler) { + public PollingServerSockJsSession createSession(String sessionId, HandlerProvider handler) { Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); - return new PollingServerSockJsSession(sessionId, getSockJsConfig(), webSocketHandler); + return new PollingServerSockJsSession(sessionId, getSockJsConfig(), handler); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingServerSockJsSession.java index c9814f09f21..86851ab9d59 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingServerSockJsSession.java @@ -17,15 +17,16 @@ package org.springframework.sockjs.server.transport; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; public class PollingServerSockJsSession extends AbstractHttpServerSockJsSession { public PollingServerSockJsSession(String sessionId, SockJsConfiguration sockJsConfig, - WebSocketHandler webSocketHandler) { + HandlerProvider handler) { - super(sessionId, sockJsConfig, webSocketHandler); + super(sessionId, sockJsConfig, handler); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java index 4e4dc0987e9..0e504569572 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java @@ -17,8 +17,6 @@ package org.springframework.sockjs.server.transport; import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -29,6 +27,7 @@ import org.springframework.sockjs.server.SockJsFrame; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.websocket.CloseStatus; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.TextMessage; import org.springframework.websocket.TextMessageHandler; import org.springframework.websocket.WebSocketHandler; @@ -38,8 +37,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; /** - * A SockJS implementation of {@link WebSocketHandler}. Delegates messages to and from a - * {@link SockJsHandler} and adds SockJS message framing. + * A wrapper around a {@link WebSocketHandler} instance that parses and adds SockJS + * messages frames as well as sends SockJS heartbeat messages. * * @author Rossen Stoyanchev * @since 4.0 @@ -50,34 +49,28 @@ public class SockJsWebSocketHandler implements TextMessageHandler { private final SockJsConfiguration sockJsConfig; - private final WebSocketHandler webSocketHandler; + private final HandlerProvider handlerProvider; - private final Map sessions = - new ConcurrentHashMap(); + private AbstractSockJsSession session; // TODO: JSON library used must be configurable private final ObjectMapper objectMapper = new ObjectMapper(); - public SockJsWebSocketHandler(SockJsConfiguration sockJsConfig, WebSocketHandler webSocketHandler) { - Assert.notNull(sockJsConfig, "sockJsConfig is required"); - Assert.notNull(webSocketHandler, "webSocketHandler is required"); - this.sockJsConfig = sockJsConfig; - this.webSocketHandler = webSocketHandler; + public SockJsWebSocketHandler(SockJsConfiguration config, HandlerProvider handlerProvider) { + Assert.notNull(config, "sockJsConfig is required"); + Assert.notNull(handlerProvider, "handlerProvider is required"); + this.sockJsConfig = config; + this.handlerProvider = handlerProvider; } protected SockJsConfiguration getSockJsConfig() { return this.sockJsConfig; } - protected AbstractSockJsSession getSockJsSession(WebSocketSession wsSession) { - return this.sessions.get(wsSession); - } - @Override public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception { - AbstractSockJsSession session = new WebSocketServerSockJsSession(wsSession, getSockJsConfig()); - this.sessions.put(wsSession, session); + this.session = new WebSocketServerSockJsSession(wsSession, getSockJsConfig()); } @Override @@ -89,8 +82,7 @@ public class SockJsWebSocketHandler implements TextMessageHandler { } try { String[] messages = this.objectMapper.readValue(payload, String[].class); - AbstractSockJsSession session = getSockJsSession(wsSession); - session.delegateMessages(messages); + this.session.delegateMessages(messages); } catch (IOException e) { logger.error("Broken data received. Terminating WebSocket connection abruptly", e); @@ -100,14 +92,12 @@ public class SockJsWebSocketHandler implements TextMessageHandler { @Override public void afterConnectionClosed(CloseStatus status, WebSocketSession wsSession) throws Exception { - AbstractSockJsSession session = this.sessions.remove(wsSession); - session.delegateConnectionClosed(status); + this.session.delegateConnectionClosed(status); } @Override public void handleError(Throwable exception, WebSocketSession webSocketSession) { - AbstractSockJsSession session = getSockJsSession(webSocketSession); - session.delegateError(exception); + this.session.delegateError(exception); } private static String getSockJsSessionId(WebSocketSession wsSession) { @@ -127,7 +117,7 @@ public class SockJsWebSocketHandler implements TextMessageHandler { public WebSocketServerSockJsSession(WebSocketSession wsSession, SockJsConfiguration sockJsConfig) throws Exception { - super(getSockJsSessionId(wsSession), sockJsConfig, SockJsWebSocketHandler.this.webSocketHandler); + super(getSockJsSessionId(wsSession), sockJsConfig, SockJsWebSocketHandler.this.handlerProvider); this.wsSession = wsSession; TextMessage message = new TextMessage(SockJsFrame.openFrame().getContent()); this.wsSession.sendMessage(message); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingServerSockJsSession.java index 77cacca7808..094fd115e98 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingServerSockJsSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -29,9 +30,9 @@ public class StreamingServerSockJsSession extends AbstractHttpServerSockJsSessio public StreamingServerSockJsSession(String sessionId, SockJsConfiguration sockJsConfig, - WebSocketHandler webSocketHandler) { + HandlerProvider handler) { - super(sessionId, sockJsConfig, webSocketHandler); + super(sessionId, sockJsConfig, handler); } protected void flushCache() throws Exception { diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java index 836e7331f54..361c8b3c14c 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketTransportHandler.java @@ -24,8 +24,10 @@ import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.TransportHandler; import org.springframework.sockjs.server.TransportType; import org.springframework.util.Assert; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.server.HandshakeHandler; +import org.springframework.websocket.support.SimpleHandlerProvider; /** @@ -60,26 +62,19 @@ public class WebSocketTransportHandler implements ConfigurableTransportHandler, @Override public void handleRequest(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler, AbstractSockJsSession session) throws Exception { + HandlerProvider handler, AbstractSockJsSession session) throws Exception { - this.handshakeHandler.doHandshake(request, response, adaptSockJsHandler(webSocketHandler)); - } - - /** - * Adapt the {@link SockJsHandler} to the {@link WebSocketHandler} contract for - * exchanging SockJS message over WebSocket. - */ - protected WebSocketHandler adaptSockJsHandler(WebSocketHandler handler) { - return new SockJsWebSocketHandler(this.sockJsConfig, handler); + WebSocketHandler sockJsWrapper = new SockJsWebSocketHandler(this.sockJsConfig, handler); + this.handshakeHandler.doHandshake(request, response, new SimpleHandlerProvider(sockJsWrapper)); } // HandshakeHandler methods @Override public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler) throws Exception { + HandlerProvider handler) throws Exception { - return this.handshakeHandler.doHandshake(request, response, webSocketHandler); + return this.handshakeHandler.doHandshake(request, response, handler); } } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrPollingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrPollingTransportHandler.java index 57b5fd4e944..06a6aaf37d8 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrPollingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrPollingTransportHandler.java @@ -23,6 +23,7 @@ import org.springframework.sockjs.server.SockJsFrame.DefaultFrameFormat; import org.springframework.sockjs.server.SockJsFrame.FrameFormat; import org.springframework.sockjs.server.TransportType; import org.springframework.util.Assert; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -50,9 +51,9 @@ public class XhrPollingTransportHandler extends AbstractHttpSendingTransportHand return new DefaultFrameFormat("%s\n"); } - public PollingServerSockJsSession createSession(String sessionId, WebSocketHandler webSocketHandler) { + public PollingServerSockJsSession createSession(String sessionId, HandlerProvider handler) { Assert.notNull(getSockJsConfig(), "This transport requires SockJsConfiguration"); - return new PollingServerSockJsSession(sessionId, getSockJsConfig(), webSocketHandler); + return new PollingServerSockJsSession(sessionId, getSockJsConfig(), handler); } } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/HandlerProvider.java b/spring-websocket/src/main/java/org/springframework/websocket/HandlerProvider.java index 135c674d8bd..3c8dc2c455d 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/HandlerProvider.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/HandlerProvider.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,82 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.websocket; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.beans.BeansException; -import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.BeanFactoryAware; -import org.springframework.beans.factory.config.AutowireCapableBeanFactory; -import org.springframework.util.Assert; -import org.springframework.util.ClassUtils; - /** + * A strategy for obtaining a handler instance that is scoped to external lifecycle events + * such as the opening and closing of a WebSocket connection. * * @author Rossen Stoyanchev * @since 4.0 */ -public class HandlerProvider implements BeanFactoryAware { - - private Log logger = LogFactory.getLog(this.getClass()); - - private final T handlerBean; - - private final Class handlerClass; - - private AutowireCapableBeanFactory beanFactory; - - - public HandlerProvider(T handlerBean) { - Assert.notNull(handlerBean, "handlerBean is required"); - this.handlerBean = handlerBean; - this.handlerClass = null; - } - - public HandlerProvider(Class handlerClass) { - Assert.notNull(handlerClass, "handlerClass is required"); - this.handlerBean = null; - this.handlerClass = handlerClass; - } - - @Override - public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - if (beanFactory instanceof AutowireCapableBeanFactory) { - this.beanFactory = (AutowireCapableBeanFactory) beanFactory; - } - } - - public void setLogger(Log logger) { - this.logger = logger; - } - - public boolean isSingleton() { - return (this.handlerBean != null); - } - - @SuppressWarnings("unchecked") - public Class getHandlerType() { - if (this.handlerClass != null) { - return this.handlerClass; - } - return (Class) ClassUtils.getUserClass(this.handlerBean.getClass()); - } - - public T getHandler() { - if (this.handlerBean != null) { - if (logger != null && logger.isTraceEnabled()) { - logger.trace("Returning handler singleton " + this.handlerBean); - } - return this.handlerBean; - } - Assert.isTrue(this.beanFactory != null, "BeanFactory is required to initialize handler instances."); - if (logger != null && logger.isTraceEnabled()) { - logger.trace("Creating handler of type " + this.handlerClass); - } - return this.beanFactory.createBean(this.handlerClass); - } +public interface HandlerProvider { + + /** + * Whether the provided handler is a shared instance or not. + */ + boolean isSingleton(); + + /** + * The type of handler provided. + */ + Class getHandlerType(); + + /** + * Obtain the handler instance, either shared or created every time. + */ + T getHandler(); + + /** + * Callback to destroy a previously created handler instance if it is not shared. + */ + void destroy(T handler); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/WebSocketClient.java b/spring-websocket/src/main/java/org/springframework/websocket/client/WebSocketClient.java index ae82de0dc08..579fcb1d8aa 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/WebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/WebSocketClient.java @@ -18,6 +18,7 @@ package org.springframework.websocket.client; import java.net.URI; import org.springframework.http.HttpHeaders; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.WebSocketSession; @@ -36,13 +37,10 @@ import org.springframework.websocket.WebSocketSession; public interface WebSocketClient { - WebSocketSession doHandshake(WebSocketHandler handler, String uriTemplate, Object... uriVariables) - throws WebSocketConnectFailureException; - - WebSocketSession doHandshake(WebSocketHandler handler, URI uri) - throws WebSocketConnectFailureException; + WebSocketSession doHandshake(HandlerProvider handler, + String uriTemplate, Object... uriVariables) throws WebSocketConnectFailureException; - WebSocketSession doHandshake(WebSocketHandler handler, HttpHeaders headers, URI uri) + WebSocketSession doHandshake(HandlerProvider handler, HttpHeaders headers, URI uri) throws WebSocketConnectFailureException; } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/WebSocketConnectionManager.java b/spring-websocket/src/main/java/org/springframework/websocket/client/WebSocketConnectionManager.java index cb16e5be449..29b465e477c 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/WebSocketConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/WebSocketConnectionManager.java @@ -21,6 +21,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.WebSocketSession; +import org.springframework.websocket.support.SimpleHandlerProvider; /** @@ -32,7 +33,7 @@ public class WebSocketConnectionManager extends AbstractWebSocketConnectionManag private final WebSocketClient client; - private final HandlerProvider webSocketHandlerProvider; + private final HandlerProvider handlerProvider; private WebSocketSession webSocketSession; @@ -43,8 +44,16 @@ public class WebSocketConnectionManager extends AbstractWebSocketConnectionManag WebSocketHandler webSocketHandler, String uriTemplate, Object... uriVariables) { super(uriTemplate, uriVariables); - this.webSocketHandlerProvider = new HandlerProvider(webSocketHandler); this.client = webSocketClient; + this.handlerProvider = new SimpleHandlerProvider(webSocketHandler); + } + + public WebSocketConnectionManager(WebSocketClient webSocketClient, + HandlerProvider handlerProvider, String uriTemplate, Object... uriVariables) { + + super(uriTemplate, uriVariables); + this.client = webSocketClient; + this.handlerProvider = handlerProvider; } public void setSubProtocols(List subProtocols) { @@ -57,10 +66,9 @@ public class WebSocketConnectionManager extends AbstractWebSocketConnectionManag @Override protected void openConnection() throws Exception { - WebSocketHandler webSocketHandler = this.webSocketHandlerProvider.getHandler(); HttpHeaders headers = new HttpHeaders(); headers.setSecWebSocketProtocol(this.subProtocols); - this.webSocketSession = this.client.doHandshake(webSocketHandler, headers, getUri()); + this.webSocketSession = this.client.doHandshake(this.handlerProvider, headers, getUri()); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/AnnotatedEndpointConnectionManager.java b/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/AnnotatedEndpointConnectionManager.java index 75433591ca4..9bb7fe1d2ad 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/AnnotatedEndpointConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/AnnotatedEndpointConnectionManager.java @@ -22,6 +22,8 @@ import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.websocket.HandlerProvider; +import org.springframework.websocket.support.BeanCreatingHandlerProvider; +import org.springframework.websocket.support.SimpleHandlerProvider; /** @@ -32,28 +34,29 @@ import org.springframework.websocket.HandlerProvider; public class AnnotatedEndpointConnectionManager extends EndpointConnectionManagerSupport implements BeanFactoryAware { - private final HandlerProvider endpointProvider; + private final HandlerProvider handlerProvider; - public AnnotatedEndpointConnectionManager(Class endpointClass, String uriTemplate, Object... uriVariables) { + public AnnotatedEndpointConnectionManager(Object endpointBean, String uriTemplate, Object... uriVariables) { super(uriTemplate, uriVariables); - this.endpointProvider = new HandlerProvider(endpointClass); + this.handlerProvider = new SimpleHandlerProvider(endpointBean); } - public AnnotatedEndpointConnectionManager(Object endpointBean, String uriTemplate, Object... uriVariables) { + public AnnotatedEndpointConnectionManager(Class endpointClass, String uriTemplate, Object... uriVariables) { super(uriTemplate, uriVariables); - this.endpointProvider = new HandlerProvider(endpointBean); + this.handlerProvider = new BeanCreatingHandlerProvider(endpointClass); } @Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - this.endpointProvider.setBeanFactory(beanFactory); + if (this.handlerProvider instanceof BeanFactoryAware) { + ((BeanFactoryAware) this.handlerProvider).setBeanFactory(beanFactory); + } } - @Override protected void openConnection() throws Exception { - Object endpoint = this.endpointProvider.getHandler(); + Object endpoint = this.handlerProvider.getHandler(); Session session = getWebSocketContainer().connectToServer(endpoint, getUri()); updateSession(session); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/EndpointConnectionManager.java b/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/EndpointConnectionManager.java index b7006ab2fc5..1307f55d9e7 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/EndpointConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/EndpointConnectionManager.java @@ -32,6 +32,8 @@ import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.util.Assert; import org.springframework.websocket.HandlerProvider; +import org.springframework.websocket.support.BeanCreatingHandlerProvider; +import org.springframework.websocket.support.SimpleHandlerProvider; /** @@ -43,19 +45,19 @@ public class EndpointConnectionManager extends EndpointConnectionManagerSupport private final ClientEndpointConfig.Builder configBuilder = ClientEndpointConfig.Builder.create(); - private final HandlerProvider endpointProvider; + private final HandlerProvider handlerProvider; public EndpointConnectionManager(Endpoint endpointBean, String uriTemplate, Object... uriVariables) { super(uriTemplate, uriVariables); Assert.notNull(endpointBean, "endpointBean is required"); - this.endpointProvider = new HandlerProvider(endpointBean); + this.handlerProvider = new SimpleHandlerProvider(endpointBean); } public EndpointConnectionManager(Class endpointClass, String uriTemplate, Object... uriVars) { super(uriTemplate, uriVars); Assert.notNull(endpointClass, "endpointClass is required"); - this.endpointProvider = new HandlerProvider(endpointClass); + this.handlerProvider = new BeanCreatingHandlerProvider(endpointClass); } public void setSubProtocols(String... subprotocols) { @@ -80,12 +82,14 @@ public class EndpointConnectionManager extends EndpointConnectionManagerSupport @Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - this.endpointProvider.setBeanFactory(beanFactory); + if (this.handlerProvider instanceof BeanFactoryAware) { + ((BeanFactoryAware) this.handlerProvider).setBeanFactory(beanFactory); + } } @Override protected void openConnection() throws Exception { - Endpoint endpoint = this.endpointProvider.getHandler(); + Endpoint endpoint = this.handlerProvider.getHandler(); ClientEndpointConfig endpointConfig = this.configBuilder.build(); Session session = getWebSocketContainer().connectToServer(endpoint, endpointConfig, getUri()); updateSession(session); diff --git a/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/StandardWebSocketClient.java index 522bfcdf1e0..1a15d4bd958 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/StandardWebSocketClient.java @@ -31,6 +31,7 @@ import javax.websocket.WebSocketContainer; import org.springframework.http.HttpHeaders; import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.WebSocketSession; import org.springframework.websocket.client.WebSocketClient; @@ -40,6 +41,7 @@ import org.springframework.websocket.endpoint.WebSocketHandlerEndpoint; /** + * A standard Java {@link WebSocketClient}. * * @author Rossen Stoyanchev * @since 4.0 @@ -57,21 +59,16 @@ public class StandardWebSocketClient implements WebSocketClient { this.webSocketContainer = container; } - public WebSocketSession doHandshake(WebSocketHandler handler, String uriTemplate, - Object... uriVariables) throws WebSocketConnectFailureException { + public WebSocketSession doHandshake(HandlerProvider handler, + String uriTemplate, Object... uriVariables) throws WebSocketConnectFailureException { URI uri = UriComponentsBuilder.fromUriString(uriTemplate).buildAndExpand(uriVariables).encode().toUri(); - return doHandshake(handler, uri); - } - - @Override - public WebSocketSession doHandshake(WebSocketHandler handler, URI uri) throws WebSocketConnectFailureException { return doHandshake(handler, null, uri); } @Override - public WebSocketSession doHandshake(WebSocketHandler handler, final HttpHeaders httpHeaders, URI uri) - throws WebSocketConnectFailureException { + public WebSocketSession doHandshake(HandlerProvider handler, + final HttpHeaders httpHeaders, URI uri) throws WebSocketConnectFailureException { Endpoint endpoint = new WebSocketHandlerEndpoint(handler); diff --git a/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java b/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java index 82d75c098db..bb1048c7916 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java @@ -16,9 +16,6 @@ package org.springframework.websocket.endpoint; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - import javax.websocket.CloseReason; import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; @@ -27,14 +24,15 @@ import javax.websocket.MessageHandler; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.util.Assert; -import org.springframework.websocket.CloseStatus; import org.springframework.websocket.BinaryMessage; import org.springframework.websocket.BinaryMessageHandler; +import org.springframework.websocket.CloseStatus; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.PartialMessageHandler; -import org.springframework.websocket.WebSocketHandler; -import org.springframework.websocket.WebSocketSession; import org.springframework.websocket.TextMessage; import org.springframework.websocket.TextMessageHandler; +import org.springframework.websocket.WebSocketHandler; +import org.springframework.websocket.WebSocketSession; /** @@ -47,14 +45,17 @@ public class WebSocketHandlerEndpoint extends Endpoint { private static Log logger = LogFactory.getLog(WebSocketHandlerEndpoint.class); - private final WebSocketHandler webSocketHandler; + private final HandlerProvider handlerProvider; - private final Map sessions = new ConcurrentHashMap(); + private final WebSocketHandler handler; + private WebSocketSession webSocketSession; - public WebSocketHandlerEndpoint(WebSocketHandler handler) { - Assert.notNull(handler, "webSocketHandler is required"); - this.webSocketHandler = handler; + + public WebSocketHandlerEndpoint(HandlerProvider handlerProvider) { + Assert.notNull(handlerProvider, "handlerProvider is required"); + this.handlerProvider = handlerProvider; + this.handler = handlerProvider.getHandler(); } @Override @@ -63,10 +64,9 @@ public class WebSocketHandlerEndpoint extends Endpoint { logger.debug("Client connected, WebSocket session id=" + session.getId() + ", uri=" + session.getRequestURI()); } try { - WebSocketSession webSocketSession = new StandardWebSocketSession(session); - this.sessions.put(session.getId(), webSocketSession); + this.webSocketSession = new StandardWebSocketSession(session); - if (this.webSocketHandler instanceof TextMessageHandler) { + if (this.handler instanceof TextMessageHandler) { session.addMessageHandler(new MessageHandler.Whole() { @Override public void onMessage(String message) { @@ -74,8 +74,8 @@ public class WebSocketHandlerEndpoint extends Endpoint { } }); } - else if (this.webSocketHandler instanceof BinaryMessageHandler) { - if (this.webSocketHandler instanceof PartialMessageHandler) { + else if (this.handler instanceof BinaryMessageHandler) { + if (this.handler instanceof PartialMessageHandler) { session.addMessageHandler(new MessageHandler.Partial() { @Override public void onMessage(byte[] messagePart, boolean isLast) { @@ -93,10 +93,10 @@ public class WebSocketHandlerEndpoint extends Endpoint { } } else { - logger.warn("WebSocketHandler handles neither text nor binary messages: " + this.webSocketHandler); + logger.warn("WebSocketHandler handles neither text nor binary messages: " + this.handler); } - this.webSocketHandler.afterConnectionEstablished(webSocketSession); + this.handler.afterConnectionEstablished(this.webSocketSession); } catch (Throwable ex) { // TODO @@ -108,11 +108,9 @@ public class WebSocketHandlerEndpoint extends Endpoint { if (logger.isTraceEnabled()) { logger.trace("Received message for WebSocket session id=" + session.getId() + ": " + message); } - WebSocketSession wsSession = getWebSocketSession(session); - Assert.notNull(wsSession, "WebSocketSession not found"); try { TextMessage textMessage = new TextMessage(message); - ((TextMessageHandler) webSocketHandler).handleTextMessage(textMessage, wsSession); + ((TextMessageHandler) handler).handleTextMessage(textMessage, this.webSocketSession); } catch (Throwable ex) { // TODO @@ -124,11 +122,9 @@ public class WebSocketHandlerEndpoint extends Endpoint { if (logger.isTraceEnabled()) { logger.trace("Received binary data for WebSocket session id=" + session.getId()); } - WebSocketSession wsSession = getWebSocketSession(session); - Assert.notNull(wsSession, "WebSocketSession not found"); try { BinaryMessage binaryMessage = new BinaryMessage(message, isLast); - ((BinaryMessageHandler) webSocketHandler).handleBinaryMessage(binaryMessage, wsSession); + ((BinaryMessageHandler) handler).handleBinaryMessage(binaryMessage, this.webSocketSession); } catch (Throwable ex) { // TODO @@ -142,32 +138,23 @@ public class WebSocketHandlerEndpoint extends Endpoint { logger.debug("Client disconnected, WebSocket session id=" + session.getId() + ", " + reason); } try { - WebSocketSession wsSession = this.sessions.remove(session.getId()); - if (wsSession != null) { - CloseStatus closeStatus = new CloseStatus(reason.getCloseCode().getCode(), reason.getReasonPhrase()); - this.webSocketHandler.afterConnectionClosed(closeStatus, wsSession); - } - else { - Assert.notNull(wsSession, "No WebSocket session"); - } + CloseStatus closeStatus = new CloseStatus(reason.getCloseCode().getCode(), reason.getReasonPhrase()); + this.handler.afterConnectionClosed(closeStatus, this.webSocketSession); } catch (Throwable ex) { // TODO logger.error("Error while processing session closing", ex); } + finally { + this.handlerProvider.destroy(this.handler); + } } @Override public void onError(javax.websocket.Session session, Throwable exception) { logger.error("Error for WebSocket session id=" + session.getId(), exception); try { - WebSocketSession wsSession = getWebSocketSession(session); - if (wsSession != null) { - this.webSocketHandler.handleError(exception, wsSession); - } - else { - logger.warn("WebSocketSession not found. Perhaps onError was called after onClose?"); - } + this.handler.handleError(exception, this.webSocketSession); } catch (Throwable ex) { // TODO @@ -175,8 +162,4 @@ public class WebSocketHandlerEndpoint extends Endpoint { } } - private WebSocketSession getWebSocketSession(javax.websocket.Session session) { - return this.sessions.get(session.getId()); - } - } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java index c7a3fdcb875..3ed1ccb4414 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java @@ -35,6 +35,7 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -87,7 +88,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { @Override public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler) throws Exception { + HandlerProvider handler) throws Exception { logger.debug("Starting handshake for " + request.getURI()); @@ -135,10 +136,10 @@ public class DefaultHandshakeHandler implements HandshakeHandler { response.flush(); if (logger.isTraceEnabled()) { - logger.trace("Upgrading with " + webSocketHandler); + logger.trace("Upgrading with " + handler); } - this.requestUpgradeStrategy.upgrade(request, response, selectedProtocol, webSocketHandler); + this.requestUpgradeStrategy.upgrade(request, response, selectedProtocol, handler); return true; } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java index 2af525c6063..61f0e42b332 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/HandshakeHandler.java @@ -18,6 +18,7 @@ package org.springframework.websocket.server; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -29,16 +30,8 @@ import org.springframework.websocket.WebSocketHandler; */ public interface HandshakeHandler { - /** - * - * @param request the HTTP request - * @param response the HTTP response - * @param webSocketMessageHandler the handler to process WebSocket messages with - * @return a boolean indicating whether the handshake negotiation was successful - * - * @throws Exception - */ - boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler) - throws Exception; + + boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, + HandlerProvider handler) throws Exception; } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java index 81c238c4230..6d2827baaa5 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/RequestUpgradeStrategy.java @@ -18,6 +18,7 @@ package org.springframework.websocket.server; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; @@ -39,9 +40,9 @@ public interface RequestUpgradeStrategy { * Perform runtime specific steps to complete the upgrade. * Invoked only if the handshake is successful. * - * @param webSocketHandler the handler for WebSocket messages + * @param handler the handler for WebSocket messages */ void upgrade(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, - WebSocketHandler webSocketHandler) throws Exception; + HandlerProvider handler) throws Exception; } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java index d9e36a61fb6..c6aadd3028d 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java @@ -37,6 +37,8 @@ import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.util.Assert; import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.endpoint.WebSocketHandlerEndpoint; +import org.springframework.websocket.support.BeanCreatingHandlerProvider; +import org.springframework.websocket.support.SimpleHandlerProvider; /** @@ -54,11 +56,9 @@ import org.springframework.websocket.endpoint.WebSocketHandlerEndpoint; */ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAware { - private static Log logger = LogFactory.getLog(EndpointRegistration.class); - private final String path; - private final HandlerProvider endpointProvider; + private final HandlerProvider handlerProvider; private List> encoders = new ArrayList>(); @@ -84,16 +84,14 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw Assert.hasText(path, "path must not be empty"); Assert.notNull(endpointClass, "endpointClass is required"); this.path = path; - this.endpointProvider = new HandlerProvider(endpointClass); - this.endpointProvider.setLogger(logger); + this.handlerProvider = new BeanCreatingHandlerProvider(endpointClass); } public EndpointRegistration(String path, Endpoint endpointBean) { Assert.hasText(path, "path must not be empty"); Assert.notNull(endpointBean, "endpointBean is required"); this.path = path; - this.endpointProvider = new HandlerProvider(endpointBean); - this.endpointProvider.setLogger(logger); + this.handlerProvider = new SimpleHandlerProvider(endpointBean); } @Override @@ -102,12 +100,13 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw } @Override + @SuppressWarnings("unchecked") public Class getEndpointClass() { - return this.endpointProvider.getHandlerType(); + return (Class) this.handlerProvider.getHandlerType(); } public Endpoint getEndpoint() { - return this.endpointProvider.getHandler(); + return this.handlerProvider.getHandler(); } public void setSubprotocols(List subprotocols) { @@ -193,7 +192,9 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw @Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - this.endpointProvider.setBeanFactory(beanFactory); + if (this.handlerProvider instanceof BeanFactoryAware) { + ((BeanFactoryAware) this.handlerProvider).setBeanFactory(beanFactory); + } } } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointContainerFactoryBean.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointContainerFactoryBean.java deleted file mode 100644 index ee6211341be..00000000000 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointContainerFactoryBean.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.websocket.server.support; - -import javax.websocket.WebSocketContainer; - -import org.springframework.beans.factory.FactoryBean; -import org.springframework.beans.factory.InitializingBean; - - -/** -* -* @author Rossen Stoyanchev -* @since 4.0 -*/ -public abstract class AbstractEndpointContainerFactoryBean implements FactoryBean, InitializingBean { - - private WebSocketContainer container; - - - public void setAsyncSendTimeout(long timeoutInMillis) { - this.container.setAsyncSendTimeout(timeoutInMillis); - } - - public long getAsyncSendTimeout() { - return this.container.getDefaultAsyncSendTimeout(); - } - - public void setMaxSessionIdleTimeout(long timeoutInMillis) { - this.container.setDefaultMaxSessionIdleTimeout(timeoutInMillis); - } - - public long getMaxSessionIdleTimeout() { - return this.container.getDefaultMaxSessionIdleTimeout(); - } - - public void setMaxTextMessageBufferSize(int bufferSize) { - this.container.setDefaultMaxTextMessageBufferSize(bufferSize); - } - - public int getMaxTextMessageBufferSize() { - return this.container.getDefaultMaxTextMessageBufferSize(); - } - - public void setMaxBinaryMessageBufferSize(int bufferSize) { - this.container.setDefaultMaxBinaryMessageBufferSize(bufferSize); - } - - public int getMaxBinaryMessageBufferSize() { - return this.container.getDefaultMaxBinaryMessageBufferSize(); - } - - @Override - public void afterPropertiesSet() throws Exception { - this.container = getContainer(); - } - - protected abstract WebSocketContainer getContainer(); - - @Override - public WebSocketContainer getObject() throws Exception { - return this.container; - } - - @Override - public Class getObjectType() { - return WebSocketContainer.class; - } - - @Override - public boolean isSingleton() { - return true; - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java index b16f3d746d6..f7884838d32 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/AbstractEndpointUpgradeStrategy.java @@ -22,6 +22,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.endpoint.WebSocketHandlerEndpoint; import org.springframework.websocket.server.RequestUpgradeStrategy; @@ -41,12 +42,12 @@ public abstract class AbstractEndpointUpgradeStrategy implements RequestUpgradeS @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String protocol, WebSocketHandler webSocketHandler) throws Exception { + String protocol, HandlerProvider handler) throws Exception { - upgradeInternal(request, response, protocol, adaptWebSocketHandler(webSocketHandler)); + upgradeInternal(request, response, protocol, adaptWebSocketHandler(handler)); } - protected Endpoint adaptWebSocketHandler(WebSocketHandler handler) { + protected Endpoint adaptWebSocketHandler(HandlerProvider handler) { return new WebSocketHandlerEndpoint(handler); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/support/WebSocketHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/support/WebSocketHttpRequestHandler.java index 9418afc6893..096ee1ca0b0 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/support/WebSocketHttpRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/support/WebSocketHttpRequestHandler.java @@ -22,9 +22,6 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.springframework.beans.BeansException; -import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; @@ -36,6 +33,7 @@ import org.springframework.websocket.HandlerProvider; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.server.DefaultHandshakeHandler; import org.springframework.websocket.server.HandshakeHandler; +import org.springframework.websocket.support.SimpleHandlerProvider; /** @@ -44,7 +42,7 @@ import org.springframework.websocket.server.HandshakeHandler; * @author Rossen Stoyanchev * @since 4.0 */ -public class WebSocketHttpRequestHandler implements HttpRequestHandler, BeanFactoryAware { +public class WebSocketHttpRequestHandler implements HttpRequestHandler { private HandshakeHandler handshakeHandler; @@ -53,13 +51,13 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, BeanFact public WebSocketHttpRequestHandler(WebSocketHandler webSocketHandler) { Assert.notNull(webSocketHandler, "webSocketHandler is required"); - this.handlerProvider = new HandlerProvider(webSocketHandler); + this.handlerProvider = new SimpleHandlerProvider(webSocketHandler); this.handshakeHandler = new DefaultHandshakeHandler(); } - public WebSocketHttpRequestHandler( Class webSocketHandlerClass) { - Assert.notNull(webSocketHandlerClass, "webSocketHandlerClass is required"); - this.handlerProvider = new HandlerProvider(webSocketHandlerClass); + public WebSocketHttpRequestHandler( HandlerProvider handlerProvider) { + Assert.notNull(handlerProvider, "handlerProvider is required"); + this.handlerProvider = handlerProvider; } public void setHandshakeHandler(HandshakeHandler handshakeHandler) { @@ -67,11 +65,6 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, BeanFact this.handshakeHandler = handshakeHandler; } - @Override - public void setBeanFactory(BeanFactory beanFactory) throws BeansException { - this.handlerProvider.setBeanFactory(beanFactory); - } - @Override public void handleRequest(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { @@ -80,8 +73,7 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, BeanFact ServerHttpResponse httpResponse = new ServletServerHttpResponse(response); try { - WebSocketHandler webSocketHandler = this.handlerProvider.getHandler(); - this.handshakeHandler.doHandshake(httpRequest, httpResponse, webSocketHandler); + this.handshakeHandler.doHandshake(httpRequest, httpResponse, this.handlerProvider); } catch (Exception e) { // TODO diff --git a/spring-websocket/src/main/java/org/springframework/websocket/support/BeanCreatingHandlerProvider.java b/spring-websocket/src/main/java/org/springframework/websocket/support/BeanCreatingHandlerProvider.java new file mode 100644 index 00000000000..1c6567dc9ac --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/support/BeanCreatingHandlerProvider.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket.support; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.beans.BeanUtils; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.config.AutowireCapableBeanFactory; +import org.springframework.util.Assert; +import org.springframework.websocket.HandlerProvider; + + +/** + * A {@link HandlerProvider} that uses {@link AutowireCapableBeanFactory#createBean(Class) + * creating a fresh instance every time #getHandler() is called. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class BeanCreatingHandlerProvider implements HandlerProvider, BeanFactoryAware { + + private static final Log logger = LogFactory.getLog(BeanCreatingHandlerProvider.class); + + private final Class handlerClass; + + private AutowireCapableBeanFactory beanFactory; + + + public BeanCreatingHandlerProvider(Class handlerClass) { + Assert.notNull(handlerClass, "handlerClass is required"); + this.handlerClass = handlerClass; + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + if (beanFactory instanceof AutowireCapableBeanFactory) { + this.beanFactory = (AutowireCapableBeanFactory) beanFactory; + } + } + + public boolean isSingleton() { + return false; + } + + public Class getHandlerType() { + return this.handlerClass; + } + + public T getHandler() { + if (logger.isTraceEnabled()) { + logger.trace("Creating instance for handler type " + this.handlerClass); + } + if (this.beanFactory == null) { + logger.warn("No BeanFactory available, attempting to use default constructor"); + return BeanUtils.instantiate(this.handlerClass); + } + else { + return this.beanFactory.createBean(this.handlerClass); + } + } + + @Override + public void destroy(T handler) { + if (this.beanFactory != null) { + if (logger.isTraceEnabled()) { + logger.trace("Destroying handler instance " + handler); + } + this.beanFactory.destroyBean(handler); + } + } + + @Override + public String toString() { + return "BeanCreatingHandlerProvider [handlerClass=" + handlerClass + "]"; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/websocket/support/SimpleHandlerProvider.java b/spring-websocket/src/main/java/org/springframework/websocket/support/SimpleHandlerProvider.java new file mode 100644 index 00000000000..c880a64802d --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/support/SimpleHandlerProvider.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.websocket.support; + +import org.springframework.util.ClassUtils; +import org.springframework.websocket.HandlerProvider; + + +/** + * A {@link HandlerProvider} that returns a singleton instance. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class SimpleHandlerProvider implements HandlerProvider { + + private final T handler; + + + public SimpleHandlerProvider(T handler) { + this.handler = handler; + } + + @Override + public boolean isSingleton() { + return true; + } + + @Override + public Class getHandlerType() { + return ClassUtils.getUserClass(this.handler); + } + + @Override + public T getHandler() { + return this.handler; + } + + @Override + public void destroy(T handler) { + } + + @Override + public String toString() { + return "SimpleHandlerProvider [handler=" + handler + "]"; + } + +}