From 9c48d63082e371da4d0870f6e222db35a4412362 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 1 Apr 2019 17:14:49 -0400 Subject: [PATCH] Release cached item in ChannelSendOperator 1. If the write Subscriber cancels with the item cached, release it. 2. If the write Publisher emits an error while the item is cached, when the write Subscriber subscribes, release the cached item and emit the error signal. Closes gh-22720 --- .../server/reactive/ChannelSendOperator.java | 36 ++++++-- .../reactive/ChannelSendOperatorTests.java | 91 +++++++++++++++++-- 2 files changed, 113 insertions(+), 14 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java index 8315c46cb0d..ebea1c16371 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,8 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; import reactor.util.context.Context; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -279,13 +281,20 @@ public class ChannelSendOperator extends Mono implements Scannable { } private boolean emitCachedSignals() { - if (this.item != null) { - requiredWriteSubscriber().onNext(this.item); - } if (this.error != null) { - requiredWriteSubscriber().onError(this.error); + try { + requiredWriteSubscriber().onError(this.error); + } + finally { + releaseCachedItem(); + } return true; } + T item = this.item; + this.item = null; + if (item != null) { + requiredWriteSubscriber().onNext(item); + } if (this.completed) { requiredWriteSubscriber().onComplete(); return true; @@ -298,7 +307,22 @@ public class ChannelSendOperator extends Mono implements Scannable { Subscription s = this.subscription; if (s != null) { this.subscription = null; - s.cancel(); + try { + s.cancel(); + } + finally { + releaseCachedItem(); + } + } + } + + private void releaseCachedItem() { + synchronized (this) { + Object item = this.item; + if (item instanceof DataBuffer) { + DataBufferUtils.release((DataBuffer) item); + } + this.item = null; } } diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java index 5adeeab9898..7e2cd98cdc1 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,25 +16,29 @@ package org.springframework.http.server.reactive; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import io.netty.buffer.ByteBufAllocator; import org.junit.Before; import org.junit.Test; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Signal; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.LeakAwareDataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBufferFactory; + +import static org.junit.Assert.*; /** * @author Rossen Stoyanchev @@ -50,9 +54,6 @@ public class ChannelSendOperatorTests { this.writer = new OneByOneAsyncWriter(); } - private Mono sendOperator(Publisher source){ - return new ChannelSendOperator<>(source, writer::send); - } @Test public void errorBeforeFirstItem() throws Exception { @@ -130,6 +131,66 @@ public class ChannelSendOperatorTests { assertSame(error, this.writer.error); } + @Test // gh-22720 + public void cancelWhileItemCached() { + NettyDataBufferFactory delegate = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT); + LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(delegate); + + ChannelSendOperator operator = new ChannelSendOperator<>( + Mono.fromCallable(() -> { + DataBuffer dataBuffer = bufferFactory.allocateBuffer(); + dataBuffer.write("foo", StandardCharsets.UTF_8); + return dataBuffer; + }), + publisher -> { + ZeroDemandSubscriber subscriber = new ZeroDemandSubscriber(); + publisher.subscribe(subscriber); + return Mono.never(); + }); + + BaseSubscriber subscriber = new BaseSubscriber() {}; + operator.subscribe(subscriber); + subscriber.cancel(); + + bufferFactory.checkForLeaks(); + } + + @Test // gh-22720 + public void errorWhileItemCached() { + NettyDataBufferFactory delegate = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT); + LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(delegate); + ZeroDemandSubscriber writeSubscriber = new ZeroDemandSubscriber(); + + ChannelSendOperator operator = new ChannelSendOperator<>( + Flux.create(sink -> { + DataBuffer dataBuffer = bufferFactory.allocateBuffer(); + dataBuffer.write("foo", StandardCharsets.UTF_8); + sink.next(dataBuffer); + sink.error(new IllegalStateException("err")); + }), + publisher -> { + publisher.subscribe(writeSubscriber); + return Mono.never(); + }); + + + operator.subscribe(new BaseSubscriber() {}); + try { + writeSubscriber.signalDemand(1); // Let cached signals ("foo" and error) be published.. + } + catch (Throwable ex) { + assertNotNull(ex.getCause()); + assertEquals("err", ex.getCause().getMessage()); + } + + bufferFactory.checkForLeaks(); + } + + + private Mono sendOperator(Publisher source){ + return new ChannelSendOperator<>(source, writer::send); + } + private static class OneByOneAsyncWriter { @@ -182,4 +243,18 @@ public class ChannelSendOperatorTests { } } + + private static class ZeroDemandSubscriber extends BaseSubscriber { + + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // Just subscribe without requesting + } + + public void signalDemand(long demand) { + upstream().request(demand); + } + } + }