diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java index d098c164d19..3743bb294aa 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java @@ -25,7 +25,9 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.ByteBuffer; import java.time.Duration; -import java.util.List; +import java.util.Collections; +import java.util.Set; +import java.util.TreeSet; import java.util.concurrent.Executor; import java.util.concurrent.Flow; @@ -33,6 +35,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.lang.Nullable; import org.springframework.util.StreamUtils; +import org.springframework.util.StringUtils; /** * {@link ClientHttpRequest} implementation based the Java {@link HttpClient}. @@ -44,12 +47,27 @@ import org.springframework.util.StreamUtils; */ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest { - /* - * The JDK HttpRequest doesn't allow all headers to be set. The named headers are taken from the default - * implementation for HttpRequest. + private static final Set DISALLOWED_HEADERS = disallowedHeaders(); + + /** + * By default, {@link HttpRequest} does not allow {@code Connection}, + * {@code Content-Length}, {@code Expect}, {@code Host}, or {@code Upgrade} + * headers to be set, but this can be overriden with the + * {@code jdk.httpclient.allowRestrictedHeaders} system property. + * @see jdk.internal.net.http.common.Utils#getDisallowedHeaders() */ - private static final List DISALLOWED_HEADERS = - List.of("connection", "content-length", "expect", "host", "upgrade"); + private static Set disallowedHeaders() { + TreeSet headers = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); + headers.addAll(Set.of("connection", "content-length", "expect", "host", "upgrade")); + + String headersToAllow = System.getProperty("jdk.httpclient.allowRestrictedHeaders"); + if (headersToAllow != null) { + Set toAllow = StringUtils.commaDelimitedListToSet(headersToAllow); + headers.removeAll(toAllow); + } + return Collections.unmodifiableSet(headers); + } + private final HttpClient httpClient; @@ -109,11 +127,9 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest { } headers.forEach((headerName, headerValues) -> { - if (!headerName.equalsIgnoreCase(HttpHeaders.CONTENT_LENGTH)) { - if (!DISALLOWED_HEADERS.contains(headerName.toLowerCase())) { - for (String headerValue : headerValues) { - builder.header(headerName, headerValue); - } + if (!DISALLOWED_HEADERS.contains(headerName.toLowerCase())) { + for (String headerValue : headerValues) { + builder.header(headerName, headerValue); } } }); diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java index cbef27287d4..1577c81e610 100644 --- a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java @@ -79,6 +79,11 @@ public abstract class AbstractMockWebServerTests { else if(request.getPath().equals("/status/notfound")) { return new MockResponse().setResponseCode(404); } + else if (request.getPath().equals("/status/299")) { + assertThat(request.getHeader("Expect")) + .contains("299"); + return new MockResponse().setResponseCode(299); + } else if(request.getPath().startsWith("/params")) { assertThat(request.getPath()).contains("param1=value"); assertThat(request.getPath()).contains("param2=value1¶m2=value2"); diff --git a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java index f882e2bfa47..cce3a37989e 100644 --- a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java @@ -16,15 +16,43 @@ package org.springframework.http.client; +import java.io.IOException; +import java.net.URI; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatusCode; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; /** * @author Marten Deinum */ public class JdkClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTests { + @Nullable + private static String originalPropertyValue; + + @BeforeAll + public static void setProperty() { + originalPropertyValue = System.getProperty("jdk.httpclient.allowRestrictedHeaders"); + System.setProperty("jdk.httpclient.allowRestrictedHeaders", "expect"); + } + + @AfterAll + public static void restoreProperty() { + if (originalPropertyValue != null) { + System.setProperty("jdk.httpclient.allowRestrictedHeaders", originalPropertyValue); + } + else { + System.clearProperty("jdk.httpclient.allowRestrictedHeaders"); + } + } + @Override protected ClientHttpRequestFactory createRequestFactory() { return new JdkClientHttpRequestFactory(); @@ -37,4 +65,14 @@ public class JdkClientHttpRequestFactoryTests extends AbstractHttpRequestFactory assertHttpMethod("patch", HttpMethod.PATCH); } + @Test + public void customizeDisallowedHeaders() throws IOException { + ClientHttpRequest request = factory.createRequest(URI.create(this.baseUrl + "/status/299"), HttpMethod.PUT); + request.getHeaders().set("Expect", "299"); + + try (ClientHttpResponse response = request.execute()) { + assertThat(response.getStatusCode()).as("Invalid status code").isEqualTo(HttpStatusCode.valueOf(299)); + } + } + }