From 1fcc2fcd8845e50d964b3c146907b90c3f3f05c2 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 15 Mar 2016 11:16:44 -0500 Subject: [PATCH 1/2] Make OnCommittedResponseWrapper public This is preparing for changes in gh-2953 Issues gh-2953 --- ...ContextOnUpdateOrErrorResponseWrapper.java | 21 +- .../OnCommittedResponseWrapper.java | 280 +++++++++++------- .../OnCommittedResponseWrapperTests.java | 4 +- 3 files changed, 189 insertions(+), 116 deletions(-) rename web/src/main/java/org/springframework/security/web/{context => util}/OnCommittedResponseWrapper.java (65%) rename web/src/test/java/org/springframework/security/web/{context => util}/OnCommittedResponseWrapperTests.java (99%) diff --git a/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java b/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java index e885e551ed..65e177e09f 100644 --- a/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java +++ b/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java @@ -17,10 +17,9 @@ package org.springframework.security.web.context; import javax.servlet.http.HttpServletResponse; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.util.OnCommittedResponseWrapper; /** * Base class for response wrappers which encapsulate the logic for storing a security @@ -40,10 +39,8 @@ import org.springframework.security.core.context.SecurityContextHolder; * @author Rob Winch * @since 3.0 */ -public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends - OnCommittedResponseWrapper { - private final Log logger = LogFactory.getLog(getClass()); - +public abstract class SaveContextOnUpdateOrErrorResponseWrapper + extends OnCommittedResponseWrapper { private boolean contextSaved = false; /* See SEC-1052 */ @@ -86,12 +83,12 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends @Override protected void onResponseCommitted() { saveContext(SecurityContextHolder.getContext()); - contextSaved = true; + this.contextSaved = true; } @Override public final String encodeRedirectUrl(String url) { - if (disableUrlRewriting) { + if (this.disableUrlRewriting) { return url; } return super.encodeRedirectUrl(url); @@ -99,7 +96,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends @Override public final String encodeRedirectURL(String url) { - if (disableUrlRewriting) { + if (this.disableUrlRewriting) { return url; } return super.encodeRedirectURL(url); @@ -107,7 +104,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends @Override public final String encodeUrl(String url) { - if (disableUrlRewriting) { + if (this.disableUrlRewriting) { return url; } return super.encodeUrl(url); @@ -115,7 +112,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends @Override public final String encodeURL(String url) { - if (disableUrlRewriting) { + if (this.disableUrlRewriting) { return url; } return super.encodeURL(url); @@ -126,6 +123,6 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends * wrapper. */ public final boolean isContextSaved() { - return contextSaved; + return this.contextSaved; } } diff --git a/web/src/main/java/org/springframework/security/web/context/OnCommittedResponseWrapper.java b/web/src/main/java/org/springframework/security/web/util/OnCommittedResponseWrapper.java similarity index 65% rename from web/src/main/java/org/springframework/security/web/context/OnCommittedResponseWrapper.java rename to web/src/main/java/org/springframework/security/web/util/OnCommittedResponseWrapper.java index 742b88eb0a..14b454816a 100644 --- a/web/src/main/java/org/springframework/security/web/context/OnCommittedResponseWrapper.java +++ b/web/src/main/java/org/springframework/security/web/util/OnCommittedResponseWrapper.java @@ -13,33 +13,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.web.context; +package org.springframework.security.web.util; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Locale; import javax.servlet.ServletOutputStream; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.Locale; /** - * Base class for response wrappers which encapsulate the logic for handling an event when the - * {@link javax.servlet.http.HttpServletResponse} is committed. + * Base class for response wrappers which encapsulate the logic for handling an event when + * the {@link javax.servlet.http.HttpServletResponse} is committed. * * @since 4.0.2 * @author Rob Winch */ -abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { - private final Log logger = LogFactory.getLog(getClass()); +public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { private boolean disableOnCommitted; /** - * The Content-Length response header. If this is greater than 0, then once {@link #contentWritten} is larger than - * or equal the response is considered committed. + * The Content-Length response header. If this is greater than 0, then once + * {@link #contentWritten} is larger than or equal the response is considered + * committed. */ private long contentLength; @@ -57,7 +55,7 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { @Override public void addHeader(String name, String value) { - if("Content-Length".equalsIgnoreCase(name)) { + if ("Content-Length".equalsIgnoreCase(name)) { setContentLength(Long.parseLong(value)); } super.addHeader(name, value); @@ -75,22 +73,33 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { } /** - * Invoke this method to disable invoking {@link OnCommittedResponseWrapper#onResponseCommitted()} when the {@link javax.servlet.http.HttpServletResponse} is - * committed. This can be useful in the event that Async Web Requests are - * made. + * Invoke this method to disable invoking + * {@link OnCommittedResponseWrapper#onResponseCommitted()} when the + * {@link javax.servlet.http.HttpServletResponse} is committed. This can be useful in + * the event that Async Web Requests are made. */ - public void disableOnResponseCommitted() { + protected void disableOnResponseCommitted() { this.disableOnCommitted = true; } /** - * Implement the logic for handling the {@link javax.servlet.http.HttpServletResponse} being committed + * Returns true if {@link #onResponseCommitted()} will be invoked when the response is + * committed, else false. + * @return if {@link #onResponseCommitted()} is enabled + */ + protected boolean isDisableOnResponseCommitted() { + return this.disableOnCommitted; + } + + /** + * Implement the logic for handling the {@link javax.servlet.http.HttpServletResponse} + * being committed */ protected abstract void onResponseCommitted(); /** - * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the - * superclass sendError() + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked + * before calling the superclass sendError() */ @Override public final void sendError(int sc) throws IOException { @@ -99,8 +108,8 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { } /** - * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the - * superclass sendError() + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked + * before calling the superclass sendError() */ @Override public final void sendError(int sc, String msg) throws IOException { @@ -109,8 +118,8 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { } /** - * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the - * superclass sendRedirect() + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked + * before calling the superclass sendRedirect() */ @Override public final void sendRedirect(String location) throws IOException { @@ -119,8 +128,9 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { } /** - * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the calling - * getOutputStream().close() or getOutputStream().flush() + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked + * before calling the calling getOutputStream().close() or + * getOutputStream().flush() */ @Override public ServletOutputStream getOutputStream() throws IOException { @@ -128,8 +138,9 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { } /** - * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the - * getWriter().close() or getWriter().flush() + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked + * before calling the getWriter().close() or + * getWriter().flush() */ @Override public PrintWriter getWriter() throws IOException { @@ -137,8 +148,8 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { } /** - * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the - * superclass flushBuffer() + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked + * before calling the superclass flushBuffer() */ @Override public void flushBuffer() throws IOException { @@ -187,36 +198,38 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { } /** - * Adds the contentLengthToWrite to the total contentWritten size and checks to see if the response should be - * written. + * Adds the contentLengthToWrite to the total contentWritten size and checks to see if + * the response should be written. * * @param contentLengthToWrite the size of the content that is about to be written. */ private void checkContentLength(long contentLengthToWrite) { - contentWritten += contentLengthToWrite; - boolean isBodyFullyWritten = contentLength > 0 && contentWritten >= contentLength; + this.contentWritten += contentLengthToWrite; + boolean isBodyFullyWritten = this.contentLength > 0 + && this.contentWritten >= this.contentLength; int bufferSize = getBufferSize(); - boolean requiresFlush = bufferSize > 0 && contentWritten >= bufferSize; - if(isBodyFullyWritten || requiresFlush) { + boolean requiresFlush = bufferSize > 0 && this.contentWritten >= bufferSize; + if (isBodyFullyWritten || requiresFlush) { doOnResponseCommitted(); } } /** * Calls onResponseCommmitted() with the current contents as long as - * {@link #disableOnResponseCommitted()()} was not invoked. + * {@link #disableOnResponseCommitted()} was not invoked. */ private void doOnResponseCommitted() { - if(!disableOnCommitted) { + if (!this.disableOnCommitted) { onResponseCommitted(); disableOnResponseCommitted(); } } /** - * Ensures {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the prior to methods that commit the response. We delegate all methods - * to the original {@link java.io.PrintWriter} to ensure that the behavior is as close to the original {@link java.io.PrintWriter} - * as possible. See SEC-2039 + * Ensures {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before + * calling the prior to methods that commit the response. We delegate all methods to + * the original {@link java.io.PrintWriter} to ensure that the behavior is as close to + * the original {@link java.io.PrintWriter} as possible. See SEC-2039 * @author Rob Winch */ private class SaveContextPrintWriter extends PrintWriter { @@ -227,197 +240,235 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { this.delegate = delegate; } + @Override public void flush() { doOnResponseCommitted(); - delegate.flush(); + this.delegate.flush(); } + @Override public void close() { doOnResponseCommitted(); - delegate.close(); + this.delegate.close(); } + @Override public int hashCode() { - return delegate.hashCode(); + return this.delegate.hashCode(); } + @Override public boolean equals(Object obj) { - return delegate.equals(obj); + return this.delegate.equals(obj); } + @Override public String toString() { - return getClass().getName() + "[delegate=" + delegate.toString() + "]"; + return getClass().getName() + "[delegate=" + this.delegate.toString() + "]"; } + @Override public boolean checkError() { - return delegate.checkError(); + return this.delegate.checkError(); } + @Override public void write(int c) { trackContentLength(c); - delegate.write(c); + this.delegate.write(c); } + @Override public void write(char[] buf, int off, int len) { checkContentLength(len); - delegate.write(buf, off, len); + this.delegate.write(buf, off, len); } + @Override public void write(char[] buf) { trackContentLength(buf); - delegate.write(buf); + this.delegate.write(buf); } + @Override public void write(String s, int off, int len) { checkContentLength(len); - delegate.write(s, off, len); + this.delegate.write(s, off, len); } + @Override public void write(String s) { trackContentLength(s); - delegate.write(s); + this.delegate.write(s); } + @Override public void print(boolean b) { trackContentLength(b); - delegate.print(b); + this.delegate.print(b); } + @Override public void print(char c) { trackContentLength(c); - delegate.print(c); + this.delegate.print(c); } + @Override public void print(int i) { trackContentLength(i); - delegate.print(i); + this.delegate.print(i); } + @Override public void print(long l) { trackContentLength(l); - delegate.print(l); + this.delegate.print(l); } + @Override public void print(float f) { trackContentLength(f); - delegate.print(f); + this.delegate.print(f); } + @Override public void print(double d) { trackContentLength(d); - delegate.print(d); + this.delegate.print(d); } + @Override public void print(char[] s) { trackContentLength(s); - delegate.print(s); + this.delegate.print(s); } + @Override public void print(String s) { trackContentLength(s); - delegate.print(s); + this.delegate.print(s); } + @Override public void print(Object obj) { trackContentLength(obj); - delegate.print(obj); + this.delegate.print(obj); } + @Override public void println() { trackContentLengthLn(); - delegate.println(); + this.delegate.println(); } + @Override public void println(boolean x) { trackContentLength(x); trackContentLengthLn(); - delegate.println(x); + this.delegate.println(x); } + @Override public void println(char x) { trackContentLength(x); trackContentLengthLn(); - delegate.println(x); + this.delegate.println(x); } + @Override public void println(int x) { trackContentLength(x); trackContentLengthLn(); - delegate.println(x); + this.delegate.println(x); } + @Override public void println(long x) { trackContentLength(x); trackContentLengthLn(); - delegate.println(x); + this.delegate.println(x); } + @Override public void println(float x) { trackContentLength(x); trackContentLengthLn(); - delegate.println(x); + this.delegate.println(x); } + @Override public void println(double x) { trackContentLength(x); trackContentLengthLn(); - delegate.println(x); + this.delegate.println(x); } + @Override public void println(char[] x) { trackContentLength(x); trackContentLengthLn(); - delegate.println(x); + this.delegate.println(x); } + @Override public void println(String x) { trackContentLength(x); trackContentLengthLn(); - delegate.println(x); + this.delegate.println(x); } + @Override public void println(Object x) { trackContentLength(x); trackContentLengthLn(); - delegate.println(x); + this.delegate.println(x); } + @Override public PrintWriter printf(String format, Object... args) { - return delegate.printf(format, args); + return this.delegate.printf(format, args); } + @Override public PrintWriter printf(Locale l, String format, Object... args) { - return delegate.printf(l, format, args); + return this.delegate.printf(l, format, args); } + @Override public PrintWriter format(String format, Object... args) { - return delegate.format(format, args); + return this.delegate.format(format, args); } + @Override public PrintWriter format(Locale l, String format, Object... args) { - return delegate.format(l, format, args); + return this.delegate.format(l, format, args); } + @Override public PrintWriter append(CharSequence csq) { checkContentLength(csq.length()); - return delegate.append(csq); + return this.delegate.append(csq); } + @Override public PrintWriter append(CharSequence csq, int start, int end) { checkContentLength(end - start); - return delegate.append(csq, start, end); + return this.delegate.append(csq, start, end); } + @Override public PrintWriter append(char c) { trackContentLength(c); - return delegate.append(c); + return this.delegate.append(c); } } /** - * Ensures{@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling methods that commit the response. We delegate all methods - * to the original {@link javax.servlet.ServletOutputStream} to ensure that the behavior is as close to the original {@link javax.servlet.ServletOutputStream} - * as possible. See SEC-2039 + * Ensures{@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before + * calling methods that commit the response. We delegate all methods to the original + * {@link javax.servlet.ServletOutputStream} to ensure that the behavior is as close + * to the original {@link javax.servlet.ServletOutputStream} as possible. See SEC-2039 * * @author Rob Winch */ @@ -428,123 +479,146 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { this.delegate = delegate; } + @Override public void write(int b) throws IOException { trackContentLength(b); this.delegate.write(b); } + @Override public void flush() throws IOException { doOnResponseCommitted(); - delegate.flush(); + this.delegate.flush(); } + @Override public void close() throws IOException { doOnResponseCommitted(); - delegate.close(); + this.delegate.close(); } + @Override public int hashCode() { - return delegate.hashCode(); + return this.delegate.hashCode(); } + @Override public boolean equals(Object obj) { - return delegate.equals(obj); + return this.delegate.equals(obj); } + @Override public void print(boolean b) throws IOException { trackContentLength(b); - delegate.print(b); + this.delegate.print(b); } + @Override public void print(char c) throws IOException { trackContentLength(c); - delegate.print(c); + this.delegate.print(c); } + @Override public void print(double d) throws IOException { trackContentLength(d); - delegate.print(d); + this.delegate.print(d); } + @Override public void print(float f) throws IOException { trackContentLength(f); - delegate.print(f); + this.delegate.print(f); } + @Override public void print(int i) throws IOException { trackContentLength(i); - delegate.print(i); + this.delegate.print(i); } + @Override public void print(long l) throws IOException { trackContentLength(l); - delegate.print(l); + this.delegate.print(l); } + @Override public void print(String s) throws IOException { trackContentLength(s); - delegate.print(s); + this.delegate.print(s); } + @Override public void println() throws IOException { trackContentLengthLn(); - delegate.println(); + this.delegate.println(); } + @Override public void println(boolean b) throws IOException { trackContentLength(b); trackContentLengthLn(); - delegate.println(b); + this.delegate.println(b); } + @Override public void println(char c) throws IOException { trackContentLength(c); trackContentLengthLn(); - delegate.println(c); + this.delegate.println(c); } + @Override public void println(double d) throws IOException { trackContentLength(d); trackContentLengthLn(); - delegate.println(d); + this.delegate.println(d); } + @Override public void println(float f) throws IOException { trackContentLength(f); trackContentLengthLn(); - delegate.println(f); + this.delegate.println(f); } + @Override public void println(int i) throws IOException { trackContentLength(i); trackContentLengthLn(); - delegate.println(i); + this.delegate.println(i); } + @Override public void println(long l) throws IOException { trackContentLength(l); trackContentLengthLn(); - delegate.println(l); + this.delegate.println(l); } + @Override public void println(String s) throws IOException { trackContentLength(s); trackContentLengthLn(); - delegate.println(s); + this.delegate.println(s); } + @Override public void write(byte[] b) throws IOException { trackContentLength(b); - delegate.write(b); + this.delegate.write(b); } + @Override public void write(byte[] b, int off, int len) throws IOException { checkContentLength(len); - delegate.write(b, off, len); + this.delegate.write(b, off, len); } + @Override public String toString() { - return getClass().getName() + "[delegate=" + delegate.toString() + "]"; + return getClass().getName() + "[delegate=" + this.delegate.toString() + "]"; } } } \ No newline at end of file diff --git a/web/src/test/java/org/springframework/security/web/context/OnCommittedResponseWrapperTests.java b/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java similarity index 99% rename from web/src/test/java/org/springframework/security/web/context/OnCommittedResponseWrapperTests.java rename to web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java index d47c60e731..130d9cb814 100644 --- a/web/src/test/java/org/springframework/security/web/context/OnCommittedResponseWrapperTests.java +++ b/web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.web.context; +package org.springframework.security.web.util; import java.io.IOException; import java.io.PrintWriter; @@ -25,6 +25,8 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.security.web.util.OnCommittedResponseWrapper; + import javax.servlet.ServletOutputStream; import javax.servlet.http.HttpServletResponse; From 242b831f20c11171975c1e2bdd50c9ae1cdbf445 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 15 Mar 2016 12:30:37 -0500 Subject: [PATCH 2/2] Cache Control only written if not set Previously Spring Security always wrote cache control headers and relied on the application to override the values. This can cause problems with cache control. For example, applications may only set cache control if the header is not already set. Additionally, setting of Cache-Control should disable writing of Pragma. This commit delays writing headers until just before the response is committed and only writes the Cache Control headers if they do not exist. Fixes gh-2953 --- .../web/header/HeaderWriterFilter.java | 58 ++++++++++-- .../writers/CacheControlHeadersWriter.java | 46 +++++++-- .../web/header/HeaderWriterFilterTests.java | 54 +++++++++-- .../CacheControlHeadersWriterTests.java | 94 ++++++++++++++++--- 4 files changed, 217 insertions(+), 35 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java b/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java index e1ca10b473..c78710d416 100644 --- a/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java +++ b/web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java @@ -15,15 +15,17 @@ */ package org.springframework.security.web.header; -import org.springframework.util.Assert; -import org.springframework.web.filter.OncePerRequestFilter; +import java.io.IOException; +import java.util.List; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.util.*; + +import org.springframework.security.web.util.OnCommittedResponseWrapper; +import org.springframework.util.Assert; +import org.springframework.web.filter.OncePerRequestFilter; /** * Filter implementation to add headers to the current request. Can be useful to add @@ -56,12 +58,52 @@ public class HeaderWriterFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { + throws ServletException, IOException { - for (HeaderWriter headerWriter : headerWriters) { - headerWriter.writeHeaders(request, response); + HeaderWriterResponse headerWriterResponse = new HeaderWriterResponse(request, + response, this.headerWriters); + try { + filterChain.doFilter(request, headerWriterResponse); + } + finally { + headerWriterResponse.writeHeaders(); } - filterChain.doFilter(request, response); } + static class HeaderWriterResponse extends OnCommittedResponseWrapper { + private final HttpServletRequest request; + private final List headerWriters; + + HeaderWriterResponse(HttpServletRequest request, HttpServletResponse response, + List headerWriters) { + super(response); + this.request = request; + this.headerWriters = headerWriters; + } + + /* + * (non-Javadoc) + * + * @see org.springframework.security.web.util.OnCommittedResponseWrapper# + * onResponseCommitted() + */ + @Override + protected void onResponseCommitted() { + writeHeaders(); + this.disableOnResponseCommitted(); + } + + protected void writeHeaders() { + if (isDisableOnResponseCommitted()) { + return; + } + for (HeaderWriter headerWriter : this.headerWriters) { + headerWriter.writeHeaders(this.request, getHttpResponse()); + } + } + + private HttpServletResponse getHttpResponse() { + return (HttpServletResponse) getResponse(); + } + } } diff --git a/web/src/main/java/org/springframework/security/web/header/writers/CacheControlHeadersWriter.java b/web/src/main/java/org/springframework/security/web/header/writers/CacheControlHeadersWriter.java index ae6c93443e..d5f115abbb 100644 --- a/web/src/main/java/org/springframework/security/web/header/writers/CacheControlHeadersWriter.java +++ b/web/src/main/java/org/springframework/security/web/header/writers/CacheControlHeadersWriter.java @@ -15,14 +15,20 @@ */ package org.springframework.security.web.header.writers; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.security.web.header.Header; +import org.springframework.security.web.header.HeaderWriter; +import org.springframework.util.ReflectionUtils; /** - * A {@link StaticHeadersWriter} that inserts headers to prevent caching. Specifically it - * adds the following headers: + * Inserts headers to prevent caching if no cache control headers have been specified. + * Specifically it adds the following headers: *
    *
  • Cache-Control: no-cache, no-store, max-age=0, must-revalidate
  • *
  • Pragma: no-cache
  • @@ -32,21 +38,47 @@ import org.springframework.security.web.header.Header; * @author Rob Winch * @since 3.2 */ -public final class CacheControlHeadersWriter extends StaticHeadersWriter { +public final class CacheControlHeadersWriter implements HeaderWriter { + private static final String EXPIRES = "Expires"; + private static final String PRAGMA = "Pragma"; + private static final String CACHE_CONTROL = "Cache-Control"; + + private final Method getHeaderMethod; + + private final HeaderWriter delegate; /** * Creates a new instance */ public CacheControlHeadersWriter() { - super(createHeaders()); + this.delegate = new StaticHeadersWriter(createHeaders()); + this.getHeaderMethod = ReflectionUtils.findMethod(HttpServletResponse.class, + "getHeader", String.class); + } + + @Override + public void writeHeaders(HttpServletRequest request, HttpServletResponse response) { + if (hasHeader(response, CACHE_CONTROL) || hasHeader(response, EXPIRES) + || hasHeader(response, PRAGMA)) { + return; + } + this.delegate.writeHeaders(request, response); + } + + private boolean hasHeader(HttpServletResponse response, String headerName) { + if (this.getHeaderMethod == null) { + return false; + } + return ReflectionUtils.invokeMethod(this.getHeaderMethod, response, + headerName) != null; } private static List
    createHeaders() { List
    headers = new ArrayList
    (2); - headers.add(new Header("Cache-Control", + headers.add(new Header(CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate")); - headers.add(new Header("Pragma", "no-cache")); - headers.add(new Header("Expires", "0")); + headers.add(new Header(PRAGMA, "no-cache")); + headers.add(new Header(EXPIRES, "0")); return headers; } } diff --git a/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java b/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java index a6e9c80a54..2fb465d374 100644 --- a/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java @@ -15,21 +15,32 @@ */ package org.springframework.security.web.header; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.verify; - +import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; + import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.web.header.HeaderWriter; -import org.springframework.security.web.header.HeaderWriterFilter; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; /** * Tests for the {@code HeadersFilter} @@ -60,8 +71,8 @@ public class HeaderWriterFilterTests { @Test public void additionalHeadersShouldBeAddedToTheResponse() throws Exception { List headerWriters = new ArrayList(); - headerWriters.add(writer1); - headerWriters.add(writer2); + headerWriters.add(this.writer1); + headerWriters.add(this.writer2); HeaderWriterFilter filter = new HeaderWriterFilter(headerWriters); @@ -71,9 +82,34 @@ public class HeaderWriterFilterTests { filter.doFilter(request, response, filterChain); - verify(writer1).writeHeaders(request, response); - verify(writer2).writeHeaders(request, response); + verify(this.writer1).writeHeaders(request, response); + verify(this.writer2).writeHeaders(request, response); assertThat(filterChain.getRequest()).isEqualTo(request); // verify the filterChain // continued } + + // gh-2953 + @Test + public void headersDelayed() throws Exception { + HeaderWriterFilter filter = new HeaderWriterFilter( + Arrays.asList(this.writer1)); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + filter.doFilter(request, response, new FilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + verifyZeroInteractions(HeaderWriterFilterTests.this.writer1); + + response.flushBuffer(); + + verify(HeaderWriterFilterTests.this.writer1).writeHeaders( + any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + }); + + verifyNoMoreInteractions(this.writer1); + } } diff --git a/web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java b/web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java index 2fa90c8335..ca4f816c33 100644 --- a/web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java +++ b/web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java @@ -15,19 +15,32 @@ */ package org.springframework.security.web.header.writers; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.Arrays; +import javax.servlet.http.HttpServletResponse; + import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest; +import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.when; +import static org.powermock.api.mockito.PowerMockito.spy; /** * @author Rob Winch * */ +@RunWith(PowerMockRunner.class) +@PrepareOnlyThisForTest(ReflectionUtils.class) public class CacheControlHeadersWriterTests { private MockHttpServletRequest request; @@ -38,20 +51,79 @@ public class CacheControlHeadersWriterTests { @Before public void setup() { - request = new MockHttpServletRequest(); - response = new MockHttpServletResponse(); - writer = new CacheControlHeadersWriter(); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + this.writer = new CacheControlHeadersWriter(); } @Test public void writeHeaders() { - writer.writeHeaders(request, response); + this.writer.writeHeaders(this.request, this.response); + + assertThat(this.response.getHeaderNames().size()).isEqualTo(3); + assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo( + Arrays.asList("no-cache, no-store, max-age=0, must-revalidate")); + assertThat(this.response.getHeaderValues("Pragma")) + .isEqualTo(Arrays.asList("no-cache")); + assertThat(this.response.getHeaderValues("Expires")) + .isEqualTo(Arrays.asList("0")); + } + + @Test + public void writeHeadersServlet25() { + spy(ReflectionUtils.class); + when(ReflectionUtils.findMethod(HttpServletResponse.class, "getHeader", + String.class)).thenReturn(null); + this.response = spy(this.response); + doThrow(NoSuchMethodError.class).when(this.response).getHeader(anyString()); + this.writer = new CacheControlHeadersWriter(); - assertThat(response.getHeaderNames().size()).isEqualTo(3); - assertThat(response.getHeaderValues("Cache-Control")).isEqualTo( + this.writer.writeHeaders(this.request, this.response); + + assertThat(this.response.getHeaderNames().size()).isEqualTo(3); + assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo( Arrays.asList("no-cache, no-store, max-age=0, must-revalidate")); - assertThat(response.getHeaderValues("Pragma")).isEqualTo( - Arrays.asList("no-cache")); - assertThat(response.getHeaderValues("Expires")).isEqualTo(Arrays.asList("0")); + assertThat(this.response.getHeaderValues("Pragma")) + .isEqualTo(Arrays.asList("no-cache")); + assertThat(this.response.getHeaderValues("Expires")) + .isEqualTo(Arrays.asList("0")); + } + + // gh-2953 + @Test + public void writeHeadersDisabledIfCacheControl() { + this.response.setHeader("Cache-Control", "max-age: 123"); + + this.writer.writeHeaders(this.request, this.response); + + assertThat(this.response.getHeaderNames()).hasSize(1); + assertThat(this.response.getHeaderValues("Cache-Control")) + .containsOnly("max-age: 123"); + assertThat(this.response.getHeaderValue("Pragma")).isNull(); + assertThat(this.response.getHeaderValue("Expires")).isNull(); + } + + @Test + public void writeHeadersDisabledIfPragma() { + this.response.setHeader("Pragma", "mock"); + + this.writer.writeHeaders(this.request, this.response); + + assertThat(this.response.getHeaderNames()).hasSize(1); + assertThat(this.response.getHeaderValues("Pragma")).containsOnly("mock"); + assertThat(this.response.getHeaderValue("Expires")).isNull(); + assertThat(this.response.getHeaderValue("Cache-Control")).isNull(); + } + + @Test + public void writeHeadersDisabledIfExpires() { + this.response.setHeader("Expires", "mock"); + + this.writer.writeHeaders(this.request, this.response); + + assertThat(this.response.getHeaderNames()).hasSize(1); + assertThat(this.response.getHeaderValues("Expires")).containsOnly("mock"); + assertThat(this.response.getHeaderValue("Cache-Control")).isNull(); + assertThat(this.response.getHeaderValue("Pragma")).isNull(); } }