diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java index 25fff972a95..af06e8439eb 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/DefaultStompSession.java @@ -291,8 +291,7 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { @Override public Subscription subscribe(StompHeaders stompHeaders, StompFrameHandler handler) { - String destination = stompHeaders.getDestination(); - Assert.hasText(destination, "Destination header is required"); + Assert.hasText(stompHeaders.getDestination(), "Destination header is required"); Assert.notNull(handler, "StompFrameHandler must not be null"); String subscriptionId = stompHeaders.getId(); @@ -300,8 +299,8 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { subscriptionId = String.valueOf(DefaultStompSession.this.subscriptionIndex.getAndIncrement()); stompHeaders.setId(subscriptionId); } - String receiptId = checkOrAddReceipt(stompHeaders); - Subscription subscription = new DefaultSubscription(subscriptionId, destination, receiptId, handler); + checkOrAddReceipt(stompHeaders); + Subscription subscription = new DefaultSubscription(stompHeaders, handler); StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SUBSCRIBE); accessor.addNativeHeaders(stompHeaders); @@ -333,8 +332,11 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { return receiptable; } - private void unsubscribe(String id) { + private void unsubscribe(String id, StompHeaders stompHeaders) { StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.UNSUBSCRIBE); + if (stompHeaders != null) { + accessor.addNativeHeaders(stompHeaders); + } accessor.setSubscriptionId(id); Message message = createMessage(accessor, EMPTY_PAYLOAD); execute(message); @@ -600,29 +602,27 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { private class DefaultSubscription extends ReceiptHandler implements Subscription { - private final String id; - - private final String destination; + private final StompHeaders headers; private final StompFrameHandler handler; - public DefaultSubscription(String id, String destination, String receiptId, StompFrameHandler handler) { - super(receiptId); - Assert.notNull(destination, "Destination must not be null"); + public DefaultSubscription(StompHeaders headers, StompFrameHandler handler) { + super(headers.getReceipt()); + Assert.notNull(headers.getDestination(), "Destination must not be null"); Assert.notNull(handler, "StompFrameHandler must not be null"); - this.id = id; - this.destination = destination; + this.headers = headers; this.handler = handler; - DefaultStompSession.this.subscriptions.put(id, this); + DefaultStompSession.this.subscriptions.put(headers.getId(), this); } @Override public String getSubscriptionId() { - return this.id; + return this.headers.getId(); } - public String getDestination() { - return this.destination; + @Override + public StompHeaders getSubscriptionHeaders() { + return this.headers; } public StompFrameHandler getHandler() { @@ -631,13 +631,20 @@ public class DefaultStompSession implements ConnectionHandlingStompSession { @Override public void unsubscribe() { - DefaultStompSession.this.subscriptions.remove(getSubscriptionId()); - DefaultStompSession.this.unsubscribe(getSubscriptionId()); + unsubscribe(null); + } + + @Override + public void unsubscribe(StompHeaders stompHeaders) { + String id = this.headers.getId(); + DefaultStompSession.this.subscriptions.remove(id); + DefaultStompSession.this.unsubscribe(id, stompHeaders); } @Override public String toString() { - return "Subscription [id=" + getSubscriptionId() + ", destination='" + getDestination() + + return "Subscription [id=" + getSubscriptionId() + + ", destination='" + this.headers.getDestination() + "', receiptId='" + getReceiptId() + "', handler=" + getHandler() + "]"; } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java index c8f2f5eeae9..b348da11a36 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -141,10 +141,22 @@ public interface StompSession { */ String getSubscriptionId(); + /** + * Return the headers used on the SUBSCRIBE frame. + */ + StompHeaders getSubscriptionHeaders(); + /** * Remove the subscription by sending an UNSUBSCRIBE frame. */ void unsubscribe(); + + /** + * Alternative to {@link #unsubscribe()} with additional custom headers + * to send to the server. + *

Note: There is no need to set the subscription id. + */ + void unsubscribe(StompHeaders stompHeaders); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java index 5744ac1ce2c..44bfea5a00e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/DefaultStompSessionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -471,6 +471,34 @@ public class DefaultStompSessionTests { assertEquals(subscription.getSubscriptionId(), stompHeaders.getId()); } + @Test // SPR-15131 + public void unsubscribeWithCustomHeader() throws Exception { + this.session.afterConnected(this.connection); + assertTrue(this.session.isConnected()); + + String headerName = "durable-subscription-name"; + String headerValue = "123"; + + StompHeaders subscribeHeaders = new StompHeaders(); + subscribeHeaders.setDestination("/topic/foo"); + subscribeHeaders.set(headerName, headerValue); + StompFrameHandler frameHandler = mock(StompFrameHandler.class); + Subscription subscription = this.session.subscribe(subscribeHeaders, frameHandler); + + StompHeaders unsubscribeHeaders = new StompHeaders(); + unsubscribeHeaders.set(headerName, subscription.getSubscriptionHeaders().getFirst(headerName)); + subscription.unsubscribe(unsubscribeHeaders); + + Message message = this.messageCaptor.getValue(); + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertEquals(StompCommand.UNSUBSCRIBE, accessor.getCommand()); + + StompHeaders stompHeaders = StompHeaders.readOnlyStompHeaders(accessor.getNativeHeaders()); + assertEquals(stompHeaders.toString(), 2, stompHeaders.size()); + assertEquals(subscription.getSubscriptionId(), stompHeaders.getId()); + assertEquals(headerValue, stompHeaders.getFirst(headerName)); + } + @Test public void ack() throws Exception { this.session.afterConnected(this.connection);