diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index f301a9691ea..de2aeb899e8 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.IntPredicate; import java.util.function.Predicate; import java.util.function.Supplier; @@ -416,6 +417,8 @@ class DefaultWebClient implements WebClient { private static class DefaultResponseSpec implements ResponseSpec { + private static final IntPredicate STATUS_CODE_ERROR = value -> value >= 400; + private final Mono responseMono; private final Supplier requestSupplier; @@ -425,17 +428,30 @@ class DefaultWebClient implements WebClient { DefaultResponseSpec(Mono responseMono, Supplier requestSupplier) { this.responseMono = responseMono; this.requestSupplier = requestSupplier; - this.statusHandlers.add(new StatusHandler(HttpStatus::isError, ClientResponse::createException)); + this.statusHandlers.add(new StatusHandler(STATUS_CODE_ERROR, ClientResponse::createException)); } @Override public ResponseSpec onStatus(Predicate statusPredicate, Function> exceptionFunction) { + return onRawStatus(toIntPredicate(statusPredicate), exceptionFunction); + } - Assert.notNull(statusPredicate, "StatusPredicate must not be null"); + private static IntPredicate toIntPredicate(Predicate predicate) { + return value -> { + HttpStatus status = HttpStatus.resolve(value); + return (status != null) && predicate.test(status); + }; + } + + @Override + public ResponseSpec onRawStatus(IntPredicate statusCodePredicate, + Function> exceptionFunction) { + + Assert.notNull(statusCodePredicate, "StatusCodePredicate must not be null"); Assert.notNull(exceptionFunction, "Function must not be null"); - this.statusHandlers.add(0, new StatusHandler(statusPredicate, exceptionFunction)); + this.statusHandlers.add(0, new StatusHandler(statusCodePredicate, exceptionFunction)); return this; } @@ -452,17 +468,12 @@ class DefaultWebClient implements WebClient { } private Mono handleBodyMono(ClientResponse response, Mono bodyPublisher) { - if (HttpStatus.resolve(response.rawStatusCode()) != null) { - Mono result = statusHandlers(response); - if (result != null) { - return result.switchIfEmpty(bodyPublisher); - } - else { - return bodyPublisher; - } + Mono result = statusHandlers(response); + if (result != null) { + return result.switchIfEmpty(bodyPublisher); } else { - return response.createException().flatMap(Mono::error); + return bodyPublisher; } } @@ -479,24 +490,20 @@ class DefaultWebClient implements WebClient { } private Publisher handleBodyFlux(ClientResponse response, Flux bodyPublisher) { - if (HttpStatus.resolve(response.rawStatusCode()) != null) { - Mono result = statusHandlers(response); - if (result != null) { - return result.flux().switchIfEmpty(bodyPublisher); - } - else { - return bodyPublisher; - } + Mono result = statusHandlers(response); + if (result != null) { + return result.flux().switchIfEmpty(bodyPublisher); } else { - return response.createException().flatMap(Mono::error); + return bodyPublisher; } } @Nullable private Mono statusHandlers(ClientResponse response) { + int statusCode = response.rawStatusCode(); for (StatusHandler handler : this.statusHandlers) { - if (handler.test(response.statusCode())) { + if (handler.test(statusCode)) { Mono exMono; try { exMono = handler.apply(response); @@ -508,7 +515,7 @@ class DefaultWebClient implements WebClient { } Mono result = exMono.flatMap(Mono::error); HttpRequest request = this.requestSupplier.get(); - return insertCheckpoint(result, response.statusCode(), request); + return insertCheckpoint(result, statusCode, request); } } return null; @@ -522,10 +529,10 @@ class DefaultWebClient implements WebClient { .onErrorResume(ex2 -> Mono.empty()).thenReturn(ex); } - private Mono insertCheckpoint(Mono result, HttpStatus status, HttpRequest request) { + private Mono insertCheckpoint(Mono result, int statusCode, HttpRequest request) { String httpMethod = request.getMethodValue(); URI uri = request.getURI(); - String description = status + " from " + httpMethod + " " + uri + " [DefaultWebClient]"; + String description = statusCode + " from " + httpMethod + " " + uri + " [DefaultWebClient]"; return result.checkpoint(description); } @@ -558,24 +565,25 @@ class DefaultWebClient implements WebClient { private static class StatusHandler { - private final Predicate predicate; + private final IntPredicate predicate; private final Function> exceptionFunction; - public StatusHandler(Predicate predicate, + public StatusHandler(IntPredicate predicate, Function> exceptionFunction) { this.predicate = predicate; this.exceptionFunction = exceptionFunction; } - public boolean test(HttpStatus status) { + public boolean test(int status) { return this.predicate.test(status); } public Mono apply(ClientResponse response) { return this.exceptionFunction.apply(response); } + } } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java index 5f95d47a1d0..16e08df70a5 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.IntPredicate; import java.util.function.Predicate; import org.reactivestreams.Publisher; @@ -691,6 +692,24 @@ public interface WebClient { ResponseSpec onStatus(Predicate statusPredicate, Function> exceptionFunction); + /** + * Register a custom error function that gets invoked when the given raw status code + * predicate applies. The exception returned from the function will be returned from + * {@link #bodyToMono(Class)} and {@link #bodyToFlux(Class)}. + *

By default, an error handler is registered that throws a + * {@link WebClientResponseException} when the response status code is 4xx or 5xx. + * @param statusCodePredicate a predicate of the raw status code that indicates + * whether {@code exceptionFunction} applies. + *

NOTE: if the response is expected to have content, + * the exceptionFunction should consume it. If not, the content will be + * automatically drained to ensure resources are released. + * @param exceptionFunction the function that returns the exception + * @return this builder + * @since 5.1.9 + */ + ResponseSpec onRawStatus(IntPredicate statusCodePredicate, + Function> exceptionFunction); + /** * Extract the body to a {@code Mono}. By default, if the response has status code 4xx or * 5xx, the {@code Mono} will contain a {@link WebClientException}. This can be overridden diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java index bfa362341ed..d78d4d7bbc1 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java @@ -730,6 +730,28 @@ public class WebClientIntegrationTests { }); } + @Test + public void shouldApplyCustomRawStatusHandler() { + prepareResponse(response -> response.setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody("Internal Server error")); + + Mono result = this.webClient.get() + .uri("/greeting?name=Spring") + .retrieve() + .onRawStatus(value -> value >= 500 && value < 600, response -> Mono.just(new MyException("500 error!"))) + .bodyToMono(String.class); + + StepVerifier.create(result) + .expectError(MyException.class) + .verify(Duration.ofSeconds(3)); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("*/*"); + assertThat(request.getPath()).isEqualTo("/greeting?name=Spring"); + }); + } + @Test public void shouldApplyCustomStatusHandlerParameterizedTypeReference() { prepareResponse(response -> response.setResponseCode(500)