From 6c2f6023695823ef5ccb338d4bed40d4833c137d Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Mon, 24 Jun 2024 15:13:08 +0100 Subject: [PATCH] Propagate context to send for SSE Flux Closes gh-32813 --- .../annotation/ReactiveTypeHandler.java | 52 +++++++--- ...ResponseBodyEmitterReturnValueHandler.java | 4 +- .../annotation/ReactiveTypeHandlerTests.java | 99 ++++++++++++++++++- 3 files changed, 135 insertions(+), 20 deletions(-) diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java index 90afb4f6c15..b2d58cd7a3b 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java @@ -40,7 +40,9 @@ import org.springframework.core.ReactiveAdapter; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.ResolvableType; import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.core.task.TaskDecorator; import org.springframework.core.task.TaskExecutor; +import org.springframework.core.task.support.ContextPropagatingTaskDecorator; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; import org.springframework.http.server.ServerHttpResponse; @@ -91,18 +93,25 @@ class ReactiveTypeHandler { private final ContentNegotiationManager contentNegotiationManager; + private final ContextSnapshotFactory contextSnapshotFactory; + public ReactiveTypeHandler() { - this(ReactiveAdapterRegistry.getSharedInstance(), new SyncTaskExecutor(), new ContentNegotiationManager()); + this(ReactiveAdapterRegistry.getSharedInstance(), new SyncTaskExecutor(), new ContentNegotiationManager(), null); } - ReactiveTypeHandler(ReactiveAdapterRegistry registry, TaskExecutor executor, ContentNegotiationManager manager) { + ReactiveTypeHandler( + ReactiveAdapterRegistry registry, TaskExecutor executor, ContentNegotiationManager manager, + @Nullable ContextSnapshotFactory contextSnapshotFactory) { + Assert.notNull(registry, "ReactiveAdapterRegistry is required"); Assert.notNull(executor, "TaskExecutor is required"); Assert.notNull(manager, "ContentNegotiationManager is required"); this.adapterRegistry = registry; this.taskExecutor = executor; this.contentNegotiationManager = manager; + this.contextSnapshotFactory = (contextSnapshotFactory != null ? + contextSnapshotFactory : ContextSnapshotFactory.builder().build()); } @@ -129,8 +138,10 @@ class ReactiveTypeHandler { ReactiveAdapter adapter = this.adapterRegistry.getAdapter(clazz); Assert.state(adapter != null, () -> "Unexpected return value type: " + clazz); + TaskDecorator taskDecorator = null; if (isContextPropagationPresent) { - returnValue = ContextSnapshotHelper.writeReactorContext(returnValue); + returnValue = ContextSnapshotHelper.writeReactorContext(returnValue, this.contextSnapshotFactory); + taskDecorator = ContextSnapshotHelper.getTaskDecorator(this.contextSnapshotFactory); } ResolvableType elementType = ResolvableType.forMethodParameter(returnType).getGeneric(); @@ -143,7 +154,7 @@ class ReactiveTypeHandler { if (mediaTypes.stream().anyMatch(MediaType.TEXT_EVENT_STREAM::includes) || ServerSentEvent.class.isAssignableFrom(elementClass)) { SseEmitter emitter = new SseEmitter(STREAMING_TIMEOUT_VALUE); - new SseEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue); + new SseEmitterSubscriber(emitter, this.taskExecutor, taskDecorator).connect(adapter, returnValue); return emitter; } if (CharSequence.class.isAssignableFrom(elementClass)) { @@ -247,9 +258,14 @@ class ReactiveTypeHandler { private volatile boolean done; - protected AbstractEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) { + private final Runnable sendTask; + + protected AbstractEmitterSubscriber( + ResponseBodyEmitter emitter, TaskExecutor executor, @Nullable TaskDecorator taskDecorator) { + this.emitter = emitter; this.taskExecutor = executor; + this.sendTask = (taskDecorator != null ? taskDecorator.decorate(this) : this); } public void connect(ReactiveAdapter adapter, Object returnValue) { @@ -302,7 +318,7 @@ class ReactiveTypeHandler { private void schedule() { try { - this.taskExecutor.execute(this); + this.taskExecutor.execute(this.sendTask); } catch (Throwable ex) { try { @@ -380,8 +396,8 @@ class ReactiveTypeHandler { private static class SseEmitterSubscriber extends AbstractEmitterSubscriber { - SseEmitterSubscriber(SseEmitter sseEmitter, TaskExecutor executor) { - super(sseEmitter, executor); + SseEmitterSubscriber(SseEmitter sseEmitter, TaskExecutor executor, @Nullable TaskDecorator taskDecorator) { + super(sseEmitter, executor, taskDecorator); } @Override @@ -423,8 +439,10 @@ class ReactiveTypeHandler { private static class JsonEmitterSubscriber extends AbstractEmitterSubscriber { - JsonEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) { - super(emitter, executor); + JsonEmitterSubscriber( + ResponseBodyEmitter emitter, TaskExecutor executor) { + + super(emitter, executor, null); } @Override @@ -438,7 +456,7 @@ class ReactiveTypeHandler { private static class TextEmitterSubscriber extends AbstractEmitterSubscriber { TextEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) { - super(emitter, executor); + super(emitter, executor, null); } @Override @@ -518,22 +536,24 @@ class ReactiveTypeHandler { private static class ContextSnapshotHelper { - private static final ContextSnapshotFactory factory = ContextSnapshotFactory.builder().build(); - @SuppressWarnings("ReactiveStreamsUnusedPublisher") - public static Object writeReactorContext(Object returnValue) { + public static Object writeReactorContext(Object returnValue, ContextSnapshotFactory snapshotFactory) { if (Mono.class.isAssignableFrom(returnValue.getClass())) { - ContextSnapshot snapshot = factory.captureAll(); + ContextSnapshot snapshot = snapshotFactory.captureAll(); return ((Mono) returnValue).contextWrite(snapshot::updateContext); } else if (Flux.class.isAssignableFrom(returnValue.getClass())) { - ContextSnapshot snapshot = factory.captureAll(); + ContextSnapshot snapshot = snapshotFactory.captureAll(); return ((Flux) returnValue).contextWrite(snapshot::updateContext); } else { return returnValue; } } + + public static TaskDecorator getTaskDecorator(ContextSnapshotFactory snapshotFactory) { + return new ContextPropagatingTaskDecorator(snapshotFactory); + } } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java index ee4eebcf9c2..fd23bf24d97 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -91,7 +91,7 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur Assert.notEmpty(messageConverters, "HttpMessageConverter List must not be empty"); this.sseMessageConverters = initSseConverters(messageConverters); - this.reactiveHandler = new ReactiveTypeHandler(registry, executor, manager); + this.reactiveHandler = new ReactiveTypeHandler(registry, executor, manager, null); } private static List> initSseConverters(List> converters) { diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java index e5422c7b8e8..3c4fdf2d0a1 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java @@ -22,28 +22,40 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Collectors; +import io.micrometer.context.ContextRegistry; +import io.micrometer.context.ContextSnapshotFactory; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.SingleEmitter; +import jakarta.servlet.http.HttpServletRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; +import reactor.util.context.ReactorContextAccessor; import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.ResolvableType; +import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.core.task.TaskExecutor; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.lang.Nullable; import org.springframework.web.accept.ContentNegotiationManager; import org.springframework.web.accept.ContentNegotiationManagerFactoryBean; import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestAttributesThreadLocalAccessor; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletWebRequest; import org.springframework.web.context.request.async.AsyncWebRequest; import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; @@ -75,12 +87,18 @@ class ReactiveTypeHandlerTests { @BeforeEach void setup() throws Exception { + this.handler = initHandler(new SyncTaskExecutor(), null); + resetRequest(); + } + + private static ReactiveTypeHandler initHandler( + TaskExecutor taskExecutor, @Nullable ContextSnapshotFactory snapshotFactory) { + ContentNegotiationManagerFactoryBean factoryBean = new ContentNegotiationManagerFactoryBean(); factoryBean.afterPropertiesSet(); ContentNegotiationManager manager = factoryBean.getObject(); ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry.getSharedInstance(); - this.handler = new ReactiveTypeHandler(adapterRegistry, new SyncTaskExecutor(), manager); - resetRequest(); + return new ReactiveTypeHandler(adapterRegistry, taskExecutor, manager, snapshotFactory); } private void resetRequest() { @@ -414,6 +432,42 @@ class ReactiveTypeHandlerTests { testEmitterContentType("application/json"); } + @Test + void contextPropagation() throws Exception { + + ContextRegistry registry = new ContextRegistry(); + registry.registerThreadLocalAccessor(new RequestAttributesThreadLocalAccessor()); + registry.registerContextAccessor(new ReactorContextAccessor()); + ContextSnapshotFactory snapshotFactory = ContextSnapshotFactory.builder().contextRegistry(registry).build(); + + ModelAndViewContainer mavContainer = new ModelAndViewContainer(); + MethodParameter returnType = on(TestController.class).resolveReturnType(Flux.class, forClass(String.class)); + ReactiveTypeHandler handler = initHandler(new SimpleAsyncTaskExecutor(), snapshotFactory); + + this.servletRequest.addHeader("Accept", MediaType.TEXT_EVENT_STREAM_VALUE); + this.servletRequest.setAttribute("key", "context value"); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.servletRequest)); + + try { + Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); + ResponseBodyEmitter emitter = handler.handleValue(sink.asFlux(), returnType, mavContainer, this.webRequest); + + ContextEmitterHandler emitterHandler = new ContextEmitterHandler(); + emitter.initialize(emitterHandler); + + sink.tryEmitNext("emitted value"); + emitterHandler.awaitMessageCount(1); + + sink.tryEmitComplete(); + + assertThat(emitterHandler.getValuesAsText()).isEqualTo("data:emitted value\n\n"); + assertThat(emitterHandler.getSavedRequest()).isSameAs(this.servletRequest); + } + finally { + RequestContextHolder.resetRequestAttributes(); + } + } + private void testEmitterContentType(String expected) throws Exception { ServletServerHttpResponse message = new ServletServerHttpResponse(this.servletResponse); ResponseBodyEmitter emitter = handleValue(Flux.empty(), Flux.class, forClass(String.class)); @@ -541,6 +595,47 @@ class ReactiveTypeHandlerTests { } } + + private static class ContextEmitterHandler extends EmitterHandler { + + private final AtomicInteger count = new AtomicInteger(); + + private HttpServletRequest savedRequest; + + public HttpServletRequest getSavedRequest() { + return this.savedRequest; + } + + @Override + public void send(Object data, MediaType mediaType) throws IOException { + saveRequest(); + super.send(data, mediaType); + this.count.addAndGet(1); + } + + @Override + public void send(Set items) throws IOException { + saveRequest(); + for (ResponseBodyEmitter.DataWithMediaType item : items) { + super.send(item.getData(), item.getMediaType()); + } + this.count.addAndGet(1); + } + + private void saveRequest() { + RequestAttributes attributes = RequestContextHolder.currentRequestAttributes(); + this.savedRequest = ((ServletRequestAttributes) attributes).getRequest(); + } + + public void awaitMessageCount(int count) throws InterruptedException { + for (int i = 0; i < 10 && this.count.get() < count; i++) { + Thread.sleep(10); + } + assertThat(this.count.get()).isGreaterThanOrEqualTo(count); + } + } + + private static class Bar { private final String value;