diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerResponseBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerResponseBuilder.java index dc61bc9e028..72c32d308fa 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerResponseBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerResponseBuilder.java @@ -209,6 +209,10 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder { .build(); } + @Override + public ServerResponse stream(Consumer streamConsumer) { + return StreamingServerResponse.create(this.statusCode, this.headers, this.cookies, streamConsumer, null); + } private static class WriteFunctionResponse extends AbstractServerResponse { diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java index ede9aceb212..58fa4a21770 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java @@ -547,6 +547,87 @@ public interface ServerResponse { * @return the built response */ ServerResponse render(String name, Map model); + + /** + * Create a low-level streaming response; for SSE support, see {@link #sse(Consumer)}. + *

The {@link StreamBuilder} provided to the {@code streamConsumer} can + * be used to write to the response in a streaming fashion. Note, the builder is + * responsible for flushing the buffered content to the network. + *

For example: + *

+		 * public ServerResponse handleStream(ServerRequest request) {
+		 *     return ServerResponse.ok()
+		 *       .contentType(MediaType.APPLICATION_ND_JSON)
+		 *       .stream(stream -> {
+		 *         try {
+		 *           // Write and flush a first item
+		 *           stream.write(new Person("John", 51), MediaType.APPLICATION_JSON)
+		 *             .write(new byte[]{'\n'})
+		 *             .flush();
+		 *           // Write and complete with the last item
+		 *           stream.write(new Person("Jane", 42), MediaType.APPLICATION_JSON)
+		 *             .write(new byte[]{'\n'})
+		 *             .complete();
+		 *         }
+		 *         catch (IOException ex) {
+		 *           throw new UncheckedIOException(ex);
+		 *         }
+		 *     });
+		 * }
+		 * 
+ * @param streamConsumer consumer that will be provided with a stream builder + * @return the server-side streaming response + * @since 6.2 + */ + ServerResponse stream(Consumer streamConsumer); + + } + + /** + * Defines a builder for async response bodies. + * @since 6.2 + * @param the builder subclass + */ + interface AsyncBuilder> { + + /** + * Completes the stream with the given error. + * + *

The throwable is dispatched back into Spring MVC, and passed to + * its exception handling mechanism. Since the response has + * been committed by this point, the response status can not change. + * @param t the throwable to dispatch + */ + void error(Throwable t); + + /** + * Completes the stream. + */ + void complete(); + + /** + * Register a callback to be invoked when a request times + * out. + * @param onTimeout the callback to invoke on timeout + * @return this builder + */ + B onTimeout(Runnable onTimeout); + + /** + * Register a callback to be invoked when an error occurs during + * processing. + * @param onError the callback to invoke on error + * @return this builder + */ + B onError(Consumer onError); + + /** + * Register a callback to be invoked when the request completes. + * @param onCompletion the callback to invoked on completion + * @return this builder + */ + B onComplete(Runnable onCompletion); + } @@ -555,7 +636,7 @@ public interface ServerResponse { * * @since 5.3.2 */ - interface SseBuilder { + interface SseBuilder extends AsyncBuilder { /** * Sends the given object as a server-sent event. @@ -618,45 +699,45 @@ public interface ServerResponse { */ void data(Object object) throws IOException; - /** - * Completes the event stream with the given error. - * - *

The throwable is dispatched back into Spring MVC, and passed to - * its exception handling mechanism. Since the response has - * been committed by this point, the response status can not change. - * @param t the throwable to dispatch - */ - void error(Throwable t); + } - /** - * Completes the event stream. - */ - void complete(); + /** + * Defines a builder for a streaming response body. + * + * @since 6.2 + */ + interface StreamBuilder extends AsyncBuilder { /** - * Register a callback to be invoked when an SSE request times - * out. - * @param onTimeout the callback to invoke on timeout + * Write the given object to the response stream, without flushing. + * Strings will be sent as UTF-8 encoded bytes, byte arrays will be sent as-is, + * and other objects will be converted into JSON using + * {@linkplain HttpMessageConverter message converters}. + * @param object the object to send as data * @return this builder + * @throws IOException in case of I/O errors */ - SseBuilder onTimeout(Runnable onTimeout); + StreamBuilder write(Object object) throws IOException; /** - * Register a callback to be invoked when an error occurs during SSE - * processing. - * @param onError the callback to invoke on error + * Write the given object to the response stream, without flushing. + * Strings will be sent as UTF-8 encoded bytes, byte arrays will be sent as-is, + * and other objects will be converted into JSON using + * {@linkplain HttpMessageConverter message converters}. + * @param object the object to send as data + * @param mediaType the media type to use for encoding the provided data * @return this builder + * @throws IOException in case of I/O errors */ - SseBuilder onError(Consumer onError); + StreamBuilder write(Object object, @Nullable MediaType mediaType) throws IOException; /** - * Register a callback to be invoked when the SSE request completes. - * @param onCompletion the callback to invoked on completion - * @return this builder + * Flush the buffered response stream content to the network. + * @throws IOException in case of I/O errors */ - SseBuilder onComplete(Runnable onCompletion); - } + void flush() throws IOException; + } /** * Defines the context used during the {@link #writeTo(HttpServletRequest, HttpServletResponse, Context)}. diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/StreamingServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/StreamingServerResponse.java new file mode 100644 index 00000000000..ee4f2001a83 --- /dev/null +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/StreamingServerResponse.java @@ -0,0 +1,223 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.servlet.function; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.function.Consumer; + +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.DelegatingServerHttpResponse; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.web.context.request.async.DeferredResult; +import org.springframework.web.servlet.ModelAndView; + +/** + * Implementation of {@link ServerResponse} for sending streaming response bodies. + * + * @author Brian Clozel + */ +final class StreamingServerResponse extends AbstractServerResponse { + + private final Consumer streamConsumer; + + @Nullable + private final Duration timeout; + + private StreamingServerResponse(HttpStatusCode statusCode, HttpHeaders headers, MultiValueMap cookies, + Consumer streamConsumer, @Nullable Duration timeout) { + super(statusCode, headers, cookies); + this.streamConsumer = streamConsumer; + this.timeout = timeout; + } + + static ServerResponse create(HttpStatusCode statusCode, HttpHeaders headers, MultiValueMap cookies, + Consumer streamConsumer, @Nullable Duration timeout) { + Assert.notNull(statusCode, "statusCode must not be null"); + Assert.notNull(headers, "headers must not be null"); + Assert.notNull(cookies, "cookies must not be null"); + Assert.notNull(streamConsumer, "streamConsumer must not be null"); + return new StreamingServerResponse(statusCode, headers, cookies, streamConsumer, timeout); + } + + @Nullable + @Override + protected ModelAndView writeToInternal(HttpServletRequest request, HttpServletResponse response, Context context) throws Exception { + DeferredResult result; + if (this.timeout != null) { + result = new DeferredResult<>(this.timeout.toMillis()); + } + else { + result = new DeferredResult<>(); + } + DefaultAsyncServerResponse.writeAsync(request, response, result); + this.streamConsumer.accept(new DefaultStreamBuilder(response, context, result, this.headers())); + return null; + } + + private static class DefaultStreamBuilder implements StreamBuilder { + + private final ServerHttpResponse outputMessage; + + private final DeferredResult deferredResult; + + private final List> messageConverters; + + private final HttpHeaders httpHeaders; + + private boolean sendFailed; + + + public DefaultStreamBuilder(HttpServletResponse response, Context context, DeferredResult deferredResult, + HttpHeaders httpHeaders) { + this.outputMessage = new ServletServerHttpResponse(response); + this.deferredResult = deferredResult; + this.messageConverters = context.messageConverters(); + this.httpHeaders = httpHeaders; + } + + @Override + public StreamBuilder write(Object object) throws IOException { + write(object, null); + return this; + } + + @Override + public StreamBuilder write(Object object, @Nullable MediaType mediaType) throws IOException { + Assert.notNull(object, "data must not be null"); + try { + if (object instanceof byte[] bytes) { + this.outputMessage.getBody().write(bytes); + } + else if (object instanceof String str) { + this.outputMessage.getBody().write(str.getBytes(StandardCharsets.UTF_8)); + } + else { + writeObject(object, mediaType); + } + } + catch (IOException ex) { + this.sendFailed = true; + throw ex; + } + return this; + } + + @SuppressWarnings("unchecked") + private void writeObject(Object data, @Nullable MediaType mediaType) throws IOException { + Class elementClass = data.getClass(); + for (HttpMessageConverter converter : this.messageConverters) { + if (converter.canWrite(elementClass, mediaType)) { + HttpMessageConverter objectConverter = (HttpMessageConverter) converter; + ServerHttpResponse response = new MutableHeadersServerHttpResponse(this.outputMessage, this.httpHeaders); + objectConverter.write(data, mediaType, response); + return; + } + } + } + + @Override + public void flush() throws IOException { + if (this.sendFailed) { + return; + } + try { + this.outputMessage.flush(); + } + catch (IOException ex) { + this.sendFailed = true; + throw ex; + } + } + + @Override + public void error(Throwable t) { + if (this.sendFailed) { + return; + } + this.deferredResult.setErrorResult(t); + } + + @Override + public void complete() { + if (this.sendFailed) { + return; + } + try { + this.outputMessage.flush(); + this.deferredResult.setResult(null); + } + catch (IOException ex) { + this.deferredResult.setErrorResult(ex); + } + } + + @Override + public StreamBuilder onTimeout(Runnable onTimeout) { + this.deferredResult.onTimeout(onTimeout); + return this; + } + + @Override + public StreamBuilder onError(Consumer onError) { + this.deferredResult.onError(onError); + return this; + } + + @Override + public StreamBuilder onComplete(Runnable onCompletion) { + this.deferredResult.onCompletion(onCompletion); + return this; + } + + /** + * Wrap to silently ignore header changes HttpMessageConverter's that would + * otherwise cause HttpHeaders to raise exceptions. + */ + private static final class MutableHeadersServerHttpResponse extends DelegatingServerHttpResponse { + + private final HttpHeaders mutableHeaders = new HttpHeaders(); + + public MutableHeadersServerHttpResponse(ServerHttpResponse delegate, HttpHeaders headers) { + super(delegate); + this.mutableHeaders.putAll(delegate.getHeaders()); + this.mutableHeaders.putAll(headers); + } + + @Override + public HttpHeaders getHeaders() { + return this.mutableHeaders; + } + + } + + } + +} diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/StreamingServerResponseTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/StreamingServerResponseTests.java new file mode 100644 index 00000000000..73e63194dd2 --- /dev/null +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/StreamingServerResponseTests.java @@ -0,0 +1,131 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.servlet.function; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.http.CacheControl; +import org.springframework.http.MediaType; +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.web.servlet.ModelAndView; +import org.springframework.web.testfixture.servlet.MockHttpServletRequest; +import org.springframework.web.testfixture.servlet.MockHttpServletResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link StreamingServerResponse}. + * @author Brian Clozel + */ +class StreamingServerResponseTests { + + private MockHttpServletRequest mockRequest; + + private MockHttpServletResponse mockResponse; + + @BeforeEach + void setUp() { + this.mockRequest = new MockHttpServletRequest("GET", "https://example.com"); + this.mockRequest.setAsyncSupported(true); + this.mockResponse = new MockHttpServletResponse(); + } + + @Test + void writeSingleString() throws Exception { + String body = "data: foo bar\n\n"; + ServerResponse response = ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .stream(stream -> { + try { + stream.write(body).complete(); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } + }); + + ServerResponse.Context context = Collections::emptyList; + ModelAndView mav = response.writeTo(this.mockRequest, this.mockResponse, context); + assertThat(mav).isNull(); + assertThat(this.mockResponse.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM.toString()); + assertThat(this.mockResponse.getContentAsString()).isEqualTo(body); + } + + @Test + void writeBytes() throws Exception { + String body = "data: foo bar\n\n"; + ServerResponse response = ServerResponse + .ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .cacheControl(CacheControl.noCache()) + .stream(stream -> { + try { + stream.write(body.getBytes(StandardCharsets.UTF_8)).complete(); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } + }); + ServerResponse.Context context = Collections::emptyList; + ModelAndView mav = response.writeTo(this.mockRequest, this.mockResponse, context); + assertThat(mav).isNull(); + assertThat(this.mockResponse.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM.toString()); + assertThat(this.mockResponse.getContentAsString()).isEqualTo(body); + } + + @Test + void writeWithConverters() throws Exception { + ServerResponse response = ServerResponse + .ok() + .contentType(MediaType.APPLICATION_NDJSON) + .cacheControl(CacheControl.noCache()) + .stream(stream -> { + try { + stream.write(new Person("John", 51), MediaType.APPLICATION_JSON) + .write(new byte[]{'\n'}) + .flush(); + stream.write(new Person("Jane", 42), MediaType.APPLICATION_JSON) + .write(new byte[]{'\n'}) + .complete(); + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } + }); + + ServerResponse.Context context = () -> Collections.singletonList(new MappingJackson2HttpMessageConverter()); + ModelAndView mav = response.writeTo(this.mockRequest, this.mockResponse, context); + assertThat(mav).isNull(); + assertThat(this.mockResponse.getContentType()).isEqualTo(MediaType.APPLICATION_NDJSON.toString()); + assertThat(this.mockResponse.getContentAsString()).isEqualTo(""" + {"name":"John","age":51} + {"name":"Jane","age":42} + """); + } + + + record Person(String name, int age) { + + } + +}