diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java index cdd5cc8df34..448d3dc152b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java @@ -52,6 +52,7 @@ public enum StompCommand { private static Collection destinationRequired = Arrays.asList(SEND, SUBSCRIBE, MESSAGE); private static Collection subscriptionIdRequired = Arrays.asList(SUBSCRIBE, UNSUBSCRIBE, MESSAGE); + private static Collection contentLengthRequired = Arrays.asList(SEND, MESSAGE, ERROR); private static Collection bodyAllowed = Arrays.asList(SEND, MESSAGE, ERROR); static { @@ -77,6 +78,10 @@ public enum StompCommand { return subscriptionIdRequired.contains(this); } + public boolean requiresContentLength() { + return contentLengthRequired.contains(this); + } + public boolean isBodyAllowed() { return bodyAllowed.contains(this); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java index 2edd0202963..44a083414e1 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java @@ -54,7 +54,6 @@ public class StompDecoder { private final Log logger = LogFactory.getLog(StompDecoder.class); - /** * Decodes one or more STOMP frames from the given {@code ByteBuffer} into a * list of {@link Message}s. If the input buffer contains any incplcontains partial STOMP frame content, or additional @@ -201,11 +200,41 @@ public class StompDecoder { } } - private String unescape(String input) { - return input.replaceAll("\\\\n", "\n") - .replaceAll("\\\\r", "\r") - .replaceAll("\\\\c", ":") - .replaceAll("\\\\\\\\", "\\\\"); + /** + * See STOMP Spec 1.2: + * "Value Encoding". + */ + private String unescape(String inString) { + + StringBuilder sb = new StringBuilder(); + int pos = 0; // position in the old string + int index = inString.indexOf("\\"); + + while (index >= 0) { + sb.append(inString.substring(pos, index)); + Character c = inString.charAt(index + 1); + if (c == 'r') { + sb.append('\r'); + } + else if (c == 'n') { + sb.append('\n'); + } + else if (c == 'c') { + sb.append(':'); + } + else if (c == '\\') { + sb.append('\\'); + } + else { + // should never happen + throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString); + } + pos = index + 2; + index = inString.indexOf("\\", pos); + } + + sb.append(inString.substring(pos)); + return sb.toString(); } private byte[] readPayload(ByteBuffer buffer, MultiValueMap headers) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java index e7f94ddd9e9..ae7e54315a8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java @@ -34,6 +34,7 @@ import org.springframework.messaging.simp.SimpMessageType; * An encoder for STOMP frames. * * @author Andy Wilkinson + * @author Rossen Stoyanchev * @since 4.0 */ public final class StompEncoder { @@ -54,19 +55,21 @@ public final class StompEncoder { */ public byte[] encode(Message message) { try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(256); + ByteArrayOutputStream baos = new ByteArrayOutputStream(128 + message.getPayload().length); DataOutputStream output = new DataOutputStream(baos); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - if (isHeartbeat(headers)) { + if (SimpMessageType.HEARTBEAT == headers.getMessageType()) { + logger.trace("Encoded heartbeat"); output.write(message.getPayload()); } else { - writeCommand(headers, output); + output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET)); + output.write(LF); writeHeaders(headers, message, output); output.write(LF); writeBody(message, output); - output.write((byte)0); + output.write((byte) 0); } return baos.toByteArray(); @@ -76,61 +79,68 @@ public final class StompEncoder { } } - private boolean isHeartbeat(StompHeaderAccessor headers) { - return (headers.getMessageType() == SimpMessageType.HEARTBEAT); - } - - private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException { - output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET)); - output.write(LF); - } - private void writeHeaders(StompHeaderAccessor headers, Message message, DataOutputStream output) throws IOException { + StompCommand command = headers.getCommand(); Map> stompHeaders = headers.toStompHeaderMap(); - if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) { - logger.trace("Encoded heartbeat"); - } - else if (logger.isDebugEnabled()) { - logger.debug("Encoded STOMP command=" + headers.getCommand() + " headers=" + stompHeaders); + boolean shouldEscape = (command != StompCommand.CONNECT && command != StompCommand.CONNECTED); + + if (logger.isDebugEnabled()) { + logger.debug("Encoded STOMP " + command + ", headers=" + stompHeaders); } + for (Entry> entry : stompHeaders.entrySet()) { - byte[] key = getUtf8BytesEscapingIfNecessary(entry.getKey(), headers); + byte[] key = encodeHeaderString(entry.getKey(), shouldEscape); for (String value : entry.getValue()) { output.write(key); output.write(COLON); - output.write(getUtf8BytesEscapingIfNecessary(value, headers)); + output.write(encodeHeaderString(value, shouldEscape)); output.write(LF); } } - if ((headers.getCommand() == StompCommand.SEND) || (headers.getCommand() == StompCommand.MESSAGE) || - (headers.getCommand() == StompCommand.ERROR)) { - + if (command.requiresContentLength()) { + int contentLength = message.getPayload().length; output.write("content-length:".getBytes(UTF8_CHARSET)); - output.write(Integer.toString(message.getPayload().length).getBytes(UTF8_CHARSET)); + output.write(Integer.toString(contentLength).getBytes(UTF8_CHARSET)); output.write(LF); } } - private void writeBody(Message message, DataOutputStream output) throws IOException { - output.write(message.getPayload()); + private byte[] encodeHeaderString(String input, boolean escape) { + input = escape ? escape(input) : input; + return input.getBytes(UTF8_CHARSET); } - private byte[] getUtf8BytesEscapingIfNecessary(String input, StompHeaderAccessor headers) { - if (headers.getCommand() != StompCommand.CONNECT && headers.getCommand() != StompCommand.CONNECTED) { - return escape(input).getBytes(UTF8_CHARSET); - } - else { - return input.getBytes(UTF8_CHARSET); + /** + * See STOMP Spec 1.2: + * "Value Encoding". + */ + private String escape(String inString) { + StringBuilder sb = new StringBuilder(inString.length()); + for (int i = 0; i < inString.length(); i++) { + char c = inString.charAt(i); + if (c == '\\') { + sb.append("\\\\"); + } + else if (c == ':') { + sb.append("\\c"); + } + else if (c == '\n') { + sb.append("\\n"); + } + else if (c == '\r') { + sb.append("\\r"); + } + else { + sb.append(c); + } } + return sb.toString(); } - private String escape(String input) { - return input.replaceAll("\\\\", "\\\\\\\\") - .replaceAll(":", "\\\\c") - .replaceAll("\n", "\\\\n") - .replaceAll("\r", "\\\\r"); + private void writeBody(Message message, DataOutputStream output) throws IOException { + output.write(message.getPayload()); } }