diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index 196f173246e..62400f6606f 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -181,7 +181,11 @@ final class DefaultRestClient implements RestClient { } private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) { - return new DefaultRequestBodyUriSpec(httpMethod); + DefaultRequestBodyUriSpec spec = new DefaultRequestBodyUriSpec(httpMethod); + if (this.defaultRequest != null) { + this.defaultRequest.accept(spec); + } + return spec; } @Override @@ -456,9 +460,6 @@ final class DefaultRestClient implements RestClient { Observation observation = null; URI uri = null; try { - if (DefaultRestClient.this.defaultRequest != null) { - DefaultRestClient.this.defaultRequest.accept(this); - } uri = initUri(); HttpHeaders headers = initHeaders(); ClientHttpRequest clientRequest = createRequest(uri); diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java index 673ae08cc79..0b0542785f7 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -900,6 +900,29 @@ class RestClientIntegrationTests { expectRequest(request -> assertThat(request.getHeader("foo")).isEqualTo("bar")); } + @ParameterizedRestClientTest + void defaultRequestOverride(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setHeader("Content-Type", "text/plain") + .setBody("Hello Spring!")); + + RestClient headersClient = this.restClient.mutate() + .defaultRequest(request -> request.accept(MediaType.APPLICATION_JSON)) + .build(); + + String result = headersClient.get() + .uri("/greeting") + .accept(MediaType.TEXT_PLAIN) + .retrieve() + .body(String.class); + + assertThat(result).isEqualTo("Hello Spring!"); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getHeader("Accept")).isEqualTo(MediaType.TEXT_PLAIN_VALUE)); + } + private void prepareResponse(Consumer consumer) { MockResponse response = new MockResponse(); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index 9518978298e..55d1c710682 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -177,7 +177,11 @@ final class DefaultWebClient implements WebClient { } private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) { - return new DefaultRequestBodyUriSpec(httpMethod); + DefaultRequestBodyUriSpec spec = new DefaultRequestBodyUriSpec(httpMethod); + if (this.defaultRequest != null) { + this.defaultRequest.accept(spec); + } + return spec; } @Override @@ -479,9 +483,6 @@ final class DefaultWebClient implements WebClient { } private ClientRequest.Builder initRequestBuilder() { - if (defaultRequest != null) { - defaultRequest.accept(this); - } ClientRequest.Builder builder = ClientRequest.create(this.httpMethod, initUri()) .headers(this::initHeaders) .cookies(this::initCookies) diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java index 2434d9a3f8e..16c433429f4 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java @@ -528,6 +528,23 @@ public class DefaultWebClientTests { StepVerifier.create(responsePublisher).expectError(WebClientResponseException.class).verify(); } + @Test // gh-32053 + void defaultRequestOverride() { + WebClient client = this.builder + .defaultRequest(spec -> spec.accept(MediaType.APPLICATION_JSON)) + .build(); + + client.get().uri("/path") + .accept(MediaType.IMAGE_PNG) + .retrieve() + .bodyToMono(Void.class) + .block(Duration.ofSeconds(3)); + + ClientRequest request = verifyAndGetRequest(); + assertThat(request.headers().getAccept()).containsExactly(MediaType.IMAGE_PNG); + } + + private ClientRequest verifyAndGetRequest() { ClientRequest request = this.captor.getValue();