From f5c287a6e66a76c12359ccfeb8a89f7495e7c18b Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 27 Aug 2020 10:50:05 +0100 Subject: [PATCH] OrderedMessageSender throughput improvement Before this change messages were sent serially across sessions but ordering is important only within a session. This leads to head of line blocking when a session is slow to send, and also enforcement of send buffer size and time limits is precluded because it happens at a lower level in the transport. This change ensures messages are held up only if there is another from the same session is being sent. This allows messages from each session to flow independent of other. See gh-25581 --- .../simp/broker/OrderedMessageSender.java | 149 ++++++++++++++---- .../broker/OrderedMessageSenderTests.java | 116 ++++++++++---- 2 files changed, 204 insertions(+), 61 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java index c155da7d681..c7652d4eb43 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageSender.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,13 @@ package org.springframework.messaging.simp.broker; +import java.util.Collection; +import java.util.HashSet; import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.logging.Log; @@ -33,8 +38,14 @@ import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; /** - * Submit messages to an {@link ExecutorSubscribableChannel}, one at a time. - * The channel must have been configured with {@link #configureOutboundChannel}. + * {@code MessageChannel} decorator that ensures messages from the same session + * are sent and processed in the same order. This would not normally be the case + * with an {@code Executor} backed {@code MessageChannel} since the executor + * is free to submit tasks in any order. + * + *

To provide ordering, inbound messages are placed in a queue and sent one + * one at a time per session. Once a message is processed, a callback is used to + * notify that the next message from the same session can be sent through. * * @author Rossen Stoyanchev * @since 5.1 @@ -48,9 +59,7 @@ class OrderedMessageSender implements MessageChannel { private final Log logger; - private final Queue> messages = new ConcurrentLinkedQueue<>(); - - private final AtomicBoolean sendInProgress = new AtomicBoolean(false); + private final Control control = new Control(); public OrderedMessageSender(MessageChannel channel, Log logger) { @@ -66,30 +75,40 @@ class OrderedMessageSender implements MessageChannel { @Override public boolean send(Message message, long timeout) { - this.messages.add(message); + this.control.addMessage(message); trySend(); return true; } private void trySend() { - // Take sendInProgress flag only if queue is not empty - if (this.messages.isEmpty()) { - return; - } - - if (this.sendInProgress.compareAndSet(false, true)) { - sendNextMessage(); + if (this.control.acquireSendLock()) { + sendMessages(); } } - private void sendNextMessage() { - for (;;) { - Message message = this.messages.poll(); - if (message != null) { + private void sendMessages() { + for ( ; ; ) { + Set skipSet = new HashSet<>(); + for (Message message : this.control.getMessagesToSend()) { + String sessionId = SimpMessageHeaderAccessor.getSessionId(message.getHeaders()); + Assert.notNull(sessionId, () -> "No session id in " + message.getHeaders()); + if (skipSet.contains(sessionId)) { + continue; + } + if (!this.control.acquireSessionLock(sessionId)) { + skipSet.add(sessionId); + continue; + } + this.control.removeMessage(message); try { - addCompletionCallback(message); + getMutableAccessor(message).setHeader(COMPLETION_TASK_HEADER, (Runnable) () -> { + this.control.releaseSessionLock(sessionId); + if (this.control.hasRemainingWork()) { + trySend(); + } + }); if (this.channel.send(message)) { - return; + continue; } } catch (Throwable ex) { @@ -97,20 +116,24 @@ class OrderedMessageSender implements MessageChannel { logger.error("Failed to send " + message, ex); } } + // We didn't send + this.control.releaseSessionLock(sessionId); } - else { - // We ran out of messages.. - this.sendInProgress.set(false); - trySend(); - break; + + if (this.control.shouldYield()) { + this.control.releaseSendLock(); + if (!this.control.shouldYield()) { + trySend(); + } + return; } } } - private void addCompletionCallback(Message msg) { - SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(msg, SimpMessageHeaderAccessor.class); + private SimpMessageHeaderAccessor getMutableAccessor(Message message) { + SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class); Assert.isTrue(accessor != null && accessor.isMutable(), "Expected mutable SimpMessageHeaderAccessor"); - accessor.setHeader(COMPLETION_TASK_HEADER, (Runnable) this::sendNextMessage); + return accessor; } @@ -126,13 +149,13 @@ class OrderedMessageSender implements MessageChannel { Assert.isInstanceOf(ExecutorSubscribableChannel.class, channel, "An ExecutorSubscribableChannel is required for `preservePublishOrder`"); ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel; - if (execChannel.getInterceptors().stream().noneMatch(i -> i instanceof CallbackInterceptor)) { - execChannel.addInterceptor(0, new CallbackInterceptor()); + if (execChannel.getInterceptors().stream().noneMatch(i -> i instanceof CompletionTaskInterceptor)) { + execChannel.addInterceptor(0, new CompletionTaskInterceptor()); } } else if (channel instanceof ExecutorSubscribableChannel) { ExecutorSubscribableChannel execChannel = (ExecutorSubscribableChannel) channel; - execChannel.getInterceptors().stream().filter(i -> i instanceof CallbackInterceptor) + execChannel.getInterceptors().stream().filter(i -> i instanceof CompletionTaskInterceptor) .findFirst() .map(execChannel::removeInterceptor); @@ -140,13 +163,71 @@ class OrderedMessageSender implements MessageChannel { } - private static class CallbackInterceptor implements ExecutorChannelInterceptor { + /** + * Provides locks required for ordered message sending and execution within + * a session as well as storage for messages waiting to be sent. + */ + private static class Control { + + private final Queue> messages = new ConcurrentLinkedQueue<>(); + + private final ConcurrentMap sessionsInProgress = new ConcurrentHashMap<>(); + + private final AtomicBoolean workInProgress = new AtomicBoolean(false); + + + public void addMessage(Message message) { + this.messages.add(message); + } + + public void removeMessage(Message message) { + if (!this.messages.remove(message)) { + throw new IllegalStateException( + "Message " + message.getHeaders() + " was expected in the queue."); + } + } + + public Collection> getMessagesToSend() { + return this.messages; + } + + public boolean acquireSendLock() { + return this.workInProgress.compareAndSet(false, true); + } + + public void releaseSendLock() { + this.workInProgress.set(false); + } + + public boolean acquireSessionLock(String sessionId) { + if (this.sessionsInProgress.put(sessionId, Boolean.TRUE) != null) { + return false; + } + return true; + } + + public void releaseSessionLock(String sessionId) { + this.sessionsInProgress.remove(sessionId); + } + + public boolean hasRemainingWork() { + return !this.messages.isEmpty(); + } + + public boolean shouldYield() { + // No remaining work, or others can pick it up + return (!hasRemainingWork() || this.sessionsInProgress.size() > 0); + } + } + + + private static class CompletionTaskInterceptor implements ExecutorChannelInterceptor { @Override public void afterMessageHandled( - Message msg, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) { + Message message, MessageChannel ch, MessageHandler handler, @Nullable Exception ex) { - Runnable task = (Runnable) msg.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER); + Runnable task = (Runnable) message.getHeaders().get(OrderedMessageSender.COMPLETION_TASK_HEADER); if (task != null) { task.run(); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java index c1fffb0a1a3..cf7914c1935 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageSenderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,10 @@ package org.springframework.messaging.simp.broker; +import java.time.Duration; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -26,7 +30,12 @@ import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.ExecutorSubscribableChannel; @@ -43,6 +52,8 @@ public class OrderedMessageSenderTests { private static final Log logger = LogFactory.getLog(OrderedMessageSenderTests.class); + private static final Random random = new Random(); + private OrderedMessageSender sender; @@ -74,46 +85,97 @@ public class OrderedMessageSenderTests { @Test public void test() throws InterruptedException { - int start = 1; - int end = 1000; + int sessionCount = 25; + int messagesPerSessionCount = 500; + + TestMessageHandler handler = new TestMessageHandler(sessionCount * messagesPerSessionCount); + this.channel.subscribe(handler); + + Publisher>> messageFluxes = + Flux.range(1, sessionCount).map(sessionId -> + Flux.range(1, messagesPerSessionCount) + .map(sequence -> createMessage(sessionId, sequence)) + .delayElements(Duration.ofMillis(Math.abs(random.nextLong()) % 5))); + + Flux.merge(messageFluxes) + .doOnNext(message -> this.sender.send(message)) + .blockLast(); + + handler.await(20, TimeUnit.SECONDS); + + assertThat(handler.getDescription()).isEqualTo("Total processed: " + sessionCount * messagesPerSessionCount); + assertThat(handler.getSequenceBySession()).hasSize(sessionCount); + handler.getSequenceBySession().forEach((key, value) -> + assertThat(value.get()).as(key).isEqualTo(messagesPerSessionCount)); + } + + private static Message createMessage(Integer sessionId, Integer sequence) { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + accessor.setSessionId("session" + sessionId); + accessor.setHeader("seq", sequence); + accessor.setLeaveMutable(true); + return MessageBuilder.createMessage("payload", accessor.getMessageHeaders()); + } + + + private static class TestMessageHandler implements MessageHandler { + + private final int totalExpected; + + private final Map sequenceBySession = new ConcurrentHashMap<>(); + + private final AtomicReference description = new AtomicReference<>(); + + private final AtomicInteger totalReceived = new AtomicInteger(); - AtomicInteger index = new AtomicInteger(start); - AtomicReference result = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); + private final CountDownLatch latch = new CountDownLatch(1); - this.channel.subscribe(message -> { - int expected = index.getAndIncrement(); - Integer actual = (Integer) message.getHeaders().getOrDefault("seq", -1); - if (actual != expected) { - result.set("Expected: " + expected + ", but was: " + actual); + TestMessageHandler(int totalExpected) { + this.totalExpected = totalExpected; + } + + public void await(long timeout, TimeUnit timeUnit) throws InterruptedException { + latch.await(timeout, timeUnit); + } + + public Map getSequenceBySession() { + return sequenceBySession; + } + + public String getDescription() { + return description.get(); + } + + @Override + public void handleMessage(Message message) throws MessagingException { + String id = SimpMessageHeaderAccessor.getSessionId(message.getHeaders()); + Integer seq = (Integer) message.getHeaders().getOrDefault("seq", -1); + + AtomicInteger prev = sequenceBySession.computeIfAbsent(id, i -> new AtomicInteger(0)); + if (!prev.compareAndSet(seq - 1, seq)) { + description.set("Out of order, session=" + id + ", prev=" + prev + ", next=" + seq); latch.countDown(); return; } - if (actual == 100 || actual == 200) { + + if (seq == 100) { try { - Thread.sleep(200); + // Processing delay to cause other session messages to queue up + Thread.sleep(50); } catch (InterruptedException ex) { - result.set(ex.toString()); + description.set(ex.toString()); latch.countDown(); + return; } } - if (actual == end) { - result.set("Done"); + + int total = totalReceived.incrementAndGet(); + description.set("Total processed: " + total); + if (total == totalExpected) { latch.countDown(); } - }); - - for (int i = start; i <= end; i++) { - SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); - accessor.setHeader("seq", i); - accessor.setLeaveMutable(true); - this.sender.send(MessageBuilder.createMessage("payload", accessor.getMessageHeaders())); } - - latch.await(10, TimeUnit.SECONDS); - assertThat(result.get()).isEqualTo("Done"); } - }