From c820e44a995b7a6c4b76af67e51e86f90358c561 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Fri, 19 Jan 2024 15:37:19 +0100 Subject: [PATCH] Unwrap request factory when creating RestClient from RestTemplate This commit makes sure that, when building a RestClient based on the configuration of a RestTemplate, the request factory is unwrapped if it is a InterceptingClientHttpRequestFactory. Closes gh-32038 --- ...stractClientHttpRequestFactoryWrapper.java | 11 ++- .../web/client/DefaultRestClientBuilder.java | 13 ++- .../web/client/RestClientBuilderTests.java | 88 +++++++++++++++++++ 3 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java diff --git a/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequestFactoryWrapper.java b/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequestFactoryWrapper.java index 2f8a6446ab0..757f3f4410b 100644 --- a/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequestFactoryWrapper.java +++ b/spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequestFactoryWrapper.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ import org.springframework.util.Assert; /** * Abstract base class for {@link ClientHttpRequestFactory} implementations - * that decorate another request factory. + * that decorate another delegate request factory. * * @author Arjen Poutsma * @since 3.1 @@ -54,6 +54,13 @@ public abstract class AbstractClientHttpRequestFactoryWrapper implements ClientH return createRequest(uri, httpMethod, this.requestFactory); } + /** + * Return the delegate request factory. + */ + public ClientHttpRequestFactory getDelegate() { + return this.requestFactory; + } + /** * Create a new {@link ClientHttpRequest} for the specified URI and HTTP method * by using the passed-on request factory. diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java index 2dfd0d7bdc8..1f06d2b2fa2 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java @@ -32,6 +32,7 @@ import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInitializer; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.JdkClientHttpRequestFactory; import org.springframework.http.client.JettyClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; @@ -177,7 +178,7 @@ final class DefaultRestClientBuilder implements RestClient.Builder { this.statusHandlers = new ArrayList<>(); this.statusHandlers.add(StatusHandler.fromErrorHandler(restTemplate.getErrorHandler())); - this.requestFactory = restTemplate.getRequestFactory(); + this.requestFactory = getRequestFactory(restTemplate); this.messageConverters = new ArrayList<>(restTemplate.getMessageConverters()); if (!CollectionUtils.isEmpty(restTemplate.getInterceptors())) { @@ -190,6 +191,16 @@ final class DefaultRestClientBuilder implements RestClient.Builder { this.observationConvention = restTemplate.getObservationConvention(); } + private static ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) { + ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); + if (requestFactory instanceof InterceptingClientHttpRequestFactory interceptingClientHttpRequestFactory) { + return interceptingClientHttpRequestFactory.getDelegate(); + } + else { + return requestFactory; + } + } + @Override public RestClient.Builder baseUrl(String baseUrl) { diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java new file mode 100644 index 00000000000..7839c992c84 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.lang.reflect.Field; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.http.client.ClientHttpRequestInitializer; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.JettyClientHttpRequestFactory; +import org.springframework.http.client.support.BasicAuthenticationInterceptor; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.web.util.DefaultUriBuilderFactory; + +import static org.assertj.core.api.Assertions.assertThat; + + +/** + * @author Arjen Poutsma + */ +public class RestClientBuilderTests { + + @SuppressWarnings("unchecked") + @Test + void createFromRestTemplate() throws NoSuchFieldException, IllegalAccessException { + JettyClientHttpRequestFactory requestFactory = new JettyClientHttpRequestFactory(); + DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory(); + ResponseErrorHandler errorHandler = new DefaultResponseErrorHandler(); + List> restTemplateMessageConverters = List.of(new StringHttpMessageConverter()); + ClientHttpRequestInterceptor interceptor = new BasicAuthenticationInterceptor("foo", "bar"); + ClientHttpRequestInitializer initializer = request -> {}; + + RestTemplate restTemplate = new RestTemplate(requestFactory); + restTemplate.setUriTemplateHandler(uriBuilderFactory); + restTemplate.setErrorHandler(errorHandler); + restTemplate.setMessageConverters(restTemplateMessageConverters); + restTemplate.setInterceptors(List.of(interceptor)); + restTemplate.setClientHttpRequestInitializers(List.of(initializer)); + + RestClient.Builder builder = RestClient.builder(restTemplate); + assertThat(builder).isInstanceOf(DefaultRestClientBuilder.class); + DefaultRestClientBuilder defaultBuilder = (DefaultRestClientBuilder) builder; + + assertThat(fieldValue("requestFactory", defaultBuilder)).isSameAs(requestFactory); + assertThat(fieldValue("uriBuilderFactory", defaultBuilder)).isSameAs(uriBuilderFactory); + + List statusHandlers = (List) fieldValue("statusHandlers", defaultBuilder); + assertThat(statusHandlers).hasSize(1); + + List> restClientMessageConverters = + (List>) fieldValue("messageConverters", defaultBuilder); + assertThat(restClientMessageConverters).containsExactlyElementsOf(restClientMessageConverters); + + List interceptors = + (List) fieldValue("interceptors", defaultBuilder); + assertThat(interceptors).containsExactly(interceptor); + + List initializers = + (List) fieldValue("initializers", defaultBuilder); + assertThat(initializers).containsExactly(initializer); + } + + private static Object fieldValue(String name, DefaultRestClientBuilder instance) + throws NoSuchFieldException, IllegalAccessException { + + Field field = DefaultRestClientBuilder.class.getDeclaredField(name); + field.setAccessible(true); + + return field.get(instance); + } +}