@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2012 - 2017 the original author or authors .
* Copyright 2012 - 2020 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -26,6 +26,7 @@ import java.util.Collections;
@@ -26,6 +26,7 @@ import java.util.Collections;
import java.util.HashSet ;
import java.util.List ;
import java.util.Set ;
import java.util.function.Predicate ;
/ * *
* < p >
@ -66,10 +67,15 @@ import java.util.Set;
@@ -66,10 +67,15 @@ import java.util.Set;
* Rejects URLs that contain a URL encoded percent . See
* { @link # setAllowUrlEncodedPercent ( boolean ) }
* < / li >
* < li >
* Rejects hosts that are not allowed . See
* { @link # setAllowedHostnames ( Predicate ) }
* < / li >
* < / ul >
*
* @see DefaultHttpFirewall
* @author Rob Winch
* @author Eddú Meléndez
* @since 4 . 2 . 4
* /
public class StrictHttpFirewall implements HttpFirewall {
@ -96,6 +102,8 @@ public class StrictHttpFirewall implements HttpFirewall {
@@ -96,6 +102,8 @@ public class StrictHttpFirewall implements HttpFirewall {
private Set < String > allowedHttpMethods = createDefaultAllowedHttpMethods ( ) ;
private Predicate < String > allowedHostnames = hostname - > true ;
public StrictHttpFirewall ( ) {
urlBlacklistsAddAll ( FORBIDDEN_SEMICOLON ) ;
urlBlacklistsAddAll ( FORBIDDEN_FORWARDSLASH ) ;
@ -277,6 +285,13 @@ public class StrictHttpFirewall implements HttpFirewall {
@@ -277,6 +285,13 @@ public class StrictHttpFirewall implements HttpFirewall {
}
}
public void setAllowedHostnames ( Predicate < String > allowedHostnames ) {
if ( allowedHostnames = = null ) {
throw new IllegalArgumentException ( "allowedHostnames cannot be null" ) ;
}
this . allowedHostnames = allowedHostnames ;
}
private void urlBlacklistsAddAll ( Collection < String > values ) {
this . encodedUrlBlacklist . addAll ( values ) ;
this . decodedUrlBlacklist . addAll ( values ) ;
@ -291,6 +306,7 @@ public class StrictHttpFirewall implements HttpFirewall {
@@ -291,6 +306,7 @@ public class StrictHttpFirewall implements HttpFirewall {
public FirewalledRequest getFirewalledRequest ( HttpServletRequest request ) throws RequestRejectedException {
rejectForbiddenHttpMethod ( request ) ;
rejectedBlacklistedUrls ( request ) ;
rejectedUntrustedHosts ( request ) ;
if ( ! isNormalized ( request ) ) {
throw new RequestRejectedException ( "The request was rejected because the URL was not normalized." ) ;
@ -332,6 +348,13 @@ public class StrictHttpFirewall implements HttpFirewall {
@@ -332,6 +348,13 @@ public class StrictHttpFirewall implements HttpFirewall {
}
}
private void rejectedUntrustedHosts ( HttpServletRequest request ) {
String serverName = request . getServerName ( ) ;
if ( serverName ! = null & & ! this . allowedHostnames . test ( serverName ) ) {
throw new RequestRejectedException ( "The request was rejected because the domain " + serverName + " is untrusted." ) ;
}
}
@Override
public HttpServletResponse getFirewalledResponse ( HttpServletResponse response ) {
return new FirewalledResponse ( response ) ;