Browse Source

Propagate context to send for SSE Flux

Closes gh-32813
pull/33097/head
rstoyanchev 2 years ago
parent
commit
6c2f602369
  1. 52
      spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java
  2. 4
      spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java
  3. 99
      spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java

52
spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java

@ -40,7 +40,9 @@ import org.springframework.core.ReactiveAdapter; @@ -40,7 +40,9 @@ import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.core.task.TaskDecorator;
import org.springframework.core.task.TaskExecutor;
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.http.server.ServerHttpResponse;
@ -91,18 +93,25 @@ class ReactiveTypeHandler { @@ -91,18 +93,25 @@ class ReactiveTypeHandler {
private final ContentNegotiationManager contentNegotiationManager;
private final ContextSnapshotFactory contextSnapshotFactory;
public ReactiveTypeHandler() {
this(ReactiveAdapterRegistry.getSharedInstance(), new SyncTaskExecutor(), new ContentNegotiationManager());
this(ReactiveAdapterRegistry.getSharedInstance(), new SyncTaskExecutor(), new ContentNegotiationManager(), null);
}
ReactiveTypeHandler(ReactiveAdapterRegistry registry, TaskExecutor executor, ContentNegotiationManager manager) {
ReactiveTypeHandler(
ReactiveAdapterRegistry registry, TaskExecutor executor, ContentNegotiationManager manager,
@Nullable ContextSnapshotFactory contextSnapshotFactory) {
Assert.notNull(registry, "ReactiveAdapterRegistry is required");
Assert.notNull(executor, "TaskExecutor is required");
Assert.notNull(manager, "ContentNegotiationManager is required");
this.adapterRegistry = registry;
this.taskExecutor = executor;
this.contentNegotiationManager = manager;
this.contextSnapshotFactory = (contextSnapshotFactory != null ?
contextSnapshotFactory : ContextSnapshotFactory.builder().build());
}
@ -129,8 +138,10 @@ class ReactiveTypeHandler { @@ -129,8 +138,10 @@ class ReactiveTypeHandler {
ReactiveAdapter adapter = this.adapterRegistry.getAdapter(clazz);
Assert.state(adapter != null, () -> "Unexpected return value type: " + clazz);
TaskDecorator taskDecorator = null;
if (isContextPropagationPresent) {
returnValue = ContextSnapshotHelper.writeReactorContext(returnValue);
returnValue = ContextSnapshotHelper.writeReactorContext(returnValue, this.contextSnapshotFactory);
taskDecorator = ContextSnapshotHelper.getTaskDecorator(this.contextSnapshotFactory);
}
ResolvableType elementType = ResolvableType.forMethodParameter(returnType).getGeneric();
@ -143,7 +154,7 @@ class ReactiveTypeHandler { @@ -143,7 +154,7 @@ class ReactiveTypeHandler {
if (mediaTypes.stream().anyMatch(MediaType.TEXT_EVENT_STREAM::includes) ||
ServerSentEvent.class.isAssignableFrom(elementClass)) {
SseEmitter emitter = new SseEmitter(STREAMING_TIMEOUT_VALUE);
new SseEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue);
new SseEmitterSubscriber(emitter, this.taskExecutor, taskDecorator).connect(adapter, returnValue);
return emitter;
}
if (CharSequence.class.isAssignableFrom(elementClass)) {
@ -247,9 +258,14 @@ class ReactiveTypeHandler { @@ -247,9 +258,14 @@ class ReactiveTypeHandler {
private volatile boolean done;
protected AbstractEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) {
private final Runnable sendTask;
protected AbstractEmitterSubscriber(
ResponseBodyEmitter emitter, TaskExecutor executor, @Nullable TaskDecorator taskDecorator) {
this.emitter = emitter;
this.taskExecutor = executor;
this.sendTask = (taskDecorator != null ? taskDecorator.decorate(this) : this);
}
public void connect(ReactiveAdapter adapter, Object returnValue) {
@ -302,7 +318,7 @@ class ReactiveTypeHandler { @@ -302,7 +318,7 @@ class ReactiveTypeHandler {
private void schedule() {
try {
this.taskExecutor.execute(this);
this.taskExecutor.execute(this.sendTask);
}
catch (Throwable ex) {
try {
@ -380,8 +396,8 @@ class ReactiveTypeHandler { @@ -380,8 +396,8 @@ class ReactiveTypeHandler {
private static class SseEmitterSubscriber extends AbstractEmitterSubscriber {
SseEmitterSubscriber(SseEmitter sseEmitter, TaskExecutor executor) {
super(sseEmitter, executor);
SseEmitterSubscriber(SseEmitter sseEmitter, TaskExecutor executor, @Nullable TaskDecorator taskDecorator) {
super(sseEmitter, executor, taskDecorator);
}
@Override
@ -423,8 +439,10 @@ class ReactiveTypeHandler { @@ -423,8 +439,10 @@ class ReactiveTypeHandler {
private static class JsonEmitterSubscriber extends AbstractEmitterSubscriber {
JsonEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) {
super(emitter, executor);
JsonEmitterSubscriber(
ResponseBodyEmitter emitter, TaskExecutor executor) {
super(emitter, executor, null);
}
@Override
@ -438,7 +456,7 @@ class ReactiveTypeHandler { @@ -438,7 +456,7 @@ class ReactiveTypeHandler {
private static class TextEmitterSubscriber extends AbstractEmitterSubscriber {
TextEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) {
super(emitter, executor);
super(emitter, executor, null);
}
@Override
@ -518,22 +536,24 @@ class ReactiveTypeHandler { @@ -518,22 +536,24 @@ class ReactiveTypeHandler {
private static class ContextSnapshotHelper {
private static final ContextSnapshotFactory factory = ContextSnapshotFactory.builder().build();
@SuppressWarnings("ReactiveStreamsUnusedPublisher")
public static Object writeReactorContext(Object returnValue) {
public static Object writeReactorContext(Object returnValue, ContextSnapshotFactory snapshotFactory) {
if (Mono.class.isAssignableFrom(returnValue.getClass())) {
ContextSnapshot snapshot = factory.captureAll();
ContextSnapshot snapshot = snapshotFactory.captureAll();
return ((Mono<?>) returnValue).contextWrite(snapshot::updateContext);
}
else if (Flux.class.isAssignableFrom(returnValue.getClass())) {
ContextSnapshot snapshot = factory.captureAll();
ContextSnapshot snapshot = snapshotFactory.captureAll();
return ((Flux<?>) returnValue).contextWrite(snapshot::updateContext);
}
else {
return returnValue;
}
}
public static TaskDecorator getTaskDecorator(ContextSnapshotFactory snapshotFactory) {
return new ContextPropagatingTaskDecorator(snapshotFactory);
}
}
}

4
spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -91,7 +91,7 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur @@ -91,7 +91,7 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur
Assert.notEmpty(messageConverters, "HttpMessageConverter List must not be empty");
this.sseMessageConverters = initSseConverters(messageConverters);
this.reactiveHandler = new ReactiveTypeHandler(registry, executor, manager);
this.reactiveHandler = new ReactiveTypeHandler(registry, executor, manager, null);
}
private static List<HttpMessageConverter<?>> initSseConverters(List<HttpMessageConverter<?>> converters) {

99
spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java

@ -22,28 +22,40 @@ import java.util.Arrays; @@ -22,28 +22,40 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import io.micrometer.context.ContextRegistry;
import io.micrometer.context.ContextSnapshotFactory;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.core.SingleEmitter;
import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.util.context.ReactorContextAccessor;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.web.accept.ContentNegotiationManager;
import org.springframework.web.accept.ContentNegotiationManagerFactoryBean;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestAttributesThreadLocalAccessor;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.context.request.async.AsyncWebRequest;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
@ -75,12 +87,18 @@ class ReactiveTypeHandlerTests { @@ -75,12 +87,18 @@ class ReactiveTypeHandlerTests {
@BeforeEach
void setup() throws Exception {
this.handler = initHandler(new SyncTaskExecutor(), null);
resetRequest();
}
private static ReactiveTypeHandler initHandler(
TaskExecutor taskExecutor, @Nullable ContextSnapshotFactory snapshotFactory) {
ContentNegotiationManagerFactoryBean factoryBean = new ContentNegotiationManagerFactoryBean();
factoryBean.afterPropertiesSet();
ContentNegotiationManager manager = factoryBean.getObject();
ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry.getSharedInstance();
this.handler = new ReactiveTypeHandler(adapterRegistry, new SyncTaskExecutor(), manager);
resetRequest();
return new ReactiveTypeHandler(adapterRegistry, taskExecutor, manager, snapshotFactory);
}
private void resetRequest() {
@ -414,6 +432,42 @@ class ReactiveTypeHandlerTests { @@ -414,6 +432,42 @@ class ReactiveTypeHandlerTests {
testEmitterContentType("application/json");
}
@Test
void contextPropagation() throws Exception {
ContextRegistry registry = new ContextRegistry();
registry.registerThreadLocalAccessor(new RequestAttributesThreadLocalAccessor());
registry.registerContextAccessor(new ReactorContextAccessor());
ContextSnapshotFactory snapshotFactory = ContextSnapshotFactory.builder().contextRegistry(registry).build();
ModelAndViewContainer mavContainer = new ModelAndViewContainer();
MethodParameter returnType = on(TestController.class).resolveReturnType(Flux.class, forClass(String.class));
ReactiveTypeHandler handler = initHandler(new SimpleAsyncTaskExecutor(), snapshotFactory);
this.servletRequest.addHeader("Accept", MediaType.TEXT_EVENT_STREAM_VALUE);
this.servletRequest.setAttribute("key", "context value");
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.servletRequest));
try {
Sinks.Many<String> sink = Sinks.many().unicast().onBackpressureBuffer();
ResponseBodyEmitter emitter = handler.handleValue(sink.asFlux(), returnType, mavContainer, this.webRequest);
ContextEmitterHandler emitterHandler = new ContextEmitterHandler();
emitter.initialize(emitterHandler);
sink.tryEmitNext("emitted value");
emitterHandler.awaitMessageCount(1);
sink.tryEmitComplete();
assertThat(emitterHandler.getValuesAsText()).isEqualTo("data:emitted value\n\n");
assertThat(emitterHandler.getSavedRequest()).isSameAs(this.servletRequest);
}
finally {
RequestContextHolder.resetRequestAttributes();
}
}
private void testEmitterContentType(String expected) throws Exception {
ServletServerHttpResponse message = new ServletServerHttpResponse(this.servletResponse);
ResponseBodyEmitter emitter = handleValue(Flux.empty(), Flux.class, forClass(String.class));
@ -541,6 +595,47 @@ class ReactiveTypeHandlerTests { @@ -541,6 +595,47 @@ class ReactiveTypeHandlerTests {
}
}
private static class ContextEmitterHandler extends EmitterHandler {
private final AtomicInteger count = new AtomicInteger();
private HttpServletRequest savedRequest;
public HttpServletRequest getSavedRequest() {
return this.savedRequest;
}
@Override
public void send(Object data, MediaType mediaType) throws IOException {
saveRequest();
super.send(data, mediaType);
this.count.addAndGet(1);
}
@Override
public void send(Set<ResponseBodyEmitter.DataWithMediaType> items) throws IOException {
saveRequest();
for (ResponseBodyEmitter.DataWithMediaType item : items) {
super.send(item.getData(), item.getMediaType());
}
this.count.addAndGet(1);
}
private void saveRequest() {
RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
this.savedRequest = ((ServletRequestAttributes) attributes).getRequest();
}
public void awaitMessageCount(int count) throws InterruptedException {
for (int i = 0; i < 10 && this.count.get() < count; i++) {
Thread.sleep(10);
}
assertThat(this.count.get()).isGreaterThanOrEqualTo(count);
}
}
private static class Bar {
private final String value;

Loading…
Cancel
Save