@ -40,6 +40,7 @@ import java.util.Collection;
@@ -40,6 +40,7 @@ import java.util.Collection;
import java.util.Date ;
import java.util.EnumSet ;
import java.util.HashMap ;
import java.util.HashSet ;
import java.util.LinkedHashMap ;
import java.util.List ;
import java.util.Locale ;
@ -65,6 +66,7 @@ import jakarta.servlet.Filter;
@@ -65,6 +66,7 @@ import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain ;
import jakarta.servlet.FilterConfig ;
import jakarta.servlet.GenericServlet ;
import jakarta.servlet.ServletConfig ;
import jakarta.servlet.ServletContext ;
import jakarta.servlet.ServletContextEvent ;
import jakarta.servlet.ServletContextListener ;
@ -1366,6 +1368,26 @@ public abstract class AbstractServletWebServerFactoryTests {
@@ -1366,6 +1368,26 @@ public abstract class AbstractServletWebServerFactoryTests {
+ " \\(http(/1.1)?\\), [0-9]+ \\(http(/1.1)?\\) with context path '/'" ) ;
}
@Test
void servletComponentsAreInitializedWithTheSameThreadContextClassLoader ( ) {
AbstractServletWebServerFactory factory = getFactory ( ) ;
ThreadContextClassLoaderCapturingServlet servlet = new ThreadContextClassLoaderCapturingServlet ( ) ;
ThreadContextClassLoaderCapturingFilter filter = new ThreadContextClassLoaderCapturingFilter ( ) ;
ThreadContextClassLoaderCapturingListener listener = new ThreadContextClassLoaderCapturingListener ( ) ;
this . webServer = factory . getWebServer ( ( context ) - > {
context . addServlet ( "tcclCapturingServlet" , servlet ) . setLoadOnStartup ( 0 ) ;
context . addFilter ( "tcclCapturingFilter" , filter ) ;
context . addListener ( listener ) ;
} ) ;
this . webServer . start ( ) ;
assertThat ( servlet . contextClassLoader ) . isNotNull ( ) ;
assertThat ( filter . contextClassLoader ) . isNotNull ( ) ;
assertThat ( listener . contextClassLoader ) . isNotNull ( ) ;
assertThat ( new HashSet < > (
Arrays . asList ( servlet . contextClassLoader , filter . contextClassLoader , listener . contextClassLoader ) ) )
. hasSize ( 1 ) ;
}
protected Future < Object > initiateGetRequest ( int port , String path ) {
return initiateGetRequest ( HttpClients . createMinimal ( ) , port , path ) ;
}
@ -1822,4 +1844,43 @@ public abstract class AbstractServletWebServerFactoryTests {
@@ -1822,4 +1844,43 @@ public abstract class AbstractServletWebServerFactoryTests {
}
static class ThreadContextClassLoaderCapturingServlet extends HttpServlet {
private ClassLoader contextClassLoader ;
@Override
public void init ( ServletConfig config ) throws ServletException {
this . contextClassLoader = Thread . currentThread ( ) . getContextClassLoader ( ) ;
}
}
static class ThreadContextClassLoaderCapturingListener implements ServletContextListener {
private ClassLoader contextClassLoader ;
@Override
public void contextInitialized ( ServletContextEvent sce ) {
this . contextClassLoader = Thread . currentThread ( ) . getContextClassLoader ( ) ;
}
}
static class ThreadContextClassLoaderCapturingFilter implements Filter {
private ClassLoader contextClassLoader ;
@Override
public void init ( FilterConfig filterConfig ) throws ServletException {
this . contextClassLoader = Thread . currentThread ( ) . getContextClassLoader ( ) ;
}
@Override
public void doFilter ( ServletRequest request , ServletResponse response , FilterChain chain )
throws IOException , ServletException {
chain . doFilter ( request , response ) ;
}
}
}