Browse Source

Pass headers to STOMP receipt callbacks

See gh-28715
pull/30296/head
Napster 4 years ago committed by rstoyanchev
parent
commit
d42f950a36
  1. 36
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java
  2. 11
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java
  3. 22
      spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java
  4. 2
      src/docs/asciidoc/web/websocket.adoc

36
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -441,7 +442,7 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
String receiptId = headers.getReceiptId(); String receiptId = headers.getReceiptId();
ReceiptHandler handler = this.receiptHandlers.get(receiptId); ReceiptHandler handler = this.receiptHandlers.get(receiptId);
if (handler != null) { if (handler != null) {
handler.handleReceiptReceived(); handler.handleReceiptReceived(headers);
} }
else if (logger.isDebugEnabled()) { else if (logger.isDebugEnabled()) {
logger.debug("No matching receipt: " + accessor.getDetailedLogMessage(message.getPayload())); logger.debug("No matching receipt: " + accessor.getDetailedLogMessage(message.getPayload()));
@ -546,9 +547,9 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
@Nullable @Nullable
private final String receiptId; private final String receiptId;
private final List<Runnable> receiptCallbacks = new ArrayList<>(2); private final List<Consumer<StompHeaders>> receiptCallbacks = new ArrayList<>(2);
private final List<Runnable> receiptLostCallbacks = new ArrayList<>(2); private final List<Consumer<StompHeaders>> receiptLostCallbacks = new ArrayList<>(2);
@Nullable @Nullable
private ScheduledFuture<?> future; private ScheduledFuture<?> future;
@ -556,6 +557,9 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
@Nullable @Nullable
private Boolean result; private Boolean result;
@Nullable
private StompHeaders receiptHeaders;
public ReceiptHandler(@Nullable String receiptId) { public ReceiptHandler(@Nullable String receiptId) {
this.receiptId = receiptId; this.receiptId = receiptId;
if (receiptId != null) { if (receiptId != null) {
@ -578,15 +582,20 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
@Override @Override
public void addReceiptTask(Runnable task) { public void addReceiptTask(Runnable task) {
addTask(h -> task.run(), true);
}
@Override
public void addReceiptTask(Consumer<StompHeaders> task) {
addTask(task, true); addTask(task, true);
} }
@Override @Override
public void addReceiptLostTask(Runnable task) { public void addReceiptLostTask(Runnable task) {
addTask(task, false); addTask(h -> task.run(), false);
} }
private void addTask(Runnable task, boolean successTask) { private void addTask(Consumer<StompHeaders> task, boolean successTask) {
Assert.notNull(this.receiptId, Assert.notNull(this.receiptId,
"To track receipts, set autoReceiptEnabled=true or add 'receiptId' header"); "To track receipts, set autoReceiptEnabled=true or add 'receiptId' header");
synchronized (this) { synchronized (this) {
@ -604,10 +613,10 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
} }
} }
private void invoke(List<Runnable> callbacks) { private void invoke(List<Consumer<StompHeaders>> callbacks) {
for (Runnable runnable : callbacks) { for (Consumer<StompHeaders> consumer : callbacks) {
try { try {
runnable.run(); consumer.accept(this.receiptHeaders);
} }
catch (Throwable ex) { catch (Throwable ex) {
// ignore // ignore
@ -615,20 +624,21 @@ public class DefaultStompSession implements ConnectionHandlingStompSession {
} }
} }
public void handleReceiptReceived() { public void handleReceiptReceived(StompHeaders receiptHeaders) {
handleInternal(true); handleInternal(true, receiptHeaders);
} }
public void handleReceiptNotReceived() { public void handleReceiptNotReceived() {
handleInternal(false); handleInternal(false, null);
} }
private void handleInternal(boolean result) { private void handleInternal(boolean result, @Nullable StompHeaders receiptHeaders) {
synchronized (this) { synchronized (this) {
if (this.result != null) { if (this.result != null) {
return; return;
} }
this.result = result; this.result = result;
this.receiptHeaders = receiptHeaders;
invoke(result ? this.receiptCallbacks : this.receiptLostCallbacks); invoke(result ? this.receiptCallbacks : this.receiptLostCallbacks);
DefaultStompSession.this.receiptHandlers.remove(this.receiptId); DefaultStompSession.this.receiptHandlers.remove(this.receiptId);
if (this.future != null) { if (this.future != null) {

11
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
package org.springframework.messaging.simp.stomp; package org.springframework.messaging.simp.stomp;
import java.util.function.Consumer;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
/** /**
@ -143,6 +145,13 @@ public interface StompSession {
*/ */
void addReceiptTask(Runnable runnable); void addReceiptTask(Runnable runnable);
/**
* Consumer to invoke when a receipt is received. Accepts the headers of the received RECEIPT frame.
* @throws java.lang.IllegalArgumentException if the receiptId is {@code null}
* @since TBD
*/
void addReceiptTask(Consumer<StompHeaders> task);
/** /**
* Task to invoke when a receipt is not received in the configured time. * Task to invoke when a receipt is not received in the configured time.
* @throws java.lang.IllegalArgumentException if the receiptId is {@code null} * @throws java.lang.IllegalArgumentException if the receiptId is {@code null}

22
spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -576,22 +576,30 @@ public class DefaultStompSessionTests {
this.session.setTaskScheduler(mock(TaskScheduler.class)); this.session.setTaskScheduler(mock(TaskScheduler.class));
AtomicReference<Boolean> received = new AtomicReference<>(); AtomicReference<Boolean> received = new AtomicReference<>();
AtomicReference<StompHeaders> receivedHeaders = new AtomicReference<>();
StompHeaders headers = new StompHeaders(); StompHeaders headers = new StompHeaders();
headers.setDestination("/topic/foo"); headers.setDestination("/topic/foo");
headers.setReceipt("my-receipt"); headers.setReceipt("my-receipt");
Subscription subscription = this.session.subscribe(headers, mock(StompFrameHandler.class)); Subscription subscription = this.session.subscribe(headers, mock(StompFrameHandler.class));
subscription.addReceiptTask(() -> received.set(true)); subscription.addReceiptTask(receiptHeaders -> {
received.set(true);
receivedHeaders.set(receiptHeaders);
});
assertThat((Object) received.get()).isNull(); assertThat((Object) received.get()).isNull();
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.RECEIPT); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.RECEIPT);
accessor.setReceiptId("my-receipt"); accessor.setReceiptId("my-receipt");
accessor.setNativeHeader("foo", "bar");
accessor.setLeaveMutable(true); accessor.setLeaveMutable(true);
this.session.handleMessage(MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders())); this.session.handleMessage(MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()));
assertThat(received.get()).isNotNull(); assertThat(received.get()).isNotNull();
assertThat(received.get()).isTrue(); assertThat(received.get()).isTrue();
assertThat(receivedHeaders.get()).isNotNull();
assertThat(receivedHeaders.get().get("foo").size()).isEqualTo(1);
assertThat(receivedHeaders.get().get("foo").get(0)).isEqualTo("bar");
} }
@Test @Test
@ -600,6 +608,7 @@ public class DefaultStompSessionTests {
this.session.setTaskScheduler(mock(TaskScheduler.class)); this.session.setTaskScheduler(mock(TaskScheduler.class));
AtomicReference<Boolean> received = new AtomicReference<>(); AtomicReference<Boolean> received = new AtomicReference<>();
AtomicReference<StompHeaders> receivedHeaders = new AtomicReference<>();
StompHeaders headers = new StompHeaders(); StompHeaders headers = new StompHeaders();
headers.setDestination("/topic/foo"); headers.setDestination("/topic/foo");
@ -608,13 +617,20 @@ public class DefaultStompSessionTests {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.RECEIPT); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.RECEIPT);
accessor.setReceiptId("my-receipt"); accessor.setReceiptId("my-receipt");
accessor.setNativeHeader("foo", "bar");
accessor.setLeaveMutable(true); accessor.setLeaveMutable(true);
this.session.handleMessage(MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders())); this.session.handleMessage(MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()));
subscription.addReceiptTask(() -> received.set(true)); subscription.addReceiptTask(receiptHeaders -> {
received.set(true);
receivedHeaders.set(receiptHeaders);
});
assertThat(received.get()).isNotNull(); assertThat(received.get()).isNotNull();
assertThat(received.get()).isTrue(); assertThat(received.get()).isTrue();
assertThat(receivedHeaders.get()).isNotNull();
assertThat(receivedHeaders.get().get("foo").size()).isEqualTo(1);
assertThat(receivedHeaders.get().get("foo").get(0)).isEqualTo("bar");
} }
@Test @Test

2
src/docs/asciidoc/web/websocket.adoc

@ -1347,7 +1347,7 @@ receipt if the server supports it (simple broker does not). For example, with th
headers.setDestination("/topic/..."); headers.setDestination("/topic/...");
headers.setReceipt("r1"); headers.setReceipt("r1");
FrameHandler handler = ...; FrameHandler handler = ...;
stompSession.subscribe(headers, handler).addReceiptTask(() -> { stompSession.subscribe(headers, handler).addReceiptTask(receiptHeaders -> {
// Subscription ready... // Subscription ready...
}); });
---- ----

Loading…
Cancel
Save