Browse Source

Add PrincipalMessageArgumentResolver

pull/286/merge
Rossen Stoyanchev 13 years ago
parent
commit
210be9cde4
  1. 54
      spring-messaging/src/main/java/org/springframework/messaging/handler/method/InvalidMessageMethodParameterException.java
  2. 20
      spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java
  3. 51
      spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java
  4. 10
      spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java
  5. 6
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java
  6. 8
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java
  7. 34
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java
  8. 15
      spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java

54
spring-messaging/src/main/java/org/springframework/messaging/handler/method/InvalidMessageMethodParameterException.java

@ -0,0 +1,54 @@ @@ -0,0 +1,54 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.handler.method;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessagingException;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class InvalidMessageMethodParameterException extends MessagingException {
private static final long serialVersionUID = -6905878930083523161L;
private final MethodParameter parameter;
public InvalidMessageMethodParameterException(Message<?> message, String description,
MethodParameter parameter, Throwable cause) {
super(message, description, cause);
this.parameter = parameter;
}
public InvalidMessageMethodParameterException(Message<?> message, String description,
MethodParameter parameter) {
super(message, description);
this.parameter = parameter;
}
public MethodParameter getParameter() {
return this.parameter;
}
}

20
spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java

@ -16,12 +16,14 @@ @@ -16,12 +16,14 @@
package org.springframework.messaging.simp;
import java.security.Principal;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@ -43,9 +45,6 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { @@ -43,9 +45,6 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String DESTINATIONS = "destinations";
// TODO
public static final String CONTENT_TYPE = "contentType";
public static final String MESSAGE_TYPE = "messageType";
public static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
@ -54,6 +53,8 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { @@ -54,6 +53,8 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String SUBSCRIPTION_ID = "subscriptionId";
public static final String USER = "user";
/**
* A constructor for creating new message headers.
@ -140,12 +141,11 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { @@ -140,12 +141,11 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
}
public MediaType getContentType() {
return (MediaType) getHeader(CONTENT_TYPE);
return (MediaType) getHeader(MessageHeaders.CONTENT_TYPE);
}
public void setContentType(MediaType contentType) {
Assert.notNull(contentType, "contentType is required");
setHeader(CONTENT_TYPE, contentType);
setHeader(MessageHeaders.CONTENT_TYPE, contentType);
}
public String getSubscriptionId() {
@ -164,4 +164,12 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { @@ -164,4 +164,12 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
setHeader(SESSION_ID, sessionId);
}
public Principal getUser() {
return (Principal) getHeader(USER);
}
public void setUser(Principal principal) {
setHeader(USER, principal);
}
}

51
spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java

@ -0,0 +1,51 @@ @@ -0,0 +1,51 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.annotation.support;
import java.security.Principal;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.method.InvalidMessageMethodParameterException;
import org.springframework.messaging.handler.method.MessageArgumentResolver;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PrincipalMessageArgumentResolver implements MessageArgumentResolver {
@Override
public boolean supportsParameter(MethodParameter parameter) {
Class<?> paramType = parameter.getParameterType();
return Principal.class.isAssignableFrom(paramType);
}
@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
Principal user = headers.getUser();
if (user == null) {
throw new InvalidMessageMethodParameterException(message, "User not available", parameter);
}
return user;
}
}

10
spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java

@ -37,15 +37,16 @@ import org.springframework.messaging.MessageChannel; @@ -37,15 +37,16 @@ import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.support.MessageBodyArgumentResolver;
import org.springframework.messaging.handler.annotation.support.MessageExceptionHandlerMethodResolver;
import org.springframework.messaging.handler.method.MessageArgumentResolverComposite;
import org.springframework.messaging.handler.method.InvocableMessageHandlerMethod;
import org.springframework.messaging.handler.method.MessageArgumentResolverComposite;
import org.springframework.messaging.handler.method.MessageReturnValueHandlerComposite;
import org.springframework.messaging.simp.annotation.SubscribeEvent;
import org.springframework.messaging.simp.annotation.UnsubscribeEvent;
import org.springframework.messaging.simp.annotation.support.MessageSendingReturnValueHandler;
import org.springframework.messaging.simp.MessageHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.SubscribeEvent;
import org.springframework.messaging.simp.annotation.UnsubscribeEvent;
import org.springframework.messaging.simp.annotation.support.MessageSendingReturnValueHandler;
import org.springframework.messaging.simp.annotation.support.PrincipalMessageArgumentResolver;
import org.springframework.messaging.support.converter.MessageConverter;
import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
@ -113,6 +114,7 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler @@ -113,6 +114,7 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler
initHandlerMethods();
this.argumentResolvers.addResolver(new PrincipalMessageArgumentResolver());
this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverter));
this.returnValueHandlers.addHandler(

6
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java

@ -46,7 +46,7 @@ public class StompMessageConverter { @@ -46,7 +46,7 @@ public class StompMessageConverter {
/**
* @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String.
*/
public Message<?> toMessage(Object stompContent, String sessionId) {
public Message<?> toMessage(Object stompContent) {
byte[] byteContent = null;
if (stompContent instanceof String) {
@ -91,12 +91,10 @@ public class StompMessageConverter { @@ -91,12 +91,10 @@ public class StompMessageConverter {
}
}
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
stompHeaders.setSessionId(sessionId);
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
return MessageBuilder.withPayload(payload).copyHeaders(stompHeaders.toMap()).build();
}

8
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java

@ -29,9 +29,9 @@ import java.util.concurrent.TimeUnit; @@ -29,9 +29,9 @@ import java.util.concurrent.TimeUnit;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@ -350,7 +350,7 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme @@ -350,7 +350,7 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
return;
}
Message<?> message = stompMessageConverter.toMessage(stompFrame, this.sessionId);
Message<?> message = stompMessageConverter.toMessage(stompFrame);
if (logger.isTraceEnabled()) {
logger.trace("Reading message " + message);
}
@ -369,6 +369,10 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme @@ -369,6 +369,10 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
}
relaySessions.remove(this.sessionId);
}
headers.setSessionId(this.sessionId);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
sendMessageToClient(message);
}

