From 3703be5aafabfca3693cf62524f9ba510a660a6c Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 26 Nov 2020 21:51:25 +0000 Subject: [PATCH] MessageHeaderAccessor handle self-copy correctly 1. Revert changes in setHeader from 5.2.9 that caused regression on self-copy. 2. Update copyHeaders/IfAbsent to ensure a copy of native headers. 3. Exit if source and target are the same instance, as an optimization. Closes gh-26155 --- .../support/MessageHeaderAccessor.java | 26 +++---- .../support/NativeMessageHeaderAccessor.java | 72 +++++++++++++------ .../NativeMessageHeaderAccessorTests.java | 43 ++++++++--- 3 files changed, 101 insertions(+), 40 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java index ef7130f54a9..95e04b63f48 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java @@ -378,13 +378,14 @@ public class MessageHeaderAccessor { * {@link #copyHeadersIfAbsent(Map)} to avoid overwriting values. */ public void copyHeaders(@Nullable Map headersToCopy) { - if (headersToCopy != null) { - headersToCopy.forEach((key, value) -> { - if (!isReadOnly(key)) { - setHeader(key, value); - } - }); + if (headersToCopy == null || this.headers == headersToCopy) { + return; } + headersToCopy.forEach((key, value) -> { + if (!isReadOnly(key)) { + setHeader(key, value); + } + }); } /** @@ -392,13 +393,14 @@ public class MessageHeaderAccessor { *

This operation will not overwrite any existing values. */ public void copyHeadersIfAbsent(@Nullable Map headersToCopy) { - if (headersToCopy != null) { - headersToCopy.forEach((key, value) -> { - if (!isReadOnly(key)) { - setHeaderIfAbsent(key, value); - } - }); + if (headersToCopy == null || this.headers == headersToCopy) { + return; } + headersToCopy.forEach((key, value) -> { + if (!isReadOnly(key)) { + setHeaderIfAbsent(key, value); + } + }); } protected boolean isReadOnly(String headerName) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java index fbd70f3598d..1ccded25c8c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java @@ -16,6 +16,7 @@ package org.springframework.messaging.support; +import java.util.ArrayList; import java.util.Collections; import java.util.LinkedList; import java.util.List; @@ -75,6 +76,8 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { @SuppressWarnings("unchecked") Map> map = (Map>) getHeader(NATIVE_HEADERS); if (map != null) { + // setHeader checks for equality but we need copy of native headers + setHeader(NATIVE_HEADERS, null); setHeader(NATIVE_HEADERS, new LinkedMultiValueMap<>(map)); } } @@ -103,6 +106,8 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { if (isMutable()) { Map> map = getNativeHeaders(); if (map != null) { + // setHeader checks for equality but we need immutable wrapper + setHeader(NATIVE_HEADERS, null); setHeader(NATIVE_HEADERS, Collections.unmodifiableMap(map)); } super.setImmutable(); @@ -110,31 +115,34 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { } @Override - public void setHeader(String name, @Nullable Object value) { - if (name.equalsIgnoreCase(NATIVE_HEADERS)) { - // Force removal since setHeader checks for equality - super.setHeader(NATIVE_HEADERS, null); + public void copyHeaders(@Nullable Map headersToCopy) { + if (headersToCopy == null) { + return; + } + + @SuppressWarnings("unchecked") + Map> map = (Map>) headersToCopy.get(NATIVE_HEADERS); + if (map != null && map != getNativeHeaders()) { + map.forEach(this::setNativeHeaderValues); } - super.setHeader(name, value); + + // setHeader checks for equality, native headers should be equal by now + super.copyHeaders(headersToCopy); } @Override - @SuppressWarnings("unchecked") - public void copyHeaders(@Nullable Map headersToCopy) { - if (headersToCopy != null) { - Map> nativeHeaders = getNativeHeaders(); - Map> map = (Map>) headersToCopy.get(NATIVE_HEADERS); - if (map != null) { - if (nativeHeaders != null) { - nativeHeaders.putAll(map); - } - else { - nativeHeaders = new LinkedMultiValueMap<>(map); - } - } - super.copyHeaders(headersToCopy); - setHeader(NATIVE_HEADERS, nativeHeaders); + public void copyHeadersIfAbsent(@Nullable Map headersToCopy) { + if (headersToCopy == null) { + return; + } + + @SuppressWarnings("unchecked") + Map> map = (Map>) headersToCopy.get(NATIVE_HEADERS); + if (map != null && getNativeHeaders() == null) { + map.forEach(this::setNativeHeaderValues); } + + super.copyHeadersIfAbsent(headersToCopy); } /** @@ -201,6 +209,30 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { } } + /** + * Variant of {@link #addNativeHeader(String, String)} for all values. + * @since 5.2.12 + */ + public void setNativeHeaderValues(String name, @Nullable List values) { + Assert.state(isMutable(), "Already immutable"); + Map> map = getNativeHeaders(); + if (values == null) { + if (map != null && map.get(name) != null) { + setModified(true); + map.remove(name); + } + return; + } + if (map == null) { + map = new LinkedMultiValueMap<>(3); + setHeader(NATIVE_HEADERS, map); + } + if (!ObjectUtils.nullSafeEquals(values, getHeader(name))) { + setModified(true); + map.put(name, new ArrayList<>(values)); + } + } + /** * Add the specified native header value to existing values. *

In order for this to work, the accessor must be {@link #isMutable() diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java index 726f80c256e..73eba06caa4 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java @@ -226,19 +226,46 @@ public class NativeMessageHeaderAccessorTests { @Test // gh-25821 void copyImmutableToMutable() { - NativeMessageHeaderAccessor source = new NativeMessageHeaderAccessor(); - source.addNativeHeader("foo", "bar"); - Message message = MessageBuilder.createMessage("payload", source.getMessageHeaders()); + NativeMessageHeaderAccessor sourceAccessor = new NativeMessageHeaderAccessor(); + sourceAccessor.addNativeHeader("foo", "bar"); + Message source = MessageBuilder.createMessage("payload", sourceAccessor.getMessageHeaders()); - NativeMessageHeaderAccessor target = new NativeMessageHeaderAccessor(); - target.copyHeaders(message.getHeaders()); - target.setLeaveMutable(true); - message = MessageBuilder.createMessage(message.getPayload(), target.getMessageHeaders()); + NativeMessageHeaderAccessor targetAccessor = new NativeMessageHeaderAccessor(); + targetAccessor.copyHeaders(source.getHeaders()); + targetAccessor.setLeaveMutable(true); + Message target = MessageBuilder.createMessage(source.getPayload(), targetAccessor.getMessageHeaders()); - MessageHeaderAccessor accessor = MessageHeaderAccessor.getMutableAccessor(message); + MessageHeaderAccessor accessor = MessageHeaderAccessor.getMutableAccessor(target); assertThat(accessor.isMutable()); ((NativeMessageHeaderAccessor) accessor).addNativeHeader("foo", "baz"); assertThat(((NativeMessageHeaderAccessor) accessor).getNativeHeader("foo")).containsExactly("bar", "baz"); } + @Test // gh-25821 + void copyIfAbsentImmutableToMutable() { + NativeMessageHeaderAccessor sourceAccessor = new NativeMessageHeaderAccessor(); + sourceAccessor.addNativeHeader("foo", "bar"); + Message source = MessageBuilder.createMessage("payload", sourceAccessor.getMessageHeaders()); + + MessageHeaderAccessor targetAccessor = new NativeMessageHeaderAccessor(); + targetAccessor.copyHeadersIfAbsent(source.getHeaders()); + targetAccessor.setLeaveMutable(true); + Message target = MessageBuilder.createMessage(source.getPayload(), targetAccessor.getMessageHeaders()); + + MessageHeaderAccessor accessor = MessageHeaderAccessor.getMutableAccessor(target); + assertThat(accessor.isMutable()); + ((NativeMessageHeaderAccessor) accessor).addNativeHeader("foo", "baz"); + assertThat(((NativeMessageHeaderAccessor) accessor).getNativeHeader("foo")).containsExactly("bar", "baz"); + } + + @Test // gh-26155 + void copySelf() { + NativeMessageHeaderAccessor accessor = new NativeMessageHeaderAccessor(); + accessor.addNativeHeader("foo", "bar"); + accessor.setHeader("otherHeader", "otherHeaderValue"); + accessor.setLeaveMutable(true); + + // Does not fail with ConcurrentModificationException + accessor.copyHeaders(accessor.getMessageHeaders()); + } }