Browse Source

Add header propagation predicate support to message return value handlers

See gh-36179

Signed-off-by: Junhwan Kim <musoyou1085@gmail.com>
pull/36273/head
김준환 2 months ago committed by rstoyanchev
parent
commit
e5f8c5b7ae
  1. 43
      spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java
  2. 41
      spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java
  3. 53
      spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java
  4. 59
      spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java

43
spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java

@ -45,6 +45,7 @@ import org.springframework.util.ObjectUtils; @@ -45,6 +45,7 @@ import org.springframework.util.ObjectUtils;
import org.springframework.util.PropertyPlaceholderHelper;
import org.springframework.util.PropertyPlaceholderHelper.PlaceholderResolver;
import org.springframework.util.StringUtils;
import java.util.function.Predicate;
/**
* A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a
@ -73,6 +74,8 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH @@ -73,6 +74,8 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
private @Nullable MessageHeaderInitializer headerInitializer;
private @Nullable Predicate<String> headerFilter;
public SendToMethodReturnValueHandler(SimpMessageSendingOperations messagingTemplate, boolean annotationRequired) {
Assert.notNull(messagingTemplate, "'messagingTemplate' must not be null");
@ -133,6 +136,27 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH @@ -133,6 +136,27 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
return this.headerInitializer;
}
/**
* Add a filter to determine which headers from the input message should be propagated to the output message.
* Multiple filters are combined with logical OR.
* <p>If not set, no input headers are propagated (default behavior).</p>
*/
public void addHeaderFilter(Predicate<String> filter) {
Assert.notNull(filter, "Filter predicate must not be null");
if (this.headerFilter == null) {
this.headerFilter = filter;
} else {
this.headerFilter = this.headerFilter.or(filter);
}
}
/**
* Return the configured header filter.
*/
public @Nullable Predicate<String> getHeaderFilter() {
return this.headerFilter;
}
@Override
public boolean supportsReturnType(MethodParameter returnType) {
@ -171,11 +195,11 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH @@ -171,11 +195,11 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
destination = destinationHelper.expandTemplateVars(destination);
if (broadcast) {
this.messagingTemplate.convertAndSendToUser(
user, destination, returnValue, createHeaders(null, returnType));
user, destination, returnValue, createHeaders(null, returnType, message));
}
else {
this.messagingTemplate.convertAndSendToUser(
user, destination, returnValue, createHeaders(sessionId, returnType));
user, destination, returnValue, createHeaders(sessionId, returnType, message));
}
}
}
@ -185,7 +209,7 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH @@ -185,7 +209,7 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix);
for (String destination : destinations) {
destination = destinationHelper.expandTemplateVars(destination);
this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType));
this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType, message));
}
}
}
@ -234,11 +258,22 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH @@ -234,11 +258,22 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
new String[] {defaultPrefix + destination} : new String[] {defaultPrefix + '/' + destination});
}
private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType) {
private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType, @Nullable Message<?> inputMessage) {
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(headerAccessor);
}
if (inputMessage != null && headerFilter != null) {
Map<String, Object> inputHeaders = inputMessage.getHeaders();
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
String name = entry.getKey();
if (headerFilter.test(name)) {
headerAccessor.setHeader(name, entry.getValue());
}
}
}
if (sessionId != null) {
headerAccessor.setSessionId(sessionId);
}

41
spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java

@ -16,6 +16,9 @@ @@ -16,6 +16,9 @@
package org.springframework.messaging.simp.annotation.support;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.commons.logging.Log;
import org.jspecify.annotations.Nullable;
@ -65,6 +68,8 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn @@ -65,6 +68,8 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
private @Nullable MessageHeaderInitializer headerInitializer;
private @Nullable Predicate<String> headerFilter;
/**
* Construct a new SubscriptionMethodReturnValueHandler.
@ -93,6 +98,27 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn @@ -93,6 +98,27 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
return this.headerInitializer;
}
/**
* Add a filter to determine which headers from the input message should be propagated to the output message.
* Multiple filters are combined with logical OR.
* <p>If not set, no input headers are propagated (default behavior).</p>
*/
public void addHeaderFilter(Predicate<String> filter) {
Assert.notNull(filter, "Filter predicate must not be null");
if (this.headerFilter == null) {
this.headerFilter = filter;
} else {
this.headerFilter = this.headerFilter.or(filter);
}
}
/**
* Return the configured header filter.
*/
public @Nullable Predicate<String> getHeaderFilter() {
return this.headerFilter;
}
@Override
public boolean supportsReturnType(MethodParameter returnType) {
@ -126,15 +152,26 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn @@ -126,15 +152,26 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
if (logger.isDebugEnabled()) {
logger.debug("Reply to @SubscribeMapping: " + returnValue);
}
MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType);
MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType, message);
this.messagingTemplate.convertAndSend(destination, returnValue, headersToSend);
}
private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType) {
private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType, @Nullable Message<?> inputMessage) {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(accessor);
}
if (inputMessage != null && headerFilter != null) {
Map<String, Object> inputHeaders = inputMessage.getHeaders();
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
String name = entry.getKey();
if (headerFilter.test(name)) {
accessor.setHeader(name, entry.getValue());
}
}
}
if (sessionId != null) {
accessor.setSessionId(sessionId);
}

