Browse Source

Polishing contribution

See gh-36179
pull/36273/head
rstoyanchev 1 month ago
parent
commit
7fc619acdc
  1. 33
      spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java
  2. 30
      spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java
  3. 23
      spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java
  4. 20
      spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java

33
spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java

@ -20,6 +20,7 @@ import java.lang.annotation.Annotation; @@ -20,6 +20,7 @@ import java.lang.annotation.Annotation;
import java.security.Principal;
import java.util.Collections;
import java.util.Map;
import java.util.function.Predicate;
import org.jspecify.annotations.Nullable;
@ -45,7 +46,6 @@ import org.springframework.util.ObjectUtils; @@ -45,7 +46,6 @@ import org.springframework.util.ObjectUtils;
import org.springframework.util.PropertyPlaceholderHelper;
import org.springframework.util.PropertyPlaceholderHelper.PlaceholderResolver;
import org.springframework.util.StringUtils;
import java.util.function.Predicate;
/**
* A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a
@ -137,21 +137,20 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH @@ -137,21 +137,20 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
}
/**
* Add a filter to determine which headers from the input message should be propagated to the output message.
* Multiple filters are combined with logical OR.
* <p>If not set, no input headers are propagated (default behavior).</p>
* Add a filter to determine which headers from the input message should be
* propagated to the output message. Multiple filters are combined with
* {@link Predicate#or(Predicate)}.
* <p>By default, no headers are propagated if this is not set.
* @since 7.0.4
*/
public void addHeaderFilter(Predicate<String> filter) {
Assert.notNull(filter, "Filter predicate must not be null");
if (this.headerFilter == null) {
this.headerFilter = filter;
} else {
this.headerFilter = this.headerFilter.or(filter);
}
Assert.notNull(filter, "'headerFilter' predicate must not be null");
this.headerFilter = (this.headerFilter != null ? this.headerFilter.or(filter) : filter);
}
/**
* Return the configured header filter.
* @since 7.0.4
*/
public @Nullable Predicate<String> getHeaderFilter() {
return this.headerFilter;
@ -263,17 +262,13 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH @@ -263,17 +262,13 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(headerAccessor);
}
if (inputMessage != null && headerFilter != null) {
Map<String, Object> inputHeaders = inputMessage.getHeaders();
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
String name = entry.getKey();
if (headerFilter.test(name)) {
headerAccessor.setHeader(name, entry.getValue());
if (inputMessage != null && this.headerFilter != null) {
inputMessage.getHeaders().forEach((name, value) -> {
if (this.headerFilter.test(name)) {
headerAccessor.setHeader(name, value);
}
}
});
}
if (sessionId != null) {
headerAccessor.setSessionId(sessionId);
}

30
spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java

@ -16,7 +16,6 @@ @@ -16,7 +16,6 @@
package org.springframework.messaging.simp.annotation.support;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.commons.logging.Log;
@ -99,21 +98,20 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn @@ -99,21 +98,20 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
}
/**
* Add a filter to determine which headers from the input message should be propagated to the output message.
* Multiple filters are combined with logical OR.
* <p>If not set, no input headers are propagated (default behavior).</p>
* Add a filter to determine which headers from the input message should be
* propagated to the output message. Multiple filters are combined with
* {@link Predicate#or(Predicate)}.
* <p>By default, no headers are propagated if this is not set.
* @since 7.0.4
*/
public void addHeaderFilter(Predicate<String> filter) {
Assert.notNull(filter, "Filter predicate must not be null");
if (this.headerFilter == null) {
this.headerFilter = filter;
} else {
this.headerFilter = this.headerFilter.or(filter);
}
this.headerFilter = (this.headerFilter != null ? this.headerFilter.or(filter) : filter);
}
/**
* Return the configured header filter.
* @since 7.0.4
*/
public @Nullable Predicate<String> getHeaderFilter() {
return this.headerFilter;
@ -161,17 +159,13 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn @@ -161,17 +159,13 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(accessor);
}
if (inputMessage != null && headerFilter != null) {
Map<String, Object> inputHeaders = inputMessage.getHeaders();
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
String name = entry.getKey();
if (headerFilter.test(name)) {
accessor.setHeader(name, entry.getValue());
if (inputMessage != null && this.headerFilter != null) {
inputMessage.getHeaders().forEach((name, value) -> {
if (this.headerFilter.test(name)) {
accessor.setHeader(name, value);
}
}
});
}
if (sessionId != null) {
accessor.setSessionId(sessionId);
}

23
spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java

@ -298,22 +298,21 @@ public class SendToMethodReturnValueHandlerTests { @@ -298,22 +298,21 @@ public class SendToMethodReturnValueHandlerTests {
given(this.messageChannel.send(any(Message.class))).willReturn(true);
String sessionId = "sess1";
String customHeaderName = "x-custom-header";
String customHeaderValue = "custom-value";
String headerName = "x-custom-header";
String headerValue = "custom-value";
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
inputMessage = MessageBuilder.fromMessage(inputMessage)
.setHeader(customHeaderName, customHeaderValue)
.build();
inputMessage = MessageBuilder.fromMessage(inputMessage).setHeader(headerName, headerValue).build();
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
handler.addHeaderFilter(name -> name.equals(customHeaderName));
SimpMessagingTemplate template = new SimpMessagingTemplate(this.messageChannel);
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(template, true);
handler.addHeaderFilter(name -> name.equals(headerName));
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
for (Message<?> sent : this.messageCaptor.getAllValues()) {
MessageHeaders headers = sent.getHeaders();
assertThat(headers.get(customHeaderName)).isEqualTo(customHeaderValue);
assertThat(headers.get(headerName)).isEqualTo(headerValue);
}
}
@ -330,7 +329,8 @@ public class SendToMethodReturnValueHandlerTests { @@ -330,7 +329,8 @@ public class SendToMethodReturnValueHandlerTests {
.setHeader(headerB, "B-value")
.build();
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
SimpMessagingTemplate template = new SimpMessagingTemplate(this.messageChannel);
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(template, true);
handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB));
@ -344,9 +344,8 @@ public class SendToMethodReturnValueHandlerTests { @@ -344,9 +344,8 @@ public class SendToMethodReturnValueHandlerTests {
}
}
private void assertResponse(MethodParameter methodParameter, String sessionId,
int index, String destination) {
private void assertResponse(
MethodParameter methodParameter, String sessionId, int index, String destination) {
SimpMessageHeaderAccessor accessor = getCapturedAccessor(index);
assertThat(accessor.getSessionId()).isEqualTo(sessionId);

20
spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java

@ -193,6 +193,7 @@ class SubscriptionMethodReturnValueHandlerTests { @@ -193,6 +193,7 @@ class SubscriptionMethodReturnValueHandlerTests {
String destination = "/dest";
String customHeaderName = "x-custom-header";
String customHeaderValue = "custom-value";
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
@ -200,18 +201,16 @@ class SubscriptionMethodReturnValueHandlerTests { @@ -200,18 +201,16 @@ class SubscriptionMethodReturnValueHandlerTests {
.setHeader(customHeaderName, customHeaderValue)
.build();
MessageSendingOperations messagingTemplate = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
MessageSendingOperations template = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(template);
handler.addHeaderFilter(name -> name.equals(customHeaderName));
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
verify(template).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
MessageHeaders sentHeaders = captor.getValue();
assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue);
assertThat(captor.getValue().get(customHeaderName)).isEqualTo(customHeaderValue);
}
@Test
@ -221,6 +220,7 @@ class SubscriptionMethodReturnValueHandlerTests { @@ -221,6 +220,7 @@ class SubscriptionMethodReturnValueHandlerTests {
String destination = "/dest";
String headerA = "x-header-a";
String headerB = "x-header-b";
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
@ -229,23 +229,21 @@ class SubscriptionMethodReturnValueHandlerTests { @@ -229,23 +229,21 @@ class SubscriptionMethodReturnValueHandlerTests {
.setHeader(headerB, "B-value")
.build();
MessageSendingOperations messagingTemplate = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
MessageSendingOperations template = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(template);
handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB));
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
verify(template).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
MessageHeaders sentHeaders = captor.getValue();
assertThat(sentHeaders.get(headerA)).isEqualTo("A-value");
assertThat(sentHeaders.get(headerB)).isEqualTo("B-value");
}
private Message<?> createInputMessage(String sessId, String subsId, String dest, Principal principal) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();
headers.setSessionId(sessId);

Loading…
Cancel
Save