diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java index ab0b1bda4a9..1daf5607cca 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEvent.java @@ -19,6 +19,7 @@ package org.springframework.http.codec; import java.time.Duration; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; @@ -253,16 +254,23 @@ public final class ServerSentEvent { @Override public Builder id(String id) { + checkEvent(id); this.id = id; return this; } @Override public Builder event(String event) { + checkEvent(event); this.event = event; return this; } + private static void checkEvent(String content) { + Assert.isTrue(content.indexOf('\n') == -1 && content.indexOf('\r') == -1, + "illegal character '\\n' or '\\r' in event content"); + } + @Override public Builder retry(Duration retry) { this.retry = retry; diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java index 1847cf28b9d..822a50c5d29 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java @@ -40,7 +40,6 @@ import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; /** * {@code HttpMessageWriter} for {@code "text/event-stream"} responses. @@ -48,6 +47,7 @@ import org.springframework.util.StringUtils; * @author Sebastien Deleuze * @author Arjen Poutsma * @author Rossen Stoyanchev + * @author Brian Clozel * @since 5.0 */ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter { @@ -131,8 +131,9 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter Flux encodeEvent(CharSequence sseText, T data, ResolvableType dataType, MediaType mediaType, DataBufferFactory factory, Map hints) { diff --git a/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java index 9368f302104..9b6b70dabad 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java @@ -110,12 +110,13 @@ class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAllocating super.bufferFactory = bufferFactory; MockServerHttpResponse outputMessage = new MockServerHttpResponse(super.bufferFactory); - Flux source = Flux.just("foo\nbar", "foo\nbaz"); + Flux source = Flux.just("first\nsecond", "first\rsecond", "first\r\nsecond"); testWrite(source, outputMessage, String.class); StepVerifier.create(outputMessage.getBody()) - .consumeNextWith(stringConsumer("data:foo\ndata:bar\n\n")) - .consumeNextWith(stringConsumer("data:foo\ndata:baz\n\n")) + .consumeNextWith(stringConsumer("data:first\ndata:second\n\n")) + .consumeNextWith(stringConsumer("data:first\ndata:second\n\n")) + .consumeNextWith(stringConsumer("data:first\ndata:second\n\n")) .expectComplete() .verify(); } diff --git a/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventTests.java b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventTests.java new file mode 100644 index 00000000000..a106ea07337 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec; + +import java.util.stream.Stream; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link ServerSentEvent}. + * @author Brian Clozel + */ +class ServerSentEventTests { + + @ParameterizedTest(name = "{1}") + @MethodSource("newLineCharacters") + void rejectsInvalidId(String newLine, String description) { + assertThatIllegalArgumentException().isThrownBy(() -> + ServerSentEvent.builder().id("first" + newLine + "second").build()); + } + + @ParameterizedTest(name = "{1}") + @MethodSource("newLineCharacters") + void rejectsInvalidEvent(String newLine, String description) { + assertThatIllegalArgumentException().isThrownBy(() -> + ServerSentEvent.builder().event("first" + newLine + "second").build()); + } + + private static Stream newLineCharacters() { + return Stream.of( + Arguments.of("\n", "LF"), + Arguments.of("\r", "CR"), + Arguments.of("\r\n", "CRLF") + ); + } + +} diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java index 85752eee9e8..94be3f280c5 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java @@ -26,6 +26,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpResponse; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; import org.springframework.web.servlet.ModelAndView; @@ -195,19 +196,20 @@ public class SseEmitter extends ResponseBodyEmitter { private final Set dataToSend = new LinkedHashSet<>(4); - @Nullable - private StringBuilder sb; + private final StringBuilder sb = new StringBuilder(); private boolean hasName; @Override public SseEventBuilder id(String id) { + checkEvent(id); append("id:").append(id).append('\n'); return this; } @Override public SseEventBuilder name(String name) { + checkEvent(name); this.hasName = true; append("event:").append(name).append('\n'); return this; @@ -221,7 +223,7 @@ public class SseEmitter extends ResponseBodyEmitter { @Override public SseEventBuilder comment(String comment) { - append(':').append(comment).append('\n'); + append(':').append(StringUtils.replace(comment, "\n", "\n:")).append('\n'); return this; } @@ -236,27 +238,53 @@ public class SseEmitter extends ResponseBodyEmitter { name(mav.getViewName()); } append("data:"); - saveAppendedText(); + saveAppendedText(TEXT_PLAIN); if (object instanceof String text) { - object = StringUtils.replace(text, "\n", "\ndata:"); + writeStringData(text, mediaType); + } + else { + this.dataToSend.add(new DataWithMediaType(object, mediaType)); } - this.dataToSend.add(new DataWithMediaType(object, mediaType)); append('\n'); return this; } - SseEventBuilderImpl append(String text) { - if (this.sb == null) { - this.sb = new StringBuilder(); + private static void checkEvent(String content) { + Assert.isTrue(content.indexOf('\n') == -1 && content.indexOf('\r') == -1, + "illegal character '\\n' or '\\r' in event content"); + } + + private void writeStringData(String input, @Nullable MediaType mediaType) { + if (input.indexOf('\n') == -1 && input.indexOf('\r') == -1) { + this.dataToSend.add(new DataWithMediaType(input, mediaType)); + } + else { + int length = input.length(); + for (int i = 0; i < length; i++) { + char c = input.charAt(i); + if (c == '\r') { + if (i + 1 < length && input.charAt(i + 1) == '\n') { + i++; + } + this.sb.append("\ndata:"); + } + else if (c == '\n') { + this.sb.append("\ndata:"); + } + else { + this.sb.append(c); + } + } + saveAppendedText(mediaType); } + } + + SseEventBuilderImpl append(String text) { this.sb.append(text); return this; } SseEventBuilderImpl append(char ch) { - if (this.sb == null) { - this.sb = new StringBuilder(); - } this.sb.append(ch); return this; } @@ -267,14 +295,14 @@ public class SseEmitter extends ResponseBodyEmitter { return Collections.emptySet(); } append('\n'); - saveAppendedText(); + saveAppendedText(TEXT_PLAIN); return this.dataToSend; } - private void saveAppendedText() { - if (this.sb != null) { - this.dataToSend.add(new DataWithMediaType(this.sb.toString(), TEXT_PLAIN)); - this.sb = null; + private void saveAppendedText(@Nullable MediaType mediaType) { + if (StringUtils.hasLength(this.sb)) { + this.dataToSend.add(new DataWithMediaType(this.sb.toString(), mediaType)); + this.sb.setLength(0); } } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java index 4a0aa857b34..3c86732eba7 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterTests.java @@ -22,14 +22,19 @@ import java.util.ArrayList; import java.util.List; import java.util.Set; import java.util.function.Consumer; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.http.MediaType; import org.springframework.lang.Nullable; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event; @@ -105,9 +110,10 @@ class SseEmitterTests { this.handler.assertWriteCount(1); } - @Test - void sendEventWithMultiline() throws Exception { - this.emitter.send(event().data("foo\nbar\nbaz")); + @ParameterizedTest(name = "{1}") + @MethodSource("newLineCharacters") + void sendEventWithMultiline(String newLineChars, String description) throws Exception { + this.emitter.send(event().data("foo" + newLineChars + "bar" + newLineChars + "baz")); this.handler.assertSentObjectCount(3); this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(1, "foo\ndata:bar\ndata:baz"); @@ -115,6 +121,17 @@ class SseEmitterTests { this.handler.assertWriteCount(1); } + @ParameterizedTest(name = "{1}") + @MethodSource("newLineCharacters") + void sendEventWithMultilineWithMediaType(String newLineChars, String description) throws Exception { + this.emitter.send(event().data("foo" + newLineChars + "bar" + newLineChars + "baz", MediaType.TEXT_PLAIN)); + this.handler.assertSentObjectCount(3); + this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); + this.handler.assertObject(1, "foo\ndata:bar\ndata:baz", MediaType.TEXT_PLAIN); + this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + this.handler.assertWriteCount(1); + } + @Test void sendEventFull() throws Exception { this.emitter.send(event().comment("blah").name("test").reconnectTime(5000L).id("1").data("foo")); @@ -137,6 +154,28 @@ class SseEmitterTests { this.handler.assertWriteCount(1); } + @ParameterizedTest(name = "{1}") + @MethodSource("newLineCharacters") + void rejectInvalidId(String newLineChars, String description) { + assertThatIllegalArgumentException().isThrownBy(() -> this.emitter + .send(event().id("first" + newLineChars + "second"))); + } + + @ParameterizedTest(name = "{1}") + @MethodSource("newLineCharacters") + void rejectInvalidName(String newLineChars, String description) { + assertThatIllegalArgumentException().isThrownBy(() -> this.emitter + .send(event().name("first" + newLineChars + "second"))); + } + + private static Stream newLineCharacters() { + return Stream.of( + Arguments.of("\n", "LF"), + Arguments.of("\r", "CR"), + Arguments.of("\r\n", "CRLF") + ); + } + private static class TestHandler implements ResponseBodyEmitter.Handler {