diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointExporter.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointExporter.java index 5a601bfe88b..9c287e7071f 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointExporter.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointExporter.java @@ -15,33 +15,41 @@ */ package org.springframework.websocket.server.endpoint; +import java.util.Map; + import javax.websocket.DeploymentException; import javax.websocket.server.ServerContainer; import javax.websocket.server.ServerContainerProvider; +import javax.websocket.server.ServerEndpoint; import javax.websocket.server.ServerEndpointConfig; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; /** * BeanPostProcessor that detects beans of type - * {@link javax.websocket.server.ServerEndpointConfig} and registers them with a standard - * Java WebSocket runtime and also configures the underlying - * {@link javax.websocket.server.ServerContainer}. + * {@link javax.websocket.server.ServerEndpointConfig} and registers the corresponding + * {@link javax.websocket.Endpoint} with a standard Java WebSocket runtime. * *

If the runtime is a Servlet container, use {@link ServletEndpointExporter}. * * @author Rossen Stoyanchev * @since 4.0 */ -public class EndpointExporter implements BeanPostProcessor, InitializingBean { +public class EndpointExporter implements InitializingBean, BeanPostProcessor, BeanFactoryAware { private static Log logger = LogFactory.getLog(EndpointExporter.class); + private Class[] annotatedEndpointClasses; + private Long maxSessionIdleTimeout; private Integer maxTextMessageBufferSize; @@ -49,6 +57,14 @@ public class EndpointExporter implements BeanPostProcessor, InitializingBean { private Integer maxBinaryMessageBufferSize; + /** + * TODO + * @param annotatedEndpointClasses + */ + public void setAnnotatedEndpointClasses(Class... annotatedEndpointClasses) { + this.annotatedEndpointClasses = annotatedEndpointClasses; + } + /** * If this property set it is in turn used to configure * {@link ServerContainer#setDefaultMaxSessionIdleTimeout(long)}. @@ -85,8 +101,29 @@ public class EndpointExporter implements BeanPostProcessor, InitializingBean { return this.maxBinaryMessageBufferSize; } + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + if (beanFactory instanceof ListableBeanFactory) { + ListableBeanFactory lbf = (ListableBeanFactory) beanFactory; + Map annotatedEndpoints = lbf.getBeansWithAnnotation(ServerEndpoint.class); + for (String beanName : annotatedEndpoints.keySet()) { + Class beanType = lbf.getType(beanName); + try { + if (logger.isInfoEnabled()) { + logger.info("Detected @ServerEndpoint bean '" + beanName + "', registering it as an endpoint by type"); + } + ServerContainerProvider.getServerContainer().addEndpoint(beanType); + } + catch (DeploymentException e) { + throw new IllegalStateException("Failed to register @ServerEndpoint bean type " + beanName, e); + } + } + } + } + @Override public void afterPropertiesSet() throws Exception { + ServerContainer serverContainer = ServerContainerProvider.getServerContainer(); Assert.notNull(serverContainer, "javax.websocket.server.ServerContainer not available"); @@ -99,19 +136,33 @@ public class EndpointExporter implements BeanPostProcessor, InitializingBean { if (this.maxBinaryMessageBufferSize != null) { serverContainer.setDefaultMaxBinaryMessageBufferSize(this.maxBinaryMessageBufferSize); } + + if (!ObjectUtils.isEmpty(this.annotatedEndpointClasses)) { + for (Class clazz : this.annotatedEndpointClasses) { + try { + logger.info("Registering @ServerEndpoint type " + clazz); + serverContainer.addEndpoint(clazz); + } + catch (DeploymentException e) { + throw new IllegalStateException("Failed to register @ServerEndpoint type " + clazz, e); + } + } + } } @Override public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { if (bean instanceof ServerEndpointConfig) { ServerEndpointConfig sec = (ServerEndpointConfig) bean; - ServerContainer serverContainer = ServerContainerProvider.getServerContainer(); try { - logger.debug("Registering javax.websocket.Endpoint for path " + sec.getPath()); - serverContainer.addEndpoint(sec); + if (logger.isInfoEnabled()) { + logger.info("Registering bean '" + beanName + + "' as javax.websocket.Endpoint under path " + sec.getPath()); + } + ServerContainerProvider.getServerContainer().addEndpoint(sec); } catch (DeploymentException e) { - throw new IllegalStateException("Failed to deploy Endpoint " + bean, e); + throw new IllegalStateException("Failed to deploy Endpoint bean " + bean, e); } } return bean; diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java index 17ee09328bc..385f8ece2b9 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java @@ -36,6 +36,7 @@ import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.web.context.ContextLoader; +import org.springframework.web.context.WebApplicationContext; import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.endpoint.StandardWebSocketHandlerAdapter; @@ -57,6 +58,8 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw private final String path; + private final Class endpointClass; + private final Object bean; private List subprotocols = new ArrayList(); @@ -70,20 +73,33 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw private final Configurator configurator = new Configurator() {}; - // ContextLoader.getCurrentWebApplicationContext().getAutowireCapableBeanFactory().createBean(Class) + /** + * Class constructor with the {@code javax.webscoket.Endpoint} class. + * TODO + * + * @param path + * @param endpointClass + */ + public EndpointRegistration(String path, Class endpointClass) { + this(path, endpointClass, null); + } + + public EndpointRegistration(String path, Object bean) { + this(path, null, bean); + } public EndpointRegistration(String path, String beanName) { - Assert.hasText(path, "path must not be empty"); - Assert.notNull(beanName, "beanName is required"); - this.path = path; - this.bean = beanName; + this(path, null, beanName); } - public EndpointRegistration(String path, Object bean) { + private EndpointRegistration(String path, Class endpointClass, Object bean) { Assert.hasText(path, "path must not be empty"); - Assert.notNull(bean, "bean is required"); + Assert.isTrue((endpointClass != null || bean != null), "Neither endpoint class nor endpoint bean provided"); this.path = path; + this.endpointClass = endpointClass; this.bean = bean; + // this will fail if the bean is not a valid Endpoint type + getEndpointClass(); } @Override @@ -94,20 +110,34 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw @SuppressWarnings("unchecked") @Override public Class getEndpointClass() { + if (this.endpointClass != null) { + return this.endpointClass; + } Class beanClass = this.bean.getClass(); if (beanClass.equals(String.class)) { beanClass = this.beanFactory.getType((String) this.bean); } beanClass = ClassUtils.getUserClass(beanClass); - if (WebSocketHandler.class.isAssignableFrom(beanClass)) { + if (Endpoint.class.isAssignableFrom(beanClass)) { + return (Class) beanClass; + } + else if (WebSocketHandler.class.isAssignableFrom(beanClass)) { return StandardWebSocketHandlerAdapter.class; } else { - return (Class) beanClass; + throw new IllegalStateException("Invalid endpoint bean: must be of type ... TODO "); } } public Endpoint getEndpoint() { + if (this.endpointClass != null) { + WebApplicationContext wac = ContextLoader.getCurrentWebApplicationContext(); + if (wac == null) { + throw new IllegalStateException("Failed to find WebApplicationContext. " + + "Was org.springframework.web.context.ContextLoader used to load the WebApplicationContext?"); + } + return wac.getAutowireCapableBeanFactory().createBean(this.endpointClass); + } Object bean = this.bean; if (this.bean instanceof String) { bean = this.beanFactory.getBean((String) this.bean); diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java new file mode 100644 index 00000000000..d412fd9e299 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.websocket.server.endpoint; + +import java.util.Map; + +import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig.Configurator; + +import org.springframework.web.context.ContextLoader; +import org.springframework.web.context.WebApplicationContext; + + +/** + * This should be used in conjuction with {@link ServerEndpoint @ServerEndpoint} classes. + * + *

For {@link javax.websocket.Endpoint}, see {@link EndpointExporter}. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class SpringConfigurator extends Configurator { + + + @Override + public T getEndpointInstance(Class endpointClass) throws InstantiationException { + + WebApplicationContext wac = ContextLoader.getCurrentWebApplicationContext(); + if (wac == null) { + throw new IllegalStateException("Failed to find WebApplicationContext. " + + "Was org.springframework.web.context.ContextLoader used to load the WebApplicationContext?"); + } + + Map beans = wac.getBeansOfType(endpointClass); + if (beans.isEmpty()) { + // Initialize a new bean instance + return wac.getAutowireCapableBeanFactory().createBean(endpointClass); + } + if (beans.size() == 1) { + // Return the matching bean instance + return beans.values().iterator().next(); + } + else { + // This should never happen (@ServerEndpoint has a single path mapping) .. + throw new IllegalStateException("Found more than one matching beans of type " + endpointClass); + } + } + +}