From 7ad60d385bf4af507fba530a6e77e4fc4e3fded6 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 30 Sep 2020 21:42:12 +0100 Subject: [PATCH] Fix issue with copying headers in NativeMessageHeaderAccessor Closes gh-25821 --- .../support/NativeMessageHeaderAccessor.java | 32 ++++++++++++++++--- .../NativeMessageHeaderAccessorTests.java | 17 ++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java index b9ecf08b5aa..251bd88931a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java @@ -75,8 +75,6 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { @SuppressWarnings("unchecked") Map> map = (Map>) getHeader(NATIVE_HEADERS); if (map != null) { - // Force removal since setHeader checks for equality - removeHeader(NATIVE_HEADERS); setHeader(NATIVE_HEADERS, new LinkedMultiValueMap<>(map)); } } @@ -105,14 +103,40 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { if (isMutable()) { Map> map = getNativeHeaders(); if (map != null) { - // Force removal since setHeader checks for equality - removeHeader(NATIVE_HEADERS); setHeader(NATIVE_HEADERS, Collections.unmodifiableMap(map)); } super.setImmutable(); } } + @Override + public void setHeader(String name, @Nullable Object value) { + if (name.equalsIgnoreCase(NATIVE_HEADERS)) { + // Force removal since setHeader checks for equality + removeHeader(NATIVE_HEADERS); + } + super.setHeader(name, value); + } + + @Override + @SuppressWarnings("unchecked") + public void copyHeaders(@Nullable Map headersToCopy) { + if (headersToCopy != null) { + Map> nativeHeaders = getNativeHeaders(); + Map> map = (Map>) headersToCopy.get(NATIVE_HEADERS); + if (map != null) { + if (nativeHeaders != null) { + nativeHeaders.putAll(map); + } + else { + nativeHeaders = new LinkedMultiValueMap<>(map); + } + } + super.copyHeaders(headersToCopy); + setHeader(NATIVE_HEADERS, nativeHeaders); + } + } + /** * Whether the native header map contains the give header name. */ diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java index 5f87e5699ac..726f80c256e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java @@ -224,4 +224,21 @@ public class NativeMessageHeaderAccessorTests { headerAccessor.setImmutable(); } + @Test // gh-25821 + void copyImmutableToMutable() { + NativeMessageHeaderAccessor source = new NativeMessageHeaderAccessor(); + source.addNativeHeader("foo", "bar"); + Message message = MessageBuilder.createMessage("payload", source.getMessageHeaders()); + + NativeMessageHeaderAccessor target = new NativeMessageHeaderAccessor(); + target.copyHeaders(message.getHeaders()); + target.setLeaveMutable(true); + message = MessageBuilder.createMessage(message.getPayload(), target.getMessageHeaders()); + + MessageHeaderAccessor accessor = MessageHeaderAccessor.getMutableAccessor(message); + assertThat(accessor.isMutable()); + ((NativeMessageHeaderAccessor) accessor).addNativeHeader("foo", "baz"); + assertThat(((NativeMessageHeaderAccessor) accessor).getNativeHeader("foo")).containsExactly("bar", "baz"); + } + }