@ -20,7 +20,9 @@ import java.io.File;
@@ -20,7 +20,9 @@ import java.io.File;
import java.io.FileNotFoundException ;
import java.util.Collection ;
import java.util.Enumeration ;
import java.util.LinkedHashSet ;
import java.util.Map ;
import java.util.Set ;
import java.util.StringTokenizer ;
import java.util.TreeMap ;
import javax.servlet.ServletContext ;
@ -33,6 +35,7 @@ import javax.servlet.http.HttpServletRequest;
@@ -33,6 +35,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse ;
import javax.servlet.http.HttpSession ;
import org.springframework.http.HttpHeaders ;
import org.springframework.http.HttpRequest ;
import org.springframework.http.server.ServletServerHttpRequest ;
import org.springframework.lang.Nullable ;
@ -135,6 +138,16 @@ public abstract class WebUtils {
@@ -135,6 +138,16 @@ public abstract class WebUtils {
/** Key for the mutex session attribute */
public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils . class . getName ( ) + ".MUTEX" ;
private static final Set < String > FORWARDED_HEADER_NAMES = new LinkedHashSet < > ( 5 ) ;
static {
FORWARDED_HEADER_NAMES . add ( "Forwarded" ) ;
FORWARDED_HEADER_NAMES . add ( "X-Forwarded-Host" ) ;
FORWARDED_HEADER_NAMES . add ( "X-Forwarded-Port" ) ;
FORWARDED_HEADER_NAMES . add ( "X-Forwarded-Proto" ) ;
FORWARDED_HEADER_NAMES . add ( "X-Forwarded-Prefix" ) ;
}
/ * *
* Set a system property to the web application root directory .
@ -693,36 +706,60 @@ public abstract class WebUtils {
@@ -693,36 +706,60 @@ public abstract class WebUtils {
* @since 4 . 2
* /
public static boolean isSameOrigin ( HttpRequest request ) {
String origin = request . getHeaders ( ) . getOrigin ( ) ;
HttpHeaders headers = request . getHeaders ( ) ;
String origin = headers . getOrigin ( ) ;
if ( origin = = null ) {
return true ;
}
UriComponentsBuilder urlBuilder ;
String scheme ;
String host ;
int port ;
if ( request instanceof ServletServerHttpRequest ) {
// Build more efficiently if we can: we only need scheme, host, port for origin comparison
HttpServletRequest servletRequest = ( ( ServletServerHttpRequest ) request ) . getServletRequest ( ) ;
urlBuilder = new UriComponentsBuilder ( ) .
scheme ( servletRequest . getScheme ( ) ) .
host ( servletRequest . getServerName ( ) ) .
port ( servletRequest . getServerPort ( ) ) .
adaptFromForwardedHeaders ( request . getHeaders ( ) ) ;
scheme = servletRequest . getScheme ( ) ;
host = servletRequest . getServerName ( ) ;
port = servletRequest . getServerPort ( ) ;
if ( containsForwardedHeaders ( servletRequest ) ) {
UriComponents actualUrl = new UriComponentsBuilder ( )
. scheme ( scheme )
. host ( host )
. port ( port )
. adaptFromForwardedHeaders ( headers )
. build ( ) ;
scheme = actualUrl . getScheme ( ) ;
host = actualUrl . getHost ( ) ;
port = actualUrl . getPort ( ) ;
}
}
else {
urlBuilder = UriComponentsBuilder . fromHttpRequest ( request ) ;
UriComponents actualUrl = UriComponentsBuilder . fromHttpRequest ( request ) . build ( ) ;
scheme = actualUrl . getScheme ( ) ;
host = actualUrl . getHost ( ) ;
port = actualUrl . getPort ( ) ;
}
UriComponents actualUrl = urlBuilder . build ( ) ;
UriComponents originUrl = UriComponentsBuilder . fromOriginHeader ( origin ) . build ( ) ;
return ( ObjectUtils . nullSafeEquals ( actualUrl . getHost ( ) , originUrl . getHost ( ) ) & &
getPort ( actualUrl ) = = getPort ( originUrl ) ) ;
return ( ObjectUtils . nullSafeEquals ( host , originUrl . getHost ( ) ) & &
getPort ( scheme , port ) = = getPort ( originUrl . getScheme ( ) , originUrl . getPort ( ) ) ) ;
}
private static boolean containsForwardedHeaders ( HttpServletRequest request ) {
for ( String headerName : FORWARDED_HEADER_NAMES ) {
if ( request . getHeader ( headerName ) ! = null ) {
return true ;
}
}
return false ;
}
private static int getPort ( UriComponents uri ) {
int port = uri . getPort ( ) ;
private static int getPort ( String scheme , int port ) {
if ( port = = - 1 ) {
if ( "http" . equals ( uri . getScheme ( ) ) | | "ws" . equals ( uri . getScheme ( ) ) ) {
if ( "http" . equals ( s cheme) | | "ws" . equals ( s cheme) ) {
port = 80 ;
}
else if ( "https" . equals ( uri . getScheme ( ) ) | | "wss" . equals ( uri . getScheme ( ) ) ) {
else if ( "https" . equals ( s cheme) | | "wss" . equals ( s cheme) ) {
port = 443 ;
}
}