Browse Source

MessageHeaderAccessor handle self-copy correctly

1. Revert changes in setHeader from 5.2.9 that caused regression on self-copy.
2. Update copyHeaders/IfAbsent to ensure a copy of native headers.
3. Exit if source and target are the same instance, as an optimization.

Closes gh-26155
pull/26273/head
Rossen Stoyanchev 5 years ago
parent
commit
3703be5aaf
  1. 26
      spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java
  2. 72
      spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java
  3. 43
      spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java

26
spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java

@ -378,13 +378,14 @@ public class MessageHeaderAccessor {
* {@link #copyHeadersIfAbsent(Map)} to avoid overwriting values. * {@link #copyHeadersIfAbsent(Map)} to avoid overwriting values.
*/ */
public void copyHeaders(@Nullable Map<String, ?> headersToCopy) { public void copyHeaders(@Nullable Map<String, ?> headersToCopy) {
if (headersToCopy != null) { if (headersToCopy == null || this.headers == headersToCopy) {
headersToCopy.forEach((key, value) -> { return;
if (!isReadOnly(key)) {
setHeader(key, value);
}
});
} }
headersToCopy.forEach((key, value) -> {
if (!isReadOnly(key)) {
setHeader(key, value);
}
});
} }
/** /**
@ -392,13 +393,14 @@ public class MessageHeaderAccessor {
* <p>This operation will <em>not</em> overwrite any existing values. * <p>This operation will <em>not</em> overwrite any existing values.
*/ */
public void copyHeadersIfAbsent(@Nullable Map<String, ?> headersToCopy) { public void copyHeadersIfAbsent(@Nullable Map<String, ?> headersToCopy) {
if (headersToCopy != null) { if (headersToCopy == null || this.headers == headersToCopy) {
headersToCopy.forEach((key, value) -> { return;
if (!isReadOnly(key)) {
setHeaderIfAbsent(key, value);
}
});
} }
headersToCopy.forEach((key, value) -> {
if (!isReadOnly(key)) {
setHeaderIfAbsent(key, value);
}
});
} }
protected boolean isReadOnly(String headerName) { protected boolean isReadOnly(String headerName) {

72
spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java

@ -16,6 +16,7 @@
package org.springframework.messaging.support; package org.springframework.messaging.support;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -75,6 +76,8 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Map<String, List<String>> map = (Map<String, List<String>>) getHeader(NATIVE_HEADERS); Map<String, List<String>> map = (Map<String, List<String>>) getHeader(NATIVE_HEADERS);
if (map != null) { if (map != null) {
// setHeader checks for equality but we need copy of native headers
setHeader(NATIVE_HEADERS, null);
setHeader(NATIVE_HEADERS, new LinkedMultiValueMap<>(map)); setHeader(NATIVE_HEADERS, new LinkedMultiValueMap<>(map));
} }
} }
@ -103,6 +106,8 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor {
if (isMutable()) { if (isMutable()) {
Map<String, List<String>> map = getNativeHeaders(); Map<String, List<String>> map = getNativeHeaders();
if (map != null) { if (map != null) {
// setHeader checks for equality but we need immutable wrapper
setHeader(NATIVE_HEADERS, null);
setHeader(NATIVE_HEADERS, Collections.unmodifiableMap(map)); setHeader(NATIVE_HEADERS, Collections.unmodifiableMap(map));
} }
super.setImmutable(); super.setImmutable();
@ -110,31 +115,34 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor {
} }
@Override @Override
public void setHeader(String name, @Nullable Object value) { public void copyHeaders(@Nullable Map<String, ?> headersToCopy) {
if (name.equalsIgnoreCase(NATIVE_HEADERS)) { if (headersToCopy == null) {
// Force removal since setHeader checks for equality return;
super.setHeader(NATIVE_HEADERS, null); }
@SuppressWarnings("unchecked")
Map<String, List<String>> map = (Map<String, List<String>>) headersToCopy.get(NATIVE_HEADERS);
if (map != null && map != getNativeHeaders()) {
map.forEach(this::setNativeHeaderValues);
} }
super.setHeader(name, value);
// setHeader checks for equality, native headers should be equal by now
super.copyHeaders(headersToCopy);
} }
@Override @Override
@SuppressWarnings("unchecked") public void copyHeadersIfAbsent(@Nullable Map<String, ?> headersToCopy) {
public void copyHeaders(@Nullable Map<String, ?> headersToCopy) { if (headersToCopy == null) {
if (headersToCopy != null) { return;
Map<String, List<String>> nativeHeaders = getNativeHeaders(); }
Map<String, List<String>> map = (Map<String, List<String>>) headersToCopy.get(NATIVE_HEADERS);
if (map != null) { @SuppressWarnings("unchecked")
if (nativeHeaders != null) { Map<String, List<String>> map = (Map<String, List<String>>) headersToCopy.get(NATIVE_HEADERS);
nativeHeaders.putAll(map); if (map != null && getNativeHeaders() == null) {
} map.forEach(this::setNativeHeaderValues);
else {
nativeHeaders = new LinkedMultiValueMap<>(map);
}
}
super.copyHeaders(headersToCopy);
setHeader(NATIVE_HEADERS, nativeHeaders);
} }
super.copyHeadersIfAbsent(headersToCopy);
} }
/** /**
@ -201,6 +209,30 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor {
} }
} }
/**
* Variant of {@link #addNativeHeader(String, String)} for all values.
* @since 5.2.12
*/
public void setNativeHeaderValues(String name, @Nullable List<String> values) {
Assert.state(isMutable(), "Already immutable");
Map<String, List<String>> map = getNativeHeaders();
if (values == null) {
if (map != null && map.get(name) != null) {
setModified(true);
map.remove(name);
}
return;
}
if (map == null) {
map = new LinkedMultiValueMap<>(3);
setHeader(NATIVE_HEADERS, map);
}
if (!ObjectUtils.nullSafeEquals(values, getHeader(name))) {
setModified(true);
map.put(name, new ArrayList<>(values));
}
}
/** /**
* Add the specified native header value to existing values. * Add the specified native header value to existing values.
* <p>In order for this to work, the accessor must be {@link #isMutable() * <p>In order for this to work, the accessor must be {@link #isMutable()

43
spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java

@ -226,19 +226,46 @@ public class NativeMessageHeaderAccessorTests {
@Test // gh-25821 @Test // gh-25821
void copyImmutableToMutable() { void copyImmutableToMutable() {
NativeMessageHeaderAccessor source = new NativeMessageHeaderAccessor(); NativeMessageHeaderAccessor sourceAccessor = new NativeMessageHeaderAccessor();
source.addNativeHeader("foo", "bar"); sourceAccessor.addNativeHeader("foo", "bar");
Message<String> message = MessageBuilder.createMessage("payload", source.getMessageHeaders()); Message<String> source = MessageBuilder.createMessage("payload", sourceAccessor.getMessageHeaders());
NativeMessageHeaderAccessor target = new NativeMessageHeaderAccessor(); NativeMessageHeaderAccessor targetAccessor = new NativeMessageHeaderAccessor();
target.copyHeaders(message.getHeaders()); targetAccessor.copyHeaders(source.getHeaders());
target.setLeaveMutable(true); targetAccessor.setLeaveMutable(true);
message = MessageBuilder.createMessage(message.getPayload(), target.getMessageHeaders()); Message<?> target = MessageBuilder.createMessage(source.getPayload(), targetAccessor.getMessageHeaders());
MessageHeaderAccessor accessor = MessageHeaderAccessor.getMutableAccessor(message); MessageHeaderAccessor accessor = MessageHeaderAccessor.getMutableAccessor(target);
assertThat(accessor.isMutable()); assertThat(accessor.isMutable());
((NativeMessageHeaderAccessor) accessor).addNativeHeader("foo", "baz"); ((NativeMessageHeaderAccessor) accessor).addNativeHeader("foo", "baz");
assertThat(((NativeMessageHeaderAccessor) accessor).getNativeHeader("foo")).containsExactly("bar", "baz"); assertThat(((NativeMessageHeaderAccessor) accessor).getNativeHeader("foo")).containsExactly("bar", "baz");
} }
@Test // gh-25821
void copyIfAbsentImmutableToMutable() {
NativeMessageHeaderAccessor sourceAccessor = new NativeMessageHeaderAccessor();
sourceAccessor.addNativeHeader("foo", "bar");
Message<String> source = MessageBuilder.createMessage("payload", sourceAccessor.getMessageHeaders());
MessageHeaderAccessor targetAccessor = new NativeMessageHeaderAccessor();
targetAccessor.copyHeadersIfAbsent(source.getHeaders());
targetAccessor.setLeaveMutable(true);
Message<?> target = MessageBuilder.createMessage(source.getPayload(), targetAccessor.getMessageHeaders());
MessageHeaderAccessor accessor = MessageHeaderAccessor.getMutableAccessor(target);
assertThat(accessor.isMutable());
((NativeMessageHeaderAccessor) accessor).addNativeHeader("foo", "baz");
assertThat(((NativeMessageHeaderAccessor) accessor).getNativeHeader("foo")).containsExactly("bar", "baz");
}
@Test // gh-26155
void copySelf() {
NativeMessageHeaderAccessor accessor = new NativeMessageHeaderAccessor();
accessor.addNativeHeader("foo", "bar");
accessor.setHeader("otherHeader", "otherHeaderValue");
accessor.setLeaveMutable(true);
// Does not fail with ConcurrentModificationException
accessor.copyHeaders(accessor.getMessageHeaders());
}
} }

Loading…
Cancel
Save