From 379e5abd833d5126450d218e427460ad652cd5fc Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Thu, 21 Aug 2014 18:40:35 +0200 Subject: [PATCH] ServerEndpointExporter can initialize itself based on a late-provided ServletContext as well (for Boot) Also allows for direct setting of a ServerContainer and for custom triggering of endpoint registration. Issue: SPR-12109 (cherry picked from commit 11805b6) --- .../standard/ServerEndpointExporter.java | 142 ++++++++++-------- .../standard/ServerEndpointExporterTests.java | 64 +++++++- 2 files changed, 134 insertions(+), 72 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java index 48b394b9c26..dbf7860de84 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java @@ -16,27 +16,21 @@ package org.springframework.web.socket.server.standard; -import java.lang.reflect.Method; -import java.util.ArrayList; import java.util.Arrays; +import java.util.LinkedHashSet; import java.util.List; -import java.util.Map; +import java.util.Set; +import javax.servlet.ServletContext; import javax.websocket.DeploymentException; import javax.websocket.server.ServerContainer; import javax.websocket.server.ServerEndpoint; import javax.websocket.server.ServerEndpointConfig; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -import org.springframework.beans.BeansException; import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.context.ApplicationContext; -import org.springframework.context.ApplicationContextAware; import org.springframework.util.Assert; -import org.springframework.util.ClassUtils; -import org.springframework.util.ReflectionUtils; +import org.springframework.web.context.support.WebApplicationObjectSupport; /** * Detects beans of type {@link javax.websocket.server.ServerEndpointConfig} and registers @@ -50,24 +44,36 @@ import org.springframework.util.ReflectionUtils; * done with the help of the {@code } element in web.xml. * * @author Rossen Stoyanchev + * @author Juergen Hoeller * @since 4.0 * @see ServerEndpointRegistration * @see SpringConfigurator * @see ServletServerContainerFactoryBean */ -public class ServerEndpointExporter implements InitializingBean, BeanPostProcessor, ApplicationContextAware { - - private static final Log logger = LogFactory.getLog(ServerEndpointExporter.class); +public class ServerEndpointExporter extends WebApplicationObjectSupport implements BeanPostProcessor, InitializingBean { + private ServerContainer serverContainer; - private final List> annotatedEndpointClasses = new ArrayList>(); + private List> annotatedEndpointClasses; - private final List> annotatedEndpointBeanTypes = new ArrayList>(); + private Set> annotatedEndpointBeanTypes; - private ApplicationContext applicationContext; - private ServerContainer serverContainer; + /** + * Set the JSR-356 {@link ServerContainer} to use for endpoint registration. + * If not set, the container is going to be retrieved via the {@code ServletContext}. + * @since 4.1 + */ + public void setServerContainer(ServerContainer serverContainer) { + this.serverContainer = serverContainer; + } + /** + * Return the JSR-356 {@link ServerContainer} to use for endpoint registration. + */ + protected ServerContainer getServerContainer() { + return this.serverContainer; + } /** * Explicitly list annotated endpoint types that should be registered on startup. This @@ -76,17 +82,19 @@ public class ServerEndpointExporter implements InitializingBean, BeanPostProcess * @param annotatedEndpointClasses {@link ServerEndpoint}-annotated types */ public void setAnnotatedEndpointClasses(Class... annotatedEndpointClasses) { - this.annotatedEndpointClasses.clear(); - this.annotatedEndpointClasses.addAll(Arrays.asList(annotatedEndpointClasses)); + this.annotatedEndpointClasses = Arrays.asList(annotatedEndpointClasses); } @Override - public void setApplicationContext(ApplicationContext applicationContext) { - this.applicationContext = applicationContext; - this.serverContainer = getServerContainer(); - Map beans = applicationContext.getBeansWithAnnotation(ServerEndpoint.class); - for (String beanName : beans.keySet()) { - Class beanType = applicationContext.getType(beanName); + protected void initApplicationContext(ApplicationContext context) { + // Initializes ServletContext given a WebApplicationContext + super.initApplicationContext(context); + + // Retrieve beans which are annotated with @ServerEndpoint + this.annotatedEndpointBeanTypes = new LinkedHashSet>(); + String[] beanNames = context.getBeanNamesForAnnotation(ServerEndpoint.class); + for (String beanName : beanNames) { + Class beanType = context.getType(beanName); if (logger.isInfoEnabled()) { logger.info("Detected @ServerEndpoint bean '" + beanName + "', registering it as an endpoint by type"); } @@ -94,66 +102,72 @@ public class ServerEndpointExporter implements InitializingBean, BeanPostProcess } } - protected ServerContainer getServerContainer() { - Class servletContextClass; - try { - servletContextClass = ClassUtils.forName("javax.servlet.ServletContext", getClass().getClassLoader()); + @Override + protected void initServletContext(ServletContext servletContext) { + if (this.serverContainer == null) { + this.serverContainer = + (ServerContainer) servletContext.getAttribute("javax.websocket.server.ServerContainer"); + } + } + + + @Override + public void afterPropertiesSet() { + Assert.state(getServerContainer() != null, "javax.websocket.server.ServerContainer not available"); + registerEndpoints(); + } + + /** + * Actually register the endpoints. Called by {@link #afterPropertiesSet()}. + * @since 4.1 + */ + protected void registerEndpoints() { + Set> endpointClasses = new LinkedHashSet>(); + if (this.annotatedEndpointClasses != null) { + endpointClasses.addAll(this.annotatedEndpointClasses); + } + if (this.annotatedEndpointBeanTypes != null) { + endpointClasses.addAll(this.annotatedEndpointBeanTypes); } - catch (Throwable ex) { - return null; + for (Class endpointClass : endpointClasses) { + registerEndpoint(endpointClass); } + } + private void registerEndpoint(Class endpointClass) { try { - Method getter = ReflectionUtils.findMethod(this.applicationContext.getClass(), "getServletContext"); - Object servletContext = getter.invoke(this.applicationContext); - Method attrMethod = ReflectionUtils.findMethod(servletContextClass, "getAttribute", String.class); - return (ServerContainer) attrMethod.invoke(servletContext, "javax.websocket.server.ServerContainer"); + if (logger.isInfoEnabled()) { + logger.info("Registering @ServerEndpoint type: " + endpointClass); + } + getServerContainer().addEndpoint(endpointClass); } - catch (Exception ex) { - throw new IllegalStateException( - "Failed to get javax.websocket.server.ServerContainer via ServletContext attribute", ex); + catch (DeploymentException ex) { + throw new IllegalStateException("Failed to register @ServerEndpoint type " + endpointClass, ex); } } - @Override - public void afterPropertiesSet() throws Exception { - Assert.state(this.serverContainer != null, "javax.websocket.server.ServerContainer not available"); - - List> allClasses = new ArrayList>(this.annotatedEndpointClasses); - allClasses.addAll(this.annotatedEndpointBeanTypes); - for (Class clazz : allClasses) { - try { - logger.info("Registering @ServerEndpoint type " + clazz); - this.serverContainer.addEndpoint(clazz); - } - catch (DeploymentException e) { - throw new IllegalStateException("Failed to register @ServerEndpoint type " + clazz, e); - } - } + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) { + return bean; } @Override - public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { + public Object postProcessAfterInitialization(Object bean, String beanName) { if (bean instanceof ServerEndpointConfig) { - ServerEndpointConfig sec = (ServerEndpointConfig) bean; + ServerEndpointConfig endpointConfig = (ServerEndpointConfig) bean; try { if (logger.isInfoEnabled()) { logger.info("Registering bean '" + beanName + - "' as javax.websocket.Endpoint under path " + sec.getPath()); + "' as javax.websocket.Endpoint under path " + endpointConfig.getPath()); } - getServerContainer().addEndpoint(sec); + getServerContainer().addEndpoint(endpointConfig); } - catch (DeploymentException e) { - throw new IllegalStateException("Failed to deploy Endpoint bean " + bean, e); + catch (DeploymentException ex) { + throw new IllegalStateException("Failed to deploy Endpoint bean with name '" + bean + "'", ex); } } return bean; } - @Override - public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { - return bean; - } - } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java index 9f25850a8cf..2eea45a19de 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.web.socket.server.standard; +import javax.servlet.ServletContext; import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; import javax.websocket.Session; @@ -24,6 +25,7 @@ import javax.websocket.server.ServerEndpoint; import org.junit.Before; import org.junit.Test; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.mock.web.test.MockServletContext; @@ -35,37 +37,59 @@ import static org.mockito.Mockito.*; * Test fixture for {@link ServerEndpointExporter}. * * @author Rossen Stoyanchev + * @author Juergen Hoeller */ public class ServerEndpointExporterTests { private ServerContainer serverContainer; - private ServerEndpointExporter exporter; + private ServletContext servletContext; private AnnotationConfigWebApplicationContext webAppContext; + private ServerEndpointExporter exporter; + @Before public void setup() { this.serverContainer = mock(ServerContainer.class); - MockServletContext servletContext = new MockServletContext(); - servletContext.setAttribute("javax.websocket.server.ServerContainer", this.serverContainer); + this.servletContext = new MockServletContext(); + this.servletContext.setAttribute("javax.websocket.server.ServerContainer", this.serverContainer); this.webAppContext = new AnnotationConfigWebApplicationContext(); this.webAppContext.register(Config.class); - this.webAppContext.setServletContext(servletContext); + this.webAppContext.setServletContext(this.servletContext); this.webAppContext.refresh(); this.exporter = new ServerEndpointExporter(); - this.exporter.setApplicationContext(this.webAppContext); } @Test - public void addAnnotatedEndpointBean() throws Exception { - + public void addAnnotatedEndpointBeans() throws Exception { this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class); + this.exporter.setApplicationContext(this.webAppContext); + this.exporter.afterPropertiesSet(); + + verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class); + verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpointBean.class); + } + + @Test + public void addAnnotatedEndpointBeansWithServletContextOnly() throws Exception { + this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class, AnnotatedDummyEndpointBean.class); + this.exporter.setServletContext(this.servletContext); + this.exporter.afterPropertiesSet(); + + verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class); + verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpointBean.class); + } + + @Test + public void addAnnotatedEndpointBeansWithServerContainerOnly() throws Exception { + this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class, AnnotatedDummyEndpointBean.class); + this.exporter.setServerContainer(this.serverContainer); this.exporter.afterPropertiesSet(); verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class); @@ -74,10 +98,31 @@ public class ServerEndpointExporterTests { @Test public void addServerEndpointConfigBean() throws Exception { + this.exporter.setApplicationContext(this.webAppContext); + this.exporter.afterPropertiesSet(); + + ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint()); + this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint"); + verify(this.serverContainer).addEndpoint(endpointRegistration); + } + + @Test + public void addServerEndpointConfigBeanWithServletContextOnly() throws Exception { + this.exporter.setServletContext(this.servletContext); + this.exporter.afterPropertiesSet(); ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint()); this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint"); + verify(this.serverContainer).addEndpoint(endpointRegistration); + } + + @Test + public void addServerEndpointConfigBeanWithServerContainerOnly() throws Exception { + this.exporter.setServerContainer(this.serverContainer); + this.exporter.afterPropertiesSet(); + ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint()); + this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint"); verify(this.serverContainer).addEndpoint(endpointRegistration); } @@ -89,14 +134,17 @@ public class ServerEndpointExporterTests { } } + @ServerEndpoint("/path") private static class AnnotatedDummyEndpoint { } + @ServerEndpoint("/path") private static class AnnotatedDummyEndpointBean { } + @Configuration static class Config {