@ -1,5 +1,5 @@
/ *
/ *
* Copyright 2002 - 2023 the original author or authors .
* Copyright 2002 - 2024 the original author or authors .
*
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
* you may not use this file except in compliance with the License .
@ -16,13 +16,12 @@
package org.springframework.web.filter ;
package org.springframework.web.filter ;
import java.util.function.BiConsumer ;
import java.util.stream.Stream ;
import java.util.stream.Stream ;
import jakarta.servlet.http.HttpServletResponse ;
import jakarta.servlet.http.HttpServletResponse ;
import org.junit.jupiter.api.Named ;
import org.junit.jupiter.api.Test ;
import org.junit.jupiter.api.Test ;
import org.junit.jupiter.params.ParameterizedTest ;
import org.junit.jupiter.params.ParameterizedTest ;
import org.junit.jupiter.params.provider.Arguments ;
import org.junit.jupiter.params.provider.MethodSource ;
import org.junit.jupiter.params.provider.MethodSource ;
import org.springframework.http.MediaType ;
import org.springframework.http.MediaType ;
@ -33,17 +32,17 @@ import org.springframework.web.util.ContentCachingResponseWrapper;
import static java.nio.charset.StandardCharsets.UTF_8 ;
import static java.nio.charset.StandardCharsets.UTF_8 ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.junit.jupiter.api.Named.named ;
import static org.junit.jupiter.api.Named.named ;
import static org.junit.jupiter.params.provider.Arguments.arguments ;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH ;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH ;
import static org.springframework.http.HttpHeaders.CONTENT_TYPE ;
import static org.springframework.http.HttpHeaders.CONTENT_TYPE ;
import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING ;
import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING ;
/ * *
/ * *
* Unit tests for { @link ContentCachingResponseWrapper } .
* Unit tests for { @link ContentCachingResponseWrapper } .
*
* @author Rossen Stoyanchev
* @author Rossen Stoyanchev
* @author Sam Brannen
* @author Sam Brannen
* /
* /
public class ContentCachingResponseWrapperTests {
class ContentCachingResponseWrapperTests {
@Test
@Test
void copyBodyToResponse ( ) throws Exception {
void copyBodyToResponse ( ) throws Exception {
@ -119,31 +118,76 @@ public class ContentCachingResponseWrapperTests {
}
}
@ParameterizedTest ( name = "[{index}] {0}" )
@ParameterizedTest ( name = "[{index}] {0}" )
@MethodSource ( "setContentType Functions" )
@MethodSource ( "setContentLength Functions" )
void copyBodyToResponseWithOverridingHeaders ( BiConsumer < HttpServletResponse , String > setContentType ) throws Exception {
void copyBodyToResponseWithOverridingContentLength ( SetContentLength setContentLength ) throws Exception {
byte [ ] responseBody = "Hello World" . getBytes ( UTF_8 ) ;
byte [ ] responseBody = "Hello World" . getBytes ( UTF_8 ) ;
int responseLength = responseBody . length ;
int responseLength = responseBody . length ;
int originalContentLength = 11 ;
int originalContentLength = 11 ;
int overridingContentLength = 22 ;
int overridingContentLength = 22 ;
String originalContentType = MediaType . TEXT_PLAIN_VALUE ;
String overridingContentType = MediaType . APPLICATION_JSON_VALUE ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
response . setContentLength ( originalContentLength ) ;
response . setContentLength ( originalContentLength ) ;
response . setContentType ( originalContentType ) ;
ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper ( response ) ;
ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper ( response ) ;
responseWrapper . setStatus ( HttpServletResponse . SC_CREATED ) ;
responseWrapper . setContentLength ( overridingContentLength ) ;
responseWrapper . setContentLength ( overridingContentLength ) ;
setContentType . accept ( responseWrapper , overridingContentType ) ;
assertThat ( responseWrapper . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_CREATED ) ;
setContentLength . invoke ( responseWrapper , overridingContentLength ) ;
assertThat ( responseWrapper . getContentSize ( ) ) . isZero ( ) ;
assertThat ( responseWrapper . getContentSize ( ) ) . isZero ( ) ;
assertThat ( responseWrapper . getHeaderNames ( ) ) . containsExactlyInAnyOrder ( CONTENT_TYPE , CONTENT_ LENGTH ) ;
assertThat ( responseWrapper . getHeaderNames ( ) ) . containsExactlyInAnyOrder ( CONTENT_LENGTH ) ;
assertHeader ( response , CONTENT_LENGTH , originalContentLength ) ;
assertHeader ( response , CONTENT_LENGTH , originalContentLength ) ;
assertHeader ( responseWrapper , CONTENT_LENGTH , overridingContentLength ) ;
assertHeader ( responseWrapper , CONTENT_LENGTH , overridingContentLength ) ;
FileCopyUtils . copy ( responseBody , responseWrapper . getOutputStream ( ) ) ;
assertThat ( responseWrapper . getContentSize ( ) ) . isEqualTo ( responseLength ) ;
responseWrapper . copyBodyToResponse ( ) ;
assertThat ( responseWrapper . getContentSize ( ) ) . isZero ( ) ;
assertThat ( responseWrapper . getHeaderNames ( ) ) . containsExactlyInAnyOrder ( CONTENT_LENGTH ) ;
assertHeader ( response , CONTENT_LENGTH , responseLength ) ;
assertHeader ( responseWrapper , CONTENT_LENGTH , responseLength ) ;
assertThat ( response . getContentLength ( ) ) . isEqualTo ( responseLength ) ;
assertThat ( response . getContentAsByteArray ( ) ) . isEqualTo ( responseBody ) ;
assertThat ( response . getHeaderNames ( ) ) . containsExactlyInAnyOrder ( CONTENT_LENGTH ) ;
}
private static Stream < Named < SetContentLength > > setContentLengthFunctions ( ) {
return Stream . of (
named ( "setContentLength()" , HttpServletResponse : : setContentLength ) ,
named ( "setContentLengthLong()" , HttpServletResponse : : setContentLengthLong ) ,
named ( "setIntHeader()" , ( response , contentLength ) - > response . setIntHeader ( CONTENT_LENGTH , contentLength ) ) ,
named ( "addIntHeader()" , ( response , contentLength ) - > response . addIntHeader ( CONTENT_LENGTH , contentLength ) ) ,
named ( "setHeader()" , ( response , contentLength ) - > response . setHeader ( CONTENT_LENGTH , "" + contentLength ) ) ,
named ( "addHeader()" , ( response , contentLength ) - > response . addHeader ( CONTENT_LENGTH , "" + contentLength ) )
) ;
}
@ParameterizedTest ( name = "[{index}] {0}" )
@MethodSource ( "setContentTypeFunctions" )
void copyBodyToResponseWithOverridingContentType ( SetContentType setContentType ) throws Exception {
byte [ ] responseBody = "Hello World" . getBytes ( UTF_8 ) ;
int responseLength = responseBody . length ;
String originalContentType = MediaType . TEXT_PLAIN_VALUE ;
String overridingContentType = MediaType . APPLICATION_JSON_VALUE ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
response . setContentType ( originalContentType ) ;
ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper ( response ) ;
assertContentTypeHeader ( response , originalContentType ) ;
assertContentTypeHeader ( response , originalContentType ) ;
assertContentTypeHeader ( responseWrapper , originalContentType ) ;
setContentType . invoke ( responseWrapper , overridingContentType ) ;
assertThat ( responseWrapper . getContentSize ( ) ) . isZero ( ) ;
assertThat ( responseWrapper . getHeaderNames ( ) ) . containsExactlyInAnyOrder ( CONTENT_TYPE ) ;
assertContentTypeHeader ( response , overridingContentType ) ;
assertContentTypeHeader ( responseWrapper , overridingContentType ) ;
assertContentTypeHeader ( responseWrapper , overridingContentType ) ;
FileCopyUtils . copy ( responseBody , responseWrapper . getOutputStream ( ) ) ;
FileCopyUtils . copy ( responseBody , responseWrapper . getOutputStream ( ) ) ;
@ -151,7 +195,6 @@ public class ContentCachingResponseWrapperTests {
responseWrapper . copyBodyToResponse ( ) ;
responseWrapper . copyBodyToResponse ( ) ;
assertThat ( responseWrapper . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_CREATED ) ;
assertThat ( responseWrapper . getContentSize ( ) ) . isZero ( ) ;
assertThat ( responseWrapper . getContentSize ( ) ) . isZero ( ) ;
assertThat ( responseWrapper . getHeaderNames ( ) ) . containsExactlyInAnyOrder ( CONTENT_TYPE , CONTENT_LENGTH ) ;
assertThat ( responseWrapper . getHeaderNames ( ) ) . containsExactlyInAnyOrder ( CONTENT_TYPE , CONTENT_LENGTH ) ;
@ -160,24 +203,19 @@ public class ContentCachingResponseWrapperTests {
assertContentTypeHeader ( response , overridingContentType ) ;
assertContentTypeHeader ( response , overridingContentType ) ;
assertContentTypeHeader ( responseWrapper , overridingContentType ) ;
assertContentTypeHeader ( responseWrapper , overridingContentType ) ;
assertThat ( response . getStatus ( ) ) . isEqualTo ( HttpServletResponse . SC_CREATED ) ;
assertThat ( response . getContentLength ( ) ) . isEqualTo ( responseLength ) ;
assertThat ( response . getContentLength ( ) ) . isEqualTo ( responseLength ) ;
assertThat ( response . getContentAsByteArray ( ) ) . isEqualTo ( responseBody ) ;
assertThat ( response . getContentAsByteArray ( ) ) . isEqualTo ( responseBody ) ;
assertThat ( response . getHeaderNames ( ) ) . containsExactlyInAnyOrder ( CONTENT_TYPE , CONTENT_LENGTH ) ;
assertThat ( response . getHeaderNames ( ) ) . containsExactlyInAnyOrder ( CONTENT_TYPE , CONTENT_LENGTH ) ;
}
}
private static Stream < Arguments > setContentTypeFunctions ( ) {
private static Stream < Named < SetContentType > > setContentTypeFunctions ( ) {
return Stream . of (
return Stream . of (
namedArguments ( "setContentType()" , HttpServletResponse : : setContentType ) ,
named ( "setContentType()" , HttpServletResponse : : setContentType ) ,
namedArguments ( "setHeader()" , ( response , contentType ) - > response . setHeader ( CONTENT_TYPE , contentType ) ) ,
named ( "setHeader()" , ( response , contentType ) - > response . setHeader ( CONTENT_TYPE , contentType ) ) ,
namedArguments ( "addHeader()" , ( response , contentType ) - > response . addHeader ( CONTENT_TYPE , contentType ) )
named ( "addHeader()" , ( response , contentType ) - > response . addHeader ( CONTENT_TYPE , contentType ) )
) ;
) ;
}
}
private static Arguments namedArguments ( String name , BiConsumer < HttpServletResponse , String > setContentTypeFunction ) {
return arguments ( named ( name , setContentTypeFunction ) ) ;
}
@Test
@Test
void copyBodyToResponseWithTransferEncoding ( ) throws Exception {
void copyBodyToResponseWithTransferEncoding ( ) throws Exception {
byte [ ] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n" . getBytes ( UTF_8 ) ;
byte [ ] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n" . getBytes ( UTF_8 ) ;
@ -217,4 +255,15 @@ public class ContentCachingResponseWrapperTests {
assertThat ( response . getContentType ( ) ) . as ( CONTENT_TYPE ) . isEqualTo ( contentType ) ;
assertThat ( response . getContentType ( ) ) . as ( CONTENT_TYPE ) . isEqualTo ( contentType ) ;
}
}
@FunctionalInterface
private interface SetContentLength {
void invoke ( HttpServletResponse response , int contentLength ) ;
}
@FunctionalInterface
private interface SetContentType {
void invoke ( HttpServletResponse response , String contentType ) ;
}
}
}