diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/ChannelSendOperator.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/ChannelSendOperator.java index 36d1c4e7c11..627b3276c6e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/ChannelSendOperator.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/ChannelSendOperator.java @@ -176,7 +176,9 @@ class ChannelSendOperator extends Mono implements Scannable { requiredWriteSubscriber().onNext(item); return; } - // FIXME revisit in case of reentrant sync deadlock + + boolean invokeWriteFunction = false; + synchronized (this) { if (this.state == State.READY_TO_WRITE) { requiredWriteSubscriber().onNext(item); @@ -184,15 +186,7 @@ class ChannelSendOperator extends Mono implements Scannable { else if (this.state == State.NEW) { this.item = item; this.state = State.FIRST_SIGNAL_RECEIVED; - Publisher result; - try { - result = writeFunction.apply(this); - } - catch (Throwable ex) { - this.writeCompletionBarrier.onError(ex); - return; - } - result.subscribe(this.writeCompletionBarrier); + invokeWriteFunction = true; } else { if (this.subscription != null) { @@ -201,6 +195,20 @@ class ChannelSendOperator extends Mono implements Scannable { this.writeCompletionBarrier.onError(new IllegalStateException("Unexpected item.")); } } + + if (!invokeWriteFunction) { + return; + } + + Publisher result; + try { + result = writeFunction.apply(this); + } + catch (Throwable ex) { + this.writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(this.writeCompletionBarrier); } private Subscriber requiredWriteSubscriber() { @@ -234,6 +242,9 @@ class ChannelSendOperator extends Mono implements Scannable { requiredWriteSubscriber().onComplete(); return; } + + boolean invokeWriteFunction = false; + synchronized (this) { if (this.state == State.READY_TO_WRITE) { requiredWriteSubscriber().onComplete(); @@ -241,20 +252,26 @@ class ChannelSendOperator extends Mono implements Scannable { else if (this.state == State.NEW) { this.completed = true; this.state = State.FIRST_SIGNAL_RECEIVED; - Publisher result; - try { - result = writeFunction.apply(this); - } - catch (Throwable ex) { - this.writeCompletionBarrier.onError(ex); - return; - } - result.subscribe(this.writeCompletionBarrier); + invokeWriteFunction = true; } else { this.completed = true; } } + + if (!invokeWriteFunction) { + return; + } + + Publisher result; + try { + result = writeFunction.apply(this); + } + catch (Throwable ex) { + this.writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(this.writeCompletionBarrier); } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java index 1baf031b20a..8c5f8893b64 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java @@ -168,7 +168,9 @@ public class ChannelSendOperator extends Mono implements Scannable { requiredWriteSubscriber().onNext(item); return; } - // FIXME revisit in case of reentrant sync deadlock + + boolean invokeWriteFunction = false; + synchronized (this) { if (this.state == State.READY_TO_WRITE) { requiredWriteSubscriber().onNext(item); @@ -176,15 +178,7 @@ public class ChannelSendOperator extends Mono implements Scannable { else if (this.state == State.NEW) { this.item = item; this.state = State.FIRST_SIGNAL_RECEIVED; - Publisher result; - try { - result = writeFunction.apply(this); - } - catch (Throwable ex) { - this.writeCompletionBarrier.onError(ex); - return; - } - result.subscribe(this.writeCompletionBarrier); + invokeWriteFunction = true; } else { if (this.subscription != null) { @@ -193,6 +187,20 @@ public class ChannelSendOperator extends Mono implements Scannable { this.writeCompletionBarrier.onError(new IllegalStateException("Unexpected item.")); } } + + if (!invokeWriteFunction) { + return; + } + + Publisher result; + try { + result = writeFunction.apply(this); + } + catch (Throwable ex) { + this.writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(this.writeCompletionBarrier); } private Subscriber requiredWriteSubscriber() { @@ -226,6 +234,9 @@ public class ChannelSendOperator extends Mono implements Scannable { requiredWriteSubscriber().onComplete(); return; } + + boolean invokeWriteFunction = false; + synchronized (this) { if (this.state == State.READY_TO_WRITE) { requiredWriteSubscriber().onComplete(); @@ -233,20 +244,26 @@ public class ChannelSendOperator extends Mono implements Scannable { else if (this.state == State.NEW) { this.completed = true; this.state = State.FIRST_SIGNAL_RECEIVED; - Publisher result; - try { - result = writeFunction.apply(this); - } - catch (Throwable ex) { - this.writeCompletionBarrier.onError(ex); - return; - } - result.subscribe(this.writeCompletionBarrier); + invokeWriteFunction = true; } else { this.completed = true; } } + + if (!invokeWriteFunction) { + return; + } + + Publisher result; + try { + result = writeFunction.apply(this); + } + catch (Throwable ex) { + this.writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(this.writeCompletionBarrier); } @Override diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java index ef730a3f6ed..f2d972ce9e1 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java @@ -21,6 +21,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -216,6 +217,44 @@ class ChannelSendOperatorTests { .verify(Duration.ofMillis(5000)); } + @Test + void writeFunctionIsNotInvokedUnderMonitorLock() { + ChannelSendOperator operator = new ChannelSendOperator<>( + Mono.just("one"), + publisher -> { + CountDownLatch acquired = new CountDownLatch(1); + Thread t = new Thread(() -> { + synchronized (publisher) { + acquired.countDown(); + } + }); + t.start(); + + try { + if (!acquired.await(1, TimeUnit.SECONDS)) { + throw new IllegalStateException("writeFunction appears to be invoked under monitor lock"); + } + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new IllegalStateException(ex); + } + finally { + try { + t.join(1_000); + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } + } + + return Mono.empty(); + }); + + StepVerifier.create(operator) + .expectComplete() + .verify(Duration.ofSeconds(2)); + } private Mono sendOperator(Publisher source){ return new ChannelSendOperator<>(source, writer::send);