@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2022 the original author or authors .
* Copyright 2002 - 2023 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -17,14 +17,20 @@
@@ -17,14 +17,20 @@
package org.springframework.security.config.annotation.web.configuration ;
import java.net.URI ;
import java.util.Arrays ;
import java.util.HashMap ;
import java.util.Map ;
import java.util.concurrent.Executors ;
import java.util.concurrent.Future ;
import java.util.concurrent.ThreadFactory ;
import jakarta.servlet.http.HttpServletRequest ;
import jakarta.servlet.http.HttpServletResponse ;
import org.junit.jupiter.api.AfterEach ;
import org.junit.jupiter.api.BeforeEach ;
import org.junit.jupiter.api.Test ;
import org.junit.jupiter.api.condition.DisabledOnJre ;
import org.junit.jupiter.api.condition.JRE ;
import org.junit.jupiter.api.extension.ExtendWith ;
import reactor.core.CoreSubscriber ;
import reactor.core.publisher.BaseSubscriber ;
@ -35,6 +41,8 @@ import reactor.util.context.Context;
@@ -35,6 +41,8 @@ import reactor.util.context.Context;
import org.springframework.context.annotation.Bean ;
import org.springframework.context.annotation.Configuration ;
import org.springframework.core.task.SimpleAsyncTaskExecutor ;
import org.springframework.core.task.VirtualThreadTaskExecutor ;
import org.springframework.http.HttpMethod ;
import org.springframework.http.HttpStatus ;
import org.springframework.mock.web.MockHttpServletRequest ;
@ -46,6 +54,7 @@ import org.springframework.security.config.annotation.web.configuration.Security
@@ -46,6 +54,7 @@ import org.springframework.security.config.annotation.web.configuration.Security
import org.springframework.security.config.test.SpringTestContext ;
import org.springframework.security.config.test.SpringTestContextExtension ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.context.SecurityContext ;
import org.springframework.security.core.context.SecurityContextHolder ;
import org.springframework.security.core.context.SecurityContextHolderStrategy ;
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction ;
@ -271,6 +280,58 @@ public class SecurityReactorContextConfigurationTests {
@@ -271,6 +280,58 @@ public class SecurityReactorContextConfigurationTests {
verify ( strategy , times ( 2 ) ) . getContext ( ) ;
}
@Test
public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable ( ) throws Exception {
this . spring . register ( SecurityConfig . class ) . autowire ( ) ;
ThreadFactory threadFactory = Executors . defaultThreadFactory ( ) ;
assertContextAttributesAvailable ( threadFactory ) ;
}
@Test
@DisabledOnJre ( JRE . JAVA_17 )
public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable ( ) throws Exception {
this . spring . register ( SecurityConfig . class ) . autowire ( ) ;
ThreadFactory threadFactory = new VirtualThreadTaskExecutor ( ) . getVirtualThreadFactory ( ) ;
assertContextAttributesAvailable ( threadFactory ) ;
}
private void assertContextAttributesAvailable ( ThreadFactory threadFactory ) throws Exception {
Map < Object , Object > expectedContextAttributes = new HashMap < > ( ) ;
expectedContextAttributes . put ( HttpServletRequest . class , this . servletRequest ) ;
expectedContextAttributes . put ( HttpServletResponse . class , this . servletResponse ) ;
expectedContextAttributes . put ( Authentication . class , this . authentication ) ;
try ( SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor ( threadFactory ) ) {
Future < Map < Object , Object > > future = taskExecutor . submit ( this : : propagateRequestAttributes ) ;
assertThat ( future . get ( ) ) . isEqualTo ( expectedContextAttributes ) ;
}
}
private Map < Object , Object > propagateRequestAttributes ( ) {
RequestAttributes requestAttributes = new ServletRequestAttributes ( this . servletRequest , this . servletResponse ) ;
RequestContextHolder . setRequestAttributes ( requestAttributes ) ;
SecurityContext securityContext = SecurityContextHolder . createEmptyContext ( ) ;
securityContext . setAuthentication ( this . authentication ) ;
SecurityContextHolder . setContext ( securityContext ) ;
// @formatter:off
return Mono . deferContextual ( Mono : : just )
. filter ( ( ctx ) - > ctx . hasKey ( SecurityReactorContextSubscriber . SECURITY_CONTEXT_ATTRIBUTES ) )
. map ( ( ctx ) - > ctx . < Map < Object , Object > > get ( SecurityReactorContextSubscriber . SECURITY_CONTEXT_ATTRIBUTES ) )
. map ( ( attributes ) - > {
Map < Object , Object > map = new HashMap < > ( ) ;
// Copy over items from lazily loaded map
Arrays . asList ( HttpServletRequest . class , HttpServletResponse . class , Authentication . class )
. forEach ( ( key ) - > map . put ( key , attributes . get ( key ) ) ) ;
return map ;
} )
. block ( ) ;
// @formatter:on
}
@Configuration
@EnableWebSecurity
static class SecurityConfig {