diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointExporter.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointExporter.java
index 5a601bfe88b..9c287e7071f 100644
--- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointExporter.java
+++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointExporter.java
@@ -15,33 +15,41 @@
*/
package org.springframework.websocket.server.endpoint;
+import java.util.Map;
+
import javax.websocket.DeploymentException;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerContainerProvider;
+import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeansException;
+import org.springframework.beans.factory.BeanFactory;
+import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.InitializingBean;
+import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.util.Assert;
+import org.springframework.util.ObjectUtils;
/**
* BeanPostProcessor that detects beans of type
- * {@link javax.websocket.server.ServerEndpointConfig} and registers them with a standard
- * Java WebSocket runtime and also configures the underlying
- * {@link javax.websocket.server.ServerContainer}.
+ * {@link javax.websocket.server.ServerEndpointConfig} and registers the corresponding
+ * {@link javax.websocket.Endpoint} with a standard Java WebSocket runtime.
*
*
If the runtime is a Servlet container, use {@link ServletEndpointExporter}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
-public class EndpointExporter implements BeanPostProcessor, InitializingBean {
+public class EndpointExporter implements InitializingBean, BeanPostProcessor, BeanFactoryAware {
private static Log logger = LogFactory.getLog(EndpointExporter.class);
+ private Class>[] annotatedEndpointClasses;
+
private Long maxSessionIdleTimeout;
private Integer maxTextMessageBufferSize;
@@ -49,6 +57,14 @@ public class EndpointExporter implements BeanPostProcessor, InitializingBean {
private Integer maxBinaryMessageBufferSize;
+ /**
+ * TODO
+ * @param annotatedEndpointClasses
+ */
+ public void setAnnotatedEndpointClasses(Class>... annotatedEndpointClasses) {
+ this.annotatedEndpointClasses = annotatedEndpointClasses;
+ }
+
/**
* If this property set it is in turn used to configure
* {@link ServerContainer#setDefaultMaxSessionIdleTimeout(long)}.
@@ -85,8 +101,29 @@ public class EndpointExporter implements BeanPostProcessor, InitializingBean {
return this.maxBinaryMessageBufferSize;
}
+ @Override
+ public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
+ if (beanFactory instanceof ListableBeanFactory) {
+ ListableBeanFactory lbf = (ListableBeanFactory) beanFactory;
+ Map annotatedEndpoints = lbf.getBeansWithAnnotation(ServerEndpoint.class);
+ for (String beanName : annotatedEndpoints.keySet()) {
+ Class> beanType = lbf.getType(beanName);
+ try {
+ if (logger.isInfoEnabled()) {
+ logger.info("Detected @ServerEndpoint bean '" + beanName + "', registering it as an endpoint by type");
+ }
+ ServerContainerProvider.getServerContainer().addEndpoint(beanType);
+ }
+ catch (DeploymentException e) {
+ throw new IllegalStateException("Failed to register @ServerEndpoint bean type " + beanName, e);
+ }
+ }
+ }
+ }
+
@Override
public void afterPropertiesSet() throws Exception {
+
ServerContainer serverContainer = ServerContainerProvider.getServerContainer();
Assert.notNull(serverContainer, "javax.websocket.server.ServerContainer not available");
@@ -99,19 +136,33 @@ public class EndpointExporter implements BeanPostProcessor, InitializingBean {
if (this.maxBinaryMessageBufferSize != null) {
serverContainer.setDefaultMaxBinaryMessageBufferSize(this.maxBinaryMessageBufferSize);
}
+
+ if (!ObjectUtils.isEmpty(this.annotatedEndpointClasses)) {
+ for (Class> clazz : this.annotatedEndpointClasses) {
+ try {
+ logger.info("Registering @ServerEndpoint type " + clazz);
+ serverContainer.addEndpoint(clazz);
+ }
+ catch (DeploymentException e) {
+ throw new IllegalStateException("Failed to register @ServerEndpoint type " + clazz, e);
+ }
+ }
+ }
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof ServerEndpointConfig) {
ServerEndpointConfig sec = (ServerEndpointConfig) bean;
- ServerContainer serverContainer = ServerContainerProvider.getServerContainer();
try {
- logger.debug("Registering javax.websocket.Endpoint for path " + sec.getPath());
- serverContainer.addEndpoint(sec);
+ if (logger.isInfoEnabled()) {
+ logger.info("Registering bean '" + beanName
+ + "' as javax.websocket.Endpoint under path " + sec.getPath());
+ }
+ ServerContainerProvider.getServerContainer().addEndpoint(sec);
}
catch (DeploymentException e) {
- throw new IllegalStateException("Failed to deploy Endpoint " + bean, e);
+ throw new IllegalStateException("Failed to deploy Endpoint bean " + bean, e);
}
}
return bean;
diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java
index 17ee09328bc..385f8ece2b9 100644
--- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java
+++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java
@@ -36,6 +36,7 @@ import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.web.context.ContextLoader;
+import org.springframework.web.context.WebApplicationContext;
import org.springframework.websocket.WebSocketHandler;
import org.springframework.websocket.endpoint.StandardWebSocketHandlerAdapter;
@@ -57,6 +58,8 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw
private final String path;
+ private final Class extends Endpoint> endpointClass;
+
private final Object bean;
private List subprotocols = new ArrayList();
@@ -70,20 +73,33 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw
private final Configurator configurator = new Configurator() {};
- // ContextLoader.getCurrentWebApplicationContext().getAutowireCapableBeanFactory().createBean(Class)
+ /**
+ * Class constructor with the {@code javax.webscoket.Endpoint} class.
+ * TODO
+ *
+ * @param path
+ * @param endpointClass
+ */
+ public EndpointRegistration(String path, Class extends Endpoint> endpointClass) {
+ this(path, endpointClass, null);
+ }
+
+ public EndpointRegistration(String path, Object bean) {
+ this(path, null, bean);
+ }
public EndpointRegistration(String path, String beanName) {
- Assert.hasText(path, "path must not be empty");
- Assert.notNull(beanName, "beanName is required");
- this.path = path;
- this.bean = beanName;
+ this(path, null, beanName);
}
- public EndpointRegistration(String path, Object bean) {
+ private EndpointRegistration(String path, Class extends Endpoint> endpointClass, Object bean) {
Assert.hasText(path, "path must not be empty");
- Assert.notNull(bean, "bean is required");
+ Assert.isTrue((endpointClass != null || bean != null), "Neither endpoint class nor endpoint bean provided");
this.path = path;
+ this.endpointClass = endpointClass;
this.bean = bean;
+ // this will fail if the bean is not a valid Endpoint type
+ getEndpointClass();
}
@Override
@@ -94,20 +110,34 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw
@SuppressWarnings("unchecked")
@Override
public Class extends Endpoint> getEndpointClass() {
+ if (this.endpointClass != null) {
+ return this.endpointClass;
+ }
Class> beanClass = this.bean.getClass();
if (beanClass.equals(String.class)) {
beanClass = this.beanFactory.getType((String) this.bean);
}
beanClass = ClassUtils.getUserClass(beanClass);
- if (WebSocketHandler.class.isAssignableFrom(beanClass)) {
+ if (Endpoint.class.isAssignableFrom(beanClass)) {
+ return (Class extends Endpoint>) beanClass;
+ }
+ else if (WebSocketHandler.class.isAssignableFrom(beanClass)) {
return StandardWebSocketHandlerAdapter.class;
}
else {
- return (Class extends Endpoint>) beanClass;
+ throw new IllegalStateException("Invalid endpoint bean: must be of type ... TODO ");
}
}
public Endpoint getEndpoint() {
+ if (this.endpointClass != null) {
+ WebApplicationContext wac = ContextLoader.getCurrentWebApplicationContext();
+ if (wac == null) {
+ throw new IllegalStateException("Failed to find WebApplicationContext. "
+ + "Was org.springframework.web.context.ContextLoader used to load the WebApplicationContext?");
+ }
+ return wac.getAutowireCapableBeanFactory().createBean(this.endpointClass);
+ }
Object bean = this.bean;
if (this.bean instanceof String) {
bean = this.beanFactory.getBean((String) this.bean);
diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java
new file mode 100644
index 00000000000..d412fd9e299
--- /dev/null
+++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/SpringConfigurator.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2002-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.websocket.server.endpoint;
+
+import java.util.Map;
+
+import javax.websocket.server.ServerEndpoint;
+import javax.websocket.server.ServerEndpointConfig.Configurator;
+
+import org.springframework.web.context.ContextLoader;
+import org.springframework.web.context.WebApplicationContext;
+
+
+/**
+ * This should be used in conjuction with {@link ServerEndpoint @ServerEndpoint} classes.
+ *
+ * For {@link javax.websocket.Endpoint}, see {@link EndpointExporter}.
+ *
+ * @author Rossen Stoyanchev
+ * @since 4.0
+ */
+public class SpringConfigurator extends Configurator {
+
+
+ @Override
+ public T getEndpointInstance(Class endpointClass) throws InstantiationException {
+
+ WebApplicationContext wac = ContextLoader.getCurrentWebApplicationContext();
+ if (wac == null) {
+ throw new IllegalStateException("Failed to find WebApplicationContext. "
+ + "Was org.springframework.web.context.ContextLoader used to load the WebApplicationContext?");
+ }
+
+ Map beans = wac.getBeansOfType(endpointClass);
+ if (beans.isEmpty()) {
+ // Initialize a new bean instance
+ return wac.getAutowireCapableBeanFactory().createBean(endpointClass);
+ }
+ if (beans.size() == 1) {
+ // Return the matching bean instance
+ return beans.values().iterator().next();
+ }
+ else {
+ // This should never happen (@ServerEndpoint has a single path mapping) ..
+ throw new IllegalStateException("Found more than one matching beans of type " + endpointClass);
+ }
+ }
+
+}