@ -18,9 +18,12 @@ package org.springframework.security.web.authentication.ui;
@@ -18,9 +18,12 @@ package org.springframework.security.web.authentication.ui;
import java.io.IOException ;
import java.nio.charset.StandardCharsets ;
import java.util.Collection ;
import java.util.Collections ;
import java.util.List ;
import java.util.Map ;
import java.util.function.Function ;
import java.util.function.Predicate ;
import java.util.stream.Collectors ;
import jakarta.servlet.FilterChain ;
@ -31,10 +34,14 @@ import jakarta.servlet.http.HttpServletRequest;
@@ -31,10 +34,14 @@ import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse ;
import org.jspecify.annotations.Nullable ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.context.SecurityContextHolder ;
import org.springframework.security.core.context.SecurityContextHolderStrategy ;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter ;
import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices ;
import org.springframework.util.Assert ;
import org.springframework.web.filter.GenericFilterBean ;
import org.springframework.web.util.UriComponentsBuilder ;
/ * *
* For internal use with namespace configuration in the case where a user doesn ' t
@ -78,6 +85,8 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
@@ -78,6 +85,8 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
private @Nullable String rememberMeParameter ;
private final Collection < String > allowedParameters = List . of ( "authority" ) ;
@SuppressWarnings ( "NullAway.Init" )
private Map < String , String > oauth2AuthenticationUrlToClientName ;
@ -223,16 +232,43 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
@@ -223,16 +232,43 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
String errorMsg = "Invalid credentials" ;
String contextPath = request . getContextPath ( ) ;
return HtmlTemplates . fromTemplate ( LOGIN_PAGE_TEMPLATE )
HtmlTemplates . Builder builder = HtmlTemplates . fromTemplate ( LOGIN_PAGE_TEMPLATE )
. withRawHtml ( "contextPath" , contextPath )
. withRawHtml ( "javaScript" , renderJavaScript ( request , contextPath ) )
. withRawHtml ( "formLogin" , renderFormLogin ( request , loginError , logoutSuccess , contextPath , errorMsg ) )
. withRawHtml ( "oneTimeTokenLogin" ,
renderOneTimeTokenLogin ( request , loginError , logoutSuccess , contextPath , errorMsg ) )
. withRawHtml ( "oauth2Login" , renderOAuth2Login ( loginError , logoutSuccess , errorMsg , contextPath ) )
. withRawHtml ( "saml2Login" , renderSaml2Login ( loginError , logoutSuccess , errorMsg , contextPath ) )
. withRawHtml ( "passkeyLogin" , renderPasskeyLogin ( ) )
. render ( ) ;
. withRawHtml ( "javaScript" , "" )
. withRawHtml ( "formLogin" , "" )
. withRawHtml ( "oneTimeTokenLogin" , "" )
. withRawHtml ( "oauth2Login" , "" )
. withRawHtml ( "saml2Login" , "" )
. withRawHtml ( "passkeyLogin" , "" ) ;
Predicate < String > wantsAuthority = wantsAuthority ( request ) ;
if ( wantsAuthority . test ( "FACTOR_WEBAUTHN" ) ) {
builder . withRawHtml ( "javaScript" , renderJavaScript ( request , contextPath ) )
. withRawHtml ( "passkeyLogin" , renderPasskeyLogin ( ) ) ;
}
if ( wantsAuthority . test ( "FACTOR_PASSWORD" ) ) {
builder . withRawHtml ( "formLogin" ,
renderFormLogin ( request , loginError , logoutSuccess , contextPath , errorMsg ) ) ;
}
if ( wantsAuthority . test ( "FACTOR_OTT" ) ) {
builder . withRawHtml ( "oneTimeTokenLogin" ,
renderOneTimeTokenLogin ( request , loginError , logoutSuccess , contextPath , errorMsg ) ) ;
}
if ( wantsAuthority . test ( "FACTOR_AUTHORIZATION_CODE" ) ) {
builder . withRawHtml ( "oauth2Login" , renderOAuth2Login ( loginError , logoutSuccess , errorMsg , contextPath ) ) ;
}
if ( wantsAuthority . test ( "FACTOR_SAML_RESPONSE" ) ) {
builder . withRawHtml ( "saml2Login" , renderSaml2Login ( loginError , logoutSuccess , errorMsg , contextPath ) ) ;
}
return builder . render ( ) ;
}
private Predicate < String > wantsAuthority ( HttpServletRequest request ) {
String [ ] authorities = request . getParameterValues ( "authority" ) ;
if ( authorities = = null ) {
return ( authority ) - > true ;
}
return List . of ( authorities ) : : contains ;
}
private String renderJavaScript ( HttpServletRequest request , String contextPath ) {
@ -413,10 +449,19 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
@@ -413,10 +449,19 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
if ( request . getQueryString ( ) ! = null ) {
uri + = "?" + request . getQueryString ( ) ;
}
UriComponentsBuilder addAllowed = UriComponentsBuilder . fromUriString ( url ) ;
for ( String parameter : this . allowedParameters ) {
String [ ] values = request . getParameterValues ( parameter ) ;
if ( values ! = null ) {
for ( String value : values ) {
addAllowed . queryParam ( parameter , value ) ;
}
}
}
if ( "" . equals ( request . getContextPath ( ) ) ) {
return uri . equals ( url ) ;
return uri . equals ( addAllowed . toUriString ( ) ) ;
}
return uri . equals ( request . getContextPath ( ) + url ) ;
return uri . equals ( request . getContextPath ( ) + addAllowed . toUriString ( ) ) ;
}
private static final String CSRF_HEADERS = "" "