diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java index d098c164d19..7001df92a78 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java @@ -25,7 +25,9 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.ByteBuffer; import java.time.Duration; -import java.util.List; +import java.util.Collections; +import java.util.Set; +import java.util.TreeSet; import java.util.concurrent.Executor; import java.util.concurrent.Flow; @@ -33,6 +35,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.lang.Nullable; import org.springframework.util.StreamUtils; +import org.springframework.util.StringUtils; /** * {@link ClientHttpRequest} implementation based the Java {@link HttpClient}. @@ -48,8 +51,19 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest { * The JDK HttpRequest doesn't allow all headers to be set. The named headers are taken from the default * implementation for HttpRequest. */ - private static final List DISALLOWED_HEADERS = - List.of("connection", "content-length", "expect", "host", "upgrade"); + protected static final Set DISALLOWED_HEADERS = getDisallowedHeaders(); + + private static Set getDisallowedHeaders() { + TreeSet headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); + headers.addAll(Set.of("connection", "content-length", "expect", "host", "upgrade")); + + String headersToAllow = System.getProperty("jdk.httpclient.allowRestrictedHeaders"); + if (headersToAllow != null) { + Set toAllow = StringUtils.commaDelimitedListToSet(headersToAllow); + headers.removeAll(toAllow); + } + return Collections.unmodifiableSet(headers); + } private final HttpClient httpClient; diff --git a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java index f882e2bfa47..d03fca2f77b 100644 --- a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java @@ -16,10 +16,17 @@ package org.springframework.http.client; +import java.net.URI; +import java.net.http.HttpClient; +import java.time.Duration; +import java.util.concurrent.Executor; + import org.junit.jupiter.api.Test; import org.springframework.http.HttpMethod; +import static org.assertj.core.api.Assertions.assertThat; + /** * @author Marten Deinum */ @@ -37,4 +44,26 @@ public class JdkClientHttpRequestFactoryTests extends AbstractHttpRequestFactory assertHttpMethod("patch", HttpMethod.PATCH); } + @Test + public void customizeDisallowedHeaders() { + String original = System.getProperty("jdk.httpclient.allowRestrictedHeaders"); + System.setProperty("jdk.httpclient.allowRestrictedHeaders", "host"); + + assertThat(TestJdkClientHttpRequest.DISALLOWED_HEADERS).doesNotContain("host"); + + if (original != null) { + System.setProperty("jdk.httpclient.allowRestrictedHeaders", original); + } + else { + System.clearProperty("jdk.httpclient.allowRestrictedHeaders"); + } + } + + static class TestJdkClientHttpRequest extends JdkClientHttpRequest { + + public TestJdkClientHttpRequest(HttpClient httpClient, URI uri, HttpMethod method, Executor executor, Duration readTimeout) { + super(httpClient, uri, method, executor, readTimeout); + } + } + }