diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/WiretapConnector.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/WiretapConnector.java index bfb88b57bfe..4a0aa9516a5 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/WiretapConnector.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/WiretapConnector.java @@ -31,6 +31,7 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.HttpMethod; import org.springframework.http.client.reactive.ClientHttpConnector; @@ -196,11 +197,21 @@ class WiretapConnector implements ClientHttpConnector { // 1. Mock server never consumed request body (for example, error before read) // 2. FluxExchangeResult: getResponseBodyContent called before getResponseBody //noinspection ConstantConditions - (this.publisher != null ? this.publisher : this.publisherNested) - .onErrorMap(ex -> new IllegalStateException( - "Content has not been consumed, and " + - "an error was raised while attempting to produce it.", ex)) - .subscribe(); + if (this.publisher != null) { + this.publisher.doOnNext(DataBufferUtils::release) + .onErrorMap(ex -> new IllegalStateException( + "Content has not been consumed, and " + + "an error was raised while attempting to produce it.", ex)) + .subscribe(); + } + else if (this.publisherNested != null) { + this.publisherNested + .map(pub -> Flux.from(pub).doOnNext(DataBufferUtils::release)) + .onErrorMap(ex -> new IllegalStateException( + "Content has not been consumed, and " + + "an error was raised while attempting to produce it.", ex)) + .subscribe(); + } } return this.content.asMono(); }); diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java index 4e0e44379f2..24866a8df0c 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java @@ -17,11 +17,18 @@ package org.springframework.test.web.reactive.server; import java.net.URI; +import java.nio.charset.StandardCharsets; import java.time.Duration; +import io.netty.buffer.PooledByteBufAllocator; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.core.testfixture.io.buffer.LeakAwareDataBufferFactory; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.client.reactive.ClientHttpConnector; @@ -30,6 +37,7 @@ import org.springframework.http.client.reactive.ClientHttpResponse; import org.springframework.mock.http.client.reactive.MockClientHttpRequest; import org.springframework.mock.http.client.reactive.MockClientHttpResponse; import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.reactive.function.client.ExchangeFunctions; @@ -44,11 +52,19 @@ import static org.assertj.core.api.Assertions.assertThat; */ public class WiretapConnectorTests { + private final LeakAwareDataBufferFactory bufferFactory = + new LeakAwareDataBufferFactory(new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)); + + @AfterEach + void tearDown() { + this.bufferFactory.checkForLeaks(); + } + @Test public void captureAndClaim() { ClientHttpRequest request = new MockClientHttpRequest(HttpMethod.GET, "/test"); ClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); - ClientHttpConnector connector = (method, uri, fn) -> fn.apply(request).then(Mono.just(response)); + ClientHttpConnector connector = createConnector(request, response); ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("/test")) .header(WebTestClient.WEBTESTCLIENT_REQUEST_ID, "1").build(); @@ -62,4 +78,47 @@ public class WiretapConnectorTests { assertThat(result.getUrl().toString()).isEqualTo("/test"); } + @Test + void shouldReleaseBuffers() { + MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.GET, "/test"); + MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + response.setBody(Flux.just(toDataBuffer("Hello Spring"))); + ClientHttpConnector connector = createConnector(request, response); + + ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("/test")) + .header(WebTestClient.WEBTESTCLIENT_REQUEST_ID, "1").build(); + + WiretapConnector wiretapConnector = new WiretapConnector(connector, null); + ExchangeFunction function = ExchangeFunctions.create(wiretapConnector); + function.exchange(clientRequest).block(ofMillis(0)); + ExchangeResult result = wiretapConnector.getExchangeResult("1", null, Duration.ZERO); + result.getResponseBodyContent(); + } + + @Test + void shouldReleaseBuffersOnlyOnce() { + MockClientHttpRequest request = new MockClientHttpRequest(HttpMethod.GET, "/test"); + MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + response.setBody(Flux.just(toDataBuffer("Hello Spring"), toDataBuffer("Hello Spring"), toDataBuffer("Hello Spring"), toDataBuffer("Hello Spring"))); + ClientHttpConnector connector = createConnector(request, response); + + ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("/test")) + .header(WebTestClient.WEBTESTCLIENT_REQUEST_ID, "1").build(); + + WiretapConnector wiretapConnector = new WiretapConnector(connector, null); + ExchangeFunction function = ExchangeFunctions.create(wiretapConnector); + function.exchange(clientRequest).flatMap(ClientResponse::releaseBody).block(ofMillis(0)); + ExchangeResult result = wiretapConnector.getExchangeResult("1", null, Duration.ZERO); + result.getResponseBodyContent(); + } + + private ClientHttpConnector createConnector(ClientHttpRequest request, ClientHttpResponse response) { + return (method, uri, fn) -> fn.apply(request).then(Mono.just(response)); + } + + private DataBuffer toDataBuffer(String s) { + DataBuffer buffer = this.bufferFactory.allocateBuffer(256); + buffer.write(s.getBytes(StandardCharsets.UTF_8)); + return buffer; + } }