diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/websocket/servlet/WebSocketServletAutoConfiguration.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/websocket/servlet/WebSocketServletAutoConfiguration.java index 781f61309f7..48d3379ef3f 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/websocket/servlet/WebSocketServletAutoConfiguration.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/websocket/servlet/WebSocketServletAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2022 the original author or authors. + * Copyright 2012-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,15 @@ package org.springframework.boot.autoconfigure.websocket.servlet; +import java.util.List; + +import jakarta.servlet.DispatcherType; import jakarta.servlet.Servlet; import jakarta.websocket.server.ServerContainer; import org.apache.catalina.startup.Tomcat; import org.apache.tomcat.websocket.server.WsSci; import org.eclipse.jetty.websocket.jakarta.server.config.JakartaWebSocketServletContainerInitializer; +import org.eclipse.jetty.websocket.servlet.WebSocketUpgradeFilter; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -28,8 +32,10 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication.Type; import org.springframework.boot.autoconfigure.web.servlet.ServletWebServerFactoryAutoConfiguration; +import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.core.Ordered; /** * Auto configuration for WebSocket servlet server in embedded Tomcat, Jetty or Undertow. @@ -79,6 +85,20 @@ public class WebSocketServletAutoConfiguration { return new JettyWebSocketServletWebServerCustomizer(); } + @Bean + @ConditionalOnMissingBean(value = WebSocketUpgradeFilter.class, + parameterizedContainer = FilterRegistrationBean.class) + FilterRegistrationBean webSocketUpgradeFilter() { + WebSocketUpgradeFilter websocketFilter = new WebSocketUpgradeFilter(); + FilterRegistrationBean registration = new FilterRegistrationBean<>(websocketFilter); + registration.setAsyncSupported(true); + registration.setDispatcherTypes(DispatcherType.REQUEST); + registration.setName(WebSocketUpgradeFilter.class.getName()); + registration.setOrder(Ordered.LOWEST_PRECEDENCE); + registration.setUrlPatterns(List.of("/*")); + return registration; + } + } @Configuration(proxyBeanMethods = false) diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/websocket/servlet/WebSocketServletAutoConfigurationTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/websocket/servlet/WebSocketServletAutoConfigurationTests.java index d4eb88131c3..423c89df1d5 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/websocket/servlet/WebSocketServletAutoConfigurationTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/websocket/servlet/WebSocketServletAutoConfigurationTests.java @@ -16,23 +16,46 @@ package org.springframework.boot.autoconfigure.websocket.servlet; +import java.io.IOException; +import java.util.Map; import java.util.stream.Stream; +import jakarta.servlet.DispatcherType; +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.websocket.DeploymentException; import jakarta.websocket.server.ServerContainer; +import jakarta.websocket.server.ServerEndpoint; +import org.eclipse.jetty.websocket.servlet.WebSocketUpgradeFilter; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.servlet.DispatcherServletAutoConfiguration; +import org.springframework.boot.test.context.runner.WebApplicationContextRunner; +import org.springframework.boot.test.web.client.TestRestTemplate; import org.springframework.boot.testsupport.classpath.ForkedClassPath; import org.springframework.boot.testsupport.web.servlet.DirtiesUrlFactories; import org.springframework.boot.testsupport.web.servlet.Servlet5ClassPathOverrides; import org.springframework.boot.web.embedded.jetty.JettyServletWebServerFactory; import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactory; +import org.springframework.boot.web.server.WebServer; import org.springframework.boot.web.server.WebServerFactoryCustomizerBeanPostProcessor; +import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext; import org.springframework.boot.web.servlet.server.ServletWebServerFactory; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.core.Ordered; +import org.springframework.http.HttpStatus; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; import static org.assertj.core.api.Assertions.assertThat; @@ -56,18 +79,90 @@ class WebSocketServletAutoConfigurationTests { } } + @ParameterizedTest(name = "{0}") + @MethodSource("testConfiguration") + @ForkedClassPath + void webSocketUpgradeDoesNotPreventAFilterFromRejectingTheRequest(String server, Class... configuration) + throws DeploymentException { + try (AnnotationConfigServletWebServerApplicationContext context = new AnnotationConfigServletWebServerApplicationContext( + configuration)) { + ServerContainer serverContainer = (ServerContainer) context.getServletContext() + .getAttribute("jakarta.websocket.server.ServerContainer"); + serverContainer.addEndpoint(TestEndpoint.class); + WebServer webServer = context.getWebServer(); + int port = webServer.getPort(); + TestRestTemplate rest = new TestRestTemplate(); + RequestEntity request = RequestEntity.get("http://localhost:" + port) + .header("Upgrade", "websocket") + .header("Connection", "upgrade") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", "key") + .build(); + ResponseEntity response = rest.exchange(request, Void.class); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); + } + } + + @Test + @SuppressWarnings("rawtypes") + void whenCustomUpgradeFilterRegistrationIsDefinedAutoConfiguredRegistrationOfJettyUpgradeFilterBacksOff() { + new WebApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(JettyConfiguration.class, + WebSocketServletAutoConfiguration.JettyWebSocketConfiguration.class)) + .withUserConfiguration(CustomUpgradeFilterRegistrationConfiguration.class) + .run((context) -> { + Map filterRegistrations = context + .getBeansOfType(FilterRegistrationBean.class); + assertThat(filterRegistrations).containsOnlyKeys("unauthorizedFilter", + "customUpgradeFilterRegistration"); + }); + } + + @Test + @SuppressWarnings("rawtypes") + void whenCustomUpgradeFilterIsDefinedAutoConfiguredRegistrationOfJettyUpgradeFilterBacksOff() { + new WebApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(JettyConfiguration.class, + WebSocketServletAutoConfiguration.JettyWebSocketConfiguration.class)) + .withUserConfiguration(CustomUpgradeFilterConfiguration.class) + .run((context) -> { + Map filterRegistrations = context + .getBeansOfType(FilterRegistrationBean.class); + assertThat(filterRegistrations).containsOnlyKeys("unauthorizedFilter"); + }); + } + static Stream testConfiguration() { + String response = "Tomcat"; return Stream.of( Arguments.of("Jetty", - new Class[] { JettyConfiguration.class, + new Class[] { JettyConfiguration.class, DispatcherServletAutoConfiguration.class, WebSocketServletAutoConfiguration.JettyWebSocketConfiguration.class }), - Arguments.of("Tomcat", new Class[] { TomcatConfiguration.class, - WebSocketServletAutoConfiguration.TomcatWebSocketConfiguration.class })); + Arguments.of(response, + new Class[] { TomcatConfiguration.class, DispatcherServletAutoConfiguration.class, + WebSocketServletAutoConfiguration.TomcatWebSocketConfiguration.class })); } @Configuration(proxyBeanMethods = false) static class CommonConfiguration { + @Bean + FilterRegistrationBean unauthorizedFilter() { + FilterRegistrationBean registration = new FilterRegistrationBean<>(new Filter() { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + ((HttpServletResponse) response).sendError(HttpStatus.UNAUTHORIZED.value()); + } + + }); + registration.setOrder(Ordered.HIGHEST_PRECEDENCE); + registration.addUrlPatterns("/*"); + registration.setDispatcherTypes(DispatcherType.REQUEST); + return registration; + } + @Bean WebServerFactoryCustomizerBeanPostProcessor ServletWebServerCustomizerBeanPostProcessor() { return new WebServerFactoryCustomizerBeanPostProcessor(); @@ -100,4 +195,31 @@ class WebSocketServletAutoConfigurationTests { } + @Configuration(proxyBeanMethods = false) + static class CustomUpgradeFilterRegistrationConfiguration { + + @Bean + FilterRegistrationBean customUpgradeFilterRegistration() { + FilterRegistrationBean registration = new FilterRegistrationBean<>( + new WebSocketUpgradeFilter()); + return registration; + } + + } + + @Configuration(proxyBeanMethods = false) + static class CustomUpgradeFilterConfiguration { + + @Bean + WebSocketUpgradeFilter customUpgradeFilter() { + return new WebSocketUpgradeFilter(); + } + + } + + @ServerEndpoint("/") + public static class TestEndpoint { + + } + }