diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java index e8a082a6b62..4878ca78c2f 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java @@ -18,11 +18,15 @@ package org.springframework.web.reactive.function.client; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.UncheckedIOException; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import java.net.ServerSocket; +import java.net.Socket; import java.net.URI; import java.nio.charset.StandardCharsets; import java.nio.file.Files; @@ -31,16 +35,21 @@ import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; +import okhttp3.mockwebserver.SocketPolicy; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.util.SocketUtils; +import org.springframework.web.reactive.function.client.WebClient.ResponseSpec; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.http.client.HttpClient; @@ -83,7 +92,7 @@ class WebClientIntegrationTests { @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) - @ParameterizedTest(name = "[{index}] webClient [{0}]") + @ParameterizedTest(name = "[{index}] {displayName} [{0}]") @MethodSource("arguments") @interface ParameterizedWebClientTest { } @@ -113,7 +122,9 @@ class WebClientIntegrationTests { @AfterEach void shutdown() throws IOException { - this.server.shutdown(); + if (server != null) { + this.server.shutdown(); + } } @@ -1209,6 +1220,135 @@ class WebClientIntegrationTests { .verify(); } + static Stream socketFaultArguments() { + Stream.Builder argumentsBuilder = Stream.builder(); + arguments().forEach(arg -> { + argumentsBuilder.accept(Arguments.of(arg, SocketPolicy.DISCONNECT_AT_START)); + argumentsBuilder.accept(Arguments.of(arg, SocketPolicy.DISCONNECT_DURING_REQUEST_BODY)); + argumentsBuilder.accept(Arguments.of(arg, SocketPolicy.DISCONNECT_AFTER_REQUEST)); + }); + return argumentsBuilder.build(); + } + + @ParameterizedTest(name = "[{index}] {displayName} [{0}, {1}]") + @MethodSource("socketFaultArguments") + void prematureClosureFault(ClientHttpConnector connector, SocketPolicy socketPolicy) { + startServer(connector); + + prepareResponse(response -> response + .setSocketPolicy(socketPolicy) + .setStatus("HTTP/1.1 200 OK") + .setHeader("Response-Header-1", "value 1") + .setHeader("Response-Header-2", "value 2") + .setBody("{\"message\": \"Hello, World!\"}")); + + String uri = "/test"; + Mono result = this.webClient + .post() + .uri(uri) + // Random non-empty body to allow us to interrupt. + .bodyValue("{\"action\": \"Say hello!\"}") + .retrieve() + .bodyToMono(String.class); + + StepVerifier.create(result) + .expectErrorSatisfies(throwable -> { + assertThat(throwable).isInstanceOf(WebClientRequestException.class); + WebClientRequestException ex = (WebClientRequestException) throwable; + // Varies between connector providers. + assertThat(ex.getCause()).isInstanceOf(IOException.class); + }) + .verify(); + } + + static Stream malformedResponseChunkArguments() { + return Stream.of( + Arguments.of(new ReactorClientHttpConnector(), true), + Arguments.of(new JettyClientHttpConnector(), true), + // Apache injects the Transfer-Encoding header for us, and complains with an exception if we also + // add it. The other two connectors do not add the header at all. We need this header for the test + // case to work correctly. + Arguments.of(new HttpComponentsClientHttpConnector(), false) + ); + } + + @ParameterizedTest(name = "[{index}] {displayName} [{0}, {1}]") + @MethodSource("malformedResponseChunkArguments") + void malformedResponseChunksOnBodilessEntity(ClientHttpConnector connector, boolean addTransferEncodingHeader) { + Mono result = doMalformedResponseChunks(connector, addTransferEncodingHeader, ResponseSpec::toBodilessEntity); + + StepVerifier.create(result) + .expectErrorSatisfies(throwable -> { + assertThat(throwable).isInstanceOf(WebClientException.class); + WebClientException ex = (WebClientException) throwable; + assertThat(ex.getCause()).isInstanceOf(IOException.class); + }) + .verify(); + } + + @ParameterizedTest(name = "[{index}] {displayName} [{0}, {1}]") + @MethodSource("malformedResponseChunkArguments") + void malformedResponseChunksOnEntityWithBody(ClientHttpConnector connector, boolean addTransferEncodingHeader) { + Mono result = doMalformedResponseChunks(connector, addTransferEncodingHeader, spec -> spec.toEntity(String.class)); + + StepVerifier.create(result) + .expectErrorSatisfies(throwable -> { + assertThat(throwable).isInstanceOf(WebClientException.class); + WebClientException ex = (WebClientException) throwable; + assertThat(ex.getCause()).isInstanceOf(IOException.class); + }) + .verify(); + } + + private Mono doMalformedResponseChunks( + ClientHttpConnector connector, + boolean addTransferEncodingHeader, + Function> responseHandler + ) { + int port = SocketUtils.findAvailableTcpPort(); + + Thread serverThread = new Thread(() -> { + // This exists separately to the main mock server, as I had a really hard time getting that to send the + // chunked responses correctly, flushing the socket each time. This was the only way I was able to replicate + // the issue of the client not handling malformed response chunks correctly. + try (ServerSocket serverSocket = new ServerSocket(port)) { + Socket socket = serverSocket.accept(); + InputStream is = socket.getInputStream(); + + //noinspection ResultOfMethodCallIgnored + is.read(new byte[4096]); + + OutputStream os = socket.getOutputStream(); + os.write("HTTP/1.1 200 OK\r\n".getBytes(StandardCharsets.UTF_8)); + os.write("Transfer-Encoding: chunked\r\n".getBytes(StandardCharsets.UTF_8)); + os.write("\r\n".getBytes(StandardCharsets.UTF_8)); + os.write("lskdu018973t09sylgasjkfg1][]'./.sdlv".getBytes(StandardCharsets.UTF_8)); + socket.close(); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + + serverThread.setDaemon(true); + serverThread.start(); + + ResponseSpec spec = WebClient + .builder() + .clientConnector(connector) + .baseUrl("http://localhost:" + port) + .build() + .post() + .headers(headers -> { + if (addTransferEncodingHeader) { + headers.add(HttpHeaders.TRANSFER_ENCODING, "chunked"); + } + }) + .retrieve(); + + return responseHandler + .apply(spec) + .doFinally(signal -> serverThread.stop()); + } private void prepareResponse(Consumer consumer) { MockResponse response = new MockResponse(); @@ -1252,5 +1392,4 @@ class WebClientIntegrationTests { this.containerValue = containerValue; } } - }