diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java index 37986ac66a3..02094d75c5f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java @@ -64,11 +64,22 @@ public abstract class AbstractHttpSendingTransportHandler extends TransportHandl if (sockJsSession.isNew()) { logger.debug("Opening " + getTransportType() + " connection"); - sockJsSession.setInitialRequest(request, response, getFrameFormat(request)); + sockJsSession.handleInitialRequest(request, response, getFrameFormat(request)); + } + else if (sockJsSession.isClosed()) { + logger.debug("Connection already closed (but not removed yet)"); + SockJsFrame frame = SockJsFrame.closeFrameGoAway(); + try { + response.getBody().write(frame.getContentBytes()); + } + catch (IOException ex) { + throw new SockJsException("Failed to send " + frame, sockJsSession.getId(), ex); + } + return; } else if (!sockJsSession.isActive()) { logger.debug("starting " + getTransportType() + " async request"); - sockJsSession.setLongPollingRequest(request, response, getFrameFormat(request)); + sockJsSession.startLongPollingRequest(request, response, getFrameFormat(request)); } else { logger.debug("another " + getTransportType() + " connection still open: " + sockJsSession); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java index 0ac2af03984..bc441df097c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java @@ -134,14 +134,20 @@ public class HtmlFileTransportHandler extends AbstractHttpSendingTransportHandle } @Override - protected void writePrelude() throws IOException { - + protected void afterRequestUpdated() { // we already validated the parameter above.. String callback = getCallbackParam(getRequest()); String html = String.format(PARTIAL_HTML_CONTENT, callback); - getResponse().getBody().write(html.getBytes("UTF-8")); - getResponse().flush(); + + try { + getResponse().getBody().write(html.getBytes("UTF-8")); + getResponse().flush(); + } + catch (IOException e) { + tryCloseWithSockJsTransportError(e, CloseStatus.SERVER_ERROR); + throw new SockJsTransportFailureException("Failed to write HTML content", getId(), e); + } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index adf50a9215f..5b76febfafd 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -134,7 +134,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { return this.acceptedProtocol; } - public synchronized void setInitialRequest(ServerHttpRequest request, ServerHttpResponse response, + public synchronized void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, FrameFormat frameFormat) throws SockJsException { udpateRequest(request, response, frameFormat); @@ -164,15 +164,10 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { protected void writePrelude() throws IOException { } - public synchronized void setLongPollingRequest(ServerHttpRequest request, ServerHttpResponse response, - FrameFormat frameFormat) throws SockJsException { + public synchronized void startLongPollingRequest(ServerHttpRequest request, + ServerHttpResponse response, FrameFormat frameFormat) throws SockJsException { udpateRequest(request, response, frameFormat); - if (isClosed()) { - logger.debug("Connection already closed (but not removed yet)"); - writeFrame(SockJsFrame.closeFrameGoAway()); - return; - } try { this.asyncRequestControl.start(-1); scheduleHeartbeat(); @@ -184,7 +179,8 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { } } - private void udpateRequest(ServerHttpRequest request, ServerHttpResponse response, FrameFormat frameFormat) { + private void udpateRequest(ServerHttpRequest request, ServerHttpResponse response, + FrameFormat frameFormat) { Assert.notNull(request, "expected request"); Assert.notNull(response, "expected response"); Assert.notNull(frameFormat, "expected frameFormat"); @@ -192,8 +188,11 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.response = response; this.asyncRequestControl = request.getAsyncRequestControl(response); this.frameFormat = frameFormat; + afterRequestUpdated(); } + protected void afterRequestUpdated() { + } @Override public synchronized boolean isActive() { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java index 017bc67f2ce..ee824256cca 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java @@ -46,14 +46,14 @@ public class StreamingSockJsSession extends AbstractHttpSockJsSession { @Override - public synchronized void setInitialRequest(ServerHttpRequest request, ServerHttpResponse response, + public synchronized void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, FrameFormat frameFormat) throws SockJsException { - super.setInitialRequest(request, response, frameFormat); + super.handleInitialRequest(request, response, frameFormat); // the WebSocketHandler delegate may have closed the session if (!isClosed()) { - super.setLongPollingRequest(request, response, frameFormat); + super.startLongPollingRequest(request, response, frameFormat); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java index ed8c9b28a4d..d270d7e28bf 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSessionTests.java @@ -78,7 +78,7 @@ public class AbstractHttpSockJsSessionTests extends BaseAbstractSockJsSessionTes @Test public void setInitialRequest() throws Exception { - this.session.setInitialRequest(this.request, this.response, this.frameFormat); + this.session.handleInitialRequest(this.request, this.response, this.frameFormat); assertTrue(this.session.hasRequest()); assertTrue(this.session.hasResponse()); @@ -93,7 +93,7 @@ public class AbstractHttpSockJsSessionTests extends BaseAbstractSockJsSessionTes public void setLongPollingRequest() throws Exception { this.session.getMessageCache().add("x"); - this.session.setLongPollingRequest(this.request, this.response, this.frameFormat); + this.session.startLongPollingRequest(this.request, this.response, this.frameFormat); assertTrue(this.session.hasRequest()); assertTrue(this.session.hasResponse()); @@ -111,7 +111,7 @@ public class AbstractHttpSockJsSessionTests extends BaseAbstractSockJsSessionTes this.session.delegateConnectionClosed(CloseStatus.NORMAL); assertClosed(); - this.session.setLongPollingRequest(this.request, this.response, this.frameFormat); + this.session.startLongPollingRequest(this.request, this.response, this.frameFormat); assertEquals("c[3000,\"Go away!\"]", this.servletResponse.getContentAsString()); assertFalse(this.servletRequest.isAsyncStarted());