diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java index 929385ba687..fe845f274ce 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java @@ -128,9 +128,7 @@ public class ResponseBodyEmitter { this.handler = handler; try { - for (DataWithMediaType sendAttempt : this.earlySendAttempts) { - sendInternal(sendAttempt.getData(), sendAttempt.getMediaType()); - } + sendInternal(this.earlySendAttempts); } finally { this.earlySendAttempts.clear(); @@ -194,11 +192,7 @@ public class ResponseBodyEmitter { */ public synchronized void send(Object object, @Nullable MediaType mediaType) throws IOException { Assert.state(!this.complete, () -> "ResponseBodyEmitter has already completed" + - (this.failure != null ? " with error: " + this.failure : "")); - sendInternal(object, mediaType); - } - - private void sendInternal(Object object, @Nullable MediaType mediaType) throws IOException { + (this.failure != null ? " with error: " + this.failure : "")); if (this.handler != null) { try { this.handler.send(object, mediaType); @@ -217,6 +211,43 @@ public class ResponseBodyEmitter { } } + /** + * Write a set of data and MediaType pairs in a batch. + *

Compared to {@link #send(Object, MediaType)}, this batches the write operations + * and flushes to the network at the end. + * @param items the object and media type pairs to write + * @throws IOException raised when an I/O error occurs + * @throws java.lang.IllegalStateException wraps any other errors + * @since 6.0.12 + */ + public synchronized void send(Set items) throws IOException { + Assert.state(!this.complete, () -> "ResponseBodyEmitter has already completed" + + (this.failure != null ? " with error: " + this.failure : "")); + sendInternal(items); + } + + private void sendInternal(Set items) throws IOException { + if (items.isEmpty()) { + return; + } + if (this.handler != null) { + try { + this.handler.send(items); + } + catch (IOException ex) { + this.sendFailed = true; + throw ex; + } + catch (Throwable ex) { + this.sendFailed = true; + throw new IllegalStateException("Failed to send " + items, ex); + } + } + else { + this.earlySendAttempts.addAll(items); + } + } + /** * Complete request processing by performing a dispatch into the servlet * container, where Spring MVC is invoked once more, and completes the @@ -302,8 +333,17 @@ public class ResponseBodyEmitter { */ interface Handler { + /** + * Immediately write and flush the given data to the network. + */ void send(Object data, @Nullable MediaType mediaType) throws IOException; + /** + * Immediately write all data items then flush to the network. + * @since 6.0.12 + */ + void send(Set items) throws IOException; + void complete(); void completeWithError(Throwable failure); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java index 1667c210c5d..ee4eebcf9c2 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.Set; import java.util.function.Consumer; import jakarta.servlet.ServletRequest; @@ -202,6 +203,15 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur @Override public void send(Object data, @Nullable MediaType mediaType) throws IOException { sendInternal(data, mediaType); + this.outputMessage.flush(); + } + + @Override + public void send(Set items) throws IOException { + for (ResponseBodyEmitter.DataWithMediaType item : items) { + sendInternal(item.getData(), item.getMediaType()); + } + this.outputMessage.flush(); } @SuppressWarnings("unchecked") @@ -209,7 +219,6 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur for (HttpMessageConverter converter : ResponseBodyEmitterReturnValueHandler.this.sseMessageConverters) { if (converter.canWrite(data.getClass(), mediaType)) { ((HttpMessageConverter) converter).write(data, mediaType, this.outputMessage); - this.outputMessage.flush(); return; } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java index b1358bc47e1..5067f5b7198 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -123,9 +123,7 @@ public class SseEmitter extends ResponseBodyEmitter { public void send(SseEventBuilder builder) throws IOException { Set dataToSend = builder.build(); synchronized (this) { - for (DataWithMediaType entry : dataToSend) { - super.send(entry.getData(), entry.getMediaType()); - } + super.send(dataToSend); } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java index 12cfd1e55ac..db5a3b043a4 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java @@ -465,6 +465,11 @@ public class ReactiveTypeHandlerTests { this.values.add(data); } + @Override + public void send(Set items) throws IOException { + items.forEach(item -> this.values.add(item.getData())); + } + @Override public void complete() { } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java index a2a389d5bb7..f691ae64615 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java @@ -30,9 +30,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIOException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anySet; import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -52,34 +52,33 @@ public class ResponseBodyEmitterTests { @Test - public void sendBeforeHandlerInitialized() throws Exception { + void sendBeforeHandlerInitialized() throws Exception { this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("bar", MediaType.TEXT_PLAIN); this.emitter.complete(); verifyNoMoreInteractions(this.handler); this.emitter.initialize(this.handler); - verify(this.handler).send("foo", MediaType.TEXT_PLAIN); - verify(this.handler).send("bar", MediaType.TEXT_PLAIN); + verify(this.handler).send(anySet()); verify(this.handler).complete(); verifyNoMoreInteractions(this.handler); } @Test - public void sendDuplicateBeforeHandlerInitialized() throws Exception { + void sendDuplicateBeforeHandlerInitialized() throws Exception { this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.complete(); verifyNoMoreInteractions(this.handler); this.emitter.initialize(this.handler); - verify(this.handler, times(2)).send("foo", MediaType.TEXT_PLAIN); + verify(this.handler).send(anySet()); verify(this.handler).complete(); verifyNoMoreInteractions(this.handler); } @Test - public void sendBeforeHandlerInitializedWithError() throws Exception { + void sendBeforeHandlerInitializedWithError() throws Exception { IllegalStateException ex = new IllegalStateException(); this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("bar", MediaType.TEXT_PLAIN); @@ -87,21 +86,20 @@ public class ResponseBodyEmitterTests { verifyNoMoreInteractions(this.handler); this.emitter.initialize(this.handler); - verify(this.handler).send("foo", MediaType.TEXT_PLAIN); - verify(this.handler).send("bar", MediaType.TEXT_PLAIN); + verify(this.handler).send(anySet()); verify(this.handler).completeWithError(ex); verifyNoMoreInteractions(this.handler); } @Test - public void sendFailsAfterComplete() throws Exception { + void sendFailsAfterComplete() throws Exception { this.emitter.complete(); assertThatIllegalStateException().isThrownBy(() -> this.emitter.send("foo")); } @Test - public void sendAfterHandlerInitialized() throws Exception { + void sendAfterHandlerInitialized() throws Exception { this.emitter.initialize(this.handler); verify(this.handler).onTimeout(any()); verify(this.handler).onError(any()); @@ -119,7 +117,7 @@ public class ResponseBodyEmitterTests { } @Test - public void sendAfterHandlerInitializedWithError() throws Exception { + void sendAfterHandlerInitializedWithError() throws Exception { this.emitter.initialize(this.handler); verify(this.handler).onTimeout(any()); verify(this.handler).onError(any()); @@ -138,7 +136,7 @@ public class ResponseBodyEmitterTests { } @Test - public void sendWithError() throws Exception { + void sendWithError() throws Exception { this.emitter.initialize(this.handler); verify(this.handler).onTimeout(any()); verify(this.handler).onError(any()); @@ -154,7 +152,7 @@ public class ResponseBodyEmitterTests { } @Test - public void onTimeoutBeforeHandlerInitialized() throws Exception { + void onTimeoutBeforeHandlerInitialized() throws Exception { Runnable runnable = mock(); this.emitter.onTimeout(runnable); this.emitter.initialize(this.handler); @@ -169,7 +167,7 @@ public class ResponseBodyEmitterTests { } @Test - public void onTimeoutAfterHandlerInitialized() throws Exception { + void onTimeoutAfterHandlerInitialized() throws Exception { this.emitter.initialize(this.handler); ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); @@ -185,7 +183,7 @@ public class ResponseBodyEmitterTests { } @Test - public void onCompletionBeforeHandlerInitialized() throws Exception { + void onCompletionBeforeHandlerInitialized() throws Exception { Runnable runnable = mock(); this.emitter.onCompletion(runnable); this.emitter.initialize(this.handler); @@ -200,7 +198,7 @@ public class ResponseBodyEmitterTests { } @Test - public void onCompletionAfterHandlerInitialized() throws Exception { + void onCompletionAfterHandlerInitialized() throws Exception { this.emitter.initialize(this.handler); ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java index 5c81bb55266..570986e056d 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java @@ -20,12 +20,14 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.Set; import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.http.MediaType; +import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event; @@ -60,6 +62,7 @@ public class SseEmitterTests { this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(1, "foo"); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test @@ -69,12 +72,14 @@ public class SseEmitterTests { this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(1, "foo", MediaType.TEXT_PLAIN); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test public void sendEventEmpty() throws Exception { this.emitter.send(event()); this.handler.assertSentObjectCount(0); + this.handler.assertWriteCount(0); } @Test @@ -84,6 +89,7 @@ public class SseEmitterTests { this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(1, "foo"); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test @@ -95,6 +101,7 @@ public class SseEmitterTests { this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8); this.handler.assertObject(3, "bar"); this.handler.assertObject(4, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test @@ -104,6 +111,7 @@ public class SseEmitterTests { this.handler.assertObject(0, ":blah\nevent:test\nretry:5000\nid:1\ndata:", TEXT_PLAIN_UTF8); this.handler.assertObject(1, "foo"); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } @Test @@ -115,14 +123,17 @@ public class SseEmitterTests { this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8); this.handler.assertObject(3, "bar"); this.handler.assertObject(4, "\nevent:test\nretry:5000\nid:1\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); } private static class TestHandler implements ResponseBodyEmitter.Handler { - private List objects = new ArrayList<>(); + private final List objects = new ArrayList<>(); - private List mediaTypes = new ArrayList<>(); + private final List mediaTypes = new ArrayList<>(); + + private int writeCount; public void assertSentObjectCount(int size) { @@ -139,10 +150,24 @@ public class SseEmitterTests { assertThat(this.mediaTypes.get(index)).isEqualTo(mediaType); } + public void assertWriteCount(int writeCount) { + assertThat(this.writeCount).isEqualTo(writeCount); + } + @Override - public void send(Object data, MediaType mediaType) throws IOException { + public void send(Object data, @Nullable MediaType mediaType) throws IOException { this.objects.add(data); this.mediaTypes.add(mediaType); + this.writeCount++; + } + + @Override + public void send(Set items) throws IOException { + for (ResponseBodyEmitter.DataWithMediaType item : items) { + this.objects.add(item.getData()); + this.mediaTypes.add(item.getMediaType()); + } + this.writeCount++; } @Override