@ -16,6 +16,7 @@
@@ -16,6 +16,7 @@
package org.springframework.web.socket ;
import java.io.IOException ;
import java.net.URI ;
import java.util.ArrayList ;
import java.util.List ;
@ -87,6 +88,26 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {
@@ -87,6 +88,26 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {
}
@ParameterizedWebSocketTest
void useHeadersAfterHandshake (
WebSocketTestServer server , WebSocketClient webSocketClient , TestInfo testInfo ) throws Exception {
super . setup ( server , webSocketClient , testInfo ) ;
WebSocketHttpHeaders headers = new WebSocketHttpHeaders ( ) ;
URI url = URI . create ( getWsBaseUrl ( ) + "/ws" ) ;
WebSocketSession session = this . webSocketClient . execute ( new TextWebSocketHandler ( ) , headers , url ) . get ( ) ;
TestWebSocketHandler serverHandler = this . wac . getBean ( TestWebSocketHandler . class ) ;
serverHandler . setWaitMessageCount ( 1 ) ;
session . sendMessage ( new TextMessage ( "header" ) ) ;
session . close ( ) ;
serverHandler . await ( ) ;
assertThat ( serverHandler . getReceivedMessages ( ) ) . hasSize ( 1 ) ;
}
@Configuration
@EnableWebSocket
static class TestConfig implements WebSocketConfigurer {
@ -131,7 +152,10 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {
@@ -131,7 +152,10 @@ class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {
}
@Override
public void handleMessage ( WebSocketSession session , WebSocketMessage < ? > message ) {
public void handleMessage ( WebSocketSession session , WebSocketMessage < ? > message ) throws IOException {
if ( message instanceof TextMessage textMessage & & textMessage . getPayload ( ) . equals ( "header" ) ) {
session . sendMessage ( new TextMessage ( session . getHandshakeHeaders ( ) . headerNames ( ) . toString ( ) ) ) ;
}
this . receivedMessages . add ( message ) ;
if ( this . receivedMessages . size ( ) > = this . waitMessageCount ) {
this . latch . countDown ( ) ;