Browse Source

Polishing contribution

See gh-36179
pull/36273/head
rstoyanchev 1 month ago
parent
commit
7fc619acdc
  1. 33
      spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java
  2. 30
      spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java
  3. 23
      spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java
  4. 20
      spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java

33
spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java

@ -20,6 +20,7 @@ import java.lang.annotation.Annotation;
import java.security.Principal; import java.security.Principal;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.function.Predicate;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
@ -45,7 +46,6 @@ import org.springframework.util.ObjectUtils;
import org.springframework.util.PropertyPlaceholderHelper; import org.springframework.util.PropertyPlaceholderHelper;
import org.springframework.util.PropertyPlaceholderHelper.PlaceholderResolver; import org.springframework.util.PropertyPlaceholderHelper.PlaceholderResolver;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import java.util.function.Predicate;
/** /**
* A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a * A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a
@ -137,21 +137,20 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
} }
/** /**
* Add a filter to determine which headers from the input message should be propagated to the output message. * Add a filter to determine which headers from the input message should be
* Multiple filters are combined with logical OR. * propagated to the output message. Multiple filters are combined with
* <p>If not set, no input headers are propagated (default behavior).</p> * {@link Predicate#or(Predicate)}.
* <p>By default, no headers are propagated if this is not set.
* @since 7.0.4
*/ */
public void addHeaderFilter(Predicate<String> filter) { public void addHeaderFilter(Predicate<String> filter) {
Assert.notNull(filter, "Filter predicate must not be null"); Assert.notNull(filter, "'headerFilter' predicate must not be null");
if (this.headerFilter == null) { this.headerFilter = (this.headerFilter != null ? this.headerFilter.or(filter) : filter);
this.headerFilter = filter;
} else {
this.headerFilter = this.headerFilter.or(filter);
}
} }
/** /**
* Return the configured header filter. * Return the configured header filter.
* @since 7.0.4
*/ */
public @Nullable Predicate<String> getHeaderFilter() { public @Nullable Predicate<String> getHeaderFilter() {
return this.headerFilter; return this.headerFilter;
@ -263,17 +262,13 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
if (getHeaderInitializer() != null) { if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(headerAccessor); getHeaderInitializer().initHeaders(headerAccessor);
} }
if (inputMessage != null && this.headerFilter != null) {
if (inputMessage != null && headerFilter != null) { inputMessage.getHeaders().forEach((name, value) -> {
Map<String, Object> inputHeaders = inputMessage.getHeaders(); if (this.headerFilter.test(name)) {
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) { headerAccessor.setHeader(name, value);
String name = entry.getKey();
if (headerFilter.test(name)) {
headerAccessor.setHeader(name, entry.getValue());
} }
} });
} }
if (sessionId != null) { if (sessionId != null) {
headerAccessor.setSessionId(sessionId); headerAccessor.setSessionId(sessionId);
} }

30
spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java

