Browse Source

ServerEndpointExporter can initialize itself based on a late-provided ServletContext as well (for Boot)

Also allows for direct setting of a ServerContainer and for custom triggering of endpoint registration.

Issue: SPR-12109
(cherry picked from commit 11805b6)
pull/642/head
Juergen Hoeller 11 years ago
parent
commit
379e5abd83
  1. 142
      spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java
  2. 64
      spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java

142
spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java

@ -16,27 +16,21 @@ @@ -16,27 +16,21 @@
package org.springframework.web.socket.server.standard;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.servlet.ServletContext;
import javax.websocket.DeploymentException;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.context.support.WebApplicationObjectSupport;
/**
* Detects beans of type {@link javax.websocket.server.ServerEndpointConfig} and registers
@ -50,24 +44,36 @@ import org.springframework.util.ReflectionUtils; @@ -50,24 +44,36 @@ import org.springframework.util.ReflectionUtils;
* done with the help of the {@code <absolute-ordering>} element in web.xml.
*
* @author Rossen Stoyanchev
* @author Juergen Hoeller
* @since 4.0
* @see ServerEndpointRegistration
* @see SpringConfigurator
* @see ServletServerContainerFactoryBean
*/
public class ServerEndpointExporter implements InitializingBean, BeanPostProcessor, ApplicationContextAware {
private static final Log logger = LogFactory.getLog(ServerEndpointExporter.class);
public class ServerEndpointExporter extends WebApplicationObjectSupport implements BeanPostProcessor, InitializingBean {
private ServerContainer serverContainer;
private final List<Class<?>> annotatedEndpointClasses = new ArrayList<Class<?>>();
private List<Class<?>> annotatedEndpointClasses;
private final List<Class<?>> annotatedEndpointBeanTypes = new ArrayList<Class<?>>();
private Set<Class<?>> annotatedEndpointBeanTypes;
private ApplicationContext applicationContext;
private ServerContainer serverContainer;
/**
* Set the JSR-356 {@link ServerContainer} to use for endpoint registration.
* If not set, the container is going to be retrieved via the {@code ServletContext}.
* @since 4.1
*/
public void setServerContainer(ServerContainer serverContainer) {
this.serverContainer = serverContainer;
}
/**
* Return the JSR-356 {@link ServerContainer} to use for endpoint registration.
*/
protected ServerContainer getServerContainer() {
return this.serverContainer;
}
/**
* Explicitly list annotated endpoint types that should be registered on startup. This
@ -76,17 +82,19 @@ public class ServerEndpointExporter implements InitializingBean, BeanPostProcess @@ -76,17 +82,19 @@ public class ServerEndpointExporter implements InitializingBean, BeanPostProcess
* @param annotatedEndpointClasses {@link ServerEndpoint}-annotated types
*/
public void setAnnotatedEndpointClasses(Class<?>... annotatedEndpointClasses) {
this.annotatedEndpointClasses.clear();
this.annotatedEndpointClasses.addAll(Arrays.asList(annotatedEndpointClasses));
this.annotatedEndpointClasses = Arrays.asList(annotatedEndpointClasses);
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) {
this.applicationContext = applicationContext;
this.serverContainer = getServerContainer();
Map<String, Object> beans = applicationContext.getBeansWithAnnotation(ServerEndpoint.class);
for (String beanName : beans.keySet()) {
Class<?> beanType = applicationContext.getType(beanName);
protected void initApplicationContext(ApplicationContext context) {
// Initializes ServletContext given a WebApplicationContext
super.initApplicationContext(context);
// Retrieve beans which are annotated with @ServerEndpoint
this.annotatedEndpointBeanTypes = new LinkedHashSet<Class<?>>();
String[] beanNames = context.getBeanNamesForAnnotation(ServerEndpoint.class);
for (String beanName : beanNames) {
Class<?> beanType = context.getType(beanName);
if (logger.isInfoEnabled()) {
logger.info("Detected @ServerEndpoint bean '" + beanName + "', registering it as an endpoint by type");
}
@ -94,66 +102,72 @@ public class ServerEndpointExporter implements InitializingBean, BeanPostProcess @@ -94,66 +102,72 @@ public class ServerEndpointExporter implements InitializingBean, BeanPostProcess
}
}
protected ServerContainer getServerContainer() {
Class<?> servletContextClass;
try {
servletContextClass = ClassUtils.forName("javax.servlet.ServletContext", getClass().getClassLoader());
@Override
protected void initServletContext(ServletContext servletContext) {
if (this.serverContainer == null) {
this.serverContainer =
(ServerContainer) servletContext.getAttribute("javax.websocket.server.ServerContainer");
}
}
@Override
public void afterPropertiesSet() {
Assert.state(getServerContainer() != null, "javax.websocket.server.ServerContainer not available");
registerEndpoints();
}
/**
* Actually register the endpoints. Called by {@link #afterPropertiesSet()}.
* @since 4.1
*/
protected void registerEndpoints() {
Set<Class<?>> endpointClasses = new LinkedHashSet<Class<?>>();
if (this.annotatedEndpointClasses != null) {
endpointClasses.addAll(this.annotatedEndpointClasses);
}
if (this.annotatedEndpointBeanTypes != null) {
endpointClasses.addAll(this.annotatedEndpointBeanTypes);
}
catch (Throwable ex) {
return null;
for (Class<?> endpointClass : endpointClasses) {
registerEndpoint(endpointClass);
}
}
private void registerEndpoint(Class<?> endpointClass) {
try {
Method getter = ReflectionUtils.findMethod(this.applicationContext.getClass(), "getServletContext");
Object servletContext = getter.invoke(this.applicationContext);
Method attrMethod = ReflectionUtils.findMethod(servletContextClass, "getAttribute", String.class);
return (ServerContainer) attrMethod.invoke(servletContext, "javax.websocket.server.ServerContainer");
if (logger.isInfoEnabled()) {
logger.info("Registering @ServerEndpoint type: " + endpointClass);
}
getServerContainer().addEndpoint(endpointClass);
}
catch (Exception ex) {
throw new IllegalStateException(
"Failed to get javax.websocket.server.ServerContainer via ServletContext attribute", ex);
catch (DeploymentException ex) {
throw new IllegalStateException("Failed to register @ServerEndpoint type " + endpointClass, ex);
}
}
@Override
public void afterPropertiesSet() throws Exception {
Assert.state(this.serverContainer != null, "javax.websocket.server.ServerContainer not available");
List<Class<?>> allClasses = new ArrayList<Class<?>>(this.annotatedEndpointClasses);
allClasses.addAll(this.annotatedEndpointBeanTypes);
for (Class<?> clazz : allClasses) {
try {
logger.info("Registering @ServerEndpoint type " + clazz);
this.serverContainer.addEndpoint(clazz);
}
catch (DeploymentException e) {
throw new IllegalStateException("Failed to register @ServerEndpoint type " + clazz, e);
}
}
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) {
return bean;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
public Object postProcessAfterInitialization(Object bean, String beanName) {
if (bean instanceof ServerEndpointConfig) {
ServerEndpointConfig sec = (ServerEndpointConfig) bean;
ServerEndpointConfig endpointConfig = (ServerEndpointConfig) bean;
try {
if (logger.isInfoEnabled()) {
logger.info("Registering bean '" + beanName +
"' as javax.websocket.Endpoint under path " + sec.getPath());
"' as javax.websocket.Endpoint under path " + endpointConfig.getPath());
}
getServerContainer().addEndpoint(sec);
getServerContainer().addEndpoint(endpointConfig);
}
catch (DeploymentException e) {
throw new IllegalStateException("Failed to deploy Endpoint bean " + bean, e);
catch (DeploymentException ex) {
throw new IllegalStateException("Failed to deploy Endpoint bean with name '" + bean + "'", ex);
}
}
return bean;
}
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
return bean;
}
}

64
spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
package org.springframework.web.socket.server.standard;
import javax.servlet.ServletContext;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
@ -24,6 +25,7 @@ import javax.websocket.server.ServerEndpoint; @@ -24,6 +25,7 @@ import javax.websocket.server.ServerEndpoint;
import org.junit.Before;
import org.junit.Test;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.mock.web.test.MockServletContext;
@ -35,37 +37,59 @@ import static org.mockito.Mockito.*; @@ -35,37 +37,59 @@ import static org.mockito.Mockito.*;
* Test fixture for {@link ServerEndpointExporter}.
*
* @author Rossen Stoyanchev
* @author Juergen Hoeller
*/
public class ServerEndpointExporterTests {
private ServerContainer serverContainer;
private ServerEndpointExporter exporter;
private ServletContext servletContext;
private AnnotationConfigWebApplicationContext webAppContext;
private ServerEndpointExporter exporter;
@Before
public void setup() {
this.serverContainer = mock(ServerContainer.class);
MockServletContext servletContext = new MockServletContext();
servletContext.setAttribute("javax.websocket.server.ServerContainer", this.serverContainer);
this.servletContext = new MockServletContext();
this.servletContext.setAttribute("javax.websocket.server.ServerContainer", this.serverContainer);
this.webAppContext = new AnnotationConfigWebApplicationContext();
this.webAppContext.register(Config.class);
this.webAppContext.setServletContext(servletContext);
this.webAppContext.setServletContext(this.servletContext);
this.webAppContext.refresh();
this.exporter = new ServerEndpointExporter();
this.exporter.setApplicationContext(this.webAppContext);
}
@Test
public void addAnnotatedEndpointBean() throws Exception {
public void addAnnotatedEndpointBeans() throws Exception {
this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class);
this.exporter.setApplicationContext(this.webAppContext);
this.exporter.afterPropertiesSet();
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class);
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpointBean.class);
}
@Test
public void addAnnotatedEndpointBeansWithServletContextOnly() throws Exception {
this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class, AnnotatedDummyEndpointBean.class);
this.exporter.setServletContext(this.servletContext);
this.exporter.afterPropertiesSet();
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class);
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpointBean.class);
}
@Test
public void addAnnotatedEndpointBeansWithServerContainerOnly() throws Exception {
this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class, AnnotatedDummyEndpointBean.class);
this.exporter.setServerContainer(this.serverContainer);
this.exporter.afterPropertiesSet();
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class);
@ -74,10 +98,31 @@ public class ServerEndpointExporterTests { @@ -74,10 +98,31 @@ public class ServerEndpointExporterTests {
@Test
public void addServerEndpointConfigBean() throws Exception {
this.exporter.setApplicationContext(this.webAppContext);
this.exporter.afterPropertiesSet();
ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint());
this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint");
verify(this.serverContainer).addEndpoint(endpointRegistration);
}
@Test
public void addServerEndpointConfigBeanWithServletContextOnly() throws Exception {
this.exporter.setServletContext(this.servletContext);
this.exporter.afterPropertiesSet();
ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint());
this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint");
verify(this.serverContainer).addEndpoint(endpointRegistration);
}
@Test
public void addServerEndpointConfigBeanWithServerContainerOnly() throws Exception {
this.exporter.setServerContainer(this.serverContainer);
this.exporter.afterPropertiesSet();
ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint());
this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint");
verify(this.serverContainer).addEndpoint(endpointRegistration);
}
@ -89,14 +134,17 @@ public class ServerEndpointExporterTests { @@ -89,14 +134,17 @@ public class ServerEndpointExporterTests {
}
}
@ServerEndpoint("/path")
private static class AnnotatedDummyEndpoint {
}
@ServerEndpoint("/path")
private static class AnnotatedDummyEndpointBean {
}
@Configuration
static class Config {

Loading…
Cancel
Save