@ -28,7 +28,9 @@ import org.assertj.core.api.Assertions;
@@ -28,7 +28,9 @@ import org.assertj.core.api.Assertions;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpStatus ;
import org.springframework.http.HttpStatus.Series ;
import org.springframework.http.MediaType ;
import org.springframework.test.http.HttpHeadersAssert ;
import org.springframework.test.http.MediaTypeAssert ;
import org.springframework.util.LinkedMultiValueMap ;
import org.springframework.util.MultiValueMap ;
import org.springframework.util.function.SingletonSupplier ;
@ -48,15 +50,24 @@ import org.springframework.util.function.SingletonSupplier;
@@ -48,15 +50,24 @@ import org.springframework.util.function.SingletonSupplier;
public abstract class AbstractHttpServletResponseAssert < R extends HttpServletResponse , SELF extends AbstractHttpServletResponseAssert < R , SELF , ACTUAL > , ACTUAL >
extends AbstractObjectAssert < SELF , ACTUAL > {
private final Supplier < AbstractIntegerAssert < ? > > statusAssert ;
private final Supplier < MediaTypeAssert > contentTypeAssertSupplier ;
private final Supplier < HttpHeadersAssert > headersAssertSupplier ;
private final Supplier < AbstractIntegerAssert < ? > > statusAssert ;
protected AbstractHttpServletResponseAssert ( ACTUAL actual , Class < ? > selfType ) {
super ( actual , selfType ) ;
this . statusAssert = SingletonSupplier . of ( ( ) - > Assertions . assertTha t( getResponse ( ) . getStatus ( ) ) . as ( "HTTP status code" ) ) ;
this . contentTypeAssertSupplier = SingletonSupplier . of ( ( ) - > new MediaTypeAsser t( getResponse ( ) . getContentType ( ) ) ) ;
this . headersAssertSupplier = SingletonSupplier . of ( ( ) - > new HttpHeadersAssert ( getHttpHeaders ( getResponse ( ) ) ) ) ;
this . statusAssert = SingletonSupplier . of ( ( ) - > Assertions . assertThat ( getResponse ( ) . getStatus ( ) ) . as ( "HTTP status code" ) ) ;
}
private static HttpHeaders getHttpHeaders ( HttpServletResponse response ) {
MultiValueMap < String , String > headers = new LinkedMultiValueMap < > ( ) ;
response . getHeaderNames ( ) . forEach ( name - > headers . put ( name , new ArrayList < > ( response . getHeaders ( name ) ) ) ) ;
return new HttpHeaders ( headers ) ;
}
/ * *
@ -67,6 +78,14 @@ public abstract class AbstractHttpServletResponseAssert<R extends HttpServletRes
@@ -67,6 +78,14 @@ public abstract class AbstractHttpServletResponseAssert<R extends HttpServletRes
* /
protected abstract R getResponse ( ) ;
/ * *
* Return a new { @linkplain MediaTypeAssert assertion } object that uses the
* response ' s { @linkplain MediaType content type } as the object to test .
* /
public MediaTypeAssert contentType ( ) {
return this . contentTypeAssertSupplier . get ( ) ;
}
/ * *
* Return a new { @linkplain HttpHeadersAssert assertion } object that uses
* { @link HttpHeaders } as the object to test . The returned assertion
@ -84,6 +103,82 @@ public abstract class AbstractHttpServletResponseAssert<R extends HttpServletRes
@@ -84,6 +103,82 @@ public abstract class AbstractHttpServletResponseAssert<R extends HttpServletRes
return this . headersAssertSupplier . get ( ) ;
}
// Content-type shortcuts
/ * *
* Verify that the response ' s { @code Content - Type } is equal to the given value .
* @param contentType the expected content type
* /
public SELF hasContentType ( MediaType contentType ) {
contentType ( ) . isEqualTo ( contentType ) ;
return this . myself ;
}
/ * *
* Verify that the response ' s { @code Content - Type } is equal to the given
* string representation .
* @param contentType the expected content type
* /
public SELF hasContentType ( String contentType ) {
contentType ( ) . isEqualTo ( contentType ) ;
return this . myself ;
}
/ * *
* Verify that the response ' s { @code Content - Type } is
* { @linkplain MediaType # isCompatibleWith ( MediaType ) compatible } with the
* given value .
* @param contentType the expected compatible content type
* /
public SELF hasContentTypeCompatibleWith ( MediaType contentType ) {
contentType ( ) . isCompatibleWith ( contentType ) ;
return this . myself ;
}
/ * *
* Verify that the response ' s { @code Content - Type } is
* { @linkplain MediaType # isCompatibleWith ( MediaType ) compatible } with the
* given string representation .
* @param contentType the expected compatible content type
* /
public SELF hasContentTypeCompatibleWith ( String contentType ) {
contentType ( ) . isCompatibleWith ( contentType ) ;
return this . myself ;
}
// Headers shortcuts
/ * *
* Verify that the response contains a header with the given { @code name } .
* @param name the name of an expected HTTP header
* /
public SELF containsHeader ( String name ) {
headers ( ) . containsHeader ( name ) ;
return this . myself ;
}
/ * *
* Verify that the response does not contain a header with the given { @code name } .
* @param name the name of an HTTP header that should not be present
* /
public SELF doesNotContainHeader ( String name ) {
headers ( ) . doesNotContainHeader ( name ) ;
return this . myself ;
}
/ * *
* Verify that the response contains a header with the given { @code name }
* and primary { @code value } .
* @param name the name of an expected HTTP header
* @param value the expected value of the header
* /
public SELF hasHeader ( String name , String value ) {
headers ( ) . hasValue ( name , value ) ;
return this . myself ;
}
// Status
/ * *
* Verify that the HTTP status is equal to the specified status code .
* @param status the expected HTTP status code
@ -159,10 +254,4 @@ public abstract class AbstractHttpServletResponseAssert<R extends HttpServletRes
@@ -159,10 +254,4 @@ public abstract class AbstractHttpServletResponseAssert<R extends HttpServletRes
return this . statusAssert . get ( ) ;
}
private static HttpHeaders getHttpHeaders ( HttpServletResponse response ) {
MultiValueMap < String , String > headers = new LinkedMultiValueMap < > ( ) ;
response . getHeaderNames ( ) . forEach ( name - > headers . put ( name , new ArrayList < > ( response . getHeaders ( name ) ) ) ) ;
return new HttpHeaders ( headers ) ;
}
}