From 5af9a8edae06b90ba8b466bc07be286b2cc7ab2f Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 3 Jul 2019 15:24:12 +0100 Subject: [PATCH] Ensure WebSocketHttpRequestHandler writes headers Closes gh-23179 --- .../support/WebSocketHttpRequestHandler.java | 5 +- .../server/DefaultHandshakeHandlerTests.java | 51 +++---- .../WebSocketHttpRequestHandlerTests.java | 141 ++++++++++++++++++ 3 files changed, 163 insertions(+), 34 deletions(-) create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandlerTests.java diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java index 296d974dfcc..4f9fa2d4786 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -166,7 +166,6 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl } this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes); chain.applyAfterHandshake(request, response, null); - response.close(); } catch (HandshakeFailureException ex) { failure = ex; @@ -177,8 +176,10 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl finally { if (failure != null) { chain.applyAfterHandshake(request, response, failure); + response.close(); throw failure; } + response.close(); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java index f0346905222..6178fe804d2 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.springframework.http.HttpHeaders; import org.springframework.web.socket.AbstractHttpRequestTests; import org.springframework.web.socket.SubProtocolCapable; import org.springframework.web.socket.WebSocketExtension; @@ -62,14 +63,9 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { public void supportedSubProtocols() { this.handshakeHandler.setSupportedProtocols("stomp", "mqtt"); given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"}); - this.servletRequest.setMethod("GET"); - WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); - headers.setUpgrade("WebSocket"); - headers.setConnection("Upgrade"); - headers.setSecWebSocketVersion("13"); - headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); - headers.setSecWebSocketProtocol("STOMP"); + this.servletRequest.setMethod("GET"); + initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("STOMP"); WebSocketHandler handler = new TextWebSocketHandler(); Map attributes = Collections.emptyMap(); @@ -88,16 +84,10 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { given(this.upgradeStrategy.getSupportedExtensions(this.request)).willReturn(Collections.singletonList(extension1)); this.servletRequest.setMethod("GET"); - - WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); - headers.setUpgrade("WebSocket"); - headers.setConnection("Upgrade"); - headers.setSecWebSocketVersion("13"); - headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); - headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2)); + initHeaders(this.request.getHeaders()).setSecWebSocketExtensions(Arrays.asList(extension1, extension2)); WebSocketHandler handler = new TextWebSocketHandler(); - Map attributes = Collections.emptyMap(); + Map attributes = Collections.emptyMap(); this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); verify(this.upgradeStrategy).upgrade(this.request, this.response, null, @@ -109,16 +99,10 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"}); this.servletRequest.setMethod("GET"); - - WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); - headers.setUpgrade("WebSocket"); - headers.setConnection("Upgrade"); - headers.setSecWebSocketVersion("13"); - headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); - headers.setSecWebSocketProtocol("v11.stomp"); + initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v11.stomp"); WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp"); - Map attributes = Collections.emptyMap(); + Map attributes = Collections.emptyMap(); this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); verify(this.upgradeStrategy).upgrade(this.request, this.response, "v11.stomp", @@ -130,22 +114,25 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"}); this.servletRequest.setMethod("GET"); - - WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); - headers.setUpgrade("WebSocket"); - headers.setConnection("Upgrade"); - headers.setSecWebSocketVersion("13"); - headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); - headers.setSecWebSocketProtocol("v10.stomp"); + initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v10.stomp"); WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp"); - Map attributes = Collections.emptyMap(); + Map attributes = Collections.emptyMap(); this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); verify(this.upgradeStrategy).upgrade(this.request, this.response, null, Collections.emptyList(), null, handler, attributes); } + private WebSocketHttpHeaders initHeaders(HttpHeaders httpHeaders) { + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(httpHeaders); + headers.setUpgrade("WebSocket"); + headers.setConnection("Upgrade"); + headers.setSecWebSocketVersion("13"); + headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); + return headers; + } + private static class SubProtocolCapableHandler extends TextWebSocketHandler implements SubProtocolCapable { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandlerTests.java new file mode 100644 index 00000000000..fb7671eb4d5 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandlerTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.socket.server.support; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import javax.servlet.ServletException; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeFailureException; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link WebSocketHttpRequestHandler}. + * @author Rossen Stoyanchev + * @since 5.1.9 + */ +public class WebSocketHttpRequestHandlerTests { + + private HandshakeHandler handshakeHandler; + + private WebSocketHttpRequestHandler requestHandler; + + private MockHttpServletResponse response; + + + @Before + public void setUp() { + this.handshakeHandler = mock(HandshakeHandler.class); + this.requestHandler = new WebSocketHttpRequestHandler(mock(WebSocketHandler.class), this.handshakeHandler); + this.response = new MockHttpServletResponse(); + } + + + @Test + public void success() throws ServletException, IOException { + TestInterceptor interceptor = new TestInterceptor(true); + this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor)); + this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response); + + verify(this.handshakeHandler).doHandshake(any(), any(), any(), any()); + assertEquals("headerValue", this.response.getHeader("headerName")); + } + + @Test + public void failure() throws ServletException, IOException { + TestInterceptor interceptor = new TestInterceptor(true); + this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor)); + + when(this.handshakeHandler.doHandshake(any(), any(), any(), any())) + .thenThrow(new IllegalStateException("bad state")); + + try { + this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response); + fail(); + } + catch (HandshakeFailureException ex) { + assertSame(ex, interceptor.getException()); + assertEquals("headerValue", this.response.getHeader("headerName")); + assertEquals("exceptionHeaderValue", this.response.getHeader("exceptionHeaderName")); + } + } + + @Test // gh-23179 + public void handshakeNotAllowed() throws ServletException, IOException { + TestInterceptor interceptor = new TestInterceptor(false); + this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor)); + + this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response); + + verifyNoMoreInteractions(this.handshakeHandler); + assertEquals("headerValue", this.response.getHeader("headerName")); + } + + + private static class TestInterceptor implements HandshakeInterceptor { + + private final boolean allowHandshake; + + private Exception exception; + + + private TestInterceptor(boolean allowHandshake) { + this.allowHandshake = allowHandshake; + } + + + public Exception getException() { + return this.exception; + } + + + @Override + public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) { + + response.getHeaders().add("headerName", "headerValue"); + return this.allowHandshake; + } + + @Override + public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Exception exception) { + + response.getHeaders().add("exceptionHeaderName", "exceptionHeaderValue"); + this.exception = exception; + } + } + +}