diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java index 8e1384c09e3..57e4a120e0f 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; - import javax.servlet.AsyncContext; import javax.servlet.DispatcherType; import javax.servlet.RequestDispatcher; @@ -58,7 +57,7 @@ import org.springframework.util.StringUtils; /** * Mock implementation of the {@link javax.servlet.http.HttpServletRequest} interface. * - *

As of Spring 4.0, this set of mocks is designed on a Servlet 3.0 baseline. + *

As of Spring Framework 4.0, this set of mocks is designed on a Servlet 3.0 baseline. * * @author Juergen Hoeller * @author Rod Johnson @@ -70,10 +69,14 @@ import org.springframework.util.StringUtils; */ public class MockHttpServletRequest implements HttpServletRequest { + private static final String HTTP = "http"; + + private static final String HTTPS = "https"; + /** * The default protocol: 'http'. */ - public static final String DEFAULT_PROTOCOL = "http"; + public static final String DEFAULT_PROTOCOL = HTTP; /** * The default server address: '127.0.0.1'. @@ -330,9 +333,10 @@ public class MockHttpServletRequest implements HttpServletRequest { } private void updateContentTypeHeader() { - if (this.contentType != null) { + if (StringUtils.hasLength(this.contentType)) { StringBuilder sb = new StringBuilder(this.contentType); - if (!this.contentType.toLowerCase().contains(CHARSET_PREFIX) && this.characterEncoding != null) { + if (!this.contentType.toLowerCase().contains(CHARSET_PREFIX) && + StringUtils.hasLength(this.characterEncoding)) { sb.append(";").append(CHARSET_PREFIX).append(this.characterEncoding); } doAddHeaderValue(CONTENT_TYPE_HEADER, sb.toString(), true); @@ -357,8 +361,7 @@ public class MockHttpServletRequest implements HttpServletRequest { if (contentType != null) { int charsetIndex = contentType.toLowerCase().indexOf(CHARSET_PREFIX); if (charsetIndex != -1) { - String encoding = contentType.substring(charsetIndex + CHARSET_PREFIX.length()); - this.characterEncoding = encoding; + this.characterEncoding = contentType.substring(charsetIndex + CHARSET_PREFIX.length()); } updateContentTypeHeader(); } @@ -955,8 +958,8 @@ public class MockHttpServletRequest implements HttpServletRequest { public StringBuffer getRequestURL() { StringBuffer url = new StringBuffer(this.scheme).append("://").append(this.serverName); - if (this.serverPort > 0 - && (("http".equalsIgnoreCase(scheme) && this.serverPort != 80) || ("https".equalsIgnoreCase(scheme) && this.serverPort != 443))) { + if (this.serverPort > 0 && ((HTTP.equalsIgnoreCase(this.scheme) && this.serverPort != 80) || + (HTTPS.equalsIgnoreCase(this.scheme) && this.serverPort != 443))) { url.append(':').append(this.serverPort); } diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java index b9725ea5693..489dfa294f5 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java @@ -30,7 +30,6 @@ import java.nio.charset.Charset; import java.security.Principal; import java.util.Arrays; import java.util.Enumeration; -import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -40,6 +39,8 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.util.Assert; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.StringUtils; /** * {@link ServerHttpRequest} implementation that is based on a {@link HttpServletRequest}. @@ -111,21 +112,30 @@ public class ServletServerHttpRequest implements ServerHttpRequest { } } // HttpServletRequest exposes some headers as properties: we should include those if not already present - if (this.headers.getContentType() == null && this.servletRequest.getContentType() != null) { - MediaType contentType = MediaType.parseMediaType(this.servletRequest.getContentType()); - this.headers.setContentType(contentType); + MediaType contentType = this.headers.getContentType(); + if (contentType == null) { + String requestContentType = this.servletRequest.getContentType(); + if (StringUtils.hasLength(requestContentType)) { + contentType = MediaType.parseMediaType(requestContentType); + this.headers.setContentType(contentType); + } } - if (this.headers.getContentType() != null && this.headers.getContentType().getCharSet() == null && - this.servletRequest.getCharacterEncoding() != null) { - MediaType oldContentType = this.headers.getContentType(); - Charset charSet = Charset.forName(this.servletRequest.getCharacterEncoding()); - Map params = new HashMap(oldContentType.getParameters()); - params.put("charset", charSet.toString()); - MediaType newContentType = new MediaType(oldContentType.getType(), oldContentType.getSubtype(), params); - this.headers.setContentType(newContentType); + if (contentType != null && contentType.getCharSet() == null) { + String requestEncoding = this.servletRequest.getCharacterEncoding(); + if (StringUtils.hasLength(requestEncoding)) { + Charset charSet = Charset.forName(requestEncoding); + Map params = new LinkedCaseInsensitiveMap(); + params.putAll(contentType.getParameters()); + params.put("charset", charSet.toString()); + MediaType newContentType = new MediaType(contentType.getType(), contentType.getSubtype(), params); + this.headers.setContentType(newContentType); + } } - if (this.headers.getContentLength() == -1 && this.servletRequest.getContentLength() != -1) { - this.headers.setContentLength(this.servletRequest.getContentLength()); + if (this.headers.getContentLength() == -1) { + int requestContentLength = this.servletRequest.getContentLength(); + if (requestContentLength != -1) { + this.headers.setContentLength(requestContentLength); + } } } return this.headers; diff --git a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpRequestTests.java index 32be2421463..490b0369b03 100644 --- a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpRequestTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import java.util.List; import org.junit.Before; import org.junit.Test; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -39,12 +40,14 @@ public class ServletServerHttpRequestTests { private MockHttpServletRequest mockRequest; + @Before public void create() throws Exception { mockRequest = new MockHttpServletRequest(); request = new ServletServerHttpRequest(mockRequest); } + @Test public void getMethod() throws Exception { mockRequest.setMethod("POST"); @@ -65,8 +68,8 @@ public class ServletServerHttpRequestTests { public void getHeaders() throws Exception { String headerName = "MyHeader"; String headerValue1 = "value1"; - mockRequest.addHeader(headerName, headerValue1); String headerValue2 = "value2"; + mockRequest.addHeader(headerName, headerValue1); mockRequest.addHeader(headerName, headerValue2); mockRequest.setContentType("text/plain"); mockRequest.setCharacterEncoding("UTF-8"); @@ -82,6 +85,26 @@ public class ServletServerHttpRequestTests { headers.getContentType()); } + @Test + public void getHeadersWithEmptyContentTypeAndEncoding() throws Exception { + String headerName = "MyHeader"; + String headerValue1 = "value1"; + String headerValue2 = "value2"; + mockRequest.addHeader(headerName, headerValue1); + mockRequest.addHeader(headerName, headerValue2); + mockRequest.setContentType(""); + mockRequest.setCharacterEncoding(""); + + HttpHeaders headers = request.getHeaders(); + assertNotNull("No HttpHeaders returned", headers); + assertTrue("Invalid headers returned", headers.containsKey(headerName)); + List headerValues = headers.get(headerName); + assertEquals("Invalid header values returned", 2, headerValues.size()); + assertTrue("Invalid header values returned", headerValues.contains(headerValue1)); + assertTrue("Invalid header values returned", headerValues.contains(headerValue2)); + assertNull(headers.getContentType()); + } + @Test public void getBody() throws Exception { byte[] content = "Hello World".getBytes("UTF-8"); @@ -105,4 +128,4 @@ public class ServletServerHttpRequestTests { assertArrayEquals("Invalid content returned", content, result); } -} \ No newline at end of file +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java index 997af6e09b8..3115b006774 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java @@ -36,7 +36,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; - import javax.servlet.AsyncContext; import javax.servlet.DispatcherType; import javax.servlet.RequestDispatcher; @@ -333,9 +332,10 @@ public class MockHttpServletRequest implements HttpServletRequest { } private void updateContentTypeHeader() { - if (this.contentType != null) { + if (StringUtils.hasLength(this.contentType)) { StringBuilder sb = new StringBuilder(this.contentType); - if (!this.contentType.toLowerCase().contains(CHARSET_PREFIX) && this.characterEncoding != null) { + if (!this.contentType.toLowerCase().contains(CHARSET_PREFIX) && + StringUtils.hasLength(this.characterEncoding)) { sb.append(";").append(CHARSET_PREFIX).append(this.characterEncoding); } doAddHeaderValue(CONTENT_TYPE_HEADER, sb.toString(), true); @@ -360,8 +360,7 @@ public class MockHttpServletRequest implements HttpServletRequest { if (contentType != null) { int charsetIndex = contentType.toLowerCase().indexOf(CHARSET_PREFIX); if (charsetIndex != -1) { - String encoding = contentType.substring(charsetIndex + CHARSET_PREFIX.length()); - this.characterEncoding = encoding; + this.characterEncoding = contentType.substring(charsetIndex + CHARSET_PREFIX.length()); } updateContentTypeHeader(); }