diff --git a/spring-web/src/main/java/org/springframework/web/client/reactive/DefaultWebClientBuilder.java b/spring-web/src/main/java/org/springframework/web/client/reactive/DefaultWebClientBuilder.java index 26876ce0003..a6eb33ea1dc 100644 --- a/spring-web/src/main/java/org/springframework/web/client/reactive/DefaultWebClientBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/client/reactive/DefaultWebClientBuilder.java @@ -95,6 +95,14 @@ class DefaultWebClientBuilder implements WebClient.Builder { this.strategies)); } + @Override + public WebClient filter(ExchangeFilterFunction filter) { + Assert.notNull(filter, "'filter' must not be null"); + + ExchangeFilterFunction composedFilter = filter.andThen(this.filter); + + return new DefaultWebClient(this.clientHttpConnector, this.strategies, composedFilter); + } } private class NoOpFilter implements ExchangeFilterFunction { diff --git a/spring-web/src/main/java/org/springframework/web/client/reactive/WebClient.java b/spring-web/src/main/java/org/springframework/web/client/reactive/WebClient.java index afdef85611d..38d138b63e0 100644 --- a/spring-web/src/main/java/org/springframework/web/client/reactive/WebClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/reactive/WebClient.java @@ -50,6 +50,15 @@ public interface WebClient extends ExchangeFunction { @Override Mono exchange(ClientRequest request); + /** + * Filters this client with the given {@code ExchangeFilterFunction}, resulting in a filtered + * {@code WebClient}. + * @param filterFunction the filter to apply to this client + * @return the filtered client + * @see ExchangeFilterFunction#apply(ExchangeFunction) + */ + WebClient filter(ExchangeFilterFunction filterFunction); + /** * Create a new instance of {@code WebClient} with the given connector. This method uses diff --git a/spring-web/src/test/java/org/springframework/web/client/reactive/WebClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/reactive/WebClientIntegrationTests.java index 1597c628c06..a2c773c0cf8 100644 --- a/spring-web/src/test/java/org/springframework/web/client/reactive/WebClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/reactive/WebClientIntegrationTests.java @@ -268,7 +268,7 @@ public class WebClientIntegrationTests { } @Test - public void filter() throws Exception { + public void buildFilter() throws Exception { HttpUrl baseUrl = server.url("/greeting?name=Spring"); this.server.enqueue(new MockResponse().setHeader("Content-Type", "text/plain").setBody("Hello Spring!")); @@ -296,6 +296,35 @@ public class WebClientIntegrationTests { } + @Test + public void filter() throws Exception { + HttpUrl baseUrl = server.url("/greeting?name=Spring"); + this.server.enqueue(new MockResponse().setHeader("Content-Type", "text/plain").setBody("Hello Spring!")); + + ExchangeFilterFunction filter = (request, next) -> { + ClientRequest filteredRequest = ClientRequest.from(request) + .header("foo", "bar").build(); + return next.exchange(filteredRequest); + }; + WebClient client = WebClient.create(new ReactorClientHttpConnector()); + WebClient filteredClient = client.filter(filter); + + ClientRequest request = ClientRequest.GET(baseUrl.toString()).build(); + + Mono result = filteredClient.exchange(request) + .then(response -> response.body(toMono(String.class))); + + StepVerifier.create(result) + .expectNext("Hello Spring!") + .expectComplete() + .verify(); + + RecordedRequest recordedRequest = server.takeRequest(); + assertEquals(1, server.getRequestCount()); + assertEquals("bar", recordedRequest.getHeader("foo")); + + } + @After public void tearDown() throws Exception { this.server.shutdown();