diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistry.java b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistry.java index 8cd4143343e..6dc6db2bd0d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistry.java @@ -28,10 +28,21 @@ import org.springframework.messaging.SubscribableChannel; @SuppressWarnings("rawtypes") public interface PubSubChannelRegistry> { + + /** + * A channel for messaging arriving from clients. + */ SubscribableChannel getClientInputChannel(); + /** + * A channel for sending direct messages to a client. The client must be have + * previously subscribed to the destination of the message. + */ SubscribableChannel getClientOutputChannel(); + /** + * A channel for broadcasting messages through a message broker. + */ SubscribableChannel getMessageBrokerChannel(); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java index e0b59661a00..f27c1f9e943 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java @@ -43,6 +43,7 @@ import org.springframework.web.messaging.annotation.SubscribeEvent; import org.springframework.web.messaging.annotation.UnsubscribeEvent; import org.springframework.web.messaging.converter.MessageConverter; import org.springframework.web.messaging.service.AbstractPubSubMessageHandler; +import org.springframework.web.messaging.support.MessageHolder; import org.springframework.web.messaging.support.PubSubHeaderAccesssor; import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.HandlerMethodSelector; @@ -197,6 +198,8 @@ public class AnnotationPubSubMessageHandler extends AbstractP invocableHandlerMethod.setMessageMethodArgumentResolvers(this.argumentResolvers); try { + MessageHolder.setMessage(message); + Object value = invocableHandlerMethod.invoke(message); MethodParameter returnType = handlerMethod.getReturnType(); @@ -205,12 +208,14 @@ public class AnnotationPubSubMessageHandler extends AbstractP } this.returnValueHandlers.handleReturnValue(value, returnType, message); - } catch (Throwable e) { // TODO: send error message, or add @ExceptionHandler-like capability e.printStackTrace(); } + finally { + MessageHolder.reset(); + } } protected HandlerMethod getHandlerMethod(String destination, Map handlerMethods) { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java index 4429421320d..b79b774da8b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java @@ -20,8 +20,6 @@ import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.util.Assert; -import org.springframework.web.messaging.support.PubSubHeaderAccesssor; -import org.springframework.web.messaging.support.SessionMessageChannel; /** @@ -46,9 +44,7 @@ public class MessageChannelArgumentResolver implements Argume @Override public Object resolveArgument(MethodParameter parameter, M message) throws Exception { - Assert.notNull(this.messageBrokerChannel, "messageBrokerChannel is required"); - final String sessionId = PubSubHeaderAccesssor.wrap(message).getSessionId(); - return new SessionMessageChannel(this.messageBrokerChannel, sessionId); + return this.messageBrokerChannel; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java index 0d525864789..caecd29b973 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java @@ -16,6 +16,8 @@ package org.springframework.web.messaging.service.method; +import java.util.Map; + import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; @@ -68,33 +70,30 @@ public class MessageReturnValueHandler implements ReturnValue return; } - returnMessage = updateReturnMessage(returnMessage, message); + returnMessage = processReturnMessage(returnMessage, message); this.clientChannel.send(returnMessage); } - protected M updateReturnMessage(M returnMessage, M message) { + protected M processReturnMessage(M returnMessage, M message) { PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); - String sessionId = headers.getSessionId(); - String subscriptionId = headers.getSubscriptionId(); - - Assert.notNull(subscriptionId, "No subscription id: " + message); + Assert.notNull(headers.getSubscriptionId(), "No subscription id: " + message); PubSubHeaderAccesssor returnHeaders = PubSubHeaderAccesssor.wrap(returnMessage); - returnHeaders.setSessionId(sessionId); - returnHeaders.setSubscriptionId(subscriptionId); + returnHeaders.setSessionId(headers.getSessionId()); + returnHeaders.setSubscriptionId(headers.getSubscriptionId()); if (returnHeaders.getDestination() == null) { returnHeaders.setDestination(headers.getDestination()); } - return createMessage(returnHeaders, returnMessage.getPayload()); + return createMessage(returnMessage.getPayload(), returnHeaders.toHeaders()); } @SuppressWarnings("unchecked") - private M createMessage(PubSubHeaderAccesssor returnHeaders, Object payload) { - return (M) MessageBuilder.withPayload(payload).copyHeaders(returnHeaders.toHeaders()).build(); + private M createMessage(Object payload, Map headers) { + return (M) MessageBuilder.withPayload(payload).copyHeaders(headers).build(); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/MessageHolder.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/MessageHolder.java new file mode 100644 index 00000000000..0e2af7f0277 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/MessageHolder.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import org.springframework.core.NamedThreadLocal; +import org.springframework.messaging.Message; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class MessageHolder { + + private static final NamedThreadLocal> messageHolder = + new NamedThreadLocal>("Current message"); + + + public static void setMessage(Message message) { + messageHolder.set(message); + } + + public static Message getMessage() { + return messageHolder.get(); + } + + public static void reset() { + messageHolder.remove(); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubMessageBuilder.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubMessageBuilder.java new file mode 100644 index 00000000000..47144ba82fd --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubMessageBuilder.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import org.springframework.http.MediaType; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +import reactor.util.Assert; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class PubSubMessageBuilder { + + private final PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.create(); + + private final T payload; + + + private PubSubMessageBuilder(T payload) { + Assert.notNull(payload, " is required"); + this.payload = payload; + } + + + public static PubSubMessageBuilder withPayload(T payload) { + return new PubSubMessageBuilder(payload); + } + + + public PubSubMessageBuilder destination(String destination) { + Assert.notNull(destination, "destination is required"); + this.headers.setDestination(destination); + return this; + } + + public PubSubMessageBuilder contentType(MediaType contentType) { + Assert.notNull(contentType, "contentType is required"); + this.headers.setContentType(contentType); + return this; + } + + public PubSubMessageBuilder contentType(String contentType) { + Assert.notNull(contentType, "contentType is required"); + this.headers.setContentType(MediaType.parseMediaType(contentType)); + return this; + } + + public Message build() { + + Message message = MessageHolder.getMessage(); + if (message != null) { + String sessionId = PubSubHeaderAccesssor.wrap(message).getSessionId(); + this.headers.setSessionId(sessionId); + } + + return MessageBuilder.withPayload(this.payload).copyHeaders(this.headers.toHeaders()).build(); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java deleted file mode 100644 index 6ae8fc59f70..00000000000 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.messaging.support; - -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.support.MessageBuilder; - -import reactor.util.Assert; - - -/** - * @author Rossen Stoyanchev - * @since 4.0 - */ -@SuppressWarnings("rawtypes") -public class SessionMessageChannel implements MessageChannel { - - private MessageChannel delegate; - - private final String sessionId; - - - public SessionMessageChannel(MessageChannel delegate, String sessionId) { - Assert.notNull(delegate, "delegate is required"); - Assert.notNull(sessionId, "sessionId is required"); - this.sessionId = sessionId; - this.delegate = delegate; - } - - @Override - public boolean send(M message) { - return send(message, -1); - } - - @Override - public boolean send(M message, long timeout) { - PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); - headers.setSessionId(this.sessionId); - Object payload = message.getPayload(); - @SuppressWarnings("unchecked") - M messageToSend = (M) MessageBuilder.withPayload(payload).copyHeaders(headers.toHeaders()).build(); - this.delegate.send(messageToSend); - return true; - } -}