diff --git a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java index 31e66eddf97..4b47dd3c769 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java @@ -18,8 +18,12 @@ package org.springframework.web.filter; import java.io.IOException; import java.io.InputStream; +import java.io.PrintWriter; + import javax.servlet.FilterChain; import javax.servlet.ServletException; +import javax.servlet.ServletOutputStream; +import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -55,11 +59,12 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { private static final String DIRECTIVE_NO_STORE = "no-store"; - /** Checking for Servlet 3.0+ HttpServletResponse.getHeader(String) */ private static final boolean responseGetHeaderAvailable = ClassUtils.hasMethod(HttpServletResponse.class, "getHeader", String.class); + private static final String STREAMING_ATTRIBUTE = ShallowEtagHeaderFilter.class.getName() + ".STREAMING"; + /** * The default value is "false" so that the filter may delay the generation of @@ -76,12 +81,12 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { HttpServletResponse responseToUse = response; if (!isAsyncDispatch(request) && !(response instanceof ContentCachingResponseWrapper)) { - responseToUse = new ContentCachingResponseWrapper(response); + responseToUse = new HttpStreamingAwareContentCachingResponseWrapper(response, request); } filterChain.doFilter(request, responseToUse); - if (!isAsyncStarted(request)) { + if (!isAsyncStarted(request) && !isContentCachingDisabled(request)) { updateResponse(request, responseToUse); } } @@ -90,7 +95,6 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { ContentCachingResponseWrapper responseWrapper = WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class); Assert.notNull(responseWrapper, "ShallowEtagResponseWrapper not found"); - HttpServletResponse rawResponse = (HttpServletResponse) responseWrapper.getResponse(); int statusCode = responseWrapper.getStatusCode(); @@ -169,4 +173,48 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { return builder.toString(); } + /** + * This method can be used to disable the content caching response wrapper + * of the ShallowEtagHeaderFilter. This can be done before the start of HTTP + * streaming for example where the response will be written to asynchronously + * and not in the context of a Servlet container thread. + * @since 4.2 + */ + public static void disableContentCaching(ServletRequest request) { + Assert.notNull(request); + request.setAttribute(STREAMING_ATTRIBUTE, true); + } + + private static boolean isContentCachingDisabled(HttpServletRequest request) { + return (request.getAttribute(STREAMING_ATTRIBUTE) != null); + } + + + private static class HttpStreamingAwareContentCachingResponseWrapper extends ContentCachingResponseWrapper { + + private final HttpServletRequest request; + + + public HttpStreamingAwareContentCachingResponseWrapper(HttpServletResponse response, + HttpServletRequest request) { + + super(response); + this.request = request; + } + + @Override + public ServletOutputStream getOutputStream() throws IOException { + return (useRawResponse() ? getResponse().getOutputStream() : super.getOutputStream()); + } + + @Override + public PrintWriter getWriter() throws IOException { + return (useRawResponse() ? getResponse().getWriter() : super.getWriter()); + } + + private boolean useRawResponse() { + return isContentCachingDisabled(this.request); + } + } + } diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java index f28d5b2b528..8a5b94cb224 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java @@ -91,7 +91,7 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { } @Override - public ServletOutputStream getOutputStream() { + public ServletOutputStream getOutputStream() throws IOException { return this.outputStream; } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java index da796ac3f85..1c40e6fcbc6 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java @@ -69,14 +69,10 @@ public class ShallowEtagHeaderFilterTests { MockHttpServletResponse response = new MockHttpServletResponse(); final byte[] responseBody = "Hello World".getBytes("UTF-8"); - FilterChain filterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) - throws IOException, ServletException { - assertEquals("Invalid request passed", request, filterRequest); - ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); - FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); - } + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); }; filter.doFilter(request, response, filterChain); @@ -93,15 +89,11 @@ public class ShallowEtagHeaderFilterTests { request.addHeader("If-None-Match", etag); MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) - throws IOException, ServletException { - assertEquals("Invalid request passed", request, filterRequest); - byte[] responseBody = "Hello World".getBytes("UTF-8"); - FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); - filterResponse.setContentLength(responseBody.length); - } + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + byte[] responseBody = "Hello World".getBytes("UTF-8"); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + filterResponse.setContentLength(responseBody.length); }; filter.doFilter(request, response, filterChain); @@ -118,15 +110,11 @@ public class ShallowEtagHeaderFilterTests { request.addHeader("If-None-Match", etag); MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) - throws IOException, ServletException { - assertEquals("Invalid request passed", request, filterRequest); - ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); - String responseBody = "Hello World"; - FileCopyUtils.copy(responseBody, filterResponse.getWriter()); - } + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + String responseBody = "Hello World"; + FileCopyUtils.copy(responseBody, filterResponse.getWriter()); }; filter.doFilter(request, response, filterChain); @@ -136,20 +124,38 @@ public class ShallowEtagHeaderFilterTests { assertArrayEquals("Invalid content", new byte[0], response.getContentAsByteArray()); } + // SPR-12960 + + @Test + public void filterWriterWithDisabledCaching() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + final byte[] responseBody = "Hello World".getBytes("UTF-8"); + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + }; + + ShallowEtagHeaderFilter.disableContentCaching(request); + this.filter.doFilter(request, response, filterChain); + + assertEquals(200, response.getStatus()); + assertNull(response.getHeader("ETag")); + assertArrayEquals(responseBody, response.getContentAsByteArray()); + } + @Test public void filterSendError() throws Exception { final MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); MockHttpServletResponse response = new MockHttpServletResponse(); final byte[] responseBody = "Hello World".getBytes("UTF-8"); - FilterChain filterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) - throws IOException, ServletException { - assertEquals("Invalid request passed", request, filterRequest); - FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); - ((HttpServletResponse) filterResponse).sendError(HttpServletResponse.SC_FORBIDDEN); - } + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + ((HttpServletResponse) filterResponse).sendError(HttpServletResponse.SC_FORBIDDEN); }; filter.doFilter(request, response, filterChain); @@ -165,14 +171,10 @@ public class ShallowEtagHeaderFilterTests { MockHttpServletResponse response = new MockHttpServletResponse(); final byte[] responseBody = "Hello World".getBytes("UTF-8"); - FilterChain filterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) - throws IOException, ServletException { - assertEquals("Invalid request passed", request, filterRequest); - FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); - ((HttpServletResponse) filterResponse).sendError(HttpServletResponse.SC_FORBIDDEN, "ERROR"); - } + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + ((HttpServletResponse) filterResponse).sendError(HttpServletResponse.SC_FORBIDDEN, "ERROR"); }; filter.doFilter(request, response, filterChain); @@ -189,14 +191,10 @@ public class ShallowEtagHeaderFilterTests { MockHttpServletResponse response = new MockHttpServletResponse(); final byte[] responseBody = "Hello World".getBytes("UTF-8"); - FilterChain filterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) - throws IOException, ServletException { - assertEquals("Invalid request passed", request, filterRequest); - FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); - ((HttpServletResponse) filterResponse).sendRedirect("http://www.google.com"); - } + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertEquals("Invalid request passed", request, filterRequest); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); + ((HttpServletResponse) filterResponse).sendRedirect("http://www.google.com"); }; filter.doFilter(request, response, filterChain); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java index fe99556b966..338a4faae62 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java @@ -19,6 +19,8 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.io.IOException; import java.io.OutputStream; import java.util.List; + +import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; @@ -37,6 +39,7 @@ import org.springframework.util.Assert; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.WebAsyncUtils; +import org.springframework.web.filter.ShallowEtagHeaderFilter; import org.springframework.web.method.support.HandlerMethodReturnValueHandler; import org.springframework.web.method.support.ModelAndViewContainer; @@ -95,6 +98,9 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur } } + ServletRequest request = webRequest.getNativeRequest(ServletRequest.class); + ShallowEtagHeaderFilter.disableContentCaching(request); + Assert.isInstanceOf(ResponseBodyEmitter.class, returnValue); ResponseBodyEmitter emitter = (ResponseBodyEmitter) returnValue; emitter.extendResponse(outputMessage); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/StreamingResponseBodyReturnValueHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/StreamingResponseBodyReturnValueHandler.java index 3e1a09148f3..40852899cb0 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/StreamingResponseBodyReturnValueHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/StreamingResponseBodyReturnValueHandler.java @@ -18,11 +18,9 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.io.OutputStream; import java.util.concurrent.Callable; +import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletResponse; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; import org.springframework.http.ResponseEntity; @@ -31,6 +29,7 @@ import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.async.WebAsyncUtils; +import org.springframework.web.filter.ShallowEtagHeaderFilter; import org.springframework.web.method.support.HandlerMethodReturnValueHandler; import org.springframework.web.method.support.ModelAndViewContainer; @@ -45,9 +44,6 @@ import org.springframework.web.method.support.ModelAndViewContainer; */ public class StreamingResponseBodyReturnValueHandler implements HandlerMethodReturnValueHandler { - private static final Log logger = LogFactory.getLog(StreamingResponseBodyReturnValueHandler.class); - - @Override public boolean supportsReturnType(MethodParameter returnType) { if (StreamingResponseBody.class.isAssignableFrom(returnType.getParameterType())) { @@ -84,6 +80,9 @@ public class StreamingResponseBodyReturnValueHandler implements HandlerMethodRet } } + ServletRequest request = webRequest.getNativeRequest(ServletRequest.class); + ShallowEtagHeaderFilter.disableContentCaching(request); + Assert.isInstanceOf(StreamingResponseBody.class, returnValue); StreamingResponseBody streamingBody = (StreamingResponseBody) returnValue; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index 8f2bba16609..869e08956f8 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -26,10 +26,14 @@ import java.util.Map; import java.util.Queue; import java.util.concurrent.LinkedBlockingQueue; +import javax.servlet.ServletRequest; + import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpAsyncRequestControl; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.web.filter.ShallowEtagHeaderFilter; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; @@ -202,6 +206,8 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.asyncRequestControl = request.getAsyncRequestControl(response); this.asyncRequestControl.start(-1); + disableShallowEtagHeaderFilter(request); + // Let "our" handler know before sending the open frame to the remote handler delegateConnectionEstablished(); @@ -243,6 +249,8 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.asyncRequestControl = request.getAsyncRequestControl(response); this.asyncRequestControl.start(-1); + disableShallowEtagHeaderFilter(request); + handleRequestInternal(request, response, false); this.readyToSend = isActive(); } @@ -253,6 +261,13 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { } } + private void disableShallowEtagHeaderFilter(ServerHttpRequest request) { + if (request instanceof ServletServerHttpRequest) { + ServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + ShallowEtagHeaderFilter.disableContentCaching(servletRequest); + } + } + /** * Invoked when a SockJS transport request is received. * @param request the current request