diff --git a/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java index 133128e1548..57f33a48d91 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -159,7 +159,7 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter } /** - * Sets the maximum length of the payload body to be included in the log message. + * Set the maximum length of the payload body to be included in the log message. * Default is 50 characters. * @since 3.0 */ @@ -233,7 +233,7 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter HttpServletRequest requestToUse = request; if (isIncludePayload() && isFirstRequest && !(request instanceof ContentCachingRequestWrapper)) { - requestToUse = new ContentCachingRequestWrapper(request); + requestToUse = new ContentCachingRequestWrapper(request, getMaxPayloadLength()); } boolean shouldLog = shouldLog(requestToUse); diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingRequestWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingRequestWrapper.java index 9b9a97cffed..259e057d81a 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingRequestWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingRequestWrapper.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,6 +51,8 @@ public class ContentCachingRequestWrapper extends HttpServletRequestWrapper { private final ByteArrayOutputStream cachedContent; + private final Integer contentCacheLimit; + private ServletInputStream inputStream; private BufferedReader reader; @@ -64,6 +66,20 @@ public class ContentCachingRequestWrapper extends HttpServletRequestWrapper { super(request); int contentLength = request.getContentLength(); this.cachedContent = new ByteArrayOutputStream(contentLength >= 0 ? contentLength : 1024); + this.contentCacheLimit = null; + } + + /** + * Create a new ContentCachingRequestWrapper for the given servlet request. + * @param request the original servlet request + * @param contentCacheLimit the maximum number of bytes to cache per request + * @since 4.3.6 + * @see #handleContentOverflow(int) + */ + public ContentCachingRequestWrapper(HttpServletRequest request, int contentCacheLimit) { + super(request); + this.cachedContent = new ByteArrayOutputStream(contentCacheLimit); + this.contentCacheLimit = contentCacheLimit; } @@ -160,16 +176,33 @@ public class ContentCachingRequestWrapper extends HttpServletRequestWrapper { /** * Return the cached request content as a byte array. + *

The returned array will never be larger than the content cache limit. + * @see #ContentCachingRequestWrapper(HttpServletRequest, int) */ public byte[] getContentAsByteArray() { return this.cachedContent.toByteArray(); } + /** + * Template method for handling a content overflow: specifically, a request + * body being read that exceeds the specified content cache limit. + *

The default implementation is empty. Subclasses may override this to + * throw a payload-too-large exception or the like. + * @param contentCacheLimit the maximum number of bytes to cache per request + * which has just been exceeded + * @since 4.3.6 + * @see #ContentCachingRequestWrapper(HttpServletRequest, int) + */ + protected void handleContentOverflow(int contentCacheLimit) { + } + private class ContentCachingInputStream extends ServletInputStream { private final ServletInputStream is; + private boolean overflow = false; + public ContentCachingInputStream(ServletInputStream is) { this.is = is; } @@ -177,8 +210,14 @@ public class ContentCachingRequestWrapper extends HttpServletRequestWrapper { @Override public int read() throws IOException { int ch = this.is.read(); - if (ch != -1) { - cachedContent.write(ch); + if (ch != -1 && !this.overflow) { + if (contentCacheLimit != null && cachedContent.size() == contentCacheLimit) { + this.overflow = true; + handleContentOverflow(contentCacheLimit); + } + else { + cachedContent.write(ch); + } } return ch; } diff --git a/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java index 5a179e726f6..57777f27796 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ import org.junit.Test; import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.util.FileCopyUtils; +import org.springframework.web.util.ContentCachingRequestWrapper; +import org.springframework.web.util.WebUtils; import static org.junit.Assert.*; @@ -51,7 +53,6 @@ public class RequestLoggingFilterTests { request.setQueryString("booking=42"); FilterChain filterChain = new NoOpFilterChain(); - filter.doFilter(request, response, filterChain); assertNotNull(filter.beforeRequestMessage); @@ -169,6 +170,9 @@ public class RequestLoggingFilterTests { ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream()); assertArrayEquals(requestBody, buf); + ContentCachingRequestWrapper wrapper = + WebUtils.getNativeRequest(filterRequest, ContentCachingRequestWrapper.class); + assertArrayEquals("Hel".getBytes("UTF-8"), wrapper.getContentAsByteArray()); } }; diff --git a/spring-web/src/test/java/org/springframework/web/util/ContentCachingRequestWrapperTests.java b/spring-web/src/test/java/org/springframework/web/util/ContentCachingRequestWrapperTests.java index 69661536f95..83abf07ad28 100644 --- a/spring-web/src/test/java/org/springframework/web/util/ContentCachingRequestWrapperTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/ContentCachingRequestWrapperTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,15 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.web.util; -import org.junit.Assert; -import org.junit.Before; import org.junit.Test; import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.util.FileCopyUtils; +import static org.junit.Assert.*; + /** * @author Brian Clozel */ @@ -31,12 +32,8 @@ public class ContentCachingRequestWrapperTests { protected static final String CHARSET = "UTF-8"; - private MockHttpServletRequest request; + private final MockHttpServletRequest request = new MockHttpServletRequest(); - @Before - public void setup() throws Exception { - this.request = new MockHttpServletRequest(); - } @Test public void cachedContent() throws Exception { @@ -46,7 +43,41 @@ public class ContentCachingRequestWrapperTests { ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request); byte[] response = FileCopyUtils.copyToByteArray(wrapper.getInputStream()); - Assert.assertArrayEquals(response, wrapper.getContentAsByteArray()); + assertArrayEquals(response, wrapper.getContentAsByteArray()); + } + + @Test + public void cachedContentWithLimit() throws Exception { + this.request.setMethod("GET"); + this.request.setCharacterEncoding(CHARSET); + this.request.setContent("Hello World".getBytes(CHARSET)); + + ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request, 3); + byte[] response = FileCopyUtils.copyToByteArray(wrapper.getInputStream()); + assertArrayEquals("Hello World".getBytes(CHARSET), response); + assertArrayEquals("Hel".getBytes(CHARSET), wrapper.getContentAsByteArray()); + } + + @Test + public void cachedContentWithOverflow() throws Exception { + this.request.setMethod("GET"); + this.request.setCharacterEncoding(CHARSET); + this.request.setContent("Hello World".getBytes(CHARSET)); + + ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request, 3) { + @Override + protected void handleContentOverflow(int contentCacheLimit) { + throw new IllegalStateException(String.valueOf(contentCacheLimit)); + } + }; + + try { + FileCopyUtils.copyToByteArray(wrapper.getInputStream()); + fail("Should have thrown IllegalStateException"); + } + catch (IllegalStateException ex) { + assertEquals("3", ex.getMessage()); + } } @Test @@ -55,29 +86,28 @@ public class ContentCachingRequestWrapperTests { this.request.setContentType(FORM_CONTENT_TYPE); this.request.setCharacterEncoding(CHARSET); this.request.setParameter("first", "value"); - this.request.setParameter("second", new String[] {"foo", "bar"}); + this.request.setParameter("second", "foo", "bar"); ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request); // getting request parameters will consume the request body - Assert.assertFalse(wrapper.getParameterMap().isEmpty()); - Assert.assertEquals("first=value&second=foo&second=bar", new String(wrapper.getContentAsByteArray())); + assertFalse(wrapper.getParameterMap().isEmpty()); + assertEquals("first=value&second=foo&second=bar", new String(wrapper.getContentAsByteArray())); // SPR-12810 : inputstream body should be consumed - Assert.assertEquals("", new String(FileCopyUtils.copyToByteArray(wrapper.getInputStream()))); + assertEquals("", new String(FileCopyUtils.copyToByteArray(wrapper.getInputStream()))); } - // SPR-12810 - @Test + @Test // SPR-12810 public void inputStreamFormPostRequest() throws Exception { this.request.setMethod("POST"); this.request.setContentType(FORM_CONTENT_TYPE); this.request.setCharacterEncoding(CHARSET); this.request.setParameter("first", "value"); - this.request.setParameter("second", new String[] {"foo", "bar"}); + this.request.setParameter("second", "foo", "bar"); ContentCachingRequestWrapper wrapper = new ContentCachingRequestWrapper(this.request); byte[] response = FileCopyUtils.copyToByteArray(wrapper.getInputStream()); - Assert.assertArrayEquals(response, wrapper.getContentAsByteArray()); + assertArrayEquals(response, wrapper.getContentAsByteArray()); } }