Browse Source

Copy WS handshake headers to store in session

Prior to this commit, the `StandardWebSocketUpgradeStrategy` would get
the HTTP headers from the handshake request and store them in the
WebSocket session for the entire duration of the session.
As of gh-36334, Spring MVC manages HTTP directly with a native API
instead of copying them. This improves performance but also uncovered
this bug: we cannot keep a reference to HTTP headers once the HTTP
exchange is finished, because such resources can be recycled and reused.

This commit ensures that the handshake headers are copied into the
session info to keep them around for the entire duration of the session.
Without that, Tomcat will raise an `IllegalStateException` at runtime.

This was already done for WebFlux in SPR-17250, but the latest header
management changes in Framework uncovered this issue for the Standard
WebSocket container case.

Fixes gh-36486
7.0.x
Brian Clozel 1 day ago
parent
commit
a21706c58a
  1. 3
      spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java
  2. 4
      spring-websocket/src/main/java/org/springframework/web/socket/server/standard/StandardWebSocketUpgradeStrategy.java
  3. 26
      spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java

3
spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java

@ -254,8 +254,7 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
URI uri = request.getURI(); URI uri = request.getURI();
// Copy request headers, as they might be pooled and recycled by // Copy request headers, as they might be pooled and recycled by
// the server implementation once the handshake HTTP exchange is done. // the server implementation once the handshake HTTP exchange is done.
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = HttpHeaders.copyOf(request.getHeaders());
headers.addAll(request.getHeaders());
MultiValueMap<String, HttpCookie> cookies = request.getCookies(); MultiValueMap<String, HttpCookie> cookies = request.getCookies();
Mono<Principal> principal = exchange.getPrincipal(); Mono<Principal> principal = exchange.getPrincipal();
String logPrefix = exchange.getLogPrefix(); String logPrefix = exchange.getLogPrefix();

4
spring-websocket/src/main/java/org/springframework/web/socket/server/standard/StandardWebSocketUpgradeStrategy.java

@ -97,7 +97,9 @@ public class StandardWebSocketUpgradeStrategy implements RequestUpgradeStrategy
@Nullable Principal user, WebSocketHandler wsHandler, Map<String, Object> attrs) @Nullable Principal user, WebSocketHandler wsHandler, Map<String, Object> attrs)
throws HandshakeFailureException { throws HandshakeFailureException {
HttpHeaders headers = request.getHeaders(); // Copy request headers, as they might be pooled and recycled by
// the server implementation once the handshake HTTP exchange is done.
HttpHeaders headers = HttpHeaders.copyOf(request.getHeaders());
InetSocketAddress localAddr = null; InetSocketAddress localAddr = null;
try { try {
localAddr = request.getLocalAddress(); localAddr = request.getLocalAddress();

26
spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java

@ -16,6 +16,7 @@
package org.springframework.web.socket; package org.springframework.web.socket;
import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -87,6 +88,26 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {
} }
@ParameterizedWebSocketTest
void useHeadersAfterHandshake(
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
super.setup(server, webSocketClient, testInfo);
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
URI url = URI.create(getWsBaseUrl() + "/ws");
WebSocketSession session = this.webSocketClient.execute(new TextWebSocketHandler(), headers, url).get();
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
serverHandler.setWaitMessageCount(1);
session.sendMessage(new TextMessage("header"));
session.close();
serverHandler.await();
assertThat(serverHandler.getReceivedMessages()).hasSize(1);
}
@Configuration @Configuration
@EnableWebSocket @EnableWebSocket
static class TestConfig implements WebSocketConfigurer { static class TestConfig implements WebSocketConfigurer {
@ -131,7 +152,10 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {
} }
@Override @Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) { public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws IOException {
if (message instanceof TextMessage textMessage && textMessage.getPayload().equals("header")) {
session.sendMessage(new TextMessage(session.getHandshakeHeaders().headerNames().toString()));
}
this.receivedMessages.add(message); this.receivedMessages.add(message);
if (this.receivedMessages.size() >= this.waitMessageCount) { if (this.receivedMessages.size() >= this.waitMessageCount) {
this.latch.countDown(); this.latch.countDown();

Loading…
Cancel
Save