34
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java

@ -81,7 +81,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement @@ -81,7 +81,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) {
try {
String payload = textMessage.getPayload();
Message<?> message = this.stompMessageConverter.toMessage(payload, session.getId());
Message<?> message = this.stompMessageConverter.toMessage(payload);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setSessionId(session.getId());
headers.setUser(session.getPrincipal());
// TODO: validate size limits
// http://stomp.github.io/stomp-specification-1.2.html#Size_Limits
@ -96,18 +100,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement @@ -96,18 +100,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
if (SimpMessageType.CONNECT.equals(messageType)) {
handleConnect(session, message);
}
else if (SimpMessageType.MESSAGE.equals(messageType)) {
handlePublish(message);
}
else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
handleSubscribe(message);
}
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
handleUnsubscribe(message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
handleDisconnect(message);
}
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
this.outputChannel.send(message);
}
catch (Throwable t) {
@ -124,7 +118,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement @@ -124,7 +118,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
}
}
protected void handleConnect(final WebSocketSession session, Message<?> message) throws IOException {
protected void handleConnect(WebSocketSession session, Message<?> message) throws IOException {
StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(message);
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
@ -152,18 +146,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement @@ -152,18 +146,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
protected void handlePublish(Message<?> stompMessage) {
}
protected void handleSubscribe(Message<?> message) {
}
protected void handleUnsubscribe(Message<?> message) {
}
protected void handleDisconnect(Message<?> message) {
}
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);

15
spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java

@ -24,9 +24,6 @@ import org.springframework.messaging.Message; @@ -24,9 +24,6 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompMessageConverter;
import static org.junit.Assert.*;
@ -51,17 +48,16 @@ public class StompMessageConverterTests { @@ -51,17 +48,16 @@ public class StompMessageConverterTests {
String accept = "accept-version:1.1\n";
String host = "host:github.org\n";
String frame = "\n\n\nCONNECT\n" + accept + host + "\n";
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
MessageHeaders headers = message.getHeaders();
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
Map<String, Object> map = stompHeaders.toMap();
assertEquals(6, map.size());
assertEquals(5, map.size());
assertNotNull(map.get(MessageHeaders.ID));
assertNotNull(map.get(MessageHeaders.TIMESTAMP));
assertNotNull(map.get(SimpMessageHeaderAccessor.SESSION_ID));
assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS));
assertNotNull(map.get(SimpMessageHeaderAccessor.MESSAGE_TYPE));
assertNotNull(map.get(SimpMessageHeaderAccessor.PROTOCOL_MESSAGE_TYPE));
@ -71,7 +67,6 @@ public class StompMessageConverterTests { @@ -71,7 +67,6 @@ public class StompMessageConverterTests {
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand());
assertEquals("session-123", stompHeaders.getSessionId());
assertNotNull(headers.get(MessageHeaders.ID));
assertNotNull(headers.get(MessageHeaders.TIMESTAMP));
@ -89,7 +84,7 @@ public class StompMessageConverterTests { @@ -89,7 +84,7 @@ public class StompMessageConverterTests {
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String frame = "CONNECT\n" + accept + host + "\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
@ -111,7 +106,7 @@ public class StompMessageConverterTests { @@ -111,7 +106,7 @@ public class StompMessageConverterTests {
String host = "host:github.org\n";
String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
@ -133,7 +128,7 @@ public class StompMessageConverterTests { @@ -133,7 +128,7 @@ public class StompMessageConverterTests {
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);

Loading…
Cancel
Save