diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java index dbf7860de84..952102d5d1d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java @@ -19,6 +19,7 @@ package org.springframework.web.socket.server.standard; import java.util.Arrays; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; import javax.servlet.ServletContext; import javax.websocket.DeploymentException; @@ -27,7 +28,7 @@ import javax.websocket.server.ServerEndpoint; import javax.websocket.server.ServerEndpointConfig; import org.springframework.beans.factory.InitializingBean; -import org.springframework.beans.factory.config.BeanPostProcessor; +import org.springframework.beans.factory.SmartInitializingSingleton; import org.springframework.context.ApplicationContext; import org.springframework.util.Assert; import org.springframework.web.context.support.WebApplicationObjectSupport; @@ -41,7 +42,7 @@ import org.springframework.web.context.support.WebApplicationObjectSupport; * *

When this class is used, by declaring it in Spring configuration, it should be * possible to turn off a Servlet container's scan for WebSocket endpoints. This can be - * done with the help of the {@code } element in web.xml. + * done with the help of the {@code } element in {@code web.xml}. * * @author Rossen Stoyanchev * @author Juergen Hoeller @@ -50,19 +51,27 @@ import org.springframework.web.context.support.WebApplicationObjectSupport; * @see SpringConfigurator * @see ServletServerContainerFactoryBean */ -public class ServerEndpointExporter extends WebApplicationObjectSupport implements BeanPostProcessor, InitializingBean { - - private ServerContainer serverContainer; +public class ServerEndpointExporter extends WebApplicationObjectSupport + implements InitializingBean, SmartInitializingSingleton { private List> annotatedEndpointClasses; - private Set> annotatedEndpointBeanTypes; + private ServerContainer serverContainer; + + /** + * Explicitly list annotated endpoint types that should be registered on startup. This + * can be done if you wish to turn off a Servlet container's scan for endpoints, which + * goes through all 3rd party jars in the, and rely on Spring configuration instead. + * @param annotatedEndpointClasses {@link ServerEndpoint}-annotated types + */ + public void setAnnotatedEndpointClasses(Class... annotatedEndpointClasses) { + this.annotatedEndpointClasses = Arrays.asList(annotatedEndpointClasses); + } /** * Set the JSR-356 {@link ServerContainer} to use for endpoint registration. * If not set, the container is going to be retrieved via the {@code ServletContext}. - * @since 4.1 */ public void setServerContainer(ServerContainer serverContainer) { this.serverContainer = serverContainer; @@ -75,33 +84,6 @@ public class ServerEndpointExporter extends WebApplicationObjectSupport implemen return this.serverContainer; } - /** - * Explicitly list annotated endpoint types that should be registered on startup. This - * can be done if you wish to turn off a Servlet container's scan for endpoints, which - * goes through all 3rd party jars in the, and rely on Spring configuration instead. - * @param annotatedEndpointClasses {@link ServerEndpoint}-annotated types - */ - public void setAnnotatedEndpointClasses(Class... annotatedEndpointClasses) { - this.annotatedEndpointClasses = Arrays.asList(annotatedEndpointClasses); - } - - @Override - protected void initApplicationContext(ApplicationContext context) { - // Initializes ServletContext given a WebApplicationContext - super.initApplicationContext(context); - - // Retrieve beans which are annotated with @ServerEndpoint - this.annotatedEndpointBeanTypes = new LinkedHashSet>(); - String[] beanNames = context.getBeanNamesForAnnotation(ServerEndpoint.class); - for (String beanName : beanNames) { - Class beanType = context.getType(beanName); - if (logger.isInfoEnabled()) { - logger.info("Detected @ServerEndpoint bean '" + beanName + "', registering it as an endpoint by type"); - } - this.annotatedEndpointBeanTypes.add(beanType); - } - } - @Override protected void initServletContext(ServletContext servletContext) { if (this.serverContainer == null) { @@ -110,64 +92,76 @@ public class ServerEndpointExporter extends WebApplicationObjectSupport implemen } } + @Override + protected boolean isContextRequired() { + return false; + } @Override public void afterPropertiesSet() { Assert.state(getServerContainer() != null, "javax.websocket.server.ServerContainer not available"); + } + + @Override + public void afterSingletonsInstantiated() { registerEndpoints(); } + /** - * Actually register the endpoints. Called by {@link #afterPropertiesSet()}. - * @since 4.1 + * Actually register the endpoints. Called by {@link #afterSingletonsInstantiated()}. */ protected void registerEndpoints() { Set> endpointClasses = new LinkedHashSet>(); if (this.annotatedEndpointClasses != null) { endpointClasses.addAll(this.annotatedEndpointClasses); } - if (this.annotatedEndpointBeanTypes != null) { - endpointClasses.addAll(this.annotatedEndpointBeanTypes); + + ApplicationContext context = getApplicationContext(); + if (context != null) { + String[] endpointNames = context.getBeanNamesForAnnotation(ServerEndpoint.class); + for (String beanName : endpointNames) { + Class beanType = context.getType(beanName); + endpointClasses.add(beanType); + } } + for (Class endpointClass : endpointClasses) { registerEndpoint(endpointClass); } + + if (context != null) { + Map endpointConfigMap = context.getBeansOfType(ServerEndpointConfig.class); + for (Map.Entry configEntry : endpointConfigMap.entrySet()) { + String beanName = configEntry.getKey(); + ServerEndpointConfig endpointConfig = configEntry.getValue(); + registerEndpoint(endpointConfig); + } + } } private void registerEndpoint(Class endpointClass) { try { if (logger.isInfoEnabled()) { - logger.info("Registering @ServerEndpoint type: " + endpointClass); + logger.info("Registering @ServerEndpoint class: " + endpointClass); } getServerContainer().addEndpoint(endpointClass); } catch (DeploymentException ex) { - throw new IllegalStateException("Failed to register @ServerEndpoint type " + endpointClass, ex); + throw new IllegalStateException("Failed to register @ServerEndpoint class: " + endpointClass, ex); } } - - @Override - public Object postProcessBeforeInitialization(Object bean, String beanName) { - return bean; - } - - @Override - public Object postProcessAfterInitialization(Object bean, String beanName) { - if (bean instanceof ServerEndpointConfig) { - ServerEndpointConfig endpointConfig = (ServerEndpointConfig) bean; - try { - if (logger.isInfoEnabled()) { - logger.info("Registering bean '" + beanName + - "' as javax.websocket.Endpoint under path " + endpointConfig.getPath()); - } - getServerContainer().addEndpoint(endpointConfig); - } - catch (DeploymentException ex) { - throw new IllegalStateException("Failed to deploy Endpoint bean with name '" + bean + "'", ex); + private void registerEndpoint(ServerEndpointConfig endpointConfig) { + try { + if (logger.isInfoEnabled()) { + logger.info("Registering ServerEndpointConfig: " + endpointConfig); } + getServerContainer().addEndpoint(endpointConfig); + } + catch (DeploymentException ex) { + throw new IllegalStateException("Failed to register ServerEndpointConfig: " + endpointConfig, ex); } - return bean; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointRegistration.java index b70f9377bfc..f20993d9d8d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointRegistration.java @@ -73,7 +73,7 @@ public class ServerEndpointRegistration extends ServerEndpointConfig.Configurato /** * Create a new {@link ServerEndpointRegistration} instance from an - * {@code javax.webscoket.Endpoint} class. + * {@code javax.websocket.Endpoint} class. * @param path the endpoint path * @param endpointClass the endpoint class */ @@ -202,4 +202,9 @@ public class ServerEndpointRegistration extends ServerEndpointConfig.Configurato return super.getNegotiatedExtensions(installed, requested); } + + @Override + public String toString() { + return "ServerEndpointRegistration for path '" + getPath() + "': " + getEndpointClass(); + } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java index 2eea45a19de..0a32d8ea03d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java @@ -67,30 +67,33 @@ public class ServerEndpointExporterTests { @Test - public void addAnnotatedEndpointBeans() throws Exception { + public void addAnnotatedEndpointClasses() throws Exception { this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class); this.exporter.setApplicationContext(this.webAppContext); this.exporter.afterPropertiesSet(); + this.exporter.afterSingletonsInstantiated(); verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class); verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpointBean.class); } @Test - public void addAnnotatedEndpointBeansWithServletContextOnly() throws Exception { + public void addAnnotatedEndpointClassesWithServletContextOnly() throws Exception { this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class, AnnotatedDummyEndpointBean.class); this.exporter.setServletContext(this.servletContext); this.exporter.afterPropertiesSet(); + this.exporter.afterSingletonsInstantiated(); verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class); verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpointBean.class); } @Test - public void addAnnotatedEndpointBeansWithServerContainerOnly() throws Exception { + public void addAnnotatedEndpointClassesWithExplicitServerContainerOnly() throws Exception { this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class, AnnotatedDummyEndpointBean.class); this.exporter.setServerContainer(this.serverContainer); this.exporter.afterPropertiesSet(); + this.exporter.afterSingletonsInstantiated(); verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class); verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpointBean.class); @@ -98,31 +101,40 @@ public class ServerEndpointExporterTests { @Test public void addServerEndpointConfigBean() throws Exception { + ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint()); + this.webAppContext.getBeanFactory().registerSingleton("dummyEndpoint", endpointRegistration); + this.exporter.setApplicationContext(this.webAppContext); this.exporter.afterPropertiesSet(); + this.exporter.afterSingletonsInstantiated(); - ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint()); - this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint"); verify(this.serverContainer).addEndpoint(endpointRegistration); } @Test - public void addServerEndpointConfigBeanWithServletContextOnly() throws Exception { + public void addServerEndpointConfigBeanWithExplicitServletContext() throws Exception { + ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint()); + this.webAppContext.getBeanFactory().registerSingleton("dummyEndpoint", endpointRegistration); + this.exporter.setServletContext(this.servletContext); + this.exporter.setApplicationContext(this.webAppContext); this.exporter.afterPropertiesSet(); + this.exporter.afterSingletonsInstantiated(); - ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint()); - this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint"); verify(this.serverContainer).addEndpoint(endpointRegistration); } @Test - public void addServerEndpointConfigBeanWithServerContainerOnly() throws Exception { + public void addServerEndpointConfigBeanWithExplicitServerContainer() throws Exception { + ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint()); + this.webAppContext.getBeanFactory().registerSingleton("dummyEndpoint", endpointRegistration); + this.servletContext.removeAttribute("javax.websocket.server.ServerContainer"); + this.exporter.setServerContainer(this.serverContainer); + this.exporter.setApplicationContext(this.webAppContext); this.exporter.afterPropertiesSet(); + this.exporter.afterSingletonsInstantiated(); - ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint()); - this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint"); verify(this.serverContainer).addEndpoint(endpointRegistration); }