diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java index 391325f443d..e9ce9a1a17c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; import org.apache.commons.logging.Log; @@ -441,7 +442,7 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { String receiptId = headers.getReceiptId(); ReceiptHandler handler = this.receiptHandlers.get(receiptId); if (handler != null) { - handler.handleReceiptReceived(); + handler.handleReceiptReceived(headers); } else if (logger.isDebugEnabled()) { logger.debug("No matching receipt: " + accessor.getDetailedLogMessage(message.getPayload())); @@ -546,9 +547,9 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { @Nullable private final String receiptId; - private final List receiptCallbacks = new ArrayList<>(2); + private final List> receiptCallbacks = new ArrayList<>(2); - private final List receiptLostCallbacks = new ArrayList<>(2); + private final List> receiptLostCallbacks = new ArrayList<>(2); @Nullable private ScheduledFuture future; @@ -556,6 +557,9 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { @Nullable private Boolean result; + @Nullable + private StompHeaders receiptHeaders; + public ReceiptHandler(@Nullable String receiptId) { this.receiptId = receiptId; if (receiptId != null) { @@ -578,15 +582,20 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { @Override public void addReceiptTask(Runnable task) { + addTask(h -> task.run(), true); + } + + @Override + public void addReceiptTask(Consumer task) { addTask(task, true); } @Override public void addReceiptLostTask(Runnable task) { - addTask(task, false); + addTask(h -> task.run(), false); } - private void addTask(Runnable task, boolean successTask) { + private void addTask(Consumer task, boolean successTask) { Assert.notNull(this.receiptId, "To track receipts, set autoReceiptEnabled=true or add 'receiptId' header"); synchronized (this) { @@ -604,10 +613,10 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { } } - private void invoke(List callbacks) { - for (Runnable runnable : callbacks) { + private void invoke(List> callbacks) { + for (Consumer consumer : callbacks) { try { - runnable.run(); + consumer.accept(this.receiptHeaders); } catch (Throwable ex) { // ignore @@ -615,20 +624,21 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { } } - public void handleReceiptReceived() { - handleInternal(true); + public void handleReceiptReceived(StompHeaders receiptHeaders) { + handleInternal(true, receiptHeaders); } public void handleReceiptNotReceived() { - handleInternal(false); + handleInternal(false, null); } - private void handleInternal(boolean result) { + private void handleInternal(boolean result, @Nullable StompHeaders receiptHeaders) { synchronized (this) { if (this.result != null) { return; } this.result = result; + this.receiptHeaders = receiptHeaders; invoke(result ? this.receiptCallbacks : this.receiptLostCallbacks); DefaultStompSession.this.receiptHandlers.remove(this.receiptId); if (this.future != null) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java index 01ad81b8d44..9d1cebb391d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.messaging.simp.stomp; +import java.util.function.Consumer; + import org.springframework.lang.Nullable; /** @@ -143,6 +145,13 @@ public interface StompSession { */ void addReceiptTask(Runnable runnable); + /** + * Consumer to invoke when a receipt is received. Accepts the headers of the received RECEIPT frame. + * @throws java.lang.IllegalArgumentException if the receiptId is {@code null} + * @since TBD + */ + void addReceiptTask(Consumer task); + /** * Task to invoke when a receipt is not received in the configured time. * @throws java.lang.IllegalArgumentException if the receiptId is {@code null} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java index 7f979b54213..f4303835d10 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -576,22 +576,30 @@ public class DefaultStompSessionTests { this.session.setTaskScheduler(mock(TaskScheduler.class)); AtomicReference received = new AtomicReference<>(); + AtomicReference receivedHeaders = new AtomicReference<>(); StompHeaders headers = new StompHeaders(); headers.setDestination("/topic/foo"); headers.setReceipt("my-receipt"); Subscription subscription = this.session.subscribe(headers, mock(StompFrameHandler.class)); - subscription.addReceiptTask(() -> received.set(true)); + subscription.addReceiptTask(receiptHeaders -> { + received.set(true); + receivedHeaders.set(receiptHeaders); + }); assertThat((Object) received.get()).isNull(); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.RECEIPT); accessor.setReceiptId("my-receipt"); + accessor.setNativeHeader("foo", "bar"); accessor.setLeaveMutable(true); this.session.handleMessage(MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders())); assertThat(received.get()).isNotNull(); assertThat(received.get()).isTrue(); + assertThat(receivedHeaders.get()).isNotNull(); + assertThat(receivedHeaders.get().get("foo").size()).isEqualTo(1); + assertThat(receivedHeaders.get().get("foo").get(0)).isEqualTo("bar"); } @Test @@ -600,6 +608,7 @@ public class DefaultStompSessionTests { this.session.setTaskScheduler(mock(TaskScheduler.class)); AtomicReference received = new AtomicReference<>(); + AtomicReference receivedHeaders = new AtomicReference<>(); StompHeaders headers = new StompHeaders(); headers.setDestination("/topic/foo"); @@ -608,13 +617,20 @@ public class DefaultStompSessionTests { StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.RECEIPT); accessor.setReceiptId("my-receipt"); + accessor.setNativeHeader("foo", "bar"); accessor.setLeaveMutable(true); this.session.handleMessage(MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders())); - subscription.addReceiptTask(() -> received.set(true)); + subscription.addReceiptTask(receiptHeaders -> { + received.set(true); + receivedHeaders.set(receiptHeaders); + }); assertThat(received.get()).isNotNull(); assertThat(received.get()).isTrue(); + assertThat(receivedHeaders.get()).isNotNull(); + assertThat(receivedHeaders.get().get("foo").size()).isEqualTo(1); + assertThat(receivedHeaders.get().get("foo").get(0)).isEqualTo("bar"); } @Test diff --git a/src/docs/asciidoc/web/websocket.adoc b/src/docs/asciidoc/web/websocket.adoc index b67b13d0d0c..00f485b721f 100644 --- a/src/docs/asciidoc/web/websocket.adoc +++ b/src/docs/asciidoc/web/websocket.adoc @@ -1347,7 +1347,7 @@ receipt if the server supports it (simple broker does not). For example, with th headers.setDestination("/topic/..."); headers.setReceipt("r1"); FrameHandler handler = ...; - stompSession.subscribe(headers, handler).addReceiptTask(() -> { + stompSession.subscribe(headers, handler).addReceiptTask(receiptHeaders -> { // Subscription ready... }); ----