From a3e7fd47c6d2b00dda030f42410144bd8939cf5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Tue, 28 May 2024 10:54:30 +0200 Subject: [PATCH 1/2] Allow MockHttpServletRequestBuilder to support AssertJ This commit moves the features of MockHttpServletRequestBuilder in an abstract class so that another class can offer the same feature whilst providing AssertJ support. This wasn't possible previously as the builder's return type would lose the concrete builder types. This change benefits MockMultipartHttpServletRequestBuilder that can use the same mechanism to offer additional settings. This change also makes it so that a builder instance can be created using only the HttpMethod. Previously, the URI had to be provided as well and that makes it impossible to specify it using the builder. See gh-32913 --- .../servlet/client/MockMvcHttpConnector.java | 5 +- ...AbstractMockHttpServletRequestBuilder.java | 948 ++++++++++++++++++ .../MockHttpServletRequestBuilder.java | 887 +--------------- ...ockMultipartHttpServletRequestBuilder.java | 10 +- .../web/servlet/MockHttpServletRequestDsl.kt | 3 +- 5 files changed, 964 insertions(+), 889 deletions(-) create mode 100644 spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcHttpConnector.java b/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcHttpConnector.java index 36a49b58c2b..d54278330a8 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcHttpConnector.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcHttpConnector.java @@ -54,6 +54,7 @@ import org.springframework.test.web.reactive.server.MockServerClientHttpResponse import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.RequestBuilder; +import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.MockMultipartHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; @@ -134,7 +135,7 @@ public class MockMvcHttpConnector implements ClientHttpConnector { // Initialize the client request requestCallback.apply(httpRequest).block(TIMEOUT); - MockHttpServletRequestBuilder requestBuilder = + AbstractMockHttpServletRequestBuilder requestBuilder = initRequestBuilder(httpMethod, uri, httpRequest, contentRef.get()); requestBuilder.headers(httpRequest.getHeaders()); @@ -149,7 +150,7 @@ public class MockMvcHttpConnector implements ClientHttpConnector { return requestBuilder; } - private MockHttpServletRequestBuilder initRequestBuilder( + private AbstractMockHttpServletRequestBuilder initRequestBuilder( HttpMethod httpMethod, URI uri, MockClientHttpRequest httpRequest, @Nullable byte[] bytes) { String contentType = httpRequest.getHeaders().getFirst(HttpHeaders.CONTENT_TYPE); diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java new file mode 100644 index 00000000000..f8b0a54f3dc --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java @@ -0,0 +1,948 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.request; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpSession; + +import org.springframework.beans.Mergeable; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.lang.Nullable; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.FlashMap; +import org.springframework.web.servlet.FlashMapManager; +import org.springframework.web.servlet.support.SessionFlashMapManager; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; +import org.springframework.web.util.UrlPathHelper; + +/** + * Base builder for {@link MockHttpServletRequest} required as input to + * perform requests in {@link MockMvc}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Arjen Poutsma + * @author Sam Brannen + * @author Kamill Sokol + * @since 6.2 + * @param a self reference to the builder type + */ +public abstract class AbstractMockHttpServletRequestBuilder> + implements ConfigurableSmartRequestBuilder, Mergeable { + + private final HttpMethod method; + + @Nullable + private URI uri; + + private String contextPath = ""; + + private String servletPath = ""; + + @Nullable + private String pathInfo = ""; + + @Nullable + private Boolean secure; + + @Nullable + private Principal principal; + + @Nullable + private MockHttpSession session; + + @Nullable + private String remoteAddress; + + @Nullable + private String characterEncoding; + + @Nullable + private byte[] content; + + @Nullable + private String contentType; + + private final MultiValueMap headers = new LinkedMultiValueMap<>(); + + private final MultiValueMap parameters = new LinkedMultiValueMap<>(); + + private final MultiValueMap queryParams = new LinkedMultiValueMap<>(); + + private final MultiValueMap formFields = new LinkedMultiValueMap<>(); + + private final List cookies = new ArrayList<>(); + + private final List locales = new ArrayList<>(); + + private final Map requestAttributes = new LinkedHashMap<>(); + + private final Map sessionAttributes = new LinkedHashMap<>(); + + private final Map flashAttributes = new LinkedHashMap<>(); + + private final List postProcessors = new ArrayList<>(); + + + /** + * Create a new instance using the specified {@link HttpMethod}. + * @param httpMethod the HTTP method (GET, POST, etc.) + */ + protected AbstractMockHttpServletRequestBuilder(HttpMethod httpMethod) { + Assert.notNull(httpMethod, "'httpMethod' is required"); + this.method = httpMethod; + } + + @SuppressWarnings("unchecked") + protected B self() { + return (B) this; + } + + /** + * Specify the URI using an absolute, fully constructed {@link java.net.URI}. + */ + public B uri(URI uri) { + this.uri = uri; + return self(); + } + + /** + * Specify the URI for the request using a URI template and URI variables. + */ + public B uri(String uriTemplate, Object... uriVariables) { + return uri(initUri(uriTemplate, uriVariables)); + } + + private static URI initUri(String uri, Object[] vars) { + Assert.notNull(uri, "'uri' must not be null"); + Assert.isTrue(uri.isEmpty() || uri.startsWith("/") || uri.startsWith("http://") || uri.startsWith("https://"), + () -> "'uri' should start with a path or be a complete HTTP URI: " + uri); + String uriString = (uri.isEmpty() ? "/" : uri); + return UriComponentsBuilder.fromUriString(uriString).buildAndExpand(vars).encode().toUri(); + } + + /** + * Specify the portion of the requestURI that represents the context path. + * The context path, if specified, must match to the start of the request URI. + *

In most cases, tests can be written by omitting the context path from + * the requestURI. This is because most applications don't actually depend + * on the name under which they're deployed. If specified here, the context + * path must start with a "/" and must not end with a "/". + * @see jakarta.servlet.http.HttpServletRequest#getContextPath() + */ + public B contextPath(String contextPath) { + if (StringUtils.hasText(contextPath)) { + Assert.isTrue(contextPath.startsWith("/"), "Context path must start with a '/'"); + Assert.isTrue(!contextPath.endsWith("/"), "Context path must not end with a '/'"); + } + this.contextPath = contextPath; + return self(); + } + + /** + * Specify the portion of the requestURI that represents the path to which + * the Servlet is mapped. This is typically a portion of the requestURI + * after the context path. + *

In most cases, tests can be written by omitting the servlet path from + * the requestURI. This is because most applications don't actually depend + * on the prefix to which a servlet is mapped. For example if a Servlet is + * mapped to {@code "/main/*"}, tests can be written with the requestURI + * {@code "/accounts/1"} as opposed to {@code "/main/accounts/1"}. + * If specified here, the servletPath must start with a "/" and must not + * end with a "/". + * @see jakarta.servlet.http.HttpServletRequest#getServletPath() + */ + public B servletPath(String servletPath) { + if (StringUtils.hasText(servletPath)) { + Assert.isTrue(servletPath.startsWith("/"), "Servlet path must start with a '/'"); + Assert.isTrue(!servletPath.endsWith("/"), "Servlet path must not end with a '/'"); + } + this.servletPath = servletPath; + return self(); + } + + /** + * Specify the portion of the requestURI that represents the pathInfo. + *

If left unspecified (recommended), the pathInfo will be automatically derived + * by removing the contextPath and the servletPath from the requestURI and using any + * remaining part. If specified here, the pathInfo must start with a "/". + *

If specified, the pathInfo will be used as-is. + * @see jakarta.servlet.http.HttpServletRequest#getPathInfo() + */ + public B pathInfo(@Nullable String pathInfo) { + if (StringUtils.hasText(pathInfo)) { + Assert.isTrue(pathInfo.startsWith("/"), "Path info must start with a '/'"); + } + this.pathInfo = pathInfo; + return self(); + } + + /** + * Set the secure property of the {@link ServletRequest} indicating use of a + * secure channel, such as HTTPS. + * @param secure whether the request is using a secure channel + */ + public B secure(boolean secure){ + this.secure = secure; + return self(); + } + + /** + * Set the character encoding of the request. + * @param encoding the character encoding + * @since 5.3.10 + * @see StandardCharsets + * @see #characterEncoding(String) + */ + public B characterEncoding(Charset encoding) { + return characterEncoding(encoding.name()); + } + + /** + * Set the character encoding of the request. + * @param encoding the character encoding + */ + public B characterEncoding(String encoding) { + this.characterEncoding = encoding; + return self(); + } + + /** + * Set the request body. + *

If content is provided and {@link #contentType(MediaType)} is set to + * {@code application/x-www-form-urlencoded}, the content will be parsed + * and used to populate the {@link #param(String, String...) request + * parameters} map. + * @param content the body content + */ + public B content(byte[] content) { + this.content = content; + return self(); + } + + /** + * Set the request body as a UTF-8 String. + *

If content is provided and {@link #contentType(MediaType)} is set to + * {@code application/x-www-form-urlencoded}, the content will be parsed + * and used to populate the {@link #param(String, String...) request + * parameters} map. + * @param content the body content + */ + public B content(String content) { + this.content = content.getBytes(StandardCharsets.UTF_8); + return self(); + } + + /** + * Set the 'Content-Type' header of the request. + *

If content is provided and {@code contentType} is set to + * {@code application/x-www-form-urlencoded}, the content will be parsed + * and used to populate the {@link #param(String, String...) request + * parameters} map. + * @param contentType the content type + */ + public B contentType(MediaType contentType) { + Assert.notNull(contentType, "'contentType' must not be null"); + this.contentType = contentType.toString(); + return self(); + } + + /** + * Set the 'Content-Type' header of the request as a raw String value, + * possibly not even well-formed (for testing purposes). + * @param contentType the content type + * @since 4.1.2 + */ + public B contentType(String contentType) { + Assert.notNull(contentType, "'contentType' must not be null"); + this.contentType = contentType; + return self(); + } + + /** + * Set the 'Accept' header to the given media type(s). + * @param mediaTypes one or more media types + */ + public B accept(MediaType... mediaTypes) { + Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); + this.headers.set("Accept", MediaType.toString(Arrays.asList(mediaTypes))); + return self(); + } + + /** + * Set the {@code Accept} header using raw String values, possibly not even + * well-formed (for testing purposes). + * @param mediaTypes one or more media types; internally joined as + * comma-separated String + */ + public B accept(String... mediaTypes) { + Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); + this.headers.set("Accept", String.join(", ", mediaTypes)); + return self(); + } + + /** + * Add a header to the request. Values are always added. + * @param name the header name + * @param values one or more header values + */ + public B header(String name, Object... values) { + addToMultiValueMap(this.headers, name, values); + return self(); + } + + /** + * Add all headers to the request. Values are always added. + * @param httpHeaders the headers and values to add + */ + public B headers(HttpHeaders httpHeaders) { + httpHeaders.forEach(this.headers::addAll); + return self(); + } + + /** + * Add a request parameter to {@link MockHttpServletRequest#getParameterMap()}. + *

In the Servlet API, a request parameter may be parsed from the query + * string and/or from the body of an {@code application/x-www-form-urlencoded} + * request. This method simply adds to the request parameter map. You may + * also use add Servlet request parameters by specifying the query or form + * data through one of the following: + *

    + *
  • Supply a URL with a query to {@link MockMvcRequestBuilders}. + *
  • Add query params via {@link #queryParam} or {@link #queryParams}. + *
  • Provide {@link #content} with {@link #contentType} + * {@code application/x-www-form-urlencoded}. + *
+ * @param name the parameter name + * @param values one or more values + */ + public B param(String name, String... values) { + addToMultiValueMap(this.parameters, name, values); + return self(); + } + + /** + * Variant of {@link #param(String, String...)} with a {@link MultiValueMap}. + * @param params the parameters to add + * @since 4.2.4 + */ + public B params(MultiValueMap params) { + params.forEach((name, values) -> { + for (String value : values) { + this.parameters.add(name, value); + } + }); + return self(); + } + + /** + * Append to the query string and also add to the + * {@link #param(String, String...) request parameters} map. The parameter + * name and value are encoded when they are added to the query string. + * @param name the parameter name + * @param values one or more values + * @since 5.2.2 + */ + public B queryParam(String name, String... values) { + param(name, values); + this.queryParams.addAll(name, Arrays.asList(values)); + return self(); + } + + /** + * Append to the query string and also add to the + * {@link #params(MultiValueMap) request parameters} map. The parameter + * name and value are encoded when they are added to the query string. + * @param params the parameters to add + * @since 5.2.2 + */ + public B queryParams(MultiValueMap params) { + params(params); + this.queryParams.addAll(params); + return self(); + } + + /** + * Append the given value(s) to the given form field and also add them to the + * {@linkplain #param(String, String...) request parameters} map. + * @param name the field name + * @param values one or more values + * @since 6.1.7 + */ + public B formField(String name, String... values) { + param(name, values); + this.formFields.addAll(name, Arrays.asList(values)); + return self(); + } + + /** + * Variant of {@link #formField(String, String...)} with a {@link MultiValueMap}. + * @param formFields the form fields to add + * @since 6.1.7 + */ + public B formFields(MultiValueMap formFields) { + params(formFields); + this.formFields.addAll(formFields); + return self(); + } + + /** + * Add the given cookies to the request. Cookies are always added. + * @param cookies the cookies to add + */ + public B cookie(Cookie... cookies) { + Assert.notEmpty(cookies, "'cookies' must not be empty"); + this.cookies.addAll(Arrays.asList(cookies)); + return self(); + } + + /** + * Add the specified locales as preferred request locales. + * @param locales the locales to add + * @since 4.3.6 + * @see #locale(Locale) + */ + public B locale(Locale... locales) { + Assert.notEmpty(locales, "'locales' must not be empty"); + this.locales.addAll(Arrays.asList(locales)); + return self(); + } + + /** + * Set the locale of the request, overriding any previous locales. + * @param locale the locale, or {@code null} to reset it + * @see #locale(Locale...) + */ + public B locale(@Nullable Locale locale) { + this.locales.clear(); + if (locale != null) { + this.locales.add(locale); + } + return self(); + } + + /** + * Set a request attribute. + * @param name the attribute name + * @param value the attribute value + */ + public B requestAttr(String name, Object value) { + addToMap(this.requestAttributes, name, value); + return self(); + } + + /** + * Set a session attribute. + * @param name the session attribute name + * @param value the session attribute value + */ + public B sessionAttr(String name, Object value) { + addToMap(this.sessionAttributes, name, value); + return self(); + } + + /** + * Set session attributes. + * @param sessionAttributes the session attributes + */ + public B sessionAttrs(Map sessionAttributes) { + Assert.notEmpty(sessionAttributes, "'sessionAttributes' must not be empty"); + sessionAttributes.forEach(this::sessionAttr); + return self(); + } + + /** + * Set an "input" flash attribute. + * @param name the flash attribute name + * @param value the flash attribute value + */ + public B flashAttr(String name, Object value) { + addToMap(this.flashAttributes, name, value); + return self(); + } + + /** + * Set flash attributes. + * @param flashAttributes the flash attributes + */ + public B flashAttrs(Map flashAttributes) { + Assert.notEmpty(flashAttributes, "'flashAttributes' must not be empty"); + flashAttributes.forEach(this::flashAttr); + return self(); + } + + /** + * Set the HTTP session to use, possibly re-used across requests. + *

Individual attributes provided via {@link #sessionAttr(String, Object)} + * override the content of the session provided here. + * @param session the HTTP session + */ + public B session(MockHttpSession session) { + Assert.notNull(session, "'session' must not be null"); + this.session = session; + return self(); + } + + /** + * Set the principal of the request. + * @param principal the principal + */ + public B principal(Principal principal) { + Assert.notNull(principal, "'principal' must not be null"); + this.principal = principal; + return self(); + } + + /** + * Set the remote address of the request. + * @param remoteAddress the remote address (IP) + * @since 6.0.10 + */ + public B remoteAddress(String remoteAddress) { + Assert.hasText(remoteAddress, "'remoteAddress' must not be null or blank"); + this.remoteAddress = remoteAddress; + return self(); + } + + /** + * An extension point for further initialization of {@link MockHttpServletRequest} + * in ways not built directly into the {@code MockHttpServletRequestBuilder}. + * Implementation of this interface can have builder-style methods themselves + * and be made accessible through static factory methods. + * @param postProcessor a post-processor to add + */ + @Override + public B with(RequestPostProcessor postProcessor) { + Assert.notNull(postProcessor, "postProcessor is required"); + this.postProcessors.add(postProcessor); + return self(); + } + + + /** + * {@inheritDoc} + * @return always returns {@code true}. + */ + @Override + public boolean isMergeEnabled() { + return true; + } + + /** + * Merges the properties of the "parent" RequestBuilder accepting values + * only if not already set in "this" instance. + * @param parent the parent {@code RequestBuilder} to inherit properties from + * @return the result of the merge + */ + @Override + public Object merge(@Nullable Object parent) { + if (parent == null) { + return this; + } + if (!(parent instanceof AbstractMockHttpServletRequestBuilder parentBuilder)) { + throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); + } + if (!StringUtils.hasText(this.contextPath)) { + this.contextPath = parentBuilder.contextPath; + } + if (!StringUtils.hasText(this.servletPath)) { + this.servletPath = parentBuilder.servletPath; + } + if ("".equals(this.pathInfo)) { + this.pathInfo = parentBuilder.pathInfo; + } + + if (this.secure == null) { + this.secure = parentBuilder.secure; + } + if (this.principal == null) { + this.principal = parentBuilder.principal; + } + if (this.session == null) { + this.session = parentBuilder.session; + } + if (this.remoteAddress == null) { + this.remoteAddress = parentBuilder.remoteAddress; + } + + if (this.characterEncoding == null) { + this.characterEncoding = parentBuilder.characterEncoding; + } + if (this.content == null) { + this.content = parentBuilder.content; + } + if (this.contentType == null) { + this.contentType = parentBuilder.contentType; + } + + for (Map.Entry> entry : parentBuilder.headers.entrySet()) { + String headerName = entry.getKey(); + if (!this.headers.containsKey(headerName)) { + this.headers.put(headerName, entry.getValue()); + } + } + for (Map.Entry> entry : parentBuilder.parameters.entrySet()) { + String paramName = entry.getKey(); + if (!this.parameters.containsKey(paramName)) { + this.parameters.put(paramName, entry.getValue()); + } + } + for (Map.Entry> entry : parentBuilder.queryParams.entrySet()) { + String paramName = entry.getKey(); + if (!this.queryParams.containsKey(paramName)) { + this.queryParams.put(paramName, entry.getValue()); + } + } + for (Map.Entry> entry : parentBuilder.formFields.entrySet()) { + String paramName = entry.getKey(); + if (!this.formFields.containsKey(paramName)) { + this.formFields.put(paramName, entry.getValue()); + } + } + for (Cookie cookie : parentBuilder.cookies) { + if (!containsCookie(cookie)) { + this.cookies.add(cookie); + } + } + for (Locale locale : parentBuilder.locales) { + if (!this.locales.contains(locale)) { + this.locales.add(locale); + } + } + + for (Map.Entry entry : parentBuilder.requestAttributes.entrySet()) { + String attributeName = entry.getKey(); + if (!this.requestAttributes.containsKey(attributeName)) { + this.requestAttributes.put(attributeName, entry.getValue()); + } + } + for (Map.Entry entry : parentBuilder.sessionAttributes.entrySet()) { + String attributeName = entry.getKey(); + if (!this.sessionAttributes.containsKey(attributeName)) { + this.sessionAttributes.put(attributeName, entry.getValue()); + } + } + for (Map.Entry entry : parentBuilder.flashAttributes.entrySet()) { + String attributeName = entry.getKey(); + if (!this.flashAttributes.containsKey(attributeName)) { + this.flashAttributes.put(attributeName, entry.getValue()); + } + } + + this.postProcessors.addAll(0, parentBuilder.postProcessors); + + return this; + } + + private boolean containsCookie(Cookie cookie) { + for (Cookie cookieToCheck : this.cookies) { + if (ObjectUtils.nullSafeEquals(cookieToCheck.getName(), cookie.getName())) { + return true; + } + } + return false; + } + + /** + * Build a {@link MockHttpServletRequest}. + */ + @Override + public final MockHttpServletRequest buildRequest(ServletContext servletContext) { + Assert.notNull(this.uri, "'uri' is required"); + MockHttpServletRequest request = createServletRequest(servletContext); + + request.setAsyncSupported(true); + request.setMethod(this.method.name()); + + String requestUri = this.uri.getRawPath(); + request.setRequestURI(requestUri); + + if (this.uri.getScheme() != null) { + request.setScheme(this.uri.getScheme()); + } + if (this.uri.getHost() != null) { + request.setServerName(this.uri.getHost()); + } + if (this.uri.getPort() != -1) { + request.setServerPort(this.uri.getPort()); + } + + updatePathRequestProperties(request, requestUri); + + if (this.secure != null) { + request.setSecure(this.secure); + } + if (this.principal != null) { + request.setUserPrincipal(this.principal); + } + if (this.remoteAddress != null) { + request.setRemoteAddr(this.remoteAddress); + } + if (this.session != null) { + request.setSession(this.session); + } + + request.setCharacterEncoding(this.characterEncoding); + request.setContent(this.content); + request.setContentType(this.contentType); + + this.headers.forEach((name, values) -> { + for (Object value : values) { + request.addHeader(name, value); + } + }); + + if (!ObjectUtils.isEmpty(this.content) && + !this.headers.containsKey(HttpHeaders.CONTENT_LENGTH) && + !this.headers.containsKey(HttpHeaders.TRANSFER_ENCODING)) { + + request.addHeader(HttpHeaders.CONTENT_LENGTH, this.content.length); + } + + String query = this.uri.getRawQuery(); + if (!this.queryParams.isEmpty()) { + String str = UriComponentsBuilder.newInstance().queryParams(this.queryParams).build().encode().getQuery(); + query = StringUtils.hasLength(query) ? (query + "&" + str) : str; + } + if (query != null) { + request.setQueryString(query); + } + addRequestParams(request, UriComponentsBuilder.fromUri(this.uri).build().getQueryParams()); + + this.parameters.forEach((name, values) -> { + for (String value : values) { + request.addParameter(name, value); + } + }); + + if (!this.formFields.isEmpty()) { + if (this.content != null && this.content.length > 0) { + throw new IllegalStateException("Could not write form data with an existing body"); + } + Charset charset = (this.characterEncoding != null ? + Charset.forName(this.characterEncoding) : StandardCharsets.UTF_8); + MediaType mediaType = (request.getContentType() != null ? + MediaType.parseMediaType(request.getContentType()) : + new MediaType(MediaType.APPLICATION_FORM_URLENCODED, charset)); + if (!mediaType.isCompatibleWith(MediaType.APPLICATION_FORM_URLENCODED)) { + throw new IllegalStateException("Invalid content type: '" + mediaType + + "' is not compatible with '" + MediaType.APPLICATION_FORM_URLENCODED + "'"); + } + request.setContent(writeFormData(mediaType, charset)); + if (request.getContentType() == null) { + request.setContentType(mediaType.toString()); + } + } + if (this.content != null && this.content.length > 0) { + String requestContentType = request.getContentType(); + if (requestContentType != null) { + try { + MediaType mediaType = MediaType.parseMediaType(requestContentType); + if (MediaType.APPLICATION_FORM_URLENCODED.includes(mediaType)) { + addRequestParams(request, parseFormData(mediaType)); + } + } + catch (Exception ex) { + // Must be invalid, ignore + } + } + } + + if (!ObjectUtils.isEmpty(this.cookies)) { + request.setCookies(this.cookies.toArray(new Cookie[0])); + } + if (!ObjectUtils.isEmpty(this.locales)) { + request.setPreferredLocales(this.locales); + } + + this.requestAttributes.forEach(request::setAttribute); + this.sessionAttributes.forEach((name, attribute) -> { + HttpSession session = request.getSession(); + Assert.state(session != null, "No HttpSession"); + session.setAttribute(name, attribute); + }); + + FlashMap flashMap = new FlashMap(); + flashMap.putAll(this.flashAttributes); + FlashMapManager flashMapManager = getFlashMapManager(request); + flashMapManager.saveOutputFlashMap(flashMap, request, new MockHttpServletResponse()); + + return request; + } + + /** + * Create a new {@link MockHttpServletRequest} based on the supplied + * {@code ServletContext}. + *

Can be overridden in subclasses. + */ + protected MockHttpServletRequest createServletRequest(ServletContext servletContext) { + return new MockHttpServletRequest(servletContext); + } + + /** + * Update the contextPath, servletPath, and pathInfo of the request. + */ + private void updatePathRequestProperties(MockHttpServletRequest request, String requestUri) { + if (!requestUri.startsWith(this.contextPath)) { + throw new IllegalArgumentException( + "Request URI [" + requestUri + "] does not start with context path [" + this.contextPath + "]"); + } + request.setContextPath(this.contextPath); + request.setServletPath(this.servletPath); + + if ("".equals(this.pathInfo)) { + if (!requestUri.startsWith(this.contextPath + this.servletPath)) { + throw new IllegalArgumentException( + "Invalid servlet path [" + this.servletPath + "] for request URI [" + requestUri + "]"); + } + String extraPath = requestUri.substring(this.contextPath.length() + this.servletPath.length()); + this.pathInfo = (StringUtils.hasText(extraPath) ? + UrlPathHelper.defaultInstance.decodeRequestString(request, extraPath) : null); + } + request.setPathInfo(this.pathInfo); + } + + private void addRequestParams(MockHttpServletRequest request, MultiValueMap map) { + map.forEach((key, values) -> values.forEach(value -> { + value = (value != null ? UriUtils.decode(value, StandardCharsets.UTF_8) : null); + request.addParameter(UriUtils.decode(key, StandardCharsets.UTF_8), value); + })); + } + + private byte[] writeFormData(MediaType mediaType, Charset charset) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + HttpOutputMessage message = new HttpOutputMessage() { + @Override + public OutputStream getBody() { + return out; + } + + @Override + public HttpHeaders getHeaders() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(mediaType); + return headers; + } + }; + try { + FormHttpMessageConverter messageConverter = new FormHttpMessageConverter(); + messageConverter.setCharset(charset); + messageConverter.write(this.formFields, mediaType, message); + return out.toByteArray(); + } + catch (IOException ex) { + throw new IllegalStateException("Failed to write form data to request body", ex); + } + } + + @SuppressWarnings("unchecked") + private MultiValueMap parseFormData(MediaType mediaType) { + HttpInputMessage message = new HttpInputMessage() { + @Override + public InputStream getBody() { + byte[] bodyContent = AbstractMockHttpServletRequestBuilder.this.content; + return (bodyContent != null ? new ByteArrayInputStream(bodyContent) : InputStream.nullInputStream()); + } + @Override + public HttpHeaders getHeaders() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(mediaType); + return headers; + } + }; + + try { + return (MultiValueMap) new FormHttpMessageConverter().read(null, message); + } + catch (IOException ex) { + throw new IllegalStateException("Failed to parse form data in request body", ex); + } + } + + private FlashMapManager getFlashMapManager(MockHttpServletRequest request) { + FlashMapManager flashMapManager = null; + try { + ServletContext servletContext = request.getServletContext(); + WebApplicationContext wac = WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext); + flashMapManager = wac.getBean(DispatcherServlet.FLASH_MAP_MANAGER_BEAN_NAME, FlashMapManager.class); + } + catch (IllegalStateException | NoSuchBeanDefinitionException ex) { + // ignore + } + return (flashMapManager != null ? flashMapManager : new SessionFlashMapManager()); + } + + @Override + public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { + for (RequestPostProcessor postProcessor : this.postProcessors) { + request = postProcessor.postProcessRequest(request); + } + return request; + } + + + private static void addToMap(Map map, String name, Object value) { + Assert.hasLength(name, "'name' must not be empty"); + Assert.notNull(value, "'value' must not be null"); + map.put(name, value); + } + + private static void addToMultiValueMap(MultiValueMap map, String name, T[] values) { + Assert.hasLength(name, "'name' must not be empty"); + Assert.notEmpty(values, "'values' must not be empty"); + for (T value : values) { + map.add(name, value); + } + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilder.java index 77d324413d1..d7dc20cff96 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilder.java @@ -16,54 +16,12 @@ package org.springframework.test.web.servlet.request; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; import java.net.URI; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.security.Principal; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import jakarta.servlet.ServletContext; -import jakarta.servlet.ServletRequest; -import jakarta.servlet.http.Cookie; -import jakarta.servlet.http.HttpSession; - -import org.springframework.beans.Mergeable; -import org.springframework.beans.factory.NoSuchBeanDefinitionException; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpMethod; -import org.springframework.http.HttpOutputMessage; -import org.springframework.http.MediaType; -import org.springframework.http.converter.FormHttpMessageConverter; -import org.springframework.lang.Nullable; import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.mock.web.MockHttpSession; import org.springframework.test.web.servlet.MockMvc; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; -import org.springframework.util.ObjectUtils; -import org.springframework.util.StringUtils; -import org.springframework.web.context.WebApplicationContext; -import org.springframework.web.context.support.WebApplicationContextUtils; -import org.springframework.web.servlet.DispatcherServlet; -import org.springframework.web.servlet.FlashMap; -import org.springframework.web.servlet.FlashMapManager; -import org.springframework.web.servlet.support.SessionFlashMapManager; -import org.springframework.web.util.UriComponentsBuilder; -import org.springframework.web.util.UriUtils; -import org.springframework.web.util.UrlPathHelper; /** * Default builder for {@link MockHttpServletRequest} required as input to @@ -84,60 +42,7 @@ import org.springframework.web.util.UrlPathHelper; * @since 3.2 */ public class MockHttpServletRequestBuilder - implements ConfigurableSmartRequestBuilder, Mergeable { - - private final HttpMethod method; - - private final URI uri; - - private String contextPath = ""; - - private String servletPath = ""; - - @Nullable - private String pathInfo = ""; - - @Nullable - private Boolean secure; - - @Nullable - private Principal principal; - - @Nullable - private MockHttpSession session; - - @Nullable - private String remoteAddress; - - @Nullable - private String characterEncoding; - - @Nullable - private byte[] content; - - @Nullable - private String contentType; - - private final MultiValueMap headers = new LinkedMultiValueMap<>(); - - private final MultiValueMap parameters = new LinkedMultiValueMap<>(); - - private final MultiValueMap queryParams = new LinkedMultiValueMap<>(); - - private final MultiValueMap formFields = new LinkedMultiValueMap<>(); - - private final List cookies = new ArrayList<>(); - - private final List locales = new ArrayList<>(); - - private final Map requestAttributes = new LinkedHashMap<>(); - - private final Map sessionAttributes = new LinkedHashMap<>(); - - private final Map flashAttributes = new LinkedHashMap<>(); - - private final List postProcessors = new ArrayList<>(); - + extends AbstractMockHttpServletRequestBuilder { /** * Package private constructor. To get an instance, use static factory @@ -150,15 +55,8 @@ public class MockHttpServletRequestBuilder * @param uriVariables zero or more URI variables */ MockHttpServletRequestBuilder(HttpMethod httpMethod, String uriTemplate, Object... uriVariables) { - this(httpMethod, initUri(uriTemplate, uriVariables)); - } - - private static URI initUri(String uri, Object[] vars) { - Assert.notNull(uri, "'uri' must not be null"); - Assert.isTrue(uri.isEmpty() || uri.startsWith("/") || uri.startsWith("http://") || uri.startsWith("https://"), - () -> "'uri' should start with a path or be a complete HTTP URI: " + uri); - String uriString = (uri.isEmpty() ? "/" : uri); - return UriComponentsBuilder.fromUriString(uriString).buildAndExpand(vars).encode().toUri(); + super(httpMethod); + super.uri(uriTemplate, uriVariables); } /** @@ -169,784 +67,9 @@ public class MockHttpServletRequestBuilder * @since 4.0.3 */ MockHttpServletRequestBuilder(HttpMethod httpMethod, URI uri) { - Assert.notNull(httpMethod, "'httpMethod' is required"); + super(httpMethod); Assert.notNull(uri, "'uri' is required"); - this.method = httpMethod; - this.uri = uri; - } - - - /** - * Specify the portion of the requestURI that represents the context path. - * The context path, if specified, must match to the start of the request URI. - *

In most cases, tests can be written by omitting the context path from - * the requestURI. This is because most applications don't actually depend - * on the name under which they're deployed. If specified here, the context - * path must start with a "/" and must not end with a "/". - * @see jakarta.servlet.http.HttpServletRequest#getContextPath() - */ - public MockHttpServletRequestBuilder contextPath(String contextPath) { - if (StringUtils.hasText(contextPath)) { - Assert.isTrue(contextPath.startsWith("/"), "Context path must start with a '/'"); - Assert.isTrue(!contextPath.endsWith("/"), "Context path must not end with a '/'"); - } - this.contextPath = contextPath; - return this; - } - - /** - * Specify the portion of the requestURI that represents the path to which - * the Servlet is mapped. This is typically a portion of the requestURI - * after the context path. - *

In most cases, tests can be written by omitting the servlet path from - * the requestURI. This is because most applications don't actually depend - * on the prefix to which a servlet is mapped. For example if a Servlet is - * mapped to {@code "/main/*"}, tests can be written with the requestURI - * {@code "/accounts/1"} as opposed to {@code "/main/accounts/1"}. - * If specified here, the servletPath must start with a "/" and must not - * end with a "/". - * @see jakarta.servlet.http.HttpServletRequest#getServletPath() - */ - public MockHttpServletRequestBuilder servletPath(String servletPath) { - if (StringUtils.hasText(servletPath)) { - Assert.isTrue(servletPath.startsWith("/"), "Servlet path must start with a '/'"); - Assert.isTrue(!servletPath.endsWith("/"), "Servlet path must not end with a '/'"); - } - this.servletPath = servletPath; - return this; - } - - /** - * Specify the portion of the requestURI that represents the pathInfo. - *

If left unspecified (recommended), the pathInfo will be automatically derived - * by removing the contextPath and the servletPath from the requestURI and using any - * remaining part. If specified here, the pathInfo must start with a "/". - *

If specified, the pathInfo will be used as-is. - * @see jakarta.servlet.http.HttpServletRequest#getPathInfo() - */ - public MockHttpServletRequestBuilder pathInfo(@Nullable String pathInfo) { - if (StringUtils.hasText(pathInfo)) { - Assert.isTrue(pathInfo.startsWith("/"), "Path info must start with a '/'"); - } - this.pathInfo = pathInfo; - return this; - } - - /** - * Set the secure property of the {@link ServletRequest} indicating use of a - * secure channel, such as HTTPS. - * @param secure whether the request is using a secure channel - */ - public MockHttpServletRequestBuilder secure(boolean secure){ - this.secure = secure; - return this; - } - - /** - * Set the character encoding of the request. - * @param encoding the character encoding - * @since 5.3.10 - * @see StandardCharsets - * @see #characterEncoding(String) - */ - public MockHttpServletRequestBuilder characterEncoding(Charset encoding) { - return this.characterEncoding(encoding.name()); - } - - /** - * Set the character encoding of the request. - * @param encoding the character encoding - */ - public MockHttpServletRequestBuilder characterEncoding(String encoding) { - this.characterEncoding = encoding; - return this; - } - - /** - * Set the request body. - *

If content is provided and {@link #contentType(MediaType)} is set to - * {@code application/x-www-form-urlencoded}, the content will be parsed - * and used to populate the {@link #param(String, String...) request - * parameters} map. - * @param content the body content - */ - public MockHttpServletRequestBuilder content(byte[] content) { - this.content = content; - return this; - } - - /** - * Set the request body as a UTF-8 String. - *

If content is provided and {@link #contentType(MediaType)} is set to - * {@code application/x-www-form-urlencoded}, the content will be parsed - * and used to populate the {@link #param(String, String...) request - * parameters} map. - * @param content the body content - */ - public MockHttpServletRequestBuilder content(String content) { - this.content = content.getBytes(StandardCharsets.UTF_8); - return this; - } - - /** - * Set the 'Content-Type' header of the request. - *

If content is provided and {@code contentType} is set to - * {@code application/x-www-form-urlencoded}, the content will be parsed - * and used to populate the {@link #param(String, String...) request - * parameters} map. - * @param contentType the content type - */ - public MockHttpServletRequestBuilder contentType(MediaType contentType) { - Assert.notNull(contentType, "'contentType' must not be null"); - this.contentType = contentType.toString(); - return this; - } - - /** - * Set the 'Content-Type' header of the request as a raw String value, - * possibly not even well-formed (for testing purposes). - * @param contentType the content type - * @since 4.1.2 - */ - public MockHttpServletRequestBuilder contentType(String contentType) { - Assert.notNull(contentType, "'contentType' must not be null"); - this.contentType = contentType; - return this; - } - - /** - * Set the 'Accept' header to the given media type(s). - * @param mediaTypes one or more media types - */ - public MockHttpServletRequestBuilder accept(MediaType... mediaTypes) { - Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - this.headers.set("Accept", MediaType.toString(Arrays.asList(mediaTypes))); - return this; - } - - /** - * Set the {@code Accept} header using raw String values, possibly not even - * well-formed (for testing purposes). - * @param mediaTypes one or more media types; internally joined as - * comma-separated String - */ - public MockHttpServletRequestBuilder accept(String... mediaTypes) { - Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - this.headers.set("Accept", String.join(", ", mediaTypes)); - return this; - } - - /** - * Add a header to the request. Values are always added. - * @param name the header name - * @param values one or more header values - */ - public MockHttpServletRequestBuilder header(String name, Object... values) { - addToMultiValueMap(this.headers, name, values); - return this; - } - - /** - * Add all headers to the request. Values are always added. - * @param httpHeaders the headers and values to add - */ - public MockHttpServletRequestBuilder headers(HttpHeaders httpHeaders) { - httpHeaders.forEach(this.headers::addAll); - return this; - } - - /** - * Add a request parameter to {@link MockHttpServletRequest#getParameterMap()}. - *

In the Servlet API, a request parameter may be parsed from the query - * string and/or from the body of an {@code application/x-www-form-urlencoded} - * request. This method simply adds to the request parameter map. You may - * also use add Servlet request parameters by specifying the query or form - * data through one of the following: - *

    - *
  • Supply a URI with a query to {@link MockMvcRequestBuilders}. - *
  • Add query params via {@link #queryParam} or {@link #queryParams}. - *
  • Provide {@link #content} with {@link #contentType} - * {@code application/x-www-form-urlencoded}. - *
- * @param name the parameter name - * @param values one or more values - */ - public MockHttpServletRequestBuilder param(String name, String... values) { - addToMultiValueMap(this.parameters, name, values); - return this; - } - - /** - * Variant of {@link #param(String, String...)} with a {@link MultiValueMap}. - * @param params the parameters to add - * @since 4.2.4 - */ - public MockHttpServletRequestBuilder params(MultiValueMap params) { - params.forEach((name, values) -> { - for (String value : values) { - this.parameters.add(name, value); - } - }); - return this; - } - - /** - * Append to the query string and also add to the - * {@link #param(String, String...) request parameters} map. The parameter - * name and value are encoded when they are added to the query string. - * @param name the parameter name - * @param values one or more values - * @since 5.2.2 - */ - public MockHttpServletRequestBuilder queryParam(String name, String... values) { - param(name, values); - this.queryParams.addAll(name, Arrays.asList(values)); - return this; - } - - /** - * Append to the query string and also add to the - * {@link #params(MultiValueMap) request parameters} map. The parameter - * name and value are encoded when they are added to the query string. - * @param params the parameters to add - * @since 5.2.2 - */ - public MockHttpServletRequestBuilder queryParams(MultiValueMap params) { - params(params); - this.queryParams.addAll(params); - return this; - } - - /** - * Append the given value(s) to the given form field and also add them to the - * {@linkplain #param(String, String...) request parameters} map. - * @param name the field name - * @param values one or more values - * @since 6.1.7 - */ - public MockHttpServletRequestBuilder formField(String name, String... values) { - param(name, values); - this.formFields.addAll(name, Arrays.asList(values)); - return this; - } - - /** - * Variant of {@link #formField(String, String...)} with a {@link MultiValueMap}. - * @param formFields the form fields to add - * @since 6.1.7 - */ - public MockHttpServletRequestBuilder formFields(MultiValueMap formFields) { - params(formFields); - this.formFields.addAll(formFields); - return this; - } - - /** - * Add the given cookies to the request. Cookies are always added. - * @param cookies the cookies to add - */ - public MockHttpServletRequestBuilder cookie(Cookie... cookies) { - Assert.notEmpty(cookies, "'cookies' must not be empty"); - this.cookies.addAll(Arrays.asList(cookies)); - return this; - } - - /** - * Add the specified locales as preferred request locales. - * @param locales the locales to add - * @since 4.3.6 - * @see #locale(Locale) - */ - public MockHttpServletRequestBuilder locale(Locale... locales) { - Assert.notEmpty(locales, "'locales' must not be empty"); - this.locales.addAll(Arrays.asList(locales)); - return this; - } - - /** - * Set the locale of the request, overriding any previous locales. - * @param locale the locale, or {@code null} to reset it - * @see #locale(Locale...) - */ - public MockHttpServletRequestBuilder locale(@Nullable Locale locale) { - this.locales.clear(); - if (locale != null) { - this.locales.add(locale); - } - return this; - } - - /** - * Set a request attribute. - * @param name the attribute name - * @param value the attribute value - */ - public MockHttpServletRequestBuilder requestAttr(String name, Object value) { - addToMap(this.requestAttributes, name, value); - return this; - } - - /** - * Set a session attribute. - * @param name the session attribute name - * @param value the session attribute value - */ - public MockHttpServletRequestBuilder sessionAttr(String name, Object value) { - addToMap(this.sessionAttributes, name, value); - return this; - } - - /** - * Set session attributes. - * @param sessionAttributes the session attributes - */ - public MockHttpServletRequestBuilder sessionAttrs(Map sessionAttributes) { - Assert.notEmpty(sessionAttributes, "'sessionAttributes' must not be empty"); - sessionAttributes.forEach(this::sessionAttr); - return this; - } - - /** - * Set an "input" flash attribute. - * @param name the flash attribute name - * @param value the flash attribute value - */ - public MockHttpServletRequestBuilder flashAttr(String name, Object value) { - addToMap(this.flashAttributes, name, value); - return this; - } - - /** - * Set flash attributes. - * @param flashAttributes the flash attributes - */ - public MockHttpServletRequestBuilder flashAttrs(Map flashAttributes) { - Assert.notEmpty(flashAttributes, "'flashAttributes' must not be empty"); - flashAttributes.forEach(this::flashAttr); - return this; - } - - /** - * Set the HTTP session to use, possibly re-used across requests. - *

Individual attributes provided via {@link #sessionAttr(String, Object)} - * override the content of the session provided here. - * @param session the HTTP session - */ - public MockHttpServletRequestBuilder session(MockHttpSession session) { - Assert.notNull(session, "'session' must not be null"); - this.session = session; - return this; - } - - /** - * Set the principal of the request. - * @param principal the principal - */ - public MockHttpServletRequestBuilder principal(Principal principal) { - Assert.notNull(principal, "'principal' must not be null"); - this.principal = principal; - return this; - } - - /** - * Set the remote address of the request. - * @param remoteAddress the remote address (IP) - * @since 6.0.10 - */ - public MockHttpServletRequestBuilder remoteAddress(String remoteAddress) { - Assert.hasText(remoteAddress, "'remoteAddress' must not be null or blank"); - this.remoteAddress = remoteAddress; - return this; - } - - /** - * An extension point for further initialization of {@link MockHttpServletRequest} - * in ways not built directly into the {@code MockHttpServletRequestBuilder}. - * Implementation of this interface can have builder-style methods themselves - * and be made accessible through static factory methods. - * @param postProcessor a post-processor to add - */ - @Override - public MockHttpServletRequestBuilder with(RequestPostProcessor postProcessor) { - Assert.notNull(postProcessor, "postProcessor is required"); - this.postProcessors.add(postProcessor); - return this; - } - - - /** - * {@inheritDoc} - * @return always returns {@code true}. - */ - @Override - public boolean isMergeEnabled() { - return true; - } - - /** - * Merges the properties of the "parent" RequestBuilder accepting values - * only if not already set in "this" instance. - * @param parent the parent {@code RequestBuilder} to inherit properties from - * @return the result of the merge - */ - @Override - public Object merge(@Nullable Object parent) { - if (parent == null) { - return this; - } - if (!(parent instanceof MockHttpServletRequestBuilder parentBuilder)) { - throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); - } - if (!StringUtils.hasText(this.contextPath)) { - this.contextPath = parentBuilder.contextPath; - } - if (!StringUtils.hasText(this.servletPath)) { - this.servletPath = parentBuilder.servletPath; - } - if ("".equals(this.pathInfo)) { - this.pathInfo = parentBuilder.pathInfo; - } - - if (this.secure == null) { - this.secure = parentBuilder.secure; - } - if (this.principal == null) { - this.principal = parentBuilder.principal; - } - if (this.session == null) { - this.session = parentBuilder.session; - } - if (this.remoteAddress == null) { - this.remoteAddress = parentBuilder.remoteAddress; - } - - if (this.characterEncoding == null) { - this.characterEncoding = parentBuilder.characterEncoding; - } - if (this.content == null) { - this.content = parentBuilder.content; - } - if (this.contentType == null) { - this.contentType = parentBuilder.contentType; - } - - for (Map.Entry> entry : parentBuilder.headers.entrySet()) { - String headerName = entry.getKey(); - if (!this.headers.containsKey(headerName)) { - this.headers.put(headerName, entry.getValue()); - } - } - for (Map.Entry> entry : parentBuilder.parameters.entrySet()) { - String paramName = entry.getKey(); - if (!this.parameters.containsKey(paramName)) { - this.parameters.put(paramName, entry.getValue()); - } - } - for (Map.Entry> entry : parentBuilder.queryParams.entrySet()) { - String paramName = entry.getKey(); - if (!this.queryParams.containsKey(paramName)) { - this.queryParams.put(paramName, entry.getValue()); - } - } - for (Map.Entry> entry : parentBuilder.formFields.entrySet()) { - String paramName = entry.getKey(); - if (!this.formFields.containsKey(paramName)) { - this.formFields.put(paramName, entry.getValue()); - } - } - for (Cookie cookie : parentBuilder.cookies) { - if (!containsCookie(cookie)) { - this.cookies.add(cookie); - } - } - for (Locale locale : parentBuilder.locales) { - if (!this.locales.contains(locale)) { - this.locales.add(locale); - } - } - - for (Map.Entry entry : parentBuilder.requestAttributes.entrySet()) { - String attributeName = entry.getKey(); - if (!this.requestAttributes.containsKey(attributeName)) { - this.requestAttributes.put(attributeName, entry.getValue()); - } - } - for (Map.Entry entry : parentBuilder.sessionAttributes.entrySet()) { - String attributeName = entry.getKey(); - if (!this.sessionAttributes.containsKey(attributeName)) { - this.sessionAttributes.put(attributeName, entry.getValue()); - } - } - for (Map.Entry entry : parentBuilder.flashAttributes.entrySet()) { - String attributeName = entry.getKey(); - if (!this.flashAttributes.containsKey(attributeName)) { - this.flashAttributes.put(attributeName, entry.getValue()); - } - } - - this.postProcessors.addAll(0, parentBuilder.postProcessors); - - return this; - } - - private boolean containsCookie(Cookie cookie) { - for (Cookie cookieToCheck : this.cookies) { - if (ObjectUtils.nullSafeEquals(cookieToCheck.getName(), cookie.getName())) { - return true; - } - } - return false; - } - - /** - * Build a {@link MockHttpServletRequest}. - */ - @Override - public final MockHttpServletRequest buildRequest(ServletContext servletContext) { - MockHttpServletRequest request = createServletRequest(servletContext); - - request.setAsyncSupported(true); - request.setMethod(this.method.name()); - - String requestUri = this.uri.getRawPath(); - request.setRequestURI(requestUri); - - if (this.uri.getScheme() != null) { - request.setScheme(this.uri.getScheme()); - } - if (this.uri.getHost() != null) { - request.setServerName(this.uri.getHost()); - } - if (this.uri.getPort() != -1) { - request.setServerPort(this.uri.getPort()); - } - - updatePathRequestProperties(request, requestUri); - - if (this.secure != null) { - request.setSecure(this.secure); - } - if (this.principal != null) { - request.setUserPrincipal(this.principal); - } - if (this.remoteAddress != null) { - request.setRemoteAddr(this.remoteAddress); - } - if (this.session != null) { - request.setSession(this.session); - } - - request.setCharacterEncoding(this.characterEncoding); - request.setContent(this.content); - request.setContentType(this.contentType); - - this.headers.forEach((name, values) -> { - for (Object value : values) { - request.addHeader(name, value); - } - }); - - if (!ObjectUtils.isEmpty(this.content) && - !this.headers.containsKey(HttpHeaders.CONTENT_LENGTH) && - !this.headers.containsKey(HttpHeaders.TRANSFER_ENCODING)) { - - request.addHeader(HttpHeaders.CONTENT_LENGTH, this.content.length); - } - - String query = this.uri.getRawQuery(); - if (!this.queryParams.isEmpty()) { - String str = UriComponentsBuilder.newInstance().queryParams(this.queryParams).build().encode().getQuery(); - query = StringUtils.hasLength(query) ? (query + "&" + str) : str; - } - if (query != null) { - request.setQueryString(query); - } - addRequestParams(request, UriComponentsBuilder.fromUri(this.uri).build().getQueryParams()); - - this.parameters.forEach((name, values) -> { - for (String value : values) { - request.addParameter(name, value); - } - }); - - if (!this.formFields.isEmpty()) { - if (this.content != null && this.content.length > 0) { - throw new IllegalStateException("Could not write form data with an existing body"); - } - Charset charset = (this.characterEncoding != null ? - Charset.forName(this.characterEncoding) : StandardCharsets.UTF_8); - MediaType mediaType = (request.getContentType() != null ? - MediaType.parseMediaType(request.getContentType()) : - new MediaType(MediaType.APPLICATION_FORM_URLENCODED, charset)); - if (!mediaType.isCompatibleWith(MediaType.APPLICATION_FORM_URLENCODED)) { - throw new IllegalStateException("Invalid content type: '" + mediaType + - "' is not compatible with '" + MediaType.APPLICATION_FORM_URLENCODED + "'"); - } - request.setContent(writeFormData(mediaType, charset)); - if (request.getContentType() == null) { - request.setContentType(mediaType.toString()); - } - } - if (this.content != null && this.content.length > 0) { - String requestContentType = request.getContentType(); - if (requestContentType != null) { - try { - MediaType mediaType = MediaType.parseMediaType(requestContentType); - if (MediaType.APPLICATION_FORM_URLENCODED.includes(mediaType)) { - addRequestParams(request, parseFormData(mediaType)); - } - } - catch (Exception ex) { - // Must be invalid, ignore - } - } - } - - if (!ObjectUtils.isEmpty(this.cookies)) { - request.setCookies(this.cookies.toArray(new Cookie[0])); - } - if (!ObjectUtils.isEmpty(this.locales)) { - request.setPreferredLocales(this.locales); - } - - this.requestAttributes.forEach(request::setAttribute); - this.sessionAttributes.forEach((name, attribute) -> { - HttpSession session = request.getSession(); - Assert.state(session != null, "No HttpSession"); - session.setAttribute(name, attribute); - }); - - FlashMap flashMap = new FlashMap(); - flashMap.putAll(this.flashAttributes); - FlashMapManager flashMapManager = getFlashMapManager(request); - flashMapManager.saveOutputFlashMap(flashMap, request, new MockHttpServletResponse()); - - return request; - } - - /** - * Create a new {@link MockHttpServletRequest} based on the supplied - * {@code ServletContext}. - *

Can be overridden in subclasses. - */ - protected MockHttpServletRequest createServletRequest(ServletContext servletContext) { - return new MockHttpServletRequest(servletContext); - } - - /** - * Update the contextPath, servletPath, and pathInfo of the request. - */ - private void updatePathRequestProperties(MockHttpServletRequest request, String requestUri) { - if (!requestUri.startsWith(this.contextPath)) { - throw new IllegalArgumentException( - "Request URI [" + requestUri + "] does not start with context path [" + this.contextPath + "]"); - } - request.setContextPath(this.contextPath); - request.setServletPath(this.servletPath); - - if ("".equals(this.pathInfo)) { - if (!requestUri.startsWith(this.contextPath + this.servletPath)) { - throw new IllegalArgumentException( - "Invalid servlet path [" + this.servletPath + "] for request URI [" + requestUri + "]"); - } - String extraPath = requestUri.substring(this.contextPath.length() + this.servletPath.length()); - this.pathInfo = (StringUtils.hasText(extraPath) ? - UrlPathHelper.defaultInstance.decodeRequestString(request, extraPath) : null); - } - request.setPathInfo(this.pathInfo); - } - - private void addRequestParams(MockHttpServletRequest request, MultiValueMap map) { - map.forEach((key, values) -> values.forEach(value -> { - value = (value != null ? UriUtils.decode(value, StandardCharsets.UTF_8) : null); - request.addParameter(UriUtils.decode(key, StandardCharsets.UTF_8), value); - })); - } - - private byte[] writeFormData(MediaType mediaType, Charset charset) { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - HttpOutputMessage message = new HttpOutputMessage() { - @Override - public OutputStream getBody() { - return out; - } - - @Override - public HttpHeaders getHeaders() { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(mediaType); - return headers; - } - }; - try { - FormHttpMessageConverter messageConverter = new FormHttpMessageConverter(); - messageConverter.setCharset(charset); - messageConverter.write(this.formFields, mediaType, message); - return out.toByteArray(); - } - catch (IOException ex) { - throw new IllegalStateException("Failed to write form data to request body", ex); - } - } - - @SuppressWarnings("unchecked") - private MultiValueMap parseFormData(MediaType mediaType) { - HttpInputMessage message = new HttpInputMessage() { - @Override - public InputStream getBody() { - byte[] bodyContent = MockHttpServletRequestBuilder.this.content; - return (bodyContent != null ? new ByteArrayInputStream(bodyContent) : InputStream.nullInputStream()); - } - @Override - public HttpHeaders getHeaders() { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(mediaType); - return headers; - } - }; - - try { - return (MultiValueMap) new FormHttpMessageConverter().read(null, message); - } - catch (IOException ex) { - throw new IllegalStateException("Failed to parse form data in request body", ex); - } - } - - private FlashMapManager getFlashMapManager(MockHttpServletRequest request) { - FlashMapManager flashMapManager = null; - try { - ServletContext servletContext = request.getServletContext(); - WebApplicationContext wac = WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext); - flashMapManager = wac.getBean(DispatcherServlet.FLASH_MAP_MANAGER_BEAN_NAME, FlashMapManager.class); - } - catch (IllegalStateException | NoSuchBeanDefinitionException ex) { - // ignore - } - return (flashMapManager != null ? flashMapManager : new SessionFlashMapManager()); - } - - @Override - public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - for (RequestPostProcessor postProcessor : this.postProcessors) { - request = postProcessor.postProcessRequest(request); - } - return request; - } - - - private static void addToMap(Map map, String name, Object value) { - Assert.hasLength(name, "'name' must not be empty"); - Assert.notNull(value, "'value' must not be null"); - map.put(name, value); - } - - private static void addToMultiValueMap(MultiValueMap map, String name, T[] values) { - Assert.hasLength(name, "'name' must not be empty"); - Assert.notEmpty(values, "'values' must not be empty"); - for (T value : values) { - map.add(name, value); - } + super.uri(uri); } } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java index fcadb3f531f..36a4a548c3e 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java @@ -47,7 +47,7 @@ import org.springframework.util.MultiValueMap; * @author Arjen Poutsma * @since 3.2 */ -public class MockMultipartHttpServletRequestBuilder extends MockHttpServletRequestBuilder { +public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServletRequestBuilder { private final List files = new ArrayList<>(); @@ -73,7 +73,8 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque * @since 5.3.22 */ MockMultipartHttpServletRequestBuilder(HttpMethod httpMethod, String uriTemplate, Object... uriVariables) { - super(httpMethod, uriTemplate, uriVariables); + super(httpMethod); + super.uri(uriTemplate, uriVariables); super.contentType(MediaType.MULTIPART_FORM_DATA); } @@ -92,7 +93,8 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque * @since 5.3.21 */ MockMultipartHttpServletRequestBuilder(HttpMethod httpMethod, URI uri) { - super(httpMethod, uri); + super(httpMethod); + super.uri(uri); super.contentType(MediaType.MULTIPART_FORM_DATA); } @@ -134,7 +136,7 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque if (parent == null) { return this; } - if (parent instanceof MockHttpServletRequestBuilder) { + if (parent instanceof AbstractMockHttpServletRequestBuilder) { super.merge(parent); if (parent instanceof MockMultipartHttpServletRequestBuilder parentBuilder) { this.files.addAll(parentBuilder.files); diff --git a/spring-test/src/main/kotlin/org/springframework/test/web/servlet/MockHttpServletRequestDsl.kt b/spring-test/src/main/kotlin/org/springframework/test/web/servlet/MockHttpServletRequestDsl.kt index 39298114469..bacee88b6a5 100644 --- a/spring-test/src/main/kotlin/org/springframework/test/web/servlet/MockHttpServletRequestDsl.kt +++ b/spring-test/src/main/kotlin/org/springframework/test/web/servlet/MockHttpServletRequestDsl.kt @@ -25,6 +25,7 @@ import org.springframework.util.MultiValueMap import java.security.Principal import java.util.* import jakarta.servlet.http.Cookie +import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder /** * Provide a [MockHttpServletRequestBuilder] Kotlin DSL in order to be able to write idiomatic Kotlin code. @@ -40,7 +41,7 @@ import jakarta.servlet.http.Cookie * @author Sebastien Deleuze * @since 5.2 */ -open class MockHttpServletRequestDsl internal constructor (private val builder: MockHttpServletRequestBuilder) { +open class MockHttpServletRequestDsl internal constructor (private val builder: AbstractMockHttpServletRequestBuilder<*>) { /** * @see [MockHttpServletRequestBuilder.contextPath] From 8d2bc3bdbabcfaf1ab4d0232f891bcf380ad975f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Tue, 28 May 2024 11:08:13 +0200 Subject: [PATCH 2/2] Add support for fluent preparation of the request in MockMvcTester See gh-32913 --- .../web/servlet/assertj/MockMvcTester.java | 180 +++++++++++++++--- ...vcTesterCompatibilityIntegrationTests.java | 94 +++++++++ .../MockMvcTesterIntegrationTests.java | 110 ++++++----- 3 files changed, 306 insertions(+), 78 deletions(-) create mode 100644 spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterCompatibilityIntegrationTests.java diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java index c0d4546e083..5dd258ee987 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java @@ -23,13 +23,18 @@ import java.util.Map; import java.util.function.Function; import java.util.stream.StreamSupport; +import org.assertj.core.api.AssertProvider; + +import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.lang.Nullable; +import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.RequestBuilder; +import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.test.web.servlet.setup.DefaultMockMvcBuilder; import org.springframework.test.web.servlet.setup.MockMvcBuilders; @@ -58,11 +63,25 @@ import org.springframework.web.context.WebApplicationContext; * MockMvcTester mvc = MockMvcTester.of(new PersonController()); * * - *

Once a tester instance is available, you can perform requests in a similar - * fashion as with {@link MockMvc}, and wrapping the result in - * {@code assertThat()} provides access to assertions. For instance: + *

Simple, single-statement assertions can be done wrapping the request + * builder in {@code assertThat()} provides access to assertions. For instance: *


  * // perform a GET on /hi and assert the response body is equal to Hello
+ * assertThat(mvc.get().uri("/hi")).hasStatusOk().hasBodyTextEqualTo("Hello");
+ * 
+ * + *

For more complex scenarios the {@linkplain MvcTestResult result} of the + * exchange can be assigned in a variable to run multiple assertions: + *


+ * // perform a POST on /save and assert the response body is empty
+ * MvcTestResult result = mvc.post().uri("/save").exchange();
+ * assertThat(result).hasStatus(HttpStatus.CREATED);
+ * assertThat(result).body().isEmpty();
+ * 
+ * + *

You can also perform requests using the static builders approach that + * {@link MockMvc} uses. For instance:


+ * // perform a GET on /hi and assert the response body is equal to Hello
  * assertThat(mvc.perform(get("/hi")))
  *         .hasStatusOk().hasBodyTextEqualTo("Hello");
  * 
@@ -74,12 +93,11 @@ import org.springframework.web.context.WebApplicationContext; * which allows you to assert that a request failed unexpectedly: *

  * // perform a GET on /boom and assert the message for the the unresolved exception
- * assertThat(mvc.perform(get("/boom")))
- *         .hasUnresolvedException())
+ * assertThat(mvc.get().uri("/boom")).hasUnresolvedException())
  *         .withMessage("Test exception");
  * 
* - *

{@link MockMvcTester} can be configured with a list of + *

{@code MockMvcTester} can be configured with a list of * {@linkplain HttpMessageConverter message converters} to allow the response * body to be deserialized, rather than asserting on the raw values. * @@ -104,8 +122,7 @@ public final class MockMvcTester { } /** - * Create a {@link MockMvcTester} instance that delegates to the given - * {@link MockMvc} instance. + * Create an instance that delegates to the given {@link MockMvc} instance. * @param mockMvc the MockMvc instance to delegate calls to */ public static MockMvcTester create(MockMvc mockMvc) { @@ -113,9 +130,9 @@ public final class MockMvcTester { } /** - * Create an {@link MockMvcTester} instance using the given, fully - * initialized (i.e., refreshed) {@link WebApplicationContext}. The - * given {@code customizations} are applied to the {@link DefaultMockMvcBuilder} + * Create an instance using the given, fully initialized (i.e., + * refreshed) {@link WebApplicationContext}. The given + * {@code customizations} are applied to the {@link DefaultMockMvcBuilder} * that ultimately creates the underlying {@link MockMvc} instance. *

If no further customization of the underlying {@link MockMvc} instance * is required, use {@link #from(WebApplicationContext)}. @@ -134,8 +151,8 @@ public final class MockMvcTester { } /** - * Shortcut to create an {@link MockMvcTester} instance using the given, - * fully initialized (i.e., refreshed) {@link WebApplicationContext}. + * Shortcut to create an instance using the given fully initialized (i.e., + * refreshed) {@link WebApplicationContext}. *

Consider using {@link #from(WebApplicationContext, Function)} if * further customization of the underlying {@link MockMvc} instance is * required. @@ -148,9 +165,8 @@ public final class MockMvcTester { } /** - * Create an {@link MockMvcTester} instance by registering one or more - * {@code @Controller} instances and configuring Spring MVC infrastructure - * programmatically. + * Create an instance by registering one or more {@code @Controller} instances + * and configuring Spring MVC infrastructure programmatically. *

This allows full control over the instantiation and initialization of * controllers and their dependencies, similar to plain unit tests while * also making it possible to test one controller at a time. @@ -170,8 +186,8 @@ public final class MockMvcTester { } /** - * Shortcut to create an {@link MockMvcTester} instance by registering one - * or more {@code @Controller} instances. + * Shortcut to create an instance by registering one or more {@code @Controller} + * instances. *

The minimum infrastructure required by the * {@link org.springframework.web.servlet.DispatcherServlet DispatcherServlet} * to serve requests with annotated controllers is created. Consider using @@ -187,8 +203,8 @@ public final class MockMvcTester { } /** - * Return a new {@link MockMvcTester} instance using the specified - * {@linkplain HttpMessageConverter message converters}. + * Return a new instance using the specified {@linkplain HttpMessageConverter + * message converters}. *

If none are specified, only basic assertions on the response body can * be performed. Consider registering a suitable JSON converter for asserting * against JSON data structures. @@ -200,8 +216,105 @@ public final class MockMvcTester { } /** - * Perform a request and return a {@link MvcTestResult result} that can be - * used with standard {@link org.assertj.core.api.Assertions AssertJ} assertions. + * Prepare an HTTP GET request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder get() { + return method(HttpMethod.GET); + } + + /** + * Prepare an HTTP HEAD request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder head() { + return method(HttpMethod.HEAD); + } + + /** + * Prepare an HTTP POST request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder post() { + return method(HttpMethod.POST); + } + + /** + * Prepare an HTTP PUT request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder put() { + return method(HttpMethod.PUT); + } + + /** + * Prepare an HTTP PATCH request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder patch() { + return method(HttpMethod.PATCH); + } + + /** + * Prepare an HTTP DELETE request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder delete() { + return method(HttpMethod.DELETE); + } + + /** + * Prepare an HTTP OPTIONS request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder options() { + return method(HttpMethod.OPTIONS); + } + + /** + * Prepare a request for the specified {@code HttpMethod}. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder method(HttpMethod method) { + return new MockMvcRequestBuilder(method); + } + + /** + * Perform a request using {@link MockMvcRequestBuilders} and return a + * {@link MvcTestResult result} that can be used with standard + * {@link org.assertj.core.api.Assertions AssertJ} assertions. *

Use static methods of {@link MockMvcRequestBuilders} to prepare the * request, wrapping the invocation in {@code assertThat}. The following * asserts that a {@linkplain MockMvcRequestBuilders#get(URI) GET} request @@ -226,6 +339,8 @@ public final class MockMvcTester { * {@link org.springframework.test.web.servlet.request.MockMvcRequestBuilders} * @return an {@link MvcTestResult} to be wrapped in {@code assertThat} * @see MockMvc#perform(RequestBuilder) + * @see #get() + * @see #post() */ public MvcTestResult perform(RequestBuilder requestBuilder) { Object result = getMvcResultOrFailure(requestBuilder); @@ -259,4 +374,25 @@ public final class MockMvcTester { .findFirst().orElse(null); } + + /** + * A builder for {@link MockHttpServletRequest} that supports AssertJ. + */ + public final class MockMvcRequestBuilder extends AbstractMockHttpServletRequestBuilder + implements AssertProvider { + + private MockMvcRequestBuilder(HttpMethod httpMethod) { + super(httpMethod); + } + + public MvcTestResult exchange() { + return perform(this); + } + + @Override + public MvcTestResultAssert assertThat() { + return new MvcTestResultAssert(exchange(), MockMvcTester.this.jsonMessageConverter); + } + } + } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterCompatibilityIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterCompatibilityIntegrationTests.java new file mode 100644 index 00000000000..712b725d90f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterCompatibilityIntegrationTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.assertj; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.http.MediaType; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Integration tests for {@link MockMvcTester} that use the methods that + * integrate with {@link MockMvc} way of building the requests and + * asserting the responses. + * + * @author Stephane Nicoll + */ +@SpringJUnitConfig +@WebAppConfiguration +class MockMvcTesterCompatibilityIntegrationTests { + + private final MockMvcTester mvc; + + MockMvcTesterCompatibilityIntegrationTests(@Autowired WebApplicationContext wac) { + this.mvc = MockMvcTester.from(wac); + } + + @Test + void performGet() { + assertThat(this.mvc.perform(get("/greet"))).hasStatusOk(); + } + + @Test + void performGetWithInvalidMediaTypeAssertion() { + MvcTestResult result = this.mvc.perform(get("/greet")); + assertThatExceptionOfType(AssertionError.class) + .isThrownBy(() -> assertThat(result).hasContentTypeCompatibleWith(MediaType.APPLICATION_JSON)) + .withMessageContaining("is compatible with 'application/json'"); + } + + @Test + void assertHttpStatusCode() { + assertThat(this.mvc.get().uri("/greet")).matches(status().isOk()); + } + + + @Configuration + @EnableWebMvc + @Import(TestController.class) + static class WebConfiguration { + } + + @RestController + static class TestController { + + @GetMapping(path = "/greet", produces = "text/plain") + String greet() { + return "hello"; + } + + @GetMapping(path = "/message", produces = MediaType.APPLICATION_JSON_VALUE) + String message() { + return "{\"message\": \"hello\"}"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java index 90f68c1f553..6778de3fbd0 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java @@ -43,7 +43,7 @@ import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; import org.springframework.test.context.web.WebAppConfiguration; import org.springframework.test.web.Person; -import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; +import org.springframework.test.web.servlet.ResultMatcher; import org.springframework.ui.Model; import org.springframework.validation.Errors; import org.springframework.web.bind.annotation.GetMapping; @@ -63,9 +63,8 @@ import static java.util.Map.entry; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.InstanceOfAssertFactories.map; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Integration tests for {@link MockMvcTester}. @@ -77,10 +76,10 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @WebAppConfiguration public class MockMvcTesterIntegrationTests { - private final MockMvcTester mockMvc; + private final MockMvcTester mvc; MockMvcTesterIntegrationTests(WebApplicationContext wac) { - this.mockMvc = MockMvcTester.from(wac); + this.mvc = MockMvcTester.from(wac); } @Nested @@ -88,24 +87,24 @@ public class MockMvcTesterIntegrationTests { @Test void hasAsyncStartedTrue() { - assertThat(perform(get("/callable").accept(MediaType.APPLICATION_JSON))) + assertThat(mvc.get().uri("/callable").accept(MediaType.APPLICATION_JSON)) .request().hasAsyncStarted(true); } @Test void hasAsyncStartedFalse() { - assertThat(perform(get("/greet"))).request().hasAsyncStarted(false); + assertThat(mvc.get().uri("/greet")).request().hasAsyncStarted(false); } @Test void attributes() { - assertThat(perform(get("/greet"))).request().attributes() + assertThat(mvc.get().uri("/greet")).request().attributes() .containsKey(DispatcherServlet.WEB_APPLICATION_CONTEXT_ATTRIBUTE); } @Test void sessionAttributes() { - assertThat(perform(get("/locale"))).request().sessionAttributes() + assertThat(mvc.get().uri("/locale")).request().sessionAttributes() .containsOnly(entry("locale", Locale.UK)); } } @@ -116,17 +115,17 @@ public class MockMvcTesterIntegrationTests { @Test void containsCookie() { Cookie cookie = new Cookie("test", "value"); - assertThat(performWithCookie(cookie, get("/greet"))).cookies().containsCookie("test"); + assertThat(withCookie(cookie).get().uri("/greet")).cookies().containsCookie("test"); } @Test void hasValue() { Cookie cookie = new Cookie("test", "value"); - assertThat(performWithCookie(cookie, get("/greet"))).cookies().hasValue("test", "value"); + assertThat(withCookie(cookie).get().uri("/greet")).cookies().hasValue("test", "value"); } - private MvcTestResult performWithCookie(Cookie cookie, MockHttpServletRequestBuilder request) { - MockMvcTester mockMvc = MockMvcTester.of(List.of(new TestController()), builder -> builder.addInterceptors( + private MockMvcTester withCookie(Cookie cookie) { + return MockMvcTester.of(List.of(new TestController()), builder -> builder.addInterceptors( new HandlerInterceptor() { @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) { @@ -134,7 +133,6 @@ public class MockMvcTesterIntegrationTests { return true; } }).build()); - return mockMvc.perform(request); } } @@ -143,12 +141,12 @@ public class MockMvcTesterIntegrationTests { @Test void statusOk() { - assertThat(perform(get("/greet"))).hasStatusOk(); + assertThat(mvc.get().uri("/greet")).hasStatusOk(); } @Test void statusSeries() { - assertThat(perform(get("/greet"))).hasStatus2xxSuccessful(); + assertThat(mvc.get().uri("/greet")).hasStatus2xxSuccessful(); } } @@ -158,13 +156,13 @@ public class MockMvcTesterIntegrationTests { @Test void shouldAssertHeader() { - assertThat(perform(get("/greet"))) + assertThat(mvc.get().uri("/greet")) .hasHeader("Content-Type", "text/plain;charset=ISO-8859-1"); } @Test void shouldAssertHeaderWithCallback() { - assertThat(perform(get("/greet"))).headers().satisfies(textContent("ISO-8859-1")); + assertThat(mvc.get().uri("/greet")).headers().satisfies(textContent("ISO-8859-1")); } private Consumer textContent(String charset) { @@ -179,33 +177,33 @@ public class MockMvcTesterIntegrationTests { @Test void hasViewName() { - assertThat(perform(get("/persons/{0}", "Andy"))).hasViewName("persons/index"); + assertThat(mvc.get().uri("/persons/{0}", "Andy")).hasViewName("persons/index"); } @Test void viewNameWithCustomAssertion() { - assertThat(perform(get("/persons/{0}", "Andy"))).viewName().startsWith("persons"); + assertThat(mvc.get().uri("/persons/{0}", "Andy")).viewName().startsWith("persons"); } @Test void containsAttributes() { - assertThat(perform(post("/persons").param("name", "Andy"))).model() + assertThat(mvc.post().uri("/persons").param("name", "Andy")).model() .containsOnlyKeys("name").containsEntry("name", "Andy"); } @Test void hasErrors() { - assertThat(perform(post("/persons"))).model().hasErrors(); + assertThat(mvc.post().uri("/persons")).model().hasErrors(); } @Test void hasAttributeErrors() { - assertThat(perform(post("/persons"))).model().hasAttributeErrors("person"); + assertThat(mvc.post().uri("/persons")).model().hasAttributeErrors("person"); } @Test void hasAttributeErrorsCount() { - assertThat(perform(post("/persons"))).model().extractingBindingResult("person").hasErrorsCount(1); + assertThat(mvc.post().uri("/persons")).model().extractingBindingResult("person").hasErrorsCount(1); } } @@ -215,7 +213,7 @@ public class MockMvcTesterIntegrationTests { @Test void containsAttributes() { - assertThat(perform(post("/persons").param("name", "Andy"))).flash() + assertThat(mvc.post().uri("/persons").param("name", "Andy")).flash() .containsOnlyKeys("message").hasEntrySatisfying("message", value -> assertThat(value).isInstanceOfSatisfying(String.class, stringValue -> assertThat(stringValue).startsWith("success"))); @@ -227,31 +225,31 @@ public class MockMvcTesterIntegrationTests { @Test void asyncResult() { - assertThat(perform(get("/callable").accept(MediaType.APPLICATION_JSON))) + assertThat(mvc.get().uri("/callable").accept(MediaType.APPLICATION_JSON)) .asyncResult().asInstanceOf(map(String.class, Object.class)) .containsOnly(entry("key", "value")); } @Test void stringContent() { - assertThat(perform(get("/greet"))).body().asString().isEqualTo("hello"); + assertThat(mvc.get().uri("/greet")).body().asString().isEqualTo("hello"); } @Test void jsonPathContent() { - assertThat(perform(get("/message"))).bodyJson() + assertThat(mvc.get().uri("/message")).bodyJson() .extractingPath("$.message").asString().isEqualTo("hello"); } @Test void jsonContentCanLoadResourceFromClasspath() { - assertThat(perform(get("/message"))).bodyJson().isLenientlyEqualTo( + assertThat(mvc.get().uri("/message")).bodyJson().isLenientlyEqualTo( new ClassPathResource("message.json", MockMvcTesterIntegrationTests.class)); } @Test void jsonContentUsingResourceLoaderClass() { - assertThat(perform(get("/message"))).bodyJson().withResourceLoadClass(MockMvcTesterIntegrationTests.class) + assertThat(mvc.get().uri("/message")).bodyJson().withResourceLoadClass(MockMvcTesterIntegrationTests.class) .isLenientlyEqualTo("message.json"); } @@ -262,22 +260,22 @@ public class MockMvcTesterIntegrationTests { @Test void handlerOn404() { - assertThat(perform(get("/unknown-resource"))).handler().isNull(); + assertThat(mvc.get().uri("/unknown-resource")).handler().isNull(); } @Test void hasType() { - assertThat(perform(get("/greet"))).handler().hasType(TestController.class); + assertThat(mvc.get().uri("/greet")).handler().hasType(TestController.class); } @Test void isMethodHandler() { - assertThat(perform(get("/greet"))).handler().isMethodHandler(); + assertThat(mvc.get().uri("/greet")).handler().isMethodHandler(); } @Test void isInvokedOn() { - assertThat(perform(get("/callable"))).handler() + assertThat(mvc.get().uri("/callable")).handler() .isInvokedOn(AsyncController.class, AsyncController::getCallable); } @@ -288,31 +286,31 @@ public class MockMvcTesterIntegrationTests { @Test void doesNotHaveUnresolvedException() { - assertThat(perform(get("/greet"))).doesNotHaveUnresolvedException(); + assertThat(mvc.get().uri("/greet")).doesNotHaveUnresolvedException(); } @Test void hasUnresolvedException() { - assertThat(perform(get("/error/1"))).hasUnresolvedException(); + assertThat(mvc.get().uri("/error/1")).hasUnresolvedException(); } @Test void doesNotHaveUnresolvedExceptionWithUnresolvedException() { assertThatExceptionOfType(AssertionError.class) - .isThrownBy(() -> assertThat(perform(get("/error/1"))).doesNotHaveUnresolvedException()) + .isThrownBy(() -> assertThat(mvc.get().uri("/error/1")).doesNotHaveUnresolvedException()) .withMessage("Expected request to succeed, but it failed"); } @Test void hasUnresolvedExceptionWithoutUnresolvedException() { assertThatExceptionOfType(AssertionError.class) - .isThrownBy(() -> assertThat(perform(get("/greet"))).hasUnresolvedException()) + .isThrownBy(() -> assertThat(mvc.get().uri("/greet")).hasUnresolvedException()) .withMessage("Expected request to fail, but it succeeded"); } @Test void unresolvedExceptionWithFailedRequest() { - assertThat(perform(get("/error/1"))).unresolvedException() + assertThat(mvc.get().uri("/error/1")).unresolvedException() .isInstanceOf(ServletException.class) .cause().isInstanceOf(IllegalStateException.class).hasMessage("Expected"); } @@ -320,7 +318,7 @@ public class MockMvcTesterIntegrationTests { @Test void unresolvedExceptionWithSuccessfulRequest() { assertThatExceptionOfType(AssertionError.class) - .isThrownBy(() -> assertThat(perform(get("/greet"))).unresolvedException()) + .isThrownBy(() -> assertThat(mvc.get().uri("/greet")).unresolvedException()) .withMessage("Expected request to fail, but it succeeded"); } @@ -406,7 +404,7 @@ public class MockMvcTesterIntegrationTests { private void testAssertionFailureWithUnresolvableException(Consumer assertions) { - MvcTestResult result = perform(get("/error/1")); + MvcTestResult result = mvc.get().uri("/error/1").exchange(); assertThatExceptionOfType(AssertionError.class) .isThrownBy(() -> assertions.accept(result)) .withMessageContainingAll("Request failed unexpectedly:", @@ -418,49 +416,49 @@ public class MockMvcTesterIntegrationTests { @Test void hasForwardUrl() { - assertThat(perform(get("/persons/John"))).hasForwardedUrl("persons/index"); + assertThat(mvc.get().uri("/persons/John")).hasForwardedUrl("persons/index"); } @Test void hasRedirectUrl() { - assertThat(perform(post("/persons").param("name", "Andy"))).hasStatus(HttpStatus.FOUND) + assertThat(mvc.post().uri("/persons").param("name", "Andy")).hasStatus(HttpStatus.FOUND) .hasRedirectedUrl("/persons/Andy"); } @Test void satisfiesAllowsAdditionalAssertions() { - assertThat(this.mockMvc.perform(get("/greet"))).satisfies(result -> { + assertThat(mvc.get().uri("/greet")).satisfies(result -> { assertThat(result).isInstanceOf(MvcTestResult.class); assertThat(result).hasStatusOk(); }); } @Test - void resultMatcherCanBeReused() { - assertThat(this.mockMvc.perform(get("/greet"))).matches(status().isOk()); + void resultMatcherCanBeReused() throws Exception { + MvcTestResult result = mvc.get().uri("/greet").exchange(); + ResultMatcher matcher = mock(ResultMatcher.class); + assertThat(result).matches(matcher); + verify(matcher).match(result.getMvcResult()); } @Test void resultMatcherFailsWithDedicatedException() { + ResultMatcher matcher = result -> assertThat(result.getResponse().getStatus()) + .isEqualTo(HttpStatus.NOT_FOUND.value()); assertThatExceptionOfType(AssertionError.class) - .isThrownBy(() -> assertThat(this.mockMvc.perform(get("/greet"))) - .matches(status().isNotFound())) - .withMessageContaining("Status expected:<404> but was:<200>"); + .isThrownBy(() -> assertThat(mvc.get().uri("/greet")) + .matches(matcher)) + .withMessageContaining("expected: 404").withMessageContaining(" but was: 200"); } @Test void shouldApplyResultHandler() { // Spring RESTDocs example AtomicBoolean applied = new AtomicBoolean(); - assertThat(this.mockMvc.perform(get("/greet"))).apply(result -> applied.set(true)); + assertThat(mvc.get().uri("/greet")).apply(result -> applied.set(true)); assertThat(applied).isTrue(); } - private MvcTestResult perform(MockHttpServletRequestBuilder builder) { - return this.mockMvc.perform(builder); - } - - @Configuration @EnableWebMvc @Import({ TestController.class, PersonController.class, AsyncController.class,