From e8e722fb597af7e84fdfbbc796e3c1b87983a578 Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Mon, 6 Jan 2025 19:29:09 +0100 Subject: [PATCH] Allow multiple executions of ClientHttpRequestInterceptors Prior to this commit, an `ClientHttpRequestInterceptor` implementation could delegate HTTP calls to the next `ClientHttpRequestExecution` only once. Calling the execution would advance to the next interceptor in the chain in a mutable fashion for the entire lifetime of the current exchange. This commit changes the implementation of `InterceptingClientHttpRequest` so that a `ClientHttpRequestInterceptor` implementation can call `ClientHttpRequestExecution#execute` multiple times. This is especially useful for interceptors in case they want to issue other HTTP requests without needing another `RestTemplate` or `RestClient` instance provided out of band. Closes gh-34169 --- .../client/InterceptingClientHttpRequest.java | 94 +++++++++++-------- ...rceptingClientHttpRequestFactoryTests.java | 71 ++++++++------ 2 files changed, 99 insertions(+), 66 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequest.java index 84ee77f3d70..9a7181ce834 100644 --- a/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/InterceptingClientHttpRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,8 @@ package org.springframework.http.client; import java.io.IOException; import java.io.OutputStream; import java.net.URI; -import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -33,6 +33,7 @@ import org.springframework.util.StreamUtils; * ClientHttpRequestInterceptors}. * * @author Arjen Poutsma + * @author Brian Clozel * @since 3.1 */ class InterceptingClientHttpRequest extends AbstractBufferingClientHttpRequest { @@ -68,54 +69,71 @@ class InterceptingClientHttpRequest extends AbstractBufferingClientHttpRequest { @Override protected final ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { - InterceptingRequestExecution requestExecution = new InterceptingRequestExecution(); + ClientHttpRequestExecution requestExecution = new DelegatingRequestExecution(this.requestFactory); + ListIterator iterator = this.interceptors.listIterator(this.interceptors.size()); + while (iterator.hasPrevious()) { + ClientHttpRequestInterceptor interceptor = iterator.previous(); + requestExecution = new InterceptingRequestExecution(interceptor, requestExecution); + } return requestExecution.execute(this, bufferedOutput); } - private class InterceptingRequestExecution implements ClientHttpRequestExecution { + private static class InterceptingRequestExecution implements ClientHttpRequestExecution { + + private final ClientHttpRequestInterceptor interceptor; - private final Iterator iterator; + private final ClientHttpRequestExecution nextExecution; - public InterceptingRequestExecution() { - this.iterator = interceptors.iterator(); + public InterceptingRequestExecution(ClientHttpRequestInterceptor interceptor, ClientHttpRequestExecution nextExecution) { + this.interceptor = interceptor; + this.nextExecution = nextExecution; } @Override public ClientHttpResponse execute(HttpRequest request, byte[] body) throws IOException { - if (this.iterator.hasNext()) { - ClientHttpRequestInterceptor nextInterceptor = this.iterator.next(); - return nextInterceptor.intercept(request, body, this); - } - else { - HttpMethod method = request.getMethod(); - ClientHttpRequest delegate = requestFactory.createRequest(request.getURI(), method); - request.getHeaders().forEach((key, value) -> delegate.getHeaders().addAll(key, value)); - request.getAttributes().forEach((key, value) -> delegate.getAttributes().put(key, value)); - if (body.length > 0) { - long contentLength = delegate.getHeaders().getContentLength(); - if (contentLength > -1 && contentLength != body.length) { - delegate.getHeaders().setContentLength(body.length); - } - if (delegate instanceof StreamingHttpOutputMessage streamingOutputMessage) { - streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() { - @Override - public void writeTo(OutputStream outputStream) throws IOException { - StreamUtils.copy(body, outputStream); - } - - @Override - public boolean repeatable() { - return true; - } - }); - } - else { - StreamUtils.copy(body, delegate.getBody()); - } + return this.interceptor.intercept(request, body, this.nextExecution); + } + + } + + private static class DelegatingRequestExecution implements ClientHttpRequestExecution { + + private final ClientHttpRequestFactory requestFactory; + + public DelegatingRequestExecution(ClientHttpRequestFactory requestFactory) { + this.requestFactory = requestFactory; + } + + @Override + public ClientHttpResponse execute(HttpRequest request, byte[] body) throws IOException { + HttpMethod method = request.getMethod(); + ClientHttpRequest delegate = this.requestFactory.createRequest(request.getURI(), method); + request.getHeaders().forEach((key, value) -> delegate.getHeaders().addAll(key, value)); + request.getAttributes().forEach((key, value) -> delegate.getAttributes().put(key, value)); + if (body.length > 0) { + long contentLength = delegate.getHeaders().getContentLength(); + if (contentLength > -1 && contentLength != body.length) { + delegate.getHeaders().setContentLength(body.length); + } + if (delegate instanceof StreamingHttpOutputMessage streamingOutputMessage) { + streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() { + @Override + public void writeTo(OutputStream outputStream) throws IOException { + StreamUtils.copy(body, outputStream); + } + + @Override + public boolean repeatable() { + return true; + } + }); + } + else { + StreamUtils.copy(body, delegate.getBody()); } - return delegate.execute(); } + return delegate.execute(); } } diff --git a/spring-web/src/test/java/org/springframework/http/client/InterceptingClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/InterceptingClientHttpRequestFactoryTests.java index f8248d0e085..fd605e0705c 100644 --- a/spring-web/src/test/java/org/springframework/http/client/InterceptingClientHttpRequestFactoryTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/InterceptingClientHttpRequestFactoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,8 +35,10 @@ import org.springframework.web.testfixture.http.client.MockClientHttpResponse; import static org.assertj.core.api.Assertions.assertThat; /** + * Tests for {@link InterceptingClientHttpRequestFactory} * @author Arjen Poutsma * @author Juergen Hoeller + * @author Brian Clozel */ class InterceptingClientHttpRequestFactoryTests { @@ -54,7 +56,7 @@ class InterceptingClientHttpRequestFactoryTests { } @Test - void basic() throws Exception { + void shouldInvokeInterceptors() throws Exception { List interceptors = new ArrayList<>(); interceptors.add(new NoOpInterceptor()); interceptors.add(new NoOpInterceptor()); @@ -64,31 +66,30 @@ class InterceptingClientHttpRequestFactoryTests { ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET); ClientHttpResponse response = request.execute(); - assertThat(((NoOpInterceptor) interceptors.get(0)).invoked).isTrue(); - assertThat(((NoOpInterceptor) interceptors.get(1)).invoked).isTrue(); - assertThat(((NoOpInterceptor) interceptors.get(2)).invoked).isTrue(); + assertThat(((NoOpInterceptor) interceptors.get(0)).invocationCount).isEqualTo(1); + assertThat(((NoOpInterceptor) interceptors.get(1)).invocationCount).isEqualTo(1); + assertThat(((NoOpInterceptor) interceptors.get(2)).invocationCount).isEqualTo(1); assertThat(requestMock.isExecuted()).isTrue(); assertThat(response).isSameAs(responseMock); } @Test - void noExecution() throws Exception { + void shouldSkipIntercetor() throws Exception { List interceptors = new ArrayList<>(); interceptors.add((request, body, execution) -> responseMock); - interceptors.add(new NoOpInterceptor()); requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, interceptors); ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET); ClientHttpResponse response = request.execute(); - assertThat(((NoOpInterceptor) interceptors.get(1)).invoked).isFalse(); + assertThat(((NoOpInterceptor) interceptors.get(1)).invocationCount).isZero(); assertThat(requestMock.isExecuted()).isFalse(); assertThat(response).isSameAs(responseMock); } @Test - void changeHeaders() throws Exception { + void interceptorShouldUpdateRequestHeader() throws Exception { final String headerName = "Foo"; final String headerValue = "Bar"; final String otherValue = "Baz"; @@ -98,7 +99,6 @@ class InterceptingClientHttpRequestFactoryTests { wrapper.getHeaders().add(headerName, otherValue); return execution.execute(wrapper, body); }; - requestMock = new MockClientHttpRequest() { @Override protected ClientHttpResponse executeInternal() { @@ -110,31 +110,26 @@ class InterceptingClientHttpRequestFactoryTests { requestMock.getHeaders().add(headerName, headerValue); requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor)); - ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET); request.execute(); } @Test - void changeAttribute() throws Exception { + void interceptorShouldUpdateRequestAttribute() throws Exception { final String attrName = "Foo"; final String attrValue = "Bar"; ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { - System.out.println("interceptor"); request.getAttributes().put(attrName, attrValue); return execution.execute(request, body); }; - requestMock = new MockClientHttpRequest() { @Override protected ClientHttpResponse executeInternal() { - System.out.println("execute"); assertThat(getAttributes()).containsEntry(attrName, attrValue); return responseMock; } }; - requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor)); ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET); @@ -142,7 +137,7 @@ class InterceptingClientHttpRequestFactoryTests { } @Test - void changeURI() throws Exception { + void interceptorShouldUpdateRequestURI() throws Exception { final URI changedUri = URI.create("https://example.com/2"); ClientHttpRequestInterceptor interceptor = (request, body, execution) -> execution.execute(new HttpRequestWrapper(request) { @@ -152,7 +147,6 @@ class InterceptingClientHttpRequestFactoryTests { } }, body); - requestFactoryMock = new RequestFactoryMock() { @Override public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { @@ -160,7 +154,6 @@ class InterceptingClientHttpRequestFactoryTests { return super.createRequest(uri, httpMethod); } }; - requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor)); ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET); @@ -168,7 +161,7 @@ class InterceptingClientHttpRequestFactoryTests { } @Test - void changeMethod() throws Exception { + void interceptorShouldUpdateRequestMethod() throws Exception { final HttpMethod changedMethod = HttpMethod.POST; ClientHttpRequestInterceptor interceptor = (request, body, execution) -> execution.execute(new HttpRequestWrapper(request) { @@ -178,7 +171,6 @@ class InterceptingClientHttpRequestFactoryTests { } }, body); - requestFactoryMock = new RequestFactoryMock() { @Override public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { @@ -186,7 +178,6 @@ class InterceptingClientHttpRequestFactoryTests { return super.createRequest(uri, httpMethod); } }; - requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor)); ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET); @@ -194,28 +185,52 @@ class InterceptingClientHttpRequestFactoryTests { } @Test - void changeBody() throws Exception { + void interceptorShouldUpdateRequestBody() throws Exception { final byte[] changedBody = "Foo".getBytes(); - ClientHttpRequestInterceptor interceptor = (request, body, execution) -> execution.execute(request, changedBody); - requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor)); - ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET); request.execute(); + assertThat(Arrays.equals(changedBody, requestMock.getBodyAsBytes())).isTrue(); assertThat(requestMock.getHeaders().getContentLength()).isEqualTo(changedBody.length); } + @Test + void interceptorShouldAlwaysExecuteNextInterceptor() throws Exception { + List interceptors = new ArrayList<>(); + interceptors.add(new MultipleExecutionInterceptor()); + interceptors.add(new NoOpInterceptor()); + requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, interceptors); + + ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET); + ClientHttpResponse response = request.execute(); + + assertThat(((NoOpInterceptor) interceptors.get(1)).invocationCount).isEqualTo(2); + assertThat(requestMock.isExecuted()).isTrue(); + assertThat(response).isSameAs(responseMock); + } + private static class NoOpInterceptor implements ClientHttpRequestInterceptor { - private boolean invoked = false; + private int invocationCount = 0; + + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) + throws IOException { + invocationCount++; + return execution.execute(request, body); + } + } + + private static class MultipleExecutionInterceptor implements ClientHttpRequestInterceptor { @Override public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException { - invoked = true; + // execute another request first + execution.execute(new MockClientHttpRequest(), body); return execution.execute(request, body); } }