@ -20,9 +20,7 @@ import java.io.IOException;
import java.net.InetSocketAddress ;
import java.net.InetSocketAddress ;
import java.util.Collections ;
import java.util.Collections ;
import java.util.Enumeration ;
import java.util.Enumeration ;
import java.util.List ;
import java.util.Locale ;
import java.util.Locale ;
import java.util.Map ;
import java.util.Set ;
import java.util.Set ;
import java.util.function.Supplier ;
import java.util.function.Supplier ;
@ -37,7 +35,6 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest ;
import org.springframework.http.server.ServerHttpRequest ;
import org.springframework.http.server.ServletServerHttpRequest ;
import org.springframework.http.server.ServletServerHttpRequest ;
import org.springframework.lang.Nullable ;
import org.springframework.lang.Nullable ;
import org.springframework.util.CollectionUtils ;
import org.springframework.util.LinkedCaseInsensitiveMap ;
import org.springframework.util.LinkedCaseInsensitiveMap ;
import org.springframework.util.StringUtils ;
import org.springframework.util.StringUtils ;
import org.springframework.web.util.UriComponents ;
import org.springframework.web.util.UriComponents ;
@ -169,23 +166,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
* /
* /
private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper {
private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper {
private final Map < String , Lis t< String > > headers ;
private final Se t< String > headerName s ;
public ForwardedHeaderRemovingRequest ( HttpServletRequest request ) {
public ForwardedHeaderRemovingRequest ( HttpServletRequest request ) {
super ( request ) ;
super ( request ) ;
this . headers = initHeaders ( request ) ;
this . headerNames = headerNames ( request ) ;
}
}
private static Map < String , List < String > > initHeaders ( HttpServletRequest request ) {
private static Set < String > headerNames ( HttpServletRequest request ) {
Map < String , List < String > > headers = new LinkedCaseInsensitiveMap < > ( Locale . ENGLISH ) ;
final var headerNames = Collections . newSetFromMap ( new LinkedCaseInsensitiveMap < > ( Locale . ENGLISH ) ) ;
Enumeration < String > names = request . getHeaderNames ( ) ;
final var names = request . getHeaderNames ( ) ;
while ( names . hasMoreElements ( ) ) {
while ( names . hasMoreElements ( ) ) {
String name = names . nextElement ( ) ;
final var name = names . nextElement ( ) ;
if ( ! FORWARDED_HEADER_NAMES . contains ( name ) ) {
headerNames . add ( name ) ;
headers . put ( name , Collections . list ( request . getHeaders ( name ) ) ) ;
}
}
}
return headers ;
headerNames . removeAll ( FORWARDED_HEADER_NAMES ) ;
return Collections . unmodifiableSet ( headerNames ) ;
}
}
// Override header accessors to not expose forwarded headers
// Override header accessors to not expose forwarded headers
@ -193,19 +193,25 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@Override
@Override
@Nullable
@Nullable
public String getHeader ( String name ) {
public String getHeader ( String name ) {
List < String > value = this . headers . get ( name ) ;
if ( FORWARDED_HEADER_NAMES . contains ( name ) ) {
return ( CollectionUtils . isEmpty ( value ) ? null : value . get ( 0 ) ) ;
return null ;
}
return super . getHeader ( name ) ;
}
}
@Override
@Override
public Enumeration < String > getHeaders ( String name ) {
public Enumeration < String > getHeaders ( String name ) {
List < String > value = this . headers . get ( name ) ;
if ( FORWARDED_HEADER_NAMES . contains ( name ) ) {
return ( Collections . enumeration ( value ! = null ? value : Collections . emptySet ( ) ) ) ;
return Collections . emptyEnumeration ( ) ;
}
return super . getHeaders ( name ) ;
}
}
@Override
@Override
public Enumeration < String > getHeaderNames ( ) {
public Enumeration < String > getHeaderNames ( ) {
return Collections . enumeration ( this . headers . keySet ( ) ) ;
return Collections . enumeration ( this . headerNames ) ;
}
}
}
}