@ -40,7 +40,9 @@ import org.springframework.core.ReactiveAdapter;
@@ -40,7 +40,9 @@ import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry ;
import org.springframework.core.ResolvableType ;
import org.springframework.core.task.SyncTaskExecutor ;
import org.springframework.core.task.TaskDecorator ;
import org.springframework.core.task.TaskExecutor ;
import org.springframework.core.task.support.ContextPropagatingTaskDecorator ;
import org.springframework.http.MediaType ;
import org.springframework.http.codec.ServerSentEvent ;
import org.springframework.http.server.ServerHttpResponse ;
@ -91,18 +93,25 @@ class ReactiveTypeHandler {
@@ -91,18 +93,25 @@ class ReactiveTypeHandler {
private final ContentNegotiationManager contentNegotiationManager ;
private final ContextSnapshotFactory contextSnapshotFactory ;
public ReactiveTypeHandler ( ) {
this ( ReactiveAdapterRegistry . getSharedInstance ( ) , new SyncTaskExecutor ( ) , new ContentNegotiationManager ( ) ) ;
this ( ReactiveAdapterRegistry . getSharedInstance ( ) , new SyncTaskExecutor ( ) , new ContentNegotiationManager ( ) , null ) ;
}
ReactiveTypeHandler ( ReactiveAdapterRegistry registry , TaskExecutor executor , ContentNegotiationManager manager ) {
ReactiveTypeHandler (
ReactiveAdapterRegistry registry , TaskExecutor executor , ContentNegotiationManager manager ,
@Nullable ContextSnapshotFactory contextSnapshotFactory ) {
Assert . notNull ( registry , "ReactiveAdapterRegistry is required" ) ;
Assert . notNull ( executor , "TaskExecutor is required" ) ;
Assert . notNull ( manager , "ContentNegotiationManager is required" ) ;
this . adapterRegistry = registry ;
this . taskExecutor = executor ;
this . contentNegotiationManager = manager ;
this . contextSnapshotFactory = ( contextSnapshotFactory ! = null ?
contextSnapshotFactory : ContextSnapshotFactory . builder ( ) . build ( ) ) ;
}
@ -129,8 +138,10 @@ class ReactiveTypeHandler {
@@ -129,8 +138,10 @@ class ReactiveTypeHandler {
ReactiveAdapter adapter = this . adapterRegistry . getAdapter ( clazz ) ;
Assert . state ( adapter ! = null , ( ) - > "Unexpected return value type: " + clazz ) ;
TaskDecorator taskDecorator = null ;
if ( isContextPropagationPresent ) {
returnValue = ContextSnapshotHelper . writeReactorContext ( returnValue ) ;
returnValue = ContextSnapshotHelper . writeReactorContext ( returnValue , this . contextSnapshotFactory ) ;
taskDecorator = ContextSnapshotHelper . getTaskDecorator ( this . contextSnapshotFactory ) ;
}
ResolvableType elementType = ResolvableType . forMethodParameter ( returnType ) . getGeneric ( ) ;
@ -143,7 +154,7 @@ class ReactiveTypeHandler {
@@ -143,7 +154,7 @@ class ReactiveTypeHandler {
if ( mediaTypes . stream ( ) . anyMatch ( MediaType . TEXT_EVENT_STREAM : : includes ) | |
ServerSentEvent . class . isAssignableFrom ( elementClass ) ) {
SseEmitter emitter = new SseEmitter ( STREAMING_TIMEOUT_VALUE ) ;
new SseEmitterSubscriber ( emitter , this . taskExecutor ) . connect ( adapter , returnValue ) ;
new SseEmitterSubscriber ( emitter , this . taskExecutor , taskDecorator ) . connect ( adapter , returnValue ) ;
return emitter ;
}
if ( CharSequence . class . isAssignableFrom ( elementClass ) ) {
@ -247,9 +258,14 @@ class ReactiveTypeHandler {
@@ -247,9 +258,14 @@ class ReactiveTypeHandler {
private volatile boolean done ;
protected AbstractEmitterSubscriber ( ResponseBodyEmitter emitter , TaskExecutor executor ) {
private final Runnable sendTask ;
protected AbstractEmitterSubscriber (
ResponseBodyEmitter emitter , TaskExecutor executor , @Nullable TaskDecorator taskDecorator ) {
this . emitter = emitter ;
this . taskExecutor = executor ;
this . sendTask = ( taskDecorator ! = null ? taskDecorator . decorate ( this ) : this ) ;
}
public void connect ( ReactiveAdapter adapter , Object returnValue ) {
@ -302,7 +318,7 @@ class ReactiveTypeHandler {
@@ -302,7 +318,7 @@ class ReactiveTypeHandler {
private void schedule ( ) {
try {
this . taskExecutor . execute ( this ) ;
this . taskExecutor . execute ( this . sendTask ) ;
}
catch ( Throwable ex ) {
try {
@ -380,8 +396,8 @@ class ReactiveTypeHandler {
@@ -380,8 +396,8 @@ class ReactiveTypeHandler {
private static class SseEmitterSubscriber extends AbstractEmitterSubscriber {
SseEmitterSubscriber ( SseEmitter sseEmitter , TaskExecutor executor ) {
super ( sseEmitter , executor ) ;
SseEmitterSubscriber ( SseEmitter sseEmitter , TaskExecutor executor , @Nullable TaskDecorator taskDecorator ) {
super ( sseEmitter , executor , taskDecorator ) ;
}
@Override
@ -423,8 +439,10 @@ class ReactiveTypeHandler {
@@ -423,8 +439,10 @@ class ReactiveTypeHandler {
private static class JsonEmitterSubscriber extends AbstractEmitterSubscriber {
JsonEmitterSubscriber ( ResponseBodyEmitter emitter , TaskExecutor executor ) {
super ( emitter , executor ) ;
JsonEmitterSubscriber (
ResponseBodyEmitter emitter , TaskExecutor executor ) {
super ( emitter , executor , null ) ;
}
@Override
@ -438,7 +456,7 @@ class ReactiveTypeHandler {
@@ -438,7 +456,7 @@ class ReactiveTypeHandler {
private static class TextEmitterSubscriber extends AbstractEmitterSubscriber {
TextEmitterSubscriber ( ResponseBodyEmitter emitter , TaskExecutor executor ) {
super ( emitter , executor ) ;
super ( emitter , executor , null ) ;
}
@Override
@ -518,22 +536,24 @@ class ReactiveTypeHandler {
@@ -518,22 +536,24 @@ class ReactiveTypeHandler {
private static class ContextSnapshotHelper {
private static final ContextSnapshotFactory factory = ContextSnapshotFactory . builder ( ) . build ( ) ;
@SuppressWarnings ( "ReactiveStreamsUnusedPublisher" )
public static Object writeReactorContext ( Object returnValue ) {
public static Object writeReactorContext ( Object returnValue , ContextSnapshotFactory snapshotFactory ) {
if ( Mono . class . isAssignableFrom ( returnValue . getClass ( ) ) ) {
ContextSnapshot snapshot = f actory. captureAll ( ) ;
ContextSnapshot snapshot = snapshotF actory. captureAll ( ) ;
return ( ( Mono < ? > ) returnValue ) . contextWrite ( snapshot : : updateContext ) ;
}
else if ( Flux . class . isAssignableFrom ( returnValue . getClass ( ) ) ) {
ContextSnapshot snapshot = f actory. captureAll ( ) ;
ContextSnapshot snapshot = snapshotF actory. captureAll ( ) ;
return ( ( Flux < ? > ) returnValue ) . contextWrite ( snapshot : : updateContext ) ;
}
else {
return returnValue ;
}
}
public static TaskDecorator getTaskDecorator ( ContextSnapshotFactory snapshotFactory ) {
return new ContextPropagatingTaskDecorator ( snapshotFactory ) ;
}
}
}