diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java index b9da1e4d02f..9eed2bf0030 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java @@ -254,8 +254,7 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle { URI uri = request.getURI(); // Copy request headers, as they might be pooled and recycled by // the server implementation once the handshake HTTP exchange is done. - HttpHeaders headers = new HttpHeaders(); - headers.addAll(request.getHeaders()); + HttpHeaders headers = HttpHeaders.copyOf(request.getHeaders()); MultiValueMap cookies = request.getCookies(); Mono principal = exchange.getPrincipal(); String logPrefix = exchange.getLogPrefix(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/StandardWebSocketUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/StandardWebSocketUpgradeStrategy.java index 5890430f220..956f7503467 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/StandardWebSocketUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/StandardWebSocketUpgradeStrategy.java @@ -97,7 +97,9 @@ public class StandardWebSocketUpgradeStrategy implements RequestUpgradeStrategy @Nullable Principal user, WebSocketHandler wsHandler, Map attrs) throws HandshakeFailureException { - HttpHeaders headers = request.getHeaders(); + // Copy request headers, as they might be pooled and recycled by + // the server implementation once the handshake HTTP exchange is done. + HttpHeaders headers = HttpHeaders.copyOf(request.getHeaders()); InetSocketAddress localAddr = null; try { localAddr = request.getLocalAddress(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java index 6c824209762..ab35f7fcefb 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java @@ -16,6 +16,7 @@ package org.springframework.web.socket; +import java.io.IOException; import java.net.URI; import java.util.ArrayList; import java.util.List; @@ -87,6 +88,26 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests { } + @ParameterizedWebSocketTest + void useHeadersAfterHandshake( + WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception { + + super.setup(server, webSocketClient, testInfo); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); + URI url = URI.create(getWsBaseUrl() + "/ws"); + WebSocketSession session = this.webSocketClient.execute(new TextWebSocketHandler(), headers, url).get(); + TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); + serverHandler.setWaitMessageCount(1); + + session.sendMessage(new TextMessage("header")); + + session.close(); + serverHandler.await(); + assertThat(serverHandler.getReceivedMessages()).hasSize(1); + } + + @Configuration @EnableWebSocket static class TestConfig implements WebSocketConfigurer { @@ -131,7 +152,10 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests { } @Override - public void handleMessage(WebSocketSession session, WebSocketMessage message) { + public void handleMessage(WebSocketSession session, WebSocketMessage message) throws IOException { + if (message instanceof TextMessage textMessage && textMessage.getPayload().equals("header")) { + session.sendMessage(new TextMessage(session.getHandshakeHeaders().headerNames().toString())); + } this.receivedMessages.add(message); if (this.receivedMessages.size() >= this.waitMessageCount) { this.latch.countDown();