diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientStrategiesBuilder.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientStrategiesBuilder.java index dd583d20146..bc6cb3b01ad 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientStrategiesBuilder.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientStrategiesBuilder.java @@ -36,6 +36,7 @@ import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageWriter; import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.http.codec.ServerSentEventHttpMessageReader; import org.springframework.http.codec.json.Jackson2JsonDecoder; import org.springframework.http.codec.json.Jackson2JsonEncoder; import org.springframework.http.codec.xml.Jaxb2XmlDecoder; @@ -70,13 +71,17 @@ class DefaultWebClientStrategiesBuilder implements WebClientStrategies.Builder { public void defaultConfiguration() { messageReader(new DecoderHttpMessageReader<>(new ByteArrayDecoder())); messageReader(new DecoderHttpMessageReader<>(new ByteBufferDecoder())); + if (jackson2Present) { + messageReader(new ServerSentEventHttpMessageReader(Collections.singletonList(new Jackson2JsonDecoder()))); + } + else { + messageReader(new ServerSentEventHttpMessageReader(Collections.emptyList())); + } messageReader(new DecoderHttpMessageReader<>(new StringDecoder(false))); - messageWriter(new EncoderHttpMessageWriter<>(new ByteArrayEncoder())); messageWriter(new EncoderHttpMessageWriter<>(new ByteBufferEncoder())); messageWriter(new EncoderHttpMessageWriter<>(new CharSequenceEncoder())); messageWriter(new ResourceHttpMessageWriter()); - if (jaxb2Present) { messageReader(new DecoderHttpMessageReader<>(new Jaxb2XmlDecoder())); messageWriter(new EncoderHttpMessageWriter<>(new Jaxb2XmlEncoder())); diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/function/server/SseHandlerFunctionIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/function/server/SseHandlerFunctionIntegrationTests.java index e95e220ee51..23618990b5c 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/function/server/SseHandlerFunctionIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/function/server/SseHandlerFunctionIntegrationTests.java @@ -18,16 +18,19 @@ package org.springframework.web.reactive.function.server; import java.time.Duration; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import org.junit.Before; import org.junit.Test; +import static org.springframework.http.MediaType.TEXT_EVENT_STREAM; +import static org.springframework.web.reactive.function.BodyExtractors.toFlux; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import org.springframework.http.MediaType; +import org.springframework.core.ResolvableType; import org.springframework.http.client.reactive.ReactorClientHttpConnector; import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.BodyExtractors; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.WebClient; @@ -37,10 +40,8 @@ import static org.springframework.web.reactive.function.server.RouterFunctions.r /** * @author Arjen Poutsma */ -public class SseHandlerFunctionIntegrationTests - extends AbstractRouterFunctionIntegrationTests { +public class SseHandlerFunctionIntegrationTests extends AbstractRouterFunctionIntegrationTests { - private static final MediaType EVENT_STREAM = new MediaType("text", "event-stream"); private WebClient webClient; @@ -57,49 +58,40 @@ public class SseHandlerFunctionIntegrationTests .and(route(RequestPredicates.GET("/event"), sseHandler::sse)); } - @Test public void sseAsString() throws Exception { - ClientRequest request = - ClientRequest + ClientRequest request = ClientRequest .GET("http://localhost:{port}/string", this.port) - .accept(EVENT_STREAM) + .accept(TEXT_EVENT_STREAM) .build(); Flux result = this.webClient .exchange(request) - .flatMap(response -> response.body(BodyExtractors.toFlux(String.class))) - .filter(s -> !s.equals("\n")) - .map(s -> (s.replace("\n", ""))) - .take(2); + .flatMap(response -> response.body(toFlux(String.class))); StepVerifier.create(result) - .expectNext("data:foo 0") - .expectNext("data:foo 1") + .expectNext("foo 0") + .expectNext("foo 1") .expectComplete() - .verify(Duration.ofSeconds(5)); + .verify(Duration.ofSeconds(5L)); } - @Test public void sseAsPerson() throws Exception { ClientRequest request = ClientRequest .GET("http://localhost:{port}/person", this.port) - .accept(EVENT_STREAM) + .accept(TEXT_EVENT_STREAM) .build(); - Mono result = this.webClient + Flux result = this.webClient .exchange(request) - .flatMap(response -> response.body(BodyExtractors.toFlux(String.class))) - .filter(s -> !s.equals("\n")) - .map(s -> s.replace("\n", "")) - .takeUntil(s -> s.endsWith("foo 1\"}")) - .reduce((s1, s2) -> s1 + s2); + .flatMap(response -> response.body(toFlux(Person.class))); StepVerifier.create(result) - .expectNext("data:{\"name\":\"foo 0\"}data:{\"name\":\"foo 1\"}") + .expectNext(new Person("foo 0")) + .expectNext(new Person("foo 1")) .expectComplete() - .verify(Duration.ofSeconds(5)); + .verify(Duration.ofSeconds(5L)); } @Test @@ -107,21 +99,31 @@ public class SseHandlerFunctionIntegrationTests ClientRequest request = ClientRequest .GET("http://localhost:{port}/event", this.port) - .accept(EVENT_STREAM) + .accept(TEXT_EVENT_STREAM) .build(); - Flux result = this.webClient + ResolvableType type = ResolvableType.forClassWithGenerics(ServerSentEvent.class, String.class); + Flux> result = this.webClient .exchange(request) - .flatMap(response -> response.body(BodyExtractors.toFlux(String.class))) - .filter(s -> !s.equals("\n")) - .map(s -> s.replace("\n", "")) - .take(2); + .flatMap(response -> response.body(toFlux(type))); StepVerifier.create(result) - .expectNext("id:0:bardata:foo") - .expectNext("id:1:bardata:foo") + .consumeNextWith( event -> { + assertEquals("0", event.id().get()); + assertEquals("foo", event.data().get()); + assertEquals("bar", event.comment().get()); + assertFalse(event.event().isPresent()); + assertFalse(event.retry().isPresent()); + }) + .consumeNextWith( event -> { + assertEquals("1", event.id().get()); + assertEquals("foo", event.data().get()); + assertEquals("bar", event.comment().get()); + assertFalse(event.event().isPresent()); + assertFalse(event.retry().isPresent()); + }) .expectComplete() - .verify(Duration.ofSeconds(5)); + .verify(Duration.ofSeconds(5L)); } private static class SseHandler { diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/SseIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/SseIntegrationTests.java index 316fb950388..676f30c5648 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/SseIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/SseIntegrationTests.java @@ -18,16 +18,20 @@ package org.springframework.web.reactive.result.method.annotation; import java.time.Duration; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import org.junit.Before; import org.junit.Test; +import static org.springframework.http.MediaType.TEXT_EVENT_STREAM; +import static org.springframework.web.reactive.function.BodyExtractors.toFlux; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; + import reactor.test.StepVerifier; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.http.MediaType; +import org.springframework.core.ResolvableType; import org.springframework.http.client.reactive.ReactorClientHttpConnector; import org.springframework.http.codec.ServerSentEvent; import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests; @@ -36,7 +40,6 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.DispatcherHandler; import org.springframework.web.reactive.config.EnableWebReactive; -import org.springframework.web.reactive.function.BodyExtractors; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; @@ -47,9 +50,6 @@ import org.springframework.web.server.adapter.WebHttpHandlerBuilder; */ public class SseIntegrationTests extends AbstractHttpHandlerIntegrationTests { - private static final MediaType EVENT_STREAM = new MediaType("text", "event-stream"); - - private AnnotationConfigApplicationContext wac; private WebClient webClient; @@ -74,22 +74,18 @@ public class SseIntegrationTests extends AbstractHttpHandlerIntegrationTests { @Test public void sseAsString() throws Exception { - ClientRequest request = - ClientRequest + ClientRequest request = ClientRequest .GET("http://localhost:{port}/sse/string", this.port) - .accept(EVENT_STREAM) + .accept(TEXT_EVENT_STREAM) .build(); Flux result = this.webClient .exchange(request) - .flatMap(response -> response.body(BodyExtractors.toFlux(String.class))) - .filter(s -> !s.equals("\n")) - .map(s -> (s.replace("\n", ""))) - .take(2); + .flatMap(response -> response.body(toFlux(String.class))); StepVerifier.create(result) - .expectNext("data:foo 0") - .expectNext("data:foo 1") + .expectNext("foo 0") + .expectNext("foo 1") .expectComplete() .verify(Duration.ofSeconds(5L)); } @@ -98,19 +94,16 @@ public class SseIntegrationTests extends AbstractHttpHandlerIntegrationTests { ClientRequest request = ClientRequest .GET("http://localhost:{port}/sse/person", this.port) - .accept(EVENT_STREAM) + .accept(TEXT_EVENT_STREAM) .build(); - Mono result = this.webClient + Flux result = this.webClient .exchange(request) - .flatMap(response -> response.body(BodyExtractors.toFlux(String.class))) - .filter(s -> !s.equals("\n")) - .map(s -> s.replace("\n", "")) - .takeUntil(s -> s.endsWith("foo 1\"}")) - .reduce((s1, s2) -> s1 + s2); + .flatMap(response -> response.body(toFlux(Person.class))); StepVerifier.create(result) - .expectNext("data:{\"name\":\"foo 0\"}data:{\"name\":\"foo 1\"}") + .expectNext(new Person("foo 0")) + .expectNext(new Person("foo 1")) .expectComplete() .verify(Duration.ofSeconds(5L)); } @@ -120,18 +113,29 @@ public class SseIntegrationTests extends AbstractHttpHandlerIntegrationTests { ClientRequest request = ClientRequest .GET("http://localhost:{port}/sse/event", this.port) - .accept(EVENT_STREAM) + .accept(TEXT_EVENT_STREAM) .build(); - Flux result = this.webClient + + ResolvableType type = ResolvableType.forClassWithGenerics(ServerSentEvent.class, String.class); + Flux> result = this.webClient .exchange(request) - .flatMap(response -> response.body(BodyExtractors.toFlux(String.class))) - .filter(s -> !s.equals("\n")) - .map(s -> s.replace("\n", "")) - .take(2); + .flatMap(response -> response.body(toFlux(type))); StepVerifier.create(result) - .expectNext("id:0:bardata:foo") - .expectNext("id:1:bardata:foo") + .consumeNextWith( event -> { + assertEquals("0", event.id().get()); + assertEquals("foo", event.data().get()); + assertEquals("bar", event.comment().get()); + assertFalse(event.event().isPresent()); + assertFalse(event.retry().isPresent()); + }) + .consumeNextWith( event -> { + assertEquals("1", event.id().get()); + assertEquals("foo", event.data().get()); + assertEquals("bar", event.comment().get()); + assertFalse(event.event().isPresent()); + assertFalse(event.retry().isPresent()); + }) .expectComplete() .verify(Duration.ofSeconds(5L)); } @@ -141,19 +145,28 @@ public class SseIntegrationTests extends AbstractHttpHandlerIntegrationTests { ClientRequest request = ClientRequest .GET("http://localhost:{port}/sse/event", this.port) - .accept(EVENT_STREAM) + .accept(TEXT_EVENT_STREAM) .build(); - Flux result = this.webClient + Flux> result = this.webClient .exchange(request) - .flatMap(response -> response.body(BodyExtractors.toFlux(String.class))) - .filter(s -> !s.equals("\n")) - .map(s -> s.replace("\n", "")) - .take(2); + .flatMap(response -> response.body(toFlux(ResolvableType.forClassWithGenerics(ServerSentEvent.class, String.class)))); StepVerifier.create(result) - .expectNext("id:0:bardata:foo") - .expectNext("id:1:bardata:foo") + .consumeNextWith( event -> { + assertEquals("0", event.id().get()); + assertEquals("foo", event.data().get()); + assertEquals("bar", event.comment().get()); + assertFalse(event.event().isPresent()); + assertFalse(event.retry().isPresent()); + }) + .consumeNextWith( event -> { + assertEquals("1", event.id().get()); + assertEquals("foo", event.data().get()); + assertEquals("bar", event.comment().get()); + assertFalse(event.event().isPresent()); + assertFalse(event.retry().isPresent()); + }) .expectComplete() .verify(Duration.ofSeconds(5L)); }