From 2f9bd6e075facbf13edd629a98da88115c130b98 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Thu, 6 Jul 2017 15:55:10 +0200 Subject: [PATCH] Add local error handling in WebClient.retrieve This commit introduces a way to customize the WebClientExceptions, as thrown by WebClient.ResponseSpec.bodyTo[Mono|Flux]. The first customization will override the defaults, additional customizations are simply tried in order. Issue: SPR-15724 --- .../org/springframework/http/HttpStatus.java | 10 +++ .../function/client/DefaultWebClient.java | 65 +++++++++++++++---- .../reactive/function/client/WebClient.java | 16 +++++ .../client/WebClientIntegrationTests.java | 21 ++++++ 4 files changed, 100 insertions(+), 12 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/HttpStatus.java b/spring-web/src/main/java/org/springframework/http/HttpStatus.java index ac65c21c7ec..780376406cb 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpStatus.java +++ b/spring-web/src/main/java/org/springframework/http/HttpStatus.java @@ -465,6 +465,16 @@ public enum HttpStatus { return Series.SERVER_ERROR.equals(series()); } + /** + * Whether this status code is in the HTTP series + * {@link org.springframework.http.HttpStatus.Series#CLIENT_ERROR} or + * {@link org.springframework.http.HttpStatus.Series#SERVER_ERROR}. + * This is a shortcut for checking the value of {@link #series()}. + */ + public boolean isError() { + return is4xxClientError() || is5xxServerError(); + } + /** * Returns the HTTP status series of this status code. * @see HttpStatus.Series diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index 6ef4b089a95..c39217adddb 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -21,12 +21,15 @@ import java.nio.charset.Charset; import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; +import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Predicate; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -365,11 +368,52 @@ class DefaultWebClient implements WebClient { private static class DefaultResponseSpec implements ResponseSpec { + private static final Function> DEFAULT_STATUS_HANDLER = + clientResponse -> { + HttpStatus statusCode = clientResponse.statusCode(); + if (statusCode.isError()) { + return Optional.of(new WebClientException( + "ClientResponse has erroneous status code: " + statusCode.value() + + " " + statusCode.getReasonPhrase())); + } else { + return Optional.empty(); + } + }; + private final Mono responseMono; + private List>> statusHandlers = + new ArrayList<>(1); + DefaultResponseSpec(Mono responseMono) { this.responseMono = responseMono; + this.statusHandlers.add(DEFAULT_STATUS_HANDLER); + } + + @Override + public ResponseSpec onStatus(Predicate statusPredicate, + Function exceptionFunction) { + + Assert.notNull(statusPredicate, "'statusPredicate' must not be null"); + Assert.notNull(exceptionFunction, "'exceptionFunction' must not be null"); + + if (this.statusHandlers.size() == 1 && this.statusHandlers.get(0) == DEFAULT_STATUS_HANDLER) { + this.statusHandlers.clear(); + } + + Function> statusHandler = + clientResponse -> { + if (statusPredicate.test(clientResponse.statusCode())) { + return Optional.of(exceptionFunction.apply(clientResponse)); + } + else { + return Optional.empty(); + } + }; + this.statusHandlers.add(statusHandler); + + return this; } @Override @@ -388,18 +432,15 @@ class DefaultWebClient implements WebClient { private > T bodyToPublisher(ClientResponse response, BodyExtractor extractor, - Function errorFunction) { - - HttpStatus status = response.statusCode(); - if (status.is4xxClientError() || status.is5xxServerError()) { - WebClientException ex = new WebClientException( - "ClientResponse has erroneous status code: " + status.value() + - " " + status.getReasonPhrase()); - return errorFunction.apply(ex); - } - else { - return response.body(extractor); - } + Function errorFunction) { + + return this.statusHandlers.stream() + .map(statusHandler -> statusHandler.apply(response)) + .filter(Optional::isPresent) + .findFirst() + .map(Optional::get) + .map(errorFunction::apply) + .orElse(response.body(extractor)); } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java index c38786f8108..d092a4fe0ee 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Predicate; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -30,6 +31,7 @@ import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.client.reactive.ClientHttpConnector; @@ -537,6 +539,20 @@ public interface WebClient { interface ResponseSpec { + /** + * Register a custom error function that gets invoked when the given {@link HttpStatus} + * predicate applies. The exception returned from the function will be returned from + * {@link #bodyToMono(Class)} and {@link #bodyToFlux(Class)}. + *

By default, an error handler is register that throws a {@link WebClientException} + * when the response status code is 4xx or 5xx. + * @param statusPredicate a predicate that indicates whether {@code exceptionFunction} + * applies + * @param exceptionFunction the function that returns the exception + * @return this builder + */ + ResponseSpec onStatus(Predicate statusPredicate, + Function exceptionFunction); + /** * Extract the body to a {@code Mono}. If the response has status code 4xx or 5xx, the * {@code Mono} will contain a {@link WebClientException}. diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java index 3d66178fcb9..7f93b6f2c3d 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java @@ -390,6 +390,27 @@ public class WebClientIntegrationTests { Assert.assertEquals("/greeting?name=Spring", recordedRequest.getPath()); } + @Test + public void retrieveBodyToCustomStatusHandler() throws Exception { + this.server.enqueue(new MockResponse().setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody("Internal Server error")); + + Mono result = this.webClient.get() + .uri("/greeting?name=Spring") + .retrieve() + .onStatus(HttpStatus::is5xxServerError, response -> new MyException("500 error!")) + .bodyToMono(String.class); + + StepVerifier.create(result) + .expectError(MyException.class) + .verify(Duration.ofSeconds(3)); + + RecordedRequest recordedRequest = server.takeRequest(); + Assert.assertEquals(1, server.getRequestCount()); + Assert.assertEquals("*/*", recordedRequest.getHeader(HttpHeaders.ACCEPT)); + Assert.assertEquals("/greeting?name=Spring", recordedRequest.getPath()); + } + @Test public void retrieveToEntityNotFound() throws Exception { this.server.enqueue(new MockResponse().setResponseCode(404)