diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java
index f391a82b34d..257cb2e7939 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java
@@ -20,6 +20,7 @@ import java.lang.annotation.Annotation;
import java.security.Principal;
import java.util.Collections;
import java.util.Map;
+import java.util.function.Predicate;
import org.jspecify.annotations.Nullable;
@@ -45,7 +46,6 @@ import org.springframework.util.ObjectUtils;
import org.springframework.util.PropertyPlaceholderHelper;
import org.springframework.util.PropertyPlaceholderHelper.PlaceholderResolver;
import org.springframework.util.StringUtils;
-import java.util.function.Predicate;
/**
* A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a
@@ -137,21 +137,20 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
}
/**
- * Add a filter to determine which headers from the input message should be propagated to the output message.
- * Multiple filters are combined with logical OR.
- *
If not set, no input headers are propagated (default behavior).
+ * Add a filter to determine which headers from the input message should be
+ * propagated to the output message. Multiple filters are combined with
+ * {@link Predicate#or(Predicate)}.
+ * By default, no headers are propagated if this is not set.
+ * @since 7.0.4
*/
public void addHeaderFilter(Predicate filter) {
- Assert.notNull(filter, "Filter predicate must not be null");
- if (this.headerFilter == null) {
- this.headerFilter = filter;
- } else {
- this.headerFilter = this.headerFilter.or(filter);
- }
+ Assert.notNull(filter, "'headerFilter' predicate must not be null");
+ this.headerFilter = (this.headerFilter != null ? this.headerFilter.or(filter) : filter);
}
/**
* Return the configured header filter.
+ * @since 7.0.4
*/
public @Nullable Predicate getHeaderFilter() {
return this.headerFilter;
@@ -263,17 +262,13 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(headerAccessor);
}
-
- if (inputMessage != null && headerFilter != null) {
- Map inputHeaders = inputMessage.getHeaders();
- for (Map.Entry entry : inputHeaders.entrySet()) {
- String name = entry.getKey();
- if (headerFilter.test(name)) {
- headerAccessor.setHeader(name, entry.getValue());
+ if (inputMessage != null && this.headerFilter != null) {
+ inputMessage.getHeaders().forEach((name, value) -> {
+ if (this.headerFilter.test(name)) {
+ headerAccessor.setHeader(name, value);
}
- }
+ });
}
-
if (sessionId != null) {
headerAccessor.setSessionId(sessionId);
}
diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java
index d487e0c1af3..c7ba0a7de47 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java
@@ -16,7 +16,6 @@
package org.springframework.messaging.simp.annotation.support;
-import java.util.Map;
import java.util.function.Predicate;
import org.apache.commons.logging.Log;
@@ -99,21 +98,20 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
}
/**
- * Add a filter to determine which headers from the input message should be propagated to the output message.
- * Multiple filters are combined with logical OR.
- * If not set, no input headers are propagated (default behavior).
+ * Add a filter to determine which headers from the input message should be
+ * propagated to the output message. Multiple filters are combined with
+ * {@link Predicate#or(Predicate)}.
+ * By default, no headers are propagated if this is not set.
+ * @since 7.0.4
*/
public void addHeaderFilter(Predicate filter) {
Assert.notNull(filter, "Filter predicate must not be null");
- if (this.headerFilter == null) {
- this.headerFilter = filter;
- } else {
- this.headerFilter = this.headerFilter.or(filter);
- }
+ this.headerFilter = (this.headerFilter != null ? this.headerFilter.or(filter) : filter);
}
/**
* Return the configured header filter.
+ * @since 7.0.4
*/
public @Nullable Predicate getHeaderFilter() {
return this.headerFilter;
@@ -161,17 +159,13 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(accessor);
}
-
- if (inputMessage != null && headerFilter != null) {
- Map inputHeaders = inputMessage.getHeaders();
- for (Map.Entry entry : inputHeaders.entrySet()) {
- String name = entry.getKey();
- if (headerFilter.test(name)) {
- accessor.setHeader(name, entry.getValue());
+ if (inputMessage != null && this.headerFilter != null) {
+ inputMessage.getHeaders().forEach((name, value) -> {
+ if (this.headerFilter.test(name)) {
+ accessor.setHeader(name, value);
}
- }
+ });
}
-
if (sessionId != null) {
accessor.setSessionId(sessionId);
}
diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java
index 1da2f9b0251..86ded07f1af 100644
--- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java
+++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java
@@ -298,22 +298,21 @@ public class SendToMethodReturnValueHandlerTests {
given(this.messageChannel.send(any(Message.class))).willReturn(true);
String sessionId = "sess1";
- String customHeaderName = "x-custom-header";
- String customHeaderValue = "custom-value";
+ String headerName = "x-custom-header";
+ String headerValue = "custom-value";
Message> inputMessage = createMessage(sessionId, "sub1", null, null, null);
- inputMessage = MessageBuilder.fromMessage(inputMessage)
- .setHeader(customHeaderName, customHeaderValue)
- .build();
+ inputMessage = MessageBuilder.fromMessage(inputMessage).setHeader(headerName, headerValue).build();
- SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
- handler.addHeaderFilter(name -> name.equals(customHeaderName));
+ SimpMessagingTemplate template = new SimpMessagingTemplate(this.messageChannel);
+ SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(template, true);
+ handler.addHeaderFilter(name -> name.equals(headerName));
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
for (Message> sent : this.messageCaptor.getAllValues()) {
MessageHeaders headers = sent.getHeaders();
- assertThat(headers.get(customHeaderName)).isEqualTo(customHeaderValue);
+ assertThat(headers.get(headerName)).isEqualTo(headerValue);
}
}
@@ -330,7 +329,8 @@ public class SendToMethodReturnValueHandlerTests {
.setHeader(headerB, "B-value")
.build();
- SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
+ SimpMessagingTemplate template = new SimpMessagingTemplate(this.messageChannel);
+ SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(template, true);
handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB));
@@ -344,9 +344,8 @@ public class SendToMethodReturnValueHandlerTests {
}
}
-
- private void assertResponse(MethodParameter methodParameter, String sessionId,
- int index, String destination) {
+ private void assertResponse(
+ MethodParameter methodParameter, String sessionId, int index, String destination) {
SimpMessageHeaderAccessor accessor = getCapturedAccessor(index);
assertThat(accessor.getSessionId()).isEqualTo(sessionId);
diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java
index a44d9354240..b795b4e1636 100644
--- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java
+++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java
@@ -193,6 +193,7 @@ class SubscriptionMethodReturnValueHandlerTests {
String destination = "/dest";
String customHeaderName = "x-custom-header";
String customHeaderValue = "custom-value";
+
Message> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
@@ -200,18 +201,16 @@ class SubscriptionMethodReturnValueHandlerTests {
.setHeader(customHeaderName, customHeaderValue)
.build();
- MessageSendingOperations messagingTemplate = mock();
- SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
-
+ MessageSendingOperations template = mock();
+ SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(template);
handler.addHeaderFilter(name -> name.equals(customHeaderName));
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHeaders.class);
- verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
+ verify(template).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
- MessageHeaders sentHeaders = captor.getValue();
- assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue);
+ assertThat(captor.getValue().get(customHeaderName)).isEqualTo(customHeaderValue);
}
@Test
@@ -221,6 +220,7 @@ class SubscriptionMethodReturnValueHandlerTests {
String destination = "/dest";
String headerA = "x-header-a";
String headerB = "x-header-b";
+
Message> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
@@ -229,23 +229,21 @@ class SubscriptionMethodReturnValueHandlerTests {
.setHeader(headerB, "B-value")
.build();
- MessageSendingOperations messagingTemplate = mock();
- SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
-
+ MessageSendingOperations template = mock();
+ SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(template);
handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB));
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
ArgumentCaptor captor = ArgumentCaptor.forClass(MessageHeaders.class);
- verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
+ verify(template).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
MessageHeaders sentHeaders = captor.getValue();
assertThat(sentHeaders.get(headerA)).isEqualTo("A-value");
assertThat(sentHeaders.get(headerB)).isEqualTo("B-value");
}
-
private Message> createInputMessage(String sessId, String subsId, String dest, Principal principal) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();
headers.setSessionId(sessId);