@ -16,7 +16,6 @@
package org.springframework.messaging.simp.annotation.support; package org.springframework.messaging.simp.annotation.support;
import java.util.Map;
import java.util.function.Predicate; import java.util.function.Predicate;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -99,21 +98,20 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
} }
/** /**
* Add a filter to determine which headers from the input message should be propagated to the output message. * Add a filter to determine which headers from the input message should be
* Multiple filters are combined with logical OR. * propagated to the output message. Multiple filters are combined with
* <p>If not set, no input headers are propagated (default behavior).</p> * {@link Predicate#or(Predicate)}.
* <p>By default, no headers are propagated if this is not set.
* @since 7.0.4
*/ */
public void addHeaderFilter(Predicate<String> filter) { public void addHeaderFilter(Predicate<String> filter) {
Assert.notNull(filter, "Filter predicate must not be null"); Assert.notNull(filter, "Filter predicate must not be null");
if (this.headerFilter == null) { this.headerFilter = (this.headerFilter != null ? this.headerFilter.or(filter) : filter);
this.headerFilter = filter;
} else {
this.headerFilter = this.headerFilter.or(filter);
}
} }
/** /**
* Return the configured header filter. * Return the configured header filter.
* @since 7.0.4
*/ */
public @Nullable Predicate<String> getHeaderFilter() { public @Nullable Predicate<String> getHeaderFilter() {
return this.headerFilter; return this.headerFilter;
@ -161,17 +159,13 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
if (getHeaderInitializer() != null) { if (getHeaderInitializer() != null) {
getHeaderInitializer().initHeaders(accessor); getHeaderInitializer().initHeaders(accessor);
} }
if (inputMessage != null && this.headerFilter != null) {
if (inputMessage != null && headerFilter != null) { inputMessage.getHeaders().forEach((name, value) -> {
Map<String, Object> inputHeaders = inputMessage.getHeaders(); if (this.headerFilter.test(name)) {
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) { accessor.setHeader(name, value);
String name = entry.getKey();
if (headerFilter.test(name)) {
accessor.setHeader(name, entry.getValue());
} }
} });
} }
if (sessionId != null) { if (sessionId != null) {
accessor.setSessionId(sessionId); accessor.setSessionId(sessionId);
} }

23
spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java

@ -298,22 +298,21 @@ public class SendToMethodReturnValueHandlerTests {
given(this.messageChannel.send(any(Message.class))).willReturn(true); given(this.messageChannel.send(any(Message.class))).willReturn(true);
String sessionId = "sess1"; String sessionId = "sess1";
String customHeaderName = "x-custom-header"; String headerName = "x-custom-header";
String customHeaderValue = "custom-value"; String headerValue = "custom-value";
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null); Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
inputMessage = MessageBuilder.fromMessage(inputMessage) inputMessage = MessageBuilder.fromMessage(inputMessage).setHeader(headerName, headerValue).build();
.setHeader(customHeaderName, customHeaderValue)
.build();
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true); SimpMessagingTemplate template = new SimpMessagingTemplate(this.messageChannel);
handler.addHeaderFilter(name -> name.equals(customHeaderName)); SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(template, true);
handler.addHeaderFilter(name -> name.equals(headerName));
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage); handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
for (Message<?> sent : this.messageCaptor.getAllValues()) { for (Message<?> sent : this.messageCaptor.getAllValues()) {
MessageHeaders headers = sent.getHeaders(); MessageHeaders headers = sent.getHeaders();
assertThat(headers.get(customHeaderName)).isEqualTo(customHeaderValue); assertThat(headers.get(headerName)).isEqualTo(headerValue);
} }
} }
@ -330,7 +329,8 @@ public class SendToMethodReturnValueHandlerTests {
.setHeader(headerB, "B-value") .setHeader(headerB, "B-value")
.build(); .build();
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true); SimpMessagingTemplate template = new SimpMessagingTemplate(this.messageChannel);
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(template, true);
handler.addHeaderFilter(name -> name.equals(headerA)); handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB)); handler.addHeaderFilter(name -> name.equals(headerB));
@ -344,9 +344,8 @@ public class SendToMethodReturnValueHandlerTests {
} }
} }
private void assertResponse(
private void assertResponse(MethodParameter methodParameter, String sessionId, MethodParameter methodParameter, String sessionId, int index, String destination) {
int index, String destination) {
SimpMessageHeaderAccessor accessor = getCapturedAccessor(index); SimpMessageHeaderAccessor accessor = getCapturedAccessor(index);
assertThat(accessor.getSessionId()).isEqualTo(sessionId); assertThat(accessor.getSessionId()).isEqualTo(sessionId);

20
spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java

@ -193,6 +193,7 @@ class SubscriptionMethodReturnValueHandlerTests {
String destination = "/dest"; String destination = "/dest";
String customHeaderName = "x-custom-header"; String customHeaderName = "x-custom-header";
String customHeaderValue = "custom-value"; String customHeaderValue = "custom-value";
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD) Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId) .setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId) .setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
@ -200,18 +201,16 @@ class SubscriptionMethodReturnValueHandlerTests {
.setHeader(customHeaderName, customHeaderValue) .setHeader(customHeaderName, customHeaderValue)
.build(); .build();
MessageSendingOperations messagingTemplate = mock(); MessageSendingOperations template = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate); SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(template);
handler.addHeaderFilter(name -> name.equals(customHeaderName)); handler.addHeaderFilter(name -> name.equals(customHeaderName));
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage); handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class); ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture()); verify(template).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
MessageHeaders sentHeaders = captor.getValue(); assertThat(captor.getValue().get(customHeaderName)).isEqualTo(customHeaderValue);
assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue);
} }
@Test @Test
@ -221,6 +220,7 @@ class SubscriptionMethodReturnValueHandlerTests {
String destination = "/dest"; String destination = "/dest";
String headerA = "x-header-a"; String headerA = "x-header-a";
String headerB = "x-header-b"; String headerB = "x-header-b";
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD) Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId) .setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId) .setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
@ -229,23 +229,21 @@ class SubscriptionMethodReturnValueHandlerTests {
.setHeader(headerB, "B-value") .setHeader(headerB, "B-value")
.build(); .build();
MessageSendingOperations messagingTemplate = mock(); MessageSendingOperations template = mock();
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate); SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(template);
handler.addHeaderFilter(name -> name.equals(headerA)); handler.addHeaderFilter(name -> name.equals(headerA));
handler.addHeaderFilter(name -> name.equals(headerB)); handler.addHeaderFilter(name -> name.equals(headerB));
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage); handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class); ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture()); verify(template).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
MessageHeaders sentHeaders = captor.getValue(); MessageHeaders sentHeaders = captor.getValue();
assertThat(sentHeaders.get(headerA)).isEqualTo("A-value"); assertThat(sentHeaders.get(headerA)).isEqualTo("A-value");
assertThat(sentHeaders.get(headerB)).isEqualTo("B-value"); assertThat(sentHeaders.get(headerB)).isEqualTo("B-value");
} }
private Message<?> createInputMessage(String sessId, String subsId, String dest, Principal principal) { private Message<?> createInputMessage(String sessId, String subsId, String dest, Principal principal) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();
headers.setSessionId(sessId); headers.setSessionId(sessId);

Loading…
Cancel
Save