From 3fc8ec498c7c7ae3074a4980971b5eb31635cc2c Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Wed, 13 Jun 2018 22:03:16 +0200 Subject: [PATCH] MockHttpServletRequest returns a single InputStream or Reader Issue: SPR-16505 Issue: SPR-16499 --- .../mock/web/MockHttpServletRequest.java | 33 ++++++++++++--- .../mock/web/MockHttpServletRequestTests.java | 33 +++++++++++++++ ...jectToStringHttpMessageConverterTests.java | 15 ++++--- .../mock/web/test/MockHttpServletRequest.java | 41 +++++++++++++------ .../web/filter/FormContentFilterTests.java | 29 ++++++------- 5 files changed, 113 insertions(+), 38 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java index fa0440baa11..9f1e037a83a 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -178,6 +178,12 @@ public class MockHttpServletRequest implements HttpServletRequest { @Nullable private String contentType; + @Nullable + private ServletInputStream inputStream; + + @Nullable + private BufferedReader reader; + private final Map parameters = new LinkedHashMap<>(16); private String protocol = DEFAULT_PROTOCOL; @@ -492,12 +498,18 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public ServletInputStream getInputStream() { - if (this.content != null) { - return new DelegatingServletInputStream(new ByteArrayInputStream(this.content)); + if (this.inputStream != null) { + return this.inputStream; } - else { - return EMPTY_SERVLET_INPUT_STREAM; + else if (this.reader != null) { + throw new IllegalStateException( + "Cannot call getInputStream() after getReader() has already been called for the current request") ; } + + this.inputStream = (this.content != null ? + new DelegatingServletInputStream(new ByteArrayInputStream(this.content)) : + EMPTY_SERVLET_INPUT_STREAM); + return this.inputStream; } /** @@ -695,16 +707,25 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public BufferedReader getReader() throws UnsupportedEncodingException { + if (this.reader != null) { + return this.reader; + } + else if (this.inputStream != null) { + throw new IllegalStateException( + "Cannot call getReader() after getInputStream() has already been called for the current request") ; + } + if (this.content != null) { InputStream sourceStream = new ByteArrayInputStream(this.content); Reader sourceReader = (this.characterEncoding != null) ? new InputStreamReader(sourceStream, this.characterEncoding) : new InputStreamReader(sourceStream); - return new BufferedReader(sourceReader); + this.reader = new BufferedReader(sourceReader); } else { - return EMPTY_BUFFERED_READER; + this.reader = EMPTY_BUFFERED_READER; } + return this.reader; } public void setRemoteAddr(String remoteAddr) { diff --git a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java index f3cc3a7a55d..ac493936c4f 100644 --- a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java +++ b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java @@ -47,6 +47,7 @@ import static org.junit.Assert.*; * @author Sam Brannen * @author Brian Clozel * @author Jakub Narloch + * @author Av Pinzur */ public class MockHttpServletRequestTests { @@ -112,6 +113,38 @@ public class MockHttpServletRequestTests { assertNull(request.getContentAsByteArray()); } + @Test // SPR-16505 + public void getReaderTwice() throws IOException { + byte[] bytes = "body".getBytes(Charset.defaultCharset()); + request.setContent(bytes); + assertSame(request.getReader(), request.getReader()); + } + + @Test // SPR-16505 + public void getInputStreamTwice() throws IOException { + byte[] bytes = "body".getBytes(Charset.defaultCharset()); + request.setContent(bytes); + assertSame(request.getInputStream(), request.getInputStream()); + } + + @Test // SPR-16499 + public void getReaderAfterGettingInputStream() throws IOException { + exception.expect(IllegalStateException.class); + exception.expectMessage( + "Cannot call getReader() after getInputStream() has already been called for the current request"); + request.getInputStream(); + request.getReader(); + } + + @Test // SPR-16499 + public void getInputStreamAfterGettingReader() throws IOException { + exception.expect(IllegalStateException.class); + exception.expectMessage( + "Cannot call getInputStream() after getReader() has already been called for the current request"); + request.getReader(); + request.getInputStream(); + } + @Test public void setContentType() { String contentType = "test/plain"; diff --git a/spring-web/src/test/java/org/springframework/http/converter/ObjectToStringHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/ObjectToStringHttpMessageConverterTests.java index 8aa6a1c41f9..649e16aeb00 100644 --- a/spring-web/src/test/java/org/springframework/http/converter/ObjectToStringHttpMessageConverterTests.java +++ b/spring-web/src/test/java/org/springframework/http/converter/ObjectToStringHttpMessageConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,7 +52,7 @@ public class ObjectToStringHttpMessageConverterTests { @Before - public void setUp() { + public void setup() { ConversionService conversionService = new DefaultConversionService(); this.converter = new ObjectToStringHttpMessageConverter(conversionService); @@ -60,6 +60,7 @@ public class ObjectToStringHttpMessageConverterTests { this.response = new ServletServerHttpResponse(this.servletResponse); } + @Test public void canRead() { assertFalse(this.converter.canRead(Math.class, null)); @@ -121,20 +122,22 @@ public class ObjectToStringHttpMessageConverterTests { @Test public void read() throws IOException { + Short shortValue = Short.valueOf((short) 781); MockHttpServletRequest request = new MockHttpServletRequest(); request.setContentType(MediaType.TEXT_PLAIN_VALUE); - - Short shortValue = Short.valueOf((short) 781); - request.setContent(shortValue.toString().getBytes( - StringHttpMessageConverter.DEFAULT_CHARSET)); + request.setContent(shortValue.toString().getBytes(StringHttpMessageConverter.DEFAULT_CHARSET)); assertEquals(shortValue, this.converter.read(Short.class, new ServletServerHttpRequest(request))); Float floatValue = Float.valueOf(123); + request = new MockHttpServletRequest(); + request.setContentType(MediaType.TEXT_PLAIN_VALUE); request.setCharacterEncoding("UTF-16"); request.setContent(floatValue.toString().getBytes("UTF-16")); assertEquals(floatValue, this.converter.read(Float.class, new ServletServerHttpRequest(request))); Long longValue = Long.valueOf(55819182821331L); + request = new MockHttpServletRequest(); + request.setContentType(MediaType.TEXT_PLAIN_VALUE); request.setCharacterEncoding("UTF-8"); request.setContent(longValue.toString().getBytes("UTF-8")); assertEquals(longValue, this.converter.read(Long.class, new ServletServerHttpRequest(request))); diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java index 5ca558586a1..357239d1ed5 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java @@ -174,6 +174,10 @@ public class MockHttpServletRequest implements HttpServletRequest { private String contentType; + private ServletInputStream inputStream; + + private BufferedReader reader; + private final Map parameters = new LinkedHashMap<>(16); private String protocol = DEFAULT_PROTOCOL; @@ -473,12 +477,18 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public ServletInputStream getInputStream() { - if (this.content != null) { - return new DelegatingServletInputStream(new ByteArrayInputStream(this.content)); + if (this.inputStream != null) { + return this.inputStream; } - else { - return EMPTY_SERVLET_INPUT_STREAM; + else if (this.reader != null) { + throw new IllegalStateException( + "Cannot call getInputStream() after getReader() has already been called for the current request") ; } + + this.inputStream = (this.content != null ? + new DelegatingServletInputStream(new ByteArrayInputStream(this.content)) : + EMPTY_SERVLET_INPUT_STREAM); + return this.inputStream; } /** @@ -507,8 +517,7 @@ public class MockHttpServletRequest implements HttpServletRequest { */ public void setParameters(Map params) { Assert.notNull(params, "Parameter map must not be null"); - for (String key : params.keySet()) { - Object value = params.get(key); + params.forEach((key, value) -> { if (value instanceof String) { setParameter(key, (String) value); } @@ -519,7 +528,7 @@ public class MockHttpServletRequest implements HttpServletRequest { throw new IllegalArgumentException( "Parameter map value must be single value " + " or array of type [" + String.class.getName() + "]"); } - } + }); } /** @@ -557,8 +566,7 @@ public class MockHttpServletRequest implements HttpServletRequest { */ public void addParameters(Map params) { Assert.notNull(params, "Parameter map must not be null"); - for (String key : params.keySet()) { - Object value = params.get(key); + params.forEach((key, value) -> { if (value instanceof String) { addParameter(key, (String) value); } @@ -569,7 +577,7 @@ public class MockHttpServletRequest implements HttpServletRequest { throw new IllegalArgumentException("Parameter map value must be single value " + " or array of type [" + String.class.getName() + "]"); } - } + }); } /** @@ -677,16 +685,25 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public BufferedReader getReader() throws UnsupportedEncodingException { + if (this.reader != null) { + return this.reader; + } + else if (this.inputStream != null) { + throw new IllegalStateException( + "Cannot call getReader() after getInputStream() has already been called for the current request") ; + } + if (this.content != null) { InputStream sourceStream = new ByteArrayInputStream(this.content); Reader sourceReader = (this.characterEncoding != null) ? new InputStreamReader(sourceStream, this.characterEncoding) : new InputStreamReader(sourceStream); - return new BufferedReader(sourceReader); + this.reader = new BufferedReader(sourceReader); } else { - return EMPTY_BUFFERED_READER; + this.reader = EMPTY_BUFFERED_READER; } + return this.reader; } public void setRemoteAddr(String remoteAddr) { diff --git a/spring-web/src/test/java/org/springframework/web/filter/FormContentFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/FormContentFilterTests.java index c15cd11f883..bcc9f00ff2a 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/FormContentFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/FormContentFilterTests.java @@ -50,7 +50,6 @@ public class FormContentFilterTests { @Before public void setup() { this.request = new MockHttpServletRequest("PUT", "/"); - this.request.addHeader("Content-Type", "application/x-www-form-urlencoded; charset=ISO-8859-1"); this.request.setContentType("application/x-www-form-urlencoded; charset=ISO-8859-1"); this.response = new MockHttpServletResponse(); this.filterChain = new MockFilterChain(); @@ -59,29 +58,31 @@ public class FormContentFilterTests { @Test public void wrapPutPatchAndDeleteOnly() throws Exception { - this.request.setContent("foo=bar".getBytes("ISO-8859-1")); for (HttpMethod method : HttpMethod.values()) { - this.request.setMethod(method.name()); + MockHttpServletRequest request = new MockHttpServletRequest(method.name(), "/"); + request.setContent("foo=bar".getBytes("ISO-8859-1")); + request.setContentType("application/x-www-form-urlencoded; charset=ISO-8859-1"); this.filterChain = new MockFilterChain(); - this.filter.doFilter(this.request, this.response, this.filterChain); + this.filter.doFilter(request, this.response, this.filterChain); if (method == HttpMethod.PUT || method == HttpMethod.PATCH || method == HttpMethod.DELETE) { - assertNotSame(this.request, this.filterChain.getRequest()); + assertNotSame(request, this.filterChain.getRequest()); } else { - assertSame(this.request, this.filterChain.getRequest()); + assertSame(request, this.filterChain.getRequest()); } } } @Test public void wrapFormEncodedOnly() throws Exception { - this.request.setContent("".getBytes("ISO-8859-1")); String[] contentTypes = new String[] {"text/plain", "multipart/form-data"}; for (String contentType : contentTypes) { - this.request.setContentType(contentType); + MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/"); + request.setContent("".getBytes("ISO-8859-1")); + request.setContentType(contentType); this.filterChain = new MockFilterChain(); - this.filter.doFilter(this.request, this.response, this.filterChain); - assertSame(this.request, this.filterChain.getRequest()); + this.filter.doFilter(request, this.response, this.filterChain); + assertSame(request, this.filterChain.getRequest()); } } @@ -146,7 +147,7 @@ public class FormContentFilterTests { String[] values = this.filterChain.getRequest().getParameterValues("name"); assertNotSame("Request not wrapped", this.request, filterChain.getRequest()); - assertArrayEquals(new String[]{"value1", "value2", "value3", "value4"}, values); + assertArrayEquals(new String[] {"value1", "value2", "value3", "value4"}, values); } @Test @@ -160,7 +161,7 @@ public class FormContentFilterTests { String[] values = this.filterChain.getRequest().getParameterValues("name"); assertNotSame("Request not wrapped", this.request, this.filterChain.getRequest()); - assertArrayEquals(new String[]{"value1", "value2"}, values); + assertArrayEquals(new String[] {"value1", "value2"}, values); } @Test @@ -173,7 +174,7 @@ public class FormContentFilterTests { String[] values = this.filterChain.getRequest().getParameterValues("anotherName"); assertNotSame("Request not wrapped", this.request, this.filterChain.getRequest()); - assertArrayEquals(new String[]{"anotherValue"}, values); + assertArrayEquals(new String[] {"anotherValue"}, values); } @Test @@ -211,7 +212,7 @@ public class FormContentFilterTests { this.request.addParameter("hiddenField", "testHidden"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertArrayEquals(new String[]{"testHidden"}, + assertArrayEquals(new String[] {"testHidden"}, this.filterChain.getRequest().getParameterValues("hiddenField")); }