diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java index 8315c46cb0d..ebea1c16371 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,8 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; import reactor.util.context.Context; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -279,13 +281,20 @@ public class ChannelSendOperator extends Mono implements Scannable { } private boolean emitCachedSignals() { - if (this.item != null) { - requiredWriteSubscriber().onNext(this.item); - } if (this.error != null) { - requiredWriteSubscriber().onError(this.error); + try { + requiredWriteSubscriber().onError(this.error); + } + finally { + releaseCachedItem(); + } return true; } + T item = this.item; + this.item = null; + if (item != null) { + requiredWriteSubscriber().onNext(item); + } if (this.completed) { requiredWriteSubscriber().onComplete(); return true; @@ -298,7 +307,22 @@ public class ChannelSendOperator extends Mono implements Scannable { Subscription s = this.subscription; if (s != null) { this.subscription = null; - s.cancel(); + try { + s.cancel(); + } + finally { + releaseCachedItem(); + } + } + } + + private void releaseCachedItem() { + synchronized (this) { + Object item = this.item; + if (item instanceof DataBuffer) { + DataBufferUtils.release((DataBuffer) item); + } + this.item = null; } } diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java index 5adeeab9898..7e2cd98cdc1 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,25 +16,29 @@ package org.springframework.http.server.reactive; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import io.netty.buffer.ByteBufAllocator; import org.junit.Before; import org.junit.Test; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Signal; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.LeakAwareDataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBufferFactory; + +import static org.junit.Assert.*; /** * @author Rossen Stoyanchev @@ -50,9 +54,6 @@ public class ChannelSendOperatorTests { this.writer = new OneByOneAsyncWriter(); } - private Mono sendOperator(Publisher source){ - return new ChannelSendOperator<>(source, writer::send); - } @Test public void errorBeforeFirstItem() throws Exception { @@ -130,6 +131,66 @@ public class ChannelSendOperatorTests { assertSame(error, this.writer.error); } + @Test // gh-22720 + public void cancelWhileItemCached() { + NettyDataBufferFactory delegate = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT); + LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(delegate); + + ChannelSendOperator operator = new ChannelSendOperator<>( + Mono.fromCallable(() -> { + DataBuffer dataBuffer = bufferFactory.allocateBuffer(); + dataBuffer.write("foo", StandardCharsets.UTF_8); + return dataBuffer; + }), + publisher -> { + ZeroDemandSubscriber subscriber = new ZeroDemandSubscriber(); + publisher.subscribe(subscriber); + return Mono.never(); + }); + + BaseSubscriber subscriber = new BaseSubscriber() {}; + operator.subscribe(subscriber); + subscriber.cancel(); + + bufferFactory.checkForLeaks(); + } + + @Test // gh-22720 + public void errorWhileItemCached() { + NettyDataBufferFactory delegate = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT); + LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(delegate); + ZeroDemandSubscriber writeSubscriber = new ZeroDemandSubscriber(); + + ChannelSendOperator operator = new ChannelSendOperator<>( + Flux.create(sink -> { + DataBuffer dataBuffer = bufferFactory.allocateBuffer(); + dataBuffer.write("foo", StandardCharsets.UTF_8); + sink.next(dataBuffer); + sink.error(new IllegalStateException("err")); + }), + publisher -> { + publisher.subscribe(writeSubscriber); + return Mono.never(); + }); + + + operator.subscribe(new BaseSubscriber() {}); + try { + writeSubscriber.signalDemand(1); // Let cached signals ("foo" and error) be published.. + } + catch (Throwable ex) { + assertNotNull(ex.getCause()); + assertEquals("err", ex.getCause().getMessage()); + } + + bufferFactory.checkForLeaks(); + } + + + private Mono sendOperator(Publisher source){ + return new ChannelSendOperator<>(source, writer::send); + } + private static class OneByOneAsyncWriter { @@ -182,4 +243,18 @@ public class ChannelSendOperatorTests { } } + + private static class ZeroDemandSubscriber extends BaseSubscriber { + + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // Just subscribe without requesting + } + + public void signalDemand(long demand) { + upstream().request(demand); + } + } + }