Browse Source

Unwrap request factory when creating RestClient from RestTemplate

This commit makes sure that, when building a RestClient based on the
configuration of a RestTemplate, the request factory is unwrapped if it
is a InterceptingClientHttpRequestFactory.

Closes gh-32038
pull/32069/head
Arjen Poutsma 2 years ago
parent
commit
c820e44a99
  1. 11
      spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequestFactoryWrapper.java
  2. 13
      spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java
  3. 88
      spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java

11
spring-web/src/main/java/org/springframework/http/client/AbstractClientHttpRequestFactoryWrapper.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2015 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,7 +24,7 @@ import org.springframework.util.Assert; @@ -24,7 +24,7 @@ import org.springframework.util.Assert;
/**
* Abstract base class for {@link ClientHttpRequestFactory} implementations
* that decorate another request factory.
* that decorate another delegate request factory.
*
* @author Arjen Poutsma
* @since 3.1
@ -54,6 +54,13 @@ public abstract class AbstractClientHttpRequestFactoryWrapper implements ClientH @@ -54,6 +54,13 @@ public abstract class AbstractClientHttpRequestFactoryWrapper implements ClientH
return createRequest(uri, httpMethod, this.requestFactory);
}
/**
* Return the delegate request factory.
*/
public ClientHttpRequestFactory getDelegate() {
return this.requestFactory;
}
/**
* Create a new {@link ClientHttpRequest} for the specified URI and HTTP method
* by using the passed-on request factory.

13
spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java

@ -32,6 +32,7 @@ import org.springframework.http.client.ClientHttpRequestFactory; @@ -32,6 +32,7 @@ import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInitializer;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.http.client.JettyClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
@ -177,7 +178,7 @@ final class DefaultRestClientBuilder implements RestClient.Builder { @@ -177,7 +178,7 @@ final class DefaultRestClientBuilder implements RestClient.Builder {
this.statusHandlers = new ArrayList<>();
this.statusHandlers.add(StatusHandler.fromErrorHandler(restTemplate.getErrorHandler()));
this.requestFactory = restTemplate.getRequestFactory();
this.requestFactory = getRequestFactory(restTemplate);
this.messageConverters = new ArrayList<>(restTemplate.getMessageConverters());
if (!CollectionUtils.isEmpty(restTemplate.getInterceptors())) {
@ -190,6 +191,16 @@ final class DefaultRestClientBuilder implements RestClient.Builder { @@ -190,6 +191,16 @@ final class DefaultRestClientBuilder implements RestClient.Builder {
this.observationConvention = restTemplate.getObservationConvention();
}
private static ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) {
ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory();
if (requestFactory instanceof InterceptingClientHttpRequestFactory interceptingClientHttpRequestFactory) {
return interceptingClientHttpRequestFactory.getDelegate();
}
else {
return requestFactory;
}
}
@Override
public RestClient.Builder baseUrl(String baseUrl) {

88
spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java

@ -0,0 +1,88 @@ @@ -0,0 +1,88 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.client;
import java.lang.reflect.Field;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.springframework.http.client.ClientHttpRequestInitializer;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.JettyClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.web.util.DefaultUriBuilderFactory;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Arjen Poutsma
*/
public class RestClientBuilderTests {
@SuppressWarnings("unchecked")
@Test
void createFromRestTemplate() throws NoSuchFieldException, IllegalAccessException {
JettyClientHttpRequestFactory requestFactory = new JettyClientHttpRequestFactory();
DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory();
ResponseErrorHandler errorHandler = new DefaultResponseErrorHandler();
List<HttpMessageConverter<?>> restTemplateMessageConverters = List.of(new StringHttpMessageConverter());
ClientHttpRequestInterceptor interceptor = new BasicAuthenticationInterceptor("foo", "bar");
ClientHttpRequestInitializer initializer = request -> {};
RestTemplate restTemplate = new RestTemplate(requestFactory);
restTemplate.setUriTemplateHandler(uriBuilderFactory);
restTemplate.setErrorHandler(errorHandler);
restTemplate.setMessageConverters(restTemplateMessageConverters);
restTemplate.setInterceptors(List.of(interceptor));
restTemplate.setClientHttpRequestInitializers(List.of(initializer));
RestClient.Builder builder = RestClient.builder(restTemplate);
assertThat(builder).isInstanceOf(DefaultRestClientBuilder.class);
DefaultRestClientBuilder defaultBuilder = (DefaultRestClientBuilder) builder;
assertThat(fieldValue("requestFactory", defaultBuilder)).isSameAs(requestFactory);
assertThat(fieldValue("uriBuilderFactory", defaultBuilder)).isSameAs(uriBuilderFactory);
List<StatusHandler> statusHandlers = (List<StatusHandler>) fieldValue("statusHandlers", defaultBuilder);
assertThat(statusHandlers).hasSize(1);
List<HttpMessageConverter<?>> restClientMessageConverters =
(List<HttpMessageConverter<?>>) fieldValue("messageConverters", defaultBuilder);
assertThat(restClientMessageConverters).containsExactlyElementsOf(restClientMessageConverters);
List<ClientHttpRequestInterceptor> interceptors =
(List<ClientHttpRequestInterceptor>) fieldValue("interceptors", defaultBuilder);
assertThat(interceptors).containsExactly(interceptor);
List<ClientHttpRequestInitializer> initializers =
(List<ClientHttpRequestInitializer>) fieldValue("initializers", defaultBuilder);
assertThat(initializers).containsExactly(initializer);
}
private static Object fieldValue(String name, DefaultRestClientBuilder instance)
throws NoSuchFieldException, IllegalAccessException {
Field field = DefaultRestClientBuilder.class.getDeclaredField(name);
field.setAccessible(true);
return field.get(instance);
}
}
Loading…
Cancel
Save