diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java index 927fcdf205d..0931bf21407 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java @@ -203,7 +203,7 @@ class DefaultClientResponse implements ClientResponse { return bytes; }) .defaultIfEmpty(EMPTY) - .onErrorReturn(IllegalStateException.class::isInstance, EMPTY) + .onErrorReturn(ex -> !(ex instanceof Error), EMPTY) .map(bodyBytes -> { HttpRequest request = this.requestSupplier.get(); Charset charset = headers().contentType().map(MimeType::getCharset).orElse(null); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index aebc7760ac0..efaf758ef43 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -612,7 +612,9 @@ class DefaultWebClient implements WebClient { public Mono> toBodilessEntity() { return this.responseMono.flatMap(response -> WebClientUtils.mapToEntity(response, handleBodyMono(response, Mono.empty())) - .flatMap(entity -> response.releaseBody().thenReturn(entity)) + .flatMap(entity -> response.releaseBody() + .onErrorResume(WebClientUtils.WRAP_EXCEPTION_PREDICATE, exceptionWrappingFunction(response)) + .thenReturn(entity)) ); } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java index e8a082a6b62..e7fa002ff51 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java @@ -18,11 +18,15 @@ package org.springframework.web.reactive.function.client; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.UncheckedIOException; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import java.net.ServerSocket; +import java.net.Socket; import java.net.URI; import java.nio.charset.StandardCharsets; import java.nio.file.Files; @@ -31,6 +35,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -64,7 +69,9 @@ import org.springframework.http.client.reactive.ClientHttpConnector; import org.springframework.http.client.reactive.HttpComponentsClientHttpConnector; import org.springframework.http.client.reactive.JettyClientHttpConnector; import org.springframework.http.client.reactive.ReactorClientHttpConnector; +import org.springframework.util.SocketUtils; import org.springframework.web.reactive.function.BodyExtractors; +import org.springframework.web.reactive.function.client.WebClient.ResponseSpec; import org.springframework.web.testfixture.xml.Pojo; import static org.assertj.core.api.Assertions.assertThat; @@ -83,7 +90,7 @@ class WebClientIntegrationTests { @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) - @ParameterizedTest(name = "[{index}] webClient [{0}]") + @ParameterizedTest(name = "[{index}] {displayName} [{0}]") @MethodSource("arguments") @interface ParameterizedWebClientTest { } @@ -113,7 +120,9 @@ class WebClientIntegrationTests { @AfterEach void shutdown() throws IOException { - this.server.shutdown(); + if (server != null) { + this.server.shutdown(); + } } @@ -1209,6 +1218,65 @@ class WebClientIntegrationTests { .verify(); } + @ParameterizedWebClientTest + void malformedResponseChunksOnBodilessEntity(ClientHttpConnector connector) { + Mono result = doMalformedChunkedResponseTest(connector, ResponseSpec::toBodilessEntity); + StepVerifier.create(result) + .expectErrorSatisfies(throwable -> { + assertThat(throwable).isInstanceOf(WebClientException.class); + WebClientException ex = (WebClientException) throwable; + assertThat(ex.getCause()).isInstanceOf(IOException.class); + }) + .verify(); + } + + @ParameterizedWebClientTest + void malformedResponseChunksOnEntityWithBody(ClientHttpConnector connector) { + Mono result = doMalformedChunkedResponseTest(connector, spec -> spec.toEntity(String.class)); + StepVerifier.create(result) + .expectErrorSatisfies(throwable -> { + assertThat(throwable).isInstanceOf(WebClientException.class); + WebClientException ex = (WebClientException) throwable; + assertThat(ex.getCause()).isInstanceOf(IOException.class); + }) + .verify(); + } + + private Mono doMalformedChunkedResponseTest( + ClientHttpConnector connector, Function> handler) { + + int port = SocketUtils.findAvailableTcpPort(); + + Thread serverThread = new Thread(() -> { + // No way to simulate a malformed chunked response through MockWebServer. + try (ServerSocket serverSocket = new ServerSocket(port)) { + Socket socket = serverSocket.accept(); + InputStream is = socket.getInputStream(); + + //noinspection ResultOfMethodCallIgnored + is.read(new byte[4096]); + + OutputStream os = socket.getOutputStream(); + os.write("HTTP/1.1 200 OK\r\n".getBytes(StandardCharsets.UTF_8)); + os.write("Transfer-Encoding: chunked\r\n".getBytes(StandardCharsets.UTF_8)); + os.write("\r\n".getBytes(StandardCharsets.UTF_8)); + os.write("lskdu018973t09sylgasjkfg1][]'./.sdlv".getBytes(StandardCharsets.UTF_8)); + socket.close(); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + + serverThread.start(); + + WebClient client = WebClient.builder() + .clientConnector(connector) + .baseUrl("http://localhost:" + port) + .build(); + + return handler.apply(client.post().retrieve()); + } private void prepareResponse(Consumer consumer) { MockResponse response = new MockResponse(); @@ -1252,5 +1320,4 @@ class WebClientIntegrationTests { this.containerValue = containerValue; } } - }