Browse Source

Ensure WebSocketHttpRequestHandler writes headers

Closes gh-23179
pull/23837/head
Rossen Stoyanchev 7 years ago
parent
commit
5af9a8edae
  1. 5
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java
  2. 51
      spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java
  3. 141
      spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandlerTests.java

5
spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -166,7 +166,6 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl @@ -166,7 +166,6 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl
}
this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes);
chain.applyAfterHandshake(request, response, null);
response.close();
}
catch (HandshakeFailureException ex) {
failure = ex;
@ -177,8 +176,10 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl @@ -177,8 +176,10 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl
finally {
if (failure != null) {
chain.applyAfterHandshake(request, response, failure);
response.close();
throw failure;
}
response.close();
}
}

51
spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ import org.junit.Test; @@ -26,6 +26,7 @@ import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.WebSocketExtension;
@ -62,14 +63,9 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { @@ -62,14 +63,9 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
public void supportedSubProtocols() {
this.handshakeHandler.setSupportedProtocols("stomp", "mqtt");
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
this.servletRequest.setMethod("GET");
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
headers.setSecWebSocketProtocol("STOMP");
this.servletRequest.setMethod("GET");
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("STOMP");
WebSocketHandler handler = new TextWebSocketHandler();
Map<String, Object> attributes = Collections.emptyMap();
@ -88,16 +84,10 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { @@ -88,16 +84,10 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
given(this.upgradeStrategy.getSupportedExtensions(this.request)).willReturn(Collections.singletonList(extension1));
this.servletRequest.setMethod("GET");
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
initHeaders(this.request.getHeaders()).setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
WebSocketHandler handler = new TextWebSocketHandler();
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
Map<String, Object> attributes = Collections.emptyMap();
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
@ -109,16 +99,10 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { @@ -109,16 +99,10 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
this.servletRequest.setMethod("GET");
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
headers.setSecWebSocketProtocol("v11.stomp");
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v11.stomp");
WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
Map<String, Object> attributes = Collections.emptyMap();
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
verify(this.upgradeStrategy).upgrade(this.request, this.response, "v11.stomp",
@ -130,22 +114,25 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { @@ -130,22 +114,25 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
this.servletRequest.setMethod("GET");
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
headers.setSecWebSocketProtocol("v10.stomp");
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v10.stomp");
WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
Map<String, Object> attributes = Collections.emptyMap();
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
Collections.emptyList(), null, handler, attributes);
}
private WebSocketHttpHeaders initHeaders(HttpHeaders httpHeaders) {
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(httpHeaders);
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
return headers;
}
private static class SubProtocolCapableHandler extends TextWebSocketHandler implements SubProtocolCapable {

141
spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandlerTests.java

@ -0,0 +1,141 @@ @@ -0,0 +1,141 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server.support;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import javax.servlet.ServletException;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Unit tests for {@link WebSocketHttpRequestHandler}.
* @author Rossen Stoyanchev
* @since 5.1.9
*/
public class WebSocketHttpRequestHandlerTests {
private HandshakeHandler handshakeHandler;
private WebSocketHttpRequestHandler requestHandler;
private MockHttpServletResponse response;
@Before
public void setUp() {
this.handshakeHandler = mock(HandshakeHandler.class);
this.requestHandler = new WebSocketHttpRequestHandler(mock(WebSocketHandler.class), this.handshakeHandler);
this.response = new MockHttpServletResponse();
}
@Test
public void success() throws ServletException, IOException {
TestInterceptor interceptor = new TestInterceptor(true);
this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor));
this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response);
verify(this.handshakeHandler).doHandshake(any(), any(), any(), any());
assertEquals("headerValue", this.response.getHeader("headerName"));
}
@Test
public void failure() throws ServletException, IOException {
TestInterceptor interceptor = new TestInterceptor(true);
this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor));
when(this.handshakeHandler.doHandshake(any(), any(), any(), any()))
.thenThrow(new IllegalStateException("bad state"));
try {
this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response);
fail();
}
catch (HandshakeFailureException ex) {
assertSame(ex, interceptor.getException());
assertEquals("headerValue", this.response.getHeader("headerName"));
assertEquals("exceptionHeaderValue", this.response.getHeader("exceptionHeaderName"));
}
}
@Test // gh-23179
public void handshakeNotAllowed() throws ServletException, IOException {
TestInterceptor interceptor = new TestInterceptor(false);
this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor));
this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response);
verifyNoMoreInteractions(this.handshakeHandler);
assertEquals("headerValue", this.response.getHeader("headerName"));
}
private static class TestInterceptor implements HandshakeInterceptor {
private final boolean allowHandshake;
private Exception exception;
private TestInterceptor(boolean allowHandshake) {
this.allowHandshake = allowHandshake;
}
public Exception getException() {
return this.exception;
}
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) {
response.getHeaders().add("headerName", "headerValue");
return this.allowHandshake;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) {
response.getHeaders().add("exceptionHeaderName", "exceptionHeaderValue");
this.exception = exception;
}
}
}
Loading…
Cancel
Save