diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java index 7a3fcee55e0..767b27bf126 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java @@ -21,6 +21,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.springframework.context.ApplicationContext; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.scheduling.TaskScheduler; import org.springframework.util.Assert; @@ -90,6 +91,9 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry { return (SubProtocolWebSocketHandler) actual; } + protected void setApplicationContext(ApplicationContext applicationContext) { + this.stompHandler.setApplicationEventPublisher(applicationContext); + } @Override public StompWebSocketEndpointRegistration addEndpoint(String... paths) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java index 80696daec31..c3553c0cd35 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java @@ -18,7 +18,6 @@ package org.springframework.web.socket.config.annotation; import org.springframework.context.annotation.Bean; import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration; -import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.socket.WebSocketHandler; @@ -46,16 +45,11 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac @Bean public HandlerMapping stompWebSocketHandlerMapping() { - WebSocketHandler webSocketHandler = subProtocolWebSocketHandler(); - UserSessionRegistry sessionRegistry = userSessionRegistry(); - WebSocketTransportRegistration transportRegistration = getTransportRegistration(); - ThreadPoolTaskScheduler taskScheduler = messageBrokerSockJsTaskScheduler(); - - WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry( - webSocketHandler, transportRegistration, sessionRegistry, taskScheduler); + WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry(subProtocolWebSocketHandler(), + getTransportRegistration(), userSessionRegistry(), messageBrokerSockJsTaskScheduler()); + registry.setApplicationContext(getApplicationContext()); registerStompEndpoints(registry); - return registry.getHandlerMapping(); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index 19b4ec6530f..041e49f29dc 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -65,7 +65,7 @@ import org.springframework.web.socket.handler.SessionLimitExceededException; * @since 4.0 */ public class SubProtocolWebSocketHandler implements WebSocketHandler, - SubProtocolCapable, MessageHandler, SmartLifecycle, ApplicationEventPublisherAware { + SubProtocolCapable, MessageHandler, SmartLifecycle { /** * Sessions connected to this handler use a sub-protocol. Hence we expect to @@ -97,12 +97,10 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, private final ReentrantLock sessionCheckLock = new ReentrantLock(); - private Object lifecycleMonitor = new Object(); + private final Object lifecycleMonitor = new Object(); private volatile boolean running = false; - private ApplicationEventPublisher eventPublisher; - public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) { Assert.notNull(clientInboundChannel, "ClientInboundChannel must not be null"); @@ -147,10 +145,6 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, + " to protocol '" + protocol + "', it is already mapped to handler " + replaced); } } - - if (handler instanceof ApplicationEventPublisherAware) { - ((ApplicationEventPublisherAware) handler).setApplicationEventPublisher(this.eventPublisher); - } } /** @@ -203,11 +197,6 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, return sendBufferSizeLimit; } - @Override - public void setApplicationEventPublisher(ApplicationEventPublisher eventPublisher) { - this.eventPublisher = eventPublisher; - } - @Override public boolean isAutoStartup() { return true; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 9b636fda6ff..481c0675dd2 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -118,6 +118,8 @@ public class MessageBrokerBeanDefinitionParserTests { assertNotNull(stompHandler); assertEquals(128 * 1024, stompHandler.getMessageSizeLimit()); + assertNotNull(new DirectFieldAccessor(stompHandler).getPropertyValue("eventPublisher")); + httpRequestHandler = (HttpRequestHandler) suhm.getUrlMap().get("/test/**"); assertNotNull(httpRequestHandler); assertThat(httpRequestHandler, Matchers.instanceOf(SockJsHttpRequestHandler.class));