diff --git a/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java index 5a63145b4be..db0f0ec04c3 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java @@ -128,7 +128,8 @@ public class EncoderHttpMessageWriter implements HttpMessageWriter { message.getHeaders().setContentLength(buffer.readableByteCount()); return message.writeWith(Mono.just(buffer) .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release)); - }); + }) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); } if (isStreamingMediaType(contentType)) { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java index b6c7f98037d..272c12ddffa 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java @@ -203,10 +203,28 @@ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { // We must resolve value first however, for a chance to handle potential error. if (body instanceof Mono) { return ((Mono) body) - .flatMap(buffer -> doCommit(() -> - writeWithInternal(Mono.fromCallable(() -> buffer) - .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release)))) - .doOnError(t -> getHeaders().clearContentHeaders()); + .flatMap(buffer -> { + AtomicReference subscribed = new AtomicReference<>(false); + return doCommit( + () -> { + try { + return writeWithInternal(Mono.fromCallable(() -> buffer) + .doOnSubscribe(s -> subscribed.set(true)) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release)); + } + catch (Throwable ex) { + return Mono.error(ex); + } + }) + .doOnError(ex -> DataBufferUtils.release(buffer)) + .doOnCancel(() -> { + if (!subscribed.get()) { + DataBufferUtils.release(buffer); + } + }); + }) + .doOnError(t -> getHeaders().clearContentHeaders()) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); } else { return new ChannelSendOperator<>(body, inner -> doCommit(() -> writeWithInternal(inner))) diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java index cdb4225381a..de08959af00 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java @@ -19,7 +19,9 @@ package org.springframework.http.server.reactive; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.function.Consumer; import java.util.function.Supplier; @@ -27,14 +29,22 @@ import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.netty.channel.AbortedException; import reactor.test.StepVerifier; +import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.testfixture.io.buffer.LeakAwareDataBufferFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseCookie; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest; +import org.springframework.web.testfixture.http.server.reactive.MockServerHttpResponse; import static org.assertj.core.api.Assertions.assertThat; @@ -176,6 +186,24 @@ public class ServerHttpResponseTests { }); } + @Test // gh-26232 + void monoResponseShouldNotLeakIfCancelled() { + LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(); + MockServerHttpRequest request = MockServerHttpRequest.get("/").build(); + MockServerHttpResponse response = new MockServerHttpResponse(bufferFactory); + response.setWriteHandler(flux -> { + throw AbortedException.beforeSend(); + }); + + HttpMessageWriter messageWriter = new EncoderHttpMessageWriter<>(new Jackson2JsonEncoder()); + Mono result = messageWriter.write(Mono.just(Collections.singletonMap("foo", "bar")), + ResolvableType.forClass(Mono.class), ResolvableType.forClass(Map.class), null, + request, response, Collections.emptyMap()); + + StepVerifier.create(result).expectError(AbortedException.class).verify(); + + bufferFactory.checkForLeaks(); + } private DefaultDataBuffer wrap(String a) { return new DefaultDataBufferFactory().wrap(ByteBuffer.wrap(a.getBytes(StandardCharsets.UTF_8)));