Browse Source

Fix MockHttpServletRequest.setCookies to produce single cookie header

Prior to this commit, MockHttpServletRequest.setCookies() produced one
Cookie header per supplied cookie, resulting in multiple Cookie headers
which violates the specification.

This commit fixes this by ensuring that all cookie name-value pairs are
stored under a single Cookie header, separated by a semicolon.

Closes gh-23074
pull/25598/head
Ilya Lukyanovich 7 years ago committed by Sam Brannen
parent
commit
5990548f6f
  1. 21
      spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java
  2. 6
      spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java
  3. 106
      spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java

21
spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -41,6 +41,7 @@ import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.TimeZone; import java.util.TimeZone;
import java.util.stream.Collectors;
import javax.servlet.AsyncContext; import javax.servlet.AsyncContext;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
import javax.servlet.RequestDispatcher; import javax.servlet.RequestDispatcher;
@ -58,6 +59,7 @@ import javax.servlet.http.Part;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.LinkedCaseInsensitiveMap;
@ -951,14 +953,20 @@ public class MockHttpServletRequest implements HttpServletRequest {
public void setCookies(@Nullable Cookie... cookies) { public void setCookies(@Nullable Cookie... cookies) {
this.cookies = (ObjectUtils.isEmpty(cookies) ? null : cookies); this.cookies = (ObjectUtils.isEmpty(cookies) ? null : cookies);
this.headers.remove(HttpHeaders.COOKIE); if (this.cookies == null) {
if (this.cookies != null) { removeHeader(HttpHeaders.COOKIE);
Arrays.stream(this.cookies) }
.map(c -> c.getName() + '=' + (c.getValue() == null ? "" : c.getValue())) else {
.forEach(value -> doAddHeaderValue(HttpHeaders.COOKIE, value, false)); doAddHeaderValue(HttpHeaders.COOKIE, encodeCookies(this.cookies), true);
} }
} }
private static String encodeCookies(@NonNull Cookie... cookies) {
return Arrays.stream(cookies)
.map(c -> c.getName() + '=' + (c.getValue() == null ? "" : c.getValue()))
.collect(Collectors.joining("; "));
}
@Override @Override
@Nullable @Nullable
public Cookie[] getCookies() { public Cookie[] getCookies() {
@ -1272,6 +1280,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
* Otherwise it simply returns the current session id. * Otherwise it simply returns the current session id.
* @since 4.0.3 * @since 4.0.3
*/ */
@Override
public String changeSessionId() { public String changeSessionId() {
Assert.isTrue(this.session != null, "The request does not have a session"); Assert.isTrue(this.session != null, "The request does not have a session");
if (this.session instanceof MockHttpSession) { if (this.session instanceof MockHttpSession) {

6
spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -259,7 +259,9 @@ public class MockHttpServletRequestTests {
assertEquals("bar", cookies[0].getValue()); assertEquals("bar", cookies[0].getValue());
assertEquals("baz", cookies[1].getName()); assertEquals("baz", cookies[1].getName());
assertEquals("qux", cookies[1].getValue()); assertEquals("qux", cookies[1].getValue());
assertEquals(Arrays.asList("foo=bar", "baz=qux"), cookieHeaders);
assertEquals(1, cookieHeaders.size());
assertEquals("foo=bar; baz=qux", cookieHeaders.get(0));
} }
@Test @Test

106
spring-web/src/test/java/org/springframework/mock/web/test/MockHttpServletRequest.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2018 the original author or authors. * Copyright 2002-2019 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -41,6 +41,7 @@ import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.TimeZone; import java.util.TimeZone;
import java.util.stream.Collectors;
import javax.servlet.AsyncContext; import javax.servlet.AsyncContext;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
import javax.servlet.RequestDispatcher; import javax.servlet.RequestDispatcher;
@ -58,6 +59,8 @@ import javax.servlet.http.Part;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
@ -168,10 +171,13 @@ public class MockHttpServletRequest implements HttpServletRequest {
private final Map<String, Object> attributes = new LinkedHashMap<>(); private final Map<String, Object> attributes = new LinkedHashMap<>();
@Nullable
private String characterEncoding; private String characterEncoding;
@Nullable
private byte[] content; private byte[] content;
@Nullable
private String contentType; private String contentType;
private final Map<String, String[]> parameters = new LinkedHashMap<>(16); private final Map<String, String[]> parameters = new LinkedHashMap<>(16);
@ -205,6 +211,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
private boolean asyncSupported = false; private boolean asyncSupported = false;
@Nullable
private MockAsyncContext asyncContext; private MockAsyncContext asyncContext;
private DispatcherType dispatcherType = DispatcherType.REQUEST; private DispatcherType dispatcherType = DispatcherType.REQUEST;
@ -214,32 +221,42 @@ public class MockHttpServletRequest implements HttpServletRequest {
// HttpServletRequest properties // HttpServletRequest properties
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
@Nullable
private String authType; private String authType;
@Nullable
private Cookie[] cookies; private Cookie[] cookies;
private final Map<String, HeaderValueHolder> headers = new LinkedCaseInsensitiveMap<>(); private final Map<String, HeaderValueHolder> headers = new LinkedCaseInsensitiveMap<>();
@Nullable
private String method; private String method;
@Nullable
private String pathInfo; private String pathInfo;
private String contextPath = ""; private String contextPath = "";
@Nullable
private String queryString; private String queryString;
@Nullable
private String remoteUser; private String remoteUser;
private final Set<String> userRoles = new HashSet<>(); private final Set<String> userRoles = new HashSet<>();
@Nullable
private Principal userPrincipal; private Principal userPrincipal;
@Nullable
private String requestedSessionId; private String requestedSessionId;
@Nullable
private String requestURI; private String requestURI;
private String servletPath = ""; private String servletPath = "";
@Nullable
private HttpSession session; private HttpSession session;
private boolean requestedSessionIdValid = true; private boolean requestedSessionIdValid = true;
@ -273,7 +290,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
* @see #setRequestURI * @see #setRequestURI
* @see #MockHttpServletRequest(ServletContext, String, String) * @see #MockHttpServletRequest(ServletContext, String, String)
*/ */
public MockHttpServletRequest(String method, String requestURI) { public MockHttpServletRequest(@Nullable String method, @Nullable String requestURI) {
this(null, method, requestURI); this(null, method, requestURI);
} }
@ -283,7 +300,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
* (may be {@code null} to use a default {@link MockServletContext}) * (may be {@code null} to use a default {@link MockServletContext})
* @see #MockHttpServletRequest(ServletContext, String, String) * @see #MockHttpServletRequest(ServletContext, String, String)
*/ */
public MockHttpServletRequest(ServletContext servletContext) { public MockHttpServletRequest(@Nullable ServletContext servletContext) {
this(servletContext, "", ""); this(servletContext, "", "");
} }
@ -300,7 +317,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
* @see #setPreferredLocales * @see #setPreferredLocales
* @see MockServletContext * @see MockServletContext
*/ */
public MockHttpServletRequest(ServletContext servletContext, String method, String requestURI) { public MockHttpServletRequest(@Nullable ServletContext servletContext, @Nullable String method, @Nullable String requestURI) {
this.servletContext = (servletContext != null ? servletContext : new MockServletContext()); this.servletContext = (servletContext != null ? servletContext : new MockServletContext());
this.method = method; this.method = method;
this.requestURI = requestURI; this.requestURI = requestURI;
@ -369,12 +386,13 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
@Override @Override
@Nullable
public String getCharacterEncoding() { public String getCharacterEncoding() {
return this.characterEncoding; return this.characterEncoding;
} }
@Override @Override
public void setCharacterEncoding(String characterEncoding) { public void setCharacterEncoding(@Nullable String characterEncoding) {
this.characterEncoding = characterEncoding; this.characterEncoding = characterEncoding;
updateContentTypeHeader(); updateContentTypeHeader();
} }
@ -399,7 +417,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
* @see #getContentAsByteArray() * @see #getContentAsByteArray()
* @see #getContentAsString() * @see #getContentAsString()
*/ */
public void setContent(byte[] content) { public void setContent(@Nullable byte[] content) {
this.content = content; this.content = content;
} }
@ -410,6 +428,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
* @see #setContent(byte[]) * @see #setContent(byte[])
* @see #getContentAsString() * @see #getContentAsString()
*/ */
@Nullable
public byte[] getContentAsByteArray() { public byte[] getContentAsByteArray() {
return this.content; return this.content;
} }
@ -425,6 +444,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
* @see #setCharacterEncoding(String) * @see #setCharacterEncoding(String)
* @see #getContentAsByteArray() * @see #getContentAsByteArray()
*/ */
@Nullable
public String getContentAsString() throws IllegalStateException, UnsupportedEncodingException { public String getContentAsString() throws IllegalStateException, UnsupportedEncodingException {
Assert.state(this.characterEncoding != null, Assert.state(this.characterEncoding != null,
"Cannot get content as a String for a null character encoding. " + "Cannot get content as a String for a null character encoding. " +
@ -446,7 +466,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
return getContentLength(); return getContentLength();
} }
public void setContentType(String contentType) { public void setContentType(@Nullable String contentType) {
this.contentType = contentType; this.contentType = contentType;
if (contentType != null) { if (contentType != null) {
try { try {
@ -467,6 +487,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
@Override @Override
@Nullable
public String getContentType() { public String getContentType() {
return this.contentType; return this.contentType;
} }
@ -507,8 +528,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
*/ */
public void setParameters(Map<String, ?> params) { public void setParameters(Map<String, ?> params) {
Assert.notNull(params, "Parameter map must not be null"); Assert.notNull(params, "Parameter map must not be null");
for (String key : params.keySet()) { params.forEach((key, value) -> {
Object value = params.get(key);
if (value instanceof String) { if (value instanceof String) {
setParameter(key, (String) value); setParameter(key, (String) value);
} }
@ -519,7 +539,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Parameter map value must be single value " + " or array of type [" + String.class.getName() + "]"); "Parameter map value must be single value " + " or array of type [" + String.class.getName() + "]");
} }
} });
} }
/** /**
@ -527,7 +547,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
* <p>If there are already one or more values registered for the given * <p>If there are already one or more values registered for the given
* parameter name, the given value will be added to the end of the list. * parameter name, the given value will be added to the end of the list.
*/ */
public void addParameter(String name, String value) { public void addParameter(String name, @Nullable String value) {
addParameter(name, new String[] {value}); addParameter(name, new String[] {value});
} }
@ -557,8 +577,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
*/ */
public void addParameters(Map<String, ?> params) { public void addParameters(Map<String, ?> params) {
Assert.notNull(params, "Parameter map must not be null"); Assert.notNull(params, "Parameter map must not be null");
for (String key : params.keySet()) { params.forEach((key, value) -> {
Object value = params.get(key);
if (value instanceof String) { if (value instanceof String) {
addParameter(key, (String) value); addParameter(key, (String) value);
} }
@ -569,7 +588,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
throw new IllegalArgumentException("Parameter map value must be single value " + throw new IllegalArgumentException("Parameter map value must be single value " +
" or array of type [" + String.class.getName() + "]"); " or array of type [" + String.class.getName() + "]");
} }
} });
} }
/** /**
@ -588,6 +607,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
@Override @Override
@Nullable
public String getParameter(String name) { public String getParameter(String name) {
Assert.notNull(name, "Parameter name must not be null"); Assert.notNull(name, "Parameter name must not be null");
String[] arr = this.parameters.get(name); String[] arr = this.parameters.get(name);
@ -708,7 +728,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
@Override @Override
public void setAttribute(String name, Object value) { public void setAttribute(String name, @Nullable Object value) {
checkActive(); checkActive();
Assert.notNull(name, "Attribute name must not be null"); Assert.notNull(name, "Attribute name must not be null");
if (value != null) { if (value != null) {
@ -872,7 +892,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
@Override @Override
public AsyncContext startAsync(ServletRequest request, ServletResponse response) { public AsyncContext startAsync(ServletRequest request, @Nullable ServletResponse response) {
Assert.state(this.asyncSupported, "Async not supported"); Assert.state(this.asyncSupported, "Async not supported");
this.asyncStarted = true; this.asyncStarted = true;
this.asyncContext = new MockAsyncContext(request, response); this.asyncContext = new MockAsyncContext(request, response);
@ -897,11 +917,12 @@ public class MockHttpServletRequest implements HttpServletRequest {
return this.asyncSupported; return this.asyncSupported;
} }
public void setAsyncContext(MockAsyncContext asyncContext) { public void setAsyncContext(@Nullable MockAsyncContext asyncContext) {
this.asyncContext = asyncContext; this.asyncContext = asyncContext;
} }
@Override @Override
@Nullable
public AsyncContext getAsyncContext() { public AsyncContext getAsyncContext() {
return this.asyncContext; return this.asyncContext;
} }
@ -920,26 +941,34 @@ public class MockHttpServletRequest implements HttpServletRequest {
// HttpServletRequest interface // HttpServletRequest interface
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
public void setAuthType(String authType) { public void setAuthType(@Nullable String authType) {
this.authType = authType; this.authType = authType;
} }
@Override @Override
@Nullable
public String getAuthType() { public String getAuthType() {
return this.authType; return this.authType;
} }
public void setCookies(Cookie... cookies) { public void setCookies(@Nullable Cookie... cookies) {
this.cookies = (ObjectUtils.isEmpty(cookies) ? null : cookies); this.cookies = (ObjectUtils.isEmpty(cookies) ? null : cookies);
this.headers.remove(HttpHeaders.COOKIE); if (this.cookies == null) {
if (this.cookies != null) { removeHeader(HttpHeaders.COOKIE);
Arrays.stream(this.cookies) }
.map(c -> c.getName() + '=' + (c.getValue() == null ? "" : c.getValue())) else {
.forEach(value -> doAddHeaderValue(HttpHeaders.COOKIE, value, false)); doAddHeaderValue(HttpHeaders.COOKIE, encodeCookies(this.cookies), true);
} }
} }
private static String encodeCookies(@NonNull Cookie... cookies) {
return Arrays.stream(cookies)
.map(c -> c.getName() + '=' + (c.getValue() == null ? "" : c.getValue()))
.collect(Collectors.joining("; "));
}
@Override @Override
@Nullable
public Cookie[] getCookies() { public Cookie[] getCookies() {
return this.cookies; return this.cookies;
} }
@ -983,7 +1012,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
} }
private void doAddHeaderValue(String name, Object value, boolean replace) { private void doAddHeaderValue(String name, @Nullable Object value, boolean replace) {
HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name); HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
Assert.notNull(value, "Header value must not be null"); Assert.notNull(value, "Header value must not be null");
if (header == null || replace) { if (header == null || replace) {
@ -1059,6 +1088,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
@Override @Override
@Nullable
public String getHeader(String name) { public String getHeader(String name) {
HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name); HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name);
return (header != null ? header.getStringValue() : null); return (header != null ? header.getStringValue() : null);
@ -1093,25 +1123,28 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
} }
public void setMethod(String method) { public void setMethod(@Nullable String method) {
this.method = method; this.method = method;
} }
@Override @Override
@Nullable
public String getMethod() { public String getMethod() {
return this.method; return this.method;
} }
public void setPathInfo(String pathInfo) { public void setPathInfo(@Nullable String pathInfo) {
this.pathInfo = pathInfo; this.pathInfo = pathInfo;
} }
@Override @Override
@Nullable
public String getPathInfo() { public String getPathInfo() {
return this.pathInfo; return this.pathInfo;
} }
@Override @Override
@Nullable
public String getPathTranslated() { public String getPathTranslated() {
return (this.pathInfo != null ? getRealPath(this.pathInfo) : null); return (this.pathInfo != null ? getRealPath(this.pathInfo) : null);
} }
@ -1125,20 +1158,22 @@ public class MockHttpServletRequest implements HttpServletRequest {
return this.contextPath; return this.contextPath;
} }
public void setQueryString(String queryString) { public void setQueryString(@Nullable String queryString) {
this.queryString = queryString; this.queryString = queryString;
} }
@Override @Override
@Nullable
public String getQueryString() { public String getQueryString() {
return this.queryString; return this.queryString;
} }
public void setRemoteUser(String remoteUser) { public void setRemoteUser(@Nullable String remoteUser) {
this.remoteUser = remoteUser; this.remoteUser = remoteUser;
} }
@Override @Override
@Nullable
public String getRemoteUser() { public String getRemoteUser() {
return this.remoteUser; return this.remoteUser;
} }
@ -1153,29 +1188,32 @@ public class MockHttpServletRequest implements HttpServletRequest {
((MockServletContext) this.servletContext).getDeclaredRoles().contains(role))); ((MockServletContext) this.servletContext).getDeclaredRoles().contains(role)));
} }
public void setUserPrincipal(Principal userPrincipal) { public void setUserPrincipal(@Nullable Principal userPrincipal) {
this.userPrincipal = userPrincipal; this.userPrincipal = userPrincipal;
} }
@Override @Override
@Nullable
public Principal getUserPrincipal() { public Principal getUserPrincipal() {
return this.userPrincipal; return this.userPrincipal;
} }
public void setRequestedSessionId(String requestedSessionId) { public void setRequestedSessionId(@Nullable String requestedSessionId) {
this.requestedSessionId = requestedSessionId; this.requestedSessionId = requestedSessionId;
} }
@Override @Override
@Nullable
public String getRequestedSessionId() { public String getRequestedSessionId() {
return this.requestedSessionId; return this.requestedSessionId;
} }
public void setRequestURI(String requestURI) { public void setRequestURI(@Nullable String requestURI) {
this.requestURI = requestURI; this.requestURI = requestURI;
} }
@Override @Override
@Nullable
public String getRequestURI() { public String getRequestURI() {
return this.requestURI; return this.requestURI;
} }
@ -1216,6 +1254,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
@Override @Override
@Nullable
public HttpSession getSession(boolean create) { public HttpSession getSession(boolean create) {
checkActive(); checkActive();
// Reset session if invalidated. // Reset session if invalidated.
@ -1230,6 +1269,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
@Override @Override
@Nullable
public HttpSession getSession() { public HttpSession getSession() {
return getSession(true); return getSession(true);
} }
@ -1240,6 +1280,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
* Otherwise it simply returns the current session id. * Otherwise it simply returns the current session id.
* @since 4.0.3 * @since 4.0.3
*/ */
@Override
public String changeSessionId() { public String changeSessionId() {
Assert.isTrue(this.session != null, "The request does not have a session"); Assert.isTrue(this.session != null, "The request does not have a session");
if (this.session instanceof MockHttpSession) { if (this.session instanceof MockHttpSession) {
@ -1303,6 +1344,7 @@ public class MockHttpServletRequest implements HttpServletRequest {
} }
@Override @Override
@Nullable
public Part getPart(String name) throws IOException, ServletException { public Part getPart(String name) throws IOException, ServletException {
return this.parts.getFirst(name); return this.parts.getFirst(name);
} }

Loading…
Cancel
Save