From 7b469f9c6248fcaa522ea2492317dfaf115836c2 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Fri, 28 Oct 2016 11:39:28 +0200 Subject: [PATCH] Improve WebClient error handling This commit introduces two new `WebClient` methods: `retrieveMono` and `retrieveFlux`, both of which offer direct access to the response body. More importantly, these methods publish a WebClientException if the response status code is in the 4xx or 5xx series. Issue: SPR-14852 --- .../reactive/DefaultWebClientBuilder.java | 31 ++++ .../reactive/ExchangeFilterFunctions.java | 137 ++++++++++++++---- .../web/client/reactive/WebClient.java | 39 ++++- .../client/reactive/WebClientException.java | 47 ++++++ .../ExchangeFilterFunctionsTests.java | 74 ++++++++++ .../reactive/WebClientIntegrationTests.java | 88 +++++++++++ 6 files changed, 385 insertions(+), 31 deletions(-) create mode 100644 spring-web/src/main/java/org/springframework/web/client/reactive/WebClientException.java diff --git a/spring-web/src/main/java/org/springframework/web/client/reactive/DefaultWebClientBuilder.java b/spring-web/src/main/java/org/springframework/web/client/reactive/DefaultWebClientBuilder.java index 26876ce0003..d9bf642757f 100644 --- a/spring-web/src/main/java/org/springframework/web/client/reactive/DefaultWebClientBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/client/reactive/DefaultWebClientBuilder.java @@ -18,9 +18,13 @@ package org.springframework.web.client.reactive; import java.util.logging.Level; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.http.client.reactive.ClientHttpConnector; +import org.springframework.http.client.reactive.ClientHttpResponse; +import org.springframework.http.codec.BodyExtractor; +import org.springframework.http.codec.BodyExtractors; import org.springframework.util.Assert; /** @@ -78,6 +82,33 @@ class DefaultWebClientBuilder implements WebClient.Builder { this.filter = filter; } + @Override + public Mono retrieveMono(ClientRequest request, Class elementClass) { + Assert.notNull(request, "'request' must not be null"); + Assert.notNull(elementClass, "'elementClass' must not be null"); + + return retrieve(request, BodyExtractors.toMono(elementClass)) + .then(m -> m); + } + + @Override + public Flux retrieveFlux(ClientRequest request, Class elementClass) { + Assert.notNull(request, "'request' must not be null"); + Assert.notNull(elementClass, "'elementClass' must not be null"); + + return retrieve(request, BodyExtractors.toFlux(elementClass)) + .flatMap(flux -> flux); + } + + private Mono retrieve(ClientRequest request, + BodyExtractor extractor) { + + ExchangeFilterFunction errorFilter = ExchangeFilterFunctions.clientOrServerError(); + + return errorFilter.filter(request, this::exchange) + .map(clientResponse -> clientResponse.body(extractor)); + } + @Override public Mono exchange(ClientRequest request) { Assert.notNull(request, "'request' must not be null"); diff --git a/spring-web/src/main/java/org/springframework/web/client/reactive/ExchangeFilterFunctions.java b/spring-web/src/main/java/org/springframework/web/client/reactive/ExchangeFilterFunctions.java index 47395369e1b..43a612a1928 100644 --- a/spring-web/src/main/java/org/springframework/web/client/reactive/ExchangeFilterFunctions.java +++ b/spring-web/src/main/java/org/springframework/web/client/reactive/ExchangeFilterFunctions.java @@ -18,15 +18,19 @@ package org.springframework.web.client.reactive; import java.nio.charset.StandardCharsets; import java.util.Base64; +import java.util.Optional; +import java.util.function.Function; +import java.util.function.Predicate; import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; import org.springframework.util.Assert; /** * Implementations of {@link ExchangeFilterFunction} that provide various useful request filter - * operations, such as basic authentication. + * operations, such as basic authentication, error handling, etc. * * @author Rob Winch * @author Arjen Poutsma @@ -34,11 +38,97 @@ import org.springframework.util.Assert; */ public abstract class ExchangeFilterFunctions { - private static final Base64.Encoder BASE_64_ENCODER = Base64.getEncoder(); + /** + * Return a filter that will publish a {@link WebClientException} when the + * {@code ClientResponse} has a 4xx status code. + * @return the {@code ExchangeFilterFunction} that publishes a {@code WebClientException} when + * the response has a client error + */ + public static ExchangeFilterFunction clientError() { + return statusError(HttpStatus::is4xxClientError); + } + + /** + * Return a filter that will publish a {@link WebClientException} if the + * {@code ClientResponse} has a 5xx status code. + * @return the {@code ExchangeFilterFunction} that publishes a {@code WebClientException} when + * the response has a server error + */ + public static ExchangeFilterFunction serverError() { + return statusError(HttpStatus::is5xxServerError); + } + /** + * Return a filter that will publish a {@link WebClientException} if the + * {@code ClientResponse} has a 4xx or 5xx status code. + * @return the {@code ExchangeFilterFunction} that publishes a {@code WebClientException} when + * the response has a client or server error + */ + public static ExchangeFilterFunction clientOrServerError() { + return clientError().andThen(serverError()); + } + + private static ExchangeFilterFunction statusError(Predicate predicate) { + Function> mapper = + clientResponse -> { + HttpStatus status = clientResponse.statusCode(); + if (predicate.test(status)) { + return Optional.of(new WebClientException( + "ClientResponse has invalid status code: " + status.value() + + " " + status.getReasonPhrase())); + } + else { + return Optional.empty(); + } + }; + + return errorMapper(mapper); + } /** - * Return a filter that adds an Authorization header for HTTP Basic. + * Return a filter that will publish a {@link WebClientException} if the response satisfies + * the given {@code predicate} function. + * @param predicate the predicate to test the response with + * @return the {@code ExchangeFilterFunction} that publishes a {@code WebClientException} when + * {@code predicate} returns {@code true} + */ + public static ExchangeFilterFunction errorPredicate(Predicate predicate) { + Assert.notNull(predicate, "'predicate' must not be null"); + + Function> mapper = + clientResponse -> { + if (predicate.test(clientResponse)) { + return Optional.of(new WebClientException( + "ClientResponse does not satisfy predicate : " + predicate)); + } + else { + return Optional.empty(); + } + }; + + return errorMapper(mapper); + } + + /** + * Return a filter that maps the response to a potential error. Exceptions returned by + * {@code mapper} will be published as signal in the {@code Mono} return value. + * @param mapper the function that maps from response to optional error + * @return the {@code ExchangeFilterFunction} that propagates the errors provided by + * {@code mapper} + */ + public static ExchangeFilterFunction errorMapper(Function> mapper) { + + Assert.notNull(mapper, "'mapper' must not be null"); + return ExchangeFilterFunction.ofResponseProcessor( + clientResponse -> { + Optional error = mapper.apply(clientResponse); + return error.isPresent() ? Mono.error(error.get()) : Mono.just(clientResponse); + }); + } + + /** + * Return a filter that adds an Authorization header for HTTP Basic Authentication. * @param username the username to use * @param password the password to use * @return the {@link ExchangeFilterFunction} that adds the Authorization header @@ -47,34 +137,23 @@ public abstract class ExchangeFilterFunctions { Assert.notNull(username, "'username' must not be null"); Assert.notNull(password, "'password' must not be null"); - return new ExchangeFilterFunction() { - - @Override - public Mono filter(ClientRequest request, ExchangeFunction next) { - String authorization = authorization(username, password); - ClientRequest authorizedRequest = ClientRequest.from(request) - .header(HttpHeaders.AUTHORIZATION, authorization) - .body(request.inserter()); - - return next.exchange(authorizedRequest); - } - - private String authorization(String username, String password) { - String credentials = username + ":" + password; - return authorization(credentials); - } - - private String authorization(String credentials) { - byte[] credentialBytes = credentials.getBytes(StandardCharsets.ISO_8859_1); - byte[] encodedBytes = BASE_64_ENCODER.encode(credentialBytes); - String encodedCredentials = new String(encodedBytes, StandardCharsets.ISO_8859_1); - return "Basic " + encodedCredentials; - } - }; - + return ExchangeFilterFunction.ofRequestProcessor( + clientRequest -> { + String authorization = authorization(username, password); + ClientRequest authorizedRequest = ClientRequest.from(clientRequest) + .header(HttpHeaders.AUTHORIZATION, authorization) + .body(clientRequest.inserter()); + return Mono.just(authorizedRequest); + }); } - + private static String authorization(String username, String password) { + String credentials = username + ":" + password; + byte[] credentialBytes = credentials.getBytes(StandardCharsets.ISO_8859_1); + byte[] encodedBytes = Base64.getEncoder().encode(credentialBytes); + String encodedCredentials = new String(encodedBytes, StandardCharsets.ISO_8859_1); + return "Basic " + encodedCredentials; + } } diff --git a/spring-web/src/main/java/org/springframework/web/client/reactive/WebClient.java b/spring-web/src/main/java/org/springframework/web/client/reactive/WebClient.java index 70fc8554d3b..b77f73803ed 100644 --- a/spring-web/src/main/java/org/springframework/web/client/reactive/WebClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/reactive/WebClient.java @@ -16,14 +16,17 @@ package org.springframework.web.client.reactive; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.http.client.reactive.ClientHttpConnector; import org.springframework.util.Assert; /** - * Reactive Web client supporting the HTTP/1.1 protocol. Main entry point is throught the - * {@link #exchange(ClientRequest)} method. + * Reactive Web client supporting the HTTP/1.1 protocol. Main entry point is through the + * {@link #exchange(ClientRequest)} method, or through the + * {@link #retrieveMono(ClientRequest, Class)} and {@link #retrieveFlux(ClientRequest, Class)} + * convenience methods. * *

For example: *

@@ -34,6 +37,11 @@ import org.springframework.util.Assert;
  *   .exchange(request)
  *   .then(response -> response.body(BodyExtractors.toMono(String.class)));
  * 
+ *

or, by using {@link #retrieveMono(ClientRequest, Class)}: + *

+ * Mono<String> result = client.retrieveMono(request, String.class);
+ * 
+ * * @author Brian Clozel * @author Arjen Poutsma * @since 5.0 @@ -43,11 +51,38 @@ public interface WebClient { /** * Exchange the given request for a response mono. Invoking this method performs the actual * HTTP request/response exchange. + *

Note that this method will not publish an exception if the response + * has a 4xx or 5xx status code; as opposed to {@link #retrieveMono(ClientRequest, Class)} and + * {@link #retrieveFlux(ClientRequest, Class)}. * @param request the request to exchange * @return the response, wrapped in a {@code Mono} */ Mono exchange(ClientRequest request); + /** + * Retrieve the body of the response as a {@code Mono}. A 4xx or 5xx status + * code in the response will result in a {@link WebClientException} published in the returned + * {@code Mono}. + * @param request the request to exchange + * @param elementClass the class of element in the {@code Mono} + * @param the element type + * @return the response body as a mono + * @see ExchangeFilterFunctions#clientOrServerError() + */ + Mono retrieveMono(ClientRequest request, Class elementClass); + + /** + * Retrieve the body of the response as a {@code Flux}. A 4xx or 5xx status + * code in the response will result in a {@link WebClientException} published in the returned + * {@code Flux}. + * @param request the request to exchange + * @param elementClass the class of element in the {@code Flux} + * @param the element type + * @return the response body as a flux + * @see ExchangeFilterFunctions#clientOrServerError() + */ + Flux retrieveFlux(ClientRequest request, Class elementClass); + /** * Create a new instance of {@code WebClient} with the given connector. This method uses diff --git a/spring-web/src/main/java/org/springframework/web/client/reactive/WebClientException.java b/spring-web/src/main/java/org/springframework/web/client/reactive/WebClientException.java new file mode 100644 index 00000000000..6039f1326a6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/client/reactive/WebClientException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client.reactive; + +import org.springframework.core.NestedRuntimeException; + +/** + * Exception published by {@link WebClient} in case of errors. + * + * @author Arjen Poutsma + * @since 5.0 + */ +@SuppressWarnings("serial") +public class WebClientException extends NestedRuntimeException { + + /** + * Construct a new instance of {@code WebClientException} with the given message. + * @param msg the message + */ + public WebClientException(String msg) { + super(msg); + } + + /** + * Construct a new instance of {@code WebClientException} with the given message and + * exception. + * @param msg the message + * @param ex the exception + */ + public WebClientException(String msg, Throwable ex) { + super(msg, ex); + } +} diff --git a/spring-web/src/test/java/org/springframework/web/client/reactive/ExchangeFilterFunctionsTests.java b/spring-web/src/test/java/org/springframework/web/client/reactive/ExchangeFilterFunctionsTests.java index 9666ca04ee0..37ce46f3236 100644 --- a/spring-web/src/test/java/org/springframework/web/client/reactive/ExchangeFilterFunctionsTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/reactive/ExchangeFilterFunctionsTests.java @@ -16,15 +16,20 @@ package org.springframework.web.client.reactive; +import java.util.Optional; + import org.junit.Test; import reactor.core.publisher.Mono; +import reactor.test.subscriber.ScriptedSubscriber; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * @author Arjen Poutsma @@ -80,6 +85,75 @@ public class ExchangeFilterFunctionsTests { assertTrue(filterInvoked[0]); } + + @Test + public void clientNoError() throws Exception { + ClientRequest request = ClientRequest.GET("http://example.com").build(); + ClientResponse response = mock(ClientResponse.class); + when(response.statusCode()).thenReturn(HttpStatus.OK); + ExchangeFunction exchange = r -> Mono.just(response); + + ExchangeFilterFunction standardErrors = ExchangeFilterFunctions.clientError(); + + Mono result = standardErrors.filter(request, exchange); + + ScriptedSubscriber.create() + .expectNext(response) + .expectComplete() + .verify(result); + } + + @Test + public void serverError() throws Exception { + ClientRequest request = ClientRequest.GET("http://example.com").build(); + ClientResponse response = mock(ClientResponse.class); + when(response.statusCode()).thenReturn(HttpStatus.INTERNAL_SERVER_ERROR); + ExchangeFunction exchange = r -> Mono.just(response); + + ExchangeFilterFunction standardErrors = ExchangeFilterFunctions.serverError(); + + Mono result = standardErrors.filter(request, exchange); + + ScriptedSubscriber.create() + .expectError(WebClientException.class) + .verify(result); + } + + @Test + public void errorPredicate() throws Exception { + ClientRequest request = ClientRequest.GET("http://example.com").build(); + ClientResponse response = mock(ClientResponse.class); + when(response.statusCode()).thenReturn(HttpStatus.NOT_FOUND); + ExchangeFunction exchange = r -> Mono.just(response); + + ExchangeFilterFunction errorPredicate = ExchangeFilterFunctions + .errorPredicate(clientResponse -> clientResponse.statusCode().is4xxClientError()); + + Mono result = errorPredicate.filter(request, exchange); + + ScriptedSubscriber.create() + .expectError(WebClientException.class) + .verify(result); + } + + + @Test + public void errorMapperFunction() throws Exception { + ClientRequest request = ClientRequest.GET("http://example.com").build(); + ClientResponse response = mock(ClientResponse.class); + ExchangeFunction exchange = r -> Mono.just(response); + + ExchangeFilterFunction errorMapper = ExchangeFilterFunctions + .errorMapper(clientResponse -> Optional.of(new IllegalStateException())); + + Mono result = errorMapper.filter(request, exchange); + + ScriptedSubscriber.create() + .expectError(IllegalStateException.class) + .verify(result); + } + + @Test public void basicAuthentication() throws Exception { ClientRequest request = ClientRequest.GET("http://example.com").build(); diff --git a/spring-web/src/test/java/org/springframework/web/client/reactive/WebClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/reactive/WebClientIntegrationTests.java index e321e8c69a5..653775c3858 100644 --- a/spring-web/src/test/java/org/springframework/web/client/reactive/WebClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/reactive/WebClientIntegrationTests.java @@ -111,6 +111,50 @@ public class WebClientIntegrationTests { assertEquals("/greeting?name=Spring", recordedRequest.getPath()); } + @Test + public void retrieveMono() throws Exception { + HttpUrl baseUrl = server.url("/greeting?name=Spring"); + this.server.enqueue(new MockResponse().setBody("Hello Spring!")); + + ClientRequest request = ClientRequest.GET(baseUrl.toString()).build(); + + Mono result = this.webClient + .retrieveMono(request, String.class); + + ScriptedSubscriber + .create() + .expectNext("Hello Spring!") + .expectComplete() + .verify(result); + + RecordedRequest recordedRequest = server.takeRequest(); + assertEquals(1, server.getRequestCount()); + assertEquals("*/*", recordedRequest.getHeader(HttpHeaders.ACCEPT)); + assertEquals("/greeting?name=Spring", recordedRequest.getPath()); + } + + @Test + public void retrieveFlux() throws Exception { + HttpUrl baseUrl = server.url("/greeting?name=Spring"); + this.server.enqueue(new MockResponse().setBody("Hello Spring!")); + + ClientRequest request = ClientRequest.GET(baseUrl.toString()).build(); + + Flux result = this.webClient + .retrieveFlux(request, String.class); + + ScriptedSubscriber + .create() + .expectNext("Hello Spring!") + .expectComplete() + .verify(result); + + RecordedRequest recordedRequest = server.takeRequest(); + assertEquals(1, server.getRequestCount()); + assertEquals("*/*", recordedRequest.getHeader(HttpHeaders.ACCEPT)); + assertEquals("/greeting?name=Spring", recordedRequest.getPath()); + } + @Test public void jsonString() throws Exception { HttpUrl baseUrl = server.url("/json"); @@ -274,6 +318,50 @@ public class WebClientIntegrationTests { assertEquals("/greeting?name=Spring", recordedRequest.getPath()); } + @Test + public void retrieveNotFound() throws Exception { + HttpUrl baseUrl = server.url("/greeting?name=Spring"); + this.server.enqueue(new MockResponse().setResponseCode(404) + .setHeader("Content-Type", "text/plain").setBody("Not Found")); + + ClientRequest request = ClientRequest.GET(baseUrl.toString()).build(); + + Mono result = this.webClient + .retrieveMono(request, String.class); + + ScriptedSubscriber + .create() + .expectError(WebClientException.class) + .verify(result, Duration.ofSeconds(3)); + + RecordedRequest recordedRequest = server.takeRequest(); + assertEquals(1, server.getRequestCount()); + assertEquals("*/*", recordedRequest.getHeader(HttpHeaders.ACCEPT)); + assertEquals("/greeting?name=Spring", recordedRequest.getPath()); + } + + @Test + public void retrieveServerError() throws Exception { + HttpUrl baseUrl = server.url("/greeting?name=Spring"); + this.server.enqueue(new MockResponse().setResponseCode(500) + .setHeader("Content-Type", "text/plain").setBody("Not Found")); + + ClientRequest request = ClientRequest.GET(baseUrl.toString()).build(); + + Mono result = this.webClient + .retrieveMono(request, String.class); + + ScriptedSubscriber + .create() + .expectError(WebClientException.class) + .verify(result, Duration.ofSeconds(3)); + + RecordedRequest recordedRequest = server.takeRequest(); + assertEquals(1, server.getRequestCount()); + assertEquals("*/*", recordedRequest.getHeader(HttpHeaders.ACCEPT)); + assertEquals("/greeting?name=Spring", recordedRequest.getPath()); + } + @Test public void filter() throws Exception { HttpUrl baseUrl = server.url("/greeting?name=Spring");