Browse Source

Do not cache Content-Type in ContentCachingResponseWrapper

Based on feedback from several members of the community, we have
decided to revert the caching of the Content-Type header that was
introduced in ContentCachingResponseWrapper in 375e0e6827.

This commit therefore completely removes Content-Type caching in
ContentCachingResponseWrapper and updates the existing tests
accordingly.

To provide guards against future regressions in this area, this commit
also introduces explicit tests for the 6 ways to set the content length
in ContentCachingResponseWrapper and modifies a test in
ShallowEtagHeaderFilterTests to ensure that a Content-Type header set
directly on ContentCachingResponseWrapper is propagated to the
underlying response even if content caching is disabled for the
ShallowEtagHeaderFilter.

See gh-32039
See gh-32317
Closes gh-32321
pull/32357/head
Sam Brannen 2 years ago
parent
commit
d1b3107398
  1. 45
      spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java
  2. 97
      spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java
  3. 7
      spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java

45
spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java

@ -60,9 +60,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @@ -60,9 +60,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@Nullable
private Integer contentLength;
@Nullable
private String contentType;
/**
* Create a new ContentCachingResponseWrapper for the given servlet response.
@ -150,28 +147,11 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @@ -150,28 +147,11 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
setContentLength((int) len);
}
@Override
public void setContentType(@Nullable String type) {
this.contentType = type;
}
@Override
@Nullable
public String getContentType() {
if (this.contentType != null) {
return this.contentType;
}
return super.getContentType();
}
@Override
public boolean containsHeader(String name) {
if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return true;
}
else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return true;
}
else {
return super.containsHeader(name);
}
@ -182,9 +162,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @@ -182,9 +162,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.setHeader(name, value);
}
@ -195,9 +172,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @@ -195,9 +172,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.addHeader(name, value);
}
@ -229,9 +203,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @@ -229,9 +203,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return this.contentLength.toString();
}
else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType;
}
else {
return super.getHeader(name);
}
@ -242,9 +213,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @@ -242,9 +213,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return Collections.singleton(this.contentLength.toString());
}
else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return Collections.singleton(this.contentType);
}
else {
return super.getHeaders(name);
}
@ -253,14 +221,9 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @@ -253,14 +221,9 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@Override
public Collection<String> getHeaderNames() {
Collection<String> headerNames = super.getHeaderNames();
if (this.contentLength != null || this.contentType != null) {
if (this.contentLength != null) {
Set<String> result = new LinkedHashSet<>(headerNames);
if (this.contentLength != null) {
result.add(HttpHeaders.CONTENT_LENGTH);
}
if (this.contentType != null) {
result.add(HttpHeaders.CONTENT_TYPE);
}
result.add(HttpHeaders.CONTENT_LENGTH);
return result;
}
else {
@ -333,10 +296,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @@ -333,10 +296,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
}
this.contentLength = null;
}
if (this.contentType != null) {
rawResponse.setContentType(this.contentType);
this.contentType = null;
}
}
this.content.writeTo(rawResponse.getOutputStream());
this.content.reset();

97
spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,13 +16,12 @@ @@ -16,13 +16,12 @@
package org.springframework.web.filter;
import java.util.function.BiConsumer;
import java.util.stream.Stream;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Named;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.http.MediaType;
@ -33,17 +32,17 @@ import org.springframework.web.util.ContentCachingResponseWrapper; @@ -33,17 +32,17 @@ import org.springframework.web.util.ContentCachingResponseWrapper;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Named.named;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH;
import static org.springframework.http.HttpHeaders.CONTENT_TYPE;
import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING;
/**
* Unit tests for {@link ContentCachingResponseWrapper}.
*
* @author Rossen Stoyanchev
* @author Sam Brannen
*/
public class ContentCachingResponseWrapperTests {
class ContentCachingResponseWrapperTests {
@Test
void copyBodyToResponse() throws Exception {
@ -119,31 +118,76 @@ public class ContentCachingResponseWrapperTests { @@ -119,31 +118,76 @@ public class ContentCachingResponseWrapperTests {
}
@ParameterizedTest(name = "[{index}] {0}")
@MethodSource("setContentTypeFunctions")
void copyBodyToResponseWithOverridingHeaders(BiConsumer<HttpServletResponse, String> setContentType) throws Exception {
@MethodSource("setContentLengthFunctions")
void copyBodyToResponseWithOverridingContentLength(SetContentLength setContentLength) throws Exception {
byte[] responseBody = "Hello World".getBytes(UTF_8);
int responseLength = responseBody.length;
int originalContentLength = 11;
int overridingContentLength = 22;
String originalContentType = MediaType.TEXT_PLAIN_VALUE;
String overridingContentType = MediaType.APPLICATION_JSON_VALUE;
MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentLength(originalContentLength);
response.setContentType(originalContentType);
ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response);
responseWrapper.setStatus(HttpServletResponse.SC_CREATED);
responseWrapper.setContentLength(overridingContentLength);
setContentType.accept(responseWrapper, overridingContentType);
assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
setContentLength.invoke(responseWrapper, overridingContentLength);
assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH);
assertHeader(response, CONTENT_LENGTH, originalContentLength);
assertHeader(responseWrapper, CONTENT_LENGTH, overridingContentLength);
FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream());
assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength);
responseWrapper.copyBodyToResponse();
assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH);
assertHeader(response, CONTENT_LENGTH, responseLength);
assertHeader(responseWrapper, CONTENT_LENGTH, responseLength);
assertThat(response.getContentLength()).isEqualTo(responseLength);
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
assertThat(response.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH);
}
private static Stream<Named<SetContentLength>> setContentLengthFunctions() {
return Stream.of(
named("setContentLength()", HttpServletResponse::setContentLength),
named("setContentLengthLong()", HttpServletResponse::setContentLengthLong),
named("setIntHeader()", (response, contentLength) -> response.setIntHeader(CONTENT_LENGTH, contentLength)),
named("addIntHeader()", (response, contentLength) -> response.addIntHeader(CONTENT_LENGTH, contentLength)),
named("setHeader()", (response, contentLength) -> response.setHeader(CONTENT_LENGTH, "" + contentLength)),
named("addHeader()", (response, contentLength) -> response.addHeader(CONTENT_LENGTH, "" + contentLength))
);
}
@ParameterizedTest(name = "[{index}] {0}")
@MethodSource("setContentTypeFunctions")
void copyBodyToResponseWithOverridingContentType(SetContentType setContentType) throws Exception {
byte[] responseBody = "Hello World".getBytes(UTF_8);
int responseLength = responseBody.length;
String originalContentType = MediaType.TEXT_PLAIN_VALUE;
String overridingContentType = MediaType.APPLICATION_JSON_VALUE;
MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentType(originalContentType);
ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response);
assertContentTypeHeader(response, originalContentType);
assertContentTypeHeader(responseWrapper, originalContentType);
setContentType.invoke(responseWrapper, overridingContentType);
assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE);
assertContentTypeHeader(response, overridingContentType);
assertContentTypeHeader(responseWrapper, overridingContentType);
FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream());
@ -151,7 +195,6 @@ public class ContentCachingResponseWrapperTests { @@ -151,7 +195,6 @@ public class ContentCachingResponseWrapperTests {
responseWrapper.copyBodyToResponse();
assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
@ -160,24 +203,19 @@ public class ContentCachingResponseWrapperTests { @@ -160,24 +203,19 @@ public class ContentCachingResponseWrapperTests {
assertContentTypeHeader(response, overridingContentType);
assertContentTypeHeader(responseWrapper, overridingContentType);
assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
assertThat(response.getContentLength()).isEqualTo(responseLength);
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
assertThat(response.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
}
private static Stream<Arguments> setContentTypeFunctions() {
private static Stream<Named<SetContentType>> setContentTypeFunctions() {
return Stream.of(
namedArguments("setContentType()", HttpServletResponse::setContentType),
namedArguments("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)),
namedArguments("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType))
named("setContentType()", HttpServletResponse::setContentType),
named("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)),
named("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType))
);
}
private static Arguments namedArguments(String name, BiConsumer<HttpServletResponse, String> setContentTypeFunction) {
return arguments(named(name, setContentTypeFunction));
}
@Test
void copyBodyToResponseWithTransferEncoding() throws Exception {
byte[] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n".getBytes(UTF_8);
@ -217,4 +255,15 @@ public class ContentCachingResponseWrapperTests { @@ -217,4 +255,15 @@ public class ContentCachingResponseWrapperTests {
assertThat(response.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType);
}
@FunctionalInterface
private interface SetContentLength {
void invoke(HttpServletResponse response, int contentLength);
}
@FunctionalInterface
private interface SetContentType {
void invoke(HttpServletResponse response, String contentType);
}
}

7
spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java

@ -28,6 +28,7 @@ import org.springframework.web.testfixture.servlet.MockHttpServletResponse; @@ -28,6 +28,7 @@ import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE;
import static org.springframework.http.MediaType.TEXT_PLAIN_VALUE;
/**
@ -36,6 +37,7 @@ import static org.springframework.http.MediaType.TEXT_PLAIN_VALUE; @@ -36,6 +37,7 @@ import static org.springframework.http.MediaType.TEXT_PLAIN_VALUE;
* @author Arjen Poutsma
* @author Brian Clozel
* @author Juergen Hoeller
* @author Sam Brannen
*/
class ShallowEtagHeaderFilterTests {
@ -123,7 +125,7 @@ class ShallowEtagHeaderFilterTests { @@ -123,7 +125,7 @@ class ShallowEtagHeaderFilterTests {
assertThat(response.getStatus()).as("Invalid status").isEqualTo(304);
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\"");
assertThat(response.containsHeader("Content-Length")).as("Response has Content-Length header").isFalse();
assertThat(response.containsHeader("Content-Type")).as("Response has Content-Type header").isFalse();
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(TEXT_PLAIN_VALUE);
assertThat(response.getContentAsByteArray()).as("Invalid content").isEmpty();
}
@ -173,11 +175,13 @@ class ShallowEtagHeaderFilterTests { @@ -173,11 +175,13 @@ class ShallowEtagHeaderFilterTests {
void filterWriterWithDisabledCaching() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels");
MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentType(TEXT_PLAIN_VALUE);
byte[] responseBody = "Hello World".getBytes(UTF_8);
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
filterResponse.setContentType(APPLICATION_JSON_VALUE);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
};
@ -186,6 +190,7 @@ class ShallowEtagHeaderFilterTests { @@ -186,6 +190,7 @@ class ShallowEtagHeaderFilterTests {
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getHeader("ETag")).isNull();
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(APPLICATION_JSON_VALUE);
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
}

Loading…
Cancel
Save