@ -48,7 +48,6 @@ import org.springframework.http.MediaType;
@@ -48,7 +48,6 @@ import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter ;
import org.springframework.http.server.PathContainer ;
import org.springframework.http.server.RequestPath ;
import org.springframework.lang.NonNull ;
import org.springframework.lang.Nullable ;
import org.springframework.util.Assert ;
import org.springframework.util.CollectionUtils ;
@ -90,7 +89,8 @@ public abstract class RequestPredicates {
@@ -90,7 +89,8 @@ public abstract class RequestPredicates {
* @return a predicate that tests against the given HTTP method
* /
public static RequestPredicate method ( HttpMethod httpMethod ) {
return new HttpMethodPredicate ( httpMethod ) ;
Assert . notNull ( httpMethod , "HttpMethod must not be null" ) ;
return new SingleHttpMethodPredicate ( httpMethod ) ;
}
/ * *
@ -100,7 +100,13 @@ public abstract class RequestPredicates {
@@ -100,7 +100,13 @@ public abstract class RequestPredicates {
* @return a predicate that tests against the given HTTP methods
* /
public static RequestPredicate methods ( HttpMethod . . . httpMethods ) {
return new HttpMethodPredicate ( httpMethods ) ;
Assert . notEmpty ( httpMethods , "HttpMethods must not be empty" ) ;
if ( httpMethods . length = = 1 ) {
return new SingleHttpMethodPredicate ( httpMethods [ 0 ] ) ;
}
else {
return new MultipleHttpMethodsPredicate ( httpMethods ) ;
}
}
/ * *
@ -150,7 +156,12 @@ public abstract class RequestPredicates {
@@ -150,7 +156,12 @@ public abstract class RequestPredicates {
* /
public static RequestPredicate contentType ( MediaType . . . mediaTypes ) {
Assert . notEmpty ( mediaTypes , "'mediaTypes' must not be empty" ) ;
return new ContentTypePredicate ( mediaTypes ) ;
if ( mediaTypes . length = = 1 ) {
return new SingleContentTypePredicate ( mediaTypes [ 0 ] ) ;
}
else {
return new MultipleContentTypesPredicate ( mediaTypes ) ;
}
}
/ * *
@ -162,7 +173,12 @@ public abstract class RequestPredicates {
@@ -162,7 +173,12 @@ public abstract class RequestPredicates {
* /
public static RequestPredicate accept ( MediaType . . . mediaTypes ) {
Assert . notEmpty ( mediaTypes , "'mediaTypes' must not be empty" ) ;
return new AcceptPredicate ( mediaTypes ) ;
if ( mediaTypes . length = = 1 ) {
return new SingleAcceptPredicate ( mediaTypes [ 0 ] ) ;
}
else {
return new MultipleAcceptsPredicate ( mediaTypes ) ;
}
}
/ * *
@ -527,29 +543,23 @@ public abstract class RequestPredicates {
@@ -527,29 +543,23 @@ public abstract class RequestPredicates {
}
private static class HttpMethodPredicate implements RequestPredicate {
private final Set < HttpMethod > httpMethods ;
private static class SingleHttpMethodPredicate implements RequestPredicate {
public HttpMethodPredicate ( HttpMethod httpMethod ) {
Assert . notNull ( httpMethod , "HttpMethod must not be null" ) ;
this . httpMethods = Set . of ( httpMethod ) ;
}
private final HttpMethod httpMethod ;
public HttpMethodPredicate ( HttpMethod . . . httpMethods ) {
Assert . notEmpty ( httpMethods , "HttpMethods must not be empty" ) ;
this . httpMethods = new LinkedHashSet < > ( Arrays . asList ( httpMethods ) ) ;
public SingleHttpMethodPredicate ( HttpMethod httpMethod ) {
this . httpMethod = httpMethod ;
}
@Override
public boolean test ( ServerRequest request ) {
HttpMethod method = method ( request ) ;
boolean match = this . httpMethods . contain s ( method ) ;
traceMatch ( "Method" , this . httpMethods , method , match ) ;
boolean match = this . httpMethod . equal s( method ) ;
traceMatch ( "Method" , this . httpMethod , method , match ) ;
return match ;
}
private static HttpMethod method ( ServerRequest request ) {
static HttpMethod method ( ServerRequest request ) {
if ( CorsUtils . isPreFlightRequest ( request . servletRequest ( ) ) ) {
String accessControlRequestMethod =
request . headers ( ) . firstHeader ( HttpHeaders . ACCESS_CONTROL_REQUEST_METHOD ) ;
@ -560,6 +570,34 @@ public abstract class RequestPredicates {
@@ -560,6 +570,34 @@ public abstract class RequestPredicates {
return request . method ( ) ;
}
@Override
public void accept ( Visitor visitor ) {
visitor . method ( Set . of ( this . httpMethod ) ) ;
}
@Override
public String toString ( ) {
return this . httpMethod . toString ( ) ;
}
}
private static class MultipleHttpMethodsPredicate implements RequestPredicate {
private final Set < HttpMethod > httpMethods ;
public MultipleHttpMethodsPredicate ( HttpMethod [ ] httpMethods ) {
this . httpMethods = new LinkedHashSet < > ( Arrays . asList ( httpMethods ) ) ;
}
@Override
public boolean test ( ServerRequest request ) {
HttpMethod method = SingleHttpMethodPredicate . method ( request ) ;
boolean match = this . httpMethods . contains ( method ) ;
traceMatch ( "Method" , this . httpMethods , method , match ) ;
return match ;
}
@Override
public void accept ( Visitor visitor ) {
visitor . method ( Collections . unmodifiableSet ( this . httpMethods ) ) ;
@ -567,12 +605,7 @@ public abstract class RequestPredicates {
@@ -567,12 +605,7 @@ public abstract class RequestPredicates {
@Override
public String toString ( ) {
if ( this . httpMethods . size ( ) = = 1 ) {
return this . httpMethods . iterator ( ) . next ( ) . toString ( ) ;
}
else {
return this . httpMethods . toString ( ) ;
}
return this . httpMethods . toString ( ) ;
}
}
@ -667,20 +700,46 @@ public abstract class RequestPredicates {
@@ -667,20 +700,46 @@ public abstract class RequestPredicates {
}
private static class ContentTypePredicate extends HeadersPredicate {
private static class Single ContentTypePredicate extends HeadersPredicate {
private final Set < MediaType > mediaTypes ;
private final MediaType mediaType ;
public ContentTypePredicate ( MediaType . . . mediaTypes ) {
this ( Set . of ( mediaTypes ) ) ;
public SingleContentTypePredicate ( MediaType mediaType ) {
super ( headers - > {
MediaType contentType = headers . contentType ( ) . orElse ( MediaType . APPLICATION_OCTET_STREAM ) ;
boolean match = mediaType . includes ( contentType ) ;
traceMatch ( "Content-Type" , mediaType , contentType , match ) ;
return match ;
} ) ;
this . mediaType = mediaType ;
}
@Override
public void accept ( Visitor visitor ) {
visitor . header ( HttpHeaders . CONTENT_TYPE , this . mediaType . toString ( ) ) ;
}
private ContentTypePredicate ( Set < MediaType > mediaTypes ) {
@Override
public String toString ( ) {
return "Content-Type: " + this . mediaType ;
}
}
private static class MultipleContentTypesPredicate extends HeadersPredicate {
private final MediaType [ ] mediaTypes ;
public MultipleContentTypesPredicate ( MediaType [ ] mediaTypes ) {
super ( headers - > {
MediaType contentType =
headers . contentType ( ) . orElse ( MediaType . APPLICATION_OCTET_STREAM ) ;
boolean match = mediaTypes . stream ( )
. anyMatch ( mediaType - > mediaType . includes ( contentType ) ) ;
MediaType contentType = headers . contentType ( ) . orElse ( MediaType . APPLICATION_OCTET_STREAM ) ;
boolean match = false ;
for ( MediaType mediaType : mediaTypes ) {
if ( mediaType . includes ( contentType ) ) {
match = true ;
break ;
}
}
traceMatch ( "Content-Type" , mediaTypes , contentType , match ) ;
return match ;
} ) ;
@ -689,44 +748,37 @@ public abstract class RequestPredicates {
@@ -689,44 +748,37 @@ public abstract class RequestPredicates {
@Override
public void accept ( Visitor visitor ) {
visitor . header ( HttpHeaders . CONTENT_TYPE ,
( this . mediaTypes . size ( ) = = 1 ) ?
this . mediaTypes . iterator ( ) . next ( ) . toString ( ) :
this . mediaTypes . toString ( ) ) ;
visitor . header ( HttpHeaders . CONTENT_TYPE , Arrays . toString ( this . mediaTypes ) ) ;
}
@Override
public String toString ( ) {
return String . format ( "Content-Type: %s" ,
( this . mediaTypes . size ( ) = = 1 ) ?
this . mediaTypes . iterator ( ) . next ( ) . toString ( ) :
this . mediaTypes . toString ( ) ) ;
return "Content-Type: " + Arrays . toString ( this . mediaTypes ) ;
}
}
private static class AcceptPredicate extends HeadersPredicate {
private final Set < MediaType > mediaTypes ;
private static class SingleAcceptPredicate extends HeadersPredicate {
public AcceptPredicate ( MediaType . . . mediaTypes ) {
this ( Set . of ( mediaTypes ) ) ;
}
private final MediaType mediaType ;
private AcceptPredicate ( Set < MediaType > mediaTypes ) {
public SingleAcceptPredicate ( MediaType mediaType ) {
super ( headers - > {
List < MediaType > acceptedMediaTypes = acceptedMediaTypes ( headers ) ;
boolean match = acceptedMediaTypes . stream ( )
. anyMatch ( acceptedMediaType - > mediaTypes . stream ( )
. anyMatch ( acceptedMediaType : : isCompatibleWith ) ) ;
traceMatch ( "Accept" , mediaTypes , acceptedMediaTypes , match ) ;
boolean match = false ;
for ( MediaType acceptedMediaType : acceptedMediaTypes ) {
if ( acceptedMediaType . isCompatibleWith ( mediaType ) ) {
match = true ;
break ;
}
}
traceMatch ( "Accept" , mediaType , acceptedMediaTypes , match ) ;
return match ;
} ) ;
this . mediaTypes = mediaTypes ;
this . mediaType = mediaType ;
}
@NonNull
private static List < MediaType > acceptedMediaTypes ( ServerRequest . Headers headers ) {
static List < MediaType > acceptedMediaTypes ( ServerRequest . Headers headers ) {
List < MediaType > acceptedMediaTypes = headers . accept ( ) ;
if ( acceptedMediaTypes . isEmpty ( ) ) {
acceptedMediaTypes = Collections . singletonList ( MediaType . ALL ) ;
@ -739,18 +791,47 @@ public abstract class RequestPredicates {
@@ -739,18 +791,47 @@ public abstract class RequestPredicates {
@Override
public void accept ( Visitor visitor ) {
visitor . header ( HttpHeaders . ACCEPT ,
( this . mediaTypes . size ( ) = = 1 ) ?
this . mediaTypes . iterator ( ) . next ( ) . toString ( ) :
this . mediaTypes . toString ( ) ) ;
visitor . header ( HttpHeaders . ACCEPT , this . mediaType . toString ( ) ) ;
}
@Override
public String toString ( ) {
return "Accept: " + this . mediaType ;
}
}
private static class MultipleAcceptsPredicate extends HeadersPredicate {
private final MediaType [ ] mediaTypes ;
public MultipleAcceptsPredicate ( MediaType [ ] mediaTypes ) {
super ( headers - > {
List < MediaType > acceptedMediaTypes = SingleAcceptPredicate . acceptedMediaTypes ( headers ) ;
boolean match = false ;
outer :
for ( MediaType acceptedMediaType : acceptedMediaTypes ) {
for ( MediaType mediaType : mediaTypes ) {
if ( acceptedMediaType . isCompatibleWith ( mediaType ) ) {
match = true ;
break outer ;
}
}
}
traceMatch ( "Accept" , mediaTypes , acceptedMediaTypes , match ) ;
return match ;
} ) ;
this . mediaTypes = mediaTypes ;
}
@Override
public void accept ( Visitor visitor ) {
visitor . header ( HttpHeaders . ACCEPT , Arrays . toString ( this . mediaTypes ) ) ;
}
@Override
public String toString ( ) {
return String . format ( "Accept: %s" ,
( this . mediaTypes . size ( ) = = 1 ) ?
this . mediaTypes . iterator ( ) . next ( ) . toString ( ) :
this . mediaTypes . toString ( ) ) ;
return "Accept: " + Arrays . toString ( this . mediaTypes ) ;
}
}