From bcecce7aace32538badce20b9ee8b0d03da12114 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Mon, 6 May 2024 17:33:30 +0200 Subject: [PATCH] Add shortcuts for frequently used assertions See gh-32712 --- .../AbstractHttpServletResponseAssert.java | 105 ++++++++++++++++-- .../servlet/assertj/MvcTestResultAssert.java | 10 -- ...bstractHttpServletResponseAssertTests.java | 59 +++++++++- .../MockMvcTesterIntegrationTests.java | 14 +-- 4 files changed, 157 insertions(+), 31 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/AbstractHttpServletResponseAssert.java b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/AbstractHttpServletResponseAssert.java index ddba8a187f6..fe079f7a9d8 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/AbstractHttpServletResponseAssert.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/AbstractHttpServletResponseAssert.java @@ -28,7 +28,9 @@ import org.assertj.core.api.Assertions; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus.Series; +import org.springframework.http.MediaType; import org.springframework.test.http.HttpHeadersAssert; +import org.springframework.test.http.MediaTypeAssert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.function.SingletonSupplier; @@ -48,15 +50,24 @@ import org.springframework.util.function.SingletonSupplier; public abstract class AbstractHttpServletResponseAssert, ACTUAL> extends AbstractObjectAssert { - private final Supplier> statusAssert; + private final Supplier contentTypeAssertSupplier; private final Supplier headersAssertSupplier; + private final Supplier> statusAssert; + protected AbstractHttpServletResponseAssert(ACTUAL actual, Class selfType) { super(actual, selfType); - this.statusAssert = SingletonSupplier.of(() -> Assertions.assertThat(getResponse().getStatus()).as("HTTP status code")); + this.contentTypeAssertSupplier = SingletonSupplier.of(() -> new MediaTypeAssert(getResponse().getContentType())); this.headersAssertSupplier = SingletonSupplier.of(() -> new HttpHeadersAssert(getHttpHeaders(getResponse()))); + this.statusAssert = SingletonSupplier.of(() -> Assertions.assertThat(getResponse().getStatus()).as("HTTP status code")); + } + + private static HttpHeaders getHttpHeaders(HttpServletResponse response) { + MultiValueMap headers = new LinkedMultiValueMap<>(); + response.getHeaderNames().forEach(name -> headers.put(name, new ArrayList<>(response.getHeaders(name)))); + return new HttpHeaders(headers); } /** @@ -67,6 +78,14 @@ public abstract class AbstractHttpServletResponseAssert headers = new LinkedMultiValueMap<>(); - response.getHeaderNames().forEach(name -> headers.put(name, new ArrayList<>(response.getHeaders(name)))); - return new HttpHeaders(headers); - } - } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MvcTestResultAssert.java b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MvcTestResultAssert.java index ca1dd65c4ac..7c7ff332ec8 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MvcTestResultAssert.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/assertj/MvcTestResultAssert.java @@ -30,12 +30,10 @@ import org.assertj.core.api.ObjectAssert; import org.assertj.core.error.BasicErrorMessageFactory; import org.assertj.core.internal.Failures; -import org.springframework.http.MediaType; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.lang.Nullable; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.test.http.MediaTypeAssert; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultHandler; import org.springframework.test.web.servlet.ResultMatcher; @@ -87,14 +85,6 @@ public class MvcTestResultAssert extends AbstractMockHttpServletResponseAssert headers) { MockHttpServletResponse response = new MockHttpServletResponse(); headers.forEach(response::addHeader); @@ -51,6 +69,45 @@ class AbstractHttpServletResponseAssertTests { } } + @Nested + class ContentTypeTests { + + @Test + void contentType() { + MockHttpServletResponse response = createResponse("text/plain"); + assertThat(response).hasContentType(MediaType.TEXT_PLAIN); + } + + @Test + void contentTypeAndRepresentation() { + MockHttpServletResponse response = createResponse("text/plain"); + assertThat(response).hasContentType("text/plain"); + } + + @Test + void contentTypeCompatibleWith() { + MockHttpServletResponse response = createResponse("application/json;charset=UTF-8"); + assertThat(response).hasContentTypeCompatibleWith(MediaType.APPLICATION_JSON); + } + + @Test + void contentTypeCompatibleWithAndStringRepresentation() { + MockHttpServletResponse response = createResponse("text/plain"); + assertThat(response).hasContentTypeCompatibleWith("text/*"); + } + + @Test + void contentTypeCanBeAsserted() { + MockHttpServletResponse response = createResponse("text/plain"); + assertThat(response).contentType().isInstanceOf(MediaType.class).isCompatibleWith("text/*").isNotNull(); + } + + private MockHttpServletResponse createResponse(String contentType) { + MockHttpServletResponse response = new MockHttpServletResponse(); + response.setContentType(contentType); + return response; + } + } @Nested class StatusTests { diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java index 4d58ad5c37a..1a6f69daa3f 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/assertj/MockMvcTesterIntegrationTests.java @@ -138,16 +138,6 @@ public class MockMvcTesterIntegrationTests { } } - @Nested - class ContentTypeTests { - - @Test - void contentType() { - assertThat(perform(get("/greet"))).contentType().isCompatibleWith("text/plain"); - } - - } - @Nested class StatusTests { @@ -168,8 +158,8 @@ public class MockMvcTesterIntegrationTests { @Test void shouldAssertHeader() { - assertThat(perform(get("/greet"))).headers() - .hasValue("Content-Type", "text/plain;charset=ISO-8859-1"); + assertThat(perform(get("/greet"))) + .hasHeader("Content-Type", "text/plain;charset=ISO-8859-1"); } @Test