From 1d92755cc74127c7371c2438f272d431ee43bc4f Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 30 Jul 2019 10:52:13 +0100 Subject: [PATCH] Remove custom handling of byte[] in DefaultStompSession Closes gh-23358 --- .../simp/stomp/DefaultStompSession.java | 10 ++-- .../simp/stomp/DefaultStompSessionTests.java | 50 ++++++++++++++++--- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java index c2512edc030..46723c87bac 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java @@ -256,12 +256,9 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { private Message createMessage(StompHeaderAccessor accessor, @Nullable Object payload) { accessor.updateSimpMessageHeadersFromStompHeaders(); Message message; - if (payload == null) { + if (isEmpty(payload)) { message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); } - else if (payload instanceof byte[]) { - message = MessageBuilder.createMessage((byte[]) payload, accessor.getMessageHeaders()); - } else { message = (Message) getMessageConverter().toMessage(payload, accessor.getMessageHeaders()); accessor.updateStompHeadersFromSimpMessageHeaders(); @@ -274,6 +271,11 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { return message; } + private boolean isEmpty(@Nullable Object payload) { + return payload == null || StringUtils.isEmpty(payload) || + (payload instanceof byte[] && ((byte[]) payload).length == 0); + } + private void execute(Message message) { if (logger.isTraceEnabled()) { StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java index 78d50000ed7..76500c01ed2 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.messaging.simp.stomp; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Date; import java.util.Map; import java.util.concurrent.ScheduledFuture; @@ -34,6 +35,8 @@ import org.mockito.MockitoAnnotations; import org.springframework.messaging.Message; import org.springframework.messaging.MessageDeliveryException; +import org.springframework.messaging.converter.ByteArrayMessageConverter; +import org.springframework.messaging.converter.CompositeMessageConverter; import org.springframework.messaging.converter.MessageConversionException; import org.springframework.messaging.converter.StringMessageConverter; import org.springframework.messaging.simp.stomp.StompSession.Receiptable; @@ -46,10 +49,23 @@ import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; import org.springframework.util.concurrent.SettableListenableFuture; -import static org.hamcrest.Matchers.*; -import static org.junit.Assert.*; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.notNull; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; /** * Unit tests for {@link DefaultStompSession}. @@ -82,7 +98,9 @@ public class DefaultStompSessionTests { this.sessionHandler = mock(StompSessionHandler.class); this.connectHeaders = new StompHeaders(); this.session = new DefaultStompSession(this.sessionHandler, this.connectHeaders); - this.session.setMessageConverter(new StringMessageConverter()); + this.session.setMessageConverter( + new CompositeMessageConverter( + Arrays.asList(new StringMessageConverter(), new ByteArrayMessageConverter()))); SettableListenableFuture future = new SettableListenableFuture<>(); future.set(null); @@ -110,7 +128,7 @@ public class DefaultStompSessionTests { @Test // SPR-16844 public void afterConnectedWithSpecificVersion() { assertFalse(this.session.isConnected()); - this.connectHeaders.setAcceptVersion(new String[] {"1.1"}); + this.connectHeaders.setAcceptVersion("1.1"); this.session.afterConnected(this.connection); @@ -388,6 +406,26 @@ public class DefaultStompSessionTests { assertEquals("my-receipt", accessor.getReceipt()); } + @Test // gh-23358 + public void sendByteArray() { + this.session.afterConnected(this.connection); + assertTrue(this.session.isConnected()); + + String destination = "/topic/foo"; + String payload = "sample payload"; + this.session.send(destination, payload.getBytes(StandardCharsets.UTF_8)); + + Message message = this.messageCaptor.getValue(); + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + + StompHeaders stompHeaders = StompHeaders.readOnlyStompHeaders(accessor.getNativeHeaders()); + assertEquals(stompHeaders.toString(), 2, stompHeaders.size()); + + assertEquals(destination, stompHeaders.getDestination()); + assertEquals(MimeTypeUtils.APPLICATION_OCTET_STREAM, stompHeaders.getContentType()); + assertEquals(payload, new String(message.getPayload(), StandardCharsets.UTF_8)); + } + @Test public void sendWithConversionException() { this.session.afterConnected(this.connection);