diff --git a/data/src/main/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtension.java b/data/src/main/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtension.java index a4ec7b00d2..ba1a938b48 100644 --- a/data/src/main/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtension.java +++ b/data/src/main/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtension.java @@ -27,6 +27,7 @@ import org.springframework.security.authentication.AuthenticationTrustResolverIm import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; /** @@ -89,6 +90,9 @@ import org.springframework.util.Assert; */ public class SecurityEvaluationContextExtension implements EvaluationContextExtension { + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private Authentication authentication; private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl(); @@ -131,11 +135,22 @@ public class SecurityEvaluationContextExtension implements EvaluationContextExte return root; } + /** + * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use + * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}. + * + * @since 5.8 + */ + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null"); + this.securityContextHolderStrategy = securityContextHolderStrategy; + } + private Authentication getAuthentication() { if (this.authentication != null) { return this.authentication; } - SecurityContext context = SecurityContextHolder.getContext(); + SecurityContext context = this.securityContextHolderStrategy.getContext(); return context.getAuthentication(); } diff --git a/data/src/test/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtensionTests.java b/data/src/test/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtensionTests.java index 293c114fb1..0480a5f6c7 100644 --- a/data/src/test/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtensionTests.java +++ b/data/src/test/java/org/springframework/security/data/repository/query/SecurityEvaluationContextExtensionTests.java @@ -29,9 +29,14 @@ import org.springframework.security.authentication.AuthenticationTrustResolver; import org.springframework.security.authentication.AuthenticationTrustResolverImpl; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; public class SecurityEvaluationContextExtensionTests { @@ -59,6 +64,16 @@ public class SecurityEvaluationContextExtensionTests { assertThat(getRoot().getAuthentication()).isSameAs(authentication); } + @Test + public void getRootObjectUseSecurityContextHolderStrategy() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(new SecurityContextImpl(authentication)); + this.securityExtension.setSecurityContextHolderStrategy(strategy); + assertThat(getRoot().getAuthentication()).isSameAs(authentication); + verify(strategy).getContext(); + } + @Test public void getRootObjectExplicitAuthenticationOverridesSecurityContextHolder() { TestingAuthenticationToken explicit = new TestingAuthenticationToken("explicit", "password", "ROLE_EXPLICIT");