diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java index c0d4546e083..5dd258ee987 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MockMvcTester.java @@ -23,13 +23,18 @@ import java.util.Map; import java.util.function.Function; import java.util.stream.StreamSupport; +import org.assertj.core.api.AssertProvider; + +import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.lang.Nullable; +import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.RequestBuilder; +import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.test.web.servlet.setup.DefaultMockMvcBuilder; import org.springframework.test.web.servlet.setup.MockMvcBuilders; @@ -58,11 +63,25 @@ import org.springframework.web.context.WebApplicationContext; * MockMvcTester mvc = MockMvcTester.of(new PersonController()); * * - *

Once a tester instance is available, you can perform requests in a similar - * fashion as with {@link MockMvc}, and wrapping the result in - * {@code assertThat()} provides access to assertions. For instance: + *

Simple, single-statement assertions can be done wrapping the request + * builder in {@code assertThat()} provides access to assertions. For instance: *


  * // perform a GET on /hi and assert the response body is equal to Hello
+ * assertThat(mvc.get().uri("/hi")).hasStatusOk().hasBodyTextEqualTo("Hello");
+ * 
+ * + *

For more complex scenarios the {@linkplain MvcTestResult result} of the + * exchange can be assigned in a variable to run multiple assertions: + *


+ * // perform a POST on /save and assert the response body is empty
+ * MvcTestResult result = mvc.post().uri("/save").exchange();
+ * assertThat(result).hasStatus(HttpStatus.CREATED);
+ * assertThat(result).body().isEmpty();
+ * 
+ * + *

You can also perform requests using the static builders approach that + * {@link MockMvc} uses. For instance:


+ * // perform a GET on /hi and assert the response body is equal to Hello
  * assertThat(mvc.perform(get("/hi")))
  *         .hasStatusOk().hasBodyTextEqualTo("Hello");
  * 
@@ -74,12 +93,11 @@ import org.springframework.web.context.WebApplicationContext; * which allows you to assert that a request failed unexpectedly: *

  * // perform a GET on /boom and assert the message for the the unresolved exception
- * assertThat(mvc.perform(get("/boom")))
- *         .hasUnresolvedException())
+ * assertThat(mvc.get().uri("/boom")).hasUnresolvedException())
  *         .withMessage("Test exception");
  * 
* - *

{@link MockMvcTester} can be configured with a list of + *

