diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java index fd481f80a8..421912a2cc 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java @@ -402,7 +402,7 @@ public class FormLoginConfigurerTests { UserDetails user = PasswordEncodedUser.user(); this.mockMvc.perform(get("/profile").with(user(user))) .andExpect(status().is3xxRedirection()) - .andExpect(redirectedUrl("http://localhost/login")); + .andExpect(redirectedUrl("http://localhost/login?authority=FACTOR_PASSWORD")); this.mockMvc .perform(post("/ott/generate").param("username", "rod") .with(user(user)) @@ -418,11 +418,11 @@ public class FormLoginConfigurerTests { user = PasswordEncodedUser.withUserDetails(user).authorities("profile:read", "FACTOR_OTT").build(); this.mockMvc.perform(get("/profile").with(user(user))) .andExpect(status().is3xxRedirection()) - .andExpect(redirectedUrl("http://localhost/login")); + .andExpect(redirectedUrl("http://localhost/login?authority=FACTOR_PASSWORD")); user = PasswordEncodedUser.withUserDetails(user).authorities("profile:read", "FACTOR_PASSWORD").build(); this.mockMvc.perform(get("/profile").with(user(user))) .andExpect(status().is3xxRedirection()) - .andExpect(redirectedUrl("http://localhost/login")); + .andExpect(redirectedUrl("http://localhost/login?authority=FACTOR_OTT")); user = PasswordEncodedUser.withUserDetails(user) .authorities("profile:read", "FACTOR_PASSWORD", "FACTOR_OTT") .build(); @@ -438,7 +438,7 @@ public class FormLoginConfigurerTests { this.mockMvc.perform(get("/login")).andExpect(status().isOk()); this.mockMvc.perform(get("/profile").with(SecurityMockMvcRequestPostProcessors.x509("rod.cer"))) .andExpect(status().is3xxRedirection()) - .andExpect(redirectedUrl("http://localhost/login")); + .andExpect(redirectedUrl("http://localhost/login?authority=FACTOR_PASSWORD")); this.mockMvc .perform(post("/login").param("username", "rod") .param("password", "password") diff --git a/core/src/main/java/org/springframework/security/core/GrantedAuthority.java b/core/src/main/java/org/springframework/security/core/GrantedAuthority.java index 143b254b85..8e7ff7ddc5 100644 --- a/core/src/main/java/org/springframework/security/core/GrantedAuthority.java +++ b/core/src/main/java/org/springframework/security/core/GrantedAuthority.java @@ -31,6 +31,8 @@ import org.springframework.security.authorization.AuthorizationManager; */ public interface GrantedAuthority extends Serializable { + String MISSING_AUTHORITIES_ATTRIBUTE = GrantedAuthority.class + ".missingAuthorities"; + /** * If the GrantedAuthority can be represented as a String * and that String is sufficient in precision to be relied upon for an diff --git a/web/src/main/java/org/springframework/security/web/access/DelegatingMissingAuthorityAccessDeniedHandler.java b/web/src/main/java/org/springframework/security/web/access/DelegatingMissingAuthorityAccessDeniedHandler.java index 215ed6832c..1e8773be52 100644 --- a/web/src/main/java/org/springframework/security/web/access/DelegatingMissingAuthorityAccessDeniedHandler.java +++ b/web/src/main/java/org/springframework/security/web/access/DelegatingMissingAuthorityAccessDeniedHandler.java @@ -91,14 +91,19 @@ public final class DelegatingMissingAuthorityAccessDeniedHandler implements Acce public void handle(HttpServletRequest request, HttpServletResponse response, AccessDeniedException denied) throws IOException, ServletException { Collection authorities = missingAuthorities(denied); - AuthenticationEntryPoint entryPoint = entryPoint(authorities); - if (entryPoint == null) { - this.defaultAccessDeniedHandler.handle(request, response, denied); + for (GrantedAuthority needed : authorities) { + AuthenticationEntryPoint entryPoint = this.entryPoints.get(needed.getAuthority()); + if (entryPoint == null) { + continue; + } + this.requestCache.saveRequest(request, response); + request.setAttribute(GrantedAuthority.MISSING_AUTHORITIES_ATTRIBUTE, List.of(needed)); + String message = String.format("Missing Authorities %s", List.of(needed)); + AuthenticationException ex = new InsufficientAuthenticationException(message, denied); + entryPoint.commence(request, response, ex); return; } - this.requestCache.saveRequest(request, response); - AuthenticationException ex = new InsufficientAuthenticationException("missing authorities", denied); - entryPoint.commence(request, response, ex); + this.defaultAccessDeniedHandler.handle(request, response, denied); } /** @@ -121,17 +126,6 @@ public final class DelegatingMissingAuthorityAccessDeniedHandler implements Acce this.requestCache = requestCache; } - private @Nullable AuthenticationEntryPoint entryPoint(Collection authorities) { - for (GrantedAuthority needed : authorities) { - AuthenticationEntryPoint entryPoint = this.entryPoints.get(needed.getAuthority()); - if (entryPoint == null) { - continue; - } - return entryPoint; - } - return null; - } - private Collection missingAuthorities(AccessDeniedException ex) { AuthorizationDeniedException denied = findAuthorizationDeniedException(ex); if (denied == null) { diff --git a/web/src/main/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPoint.java b/web/src/main/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPoint.java index 2cc2ac92d3..50e0adbc05 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPoint.java @@ -17,6 +17,7 @@ package org.springframework.security.web.authentication; import java.io.IOException; +import java.util.Collection; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletException; @@ -30,6 +31,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.beans.factory.InitializingBean; import org.springframework.core.log.LogMessage; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.DefaultRedirectStrategy; import org.springframework.security.web.PortMapper; @@ -40,6 +42,7 @@ import org.springframework.security.web.util.RedirectUrlBuilder; import org.springframework.security.web.util.UrlUtils; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; /** * Used by the {@link ExceptionTranslationFilter} to commence a form login authentication @@ -109,6 +112,12 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin */ protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) { + Object value = request.getAttribute(GrantedAuthority.MISSING_AUTHORITIES_ATTRIBUTE); + if (value instanceof Collection authorities) { + return UriComponentsBuilder.fromUriString(getLoginFormUrl()) + .queryParam("authority", authorities) + .toUriString(); + } return getLoginFormUrl(); } diff --git a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java index 3c74146b83..05adc4b58b 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java @@ -18,9 +18,12 @@ package org.springframework.security.web.authentication.ui; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.function.Function; +import java.util.function.Predicate; import java.util.stream.Collectors; import jakarta.servlet.FilterChain; @@ -31,10 +34,14 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.jspecify.annotations.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices; import org.springframework.util.Assert; import org.springframework.web.filter.GenericFilterBean; +import org.springframework.web.util.UriComponentsBuilder; /** * For internal use with namespace configuration in the case where a user doesn't @@ -78,6 +85,8 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { private @Nullable String rememberMeParameter; + private final Collection allowedParameters = List.of("authority"); + @SuppressWarnings("NullAway.Init") private Map oauth2AuthenticationUrlToClientName; @@ -223,16 +232,43 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { String errorMsg = "Invalid credentials"; String contextPath = request.getContextPath(); - return HtmlTemplates.fromTemplate(LOGIN_PAGE_TEMPLATE) + HtmlTemplates.Builder builder = HtmlTemplates.fromTemplate(LOGIN_PAGE_TEMPLATE) .withRawHtml("contextPath", contextPath) - .withRawHtml("javaScript", renderJavaScript(request, contextPath)) - .withRawHtml("formLogin", renderFormLogin(request, loginError, logoutSuccess, contextPath, errorMsg)) - .withRawHtml("oneTimeTokenLogin", - renderOneTimeTokenLogin(request, loginError, logoutSuccess, contextPath, errorMsg)) - .withRawHtml("oauth2Login", renderOAuth2Login(loginError, logoutSuccess, errorMsg, contextPath)) - .withRawHtml("saml2Login", renderSaml2Login(loginError, logoutSuccess, errorMsg, contextPath)) - .withRawHtml("passkeyLogin", renderPasskeyLogin()) - .render(); + .withRawHtml("javaScript", "") + .withRawHtml("formLogin", "") + .withRawHtml("oneTimeTokenLogin", "") + .withRawHtml("oauth2Login", "") + .withRawHtml("saml2Login", "") + .withRawHtml("passkeyLogin", ""); + + Predicate wantsAuthority = wantsAuthority(request); + if (wantsAuthority.test("FACTOR_WEBAUTHN")) { + builder.withRawHtml("javaScript", renderJavaScript(request, contextPath)) + .withRawHtml("passkeyLogin", renderPasskeyLogin()); + } + if (wantsAuthority.test("FACTOR_PASSWORD")) { + builder.withRawHtml("formLogin", + renderFormLogin(request, loginError, logoutSuccess, contextPath, errorMsg)); + } + if (wantsAuthority.test("FACTOR_OTT")) { + builder.withRawHtml("oneTimeTokenLogin", + renderOneTimeTokenLogin(request, loginError, logoutSuccess, contextPath, errorMsg)); + } + if (wantsAuthority.test("FACTOR_AUTHORIZATION_CODE")) { + builder.withRawHtml("oauth2Login", renderOAuth2Login(loginError, logoutSuccess, errorMsg, contextPath)); + } + if (wantsAuthority.test("FACTOR_SAML_RESPONSE")) { + builder.withRawHtml("saml2Login", renderSaml2Login(loginError, logoutSuccess, errorMsg, contextPath)); + } + return builder.render(); + } + + private Predicate wantsAuthority(HttpServletRequest request) { + String[] authorities = request.getParameterValues("authority"); + if (authorities == null) { + return (authority) -> true; + } + return List.of(authorities)::contains; } private String renderJavaScript(HttpServletRequest request, String contextPath) { @@ -413,10 +449,19 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { if (request.getQueryString() != null) { uri += "?" + request.getQueryString(); } + UriComponentsBuilder addAllowed = UriComponentsBuilder.fromUriString(url); + for (String parameter : this.allowedParameters) { + String[] values = request.getParameterValues(parameter); + if (values != null) { + for (String value : values) { + addAllowed.queryParam(parameter, value); + } + } + } if ("".equals(request.getContextPath())) { - return uri.equals(url); + return uri.equals(addAllowed.toUriString()); } - return uri.equals(request.getContextPath() + url); + return uri.equals(request.getContextPath() + addAllowed.toUriString()); } private static final String CSRF_HEADERS = """ diff --git a/web/src/test/java/org/springframework/security/web/authentication/DefaultLoginPageGeneratingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/DefaultLoginPageGeneratingFilterTests.java index 2dff03e7e4..fc469f4f3c 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/DefaultLoginPageGeneratingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/DefaultLoginPageGeneratingFilterTests.java @@ -28,6 +28,7 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.web.WebAttributes; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; +import org.springframework.security.web.servlet.TestMockHttpServletRequests; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; @@ -191,6 +192,60 @@ public class DefaultLoginPageGeneratingFilterTests { """); } + @Test + public void generateWhenOneTimeTokenRequestedThenOttForm() throws Exception { + DefaultLoginPageGeneratingFilter filter = new DefaultLoginPageGeneratingFilter(); + filter.setLoginPageUrl(DefaultLoginPageGeneratingFilter.DEFAULT_LOGIN_PAGE_URL); + filter.setFormLoginEnabled(true); + filter.setOneTimeTokenEnabled(true); + filter.setOneTimeTokenGenerationUrl("/ott/authenticate"); + MockHttpServletResponse response = new MockHttpServletResponse(); + filter.doFilter(TestMockHttpServletRequests.get("/login?authority=FACTOR_OTT").build(), response, this.chain); + assertThat(response.getContentAsString()).contains("Request a One-Time Token"); + assertThat(response.getContentAsString()).contains(""" + + """); + assertThat(response.getContentAsString()).doesNotContain("Password"); + } + + @Test + public void generateWhenTwoAuthoritiesRequestedThenBothForms() throws Exception { + DefaultLoginPageGeneratingFilter filter = new DefaultLoginPageGeneratingFilter(); + filter.setLoginPageUrl(DefaultLoginPageGeneratingFilter.DEFAULT_LOGIN_PAGE_URL); + filter.setFormLoginEnabled(true); + filter.setUsernameParameter("username"); + filter.setPasswordParameter("password"); + filter.setOneTimeTokenEnabled(true); + filter.setOneTimeTokenGenerationUrl("/ott/authenticate"); + MockHttpServletResponse response = new MockHttpServletResponse(); + filter.doFilter( + TestMockHttpServletRequests.get("/login?authority=FACTOR_OTT&authority=FACTOR_PASSWORD").build(), + response, this.chain); + assertThat(response.getContentAsString()).contains("Request a One-Time Token"); + assertThat(response.getContentAsString()).contains(""" + + """); + assertThat(response.getContentAsString()).contains("Password"); + } + @Test void generatesThenRenders() throws ServletException, IOException { DefaultLoginPageGeneratingFilter filter = new DefaultLoginPageGeneratingFilter(