diff --git a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/websocket/WebSocketAutoConfiguration.java b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/websocket/WebSocketAutoConfiguration.java index 6887a8f8db6..0091282208e 100644 --- a/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/websocket/WebSocketAutoConfiguration.java +++ b/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/websocket/WebSocketAutoConfiguration.java @@ -16,15 +16,19 @@ package org.springframework.boot.autoconfigure.websocket; -import java.util.Collection; import java.util.HashMap; import java.util.Map; import javax.servlet.ServletContainerInitializer; import org.apache.catalina.Context; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.beans.BeanUtils; import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.boot.autoconfigure.AutoConfigureBefore; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -33,13 +37,11 @@ import org.springframework.boot.autoconfigure.web.EmbeddedServletContainerAutoCo import org.springframework.boot.context.embedded.tomcat.TomcatEmbeddedServletContainerFactory; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.ClassUtils; -import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.WebSocketHandler; -import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler; -import org.springframework.web.socket.sockjs.SockJsService; -import org.springframework.web.socket.sockjs.support.AbstractSockJsService; +import org.springframework.web.socket.server.config.EnableWebSocket; +import org.springframework.web.socket.server.config.WebSocketConfigurer; +import org.springframework.web.socket.server.config.WebSocketHandlerRegistry; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; /** @@ -54,12 +56,27 @@ import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsServ @Configuration @ConditionalOnClass({ WebSocketHandler.class }) @AutoConfigureBefore(EmbeddedServletContainerAutoConfiguration.class) +@ConditionalOnMissingBean(WebSocketConfigurer.class) +@EnableWebSocket public class WebSocketAutoConfiguration { - private static class WebSocketEndpointPostProcessor implements BeanPostProcessor { + private static Log logger = LogFactory.getLog(WebSocketAutoConfiguration.class); + + // Nested class to avoid having to load WebSocketConfigurer before conditions are + // evaluated + @Configuration + protected static class WebSocketRegistrationConfiguration implements + BeanPostProcessor, BeanFactoryAware, WebSocketConfigurer { private Map prefixes = new HashMap(); + private ListableBeanFactory beanFactory; + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.beanFactory = (ListableBeanFactory) beanFactory; + } + @Override public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { @@ -75,58 +92,24 @@ public class WebSocketAutoConfiguration { return bean; } - public WebSocketHandler getHandler(String prefix) { + private WebSocketHandler getHandler(String prefix) { return this.prefixes.get(prefix); } - public String[] getPrefixes() { + private String[] getPrefixes() { return this.prefixes.keySet().toArray(new String[this.prefixes.size()]); } - } - - @Bean - public WebSocketEndpointPostProcessor webSocketEndpointPostProcessor() { - return new WebSocketEndpointPostProcessor(); - } - - @Bean - @ConditionalOnMissingBean(SockJsService.class) - public DefaultSockJsService sockJsService() { - DefaultSockJsService service = new DefaultSockJsService(sockJsTaskScheduler()); - service.setSockJsClientLibraryUrl("https://cdn.sockjs.org/sockjs-0.3.4.min.js"); - service.setWebSocketsEnabled(true); - return service; - } - - @Bean - public SimpleUrlHandlerMapping handlerMapping(SockJsService sockJsService, - Collection handlers) { - - WebSocketEndpointPostProcessor processor = webSocketEndpointPostProcessor(); - Map urlMap = new HashMap(); - for (String prefix : webSocketEndpointPostProcessor().getPrefixes()) { - urlMap.put(prefix + "/**", new SockJsHttpRequestHandler(sockJsService, - processor.getHandler(prefix))); - } - - if (sockJsService instanceof AbstractSockJsService) { - ((AbstractSockJsService) sockJsService).setValidSockJsPrefixes(processor - .getPrefixes()); + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + // Force initialization of WebSocketHandler beans + this.beanFactory.getBeansOfType(WebSocketHandler.class); + for (String prefix : getPrefixes()) { + logger.info("Adding SockJS handler: " + prefix); + registry.addHandler(getHandler(prefix), prefix).withSockJS(); + } } - SimpleUrlHandlerMapping handlerMapping = new SimpleUrlHandlerMapping(); - handlerMapping.setOrder(-1); - handlerMapping.setUrlMap(urlMap); - - return handlerMapping; - } - @Bean - @ConditionalOnMissingBean(name = "sockJsTaskScheduler") - public ThreadPoolTaskScheduler sockJsTaskScheduler() { - ThreadPoolTaskScheduler taskScheduler = new ThreadPoolTaskScheduler(); - taskScheduler.setThreadNamePrefix("SockJS-"); - return taskScheduler; } @Configuration