{@code MockMvcTester} can be configured with a list of * {@linkplain HttpMessageConverter message converters} to allow the response * body to be deserialized, rather than asserting on the raw values. * @@ -104,8 +122,7 @@ public final class MockMvcTester { } /** - * Create a {@link MockMvcTester} instance that delegates to the given - * {@link MockMvc} instance. + * Create an instance that delegates to the given {@link MockMvc} instance. * @param mockMvc the MockMvc instance to delegate calls to */ public static MockMvcTester create(MockMvc mockMvc) { @@ -113,9 +130,9 @@ public final class MockMvcTester { } /** - * Create an {@link MockMvcTester} instance using the given, fully - * initialized (i.e., refreshed) {@link WebApplicationContext}. The - * given {@code customizations} are applied to the {@link DefaultMockMvcBuilder} + * Create an instance using the given, fully initialized (i.e., + * refreshed) {@link WebApplicationContext}. The given + * {@code customizations} are applied to the {@link DefaultMockMvcBuilder} * that ultimately creates the underlying {@link MockMvc} instance. *

If no further customization of the underlying {@link MockMvc} instance * is required, use {@link #from(WebApplicationContext)}. @@ -134,8 +151,8 @@ public final class MockMvcTester { } /** - * Shortcut to create an {@link MockMvcTester} instance using the given, - * fully initialized (i.e., refreshed) {@link WebApplicationContext}. + * Shortcut to create an instance using the given fully initialized (i.e., + * refreshed) {@link WebApplicationContext}. *

Consider using {@link #from(WebApplicationContext, Function)} if * further customization of the underlying {@link MockMvc} instance is * required. @@ -148,9 +165,8 @@ public final class MockMvcTester { } /** - * Create an {@link MockMvcTester} instance by registering one or more - * {@code @Controller} instances and configuring Spring MVC infrastructure - * programmatically. + * Create an instance by registering one or more {@code @Controller} instances + * and configuring Spring MVC infrastructure programmatically. *

This allows full control over the instantiation and initialization of * controllers and their dependencies, similar to plain unit tests while * also making it possible to test one controller at a time. @@ -170,8 +186,8 @@ public final class MockMvcTester { } /** - * Shortcut to create an {@link MockMvcTester} instance by registering one - * or more {@code @Controller} instances. + * Shortcut to create an instance by registering one or more {@code @Controller} + * instances. *

The minimum infrastructure required by the * {@link org.springframework.web.servlet.DispatcherServlet DispatcherServlet} * to serve requests with annotated controllers is created. Consider using @@ -187,8 +203,8 @@ public final class MockMvcTester { } /** - * Return a new {@link MockMvcTester} instance using the specified - * {@linkplain HttpMessageConverter message converters}. + * Return a new instance using the specified {@linkplain HttpMessageConverter + * message converters}. *

If none are specified, only basic assertions on the response body can * be performed. Consider registering a suitable JSON converter for asserting * against JSON data structures. @@ -200,8 +216,105 @@ public final class MockMvcTester { } /** - * Perform a request and return a {@link MvcTestResult result} that can be - * used with standard {@link org.assertj.core.api.Assertions AssertJ} assertions. + * Prepare an HTTP GET request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder get() { + return method(HttpMethod.GET); + } + + /** + * Prepare an HTTP HEAD request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder head() { + return method(HttpMethod.HEAD); + } + + /** + * Prepare an HTTP POST request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder post() { + return method(HttpMethod.POST); + } + + /** + * Prepare an HTTP PUT request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder put() { + return method(HttpMethod.PUT); + } + + /** + * Prepare an HTTP PATCH request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder patch() { + return method(HttpMethod.PATCH); + } + + /** + * Prepare an HTTP DELETE request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder delete() { + return method(HttpMethod.DELETE); + } + + /** + * Prepare an HTTP OPTIONS request. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder options() { + return method(HttpMethod.OPTIONS); + } + + /** + * Prepare a request for the specified {@code HttpMethod}. + *

The returned builder can be wrapped in {@code assertThat} to enable + * assertions on the result. For multi-statements assertions, use + * {@linkplain MockMvcRequestBuilder#exchange() exchange} to assign the + * result. + * @return a request builder for specifying the target URI + */ + public MockMvcRequestBuilder method(HttpMethod method) { + return new MockMvcRequestBuilder(method); + } + + /** + * Perform a request using {@link MockMvcRequestBuilders} and return a + * {@link MvcTestResult result} that can be used with standard + * {@link org.assertj.core.api.Assertions AssertJ} assertions. *

Use static methods of {@link MockMvcRequestBuilders} to prepare the * request, wrapping the invocation in {@code assertThat}. The following * asserts that a {@linkplain MockMvcRequestBuilders#get(URI) GET} request @@ -226,6 +339,8 @@ public final class MockMvcTester { * {@link org.springframework.test.web.servlet.request.MockMvcRequestBuilders} * @return an {@link MvcTestResult} to be wrapped in {@code assertThat} * @see MockMvc#perform(RequestBuilder) + * @see #get() + * @see #post() */ public MvcTestResult perform(RequestBuilder requestBuilder) { Object result = getMvcResultOrFailure(requestBuilder); @@ -259,4 +374,25 @@ public final class MockMvcTester { .findFirst().orElse(null); } + + /** + * A builder for {@link MockHttpServletRequest} that supports AssertJ. + */ + public final class MockMvcRequestBuilder extends AbstractMockHttpServletRequestBuilder + implements AssertProvider { + + private MockMvcRequestBuilder(HttpMethod httpMethod) { + super(httpMethod); + } + + public MvcTestResult exchange() { + return perform(this); + } + + @Override + public MvcTestResultAssert assertThat() { + return new MvcTestResultAssert(exchange(), MockMvcTester.this.jsonMessageConverter); + } + } + } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcHttpConnector.java b/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcHttpConnector.java index 36a49b58c2b..d54278330a8 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcHttpConnector.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/client/MockMvcHttpConnector.java @@ -54,6 +54,7 @@ import org.springframework.test.web.reactive.server.MockServerClientHttpResponse import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.RequestBuilder; +import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.MockMultipartHttpServletRequestBuilder; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; @@ -134,7 +135,7 @@ public class MockMvcHttpConnector implements ClientHttpConnector { // Initialize the client request requestCallback.apply(httpRequest).block(TIMEOUT); - MockHttpServletRequestBuilder requestBuilder = + AbstractMockHttpServletRequestBuilder requestBuilder = initRequestBuilder(httpMethod, uri, httpRequest, contentRef.get()); requestBuilder.headers(httpRequest.getHeaders()); @@ -149,7 +150,7 @@ public class MockMvcHttpConnector implements ClientHttpConnector { return requestBuilder; } - private MockHttpServletRequestBuilder initRequestBuilder( + private AbstractMockHttpServletRequestBuilder initRequestBuilder( HttpMethod httpMethod, URI uri, MockClientHttpRequest httpRequest, @Nullable byte[] bytes) { String contentType = httpRequest.getHeaders().getFirst(HttpHeaders.CONTENT_TYPE); diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java new file mode 100644 index 00000000000..f8b0a54f3dc --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/AbstractMockHttpServletRequestBuilder.java @@ -0,0 +1,948 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.request; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpSession; + +import org.springframework.beans.Mergeable; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.lang.Nullable; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.FlashMap; +import org.springframework.web.servlet.FlashMapManager; +import org.springframework.web.servlet.support.SessionFlashMapManager; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; +import org.springframework.web.util.UrlPathHelper; + +/** + * Base builder for {@link MockHttpServletRequest} required as input to + * perform requests in {@link MockMvc}. + * + * @author Rossen Stoyanchev + * @author Juergen Hoeller + * @author Arjen Poutsma + * @author Sam Brannen + * @author Kamill Sokol + * @since 6.2 + * @param a self reference to the builder type + */ +public abstract class AbstractMockHttpServletRequestBuilder> + implements ConfigurableSmartRequestBuilder, Mergeable { + + private final HttpMethod method; + + @Nullable + private URI uri; + + private String contextPath = ""; + + private String servletPath = ""; + + @Nullable + private String pathInfo = ""; + + @Nullable + private Boolean secure; + + @Nullable + private Principal principal; + + @Nullable + private MockHttpSession session; + + @Nullable + private String remoteAddress; + + @Nullable + private String characterEncoding; + + @Nullable + private byte[] content; + + @Nullable + private String contentType; + + private final MultiValueMap headers = new LinkedMultiValueMap<>(); + + private final MultiValueMap parameters = new LinkedMultiValueMap<>(); + + private final MultiValueMap queryParams = new LinkedMultiValueMap<>(); + + private final MultiValueMap formFields = new LinkedMultiValueMap<>(); + + private final List cookies = new ArrayList<>(); + + private final List locales = new ArrayList<>(); + + private final Map requestAttributes = new LinkedHashMap<>(); + + private final Map sessionAttributes = new LinkedHashMap<>(); + + private final Map flashAttributes = new LinkedHashMap<>(); + + private final List postProcessors = new ArrayList<>(); + + + /** + * Create a new instance using the specified {@link HttpMethod}. + * @param httpMethod the HTTP method (GET, POST, etc.) + */ + protected AbstractMockHttpServletRequestBuilder(HttpMethod httpMethod) { + Assert.notNull(httpMethod, "'httpMethod' is required"); + this.method = httpMethod; + } + + @SuppressWarnings("unchecked") + protected B self() { + return (B) this; + } + + /** + * Specify the URI using an absolute, fully constructed {@link java.net.URI}. + */ + public B uri(URI uri) { + this.uri = uri; + return self(); + } + + /** + * Specify the URI for the request using a URI template and URI variables. + */ + public B uri(String uriTemplate, Object... uriVariables) { + return uri(initUri(uriTemplate, uriVariables)); + } + + private static URI initUri(String uri, Object[] vars) { + Assert.notNull(uri, "'uri' must not be null"); + Assert.isTrue(uri.isEmpty() || uri.startsWith("/") || uri.startsWith("http://") || uri.startsWith("https://"), + () -> "'uri' should start with a path or be a complete HTTP URI: " + uri); + String uriString = (uri.isEmpty() ? "/" : uri); + return UriComponentsBuilder.fromUriString(uriString).buildAndExpand(vars).encode().toUri(); + } + + /** + * Specify the portion of the requestURI that represents the context path. + * The context path, if specified, must match to the start of the request URI. + *

In most cases, tests can be written by omitting the context path from + * the requestURI. This is because most applications don't actually depend + * on the name under which they're deployed. If specified here, the context + * path must start with a "/" and must not end with a "/". + * @see jakarta.servlet.http.HttpServletRequest#getContextPath() + */ + public B contextPath(String contextPath) { + if (StringUtils.hasText(contextPath)) { + Assert.isTrue(contextPath.startsWith("/"), "Context path must start with a '/'"); + Assert.isTrue(!contextPath.endsWith("/"), "Context path must not end with a '/'"); + } + this.contextPath = contextPath; + return self(); + } + + /** + * Specify the portion of the requestURI that represents the path to which + * the Servlet is mapped. This is typically a portion of the requestURI + * after the context path. + *

In most cases, tests can be written by omitting the servlet path from + * the requestURI. This is because most applications don't actually depend + * on the prefix to which a servlet is mapped. For example if a Servlet is + * mapped to {@code "/main/*"}, tests can be written with the requestURI + * {@code "/accounts/1"} as opposed to {@code "/main/accounts/1"}. + * If specified here, the servletPath must start with a "/" and must not + * end with a "/". + * @see jakarta.servlet.http.HttpServletRequest#getServletPath() + */ + public B servletPath(String servletPath) { + if (StringUtils.hasText(servletPath)) { + Assert.isTrue(servletPath.startsWith("/"), "Servlet path must start with a '/'"); + Assert.isTrue(!servletPath.endsWith("/"), "Servlet path must not end with a '/'"); + } + this.servletPath = servletPath; + return self(); + } + + /** + * Specify the portion of the requestURI that represents the pathInfo. + *

If left unspecified (recommended), the pathInfo will be automatically derived + * by removing the contextPath and the servletPath from the requestURI and using any + * remaining part. If specified here, the pathInfo must start with a "/". + *

If specified, the pathInfo will be used as-is. + * @see jakarta.servlet.http.HttpServletRequest#getPathInfo() + */ + public B pathInfo(@Nullable String pathInfo) { + if (StringUtils.hasText(pathInfo)) { + Assert.isTrue(pathInfo.startsWith("/"), "Path info must start with a '/'"); + } + this.pathInfo = pathInfo; + return self(); + } + + /** + * Set the secure property of the {@link ServletRequest} indicating use of a + * secure channel, such as HTTPS. + * @param secure whether the request is using a secure channel + */ + public B secure(boolean secure){ + this.secure = secure; + return self(); + } + + /** + * Set the character encoding of the request. + * @param encoding the character encoding + * @since 5.3.10 + * @see StandardCharsets + * @see #characterEncoding(String) + */ + public B characterEncoding(Charset encoding) { + return characterEncoding(encoding.name()); + } + + /** + * Set the character encoding of the request. + * @param encoding the character encoding + */ + public B characterEncoding(String encoding) { + this.characterEncoding = encoding; + return self(); + } + + /** + * Set the request body. + *

If content is provided and {@link #contentType(MediaType)} is set to + * {@code application/x-www-form-urlencoded}, the content will be parsed + * and used to populate the {@link #param(String, String...) request + * parameters} map. + * @param content the body content + */ + public B content(byte[] content) { + this.content = content; + return self(); + } + + /** + * Set the request body as a UTF-8 String. + *

If content is provided and {@link #contentType(MediaType)} is set to + * {@code application/x-www-form-urlencoded}, the content will be parsed + * and used to populate the {@link #param(String, String...) request + * parameters} map. + * @param content the body content + */ + public B content(String content) { + this.content = content.getBytes(StandardCharsets.UTF_8); + return self(); + } + + /** + * Set the 'Content-Type' header of the request. + *

If content is provided and {@code contentType} is set to + * {@code application/x-www-form-urlencoded}, the content will be parsed + * and used to populate the {@link #param(String, String...) request + * parameters} map. + * @param contentType the content type + */ + public B contentType(MediaType contentType) { + Assert.notNull(contentType, "'contentType' must not be null"); + this.contentType = contentType.toString(); + return self(); + } + + /** + * Set the 'Content-Type' header of the request as a raw String value, + * possibly not even well-formed (for testing purposes). + * @param contentType the content type + * @since 4.1.2 + */ + public B contentType(String contentType) { + Assert.notNull(contentType, "'contentType' must not be null"); + this.contentType = contentType; + return self(); + } + + /** + * Set the 'Accept' header to the given media type(s). + * @param mediaTypes one or more media types + */ + public B accept(MediaType... mediaTypes) { + Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); + this.headers.set("Accept", MediaType.toString(Arrays.asList(mediaTypes))); + return self(); + } + + /** + * Set the {@code Accept} header using raw String values, possibly not even + * well-formed (for testing purposes). + * @param mediaTypes one or more media types; internally joined as + * comma-separated String + */ + public B accept(String... mediaTypes) { + Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); + this.headers.set("Accept", String.join(", ", mediaTypes)); + return self(); + } + + /** + * Add a header to the request. Values are always added. + * @param name the header name + * @param values one or more header values + */ + public B header(String name, Object... values) { + addToMultiValueMap(this.headers, name, values); + return self(); + } + + /** + * Add all headers to the request. Values are always added. + * @param httpHeaders the headers and values to add + */ + public B headers(HttpHeaders httpHeaders) { + httpHeaders.forEach(this.headers::addAll); + return self(); + } + + /** + * Add a request parameter to {@link MockHttpServletRequest#getParameterMap()}. + *

In the Servlet API, a request parameter may be parsed from the query + * string and/or from the body of an {@code application/x-www-form-urlencoded} + * request. This method simply adds to the request parameter map. You may + * also use add Servlet request parameters by specifying the query or form + * data through one of the following: + *

+ * @param name the parameter name + * @param values one or more values + */ + public B param(String name, String... values) { + addToMultiValueMap(this.parameters, name, values); + return self(); + } + + /** + * Variant of {@link #param(String, String...)} with a {@link MultiValueMap}. + * @param params the parameters to add + * @since 4.2.4 + */ + public B params(MultiValueMap params) { + params.forEach((name, values) -> { + for (String value : values) { + this.parameters.add(name, value); + } + }); + return self(); + } + + /** + * Append to the query string and also add to the + * {@link #param(String, String...) request parameters} map. The parameter + * name and value are encoded when they are added to the query string. + * @param name the parameter name + * @param values one or more values + * @since 5.2.2 + */ + public B queryParam(String name, String... values) { + param(name, values); + this.queryParams.addAll(name, Arrays.asList(values)); + return self(); + } + + /** + * Append to the query string and also add to the + * {@link #params(MultiValueMap) request parameters} map. The parameter + * name and value are encoded when they are added to the query string. + * @param params the parameters to add + * @since 5.2.2 + */ + public B queryParams(MultiValueMap params) { + params(params); + this.queryParams.addAll(params); + return self(); + } + + /** + * Append the given value(s) to the given form field and also add them to the + * {@linkplain #param(String, String...) request parameters} map. + * @param name the field name + * @param values one or more values + * @since 6.1.7 + */ + public B formField(String name, String... values) { + param(name, values); + this.formFields.addAll(name, Arrays.asList(values)); + return self(); + } + + /** + * Variant of {@link #formField(String, String...)} with a {@link MultiValueMap}. + * @param formFields the form fields to add + * @since 6.1.7 + */ + public B formFields(MultiValueMap formFields) { + params(formFields); + this.formFields.addAll(formFields); + return self(); + } + + /** + * Add the given cookies to the request. Cookies are always added. + * @param cookies the cookies to add + */ + public B cookie(Cookie... cookies) { + Assert.notEmpty(cookies, "'cookies' must not be empty"); + this.cookies.addAll(Arrays.asList(cookies)); + return self(); + } + + /** + * Add the specified locales as preferred request locales. + * @param locales the locales to add + * @since 4.3.6 + * @see #locale(Locale) + */ + public B locale(Locale... locales) { + Assert.notEmpty(locales, "'locales' must not be empty"); + this.locales.addAll(Arrays.asList(locales)); + return self(); + } + + /** + * Set the locale of the request, overriding any previous locales. + * @param locale the locale, or {@code null} to reset it + * @see #locale(Locale...) + */ + public B locale(@Nullable Locale locale) { + this.locales.clear(); + if (locale != null) { + this.locales.add(locale); + } + return self(); + } + + /** + * Set a request attribute. + * @param name the attribute name + * @param value the attribute value + */ + public B requestAttr(String name, Object value) { + addToMap(this.requestAttributes, name, value); + return self(); + } + + /** + * Set a session attribute. + * @param name the session attribute name + * @param value the session attribute value + */ + public B sessionAttr(String name, Object value) { + addToMap(this.sessionAttributes, name, value); + return self(); + } + + /** + * Set session attributes. + * @param sessionAttributes the session attributes + */ + public B sessionAttrs(Map sessionAttributes) { + Assert.notEmpty(sessionAttributes, "'sessionAttributes' must not be empty"); + sessionAttributes.forEach(this::sessionAttr); + return self(); + } + + /** + * Set an "input" flash attribute. + * @param name the flash attribute name + * @param value the flash attribute value + */ + public B flashAttr(String name, Object value) { + addToMap(this.flashAttributes, name, value); + return self(); + } + + /** + * Set flash attributes. + * @param flashAttributes the flash attributes + */ + public B flashAttrs(Map flashAttributes) { + Assert.notEmpty(flashAttributes, "'flashAttributes' must not be empty"); + flashAttributes.forEach(this::flashAttr); + return self(); + } + + /** + * Set the HTTP session to use, possibly re-used across requests. + *

Individual attributes provided via {@link #sessionAttr(String, Object)} + * override the content of the session provided here. + * @param session the HTTP session + */ + public B session(MockHttpSession session) { + Assert.notNull(session, "'session' must not be null"); + this.session = session; + return self(); + } + + /** + * Set the principal of the request. + * @param principal the principal + */ + public B principal(Principal principal) { + Assert.notNull(principal, "'principal' must not be null"); + this.principal = principal; + return self(); + } + + /** + * Set the remote address of the request. + * @param remoteAddress the remote address (IP) + * @since 6.0.10 + */ + public B remoteAddress(String remoteAddress) { + Assert.hasText(remoteAddress, "'remoteAddress' must not be null or blank"); + this.remoteAddress = remoteAddress; + return self(); + } + + /** + * An extension point for further initialization of {@link MockHttpServletRequest} + * in ways not built directly into the {@code MockHttpServletRequestBuilder}. + * Implementation of this interface can have builder-style methods themselves + * and be made accessible through static factory methods. + * @param postProcessor a post-processor to add + */ + @Override + public B with(RequestPostProcessor postProcessor) { + Assert.notNull(postProcessor, "postProcessor is required"); + this.postProcessors.add(postProcessor); + return self(); + } + + + /** + * {@inheritDoc} + * @return always returns {@code true}. + */ + @Override + public boolean isMergeEnabled() { + return true; + } + + /** + * Merges the properties of the "parent" RequestBuilder accepting values + * only if not already set in "this" instance. + * @param parent the parent {@code RequestBuilder} to inherit properties from + * @return the result of the merge + */ + @Override + public Object merge(@Nullable Object parent) { + if (parent == null) { + return this; + } + if (!(parent instanceof AbstractMockHttpServletRequestBuilder parentBuilder)) { + throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); + } + if (!StringUtils.hasText(this.contextPath)) { + this.contextPath = parentBuilder.contextPath; + } + if (!StringUtils.hasText(this.servletPath)) { + this.servletPath = parentBuilder.servletPath; + } + if ("".equals(this.pathInfo)) { + this.pathInfo = parentBuilder.pathInfo; + } + + if (this.secure == null) { + this.secure = parentBuilder.secure; + } + if (this.principal == null) { + this.principal = parentBuilder.principal; + } + if (this.session == null) { + this.session = parentBuilder.session; + } + if (this.remoteAddress == null) { + this.remoteAddress = parentBuilder.remoteAddress; + } + + if (this.characterEncoding == null) { + this.characterEncoding = parentBuilder.characterEncoding; + } + if (this.content == null) { + this.content = parentBuilder.content; + } + if (this.contentType == null) { + this.contentType = parentBuilder.contentType; + } + + for (Map.Entry> entry : parentBuilder.headers.entrySet()) { + String headerName = entry.getKey(); + if (!this.headers.containsKey(headerName)) { + this.headers.put(headerName, entry.getValue()); + } + } + for (Map.Entry> entry : parentBuilder.parameters.entrySet()) { + String paramName = entry.getKey(); + if (!this.parameters.containsKey(paramName)) { + this.parameters.put(paramName, entry.getValue()); + } + } + for (Map.Entry> entry : parentBuilder.queryParams.entrySet()) { + String paramName = entry.getKey(); + if (!this.queryParams.containsKey(paramName)) { + this.queryParams.put(paramName, entry.getValue()); + } + } + for (Map.Entry> entry : parentBuilder.formFields.entrySet()) { + String paramName = entry.getKey(); + if (!this.formFields.containsKey(paramName)) { + this.formFields.put(paramName, entry.getValue()); + } + } + for (Cookie cookie : parentBuilder.cookies) { + if (!containsCookie(cookie)) { + this.cookies.add(cookie); + } + } + for (Locale locale : parentBuilder.locales) { + if (!this.locales.contains(locale)) { + this.locales.add(locale); + } + } + + for (Map.Entry entry : parentBuilder.requestAttributes.entrySet()) { + String attributeName = entry.getKey(); + if (!this.requestAttributes.containsKey(attributeName)) { + this.requestAttributes.put(attributeName, entry.getValue()); + } + } + for (Map.Entry entry : parentBuilder.sessionAttributes.entrySet()) { + String attributeName = entry.getKey(); + if (!this.sessionAttributes.containsKey(attributeName)) { + this.sessionAttributes.put(attributeName, entry.getValue()); + } + } + for (Map.Entry entry : parentBuilder.flashAttributes.entrySet()) { + String attributeName = entry.getKey(); + if (!this.flashAttributes.containsKey(attributeName)) { + this.flashAttributes.put(attributeName, entry.getValue()); + } + } + + this.postProcessors.addAll(0, parentBuilder.postProcessors); + + return this; + } + + private boolean containsCookie(Cookie cookie) { + for (Cookie cookieToCheck : this.cookies) { + if (ObjectUtils.nullSafeEquals(cookieToCheck.getName(), cookie.getName())) { + return true; + } + } + return false; + } + + /** + * Build a {@link MockHttpServletRequest}. + */ + @Override + public final MockHttpServletRequest buildRequest(ServletContext servletContext) { + Assert.notNull(this.uri, "'uri' is required"); + MockHttpServletRequest request = createServletRequest(servletContext); + + request.setAsyncSupported(true); + request.setMethod(this.method.name()); + + String requestUri = this.uri.getRawPath(); + request.setRequestURI(requestUri); + + if (this.uri.getScheme() != null) { + request.setScheme(this.uri.getScheme()); + } + if (this.uri.getHost() != null) { + request.setServerName(this.uri.getHost()); + } + if (this.uri.getPort() != -1) { + request.setServerPort(this.uri.getPort()); + } + + updatePathRequestProperties(request, requestUri); + + if (this.secure != null) { + request.setSecure(this.secure); + } + if (this.principal != null) { + request.setUserPrincipal(this.principal); + } + if (this.remoteAddress != null) { + request.setRemoteAddr(this.remoteAddress); + } + if (this.session != null) { + request.setSession(this.session); + } + + request.setCharacterEncoding(this.characterEncoding); + request.setContent(this.content); + request.setContentType(this.contentType); + + this.headers.forEach((name, values) -> { + for (Object value : values) { + request.addHeader(name, value); + } + }); + + if (!ObjectUtils.isEmpty(this.content) && + !this.headers.containsKey(HttpHeaders.CONTENT_LENGTH) && + !this.headers.containsKey(HttpHeaders.TRANSFER_ENCODING)) { + + request.addHeader(HttpHeaders.CONTENT_LENGTH, this.content.length); + } + + String query = this.uri.getRawQuery(); + if (!this.queryParams.isEmpty()) { + String str = UriComponentsBuilder.newInstance().queryParams(this.queryParams).build().encode().getQuery(); + query = StringUtils.hasLength(query) ? (query + "&" + str) : str; + } + if (query != null) { + request.setQueryString(query); + } + addRequestParams(request, UriComponentsBuilder.fromUri(this.uri).build().getQueryParams()); + + this.parameters.forEach((name, values) -> { + for (String value : values) { + request.addParameter(name, value); + } + }); + + if (!this.formFields.isEmpty()) { + if (this.content != null && this.content.length > 0) { + throw new IllegalStateException("Could not write form data with an existing body"); + } + Charset charset = (this.characterEncoding != null ? + Charset.forName(this.characterEncoding) : StandardCharsets.UTF_8); + MediaType mediaType = (request.getContentType() != null ? + MediaType.parseMediaType(request.getContentType()) : + new MediaType(MediaType.APPLICATION_FORM_URLENCODED, charset)); + if (!mediaType.isCompatibleWith(MediaType.APPLICATION_FORM_URLENCODED)) { + throw new IllegalStateException("Invalid content type: '" + mediaType + + "' is not compatible with '" + MediaType.APPLICATION_FORM_URLENCODED + "'"); + } + request.setContent(writeFormData(mediaType, charset)); + if (request.getContentType() == null) { + request.setContentType(mediaType.toString()); + } + } + if (this.content != null && this.content.length > 0) { + String requestContentType = request.getContentType(); + if (requestContentType != null) { + try { + MediaType mediaType = MediaType.parseMediaType(requestContentType); + if (MediaType.APPLICATION_FORM_URLENCODED.includes(mediaType)) { + addRequestParams(request, parseFormData(mediaType)); + } + } + catch (Exception ex) { + // Must be invalid, ignore + } + } + } + + if (!ObjectUtils.isEmpty(this.cookies)) { + request.setCookies(this.cookies.toArray(new Cookie[0])); + } + if (!ObjectUtils.isEmpty(this.locales)) { + request.setPreferredLocales(this.locales); + } + + this.requestAttributes.forEach(request::setAttribute); + this.sessionAttributes.forEach((name, attribute) -> { + HttpSession session = request.getSession(); + Assert.state(session != null, "No HttpSession"); + session.setAttribute(name, attribute); + }); + + FlashMap flashMap = new FlashMap(); + flashMap.putAll(this.flashAttributes); + FlashMapManager flashMapManager = getFlashMapManager(request); + flashMapManager.saveOutputFlashMap(flashMap, request, new MockHttpServletResponse()); + + return request; + } + + /** + * Create a new {@link MockHttpServletRequest} based on the supplied + * {@code ServletContext}. + *

Can be overridden in subclasses. + */ + protected MockHttpServletRequest createServletRequest(ServletContext servletContext) { + return new MockHttpServletRequest(servletContext); + } + + /** + * Update the contextPath, servletPath, and pathInfo of the request. + */ + private void updatePathRequestProperties(MockHttpServletRequest request, String requestUri) { + if (!requestUri.startsWith(this.contextPath)) { + throw new IllegalArgumentException( + "Request URI [" + requestUri + "] does not start with context path [" + this.contextPath + "]"); + } + request.setContextPath(this.contextPath); + request.setServletPath(this.servletPath); + + if ("".equals(this.pathInfo)) { + if (!requestUri.startsWith(this.contextPath + this.servletPath)) { + throw new IllegalArgumentException( + "Invalid servlet path [" + this.servletPath + "] for request URI [" + requestUri + "]"); + } + String extraPath = requestUri.substring(this.contextPath.length() + this.servletPath.length()); + this.pathInfo = (StringUtils.hasText(extraPath) ? + UrlPathHelper.defaultInstance.decodeRequestString(request, extraPath) : null); + } + request.setPathInfo(this.pathInfo); + } + + private void addRequestParams(MockHttpServletRequest request, MultiValueMap map) { + map.forEach((key, values) -> values.forEach(value -> { + value = (value != null ? UriUtils.decode(value, StandardCharsets.UTF_8) : null); + request.addParameter(UriUtils.decode(key, StandardCharsets.UTF_8), value); + })); + } + + private byte[] writeFormData(MediaType mediaType, Charset charset) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + HttpOutputMessage message = new HttpOutputMessage() { + @Override + public OutputStream getBody() { + return out; + } + + @Override + public HttpHeaders getHeaders() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(mediaType); + return headers; + } + }; + try { + FormHttpMessageConverter messageConverter = new FormHttpMessageConverter(); + messageConverter.setCharset(charset); + messageConverter.write(this.formFields, mediaType, message); + return out.toByteArray(); + } + catch (IOException ex) { + throw new IllegalStateException("Failed to write form data to request body", ex); + } + } + + @SuppressWarnings("unchecked") + private MultiValueMap parseFormData(MediaType mediaType) { + HttpInputMessage message = new HttpInputMessage() { + @Override + public InputStream getBody() { + byte[] bodyContent = AbstractMockHttpServletRequestBuilder.this.content; + return (bodyContent != null ? new ByteArrayInputStream(bodyContent) : InputStream.nullInputStream()); + } + @Override + public HttpHeaders getHeaders() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(mediaType); + return headers; + } + }; + + try { + return (MultiValueMap) new FormHttpMessageConverter().read(null, message); + } + catch (IOException ex) { + throw new IllegalStateException("Failed to parse form data in request body", ex); + } + } + + private FlashMapManager getFlashMapManager(MockHttpServletRequest request) { + FlashMapManager flashMapManager = null; + try { + ServletContext servletContext = request.getServletContext(); + WebApplicationContext wac = WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext); + flashMapManager = wac.getBean(DispatcherServlet.FLASH_MAP_MANAGER_BEAN_NAME, FlashMapManager.class); + } + catch (IllegalStateException | NoSuchBeanDefinitionException ex) { + // ignore + } + return (flashMapManager != null ? flashMapManager : new SessionFlashMapManager()); + } + + @Override + public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { + for (RequestPostProcessor postProcessor : this.postProcessors) { + request = postProcessor.postProcessRequest(request); + } + return request; + } + + + private static void addToMap(Map map, String name, Object value) { + Assert.hasLength(name, "'name' must not be empty"); + Assert.notNull(value, "'value' must not be null"); + map.put(name, value); + } + + private static void addToMultiValueMap(MultiValueMap map, String name, T[] values) { + Assert.hasLength(name, "'name' must not be empty"); + Assert.notEmpty(values, "'values' must not be empty"); + for (T value : values) { + map.add(name, value); + } + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilder.java index 77d324413d1..d7dc20cff96 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockHttpServletRequestBuilder.java @@ -16,54 +16,12 @@ package org.springframework.test.web.servlet.request; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; import java.net.URI; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.security.Principal; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import jakarta.servlet.ServletContext; -import jakarta.servlet.ServletRequest; -import jakarta.servlet.http.Cookie; -import jakarta.servlet.http.HttpSession; - -import org.springframework.beans.Mergeable; -import org.springframework.beans.factory.NoSuchBeanDefinitionException; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpInputMessage; import org.springframework.http.HttpMethod; -import org.springframework.http.HttpOutputMessage; -import org.springframework.http.MediaType; -import org.springframework.http.converter.FormHttpMessageConverter; -import org.springframework.lang.Nullable; import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.mock.web.MockHttpSession; import org.springframework.test.web.servlet.MockMvc; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; -import org.springframework.util.ObjectUtils; -import org.springframework.util.StringUtils; -import org.springframework.web.context.WebApplicationContext; -import org.springframework.web.context.support.WebApplicationContextUtils; -import org.springframework.web.servlet.DispatcherServlet; -import org.springframework.web.servlet.FlashMap; -import org.springframework.web.servlet.FlashMapManager; -import org.springframework.web.servlet.support.SessionFlashMapManager; -import org.springframework.web.util.UriComponentsBuilder; -import org.springframework.web.util.UriUtils; -import org.springframework.web.util.UrlPathHelper; /** * Default builder for {@link MockHttpServletRequest} required as input to @@ -84,60 +42,7 @@ import org.springframework.web.util.UrlPathHelper; * @since 3.2 */ public class MockHttpServletRequestBuilder - implements ConfigurableSmartRequestBuilder, Mergeable { - - private final HttpMethod method; - - private final URI uri; - - private String contextPath = ""; - - private String servletPath = ""; - - @Nullable - private String pathInfo = ""; - - @Nullable - private Boolean secure; - - @Nullable - private Principal principal; - - @Nullable - private MockHttpSession session; - - @Nullable - private String remoteAddress; - - @Nullable - private String characterEncoding; - - @Nullable - private byte[] content; - - @Nullable - private String contentType; - - private final MultiValueMap headers = new LinkedMultiValueMap<>(); - - private final MultiValueMap parameters = new LinkedMultiValueMap<>(); - - private final MultiValueMap queryParams = new LinkedMultiValueMap<>(); - - private final MultiValueMap formFields = new LinkedMultiValueMap<>(); - - private final List cookies = new ArrayList<>(); - - private final List locales = new ArrayList<>(); - - private final Map requestAttributes = new LinkedHashMap<>(); - - private final Map sessionAttributes = new LinkedHashMap<>(); - - private final Map flashAttributes = new LinkedHashMap<>(); - - private final List postProcessors = new ArrayList<>(); - + extends AbstractMockHttpServletRequestBuilder { /** * Package private constructor. To get an instance, use static factory @@ -150,15 +55,8 @@ public class MockHttpServletRequestBuilder * @param uriVariables zero or more URI variables */ MockHttpServletRequestBuilder(HttpMethod httpMethod, String uriTemplate, Object... uriVariables) { - this(httpMethod, initUri(uriTemplate, uriVariables)); - } - - private static URI initUri(String uri, Object[] vars) { - Assert.notNull(uri, "'uri' must not be null"); - Assert.isTrue(uri.isEmpty() || uri.startsWith("/") || uri.startsWith("http://") || uri.startsWith("https://"), - () -> "'uri' should start with a path or be a complete HTTP URI: " + uri); - String uriString = (uri.isEmpty() ? "/" : uri); - return UriComponentsBuilder.fromUriString(uriString).buildAndExpand(vars).encode().toUri(); + super(httpMethod); + super.uri(uriTemplate, uriVariables); } /** @@ -169,784 +67,9 @@ public class MockHttpServletRequestBuilder * @since 4.0.3 */ MockHttpServletRequestBuilder(HttpMethod httpMethod, URI uri) { - Assert.notNull(httpMethod, "'httpMethod' is required"); + super(httpMethod); Assert.notNull(uri, "'uri' is required"); - this.method = httpMethod; - this.uri = uri; - } - - - /** - * Specify the portion of the requestURI that represents the context path. - * The context path, if specified, must match to the start of the request URI. - *

In most cases, tests can be written by omitting the context path from - * the requestURI. This is because most applications don't actually depend - * on the name under which they're deployed. If specified here, the context - * path must start with a "/" and must not end with a "/". - * @see jakarta.servlet.http.HttpServletRequest#getContextPath() - */ - public MockHttpServletRequestBuilder contextPath(String contextPath) { - if (StringUtils.hasText(contextPath)) { - Assert.isTrue(contextPath.startsWith("/"), "Context path must start with a '/'"); - Assert.isTrue(!contextPath.endsWith("/"), "Context path must not end with a '/'"); - } - this.contextPath = contextPath; - return this; - } - - /** - * Specify the portion of the requestURI that represents the path to which - * the Servlet is mapped. This is typically a portion of the requestURI - * after the context path. - *

In most cases, tests can be written by omitting the servlet path from - * the requestURI. This is because most applications don't actually depend - * on the prefix to which a servlet is mapped. For example if a Servlet is - * mapped to {@code "/main/*"}, tests can be written with the requestURI - * {@code "/accounts/1"} as opposed to {@code "/main/accounts/1"}. - * If specified here, the servletPath must start with a "/" and must not - * end with a "/". - * @see jakarta.servlet.http.HttpServletRequest#getServletPath() - */ - public MockHttpServletRequestBuilder servletPath(String servletPath) { - if (StringUtils.hasText(servletPath)) { - Assert.isTrue(servletPath.startsWith("/"), "Servlet path must start with a '/'"); - Assert.isTrue(!servletPath.endsWith("/"), "Servlet path must not end with a '/'"); - } - this.servletPath = servletPath; - return this; - } - - /** - * Specify the portion of the requestURI that represents the pathInfo. - *

If left unspecified (recommended), the pathInfo will be automatically derived - * by removing the contextPath and the servletPath from the requestURI and using any - * remaining part. If specified here, the pathInfo must start with a "/". - *

If specified, the pathInfo will be used as-is. - * @see jakarta.servlet.http.HttpServletRequest#getPathInfo() - */ - public MockHttpServletRequestBuilder pathInfo(@Nullable String pathInfo) { - if (StringUtils.hasText(pathInfo)) { - Assert.isTrue(pathInfo.startsWith("/"), "Path info must start with a '/'"); - } - this.pathInfo = pathInfo; - return this; - } - - /** - * Set the secure property of the {@link ServletRequest} indicating use of a - * secure channel, such as HTTPS. - * @param secure whether the request is using a secure channel - */ - public MockHttpServletRequestBuilder secure(boolean secure){ - this.secure = secure; - return this; - } - - /** - * Set the character encoding of the request. - * @param encoding the character encoding - * @since 5.3.10 - * @see StandardCharsets - * @see #characterEncoding(String) - */ - public MockHttpServletRequestBuilder characterEncoding(Charset encoding) { - return this.characterEncoding(encoding.name()); - } - - /** - * Set the character encoding of the request. - * @param encoding the character encoding - */ - public MockHttpServletRequestBuilder characterEncoding(String encoding) { - this.characterEncoding = encoding; - return this; - } - - /** - * Set the request body. - *

If content is provided and {@link #contentType(MediaType)} is set to - * {@code application/x-www-form-urlencoded}, the content will be parsed - * and used to populate the {@link #param(String, String...) request - * parameters} map. - * @param content the body content - */ - public MockHttpServletRequestBuilder content(byte[] content) { - this.content = content; - return this; - } - - /** - * Set the request body as a UTF-8 String. - *

If content is provided and {@link #contentType(MediaType)} is set to - * {@code application/x-www-form-urlencoded}, the content will be parsed - * and used to populate the {@link #param(String, String...) request - * parameters} map. - * @param content the body content - */ - public MockHttpServletRequestBuilder content(String content) { - this.content = content.getBytes(StandardCharsets.UTF_8); - return this; - } - - /** - * Set the 'Content-Type' header of the request. - *

If content is provided and {@code contentType} is set to - * {@code application/x-www-form-urlencoded}, the content will be parsed - * and used to populate the {@link #param(String, String...) request - * parameters} map. - * @param contentType the content type - */ - public MockHttpServletRequestBuilder contentType(MediaType contentType) { - Assert.notNull(contentType, "'contentType' must not be null"); - this.contentType = contentType.toString(); - return this; - } - - /** - * Set the 'Content-Type' header of the request as a raw String value, - * possibly not even well-formed (for testing purposes). - * @param contentType the content type - * @since 4.1.2 - */ - public MockHttpServletRequestBuilder contentType(String contentType) { - Assert.notNull(contentType, "'contentType' must not be null"); - this.contentType = contentType; - return this; - } - - /** - * Set the 'Accept' header to the given media type(s). - * @param mediaTypes one or more media types - */ - public MockHttpServletRequestBuilder accept(MediaType... mediaTypes) { - Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - this.headers.set("Accept", MediaType.toString(Arrays.asList(mediaTypes))); - return this; - } - - /** - * Set the {@code Accept} header using raw String values, possibly not even - * well-formed (for testing purposes). - * @param mediaTypes one or more media types; internally joined as - * comma-separated String - */ - public MockHttpServletRequestBuilder accept(String... mediaTypes) { - Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - this.headers.set("Accept", String.join(", ", mediaTypes)); - return this; - } - - /** - * Add a header to the request. Values are always added. - * @param name the header name - * @param values one or more header values - */ - public MockHttpServletRequestBuilder header(String name, Object... values) { - addToMultiValueMap(this.headers, name, values); - return this; - } - - /** - * Add all headers to the request. Values are always added. - * @param httpHeaders the headers and values to add - */ - public MockHttpServletRequestBuilder headers(HttpHeaders httpHeaders) { - httpHeaders.forEach(this.headers::addAll); - return this; - } - - /** - * Add a request parameter to {@link MockHttpServletRequest#getParameterMap()}. - *

In the Servlet API, a request parameter may be parsed from the query - * string and/or from the body of an {@code application/x-www-form-urlencoded} - * request. This method simply adds to the request parameter map. You may - * also use add Servlet request parameters by specifying the query or form - * data through one of the following: - *

    - *
  • Supply a URI with a query to {@link MockMvcRequestBuilders}. - *
  • Add query params via {@link #queryParam} or {@link #queryParams}. - *
  • Provide {@link #content} with {@link #contentType} - * {@code application/x-www-form-urlencoded}. - *
- * @param name the parameter name - * @param values one or more values - */ - public MockHttpServletRequestBuilder param(String name, String... values) { - addToMultiValueMap(this.parameters, name, values); - return this; - } - - /** - * Variant of {@link #param(String, String...)} with a {@link MultiValueMap}. - * @param params the parameters to add - * @since 4.2.4 - */ - public MockHttpServletRequestBuilder params(MultiValueMap params) { - params.forEach((name, values) -> { - for (String value : values) { - this.parameters.add(name, value); - } - }); - return this; - } - - /** - * Append to the query string and also add to the - * {@link #param(String, String...) request parameters} map. The parameter - * name and value are encoded when they are added to the query string. - * @param name the parameter name - * @param values one or more values - * @since 5.2.2 - */ - public MockHttpServletRequestBuilder queryParam(String name, String... values) { - param(name, values); - this.queryParams.addAll(name, Arrays.asList(values)); - return this; - } - - /** - * Append to the query string and also add to the - * {@link #params(MultiValueMap) request parameters} map. The parameter - * name and value are encoded when they are added to the query string. - * @param params the parameters to add - * @since 5.2.2 - */ - public MockHttpServletRequestBuilder queryParams(MultiValueMap params) { - params(params); - this.queryParams.addAll(params); - return this; - } - - /** - * Append the given value(s) to the given form field and also add them to the - * {@linkplain #param(String, String...) request parameters} map. - * @param name the field name - * @param values one or more values - * @since 6.1.7 - */ - public MockHttpServletRequestBuilder formField(String name, String... values) { - param(name, values); - this.formFields.addAll(name, Arrays.asList(values)); - return this; - } - - /** - * Variant of {@link #formField(String, String...)} with a {@link MultiValueMap}. - * @param formFields the form fields to add - * @since 6.1.7 - */ - public MockHttpServletRequestBuilder formFields(MultiValueMap formFields) { - params(formFields); - this.formFields.addAll(formFields); - return this; - } - - /** - * Add the given cookies to the request. Cookies are always added. - * @param cookies the cookies to add - */ - public MockHttpServletRequestBuilder cookie(Cookie... cookies) { - Assert.notEmpty(cookies, "'cookies' must not be empty"); - this.cookies.addAll(Arrays.asList(cookies)); - return this; - } - - /** - * Add the specified locales as preferred request locales. - * @param locales the locales to add - * @since 4.3.6 - * @see #locale(Locale) - */ - public MockHttpServletRequestBuilder locale(Locale... locales) { - Assert.notEmpty(locales, "'locales' must not be empty"); - this.locales.addAll(Arrays.asList(locales)); - return this; - } - - /** - * Set the locale of the request, overriding any previous locales. - * @param locale the locale, or {@code null} to reset it - * @see #locale(Locale...) - */ - public MockHttpServletRequestBuilder locale(@Nullable Locale locale) { - this.locales.clear(); - if (locale != null) { - this.locales.add(locale); - } - return this; - } - - /** - * Set a request attribute. - * @param name the attribute name - * @param value the attribute value - */ - public MockHttpServletRequestBuilder requestAttr(String name, Object value) { - addToMap(this.requestAttributes, name, value); - return this; - } - - /** - * Set a session attribute. - * @param name the session attribute name - * @param value the session attribute value - */ - public MockHttpServletRequestBuilder sessionAttr(String name, Object value) { - addToMap(this.sessionAttributes, name, value); - return this; - } - - /** - * Set session attributes. - * @param sessionAttributes the session attributes - */ - public MockHttpServletRequestBuilder sessionAttrs(Map sessionAttributes) { - Assert.notEmpty(sessionAttributes, "'sessionAttributes' must not be empty"); - sessionAttributes.forEach(this::sessionAttr); - return this; - } - - /** - * Set an "input" flash attribute. - * @param name the flash attribute name - * @param value the flash attribute value - */ - public MockHttpServletRequestBuilder flashAttr(String name, Object value) { - addToMap(this.flashAttributes, name, value); - return this; - } - - /** - * Set flash attributes. - * @param flashAttributes the flash attributes - */ - public MockHttpServletRequestBuilder flashAttrs(Map flashAttributes) { - Assert.notEmpty(flashAttributes, "'flashAttributes' must not be empty"); - flashAttributes.forEach(this::flashAttr); - return this; - } - - /** - * Set the HTTP session to use, possibly re-used across requests. - *

Individual attributes provided via {@link #sessionAttr(String, Object)} - * override the content of the session provided here. - * @param session the HTTP session - */ - public MockHttpServletRequestBuilder session(MockHttpSession session) { - Assert.notNull(session, "'session' must not be null"); - this.session = session; - return this; - } - - /** - * Set the principal of the request. - * @param principal the principal - */ - public MockHttpServletRequestBuilder principal(Principal principal) { - Assert.notNull(principal, "'principal' must not be null"); - this.principal = principal; - return this; - } - - /** - * Set the remote address of the request. - * @param remoteAddress the remote address (IP) - * @since 6.0.10 - */ - public MockHttpServletRequestBuilder remoteAddress(String remoteAddress) { - Assert.hasText(remoteAddress, "'remoteAddress' must not be null or blank"); - this.remoteAddress = remoteAddress; - return this; - } - - /** - * An extension point for further initialization of {@link MockHttpServletRequest} - * in ways not built directly into the {@code MockHttpServletRequestBuilder}. - * Implementation of this interface can have builder-style methods themselves - * and be made accessible through static factory methods. - * @param postProcessor a post-processor to add - */ - @Override - public MockHttpServletRequestBuilder with(RequestPostProcessor postProcessor) { - Assert.notNull(postProcessor, "postProcessor is required"); - this.postProcessors.add(postProcessor); - return this; - } - - - /** - * {@inheritDoc} - * @return always returns {@code true}. - */ - @Override - public boolean isMergeEnabled() { - return true; - } - - /** - * Merges the properties of the "parent" RequestBuilder accepting values - * only if not already set in "this" instance. - * @param parent the parent {@code RequestBuilder} to inherit properties from - * @return the result of the merge - */ - @Override - public Object merge(@Nullable Object parent) { - if (parent == null) { - return this; - } - if (!(parent instanceof MockHttpServletRequestBuilder parentBuilder)) { - throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]"); - } - if (!StringUtils.hasText(this.contextPath)) { - this.contextPath = parentBuilder.contextPath; - } - if (!StringUtils.hasText(this.servletPath)) { - this.servletPath = parentBuilder.servletPath; - } - if ("".equals(this.pathInfo)) { - this.pathInfo = parentBuilder.pathInfo; - } - - if (this.secure == null) { - this.secure = parentBuilder.secure; - } - if (this.principal == null) { - this.principal = parentBuilder.principal; - } - if (this.session == null) { - this.session = parentBuilder.session; - } - if (this.remoteAddress == null) { - this.remoteAddress = parentBuilder.remoteAddress; - } - - if (this.characterEncoding == null) { - this.characterEncoding = parentBuilder.characterEncoding; - } - if (this.content == null) { - this.content = parentBuilder.content; - } - if (this.contentType == null) { - this.contentType = parentBuilder.contentType; - } - - for (Map.Entry> entry : parentBuilder.headers.entrySet()) { - String headerName = entry.getKey(); - if (!this.headers.containsKey(headerName)) { - this.headers.put(headerName, entry.getValue()); - } - } - for (Map.Entry> entry : parentBuilder.parameters.entrySet()) { - String paramName = entry.getKey(); - if (!this.parameters.containsKey(paramName)) { - this.parameters.put(paramName, entry.getValue()); - } - } - for (Map.Entry> entry : parentBuilder.queryParams.entrySet()) { - String paramName = entry.getKey(); - if (!this.queryParams.containsKey(paramName)) { - this.queryParams.put(paramName, entry.getValue()); - } - } - for (Map.Entry> entry : parentBuilder.formFields.entrySet()) { - String paramName = entry.getKey(); - if (!this.formFields.containsKey(paramName)) { - this.formFields.put(paramName, entry.getValue()); - } - } - for (Cookie cookie : parentBuilder.cookies) { - if (!containsCookie(cookie)) { - this.cookies.add(cookie); - } - } - for (Locale locale : parentBuilder.locales) { - if (!this.locales.contains(locale)) { - this.locales.add(locale); - } - } - - for (Map.Entry entry : parentBuilder.requestAttributes.entrySet()) { - String attributeName = entry.getKey(); - if (!this.requestAttributes.containsKey(attributeName)) { - this.requestAttributes.put(attributeName, entry.getValue()); - } - } - for (Map.Entry entry : parentBuilder.sessionAttributes.entrySet()) { - String attributeName = entry.getKey(); - if (!this.sessionAttributes.containsKey(attributeName)) { - this.sessionAttributes.put(attributeName, entry.getValue()); - } - } - for (Map.Entry entry : parentBuilder.flashAttributes.entrySet()) { - String attributeName = entry.getKey(); - if (!this.flashAttributes.containsKey(attributeName)) { - this.flashAttributes.put(attributeName, entry.getValue()); - } - } - - this.postProcessors.addAll(0, parentBuilder.postProcessors); - - return this; - } - - private boolean containsCookie(Cookie cookie) { - for (Cookie cookieToCheck : this.cookies) { - if (ObjectUtils.nullSafeEquals(cookieToCheck.getName(), cookie.getName())) { - return true; - } - } - return false; - } - - /** - * Build a {@link MockHttpServletRequest}. - */ - @Override - public final MockHttpServletRequest buildRequest(ServletContext servletContext) { - MockHttpServletRequest request = createServletRequest(servletContext); - - request.setAsyncSupported(true); - request.setMethod(this.method.name()); - - String requestUri = this.uri.getRawPath(); - request.setRequestURI(requestUri); - - if (this.uri.getScheme() != null) { - request.setScheme(this.uri.getScheme()); - } - if (this.uri.getHost() != null) { - request.setServerName(this.uri.getHost()); - } - if (this.uri.getPort() != -1) { - request.setServerPort(this.uri.getPort()); - } - - updatePathRequestProperties(request, requestUri); - - if (this.secure != null) { - request.setSecure(this.secure); - } - if (this.principal != null) { - request.setUserPrincipal(this.principal); - } - if (this.remoteAddress != null) { - request.setRemoteAddr(this.remoteAddress); - } - if (this.session != null) { - request.setSession(this.session); - } - - request.setCharacterEncoding(this.characterEncoding); - request.setContent(this.content); - request.setContentType(this.contentType); - - this.headers.forEach((name, values) -> { - for (Object value : values) { - request.addHeader(name, value); - } - }); - - if (!ObjectUtils.isEmpty(this.content) && - !this.headers.containsKey(HttpHeaders.CONTENT_LENGTH) && - !this.headers.containsKey(HttpHeaders.TRANSFER_ENCODING)) { - - request.addHeader(HttpHeaders.CONTENT_LENGTH, this.content.length); - } - - String query = this.uri.getRawQuery(); - if (!this.queryParams.isEmpty()) { - String str = UriComponentsBuilder.newInstance().queryParams(this.queryParams).build().encode().getQuery(); - query = StringUtils.hasLength(query) ? (query + "&" + str) : str; - } - if (query != null) { - request.setQueryString(query); - } - addRequestParams(request, UriComponentsBuilder.fromUri(this.uri).build().getQueryParams()); - - this.parameters.forEach((name, values) -> { - for (String value : values) { - request.addParameter(name, value); - } - }); - - if (!this.formFields.isEmpty()) { - if (this.content != null && this.content.length > 0) { - throw new IllegalStateException("Could not write form data with an existing body"); - } - Charset charset = (this.characterEncoding != null ? - Charset.forName(this.characterEncoding) : StandardCharsets.UTF_8); - MediaType mediaType = (request.getContentType() != null ? - MediaType.parseMediaType(request.getContentType()) : - new MediaType(MediaType.APPLICATION_FORM_URLENCODED, charset)); - if (!mediaType.isCompatibleWith(MediaType.APPLICATION_FORM_URLENCODED)) { - throw new IllegalStateException("Invalid content type: '" + mediaType + - "' is not compatible with '" + MediaType.APPLICATION_FORM_URLENCODED + "'"); - } - request.setContent(writeFormData(mediaType, charset)); - if (request.getContentType() == null) { - request.setContentType(mediaType.toString()); - } - } - if (this.content != null && this.content.length > 0) { - String requestContentType = request.getContentType(); - if (requestContentType != null) { - try { - MediaType mediaType = MediaType.parseMediaType(requestContentType); - if (MediaType.APPLICATION_FORM_URLENCODED.includes(mediaType)) { - addRequestParams(request, parseFormData(mediaType)); - } - } - catch (Exception ex) { - // Must be invalid, ignore - } - } - } - - if (!ObjectUtils.isEmpty(this.cookies)) { - request.setCookies(this.cookies.toArray(new Cookie[0])); - } - if (!ObjectUtils.isEmpty(this.locales)) { - request.setPreferredLocales(this.locales); - } - - this.requestAttributes.forEach(request::setAttribute); - this.sessionAttributes.forEach((name, attribute) -> { - HttpSession session = request.getSession(); - Assert.state(session != null, "No HttpSession"); - session.setAttribute(name, attribute); - }); - - FlashMap flashMap = new FlashMap(); - flashMap.putAll(this.flashAttributes); - FlashMapManager flashMapManager = getFlashMapManager(request); - flashMapManager.saveOutputFlashMap(flashMap, request, new MockHttpServletResponse()); - - return request; - } - - /** - * Create a new {@link MockHttpServletRequest} based on the supplied - * {@code ServletContext}. - *

Can be overridden in subclasses. - */ - protected MockHttpServletRequest createServletRequest(ServletContext servletContext) { - return new MockHttpServletRequest(servletContext); - } - - /** - * Update the contextPath, servletPath, and pathInfo of the request. - */ - private void updatePathRequestProperties(MockHttpServletRequest request, String requestUri) { - if (!requestUri.startsWith(this.contextPath)) { - throw new IllegalArgumentException( - "Request URI [" + requestUri + "] does not start with context path [" + this.contextPath + "]"); - } - request.setContextPath(this.contextPath); - request.setServletPath(this.servletPath); - - if ("".equals(this.pathInfo)) { - if (!requestUri.startsWith(this.contextPath + this.servletPath)) { - throw new IllegalArgumentException( - "Invalid servlet path [" + this.servletPath + "] for request URI [" + requestUri + "]"); - } - String extraPath = requestUri.substring(this.contextPath.length() + this.servletPath.length()); - this.pathInfo = (StringUtils.hasText(extraPath) ? - UrlPathHelper.defaultInstance.decodeRequestString(request, extraPath) : null); - } - request.setPathInfo(this.pathInfo); - } - - private void addRequestParams(MockHttpServletRequest request, MultiValueMap map) { - map.forEach((key, values) -> values.forEach(value -> { - value = (value != null ? UriUtils.decode(value, StandardCharsets.UTF_8) : null); - request.addParameter(UriUtils.decode(key, StandardCharsets.UTF_8), value); - })); - } - - private byte[] writeFormData(MediaType mediaType, Charset charset) { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - HttpOutputMessage message = new HttpOutputMessage() { - @Override - public OutputStream getBody() { - return out; - } - - @Override - public HttpHeaders getHeaders() { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(mediaType); - return headers; - } - }; - try { - FormHttpMessageConverter messageConverter = new FormHttpMessageConverter(); - messageConverter.setCharset(charset); - messageConverter.write(this.formFields, mediaType, message); - return out.toByteArray(); - } - catch (IOException ex) { - throw new IllegalStateException("Failed to write form data to request body", ex); - } - } - - @SuppressWarnings("unchecked") - private MultiValueMap parseFormData(MediaType mediaType) { - HttpInputMessage message = new HttpInputMessage() { - @Override - public InputStream getBody() { - byte[] bodyContent = MockHttpServletRequestBuilder.this.content; - return (bodyContent != null ? new ByteArrayInputStream(bodyContent) : InputStream.nullInputStream()); - } - @Override - public HttpHeaders getHeaders() { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(mediaType); - return headers; - } - }; - - try { - return (MultiValueMap) new FormHttpMessageConverter().read(null, message); - } - catch (IOException ex) { - throw new IllegalStateException("Failed to parse form data in request body", ex); - } - } - - private FlashMapManager getFlashMapManager(MockHttpServletRequest request) { - FlashMapManager flashMapManager = null; - try { - ServletContext servletContext = request.getServletContext(); - WebApplicationContext wac = WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext); - flashMapManager = wac.getBean(DispatcherServlet.FLASH_MAP_MANAGER_BEAN_NAME, FlashMapManager.class); - } - catch (IllegalStateException | NoSuchBeanDefinitionException ex) { - // ignore - } - return (flashMapManager != null ? flashMapManager : new SessionFlashMapManager()); - } - - @Override - public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { - for (RequestPostProcessor postProcessor : this.postProcessors) { - request = postProcessor.postProcessRequest(request); - } - return request; - } - - - private static void addToMap(Map map, String name, Object value) { - Assert.hasLength(name, "'name' must not be empty"); - Assert.notNull(value, "'value' must not be null"); - map.put(name, value); - } - - private static void addToMultiValueMap(MultiValueMap map, String name, T[] values) { - Assert.hasLength(name, "'name' must not be empty"); - Assert.notEmpty(values, "'values' must not be empty"); - for (T value : values) { - map.add(name, value); - } + super.uri(uri); } } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java index fcadb3f531f..36a4a548c3e 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java @@ -47,7 +47,7 @@ import org.springframework.util.MultiValueMap; * @author Arjen Poutsma * @since 3.2 */ -public class MockMultipartHttpServletRequestBuilder extends MockHttpServletRequestBuilder { +public class MockMultipartHttpServletRequestBuilder extends AbstractMockHttpServletRequestBuilder { private final List files = new ArrayList<>(); @@ -73,7 +73,8 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque * @since 5.3.22 */ MockMultipartHttpServletRequestBuilder(HttpMethod httpMethod, String uriTemplate, Object... uriVariables) { - super(httpMethod, uriTemplate, uriVariables); + super(httpMethod); + super.uri(uriTemplate, uriVariables); super.contentType(MediaType.MULTIPART_FORM_DATA); } @@ -92,7 +93,8 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque * @since 5.3.21 */ MockMultipartHttpServletRequestBuilder(HttpMethod httpMethod, URI uri) { - super(httpMethod, uri); + super(httpMethod); + super.uri(uri); super.contentType(MediaType.MULTIPART_FORM_DATA); } @@ -134,7 +136,7 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque if (parent == null) { return this; } - if (parent instanceof MockHttpServletRequestBuilder) { + if (parent instanceof AbstractMockHttpServletRequestBuilder) { super.merge(parent); if (parent instanceof MockMultipartHttpServletRequestBuilder parentBuilder) { this.files.addAll(parentBuilder.files); diff --git a/spring-test/src/main/kotlin/org/springframework/test/web/servlet/MockHttpServletRequestDsl.kt b/spring-test/src/main/kotlin/org/springframework/test/web/servlet/MockHttpServletRequestDsl.kt index 39298114469..bacee88b6a5 100644 --- a/spring-test/src/main/kotlin/org/springframework/test/web/servlet/MockHttpServletRequestDsl.kt +++ b/spring-test/src/main/kotlin/org/springframework/test/web/servlet/MockHttpServletRequestDsl.kt @@ -25,6 +25,7 @@ import org.springframework.util.MultiValueMap import java.security.Principal import java.util.* import jakarta.servlet.http.Cookie +import org.springframework.test.web.servlet.request.AbstractMockHttpServletRequestBuilder /** * Provide a [MockHttpServletRequestBuilder] Kotlin DSL in order to be able to write idiomatic Kotlin code. @@ -40,7 +41,7 @@ import jakarta.servlet.http.Cookie * @author Sebastien Deleuze * @since 5.2 */ -open class MockHttpServletRequestDsl internal constructor (private val builder: MockHttpServletRequestBuilder) { +open class MockHttpServletRequestDsl internal constructor (private val builder: AbstractMockHttpServletRequestBuilder<*>) { /** * @see [MockHttpServletRequestBuilder.contextPath] diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterCompatibilityIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterCompatibilityIntegrationTests.java new file mode 100644 index 00000000000..712b725d90f --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterCompatibilityIntegrationTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.web.servlet.assertj; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.http.MediaType; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +/** + * Integration tests for {@link MockMvcTester} that use the methods that + * integrate with {@link MockMvc} way of building the requests and + * asserting the responses. + * + * @author Stephane Nicoll + */ +@SpringJUnitConfig +@WebAppConfiguration +class MockMvcTesterCompatibilityIntegrationTests { + + private final MockMvcTester mvc; + + MockMvcTesterCompatibilityIntegrationTests(@Autowired WebApplicationContext wac) { + this.mvc = MockMvcTester.from(wac); + } + + @Test + void performGet() { + assertThat(this.mvc.perform(get("/greet"))).hasStatusOk(); + } + + @Test + void performGetWithInvalidMediaTypeAssertion() { + MvcTestResult result = this.mvc.perform(get("/greet")); + assertThatExceptionOfType(AssertionError.class) + .isThrownBy(() -> assertThat(result).hasContentTypeCompatibleWith(MediaType.APPLICATION_JSON)) + .withMessageContaining("is compatible with 'application/json'"); + } + + @Test + void assertHttpStatusCode() { + assertThat(this.mvc.get().uri("/greet")).matches(status().isOk()); + } + + + @Configuration + @EnableWebMvc + @Import(TestController.class) + static class WebConfiguration { + } + + @RestController + static class TestController { + + @GetMapping(path = "/greet", produces = "text/plain") + String greet() { + return "hello"; + } + + @GetMapping(path = "/message", produces = MediaType.APPLICATION_JSON_VALUE) + String message() { + return "{\"message\": \"hello\"}"; + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java index 90f68c1f553..6778de3fbd0 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java @@ -43,7 +43,7 @@ import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; import org.springframework.test.context.web.WebAppConfiguration; import org.springframework.test.web.Person; -import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; +import org.springframework.test.web.servlet.ResultMatcher; import org.springframework.ui.Model; import org.springframework.validation.Errors; import org.springframework.web.bind.annotation.GetMapping; @@ -63,9 +63,8 @@ import static java.util.Map.entry; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.InstanceOfAssertFactories.map; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Integration tests for {@link MockMvcTester}. @@ -77,10 +76,10 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers. @WebAppConfiguration public class MockMvcTesterIntegrationTests { - private final MockMvcTester mockMvc; + private final MockMvcTester mvc; MockMvcTesterIntegrationTests(WebApplicationContext wac) { - this.mockMvc = MockMvcTester.from(wac); + this.mvc = MockMvcTester.from(wac); } @Nested @@ -88,24 +87,24 @@ public class MockMvcTesterIntegrationTests { @Test void hasAsyncStartedTrue() { - assertThat(perform(get("/callable").accept(MediaType.APPLICATION_JSON))) + assertThat(mvc.get().uri("/callable").accept(MediaType.APPLICATION_JSON)) .request().hasAsyncStarted(true); } @Test void hasAsyncStartedFalse() { - assertThat(perform(get("/greet"))).request().hasAsyncStarted(false); + assertThat(mvc.get().uri("/greet")).request().hasAsyncStarted(false); } @Test void attributes() { - assertThat(perform(get("/greet"))).request().attributes() + assertThat(mvc.get().uri("/greet")).request().attributes() .containsKey(DispatcherServlet.WEB_APPLICATION_CONTEXT_ATTRIBUTE); } @Test void sessionAttributes() { - assertThat(perform(get("/locale"))).request().sessionAttributes() + assertThat(mvc.get().uri("/locale")).request().sessionAttributes() .containsOnly(entry("locale", Locale.UK)); } } @@ -116,17 +115,17 @@ public class MockMvcTesterIntegrationTests { @Test void containsCookie() { Cookie cookie = new Cookie("test", "value"); - assertThat(performWithCookie(cookie, get("/greet"))).cookies().containsCookie("test"); + assertThat(withCookie(cookie).get().uri("/greet")).cookies().containsCookie("test"); } @Test void hasValue() { Cookie cookie = new Cookie("test", "value"); - assertThat(performWithCookie(cookie, get("/greet"))).cookies().hasValue("test", "value"); + assertThat(withCookie(cookie).get().uri("/greet")).cookies().hasValue("test", "value"); } - private MvcTestResult performWithCookie(Cookie cookie, MockHttpServletRequestBuilder request) { - MockMvcTester mockMvc = MockMvcTester.of(List.of(new TestController()), builder -> builder.addInterceptors( + private MockMvcTester withCookie(Cookie cookie) { + return MockMvcTester.of(List.of(new TestController()), builder -> builder.addInterceptors( new HandlerInterceptor() { @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) { @@ -134,7 +133,6 @@ public class MockMvcTesterIntegrationTests { return true; } }).build()); - return mockMvc.perform(request); } } @@ -143,12 +141,12 @@ public class MockMvcTesterIntegrationTests { @Test void statusOk() { - assertThat(perform(get("/greet"))).hasStatusOk(); + assertThat(mvc.get().uri("/greet")).hasStatusOk(); } @Test void statusSeries() { - assertThat(perform(get("/greet"))).hasStatus2xxSuccessful(); + assertThat(mvc.get().uri("/greet")).hasStatus2xxSuccessful(); } } @@ -158,13 +156,13 @@ public class MockMvcTesterIntegrationTests { @Test void shouldAssertHeader() { - assertThat(perform(get("/greet"))) + assertThat(mvc.get().uri("/greet")) .hasHeader("Content-Type", "text/plain;charset=ISO-8859-1"); } @Test void shouldAssertHeaderWithCallback() { - assertThat(perform(get("/greet"))).headers().satisfies(textContent("ISO-8859-1")); + assertThat(mvc.get().uri("/greet")).headers().satisfies(textContent("ISO-8859-1")); } private Consumer textContent(String charset) { @@ -179,33 +177,33 @@ public class MockMvcTesterIntegrationTests { @Test void hasViewName() { - assertThat(perform(get("/persons/{0}", "Andy"))).hasViewName("persons/index"); + assertThat(mvc.get().uri("/persons/{0}", "Andy")).hasViewName("persons/index"); } @Test void viewNameWithCustomAssertion() { - assertThat(perform(get("/persons/{0}", "Andy"))).viewName().startsWith("persons"); + assertThat(mvc.get().uri("/persons/{0}", "Andy")).viewName().startsWith("persons"); } @Test void containsAttributes() { - assertThat(perform(post("/persons").param("name", "Andy"))).model() + assertThat(mvc.post().uri("/persons").param("name", "Andy")).model() .containsOnlyKeys("name").containsEntry("name", "Andy"); } @Test void hasErrors() { - assertThat(perform(post("/persons"))).model().hasErrors(); + assertThat(mvc.post().uri("/persons")).model().hasErrors(); } @Test void hasAttributeErrors() { - assertThat(perform(post("/persons"))).model().hasAttributeErrors("person"); + assertThat(mvc.post().uri("/persons")).model().hasAttributeErrors("person"); } @Test void hasAttributeErrorsCount() { - assertThat(perform(post("/persons"))).model().extractingBindingResult("person").hasErrorsCount(1); + assertThat(mvc.post().uri("/persons")).model().extractingBindingResult("person").hasErrorsCount(1); } } @@ -215,7 +213,7 @@ public class MockMvcTesterIntegrationTests { @Test void containsAttributes() { - assertThat(perform(post("/persons").param("name", "Andy"))).flash() + assertThat(mvc.post().uri("/persons").param("name", "Andy")).flash() .containsOnlyKeys("message").hasEntrySatisfying("message", value -> assertThat(value).isInstanceOfSatisfying(String.class, stringValue -> assertThat(stringValue).startsWith("success"))); @@ -227,31 +225,31 @@ public class MockMvcTesterIntegrationTests { @Test void asyncResult() { - assertThat(perform(get("/callable").accept(MediaType.APPLICATION_JSON))) + assertThat(mvc.get().uri("/callable").accept(MediaType.APPLICATION_JSON)) .asyncResult().asInstanceOf(map(String.class, Object.class)) .containsOnly(entry("key", "value")); } @Test void stringContent() { - assertThat(perform(get("/greet"))).body().asString().isEqualTo("hello"); + assertThat(mvc.get().uri("/greet")).body().asString().isEqualTo("hello"); } @Test void jsonPathContent() { - assertThat(perform(get("/message"))).bodyJson() + assertThat(mvc.get().uri("/message")).bodyJson() .extractingPath("$.message").asString().isEqualTo("hello"); } @Test void jsonContentCanLoadResourceFromClasspath() { - assertThat(perform(get("/message"))).bodyJson().isLenientlyEqualTo( + assertThat(mvc.get().uri("/message")).bodyJson().isLenientlyEqualTo( new ClassPathResource("message.json", MockMvcTesterIntegrationTests.class)); } @Test void jsonContentUsingResourceLoaderClass() { - assertThat(perform(get("/message"))).bodyJson().withResourceLoadClass(MockMvcTesterIntegrationTests.class) + assertThat(mvc.get().uri("/message")).bodyJson().withResourceLoadClass(MockMvcTesterIntegrationTests.class) .isLenientlyEqualTo("message.json"); } @@ -262,22 +260,22 @@ public class MockMvcTesterIntegrationTests { @Test void handlerOn404() { - assertThat(perform(get("/unknown-resource"))).handler().isNull(); + assertThat(mvc.get().uri("/unknown-resource")).handler().isNull(); } @Test void hasType() { - assertThat(perform(get("/greet"))).handler().hasType(TestController.class); + assertThat(mvc.get().uri("/greet")).handler().hasType(TestController.class); } @Test void isMethodHandler() { - assertThat(perform(get("/greet"))).handler().isMethodHandler(); + assertThat(mvc.get().uri("/greet")).handler().isMethodHandler(); } @Test void isInvokedOn() { - assertThat(perform(get("/callable"))).handler() + assertThat(mvc.get().uri("/callable")).handler() .isInvokedOn(AsyncController.class, AsyncController::getCallable); } @@ -288,31 +286,31 @@ public class MockMvcTesterIntegrationTests { @Test void doesNotHaveUnresolvedException() { - assertThat(perform(get("/greet"))).doesNotHaveUnresolvedException(); + assertThat(mvc.get().uri("/greet")).doesNotHaveUnresolvedException(); } @Test void hasUnresolvedException() { - assertThat(perform(get("/error/1"))).hasUnresolvedException(); + assertThat(mvc.get().uri("/error/1")).hasUnresolvedException(); } @Test void doesNotHaveUnresolvedExceptionWithUnresolvedException() { assertThatExceptionOfType(AssertionError.class) - .isThrownBy(() -> assertThat(perform(get("/error/1"))).doesNotHaveUnresolvedException()) + .isThrownBy(() -> assertThat(mvc.get().uri("/error/1")).doesNotHaveUnresolvedException()) .withMessage("Expected request to succeed, but it failed"); } @Test void hasUnresolvedExceptionWithoutUnresolvedException() { assertThatExceptionOfType(AssertionError.class) - .isThrownBy(() -> assertThat(perform(get("/greet"))).hasUnresolvedException()) + .isThrownBy(() -> assertThat(mvc.get().uri("/greet")).hasUnresolvedException()) .withMessage("Expected request to fail, but it succeeded"); } @Test void unresolvedExceptionWithFailedRequest() { - assertThat(perform(get("/error/1"))).unresolvedException() + assertThat(mvc.get().uri("/error/1")).unresolvedException() .isInstanceOf(ServletException.class) .cause().isInstanceOf(IllegalStateException.class).hasMessage("Expected"); } @@ -320,7 +318,7 @@ public class MockMvcTesterIntegrationTests { @Test void unresolvedExceptionWithSuccessfulRequest() { assertThatExceptionOfType(AssertionError.class) - .isThrownBy(() -> assertThat(perform(get("/greet"))).unresolvedException()) + .isThrownBy(() -> assertThat(mvc.get().uri("/greet")).unresolvedException()) .withMessage("Expected request to fail, but it succeeded"); } @@ -406,7 +404,7 @@ public class MockMvcTesterIntegrationTests { private void testAssertionFailureWithUnresolvableException(Consumer assertions) { - MvcTestResult result = perform(get("/error/1")); + MvcTestResult result = mvc.get().uri("/error/1").exchange(); assertThatExceptionOfType(AssertionError.class) .isThrownBy(() -> assertions.accept(result)) .withMessageContainingAll("Request failed unexpectedly:", @@ -418,49 +416,49 @@ public class MockMvcTesterIntegrationTests { @Test void hasForwardUrl() { - assertThat(perform(get("/persons/John"))).hasForwardedUrl("persons/index"); + assertThat(mvc.get().uri("/persons/John")).hasForwardedUrl("persons/index"); } @Test void hasRedirectUrl() { - assertThat(perform(post("/persons").param("name", "Andy"))).hasStatus(HttpStatus.FOUND) + assertThat(mvc.post().uri("/persons").param("name", "Andy")).hasStatus(HttpStatus.FOUND) .hasRedirectedUrl("/persons/Andy"); } @Test void satisfiesAllowsAdditionalAssertions() { - assertThat(this.mockMvc.perform(get("/greet"))).satisfies(result -> { + assertThat(mvc.get().uri("/greet")).satisfies(result -> { assertThat(result).isInstanceOf(MvcTestResult.class); assertThat(result).hasStatusOk(); }); } @Test - void resultMatcherCanBeReused() { - assertThat(this.mockMvc.perform(get("/greet"))).matches(status().isOk()); + void resultMatcherCanBeReused() throws Exception { + MvcTestResult result = mvc.get().uri("/greet").exchange(); + ResultMatcher matcher = mock(ResultMatcher.class); + assertThat(result).matches(matcher); + verify(matcher).match(result.getMvcResult()); } @Test void resultMatcherFailsWithDedicatedException() { + ResultMatcher matcher = result -> assertThat(result.getResponse().getStatus()) + .isEqualTo(HttpStatus.NOT_FOUND.value()); assertThatExceptionOfType(AssertionError.class) - .isThrownBy(() -> assertThat(this.mockMvc.perform(get("/greet"))) - .matches(status().isNotFound())) - .withMessageContaining("Status expected:<404> but was:<200>"); + .isThrownBy(() -> assertThat(mvc.get().uri("/greet")) + .matches(matcher)) + .withMessageContaining("expected: 404").withMessageContaining(" but was: 200"); } @Test void shouldApplyResultHandler() { // Spring RESTDocs example AtomicBoolean applied = new AtomicBoolean(); - assertThat(this.mockMvc.perform(get("/greet"))).apply(result -> applied.set(true)); + assertThat(mvc.get().uri("/greet")).apply(result -> applied.set(true)); assertThat(applied).isTrue(); } - private MvcTestResult perform(MockHttpServletRequestBuilder builder) { - return this.mockMvc.perform(builder); - } - - @Configuration @EnableWebMvc @Import({ TestController.class, PersonController.class, AsyncController.class,