diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClient.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClient.java index 98bf4361bcc..ef6bde25d26 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClient.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClient.java @@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.ZonedDateTime; import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -46,12 +47,17 @@ import org.springframework.test.util.AssertionErrors; import org.springframework.test.util.JsonExpectationsHelper; import org.springframework.test.util.XmlExpectationsHelper; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MimeType; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyInserter; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; -import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.util.UriBuilder; +import org.springframework.web.util.UriBuilderFactory; /** * Default implementation of {@link WebTestClient}. @@ -61,30 +67,42 @@ import org.springframework.web.util.UriBuilder; */ class DefaultWebTestClient implements WebTestClient { - private final WebClient webClient; - private final WiretapConnector wiretapConnector; - private final Duration timeout; + private final ExchangeFunction exchangeFunction; + + private final UriBuilderFactory uriBuilderFactory; + + @Nullable + private final HttpHeaders defaultHeaders; + + @Nullable + private final MultiValueMap defaultCookies; + + private final Duration responseTimeout; private final DefaultWebTestClientBuilder builder; private final AtomicLong requestIndex = new AtomicLong(); - DefaultWebTestClient(WebClient.Builder clientBuilder, ClientHttpConnector connector, - @Nullable Duration timeout, DefaultWebTestClientBuilder webTestClientBuilder) { + DefaultWebTestClient(ClientHttpConnector connector, + Function exchangeFactory, UriBuilderFactory uriBuilderFactory, + @Nullable HttpHeaders headers, @Nullable MultiValueMap cookies, + @Nullable Duration responseTimeout, DefaultWebTestClientBuilder clientBuilder) { - Assert.notNull(clientBuilder, "WebClient.Builder is required"); this.wiretapConnector = new WiretapConnector(connector); - this.webClient = clientBuilder.clientConnector(this.wiretapConnector).build(); - this.timeout = (timeout != null ? timeout : Duration.ofSeconds(5)); - this.builder = webTestClientBuilder; + this.exchangeFunction = exchangeFactory.apply(this.wiretapConnector); + this.uriBuilderFactory = uriBuilderFactory; + this.defaultHeaders = headers; + this.defaultCookies = cookies; + this.responseTimeout = (responseTimeout != null ? responseTimeout : Duration.ofSeconds(5)); + this.builder = clientBuilder; } - private Duration getTimeout() { - return this.timeout; + private Duration getResponseTimeout() { + return this.responseTimeout; } @@ -124,12 +142,12 @@ class DefaultWebTestClient implements WebTestClient { } @Override - public RequestBodyUriSpec method(HttpMethod method) { - return methodInternal(method); + public RequestBodyUriSpec method(HttpMethod httpMethod) { + return methodInternal(httpMethod); } - private RequestBodyUriSpec methodInternal(HttpMethod method) { - return new DefaultRequestBodyUriSpec(this.webClient.method(method)); + private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) { + return new DefaultRequestBodyUriSpec(httpMethod); } @Override @@ -145,154 +163,180 @@ class DefaultWebTestClient implements WebTestClient { private class DefaultRequestBodyUriSpec implements RequestBodyUriSpec { - private final WebClient.RequestBodyUriSpec bodySpec; + private final HttpMethod httpMethod; + + @Nullable + private URI uri; + + private final HttpHeaders headers; + + @Nullable + private MultiValueMap cookies; + + @Nullable + private BodyInserter inserter; + + private final Map attributes = new LinkedHashMap<>(4); + + @Nullable + private Consumer httpRequestConsumer; @Nullable private String uriTemplate; private final String requestId; - DefaultRequestBodyUriSpec(WebClient.RequestBodyUriSpec spec) { - this.bodySpec = spec; + DefaultRequestBodyUriSpec(HttpMethod httpMethod) { + this.httpMethod = httpMethod; this.requestId = String.valueOf(requestIndex.incrementAndGet()); - this.bodySpec.header(WebTestClient.WEBTESTCLIENT_REQUEST_ID, this.requestId); + this.headers = new HttpHeaders(); + this.headers.add(WebTestClient.WEBTESTCLIENT_REQUEST_ID, this.requestId); } @Override public RequestBodySpec uri(String uriTemplate, Object... uriVariables) { - this.bodySpec.uri(uriTemplate, uriVariables); this.uriTemplate = uriTemplate; - return this; + return uri(uriBuilderFactory.expand(uriTemplate, uriVariables)); } @Override public RequestBodySpec uri(String uriTemplate, Map uriVariables) { - this.bodySpec.uri(uriTemplate, uriVariables); this.uriTemplate = uriTemplate; - return this; + return uri(uriBuilderFactory.expand(uriTemplate, uriVariables)); } @Override public RequestBodySpec uri(Function uriFunction) { - this.bodySpec.uri(uriFunction); this.uriTemplate = null; - return this; + return uri(uriFunction.apply(uriBuilderFactory.builder())); } @Override public RequestBodySpec uri(URI uri) { - this.bodySpec.uri(uri); this.uriTemplate = null; + this.uri = uri; return this; } + private HttpHeaders getHeaders() { + return this.headers; + } + + private MultiValueMap getCookies() { + if (this.cookies == null) { + this.cookies = new LinkedMultiValueMap<>(3); + } + return this.cookies; + } + @Override public RequestBodySpec header(String headerName, String... headerValues) { - this.bodySpec.header(headerName, headerValues); + for (String headerValue : headerValues) { + getHeaders().add(headerName, headerValue); + } return this; } @Override public RequestBodySpec headers(Consumer headersConsumer) { - this.bodySpec.headers(headersConsumer); + headersConsumer.accept(getHeaders()); return this; } @Override public RequestBodySpec attribute(String name, Object value) { - this.bodySpec.attribute(name, value); + this.attributes.put(name, value); return this; } @Override - public RequestBodySpec attributes( - Consumer> attributesConsumer) { - this.bodySpec.attributes(attributesConsumer); + public RequestBodySpec attributes(Consumer> attributesConsumer) { + attributesConsumer.accept(this.attributes); return this; } @Override public RequestBodySpec accept(MediaType... acceptableMediaTypes) { - this.bodySpec.accept(acceptableMediaTypes); + getHeaders().setAccept(Arrays.asList(acceptableMediaTypes)); return this; } @Override public RequestBodySpec acceptCharset(Charset... acceptableCharsets) { - this.bodySpec.acceptCharset(acceptableCharsets); + getHeaders().setAcceptCharset(Arrays.asList(acceptableCharsets)); return this; } @Override public RequestBodySpec contentType(MediaType contentType) { - this.bodySpec.contentType(contentType); + getHeaders().setContentType(contentType); return this; } @Override public RequestBodySpec contentLength(long contentLength) { - this.bodySpec.contentLength(contentLength); + getHeaders().setContentLength(contentLength); return this; } @Override public RequestBodySpec cookie(String name, String value) { - this.bodySpec.cookie(name, value); + getCookies().add(name, value); return this; } @Override - public RequestBodySpec cookies( - Consumer> cookiesConsumer) { - this.bodySpec.cookies(cookiesConsumer); + public RequestBodySpec cookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(getCookies()); return this; } @Override public RequestBodySpec ifModifiedSince(ZonedDateTime ifModifiedSince) { - this.bodySpec.ifModifiedSince(ifModifiedSince); + getHeaders().setIfModifiedSince(ifModifiedSince); return this; } @Override public RequestBodySpec ifNoneMatch(String... ifNoneMatches) { - this.bodySpec.ifNoneMatch(ifNoneMatches); + getHeaders().setIfNoneMatch(Arrays.asList(ifNoneMatches)); return this; } @Override public RequestHeadersSpec bodyValue(Object body) { - this.bodySpec.bodyValue(body); + this.inserter = BodyInserters.fromValue(body); return this; } @Override - public > RequestHeadersSpec body(S publisher, Class elementClass) { - this.bodySpec.body(publisher, elementClass); + public > RequestHeadersSpec body( + P publisher, ParameterizedTypeReference elementTypeRef) { + this.inserter = BodyInserters.fromPublisher(publisher, elementTypeRef); return this; } @Override - public > RequestHeadersSpec body(S publisher, ParameterizedTypeReference elementTypeRef) { - this.bodySpec.body(publisher, elementTypeRef); + public > RequestHeadersSpec body(P publisher, Class elementClass) { + this.inserter = BodyInserters.fromPublisher(publisher, elementClass); return this; } @Override public RequestHeadersSpec body(Object producer, Class elementClass) { - this.bodySpec.body(producer, elementClass); + this.inserter = BodyInserters.fromProducer(producer, elementClass); return this; } @Override public RequestHeadersSpec body(Object producer, ParameterizedTypeReference elementTypeRef) { - this.bodySpec.body(producer, elementTypeRef); + this.inserter = BodyInserters.fromProducer(producer, elementTypeRef); return this; } @Override public RequestHeadersSpec body(BodyInserter inserter) { - this.bodySpec.body(inserter); + this.inserter = inserter; return this; } @@ -304,10 +348,57 @@ class DefaultWebTestClient implements WebTestClient { @Override public ResponseSpec exchange() { - ClientResponse clientResponse = this.bodySpec.exchange().block(getTimeout()); - Assert.state(clientResponse != null, "No ClientResponse"); - ExchangeResult result = wiretapConnector.getExchangeResult(this.requestId, this.uriTemplate, getTimeout()); - return new DefaultResponseSpec(result, clientResponse, getTimeout()); + ClientRequest request = (this.inserter != null ? + initRequestBuilder().body(this.inserter).build() : + initRequestBuilder().build()); + + ClientResponse response = exchangeFunction.exchange(request).block(getResponseTimeout()); + Assert.state(response != null, "No ClientResponse"); + + ExchangeResult result = wiretapConnector.getExchangeResult( + this.requestId, this.uriTemplate, getResponseTimeout()); + + return new DefaultResponseSpec(result, response, getResponseTimeout()); + } + + private ClientRequest.Builder initRequestBuilder() { + ClientRequest.Builder builder = ClientRequest.create(this.httpMethod, initUri()) + .headers(headers -> headers.addAll(initHeaders())) + .cookies(cookies -> cookies.addAll(initCookies())) + .attributes(attributes -> attributes.putAll(this.attributes)); + if (this.httpRequestConsumer != null) { + builder.httpRequest(this.httpRequestConsumer); + } + return builder; + } + + private URI initUri() { + return (this.uri != null ? this.uri : uriBuilderFactory.expand("")); + } + + private HttpHeaders initHeaders() { + if (CollectionUtils.isEmpty(defaultHeaders)) { + return this.headers; + } + HttpHeaders result = new HttpHeaders(); + result.putAll(defaultHeaders); + result.putAll(this.headers); + return result; + } + + private MultiValueMap initCookies() { + if (CollectionUtils.isEmpty(this.cookies)) { + return (defaultCookies != null ? defaultCookies : new LinkedMultiValueMap<>()); + } + else if (CollectionUtils.isEmpty(defaultCookies)) { + return this.cookies; + } + else { + MultiValueMap result = new LinkedMultiValueMap<>(); + result.putAll(defaultCookies); + result.putAll(this.cookies); + return result; + } } } diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClientBuilder.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClientBuilder.java index 69279499fe7..8e1904f3114 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClientBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,20 +17,30 @@ package org.springframework.test.web.reactive.server; import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.Consumer; +import java.util.function.Function; import org.springframework.http.HttpHeaders; import org.springframework.http.client.reactive.ClientHttpConnector; +import org.springframework.http.client.reactive.HttpComponentsClientHttpConnector; +import org.springframework.http.client.reactive.JettyClientHttpConnector; import org.springframework.http.client.reactive.ReactorClientHttpConnector; import org.springframework.http.codec.ClientCodecConfigurer; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.ExchangeFunction; +import org.springframework.web.reactive.function.client.ExchangeFunctions; import org.springframework.web.reactive.function.client.ExchangeStrategies; -import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; +import org.springframework.web.util.DefaultUriBuilderFactory; import org.springframework.web.util.UriBuilderFactory; /** @@ -41,7 +51,21 @@ import org.springframework.web.util.UriBuilderFactory; */ class DefaultWebTestClientBuilder implements WebTestClient.Builder { - private final WebClient.Builder webClientBuilder; + private static final boolean reactorClientPresent; + + private static final boolean jettyClientPresent; + + private static final boolean httpComponentsClientPresent; + + static { + ClassLoader loader = DefaultWebTestClientBuilder.class.getClassLoader(); + reactorClientPresent = ClassUtils.isPresent("reactor.netty.http.client.HttpClient", loader); + jettyClientPresent = ClassUtils.isPresent("org.eclipse.jetty.client.HttpClient", loader); + httpComponentsClientPresent = + ClassUtils.isPresent("org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient", loader) && + ClassUtils.isPresent("org.apache.hc.core5.reactive.ReactiveDataConsumer", loader); + } + @Nullable private final WebHttpHandlerBuilder httpHandlerBuilder; @@ -49,136 +73,243 @@ class DefaultWebTestClientBuilder implements WebTestClient.Builder { @Nullable private final ClientHttpConnector connector; + @Nullable + private String baseUrl; + + @Nullable + private UriBuilderFactory uriBuilderFactory; + + @Nullable + private HttpHeaders defaultHeaders; + + @Nullable + private MultiValueMap defaultCookies; + + @Nullable + private List filters; + + @Nullable + private ExchangeStrategies strategies; + + @Nullable + private List> strategiesConfigurers; + @Nullable private Duration responseTimeout; - /** Connect to server via Reactor Netty. */ + /** Determine connector via classpath detection. */ DefaultWebTestClientBuilder() { - this(new ReactorClientHttpConnector()); - } - - /** Connect to server through the given connector. */ - DefaultWebTestClientBuilder(ClientHttpConnector connector) { - this(null, null, connector, null); + this(null, null); } - /** Connect to given mock server with mock request and response. */ + /** Use HttpHandlerConnector with mock server. */ DefaultWebTestClientBuilder(WebHttpHandlerBuilder httpHandlerBuilder) { - this(null, httpHandlerBuilder, null, null); + this(httpHandlerBuilder, null); } - /** Copy constructor. */ - DefaultWebTestClientBuilder(DefaultWebTestClientBuilder other) { - this(other.webClientBuilder.clone(), other.httpHandlerBuilder, other.connector, - other.responseTimeout); + /** Use given connector. */ + DefaultWebTestClientBuilder(ClientHttpConnector connector) { + this(null, connector); } - private DefaultWebTestClientBuilder(@Nullable WebClient.Builder webClientBuilder, - @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector, - @Nullable Duration responseTimeout) { + DefaultWebTestClientBuilder( + @Nullable WebHttpHandlerBuilder httpHandlerBuilder, @Nullable ClientHttpConnector connector) { - Assert.isTrue(httpHandlerBuilder != null || connector != null, - "Either WebHttpHandlerBuilder or ClientHttpConnector must be provided"); + Assert.isTrue(httpHandlerBuilder == null || connector == null, + "Expected WebHttpHandlerBuilder or ClientHttpConnector but not both."); - this.webClientBuilder = (webClientBuilder != null ? webClientBuilder : WebClient.builder()); - this.httpHandlerBuilder = (httpHandlerBuilder != null ? httpHandlerBuilder.clone() : null); this.connector = connector; - this.responseTimeout = responseTimeout; + this.httpHandlerBuilder = (httpHandlerBuilder != null ? httpHandlerBuilder.clone() : null); + } + + /** Copy constructor. */ + DefaultWebTestClientBuilder(DefaultWebTestClientBuilder other) { + this.httpHandlerBuilder = (other.httpHandlerBuilder != null ? other.httpHandlerBuilder.clone() : null); + this.connector = other.connector; + this.responseTimeout = other.responseTimeout; + + this.baseUrl = other.baseUrl; + this.uriBuilderFactory = other.uriBuilderFactory; + if (other.defaultHeaders != null) { + this.defaultHeaders = new HttpHeaders(); + this.defaultHeaders.putAll(other.defaultHeaders); + } + else { + this.defaultHeaders = null; + } + this.defaultCookies = (other.defaultCookies != null ? + new LinkedMultiValueMap<>(other.defaultCookies) : null); + this.filters = (other.filters != null ? new ArrayList<>(other.filters) : null); + this.strategies = other.strategies; + this.strategiesConfigurers = (other.strategiesConfigurers != null ? + new ArrayList<>(other.strategiesConfigurers) : null); } @Override public WebTestClient.Builder baseUrl(String baseUrl) { - this.webClientBuilder.baseUrl(baseUrl); + this.baseUrl = baseUrl; return this; } @Override public WebTestClient.Builder uriBuilderFactory(UriBuilderFactory uriBuilderFactory) { - this.webClientBuilder.uriBuilderFactory(uriBuilderFactory); + this.uriBuilderFactory = uriBuilderFactory; return this; } @Override - public WebTestClient.Builder defaultHeader(String headerName, String... headerValues) { - this.webClientBuilder.defaultHeader(headerName, headerValues); + public WebTestClient.Builder defaultHeader(String header, String... values) { + initHeaders().put(header, Arrays.asList(values)); return this; } @Override public WebTestClient.Builder defaultHeaders(Consumer headersConsumer) { - this.webClientBuilder.defaultHeaders(headersConsumer); + headersConsumer.accept(initHeaders()); return this; } + private HttpHeaders initHeaders() { + if (this.defaultHeaders == null) { + this.defaultHeaders = new HttpHeaders(); + } + return this.defaultHeaders; + } + @Override - public WebTestClient.Builder defaultCookie(String cookieName, String... cookieValues) { - this.webClientBuilder.defaultCookie(cookieName, cookieValues); + public WebTestClient.Builder defaultCookie(String cookie, String... values) { + initCookies().addAll(cookie, Arrays.asList(values)); return this; } @Override - public WebTestClient.Builder defaultCookies( - Consumer> cookiesConsumer) { - this.webClientBuilder.defaultCookies(cookiesConsumer); + public WebTestClient.Builder defaultCookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(initCookies()); return this; } + private MultiValueMap initCookies() { + if (this.defaultCookies == null) { + this.defaultCookies = new LinkedMultiValueMap<>(3); + } + return this.defaultCookies; + } + @Override public WebTestClient.Builder filter(ExchangeFilterFunction filter) { - this.webClientBuilder.filter(filter); + Assert.notNull(filter, "ExchangeFilterFunction must not be null"); + initFilters().add(filter); return this; } @Override public WebTestClient.Builder filters(Consumer> filtersConsumer) { - this.webClientBuilder.filters(filtersConsumer); + filtersConsumer.accept(initFilters()); return this; } + private List initFilters() { + if (this.filters == null) { + this.filters = new ArrayList<>(); + } + return this.filters; + } + @Override public WebTestClient.Builder codecs(Consumer configurer) { - this.webClientBuilder.codecs(configurer); + if (this.strategiesConfigurers == null) { + this.strategiesConfigurers = new ArrayList<>(4); + } + this.strategiesConfigurers.add(builder -> builder.codecs(configurer)); return this; } @Override public WebTestClient.Builder exchangeStrategies(ExchangeStrategies strategies) { - this.webClientBuilder.exchangeStrategies(strategies); + this.strategies = strategies; return this; } - @SuppressWarnings("deprecation") @Override + @SuppressWarnings("deprecation") public WebTestClient.Builder exchangeStrategies(Consumer configurer) { - this.webClientBuilder.exchangeStrategies(configurer); + if (this.strategiesConfigurers == null) { + this.strategiesConfigurers = new ArrayList<>(4); + } + this.strategiesConfigurers.add(configurer); return this; } @Override - public WebTestClient.Builder responseTimeout(Duration timeout) { - this.responseTimeout = timeout; + public WebTestClient.Builder apply(WebTestClientConfigurer configurer) { + configurer.afterConfigurerAdded(this, this.httpHandlerBuilder, this.connector); return this; } @Override - public WebTestClient.Builder apply(WebTestClientConfigurer configurer) { - configurer.afterConfigurerAdded(this, this.httpHandlerBuilder, this.connector); + public WebTestClient.Builder responseTimeout(Duration timeout) { + this.responseTimeout = timeout; return this; } - @Override public WebTestClient build() { ClientHttpConnector connectorToUse = this.connector; if (connectorToUse == null) { - Assert.state(this.httpHandlerBuilder != null, "No WebHttpHandlerBuilder available"); - connectorToUse = new HttpHandlerConnector(this.httpHandlerBuilder.build()); + if (this.httpHandlerBuilder != null) { + connectorToUse = new HttpHandlerConnector(this.httpHandlerBuilder.build()); + } + } + if (connectorToUse == null) { + connectorToUse = initConnector(); + } + Function exchangeFactory = connector -> { + ExchangeFunction exchange = ExchangeFunctions.create(connector, initExchangeStrategies()); + if (CollectionUtils.isEmpty(this.filters)) { + return exchange; + } + return this.filters.stream() + .reduce(ExchangeFilterFunction::andThen) + .map(filter -> filter.apply(exchange)) + .orElse(exchange); + + }; + return new DefaultWebTestClient(connectorToUse, exchangeFactory, initUriBuilderFactory(), + this.defaultHeaders != null ? HttpHeaders.readOnlyHttpHeaders(this.defaultHeaders) : null, + this.defaultCookies != null ? CollectionUtils.unmodifiableMultiValueMap(this.defaultCookies) : null, + this.responseTimeout, new DefaultWebTestClientBuilder(this)); + } + + private static ClientHttpConnector initConnector() { + if (reactorClientPresent) { + return new ReactorClientHttpConnector(); + } + else if (jettyClientPresent) { + return new JettyClientHttpConnector(); + } + else if (httpComponentsClientPresent) { + return new HttpComponentsClientHttpConnector(); } + throw new IllegalStateException("No suitable default ClientHttpConnector found"); + } - return new DefaultWebTestClient(this.webClientBuilder, - connectorToUse, this.responseTimeout, new DefaultWebTestClientBuilder(this)); + private ExchangeStrategies initExchangeStrategies() { + if (CollectionUtils.isEmpty(this.strategiesConfigurers)) { + return (this.strategies != null ? this.strategies : ExchangeStrategies.withDefaults()); + } + ExchangeStrategies.Builder builder = + (this.strategies != null ? this.strategies.mutate() : ExchangeStrategies.builder()); + this.strategiesConfigurers.forEach(configurer -> configurer.accept(builder)); + return builder.build(); } + private UriBuilderFactory initUriBuilderFactory() { + if (this.uriBuilderFactory != null) { + return this.uriBuilderFactory; + } + return (this.baseUrl != null ? + new DefaultUriBuilderFactory(this.baseUrl) : new DefaultUriBuilderFactory()); + } } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index b8c849c29af..9431b263b4b 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -49,7 +49,6 @@ import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.BodyInserters; -import org.springframework.web.util.DefaultUriBuilderFactory; import org.springframework.web.util.UriBuilder; import org.springframework.web.util.UriBuilderFactory; @@ -85,12 +84,12 @@ class DefaultWebClient implements WebClient { private final DefaultWebClientBuilder builder; - DefaultWebClient(ExchangeFunction exchangeFunction, @Nullable UriBuilderFactory factory, + DefaultWebClient(ExchangeFunction exchangeFunction, UriBuilderFactory uriBuilderFactory, @Nullable HttpHeaders defaultHeaders, @Nullable MultiValueMap defaultCookies, @Nullable Consumer> defaultRequest, DefaultWebClientBuilder builder) { this.exchangeFunction = exchangeFunction; - this.uriBuilderFactory = (factory != null ? factory : new DefaultUriBuilderFactory()); + this.uriBuilderFactory = uriBuilderFactory; this.defaultHeaders = defaultHeaders; this.defaultCookies = defaultCookies; this.defaultRequest = defaultRequest; @@ -251,13 +250,6 @@ class DefaultWebClient implements WebClient { return this; } - @Override - public RequestBodySpec httpRequest(Consumer requestConsumer) { - this.httpRequestConsumer = (this.httpRequestConsumer != null ? - this.httpRequestConsumer.andThen(requestConsumer) : requestConsumer); - return this; - } - @Override public DefaultRequestBodyUriSpec accept(MediaType... acceptableMediaTypes) { getHeaders().setAccept(Arrays.asList(acceptableMediaTypes)); @@ -306,6 +298,13 @@ class DefaultWebClient implements WebClient { return this; } + @Override + public RequestBodySpec httpRequest(Consumer requestConsumer) { + this.httpRequestConsumer = (this.httpRequestConsumer != null ? + this.httpRequestConsumer.andThen(requestConsumer) : requestConsumer); + return this; + } + @Override public RequestHeadersSpec bodyValue(Object body) { this.inserter = BodyInserters.fromValue(body); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java index 80e79090b0e..c8c0727cadf 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java @@ -262,24 +262,26 @@ final class DefaultWebClientBuilder implements WebClient.Builder { @Override public WebClient build() { + ClientHttpConnector connectorToUse = + (this.connector != null ? this.connector : initConnector()); + ExchangeFunction exchange = (this.exchangeFunction == null ? - ExchangeFunctions.create(getOrInitConnector(), initExchangeStrategies()) : + ExchangeFunctions.create(connectorToUse, initExchangeStrategies()) : this.exchangeFunction); + ExchangeFunction filteredExchange = (this.filters != null ? this.filters.stream() .reduce(ExchangeFilterFunction::andThen) .map(filter -> filter.apply(exchange)) .orElse(exchange) : exchange); + return new DefaultWebClient(filteredExchange, initUriBuilderFactory(), this.defaultHeaders != null ? HttpHeaders.readOnlyHttpHeaders(this.defaultHeaders) : null, this.defaultCookies != null ? CollectionUtils.unmodifiableMultiValueMap(this.defaultCookies) : null, this.defaultRequest, new DefaultWebClientBuilder(this)); } - private ClientHttpConnector getOrInitConnector() { - if (this.connector != null) { - return this.connector; - } - else if (reactorClientPresent) { + private ClientHttpConnector initConnector() { + if (reactorClientPresent) { return new ReactorClientHttpConnector(); } else if (jettyClientPresent) { diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java index e5f2df9e252..db58a16560a 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java @@ -43,11 +43,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.when; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; /** * Unit tests for {@link DefaultWebClient}. diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt index c5ec1002b5c..2811cf6cc5c 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt @@ -80,6 +80,7 @@ class WebClientExtensionsTests { } @Test + @Suppress("DEPRECATION") fun awaitExchange() { val response = mockk() every { requestBodySpec.exchange() } returns Mono.just(response)