Browse Source

Allow multiple executions of ClientHttpRequestInterceptors

Prior to this commit, an `ClientHttpRequestInterceptor` implementation
could delegate HTTP calls to the next `ClientHttpRequestExecution` only
once. Calling the execution would advance to the next interceptor in the
chain in a mutable fashion for the entire lifetime of the current
exchange.

This commit changes the implementation of `InterceptingClientHttpRequest`
so that a `ClientHttpRequestInterceptor` implementation can call
`ClientHttpRequestExecution#execute` multiple times.

This is especially useful for interceptors in case they want to issue
other HTTP requests without needing another `RestTemplate` or
`RestClient` instance provided out of band.

Closes gh-34169
pull/34207/head
Brian Clozel 12 months ago
parent
commit
e8e722fb59
  1. 94
      spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequest.java
  2. 71
      spring-web/src/test/java/org/springframework/http/client/InterceptingClientHttpRequestFactoryTests.java

94
spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequest.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2024 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,8 +19,8 @@ package org.springframework.http.client; @@ -19,8 +19,8 @@ package org.springframework.http.client;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@ -33,6 +33,7 @@ import org.springframework.util.StreamUtils; @@ -33,6 +33,7 @@ import org.springframework.util.StreamUtils;
* ClientHttpRequestInterceptors}.
*
* @author Arjen Poutsma
* @author Brian Clozel
* @since 3.1
*/
class InterceptingClientHttpRequest extends AbstractBufferingClientHttpRequest {
@ -68,54 +69,71 @@ class InterceptingClientHttpRequest extends AbstractBufferingClientHttpRequest { @@ -68,54 +69,71 @@ class InterceptingClientHttpRequest extends AbstractBufferingClientHttpRequest {
@Override
protected final ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException {
InterceptingRequestExecution requestExecution = new InterceptingRequestExecution();
ClientHttpRequestExecution requestExecution = new DelegatingRequestExecution(this.requestFactory);
ListIterator<ClientHttpRequestInterceptor> iterator = this.interceptors.listIterator(this.interceptors.size());
while (iterator.hasPrevious()) {
ClientHttpRequestInterceptor interceptor = iterator.previous();
requestExecution = new InterceptingRequestExecution(interceptor, requestExecution);
}
return requestExecution.execute(this, bufferedOutput);
}
private class InterceptingRequestExecution implements ClientHttpRequestExecution {
private static class InterceptingRequestExecution implements ClientHttpRequestExecution {
private final ClientHttpRequestInterceptor interceptor;
private final Iterator<ClientHttpRequestInterceptor> iterator;
private final ClientHttpRequestExecution nextExecution;
public InterceptingRequestExecution() {
this.iterator = interceptors.iterator();
public InterceptingRequestExecution(ClientHttpRequestInterceptor interceptor, ClientHttpRequestExecution nextExecution) {
this.interceptor = interceptor;
this.nextExecution = nextExecution;
}
@Override
public ClientHttpResponse execute(HttpRequest request, byte[] body) throws IOException {
if (this.iterator.hasNext()) {
ClientHttpRequestInterceptor nextInterceptor = this.iterator.next();
return nextInterceptor.intercept(request, body, this);
}
else {
HttpMethod method = request.getMethod();
ClientHttpRequest delegate = requestFactory.createRequest(request.getURI(), method);
request.getHeaders().forEach((key, value) -> delegate.getHeaders().addAll(key, value));
request.getAttributes().forEach((key, value) -> delegate.getAttributes().put(key, value));
if (body.length > 0) {
long contentLength = delegate.getHeaders().getContentLength();
if (contentLength > -1 && contentLength != body.length) {
delegate.getHeaders().setContentLength(body.length);
}
if (delegate instanceof StreamingHttpOutputMessage streamingOutputMessage) {
streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() {
@Override
public void writeTo(OutputStream outputStream) throws IOException {
StreamUtils.copy(body, outputStream);
}
@Override
public boolean repeatable() {
return true;
}
});
}
else {
StreamUtils.copy(body, delegate.getBody());
}
return this.interceptor.intercept(request, body, this.nextExecution);
}
}
private static class DelegatingRequestExecution implements ClientHttpRequestExecution {
private final ClientHttpRequestFactory requestFactory;
public DelegatingRequestExecution(ClientHttpRequestFactory requestFactory) {
this.requestFactory = requestFactory;
}
@Override
public ClientHttpResponse execute(HttpRequest request, byte[] body) throws IOException {
HttpMethod method = request.getMethod();
ClientHttpRequest delegate = this.requestFactory.createRequest(request.getURI(), method);
request.getHeaders().forEach((key, value) -> delegate.getHeaders().addAll(key, value));
request.getAttributes().forEach((key, value) -> delegate.getAttributes().put(key, value));
if (body.length > 0) {
long contentLength = delegate.getHeaders().getContentLength();
if (contentLength > -1 && contentLength != body.length) {
delegate.getHeaders().setContentLength(body.length);
}
if (delegate instanceof StreamingHttpOutputMessage streamingOutputMessage) {
streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() {
@Override
public void writeTo(OutputStream outputStream) throws IOException {
StreamUtils.copy(body, outputStream);
}
@Override
public boolean repeatable() {
return true;
}
});
}
else {
StreamUtils.copy(body, delegate.getBody());
}
return delegate.execute();
}
return delegate.execute();
}
}

71
spring-web/src/test/java/org/springframework/http/client/InterceptingClientHttpRequestFactoryTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2024 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,8 +35,10 @@ import org.springframework.web.testfixture.http.client.MockClientHttpResponse; @@ -35,8 +35,10 @@ import org.springframework.web.testfixture.http.client.MockClientHttpResponse;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link InterceptingClientHttpRequestFactory}
* @author Arjen Poutsma
* @author Juergen Hoeller
* @author Brian Clozel
*/
class InterceptingClientHttpRequestFactoryTests {
@ -54,7 +56,7 @@ class InterceptingClientHttpRequestFactoryTests { @@ -54,7 +56,7 @@ class InterceptingClientHttpRequestFactoryTests {
}
@Test
void basic() throws Exception {
void shouldInvokeInterceptors() throws Exception {
List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>();
interceptors.add(new NoOpInterceptor());
interceptors.add(new NoOpInterceptor());
@ -64,31 +66,30 @@ class InterceptingClientHttpRequestFactoryTests { @@ -64,31 +66,30 @@ class InterceptingClientHttpRequestFactoryTests {
ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET);
ClientHttpResponse response = request.execute();
assertThat(((NoOpInterceptor) interceptors.get(0)).invoked).isTrue();
assertThat(((NoOpInterceptor) interceptors.get(1)).invoked).isTrue();
assertThat(((NoOpInterceptor) interceptors.get(2)).invoked).isTrue();
assertThat(((NoOpInterceptor) interceptors.get(0)).invocationCount).isEqualTo(1);
assertThat(((NoOpInterceptor) interceptors.get(1)).invocationCount).isEqualTo(1);
assertThat(((NoOpInterceptor) interceptors.get(2)).invocationCount).isEqualTo(1);
assertThat(requestMock.isExecuted()).isTrue();
assertThat(response).isSameAs(responseMock);
}
@Test
void noExecution() throws Exception {
void shouldSkipIntercetor() throws Exception {
List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>();
interceptors.add((request, body, execution) -> responseMock);
interceptors.add(new NoOpInterceptor());
requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, interceptors);
ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET);
ClientHttpResponse response = request.execute();
assertThat(((NoOpInterceptor) interceptors.get(1)).invoked).isFalse();
assertThat(((NoOpInterceptor) interceptors.get(1)).invocationCount).isZero();
assertThat(requestMock.isExecuted()).isFalse();
assertThat(response).isSameAs(responseMock);
}
@Test
void changeHeaders() throws Exception {
void interceptorShouldUpdateRequestHeader() throws Exception {
final String headerName = "Foo";
final String headerValue = "Bar";
final String otherValue = "Baz";
@ -98,7 +99,6 @@ class InterceptingClientHttpRequestFactoryTests { @@ -98,7 +99,6 @@ class InterceptingClientHttpRequestFactoryTests {
wrapper.getHeaders().add(headerName, otherValue);
return execution.execute(wrapper, body);
};
requestMock = new MockClientHttpRequest() {
@Override
protected ClientHttpResponse executeInternal() {
@ -110,31 +110,26 @@ class InterceptingClientHttpRequestFactoryTests { @@ -110,31 +110,26 @@ class InterceptingClientHttpRequestFactoryTests {
requestMock.getHeaders().add(headerName, headerValue);
requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor));
ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET);
request.execute();
}
@Test
void changeAttribute() throws Exception {
void interceptorShouldUpdateRequestAttribute() throws Exception {
final String attrName = "Foo";
final String attrValue = "Bar";
ClientHttpRequestInterceptor interceptor = (request, body, execution) -> {
System.out.println("interceptor");
request.getAttributes().put(attrName, attrValue);
return execution.execute(request, body);
};
requestMock = new MockClientHttpRequest() {
@Override
protected ClientHttpResponse executeInternal() {
System.out.println("execute");
assertThat(getAttributes()).containsEntry(attrName, attrValue);
return responseMock;
}
};
requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor));
ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET);
@ -142,7 +137,7 @@ class InterceptingClientHttpRequestFactoryTests { @@ -142,7 +137,7 @@ class InterceptingClientHttpRequestFactoryTests {
}
@Test
void changeURI() throws Exception {
void interceptorShouldUpdateRequestURI() throws Exception {
final URI changedUri = URI.create("https://example.com/2");
ClientHttpRequestInterceptor interceptor = (request, body, execution) -> execution.execute(new HttpRequestWrapper(request) {
@ -152,7 +147,6 @@ class InterceptingClientHttpRequestFactoryTests { @@ -152,7 +147,6 @@ class InterceptingClientHttpRequestFactoryTests {
}
}, body);
requestFactoryMock = new RequestFactoryMock() {
@Override
public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
@ -160,7 +154,6 @@ class InterceptingClientHttpRequestFactoryTests { @@ -160,7 +154,6 @@ class InterceptingClientHttpRequestFactoryTests {
return super.createRequest(uri, httpMethod);
}
};
requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor));
ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET);
@ -168,7 +161,7 @@ class InterceptingClientHttpRequestFactoryTests { @@ -168,7 +161,7 @@ class InterceptingClientHttpRequestFactoryTests {
}
@Test
void changeMethod() throws Exception {
void interceptorShouldUpdateRequestMethod() throws Exception {
final HttpMethod changedMethod = HttpMethod.POST;
ClientHttpRequestInterceptor interceptor = (request, body, execution) -> execution.execute(new HttpRequestWrapper(request) {
@ -178,7 +171,6 @@ class InterceptingClientHttpRequestFactoryTests { @@ -178,7 +171,6 @@ class InterceptingClientHttpRequestFactoryTests {
}
}, body);
requestFactoryMock = new RequestFactoryMock() {
@Override
public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
@ -186,7 +178,6 @@ class InterceptingClientHttpRequestFactoryTests { @@ -186,7 +178,6 @@ class InterceptingClientHttpRequestFactoryTests {
return super.createRequest(uri, httpMethod);
}
};
requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor));
ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET);
@ -194,28 +185,52 @@ class InterceptingClientHttpRequestFactoryTests { @@ -194,28 +185,52 @@ class InterceptingClientHttpRequestFactoryTests {
}
@Test
void changeBody() throws Exception {
void interceptorShouldUpdateRequestBody() throws Exception {
final byte[] changedBody = "Foo".getBytes();
ClientHttpRequestInterceptor interceptor = (request, body, execution) -> execution.execute(request, changedBody);
requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor));
ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET);
request.execute();
assertThat(Arrays.equals(changedBody, requestMock.getBodyAsBytes())).isTrue();
assertThat(requestMock.getHeaders().getContentLength()).isEqualTo(changedBody.length);
}
@Test
void interceptorShouldAlwaysExecuteNextInterceptor() throws Exception {
List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>();
interceptors.add(new MultipleExecutionInterceptor());
interceptors.add(new NoOpInterceptor());
requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, interceptors);
ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET);
ClientHttpResponse response = request.execute();
assertThat(((NoOpInterceptor) interceptors.get(1)).invocationCount).isEqualTo(2);
assertThat(requestMock.isExecuted()).isTrue();
assertThat(response).isSameAs(responseMock);
}
private static class NoOpInterceptor implements ClientHttpRequestInterceptor {
private boolean invoked = false;
private int invocationCount = 0;
@Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
throws IOException {
invocationCount++;
return execution.execute(request, body);
}
}
private static class MultipleExecutionInterceptor implements ClientHttpRequestInterceptor {
@Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
throws IOException {
invoked = true;
// execute another request first
execution.execute(new MockClientHttpRequest(), body);
return execution.execute(request, body);
}
}

Loading…
Cancel
Save