Browse Source

Merge branch '7.0.x'

pull/36489/head
Brian Clozel 1 week ago
parent
commit
4637da1f36
  1. 3
      spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java
  2. 4
      spring-websocket/src/main/java/org/springframework/web/socket/server/standard/StandardWebSocketUpgradeStrategy.java
  3. 26
      spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java

3
spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java

@ -254,8 +254,7 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
URI uri = request.getURI(); URI uri = request.getURI();
// Copy request headers, as they might be pooled and recycled by // Copy request headers, as they might be pooled and recycled by
// the server implementation once the handshake HTTP exchange is done. // the server implementation once the handshake HTTP exchange is done.
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = HttpHeaders.copyOf(request.getHeaders());
headers.addAll(request.getHeaders());
MultiValueMap<String, HttpCookie> cookies = request.getCookies(); MultiValueMap<String, HttpCookie> cookies = request.getCookies();
Mono<Principal> principal = exchange.getPrincipal(); Mono<Principal> principal = exchange.getPrincipal();
String logPrefix = exchange.getLogPrefix(); String logPrefix = exchange.getLogPrefix();

4
spring-websocket/src/main/java/org/springframework/web/socket/server/standard/StandardWebSocketUpgradeStrategy.java

@ -97,7 +97,9 @@ public class StandardWebSocketUpgradeStrategy implements RequestUpgradeStrategy
@Nullable Principal user, WebSocketHandler wsHandler, Map<String, Object> attrs) @Nullable Principal user, WebSocketHandler wsHandler, Map<String, Object> attrs)
throws HandshakeFailureException { throws HandshakeFailureException {
HttpHeaders headers = request.getHeaders(); // Copy request headers, as they might be pooled and recycled by
// the server implementation once the handshake HTTP exchange is done.
HttpHeaders headers = HttpHeaders.copyOf(request.getHeaders());
InetSocketAddress localAddr = null; InetSocketAddress localAddr = null;
try { try {
localAddr = request.getLocalAddress(); localAddr = request.getLocalAddress();

26
spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java

@ -16,6 +16,7 @@
package org.springframework.web.socket; package org.springframework.web.socket;
import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -87,6 +88,26 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {
} }
@ParameterizedWebSocketTest
void useHeadersAfterHandshake(
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
super.setup(server, webSocketClient, testInfo);
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
URI url = URI.create(getWsBaseUrl() + "/ws");
WebSocketSession session = this.webSocketClient.execute(new TextWebSocketHandler(), headers, url).get();
TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
serverHandler.setWaitMessageCount(1);
session.sendMessage(new TextMessage("header"));
session.close();
serverHandler.await();
assertThat(serverHandler.getReceivedMessages()).hasSize(1);
}
@Configuration @Configuration
@EnableWebSocket @EnableWebSocket
static class TestConfig implements WebSocketConfigurer { static class TestConfig implements WebSocketConfigurer {
@ -131,7 +152,10 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {
} }
@Override @Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) { public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws IOException {
if (message instanceof TextMessage textMessage && textMessage.getPayload().equals("header")) {
session.sendMessage(new TextMessage(session.getHandshakeHeaders().headerNames().toString()));
}
this.receivedMessages.add(message); this.receivedMessages.add(message);
if (this.receivedMessages.size() >= this.waitMessageCount) { if (this.receivedMessages.size() >= this.waitMessageCount) {
this.latch.countDown(); this.latch.countDown();

Loading…
Cancel
Save