@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2022 the original author or authors .
* Copyright 2002 - 2024 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -21,6 +21,10 @@ import java.io.InputStream;
@@ -21,6 +21,10 @@ import java.io.InputStream;
import java.io.OutputStreamWriter ;
import java.io.PrintWriter ;
import java.io.UnsupportedEncodingException ;
import java.util.ArrayList ;
import java.util.Collection ;
import java.util.Collections ;
import java.util.List ;
import jakarta.servlet.ServletOutputStream ;
import jakarta.servlet.WriteListener ;
@ -55,6 +59,9 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@@ -55,6 +59,9 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@Nullable
private Integer contentLength ;
@Nullable
private String contentType ;
/ * *
* Create a new ContentCachingResponseWrapper for the given servlet response .
@ -139,6 +146,122 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@@ -139,6 +146,122 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
this . contentLength = lenInt ;
}
@Override
public void setContentType ( String type ) {
this . contentType = type ;
}
@Override
@Nullable
public String getContentType ( ) {
return this . contentType ;
}
@Override
public boolean containsHeader ( String name ) {
if ( HttpHeaders . CONTENT_LENGTH . equalsIgnoreCase ( name ) ) {
return this . contentLength ! = null ;
}
else if ( HttpHeaders . CONTENT_TYPE . equalsIgnoreCase ( name ) ) {
return this . contentType ! = null ;
}
else {
return super . containsHeader ( name ) ;
}
}
@Override
public void setHeader ( String name , String value ) {
if ( HttpHeaders . CONTENT_LENGTH . equalsIgnoreCase ( name ) ) {
this . contentLength = Integer . valueOf ( value ) ;
}
else if ( HttpHeaders . CONTENT_TYPE . equalsIgnoreCase ( name ) ) {
this . contentType = value ;
}
else {
super . setHeader ( name , value ) ;
}
}
@Override
public void addHeader ( String name , String value ) {
if ( HttpHeaders . CONTENT_LENGTH . equalsIgnoreCase ( name ) ) {
this . contentLength = Integer . valueOf ( value ) ;
}
else if ( HttpHeaders . CONTENT_TYPE . equalsIgnoreCase ( name ) ) {
this . contentType = value ;
}
else {
super . addHeader ( name , value ) ;
}
}
@Override
public void setIntHeader ( String name , int value ) {
if ( HttpHeaders . CONTENT_LENGTH . equalsIgnoreCase ( name ) ) {
this . contentLength = Integer . valueOf ( value ) ;
}
else {
super . setIntHeader ( name , value ) ;
}
}
@Override
public void addIntHeader ( String name , int value ) {
if ( HttpHeaders . CONTENT_LENGTH . equalsIgnoreCase ( name ) ) {
this . contentLength = Integer . valueOf ( value ) ;
}
else {
super . addIntHeader ( name , value ) ;
}
}
@Override
@Nullable
public String getHeader ( String name ) {
if ( HttpHeaders . CONTENT_LENGTH . equalsIgnoreCase ( name ) ) {
return ( this . contentLength ! = null ) ? this . contentLength . toString ( ) : null ;
}
else if ( HttpHeaders . CONTENT_TYPE . equalsIgnoreCase ( name ) ) {
return this . contentType ;
}
else {
return super . getHeader ( name ) ;
}
}
@Override
public Collection < String > getHeaders ( String name ) {
if ( HttpHeaders . CONTENT_LENGTH . equalsIgnoreCase ( name ) ) {
return this . contentLength ! = null ? Collections . singleton ( this . contentLength . toString ( ) ) :
Collections . emptySet ( ) ;
}
else if ( HttpHeaders . CONTENT_TYPE . equalsIgnoreCase ( name ) ) {
return this . contentType ! = null ? Collections . singleton ( this . contentType ) : Collections . emptySet ( ) ;
}
else {
return super . getHeaders ( name ) ;
}
}
@Override
public Collection < String > getHeaderNames ( ) {
Collection < String > headerNames = super . getHeaderNames ( ) ;
if ( this . contentLength ! = null | | this . contentType ! = null ) {
List < String > result = new ArrayList < > ( headerNames ) ;
if ( this . contentLength ! = null ) {
result . add ( HttpHeaders . CONTENT_LENGTH ) ;
}
if ( this . contentType ! = null ) {
result . add ( HttpHeaders . CONTENT_TYPE ) ;
}
return result ;
}
else {
return headerNames ;
}
}
@Override
public void setBufferSize ( int size ) {
if ( size > this . content . size ( ) ) {
@ -197,11 +320,17 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@@ -197,11 +320,17 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
protected void copyBodyToResponse ( boolean complete ) throws IOException {
if ( this . content . size ( ) > 0 ) {
HttpServletResponse rawResponse = ( HttpServletResponse ) getResponse ( ) ;
if ( ( complete | | this . contentLength ! = null ) & & ! rawResponse . isCommitted ( ) ) {
if ( rawResponse . getHeader ( HttpHeaders . TRANSFER_ENCODING ) = = null ) {
rawResponse . setContentLength ( complete ? this . content . size ( ) : this . contentLength ) ;
if ( ! rawResponse . isCommitted ( ) ) {
if ( complete | | this . contentLength ! = null ) {
if ( rawResponse . getHeader ( HttpHeaders . TRANSFER_ENCODING ) = = null ) {
rawResponse . setContentLength ( complete ? this . content . size ( ) : this . contentLength ) ;
}
this . contentLength = null ;
}
if ( complete | | this . contentType ! = null ) {
rawResponse . setContentType ( this . contentType ) ;
this . contentType = null ;
}
this . contentLength = null ;
}
this . content . writeTo ( rawResponse . getOutputStream ( ) ) ;
this . content . reset ( ) ;