From e83793ba7f804c30f183a3295e4fd051914b1fbe Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Fri, 4 Aug 2023 10:08:50 +0200 Subject: [PATCH] Batch SSE events writes when possible Prior to this commit, the `SseEventBuilder` would be used to create SSE events and write them to the connection using the `ResponseBodyEmitter`. This would send each data item one by one, effectively writing and flushing to the network for each. Since multiple data lines are prepared by the `SseEventBuilder`, a typical write of an SSE event performs multiple flushes operations. This commit adds a method on `ResponseBodyEmitter` to perform batch writes (given a `Set`) and only flush once all elements of the set have been written. This also applies in case of early writes, where now all buffered elements are written then flushed altogether. Fixes gh-30912 --- .../annotation/ResponseBodyEmitter.java | 56 ++++++++++++++++--- ...ResponseBodyEmitterReturnValueHandler.java | 11 +++- .../mvc/method/annotation/SseEmitter.java | 6 +- .../annotation/ReactiveTypeHandlerTests.java | 5 ++ .../annotation/ResponseBodyEmitterTests.java | 32 +++++------ .../method/annotation/SseEmitterTests.java | 31 +++++++++- 6 files changed, 108 insertions(+), 33 deletions(-) diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java index 929385ba687..fe845f274ce 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java @@ -128,9 +128,7 @@ public class ResponseBodyEmitter { this.handler = handler; try { - for (DataWithMediaType sendAttempt : this.earlySendAttempts) { - sendInternal(sendAttempt.getData(), sendAttempt.getMediaType()); - } + sendInternal(this.earlySendAttempts); } finally { this.earlySendAttempts.clear(); @@ -194,11 +192,7 @@ public class ResponseBodyEmitter { */ public synchronized void send(Object object, @Nullable MediaType mediaType) throws IOException { Assert.state(!this.complete, () -> "ResponseBodyEmitter has already completed" + - (this.failure != null ? " with error: " + this.failure : "")); - sendInternal(object, mediaType); - } - - private void sendInternal(Object object, @Nullable MediaType mediaType) throws IOException { + (this.failure != null ? " with error: " + this.failure : "")); if (this.handler != null) { try { this.handler.send(object, mediaType); @@ -217,6 +211,43 @@ public class ResponseBodyEmitter { } } + /** + * Write a set of data and MediaType pairs in a batch. + *

Compared to {@link #send(Object, MediaType)}, this batches the write operations + * and flushes to the network at the end. + * @param items the object and media type pairs to write + * @throws IOException raised when an I/O error occurs + * @throws java.lang.IllegalStateException wraps any other errors + * @since 6.0.12 + */ + public synchronized void send(Set items) throws IOException { + Assert.state(!this.complete, () -> "ResponseBodyEmitter has already completed" + + (this.failure != null ? " with error: " + this.failure : "")); + sendInternal(items); + } + + private void sendInternal(Set items) throws IOException { + if (items.isEmpty()) { + return; + } + if (this.handler != null) { + try { + this.handler.send(items); + } + catch (IOException ex) { + this.sendFailed = true; + throw ex; + } + catch (Throwable ex) { + this.sendFailed = true; + throw new IllegalStateException("Failed to send " + items, ex); + } + } + else { + this.earlySendAttempts.addAll(items); + } + } + /** * Complete request processing by performing a dispatch into the servlet * container, where Spring MVC is invoked once more, and completes the @@ -302,8 +333,17 @@ public class ResponseBodyEmitter { */ interface Handler { + /** + * Immediately write and flush the given data to the network. + */ void send(Object data, @Nullable MediaType mediaType) throws IOException; + /** + * Immediately write all data items then flush to the network. + * @since 6.0.12 + */ + void send(Set items) throws IOException; + void complete(); void completeWithError(Throwable failure); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java index 1667c210c5d..ee4eebcf9c2 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.Set; import java.util.function.Consumer; import jakarta.servlet.ServletRequest; @@ -202,6 +203,15 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur @Override public void send(Object data, @Nullable MediaType mediaType) throws IOException { sendInternal(data, mediaType); + this.outputMessage.flush(); + } + + @Override + public void send(Set items) throws IOException { + for (ResponseBodyEmitter.DataWithMediaType item : items) { + sendInternal(item.getData(), item.getMediaType()); + } + this.outputMessage.flush(); } @SuppressWarnings("unchecked") @@ -209,7 +219,6 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur for (HttpMessageConverter converter : ResponseBodyEmitterReturnValueHandler.this.sseMessageConverters) { if (converter.canWrite(data.getClass(), mediaType)) { ((HttpMessageConverter) converter).write(data, mediaType, this.outputMessage); - this.outputMessage.flush(); return; } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java index b1358bc47e1..5067f5b7198 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -123,9 +123,7 @@ public class SseEmitter extends ResponseBodyEmitter { public void send(SseEventBuilder builder) throws IOException { Set dataToSend = builder.build(); synchronized (this) { - for (DataWithMediaType entry : dataToSend) { - super.send(entry.getData(), entry.getMediaType()); - } + super.send(dataToSend); } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java index d7952e28a09..793c77493bf 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java @@ -365,6 +365,11 @@ public class ReactiveTypeHandlerTests { this.values.add(data); } + @Override + public void send(Set items) throws IOException { + items.forEach(item -> this.values.add(item.getData())); + } + @Override public void complete() { } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java index a2a389d5bb7..f691ae64615 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java @@ -30,9 +30,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIOException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anySet; import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -52,34 +52,33 @@ public class ResponseBodyEmitterTests { @Test - public void sendBeforeHandlerInitialized() throws Exception { + void sendBeforeHandlerInitialized() throws Exception { this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("bar", MediaType.TEXT_PLAIN); this.emitter.complete(); verifyNoMoreInteractions(this.handler); this.emitter.initialize(this.handler); - verify(this.handler).send("foo", MediaType.TEXT_PLAIN); - verify(this.handler).send("bar", MediaType.TEXT_PLAIN); + verify(this.handler).send(anySet()); verify(this.handler).complete(); verifyNoMoreInteractions(this.handler); } @Test - public void sendDuplicateBeforeHandlerInitialized() throws Exception { + void sendDuplicateBeforeHandlerInitialized() throws Exception { this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.complete(); verifyNoMoreInteractions(this.handler); this.emitter.initialize(this.handler); - verify(this.handler, times(2)).send("foo", MediaType.TEXT_PLAIN); + verify(this.handler).send(anySet()); verify(this.handler).complete(); verifyNoMoreInteractions(this.handler); } @Test - public void sendBeforeHandlerInitializedWithError() throws Exception { + void sendBeforeHandlerInitializedWithError() throws Exception { IllegalStateException ex = new IllegalStateException(); this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("bar", MediaType.TEXT_PLAIN); @@ -87,21 +86,20 @@ public class ResponseBodyEmitterTests { verifyNoMoreInteractions(this.handler); this.emitter.initialize(this.handler); - verify(this.handler).send("foo", MediaType.TEXT_PLAIN); - verify(this.handler).send("bar", MediaType.TEXT_PLAIN); + verify(this.handler).send(anySet()); verify(this.handler).completeWithError(ex); verifyNoMoreInteractions(this.handler); } @Test - public void sendFailsAfterComplete() throws Exception { + void sendFailsAfterComplete() throws Exception { this.emitter.complete(); assertThatIllegalStateException().isThrownBy(() -> this.emitter.send("foo")); } @Test - public void sendAfterHandlerInitialized() throws Exception { + void sendAfterHandlerInitialized() throws Exception { this.emitter.initialize(this.handler); verify(this.handler).onTimeout(any()); verify(this.handler).onError(any()); @@ -119,7 +117,7 @@ public class ResponseBodyEmitterTests { } @Test - public void sendAfterHandlerInitializedWithError() throws Exception { + void sendAfterHandlerInitializedWithError() throws Exception { this.emitter.initialize(this.handler); verify(this.handler).onTimeout(any()); verify(this.handler).onError(any()); @@ -138,7 +136,7 @@ public class ResponseBodyEmitterTests { } @Test - public void sendWithError() throws Exception { + void sendWithError() throws Exception { this.emitter.initialize(this.handler); verify(this.handler).onTimeout(any()); verify(this.handler).onError(any()); @@ -154,7 +152,7 @@ public class ResponseBodyEmitterTests { } @Test - public void onTimeoutBeforeHandlerInitialized() throws Exception { + void onTimeoutBeforeHandlerInitialized() throws Exception { Runnable runnable = mock(); this.emitter.onTimeout(runnable); this.emitter.initialize(this.handler); @@ -169,7 +167,7 @@ public class ResponseBodyEmitterTests { } @Test - public void onTimeoutAfterHandlerInitialized() throws Exception { + void onTimeoutAfterHandlerInitialized() throws Exception { this.emitter.initialize(this.handler); ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); @@ -185,7 +183,7 @@ public class ResponseBodyEmitterTests { } @Test - public void onCompletionBeforeHandlerInitialized() throws Exception { + void onCompletionBeforeHandlerInitialized() throws Exception { Runnable runnable = mock(); this.emitter.onCompletion(runnable); this.emitter.initialize(this.handler); @@ -200,7 +198,7 @@ public class ResponseBodyEmitterTests { } @Test - public void onCompletionAfterHandlerInitialized() throws Exception { + void onCompletionAfterHandlerInitialized() throws Exception { this.emitter.initialize(this.handler); ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java index 5c81bb55266..570986e056d 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java @@ -20,12 +20,14 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.Set; import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event; @@ -60,6 +62,7 @@ public class SseEmitterTests { this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(1, "foo"); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test @@ -69,12 +72,14 @@ public class SseEmitterTests { this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(1, "foo", MediaType.TEXT_PLAIN); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test public void sendEventEmpty() throws Exception { this.emitter.send(event()); this.handler.assertSentObjectCount(0); + this.handler.assertWriteCount(0); } @Test @@ -84,6 +89,7 @@ public class SseEmitterTests { this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(1, "foo"); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test @@ -95,6 +101,7 @@ public class SseEmitterTests { this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8); this.handler.assertObject(3, "bar"); this.handler.assertObject(4, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test @@ -104,6 +111,7 @@ public class SseEmitterTests { this.handler.assertObject(0, ":blah\nevent:test\nretry:5000\nid:1\ndata:", TEXT_PLAIN_UTF8); this.handler.assertObject(1, "foo"); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test @@ -115,14 +123,17 @@ public class SseEmitterTests { this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8); this.handler.assertObject(3, "bar"); this.handler.assertObject(4, "\nevent:test\nretry:5000\nid:1\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } private static class TestHandler implements ResponseBodyEmitter.Handler { - private List objects = new ArrayList<>(); + private final List objects = new ArrayList<>(); - private List mediaTypes = new ArrayList<>(); + private final List mediaTypes = new ArrayList<>(); + + private int writeCount; public void assertSentObjectCount(int size) { @@ -139,10 +150,24 @@ public class SseEmitterTests { assertThat(this.mediaTypes.get(index)).isEqualTo(mediaType); } + public void assertWriteCount(int writeCount) { + assertThat(this.writeCount).isEqualTo(writeCount); + } + @Override - public void send(Object data, MediaType mediaType) throws IOException { + public void send(Object data, @Nullable MediaType mediaType) throws IOException { this.objects.add(data); this.mediaTypes.add(mediaType); + this.writeCount++; + } + + @Override + public void send(Set items) throws IOException { + for (ResponseBodyEmitter.DataWithMediaType item : items) { + this.objects.add(item.getData()); + this.mediaTypes.add(item.getMediaType()); + } + this.writeCount++; } @Override