53
spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java

@ -293,9 +293,60 @@ public class SendToMethodReturnValueHandlerTests { @@ -293,9 +293,60 @@ public class SendToMethodReturnValueHandlerTests {
assertResponse(parameter, sessionId, 1, "/dest4");
}
@Test
void sendToWithHeaderFilterSinglePredicate() throws Exception {
given(this.messageChannel.send(any(Message.class))).willReturn(true);
String sessionId = "sess1";
String customHeaderName = "x-custom-header";
String customHeaderValue = "custom-value";
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
inputMessage = MessageBuilder.fromMessage(inputMessage)
.setHeader(customHeaderName, customHeaderValue)
.build();
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
handler.addHeaderFilter(name -> name.equals(customHeaderName));
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
for (Message<?> sent : this.messageCaptor.getAllValues()) {
MessageHeaders headers = sent.getHeaders();
assertThat(headers.get(customHeaderName)).isEqualTo(customHeaderValue);
}
}
@Test
void sendToWithHeaderFilterMultiplePredicates() throws Exception {
given(this.messageChannel.send(any(Message.class))).willReturn(true);
String sessionId = "sess1";
String headerA = "x-header-a";
String headerB = "x-header-b";
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
inputMessage = MessageBuilder.fromMessage(inputMessage)
.setHeader(headerA, "A-value")
.setHeader(headerB, "B-value")
.build();
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB));
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
for (Message<?> sent : this.messageCaptor.getAllValues()) {
MessageHeaders headers = sent.getHeaders();
assertThat(headers.get(headerA)).isEqualTo("A-value");
assertThat(headers.get(headerB)).isEqualTo("B-value");
}
}
private void assertResponse(MethodParameter methodParameter, String sessionId,
int index, String destination) {
int index, String destination) {
SimpMessageHeaderAccessor accessor = getCapturedAccessor(index);
assertThat(accessor.getSessionId()).isEqualTo(sessionId);

59
spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java

@ -186,6 +186,65 @@ class SubscriptionMethodReturnValueHandlerTests { @@ -186,6 +186,65 @@ class SubscriptionMethodReturnValueHandlerTests {
assertThat(new String((byte[]) message.getPayload(), StandardCharsets.UTF_8)).isEqualTo("{\"withView1\":\"with\"}");
}
@Test
void testHeaderFilterSinglePredicate() throws Exception {
String sessionId = "sess1";
String subscriptionId = "subs1";
String destination = "/dest";
String customHeaderName = "x-custom-header";
String customHeaderValue = "custom-value";
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
.setHeader(customHeaderName, customHeaderValue)
.build();
MessageSendingOperations messagingTemplate = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
handler.addHeaderFilter(name -> name.equals(customHeaderName));
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
MessageHeaders sentHeaders = captor.getValue();
assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue);
}
@Test
void testHeaderFilterMultiplePredicates() throws Exception {
String sessionId = "sess1";
String subscriptionId = "subs1";
String destination = "/dest";
String headerA = "x-header-a";
String headerB = "x-header-b";
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
.setHeader(headerA, "A-value")
.setHeader(headerB, "B-value")
.build();
MessageSendingOperations messagingTemplate = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB));
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
MessageHeaders sentHeaders = captor.getValue();
assertThat(sentHeaders.get(headerA)).isEqualTo("A-value");
assertThat(sentHeaders.get(headerB)).isEqualTo("B-value");
}
private Message<?> createInputMessage(String sessId, String subsId, String dest, Principal principal) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();

Loading…
Cancel
Save