diff --git a/core/spring-security-core.gradle b/core/spring-security-core.gradle index 7a326ed591..3eb2343e3e 100644 --- a/core/spring-security-core.gradle +++ b/core/spring-security-core.gradle @@ -16,6 +16,7 @@ dependencies { api 'io.micrometer:micrometer-observation' optional 'com.fasterxml.jackson.core:jackson-databind' + optional 'io.micrometer:context-propagation' optional 'io.projectreactor:reactor-core' optional 'jakarta.annotation:jakarta.annotation-api' optional 'org.aspectj:aspectjrt' diff --git a/core/src/main/java/org/springframework/security/core/context/ReactiveSecurityContextHolderThreadLocalAccessor.java b/core/src/main/java/org/springframework/security/core/context/ReactiveSecurityContextHolderThreadLocalAccessor.java new file mode 100644 index 0000000000..6b7953f95b --- /dev/null +++ b/core/src/main/java/org/springframework/security/core/context/ReactiveSecurityContextHolderThreadLocalAccessor.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.core.context; + +import io.micrometer.context.ThreadLocalAccessor; +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; + +/** + * A {@link ThreadLocalAccessor} for accessing a {@link SecurityContext} with the + * {@link ReactiveSecurityContextHolder}. + *

+ * This class adapts the {@link ReactiveSecurityContextHolder} to the + * {@link ThreadLocalAccessor} contract to allow Micrometer Context Propagation to + * automatically propagate a {@link SecurityContext} in Reactive applications. It is + * automatically registered with the {@link io.micrometer.context.ContextRegistry} through + * the {@link java.util.ServiceLoader} mechanism when context-propagation is on the + * classpath. + * + * @author Steve Riesenberg + * @since 6.5 + * @see io.micrometer.context.ContextRegistry + */ +public final class ReactiveSecurityContextHolderThreadLocalAccessor + implements ThreadLocalAccessor> { + + private static final ThreadLocal> threadLocal = new ThreadLocal<>(); + + @Override + public Object key() { + return SecurityContext.class; + } + + @Override + public Mono getValue() { + return threadLocal.get(); + } + + @Override + public void setValue(Mono securityContext) { + Assert.notNull(securityContext, "securityContext cannot be null"); + threadLocal.set(securityContext); + } + + @Override + public void setValue() { + threadLocal.remove(); + } + +} diff --git a/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderThreadLocalAccessor.java b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderThreadLocalAccessor.java new file mode 100644 index 0000000000..79b817d73f --- /dev/null +++ b/core/src/main/java/org/springframework/security/core/context/SecurityContextHolderThreadLocalAccessor.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.core.context; + +import io.micrometer.context.ThreadLocalAccessor; + +import org.springframework.util.Assert; + +/** + * A {@link ThreadLocalAccessor} for accessing a {@link SecurityContext} with the + * {@link SecurityContextHolder}. + *

+ * This class adapts the {@link SecurityContextHolder} to the {@link ThreadLocalAccessor} + * contract to allow Micrometer Context Propagation to automatically propagate a + * {@link SecurityContext} in Servlet applications. It is automatically registered with + * the {@link io.micrometer.context.ContextRegistry} through the + * {@link java.util.ServiceLoader} mechanism when context-propagation is on the classpath. + * + * @author Steve Riesenberg + * @since 6.5 + * @see io.micrometer.context.ContextRegistry + */ +public final class SecurityContextHolderThreadLocalAccessor implements ThreadLocalAccessor { + + @Override + public Object key() { + return SecurityContext.class.getName(); + } + + @Override + public SecurityContext getValue() { + SecurityContext securityContext = SecurityContextHolder.getContext(); + SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); + + return !securityContext.equals(emptyContext) ? securityContext : null; + } + + @Override + public void setValue(SecurityContext securityContext) { + Assert.notNull(securityContext, "securityContext cannot be null"); + SecurityContextHolder.setContext(securityContext); + } + + @Override + public void setValue() { + SecurityContextHolder.clearContext(); + } + +} diff --git a/core/src/main/resources/META-INF/services/io.micrometer.context.ThreadLocalAccessor b/core/src/main/resources/META-INF/services/io.micrometer.context.ThreadLocalAccessor new file mode 100644 index 0000000000..65b406e6ae --- /dev/null +++ b/core/src/main/resources/META-INF/services/io.micrometer.context.ThreadLocalAccessor @@ -0,0 +1,2 @@ +org.springframework.security.core.context.ReactiveSecurityContextHolderThreadLocalAccessor +org.springframework.security.core.context.SecurityContextHolderThreadLocalAccessor diff --git a/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderThreadLocalAccessorTests.java b/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderThreadLocalAccessorTests.java new file mode 100644 index 0000000000..12d5a6384d --- /dev/null +++ b/core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderThreadLocalAccessorTests.java @@ -0,0 +1,123 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.core.context; + +import java.util.concurrent.CountDownLatch; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.security.authentication.TestingAuthenticationToken; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link ReactiveSecurityContextHolderThreadLocalAccessor}. + * + * @author Steve Riesenberg + */ +public class ReactiveSecurityContextHolderThreadLocalAccessorTests { + + private ReactiveSecurityContextHolderThreadLocalAccessor threadLocalAccessor; + + @BeforeEach + public void setUp() { + this.threadLocalAccessor = new ReactiveSecurityContextHolderThreadLocalAccessor(); + } + + @AfterEach + public void tearDown() { + this.threadLocalAccessor.setValue(); + } + + @Test + public void keyAlwaysReturnsSecurityContextClass() { + assertThat(this.threadLocalAccessor.key()).isEqualTo(SecurityContext.class); + } + + @Test + public void getValueWhenThreadLocalNotSetThenReturnsNull() { + assertThat(this.threadLocalAccessor.getValue()).isNull(); + } + + @Test + public void getValueWhenThreadLocalSetThenReturnsSecurityContextMono() { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", "password")); + Mono mono = Mono.just(securityContext); + this.threadLocalAccessor.setValue(mono); + + assertThat(this.threadLocalAccessor.getValue()).isSameAs(mono); + } + + @Test + public void getValueWhenThreadLocalSetOnAnotherThreadThenReturnsNull() throws InterruptedException { + CountDownLatch threadLocalSet = new CountDownLatch(1); + CountDownLatch threadLocalRead = new CountDownLatch(1); + CountDownLatch threadLocalCleared = new CountDownLatch(1); + + Runnable task = () -> { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", "password")); + Mono mono = Mono.just(securityContext); + this.threadLocalAccessor.setValue(mono); + threadLocalSet.countDown(); + try { + threadLocalRead.await(); + } + catch (InterruptedException ignored) { + } + finally { + this.threadLocalAccessor.setValue(); + threadLocalCleared.countDown(); + } + }; + try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor()) { + taskExecutor.execute(task); + threadLocalSet.await(); + assertThat(this.threadLocalAccessor.getValue()).isNull(); + threadLocalRead.countDown(); + threadLocalCleared.await(); + } + } + + @Test + public void setValueWhenNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.threadLocalAccessor.setValue(null)) + .withMessage("securityContext cannot be null"); + // @formatter:on + } + + @Test + public void setValueWhenThreadLocalSetThenClearsThreadLocal() { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", "password")); + Mono mono = Mono.just(securityContext); + this.threadLocalAccessor.setValue(mono); + assertThat(this.threadLocalAccessor.getValue()).isSameAs(mono); + + this.threadLocalAccessor.setValue(); + assertThat(this.threadLocalAccessor.getValue()).isNull(); + } + +} diff --git a/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderThreadLocalAccessorTests.java b/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderThreadLocalAccessorTests.java new file mode 100644 index 0000000000..c3a6a6bb93 --- /dev/null +++ b/core/src/test/java/org/springframework/security/core/context/SecurityContextHolderThreadLocalAccessorTests.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.core.context; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.security.authentication.TestingAuthenticationToken; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link SecurityContextHolderThreadLocalAccessor}. + * + * @author Steve Riesenberg + */ +public class SecurityContextHolderThreadLocalAccessorTests { + + private SecurityContextHolderThreadLocalAccessor threadLocalAccessor; + + @BeforeEach + public void setUp() { + this.threadLocalAccessor = new SecurityContextHolderThreadLocalAccessor(); + } + + @AfterEach + public void tearDown() { + this.threadLocalAccessor.setValue(); + } + + @Test + public void keyAlwaysReturnsSecurityContextClassName() { + assertThat(this.threadLocalAccessor.key()).isEqualTo(SecurityContext.class.getName()); + } + + @Test + public void getValueWhenSecurityContextHolderNotSetThenReturnsNull() { + assertThat(this.threadLocalAccessor.getValue()).isNull(); + } + + @Test + public void getValueWhenSecurityContextHolderSetThenReturnsSecurityContext() { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", "password")); + SecurityContextHolder.setContext(securityContext); + assertThat(this.threadLocalAccessor.getValue()).isSameAs(securityContext); + } + + @Test + public void setValueWhenSecurityContextThenSetsSecurityContextHolder() { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", "password")); + this.threadLocalAccessor.setValue(securityContext); + assertThat(SecurityContextHolder.getContext()).isSameAs(securityContext); + } + + @Test + public void setValueWhenNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.threadLocalAccessor.setValue(null)) + .withMessage("securityContext cannot be null"); + // @formatter:on + } + + @Test + public void setValueWhenSecurityContextSetThenClearsSecurityContextHolder() { + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(new TestingAuthenticationToken("user", "password")); + SecurityContextHolder.setContext(securityContext); + this.threadLocalAccessor.setValue(); + + SecurityContext emptyContext = SecurityContextHolder.createEmptyContext(); + assertThat(SecurityContextHolder.getContext()).isEqualTo(emptyContext); + } + +} diff --git a/dependencies/spring-security-dependencies.gradle b/dependencies/spring-security-dependencies.gradle index 7d69e3ec13..2a2d433252 100644 --- a/dependencies/spring-security-dependencies.gradle +++ b/dependencies/spring-security-dependencies.gradle @@ -35,6 +35,7 @@ dependencies { api libs.com.unboundid.unboundid.ldapsdk api libs.commons.collections api libs.io.mockk + api libs.io.micrometer.context.propagation api libs.io.micrometer.micrometer.observation api libs.jakarta.annotation.jakarta.annotation.api api libs.jakarta.inject.jakarta.inject.api diff --git a/docs/modules/ROOT/pages/whats-new.adoc b/docs/modules/ROOT/pages/whats-new.adoc index ce39420774..03368827eb 100644 --- a/docs/modules/ROOT/pages/whats-new.adoc +++ b/docs/modules/ROOT/pages/whats-new.adoc @@ -4,6 +4,10 @@ Spring Security 6.5 provides a number of new features. Below are the highlights of the release, or you can view https://github.com/spring-projects/spring-security/releases[the release notes] for a detailed listing of each feature and bug fix. +== New Features + +* Support for automatic context-propagation with Micrometer (https://github.com/spring-projects/spring-security/issues/16665[gh-16665]) + == Breaking Changes === Observability diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c189bc8393..63fe86fa4d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -28,6 +28,7 @@ com-squareup-okhttp3-okhttp = { module = "com.squareup.okhttp3:okhttp", version. com-unboundid-unboundid-ldapsdk = "com.unboundid:unboundid-ldapsdk:6.0.11" com-unboundid-unboundid-ldapsdk7 = "com.unboundid:unboundid-ldapsdk:7.0.1" commons-collections = "commons-collections:commons-collections:3.2.2" +io-micrometer-context-propagation = "io.micrometer:context-propagation:1.1.2" io-micrometer-micrometer-observation = "io.micrometer:micrometer-observation:1.14.5" io-mockk = "io.mockk:mockk:1.13.17" io-projectreactor-reactor-bom = "io.projectreactor:reactor-bom:2023.0.16" diff --git a/oauth2/oauth2-client/spring-security-oauth2-client.gradle b/oauth2/oauth2-client/spring-security-oauth2-client.gradle index 9231fc2c22..35806dfebf 100644 --- a/oauth2/oauth2-client/spring-security-oauth2-client.gradle +++ b/oauth2/oauth2-client/spring-security-oauth2-client.gradle @@ -19,6 +19,7 @@ dependencies { testImplementation project(path: ':spring-security-oauth2-core', configuration: 'tests') testImplementation project(path: ':spring-security-oauth2-jose', configuration: 'tests') testImplementation 'com.squareup.okhttp3:mockwebserver' + testImplementation 'io.micrometer:context-propagation' testImplementation 'io.projectreactor.netty:reactor-netty' testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.skyscreamer:jsonassert' diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java index 4bf76c6c2d..46a7c102fb 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java @@ -19,15 +19,22 @@ package org.springframework.security.oauth2.client.web; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; import java.util.function.Function; +import io.micrometer.context.ContextExecutorService; +import io.micrometer.context.ContextSnapshotFactory; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.test.publisher.PublisherProbe; import reactor.util.context.Context; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.support.ExecutorServiceAdapter; import org.springframework.http.MediaType; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; @@ -565,6 +572,41 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { assertThat(requestScopeAttribute).contains("read", "write"); } + @Test + public void authorizeWhenBlockingExecutionAndContextPropagationEnabledThenContextPropagated() + throws InterruptedException { + Hooks.enableAutomaticContextPropagation(); + given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) + .willReturn(Mono.just(this.clientRegistration)); + given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(this.authorizedClient)); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest + .withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + CountDownLatch countDownLatch = new CountDownLatch(1); + Runnable task = () -> { + try { + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest) + .block(); + assertThat(authorizedClient).isSameAs(this.authorizedClient); + } + finally { + countDownLatch.countDown(); + } + }; + try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor()) { + ContextSnapshotFactory contextSnapshotFactory = ContextSnapshotFactory.builder().build(); + ExecutorService executorService = ContextExecutorService.wrap(new ExecutorServiceAdapter(taskExecutor), + contextSnapshotFactory); + Mono.fromRunnable(() -> executorService.execute(task)).contextWrite(this.context).block(); + } + + countDownLatch.await(); + verify(this.authorizedClientProvider).authorize(any(OAuth2AuthorizationContext.class)); + } + private Mono currentServerWebExchange() { return Mono.deferContextual(Mono::just) .filter((c) -> c.hasKey(ServerWebExchange.class)) diff --git a/web/spring-security-web.gradle b/web/spring-security-web.gradle index a709c95a99..1a87fac63e 100644 --- a/web/spring-security-web.gradle +++ b/web/spring-security-web.gradle @@ -36,6 +36,7 @@ dependencies { api 'org.springframework:spring-web' optional 'com.fasterxml.jackson.core:jackson-databind' + optional 'io.micrometer:context-propagation' optional 'io.projectreactor:reactor-core' optional 'org.springframework:spring-jdbc' optional 'org.springframework:spring-tx' diff --git a/web/src/main/java/org/springframework/security/web/server/ServerWebExchangeThreadLocalAccessor.java b/web/src/main/java/org/springframework/security/web/server/ServerWebExchangeThreadLocalAccessor.java new file mode 100644 index 0000000000..b943ecaad4 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/ServerWebExchangeThreadLocalAccessor.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server; + +import io.micrometer.context.ThreadLocalAccessor; + +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + +/** + * A {@link ThreadLocalAccessor} for accessing a {@link ServerWebExchange}. + *

