diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java index 0052e4a62b9..2f8af19d93f 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java @@ -19,6 +19,7 @@ package org.springframework.http.client; import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.PushbackInputStream; import java.io.UncheckedIOException; import java.net.URI; import java.net.http.HttpClient; @@ -60,6 +61,7 @@ import org.springframework.util.StringUtils; * * @author Marten Deinum * @author Arjen Poutsma + * @author Brian Clozel * @since 6.1 */ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest { @@ -325,30 +327,61 @@ class JdkClientHttpRequest extends AbstractStreamingClientHttpRequest { */ private static final class DecompressingBodyHandler implements BodyHandler { + @Override public BodySubscriber apply(ResponseInfo responseInfo) { - String contentEncoding = responseInfo.headers().firstValue(HttpHeaders.CONTENT_ENCODING).orElse(""); - if (contentEncoding.equalsIgnoreCase("gzip")) { - return BodySubscribers.mapping( + + String contentEncoding = responseInfo.headers() + .firstValue(HttpHeaders.CONTENT_ENCODING) + .orElse("") + .toLowerCase(Locale.ROOT); + + return switch (contentEncoding) { + case "gzip", "deflate" -> BodySubscribers.mapping( BodySubscribers.ofInputStream(), - (InputStream is) -> { - try { - return new GZIPInputStream(is); - } - catch (IOException ex) { - throw new UncheckedIOException(ex); - } - }); + (InputStream is) -> decompressStream(is, contentEncoding)); + default -> BodySubscribers.ofInputStream(); + }; + } + + private static InputStream decompressStream(InputStream original, String contentEncoding) { + PushbackInputStream wrapped = new PushbackInputStream(original); + try { + if (hasResponseBody(wrapped)) { + if (contentEncoding.equals("gzip")) { + return new GZIPInputStream(wrapped); + } + else if (contentEncoding.equals("deflate")) { + return new InflaterInputStream(wrapped); + } + } + else { + return wrapped; + } } - else if (contentEncoding.equalsIgnoreCase("deflate")) { - return BodySubscribers.mapping( - BodySubscribers.ofInputStream(), - InflaterInputStream::new); + catch (IOException ex) { + throw new UncheckedIOException(ex); } - else { - return BodySubscribers.ofInputStream(); + return wrapped; + } + + private static boolean hasResponseBody(PushbackInputStream inputStream) { + try { + int b = inputStream.read(); + if (b == -1) { + return false; + } + else { + inputStream.unread(b); + return true; + } + + } + catch (IOException exc) { + return false; } } } + } diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java index 1376fdc2026..552413c3ddb 100644 --- a/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractMockWebServerTests.java @@ -112,7 +112,7 @@ public abstract class AbstractMockWebServerTests { String headerName = request.getTarget().replace("/header/",""); return new MockResponse.Builder().body(headerName + ":" + request.getHeaders().get(headerName)).code(200).build(); } - else if(request.getTarget().startsWith("/compress/") && request.getBody() != null) { + else if(request.getMethod().equals("POST") && request.getTarget().startsWith("/compress/") && request.getBody() != null) { String encoding = request.getTarget().replace("/compress/",""); String requestBody = request.getBody().utf8(); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); @@ -139,6 +139,13 @@ public abstract class AbstractMockWebServerTests { builder.setHeader(HttpHeaders.CONTENT_LENGTH, buffer.size()); return builder.build(); } + else if (request.getMethod().equals("HEAD") && request.getTarget().startsWith("/headforcompress/")) { + String encoding = request.getTarget().replace("/headforcompress/",""); + MockResponse.Builder builder = new MockResponse.Builder().code(200) + .setHeader(HttpHeaders.CONTENT_LENGTH, 500) + .setHeader(HttpHeaders.CONTENT_ENCODING, encoding); + return builder.build(); + } return new MockResponse.Builder().code(404).build(); } catch (Throwable ex) { diff --git a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java index 391e42e4403..efb0a01dd06 100644 --- a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java @@ -26,6 +26,8 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledForJreRange; import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; @@ -159,6 +161,23 @@ class JdkClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTests { } } + @ParameterizedTest + @ValueSource(strings = {"gzip", "deflate"}) + void gzipCompressionWithHeadRequest(String compression) throws IOException { + URI uri = URI.create(baseUrl + "/headforcompress/" + compression); + JdkClientHttpRequestFactory requestFactory = (JdkClientHttpRequestFactory) this.factory; + requestFactory.enableCompression(true); + ClientHttpRequest request = requestFactory.createRequest(uri, HttpMethod.HEAD); + try (ClientHttpResponse response = request.execute()) { + assertThat(response.getStatusCode()).as("Invalid response status").isEqualTo(HttpStatus.OK); + assertThat(response.getHeaders().getFirst("Content-Encoding")) + .as("Content Encoding should be removed").isNull(); + assertThat(response.getHeaders().getFirst("Content-Length")) + .as("Content-Length should be removed").isNull(); + assertThat(response.getBody()).as("Invalid response body").isEmpty(); + } + } + @Test // gh-34971 @EnabledForJreRange(min = JRE.JAVA_19) // behavior fixed in Java 19 void requestContentLengthHeaderWhenNoBody() throws Exception {