From e176d764ba2ae752d5d5c389b2f8274d89bae173 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Fri, 25 Mar 2022 13:45:30 -0500 Subject: [PATCH] Add SecurityContextRepository.loadContext(HttpServletRequest) This allows loading the SecurityContext lazily, without the need for the response, and does not attempt to automatically save the request when the response is comitted. Closes gh-11028 --- .../SecurityContextConfigurerTests.java | 3 +- .../config/http/MiscHttpConfigTests.java | 4 +-- ...estAttributeSecurityContextRepository.java | 11 +++++--- .../context/SecurityContextHolderFilter.java | 3 +- .../context/SecurityContextRepository.java | 16 +++++++++++ ...SessionSecurityContextRepositoryTests.java | 28 +++++++++++++++++++ .../SecurityContextHolderFilterTests.java | 7 ++--- 7 files changed, 57 insertions(+), 15 deletions(-) diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurerTests.java index bc2113dd6c..59e0db7892 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurerTests.java @@ -80,7 +80,8 @@ public class SecurityContextConfigurerTests { @Test public void securityContextWhenInvokedTwiceThenUsesOriginalSecurityContextRepository() throws Exception { this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire(); - given(DuplicateDoesNotOverrideConfig.SCR.loadContext(any())).willReturn(mock(SecurityContext.class)); + given(DuplicateDoesNotOverrideConfig.SCR.loadContext(any(HttpRequestResponseHolder.class))) + .willReturn(mock(SecurityContext.class)); this.mvc.perform(get("/")); verify(DuplicateDoesNotOverrideConfig.SCR).loadContext(any(HttpRequestResponseHolder.class)); } diff --git a/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java b/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java index ace1a7903f..6a5293fb9d 100644 --- a/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java @@ -124,7 +124,6 @@ import static org.springframework.security.test.web.servlet.request.SecurityMock import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509; -import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -468,11 +467,10 @@ public class MiscHttpConfigTests { this.spring.configLocations(xml("ExplicitSaveAndExplicitRepository")).autowire(); SecurityContextRepository repository = this.spring.getContext().getBean(SecurityContextRepository.class); SecurityContext context = new SecurityContextImpl(new TestingAuthenticationToken("user", "password")); - given(repository.loadContext(any(HttpRequestResponseHolder.class))).willReturn(context); + given(repository.loadContext(any(HttpServletRequest.class))).willReturn(() -> context); // @formatter:off MvcResult result = this.mvc.perform(formLogin()) .andExpect(status().is3xxRedirection()) - .andExpect(authenticated()) .andReturn(); // @formatter:on verify(repository, atLeastOnce()).saveContext(any(SecurityContext.class), any(HttpServletRequest.class), diff --git a/web/src/main/java/org/springframework/security/web/context/RequestAttributeSecurityContextRepository.java b/web/src/main/java/org/springframework/security/web/context/RequestAttributeSecurityContextRepository.java index 795dea007e..d19ac85b30 100644 --- a/web/src/main/java/org/springframework/security/web/context/RequestAttributeSecurityContextRepository.java +++ b/web/src/main/java/org/springframework/security/web/context/RequestAttributeSecurityContextRepository.java @@ -16,6 +16,8 @@ package org.springframework.security.web.context; +import java.util.function.Supplier; + import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -64,17 +66,18 @@ public final class RequestAttributeSecurityContextRepository implements Security @Override public boolean containsContext(HttpServletRequest request) { - return loadContext(request) != null; + return loadContext(request).get() != null; } @Override public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) { - SecurityContext context = loadContext(requestResponseHolder.getRequest()); + SecurityContext context = loadContext(requestResponseHolder.getRequest()).get(); return (context != null) ? context : SecurityContextHolder.createEmptyContext(); } - private SecurityContext loadContext(HttpServletRequest request) { - return (SecurityContext) request.getAttribute(this.requestAttributeName); + @Override + public Supplier loadContext(HttpServletRequest request) { + return () -> (SecurityContext) request.getAttribute(this.requestAttributeName); } @Override diff --git a/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java b/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java index 4aba0d3bf9..05e3fd493f 100644 --- a/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java +++ b/web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java @@ -58,8 +58,7 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - SecurityContext securityContext = this.securityContextRepository - .loadContext(new HttpRequestResponseHolder(request, response)); + SecurityContext securityContext = this.securityContextRepository.loadContext(request).get(); try { SecurityContextHolder.setContext(securityContext); filterChain.doFilter(request, response); diff --git a/web/src/main/java/org/springframework/security/web/context/SecurityContextRepository.java b/web/src/main/java/org/springframework/security/web/context/SecurityContextRepository.java index 9518481928..95df5213ae 100644 --- a/web/src/main/java/org/springframework/security/web/context/SecurityContextRepository.java +++ b/web/src/main/java/org/springframework/security/web/context/SecurityContextRepository.java @@ -16,6 +16,8 @@ package org.springframework.security.web.context; +import java.util.function.Supplier; + import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -61,6 +63,20 @@ public interface SecurityContextRepository { */ SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder); + /** + * Obtains the security context for the supplied request. For an unauthenticated user, + * an empty context implementation should be returned. This method should not return + * null. + * @param request the {@link HttpServletRequest} to load the {@link SecurityContext} + * from + * @return a {@link Supplier} that returns the {@link SecurityContext} which cannot be + * null. + * @since 5.7 + */ + default Supplier loadContext(HttpServletRequest request) { + return () -> loadContext(new HttpRequestResponseHolder(request, null)); + } + /** * Stores the security context on completion of a request. * @param context the non-null context which was obtained from the holder. diff --git a/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java b/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java index d3e6e9053b..8718d191a9 100644 --- a/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java @@ -63,6 +63,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; /** * @author Luke Taylor @@ -141,6 +142,33 @@ public class HttpSessionSecurityContextRepositoryTests { assertThat(repo.loadContext(holder)).isEqualTo(SecurityContextHolder.createEmptyContext()); } + @Test + public void loadContextHttpServletRequestWhenNotSavedThenEmptyContextReturned() { + HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); + MockHttpServletRequest request = new MockHttpServletRequest(); + assertThat(repo.loadContext(request).get()).isEqualTo(SecurityContextHolder.createEmptyContext()); + } + + @Test + public void loadContextHttpServletRequestWhenSavedThenSavedContextReturned() { + SecurityContextImpl expectedContext = new SecurityContextImpl(this.testToken); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); + repo.saveContext(expectedContext, request, response); + assertThat(repo.loadContext(request).get()).isEqualTo(expectedContext); + } + + @Test + public void loadContextHttpServletRequestWhenNotAccessedThenHttpSessionNotAccessed() { + HttpSession session = mock(HttpSession.class); + HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setSession(session); + repo.loadContext(request); + verifyNoInteractions(session); + } + @Test public void existingContextIsSuccessFullyLoadedFromSessionAndSavedBack() { HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); diff --git a/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java b/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java index 5f9b8f0c04..f169591c0d 100644 --- a/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java @@ -49,11 +49,8 @@ class SecurityContextHolderFilterTests { @Mock private HttpServletResponse response; - @Mock - private FilterChain chain; - @Captor - private ArgumentCaptor requestResponse; + private ArgumentCaptor requestArg; private SecurityContextHolderFilter filter; @@ -71,7 +68,7 @@ class SecurityContextHolderFilterTests { void doFilterThenSetsAndClearsSecurityContextHolder() throws Exception { Authentication authentication = TestAuthentication.authenticatedUser(); SecurityContext expectedContext = new SecurityContextImpl(authentication); - given(this.repository.loadContext(this.requestResponse.capture())).willReturn(expectedContext); + given(this.repository.loadContext(this.requestArg.capture())).willReturn(() -> expectedContext); FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext()) .isEqualTo(expectedContext);