+ * This class adapts the existing Reactor Context attribute + * {@code ServerWebExchange.class} to the {@link ThreadLocalAccessor} contract to allow + * Micrometer Context Propagation to automatically propagate a {@link ServerWebExchange} + * in Reactive applications. It is automatically registered with the + * {@link io.micrometer.context.ContextRegistry} through the + * {@link java.util.ServiceLoader} mechanism when context-propagation is on the classpath. + * + * @author Steve Riesenberg + * @since 6.5 + * @see io.micrometer.context.ContextRegistry + */ +public final class ServerWebExchangeThreadLocalAccessor implements ThreadLocalAccessor { + + private static final ThreadLocal threadLocal = new ThreadLocal<>(); + + @Override + public Object key() { + return ServerWebExchange.class; + } + + @Override + public ServerWebExchange getValue() { + return threadLocal.get(); + } + + @Override + public void setValue(ServerWebExchange exchange) { + Assert.notNull(exchange, "exchange cannot be null"); + threadLocal.set(exchange); + } + + @Override + public void setValue() { + threadLocal.remove(); + } + +} diff --git a/web/src/main/resources/META-INF/services/io.micrometer.context.ThreadLocalAccessor b/web/src/main/resources/META-INF/services/io.micrometer.context.ThreadLocalAccessor new file mode 100644 index 0000000000..63959839f8 --- /dev/null +++ b/web/src/main/resources/META-INF/services/io.micrometer.context.ThreadLocalAccessor @@ -0,0 +1 @@ +org.springframework.security.web.server.ServerWebExchangeThreadLocalAccessor diff --git a/web/src/test/java/org/springframework/security/web/ServerWebExchangeThreadLocalAccessorTests.java b/web/src/test/java/org/springframework/security/web/ServerWebExchangeThreadLocalAccessorTests.java new file mode 100644 index 0000000000..d5bea2890b --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/ServerWebExchangeThreadLocalAccessorTests.java @@ -0,0 +1,118 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web; + +import java.util.concurrent.CountDownLatch; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.web.server.ServerWebExchangeThreadLocalAccessor; +import org.springframework.web.server.ServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link ServerWebExchangeThreadLocalAccessor}. + * + * @author Steve Riesenberg + */ +public class ServerWebExchangeThreadLocalAccessorTests { + + private ServerWebExchangeThreadLocalAccessor threadLocalAccessor; + + private ServerWebExchange exchange; + + @BeforeEach + public void setUp() { + this.threadLocalAccessor = new ServerWebExchangeThreadLocalAccessor(); + this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); + } + + @AfterEach + public void tearDown() { + this.threadLocalAccessor.setValue(); + } + + @Test + public void keyAlwaysReturnsServerWebExchangeClass() { + assertThat(this.threadLocalAccessor.key()).isEqualTo(ServerWebExchange.class); + } + + @Test + public void getValueWhenThreadLocalNotSetThenReturnsNull() { + assertThat(this.threadLocalAccessor.getValue()).isNull(); + } + + @Test + public void getValueWhenThreadLocalSetThenReturnsServerWebExchange() { + this.threadLocalAccessor.setValue(this.exchange); + assertThat(this.threadLocalAccessor.getValue()).isSameAs(this.exchange); + } + + @Test + public void getValueWhenThreadLocalSetOnAnotherThreadThenReturnsNull() throws InterruptedException { + CountDownLatch threadLocalSet = new CountDownLatch(1); + CountDownLatch threadLocalRead = new CountDownLatch(1); + CountDownLatch threadLocalCleared = new CountDownLatch(1); + + Runnable task = () -> { + this.threadLocalAccessor.setValue(this.exchange); + threadLocalSet.countDown(); + try { + threadLocalRead.await(); + } + catch (InterruptedException ignored) { + } + finally { + this.threadLocalAccessor.setValue(); + threadLocalCleared.countDown(); + } + }; + try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor()) { + taskExecutor.execute(task); + threadLocalSet.await(); + assertThat(this.threadLocalAccessor.getValue()).isNull(); + threadLocalRead.countDown(); + threadLocalCleared.await(); + } + } + + @Test + public void setValueWhenNullThenThrowsIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.threadLocalAccessor.setValue(null)) + .withMessage("exchange cannot be null"); + // @formatter:on + } + + @Test + public void setValueWhenThreadLocalSetThenClearsThreadLocal() { + this.threadLocalAccessor.setValue(this.exchange); + assertThat(this.threadLocalAccessor.getValue()).isSameAs(this.exchange); + + this.threadLocalAccessor.setValue(); + assertThat(this.threadLocalAccessor.getValue()).isNull(); + } + +}