Browse Source

Support Flux<ServerSentEvent<Fragment>> in WebFlux

Closes gh-33975
pull/34074/head
rstoyanchev 1 year ago
parent
commit
d45e6ec197
  1. 85
      spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java
  2. 67
      spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java
  3. 5
      spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java

85
spring-webflux/src/main/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandler.java

@ -45,6 +45,7 @@ import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode; import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator; import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
@ -101,7 +102,7 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
private final List<View> defaultViews = new ArrayList<>(4); private final List<View> defaultViews = new ArrayList<>(4);
private final List<StreamHandler> streamHandlers = List.of(new SseStreamHandler()); private final SseStreamHandler sseHandler = new SseStreamHandler();
/** /**
@ -175,7 +176,7 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
returnType = returnType.getNested(2); returnType = returnType.getNested(2);
if (adapter.isMultiValue()) { if (adapter.isMultiValue()) {
return Fragment.class.isAssignableFrom(type); return (Fragment.class.isAssignableFrom(type) || isSseFragmentStream(returnType));
} }
} }
@ -194,8 +195,13 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
} }
private static boolean isFragmentCollection(ResolvableType returnType) { private static boolean isFragmentCollection(ResolvableType returnType) {
Class<?> clazz = returnType.resolve(Object.class); return (Collection.class.isAssignableFrom(returnType.resolve(Object.class)) &&
return (Collection.class.isAssignableFrom(clazz) && Fragment.class.equals(returnType.getNested(2).resolve())); Fragment.class.equals(returnType.getNested(2).resolve()));
}
private static boolean isSseFragmentStream(ResolvableType returnType) {
return (ServerSentEvent.class.equals(returnType.resolve()) &&
Fragment.class.equals(returnType.getNested(2).resolve()));
} }
@Override @Override
@ -204,9 +210,15 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
Mono<Object> valueMono; Mono<Object> valueMono;
ResolvableType valueType; ResolvableType valueType;
ReactiveAdapter adapter = getAdapter(result); ReactiveAdapter adapter = getAdapter(result);
BindingContext bindingContext = result.getBindingContext();
Locale locale = LocaleContextHolder.getLocale(exchange.getLocaleContext());
if (adapter != null) { if (adapter != null) {
if (adapter.isMultiValue()) { if (adapter.isMultiValue()) {
if (isSseFragmentStream(result.getReturnType().getNested(2))) {
return handleSseFragmentStream(exchange, result, adapter, locale, bindingContext);
}
valueMono = (result.getReturnValue() != null ? valueMono = (result.getReturnValue() != null ?
Mono.just(FragmentsRendering.fragmentsPublisher(adapter.toPublisher(result.getReturnValue())).build()) : Mono.just(FragmentsRendering.fragmentsPublisher(adapter.toPublisher(result.getReturnValue())).build()) :
Mono.empty()); Mono.empty());
@ -233,8 +245,6 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
Mono<List<View>> viewsMono; Mono<List<View>> viewsMono;
Model model = result.getModel(); Model model = result.getModel();
MethodParameter parameter = result.getReturnTypeSource(); MethodParameter parameter = result.getReturnTypeSource();
BindingContext bindingContext = result.getBindingContext();
Locale locale = LocaleContextHolder.getLocale(exchange.getLocaleContext());
Class<?> clazz = valueType.toClass(); Class<?> clazz = valueType.toClass();
if (clazz == Object.class) { if (clazz == Object.class) {
@ -277,13 +287,15 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
response.getHeaders().putAll(render.headers()); response.getHeaders().putAll(render.headers());
bindingContext.updateModel(exchange); bindingContext.updateModel(exchange);
StreamHandler streamHandler = getStreamHandler(exchange); StreamHandler streamHandler =
(this.sseHandler.supports(exchange.getRequest()) ? this.sseHandler : null);
if (streamHandler != null) { if (streamHandler != null) {
streamHandler.updateResponse(exchange); streamHandler.updateResponse(exchange);
} }
Flux<Flux<DataBuffer>> renderFlux = render.fragments() Flux<Flux<DataBuffer>> renderFlux = render.fragments()
.concatMap(fragment -> renderFragment(fragment, streamHandler, locale, bindingContext, exchange)) .concatMap(fragment -> renderFragment(fragment, null, streamHandler, locale, bindingContext, exchange))
.doOnDiscard(DataBuffer.class, DataBufferUtils::release); .doOnDiscard(DataBuffer.class, DataBufferUtils::release);
return response.writeAndFlushWith(renderFlux); return response.writeAndFlushWith(renderFlux);
@ -338,9 +350,29 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
}); });
} }
private Mono<Void> handleSseFragmentStream(
ServerWebExchange exchange, HandlerResult result, ReactiveAdapter adapter, Locale locale,
BindingContext bindingContext) {
this.sseHandler.updateResponse(exchange);
Flux<ServerSentEvent<Fragment>> eventFlux =
Flux.from(adapter.toPublisher(result.getReturnValue()));
Flux<Flux<DataBuffer>> dataBufferFlux = eventFlux
.concatMap(event -> renderFragment(event.data(), event, this.sseHandler, locale, bindingContext, exchange))
.doOnDiscard(DataBuffer.class, DataBufferUtils::release);
return exchange.getResponse().writeAndFlushWith(dataBufferFlux);
}
private Mono<Flux<DataBuffer>> renderFragment( private Mono<Flux<DataBuffer>> renderFragment(
Fragment fragment, @Nullable StreamHandler streamHandler, Locale locale, @Nullable Fragment fragment, @Nullable Object streamingHints, @Nullable StreamHandler streamHandler,
BindingContext bindingContext, ServerWebExchange exchange) { Locale locale, BindingContext bindingContext, ServerWebExchange exchange) {
if (fragment == null) {
return Mono.empty();
}
// Merge attributes from top-level model // Merge attributes from top-level model
fragment.mergeAttributes(bindingContext.getModel()); fragment.mergeAttributes(bindingContext.getModel());
@ -355,8 +387,11 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
Map<String, Object> model = fragment.model(); Map<String, Object> model = fragment.model();
if (streamHandler != null) { if (streamHandler != null) {
return selectedViews.flatMap(views -> render(views, model, MediaType.TEXT_HTML, bindingContext, mutatedExchange)) return selectedViews
.then(Mono.fromSupplier(() -> streamHandler.format(response.getBodyFlux(), fragment, exchange))); .flatMap(views ->
render(views, model, MediaType.TEXT_HTML, bindingContext, mutatedExchange))
.then(Mono.fromSupplier(() -> streamHandler.format(
response.getBodyFlux(), fragment, streamingHints, exchange)));
} }
else { else {
return selectedViews.flatMap(views -> render(views, model, null, bindingContext, mutatedExchange)) return selectedViews.flatMap(views -> render(views, model, null, bindingContext, mutatedExchange))
@ -364,16 +399,6 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
} }
} }
@Nullable
private StreamHandler getStreamHandler(ServerWebExchange exchange) {
for (StreamHandler handler : this.streamHandlers) {
if (handler.supports(exchange.getRequest())) {
return handler;
}
}
return null;
}
private String getNameForReturnValue(MethodParameter returnType) { private String getNameForReturnValue(MethodParameter returnType) {
return Optional.ofNullable(returnType.getMethodAnnotation(ModelAttribute.class)) return Optional.ofNullable(returnType.getMethodAnnotation(ModelAttribute.class))
.filter(ann -> StringUtils.hasText(ann.value())) .filter(ann -> StringUtils.hasText(ann.value()))
@ -499,10 +524,13 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
* Format the given fragment. * Format the given fragment.
* @param fragmentContent the fragment serialized to data buffers * @param fragmentContent the fragment serialized to data buffers
* @param fragment the fragment being rendered * @param fragment the fragment being rendered
* @param streamingHints extra hints for the stream format (e.g. ServerSentEvent wrapper)
* @param exchange the current exchange * @param exchange the current exchange
* @return the formatted fragment * @return the formatted fragment
*/ */
Flux<DataBuffer> format(Flux<DataBuffer> fragmentContent, Fragment fragment, ServerWebExchange exchange); Flux<DataBuffer> format(
Flux<DataBuffer> fragmentContent, Fragment fragment, @Nullable Object streamingHints,
ServerWebExchange exchange);
} }
@ -540,16 +568,21 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
@Override @Override
public Flux<DataBuffer> format( public Flux<DataBuffer> format(
Flux<DataBuffer> fragmentFlux, Fragment fragment, ServerWebExchange exchange) { Flux<DataBuffer> fragmentFlux, Fragment fragment, @Nullable Object hints,
ServerWebExchange exchange) {
MediaType mediaType = exchange.getResponse().getHeaders().getContentType(); MediaType mediaType = exchange.getResponse().getHeaders().getContentType();
Charset charset = (mediaType != null && mediaType.getCharset() != null ? Charset charset = (mediaType != null && mediaType.getCharset() != null ?
mediaType.getCharset() : StandardCharsets.UTF_8); mediaType.getCharset() : StandardCharsets.UTF_8);
Assert.state(hints == null || hints instanceof ServerSentEvent, "Expected ServerSentEvent");
DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory(); DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
String eventLine = (fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : ""); ServerSentEvent<?> sse = (ServerSentEvent<?>) hints;
DataBuffer prefix = encodeText(eventLine + "data:", charset, bufferFactory); CharSequence eventText = (sse != null ? sse.format() :
(fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : "") + "data:");
DataBuffer prefix = encodeText(eventText.toString(), charset, bufferFactory);
DataBuffer suffix = encodeText("\n\n", charset, bufferFactory); DataBuffer suffix = encodeText("\n\n", charset, bufferFactory);
Mono<DataBuffer> content = DataBufferUtils.join(fragmentFlux) Mono<DataBuffer> content = DataBufferUtils.join(fragmentFlux)

67
spring-webflux/src/test/java/org/springframework/web/reactive/result/view/FragmentViewResolutionResultHandlerTests.java

@ -34,7 +34,9 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.accept.HeaderContentTypeResolver; import org.springframework.web.reactive.accept.HeaderContentTypeResolver;
@ -99,7 +101,51 @@ public class FragmentViewResolutionResultHandlerTests {
} }
@Test @Test
void renderSse() { void renderFragmentStream() {
testSse(Flux.just(fragment1, fragment2),
on(Handler.class).resolveReturnType(Flux.class, Fragment.class),
"""
event:fragment1
data:<p>
data: Hello Foo
data:</p>
event:fragment2
data:<p>
data: Hello Bar
data:</p>
""");
}
@Test
void renderServerSentEventFragmentStream() {
ServerSentEvent<Fragment> event1 = ServerSentEvent.builder(fragment1).id("id1").event("event1").build();
ServerSentEvent<Fragment> event2 = ServerSentEvent.builder(fragment2).id("id2").event("event2").build();
MethodParameter returnType = on(Handler.class).resolveReturnType(
Flux.class, ResolvableType.forClassWithGenerics(ServerSentEvent.class, Fragment.class));
testSse(Flux.just(event1, event2), returnType,
"""
id:id1
event:event1
data:<p>
data: Hello Foo
data:</p>
id:id2
event:event2
data:<p>
data: Hello Bar
data:</p>
""");
}
private void testSse(Flux<?> dataFlux, MethodParameter returnType, String output) {
MockServerHttpRequest request = MockServerHttpRequest.get("/") MockServerHttpRequest request = MockServerHttpRequest.get("/")
.accept(MediaType.TEXT_EVENT_STREAM) .accept(MediaType.TEXT_EVENT_STREAM)
.acceptLanguageAsLocales(Locale.ENGLISH) .acceptLanguageAsLocales(Locale.ENGLISH)
@ -110,8 +156,8 @@ public class FragmentViewResolutionResultHandlerTests {
HandlerResult result = new HandlerResult( HandlerResult result = new HandlerResult(
new Handler(), new Handler(),
Flux.just(fragment1, fragment2).subscribeOn(Schedulers.boundedElastic()), dataFlux.subscribeOn(Schedulers.boundedElastic()),
on(Handler.class).resolveReturnType(Flux.class, Fragment.class), returnType,
new BindingContext()); new BindingContext());
String body = initHandler().handleResult(exchange, result) String body = initHandler().handleResult(exchange, result)
@ -119,18 +165,7 @@ public class FragmentViewResolutionResultHandlerTests {
.block(Duration.ofSeconds(60)); .block(Duration.ofSeconds(60));
assertThat(response.getHeaders().getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM); assertThat(response.getHeaders().getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM);
assertThat(body).isEqualTo(""" assertThat(body).isEqualTo(output);
event:fragment1
data:<p>
data: Hello Foo
data:</p>
event:fragment2
data:<p>
data: Hello Bar
data:</p>
""");
} }
private ViewResolutionResultHandler initHandler() { private ViewResolutionResultHandler initHandler() {
@ -155,6 +190,8 @@ public class FragmentViewResolutionResultHandlerTests {
Flux<Fragment> renderFlux() { return null; } Flux<Fragment> renderFlux() { return null; }
Flux<ServerSentEvent<Fragment>> renderSseFlux() { return null; }
List<Fragment> renderList() { return null; } List<Fragment> renderList() { return null; }
} }

5
spring-webflux/src/test/java/org/springframework/web/reactive/result/view/ViewResolutionResultHandlerTests.java

@ -41,6 +41,7 @@ import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.ui.ConcurrentModel; import org.springframework.ui.ConcurrentModel;
@ -84,6 +85,9 @@ class ViewResolutionResultHandlerTests {
testSupports(on(Handler.class).resolveReturnType(FragmentsRendering.class)); testSupports(on(Handler.class).resolveReturnType(FragmentsRendering.class));
testSupports(on(Handler.class).resolveReturnType(Flux.class, Fragment.class)); testSupports(on(Handler.class).resolveReturnType(Flux.class, Fragment.class));
testSupports(on(Handler.class).resolveReturnType(
Flux.class, ResolvableType.forClassWithGenerics(ServerSentEvent.class, Fragment.class)));
testSupports(on(Handler.class).resolveReturnType(List.class, Fragment.class)); testSupports(on(Handler.class).resolveReturnType(List.class, Fragment.class));
testSupports(on(Handler.class).resolveReturnType( testSupports(on(Handler.class).resolveReturnType(
Mono.class, ResolvableType.forClassWithGenerics(List.class, Fragment.class))); Mono.class, ResolvableType.forClassWithGenerics(List.class, Fragment.class)));
@ -457,6 +461,7 @@ class ViewResolutionResultHandlerTests {
FragmentsRendering fragmentsRendering() { return null; } FragmentsRendering fragmentsRendering() { return null; }
Flux<Fragment> fragmentFlux() { return null; } Flux<Fragment> fragmentFlux() { return null; }
Flux<ServerSentEvent<Fragment>> fragmentServerSentEventFlux() { return null; }
Mono<List<Fragment>> monoFragmentList() { return null; } Mono<List<Fragment>> monoFragmentList() { return null; }
List<Fragment> fragmentList() { return null; } List<Fragment> fragmentList() { return null; }

Loading…
Cancel
Save