diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/AbstractDataBufferAllocatingTestCase.java b/spring-core/src/test/java/org/springframework/core/io/buffer/AbstractDataBufferAllocatingTestCase.java index 87404da6802..c6be1b7537c 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/AbstractDataBufferAllocatingTestCase.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/AbstractDataBufferAllocatingTestCase.java @@ -17,6 +17,8 @@ package org.springframework.core.io.buffer; import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; import java.util.Arrays; import java.util.List; import java.util.function.Consumer; @@ -36,7 +38,11 @@ import org.springframework.core.io.buffer.support.DataBufferTestUtils; import static org.junit.Assert.*; /** + * Base class for tests that read or write data buffers with a rule to check + * that allocated buffers have been released. + * * @author Arjen Poutsma + * @author Rossen Stoyanchev */ @RunWith(Parameterized.class) public abstract class AbstractDataBufferAllocatingTestCase { @@ -61,6 +67,7 @@ public abstract class AbstractDataBufferAllocatingTestCase { @Rule public final Verifier leakDetector = new LeakDetector(); + protected DataBuffer createDataBuffer(int capacity) { return this.bufferFactory.allocateBuffer(capacity); } @@ -85,30 +92,45 @@ public abstract class AbstractDataBufferAllocatingTestCase { }; } - - private class LeakDetector extends Verifier { - - @Override - protected void verify() throws Throwable { - if (bufferFactory instanceof NettyDataBufferFactory) { - ByteBufAllocator byteBufAllocator = - ((NettyDataBufferFactory) bufferFactory).getByteBufAllocator(); - if (byteBufAllocator instanceof PooledByteBufAllocator) { - PooledByteBufAllocator pooledByteBufAllocator = - (PooledByteBufAllocator) byteBufAllocator; - PooledByteBufAllocatorMetric metric = pooledByteBufAllocator.metric(); - long allocations = calculateAllocations(metric.directArenas()) + - calculateAllocations(metric.heapArenas()); - assertTrue("ByteBuf leak detected: " + allocations + - " allocations were not released", allocations == 0); - } + /** + * Wait until allocations are at 0, or the given duration elapses. + */ + protected void waitForDataBufferRelease(Duration duration) throws InterruptedException { + Instant start = Instant.now(); + while (Instant.now().isBefore(start.plus(duration))) { + try { + verifyAllocations(); + break; + } + catch (AssertionError ex) { + // ignore; } + Thread.sleep(50); } + } - private long calculateAllocations(List metrics) { - return metrics.stream().mapToLong(PoolArenaMetric::numActiveAllocations).sum(); + private void verifyAllocations() { + if (this.bufferFactory instanceof NettyDataBufferFactory) { + ByteBufAllocator allocator = ((NettyDataBufferFactory) this.bufferFactory).getByteBufAllocator(); + if (allocator instanceof PooledByteBufAllocator) { + PooledByteBufAllocatorMetric metric = ((PooledByteBufAllocator) allocator).metric(); + long total = getAllocations(metric.directArenas()) + getAllocations(metric.heapArenas()); + assertEquals("ByteBuf Leak: " + total + " unreleased allocations", 0, total); + } } + } + + private static long getAllocations(List metrics) { + return metrics.stream().mapToLong(PoolArenaMetric::numActiveAllocations).sum(); + } + + protected class LeakDetector extends Verifier { + + @Override + public void verify() { + AbstractDataBufferAllocatingTestCase.this.verifyAllocations(); + } } } diff --git a/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpResponse.java index adbf6b1b5f2..c050e2cdda9 100644 --- a/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpResponse.java @@ -17,6 +17,7 @@ package org.springframework.http.client.reactive; import java.util.Collection; +import java.util.concurrent.atomic.AtomicBoolean; import reactor.core.publisher.Flux; import reactor.ipc.netty.http.client.HttpClientResponse; @@ -26,6 +27,7 @@ import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseCookie; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -43,6 +45,8 @@ class ReactorClientHttpResponse implements ClientHttpResponse { private final HttpClientResponse response; + private final AtomicBoolean bodyConsumed = new AtomicBoolean(); + public ReactorClientHttpResponse(HttpClientResponse response) { this.response = response; @@ -53,6 +57,13 @@ class ReactorClientHttpResponse implements ClientHttpResponse { @Override public Flux getBody() { return response.receive() + .doOnSubscribe(s -> + // WebClient's onStatus handling tries to drain the body, which may + // have also been done by application code in the onStatus callback. + // That relies on the 2nd subscriber being rejected but FluxReceive + // isn't consistent in doing so and may hang without completion. + Assert.state(this.bodyConsumed.compareAndSet(false, true), + "The client response body can only be consumed once.")) .map(buf -> { buf.retain(); return dataBufferFactory.wrap(buf); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index d6d50c5e007..ab01f00d495 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -433,12 +433,22 @@ class DefaultWebClient implements WebClient { private > T bodyToPublisher(ClientResponse response, T bodyPublisher, Function, T> errorFunction) { - return this.statusHandlers.stream() - .filter(statusHandler -> statusHandler.test(response.statusCode())) - .findFirst() - .map(statusHandler -> statusHandler.apply(response)) - .map(errorFunction::apply) - .orElse(bodyPublisher); + for (StatusHandler handler : this.statusHandlers) { + if (handler.test(response.statusCode())) { + Mono exMono = handler.apply(response); + exMono = exMono.flatMap(ex -> drainBody(response, ex)); + exMono = exMono.onErrorResume(ex -> drainBody(response, ex)); + return errorFunction.apply(exMono); + } + } + return bodyPublisher; + } + + @SuppressWarnings("unchecked") + private Mono drainBody(ClientResponse response, Throwable ex) { + // Ensure the body is drained, even if the StatusHandler didn't consume it, + // but ignore errors in case it did consume it. + return (Mono) response.bodyToMono(Void.class).onErrorMap(ex2 -> ex).thenReturn(ex); } private static Mono createResponseException(ClientResponse response) { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java index cc2105b162e..01043d58969 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java @@ -596,6 +596,9 @@ public interface WebClient { * {@link WebClientResponseException} when the response status code is 4xx or 5xx. * @param statusPredicate a predicate that indicates whether {@code exceptionFunction} * applies + *

NOTE: if the response is expected to have content, + * the exceptionFunction should consume it. If not, the content will be + * automatically drained to ensure resources are released. * @param exceptionFunction the function that returns the exception * @return this builder */ diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientDataBufferAllocatingTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientDataBufferAllocatingTests.java new file mode 100644 index 00000000000..075666f322d --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientDataBufferAllocatingTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.function.client; + +import java.time.Duration; +import java.util.function.Function; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelOption; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.io.buffer.AbstractDataBufferAllocatingTestCase; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.client.reactive.ReactorClientHttpConnector; + +import static org.junit.Assert.*; + +/** + * WebClient integration tests focusing on data buffer management. + * @author Rossen Stoyanchev + */ +public class WebClientDataBufferAllocatingTests extends AbstractDataBufferAllocatingTestCase { + + private static final Duration DELAY = Duration.ofSeconds(5); + + + private MockWebServer server; + + private WebClient webClient; + + + @Before + public void setUp() { + this.server = new MockWebServer(); + this.webClient = WebClient + .builder() + .clientConnector(initConnector()) + .baseUrl(this.server.url("/").toString()) + .build(); + } + + private ReactorClientHttpConnector initConnector() { + if (bufferFactory instanceof NettyDataBufferFactory) { + ByteBufAllocator allocator = ((NettyDataBufferFactory) bufferFactory).getByteBufAllocator(); + return new ReactorClientHttpConnector(builder -> builder.option(ChannelOption.ALLOCATOR, allocator)); + } + else { + return new ReactorClientHttpConnector(); + } + } + + @After + public void shutDown() throws InterruptedException { + waitForDataBufferRelease(Duration.ofSeconds(2)); + } + + + @Test + public void bodyToMonoVoid() { + + this.server.enqueue(new MockResponse() + .setResponseCode(201) + .setHeader("Content-Type", "application/json") + .setChunkedBody("{\"foo\" : {\"bar\" : \"123\", \"baz\" : \"456\"}}", 5)); + + Mono mono = this.webClient.get() + .uri("/json").accept(MediaType.APPLICATION_JSON) + .retrieve() + .bodyToMono(Void.class); + + StepVerifier.create(mono).expectComplete().verify(Duration.ofSeconds(3)); + assertEquals(1, this.server.getRequestCount()); + } + + + @Test + public void onStatusWithBodyNotConsumed() { + RuntimeException ex = new RuntimeException("response error"); + testOnStatus(ex, response -> Mono.just(ex)); + } + + @Test + public void onStatusWithBodyConsumed() { + RuntimeException ex = new RuntimeException("response error"); + testOnStatus(ex, response -> response.bodyToMono(Void.class).thenReturn(ex)); + } + + @Test // SPR-17473 + public void onStatusWithMonoErrorAndBodyNotConsumed() { + RuntimeException ex = new RuntimeException("response error"); + testOnStatus(ex, response -> Mono.error(ex)); + } + + @Test + public void onStatusWithMonoErrorAndBodyConsumed() { + RuntimeException ex = new RuntimeException("response error"); + testOnStatus(ex, response -> response.bodyToMono(Void.class).then(Mono.error(ex))); + } + + private void testOnStatus(Throwable expected, + Function> exceptionFunction) { + + HttpStatus errorStatus = HttpStatus.BAD_GATEWAY; + + this.server.enqueue(new MockResponse() + .setResponseCode(errorStatus.value()) + .setHeader("Content-Type", "application/json") + .setChunkedBody("{\"error\" : {\"status\" : 502, \"message\" : \"Bad gateway.\"}}", 5)); + + Mono mono = this.webClient.get() + .uri("/json").accept(MediaType.APPLICATION_JSON) + .retrieve() + .onStatus(status -> status.equals(errorStatus), exceptionFunction) + .bodyToMono(String.class); + + StepVerifier.create(mono).expectErrorSatisfies(actual -> assertSame(expected, actual)).verify(DELAY); + assertEquals(1, this.server.getRequestCount()); + } + +} diff --git a/src/docs/asciidoc/web/webflux-webclient.adoc b/src/docs/asciidoc/web/webflux-webclient.adoc index 85a8b3fb501..fa652874c05 100644 --- a/src/docs/asciidoc/web/webflux-webclient.adoc +++ b/src/docs/asciidoc/web/webflux-webclient.adoc @@ -71,6 +71,10 @@ By default, responses with 4xx or 5xx status codes result in an error of type .bodyToMono(Person.class); ---- +When `onStatus` is used, if the response is expected to have content, then the `onStatus` +callback should consume it. If not, the content will be automatically drained to ensure +resources are released. +