diff --git a/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequest.java index a14d8ef8cbe..d2559a800df 100644 --- a/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/HttpComponentsClientHttpRequest.java @@ -30,6 +30,8 @@ import org.apache.hc.core5.http.ClassicHttpRequest; import org.apache.hc.core5.http.ClassicHttpResponse; import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.Method; +import org.apache.hc.core5.http.io.entity.NullEntity; import org.apache.hc.core5.http.protocol.HttpContext; import org.springframework.http.HttpHeaders; @@ -89,8 +91,10 @@ final class HttpComponentsClientHttpRequest extends AbstractStreamingClientHttpR addHeaders(this.httpRequest, headers); if (body != null) { - HttpEntity requestEntity = new BodyEntity(headers, body); - this.httpRequest.setEntity(requestEntity); + this.httpRequest.setEntity(new BodyEntity(headers, body)); + } + else if (!Method.isSafe(this.httpRequest.getMethod())) { + this.httpRequest.setEntity(NullEntity.INSTANCE); } ClassicHttpResponse httpResponse = this.httpClient.executeOpen(null, this.httpRequest, this.httpContext); return new HttpComponentsClientHttpResponse(httpResponse); diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java index c2237dd61d8..c8c88be4ec7 100644 --- a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java @@ -106,6 +106,10 @@ public abstract class AbstractMockWebServerTests { assertThat(request.getMethod()).isEqualTo(expectedMethod); return new MockResponse(); } + else if(request.getPath().startsWith("/header/")) { + String headerName = request.getPath().replace("/header/",""); + return new MockResponse().setBody(headerName + ":" + request.getHeader(headerName)).setResponseCode(200); + } return new MockResponse().setResponseCode(404); } catch (Throwable exc) { diff --git a/spring-web/src/test/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactoryTests.java index d1d355ccfc2..8ad64e76b51 100644 --- a/spring-web/src/test/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactoryTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/HttpComponentsClientHttpRequestFactoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ package org.springframework.http.client; +import java.io.InputStreamReader; import java.net.URI; +import java.util.stream.Stream; import org.apache.hc.client5.http.classic.HttpClient; import org.apache.hc.client5.http.config.Configurable; @@ -26,8 +28,12 @@ import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; import org.apache.hc.client5.http.protocol.HttpClientContext; import org.apache.hc.core5.util.Timeout; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.util.FileCopyUtils; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.assertj.core.api.Assertions.assertThat; @@ -37,6 +43,7 @@ import static org.mockito.Mockito.withSettings; /** * @author Stephane Nicoll + * @author Brian Clozel */ class HttpComponentsClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTests { @@ -145,6 +152,36 @@ class HttpComponentsClientHttpRequestFactoryTests extends AbstractHttpRequestFac assertThat(requestConfig2.getConnectionRequestTimeout()).isEqualTo(Timeout.of(7000, MILLISECONDS)); } + @ParameterizedTest + @MethodSource("unsafeHttpMethods") + void shouldSetContentLengthWhenEmptyBody(HttpMethod method) throws Exception { + ClientHttpRequest request = factory.createRequest(URI.create(baseUrl + "/header/Content-Length"), method); + try (ClientHttpResponse response = request.execute()) { + assertThat(response.getStatusCode()).as("Invalid status code").isEqualTo(HttpStatus.OK); + String result = FileCopyUtils.copyToString(new InputStreamReader(response.getBody())); + assertThat(result).as("Invalid body").isEqualTo("Content-Length:0"); + } + } + + static Stream unsafeHttpMethods() { + return Stream.of(HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE, HttpMethod.PATCH); + } + + @ParameterizedTest + @MethodSource("safeHttpMethods") + void shouldNotSetContentLengthWhenEmptyBodyAndSafeMethod(HttpMethod method) throws Exception { + ClientHttpRequest request = factory.createRequest(URI.create(baseUrl + "/header/Content-Length"), method); + try (ClientHttpResponse response = request.execute()) { + assertThat(response.getStatusCode()).as("Invalid status code").isEqualTo(HttpStatus.OK); + String result = FileCopyUtils.copyToString(new InputStreamReader(response.getBody())); + assertThat(result).as("Invalid body").isEqualTo("Content-Length:null"); + } + } + + static Stream safeHttpMethods() { + return Stream.of(HttpMethod.GET, HttpMethod.OPTIONS, HttpMethod.TRACE); + } + private RequestConfig retrieveRequestConfig(HttpComponentsClientHttpRequestFactory factory) throws Exception { URI uri = URI.create(baseUrl + "/status/ok"); HttpComponentsClientHttpRequest request = (HttpComponentsClientHttpRequest)