From 3a2c15b0fd5d370f80964b8f4d62b353823e06cf Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 11 Apr 2013 17:15:18 -0400 Subject: [PATCH] Add flush method to ServerHttpResponse This is useful to make sure response headers are written to the underlying response. It is also useful in conjunction with long running, async requests and HTTP streaming, to ensure the Servlet response buffer is sent to the client without additional delay and also causes an IOException to be raised if the client has gone away. --- .../server/AsyncServletServerHttpRequest.java | 22 +++++- .../http/server/ServerHttpResponse.java | 6 ++ .../server/ServletServerHttpResponse.java | 7 ++ .../springframework/sockjs/SockJsSession.java | 4 +- .../sockjs/server/AbstractServerSession.java | 39 ++++++--- .../sockjs/server/AbstractSockJsService.java | 79 ++++++++++--------- .../server/NestedSockJsRuntimeException.java | 39 +++++++++ .../sockjs/server/SockJsFrame.java | 7 +- .../server/support/DefaultSockJsService.java | 7 +- .../AbstractHttpSendingTransportHandler.java | 4 +- .../transport/AbstractHttpServerSession.java | 33 ++++---- .../EventSourceTransportHandler.java | 2 +- .../transport/HtmlFileTransportHandler.java | 2 +- .../transport/PollingHttpServerSession.java | 6 +- .../transport/SockJsWebSocketHandler.java | 4 +- .../transport/StreamingHttpServerSession.java | 10 +-- .../WebSocketSockJsHandlerAdapter.java | 4 +- .../XhrStreamingTransportHandler.java | 2 +- .../websocket/WebSocketSession.java | 4 +- .../endpoint/StandardWebSocketSession.java | 4 +- .../server/endpoint/EndpointRegistration.java | 9 ++- .../handshake/EndpointHandshakeHandler.java | 70 +++++++++------- .../GlassfishRequestUpgradeStrategy.java | 8 +- .../RequestUpgradeStrategy.java} | 4 +- .../TomcatRequestUpgradeStrategy.java | 18 +++-- 25 files changed, 257 insertions(+), 137 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/sockjs/server/NestedSockJsRuntimeException.java rename spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/{EndpointRequestUpgradeStrategy.java => handshake/RequestUpgradeStrategy.java} (91%) diff --git a/spring-web/src/main/java/org/springframework/http/server/AsyncServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/AsyncServletServerHttpRequest.java index 37d424e017d..1e53fc7abf8 100644 --- a/spring-web/src/main/java/org/springframework/http/server/AsyncServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/AsyncServletServerHttpRequest.java @@ -116,21 +116,35 @@ public class AsyncServletServerHttpRequest extends ServletServerHttpRequest // Implementation of AsyncListener methods // --------------------------------------------------------------------- + @Override public void onStartAsync(AsyncEvent event) throws IOException { } + @Override public void onError(AsyncEvent event) throws IOException { } + @Override public void onTimeout(AsyncEvent event) throws IOException { - for (Runnable handler : this.timeoutHandlers) { - handler.run(); + try { + for (Runnable handler : this.timeoutHandlers) { + handler.run(); + } + } + catch (Throwable t) { + // ignore } } + @Override public void onComplete(AsyncEvent event) throws IOException { - for (Runnable handler : this.completionHandlers) { - handler.run(); + try { + for (Runnable handler : this.completionHandlers) { + handler.run(); + } + } + catch (Throwable t) { + // ignore } this.asyncContext = null; this.asyncCompleted.set(true); diff --git a/spring-web/src/main/java/org/springframework/http/server/ServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServerHttpResponse.java index 8ce306271ff..02541fab927 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServerHttpResponse.java @@ -17,6 +17,7 @@ package org.springframework.http.server; import java.io.Closeable; +import java.io.IOException; import org.springframework.http.HttpOutputMessage; import org.springframework.http.HttpStatus; @@ -35,6 +36,11 @@ public interface ServerHttpResponse extends HttpOutputMessage, Closeable { */ void setStatusCode(HttpStatus status); + /** + * TODO + */ + void flush() throws IOException; + /** * Close this response, freeing any resources created. */ diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index b0901324575..3a2ba0445d8 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -80,6 +80,13 @@ public class ServletServerHttpResponse implements ServerHttpResponse { return this.servletResponse.getOutputStream(); } + @Override + public void flush() throws IOException { + writeCookies(); + writeHeaders(); + this.servletResponse.flushBuffer(); + } + public void close() { writeCookies(); writeHeaders(); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSession.java index 77a61128e3c..0dab8e17f52 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/SockJsSession.java @@ -16,6 +16,8 @@ package org.springframework.sockjs; +import java.io.IOException; + /** @@ -25,7 +27,7 @@ package org.springframework.sockjs; */ public interface SockJsSession { - void sendMessage(String text) throws Exception; + void sendMessage(String text) throws IOException; void close(); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSession.java index b4624914bb9..feb7599b807 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractServerSession.java @@ -17,6 +17,8 @@ package org.springframework.sockjs.server; import java.io.EOFException; +import java.io.IOException; +import java.net.SocketException; import java.util.Date; import java.util.concurrent.ScheduledFuture; @@ -54,12 +56,12 @@ public abstract class AbstractServerSession extends SockJsSessionSupport { return this.sockJsConfig; } - public final synchronized void sendMessage(String message) { + public final synchronized void sendMessage(String message) throws IOException { Assert.isTrue(!isClosed(), "Cannot send a message, session has been closed"); sendMessageInternal(message); } - protected abstract void sendMessageInternal(String message); + protected abstract void sendMessageInternal(String message) throws IOException; public final synchronized void close() { if (!isClosed()) { @@ -67,7 +69,12 @@ public abstract class AbstractServerSession extends SockJsSessionSupport { if (isActive()) { // deliver messages "in flight" before sending close frame - writeFrame(SockJsFrame.closeFrameGoAway()); + try { + writeFrame(SockJsFrame.closeFrameGoAway()); + } + catch (Exception e) { + // ignore + } } super.close(); @@ -83,26 +90,33 @@ public abstract class AbstractServerSession extends SockJsSessionSupport { * For internal use within a TransportHandler and the (TransportHandler-specific) * session sub-class. The frame is written only if the connection is active. */ - protected void writeFrame(SockJsFrame frame) { + protected void writeFrame(SockJsFrame frame) throws IOException { if (logger.isTraceEnabled()) { logger.trace("Preparing to write " + frame); } try { writeFrameInternal(frame); } - catch (EOFException ex) { - logger.warn("Client went away. Terminating connection abruptly"); + catch (IOException ex) { + if (ex instanceof EOFException || ex instanceof SocketException) { + logger.warn("Client went away. Terminating connection"); + } + else { + logger.warn("Failed to send message. Terminating connection: " + ex.getMessage()); + } deactivate(); close(); + throw ex; } catch (Throwable t) { - logger.warn("Failed to send message. Terminating connection abruptly: " + t.getMessage()); + logger.warn("Failed to send message. Terminating connection: " + t.getMessage()); deactivate(); close(); + throw new NestedSockJsRuntimeException("Failed to write frame " + frame, t); } } - protected abstract void writeFrameInternal(SockJsFrame frame) throws Exception; + protected abstract void writeFrameInternal(SockJsFrame frame) throws IOException; /** * Some {@link TransportHandler} types cannot detect if a client connection is closed @@ -111,7 +125,7 @@ public abstract class AbstractServerSession extends SockJsSessionSupport { */ protected abstract void deactivate(); - public synchronized void sendHeartbeat() { + public synchronized void sendHeartbeat() throws IOException { if (isActive()) { writeFrame(SockJsFrame.heartbeatFrame()); scheduleHeartbeat(); @@ -127,7 +141,12 @@ public abstract class AbstractServerSession extends SockJsSessionSupport { Date time = new Date(System.currentTimeMillis() + getSockJsConfig().getHeartbeatTime()); this.heartbeatTask = getSockJsConfig().getHeartbeatScheduler().schedule(new Runnable() { public void run() { - sendHeartbeat(); + try { + sendHeartbeat(); + } + catch (IOException e) { + // ignore + } } }, time); if (logger.isTraceEnabled()) { diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java index 412704b3355..bff4ccace3a 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/AbstractSockJsService.java @@ -214,50 +214,57 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { request.getHeaders(); } catch (IllegalArgumentException ex) { - // Ignore invalid Content-Type (TODO!!) + // Ignore invalid Content-Type (TODO) } - if (sockJsPath.equals("") || sockJsPath.equals("/")) { - response.getHeaders().setContentType(new MediaType("text", "plain", Charset.forName("UTF-8"))); - response.getBody().write("Welcome to SockJS!\n".getBytes("UTF-8")); - return; - } - else if (sockJsPath.equals("/info")) { - this.infoHandler.handle(request, response); - return; - } - else if (sockJsPath.matches("/iframe[0-9-.a-z_]*.html")) { - this.iframeHandler.handle(request, response); - return; - } - else if (sockJsPath.equals("/websocket")) { - handleRawWebSocket(request, response); - return; - } - - String[] pathSegments = StringUtils.tokenizeToStringArray(sockJsPath.substring(1), "/"); - if (pathSegments.length != 3) { - logger.debug("Expected /{server}/{session}/{transport} but got " + sockJsPath); - response.setStatusCode(HttpStatus.NOT_FOUND); - return; - } + try { + if (sockJsPath.equals("") || sockJsPath.equals("/")) { + response.getHeaders().setContentType(new MediaType("text", "plain", Charset.forName("UTF-8"))); + response.getBody().write("Welcome to SockJS!\n".getBytes("UTF-8")); + return; + } + else if (sockJsPath.equals("/info")) { + this.infoHandler.handle(request, response); + return; + } + else if (sockJsPath.matches("/iframe[0-9-.a-z_]*.html")) { + this.iframeHandler.handle(request, response); + return; + } + else if (sockJsPath.equals("/websocket")) { + handleRawWebSocket(request, response); + return; + } - String serverId = pathSegments[0]; - String sessionId = pathSegments[1]; - String transport = pathSegments[2]; + String[] pathSegments = StringUtils.tokenizeToStringArray(sockJsPath.substring(1), "/"); + if (pathSegments.length != 3) { + logger.debug("Expected /{server}/{session}/{transport} but got " + sockJsPath); + response.setStatusCode(HttpStatus.NOT_FOUND); + return; + } - if (!validateRequest(serverId, sessionId, transport)) { - response.setStatusCode(HttpStatus.NOT_FOUND); - return; - } + String serverId = pathSegments[0]; + String sessionId = pathSegments[1]; + String transport = pathSegments[2]; - handleRequestInternal(request, response, sessionId, TransportType.fromValue(transport)); + if (!validateRequest(serverId, sessionId, transport)) { + response.setStatusCode(HttpStatus.NOT_FOUND); + return; + } + handleTransportRequest(request, response, sessionId, TransportType.fromValue(transport)); + } + finally { + response.flush(); + } } protected abstract void handleRawWebSocket(ServerHttpRequest request, ServerHttpResponse response) throws Exception; + protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, + String sessionId, TransportType transportType) throws Exception; + protected boolean validateRequest(String serverId, String sessionId, String transport) { if (!StringUtils.hasText(serverId) || !StringUtils.hasText(sessionId) || !StringUtils.hasText(transport)) { @@ -279,9 +286,6 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { return true; } - protected abstract void handleRequestInternal(ServerHttpRequest request, ServerHttpResponse response, - String sessionId, TransportType transportType) throws Exception; - protected void addCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) { String origin = request.getHeaders().getFirst("origin"); @@ -316,7 +320,6 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { logger.debug("Sending Method Not Allowed (405)"); response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); response.getHeaders().setAllow(new HashSet(httpMethods)); - response.getBody(); // ensure headers are flushed (TODO!) } @@ -350,8 +353,6 @@ public abstract class AbstractSockJsService implements SockJsConfiguration { addCorsHeaders(request, response, HttpMethod.GET, HttpMethod.OPTIONS); addCacheHeaders(response); - - response.getBody(); // ensure headers are flushed (TODO!) } else { sendMethodNotAllowed(response, Arrays.asList(HttpMethod.OPTIONS, HttpMethod.GET)); diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/NestedSockJsRuntimeException.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/NestedSockJsRuntimeException.java new file mode 100644 index 00000000000..9f89c4a1207 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/NestedSockJsRuntimeException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.sockjs.server; + +import org.springframework.core.NestedRuntimeException; + + +/** + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +@SuppressWarnings("serial") +public class NestedSockJsRuntimeException extends NestedRuntimeException { + + + public NestedSockJsRuntimeException(String msg) { + super(msg); + } + + public NestedSockJsRuntimeException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsFrame.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsFrame.java index 131354fbde8..ded997e9455 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsFrame.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/SockJsFrame.java @@ -80,8 +80,11 @@ public class SockJsFrame { } public String toString() { - String quoted = this.content.replace("\n", "\\n").replace("\r", "\\r"); - return "SockJsFrame content='" + quoted + "'"; + String result = this.content; + if (result.length() > 80) { + result = result.substring(0, 80) + "...(truncated)"; + } + return "SockJsFrame content='" + result.replace("\n", "\\n").replace("\r", "\\r") + "'"; } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java index 69f325da2c3..8a6ca6a5646 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/support/DefaultSockJsService.java @@ -175,7 +175,7 @@ public class DefaultSockJsService extends AbstractSockJsService } @Override - protected void handleRequestInternal(ServerHttpRequest request, ServerHttpResponse response, + protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, String sessionId, TransportType transportType) throws Exception { TransportHandler transportHandler = this.transportHandlers.get(transportType); @@ -192,7 +192,6 @@ public class DefaultSockJsService extends AbstractSockJsService response.setStatusCode(HttpStatus.NO_CONTENT); addCorsHeaders(request, response, supportedMethod, HttpMethod.OPTIONS); addCacheHeaders(response); - response.getBody(); // ensure headers are flushed (TODO!) } else { List supportedMethods = Arrays.asList(supportedMethod); @@ -214,7 +213,7 @@ public class DefaultSockJsService extends AbstractSockJsService if (isJsessionIdCookieNeeded()) { Cookie cookie = request.getCookies().getCookie("JSESSIONID"); String jsid = (cookie != null) ? cookie.getValue() : "dummy"; - // TODO: Jetty sets Expires header, so bypass Cookie object for now + // TODO: bypass use of Cookie object (causes Jetty to set Expires header) response.getHeaders().set("Set-Cookie", "JSESSIONID=" + jsid + ";path=/"); // TODO } @@ -223,8 +222,6 @@ public class DefaultSockJsService extends AbstractSockJsService } transportHandler.handleRequest(request, response, session); - - response.close(); // ensure headers are flushed (TODO !!) } public SockJsSessionSupport getSockJsSession(String sessionId, TransportHandler transportHandler) { diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java index dffa4389c02..8b43a6aff6a 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpSendingTransportHandler.java @@ -78,7 +78,7 @@ public abstract class AbstractHttpSendingTransportHandler implements TransportHa } else if (httpServerSession.isActive()) { logger.debug("another " + getTransportType() + " connection still open: " + httpServerSession); - httpServerSession.writeFrame(response.getBody(), SockJsFrame.closeFrameAnotherConnectionOpen()); + httpServerSession.writeFrame(response, SockJsFrame.closeFrameAnotherConnectionOpen()); } else { logger.debug("starting " + getTransportType() + " async request"); @@ -91,7 +91,7 @@ public abstract class AbstractHttpSendingTransportHandler implements TransportHa logger.debug("Opening " + getTransportType() + " connection"); session.setFrameFormat(getFrameFormat(request)); - session.writeFrame(response.getBody(), SockJsFrame.openFrame()); + session.writeFrame(response, SockJsFrame.openFrame()); session.connectionInitialized(); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSession.java index 0c0d60ae66f..5f657d94709 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/AbstractHttpServerSession.java @@ -16,7 +16,6 @@ package org.springframework.sockjs.server.transport; import java.io.IOException; -import java.io.OutputStream; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -44,7 +43,7 @@ public abstract class AbstractHttpServerSession extends AbstractServerSession { private AsyncServerHttpRequest asyncRequest; - private OutputStream outputStream; + private ServerHttpResponse response; public AbstractHttpServerSession(String sessionId, SockJsConfiguration sockJsConfig) { @@ -60,7 +59,7 @@ public abstract class AbstractHttpServerSession extends AbstractServerSession { if (isClosed()) { logger.debug("connection already closed"); - writeFrame(response.getBody(), SockJsFrame.closeFrameGoAway()); + writeFrame(response, SockJsFrame.closeFrameGoAway()); return; } @@ -70,11 +69,11 @@ public abstract class AbstractHttpServerSession extends AbstractServerSession { this.asyncRequest.setTimeout(-1); this.asyncRequest.startAsync(); - this.outputStream = response.getBody(); + this.response = response; this.frameFormat = frameFormat; scheduleHeartbeat(); - tryFlush(); + tryFlushCache(); } public synchronized boolean isActive() { @@ -85,24 +84,28 @@ public abstract class AbstractHttpServerSession extends AbstractServerSession { return this.messageCache; } - protected final synchronized void sendMessageInternal(String message) { + protected ServerHttpResponse getResponse() { + return this.response; + } + + protected final synchronized void sendMessageInternal(String message) throws IOException { // assert close() was not called // threads: TH-Session-Endpoint or any other thread this.messageCache.add(message); - tryFlush(); + tryFlushCache(); } - private void tryFlush() { + private void tryFlushCache() throws IOException { if (isActive() && !getMessageCache().isEmpty()) { logger.trace("Flushing messages"); - flush(); + flushCache(); } } /** * Only called if the connection is currently active */ - protected abstract void flush(); + protected abstract void flushCache() throws IOException; protected void closeInternal() { resetRequest(); @@ -110,7 +113,7 @@ public abstract class AbstractHttpServerSession extends AbstractServerSession { protected synchronized void writeFrameInternal(SockJsFrame frame) throws IOException { if (isActive()) { - writeFrame(this.outputStream, frame); + writeFrame(this.response, frame); } } @@ -119,18 +122,18 @@ public abstract class AbstractHttpServerSession extends AbstractServerSession { * even when the connection is not active, as long as a valid OutputStream * is provided. */ - public void writeFrame(OutputStream outputStream, SockJsFrame frame) throws IOException { + public void writeFrame(ServerHttpResponse response, SockJsFrame frame) throws IOException { frame = this.frameFormat.format(frame); if (logger.isTraceEnabled()) { logger.trace("Writing " + frame); } - outputStream.write(frame.getContentBytes()); + response.getBody().write(frame.getContentBytes()); } @Override protected void deactivate() { - this.outputStream = null; this.asyncRequest = null; + this.response = null; updateLastActiveTime(); } @@ -138,8 +141,8 @@ public abstract class AbstractHttpServerSession extends AbstractServerSession { if (isActive()) { this.asyncRequest.completeAsync(); } - this.outputStream = null; this.asyncRequest = null; + this.response = null; updateLastActiveTime(); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java index 0725e6284cb..235b6f5de13 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/EventSourceTransportHandler.java @@ -54,7 +54,7 @@ public class EventSourceTransportHandler extends AbstractStreamingTransportHandl protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { response.getBody().write('\r'); response.getBody().write('\n'); - response.getBody().flush(); + response.flush(); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java index 188c59d9ab4..35e48b238e4 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/HtmlFileTransportHandler.java @@ -102,7 +102,7 @@ public class HtmlFileTransportHandler extends AbstractStreamingTransportHandler String html = String.format(PARTIAL_HTML_CONTENT, callback); response.getBody().write(html.getBytes("UTF-8")); - response.getBody().flush(); + response.flush(); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingHttpServerSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingHttpServerSession.java index 3dd0e8d517d..34302f610f2 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingHttpServerSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/PollingHttpServerSession.java @@ -15,6 +15,8 @@ */ package org.springframework.sockjs.server.transport; +import java.io.IOException; + import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; @@ -26,7 +28,7 @@ public class PollingHttpServerSession extends AbstractHttpServerSession { } @Override - protected void flush() { + protected void flushCache() throws IOException { cancelHeartbeat(); String[] messages = getMessageCache().toArray(new String[getMessageCache().size()]); getMessageCache().clear(); @@ -34,7 +36,7 @@ public class PollingHttpServerSession extends AbstractHttpServerSession { } @Override - protected void writeFrame(SockJsFrame frame) { + protected void writeFrame(SockJsFrame frame) throws IOException { super.writeFrame(frame); resetRequest(); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java index db5da4c3628..c19818e03df 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java @@ -92,14 +92,14 @@ public class SockJsWebSocketHandler extends AbstractSockJsWebSocketHandler { } @Override - public void sendMessageInternal(String message) { + public void sendMessageInternal(String message) throws IOException { cancelHeartbeat(); writeFrame(SockJsFrame.messageFrame(message)); scheduleHeartbeat(); } @Override - protected void writeFrameInternal(SockJsFrame frame) throws Exception { + protected void writeFrameInternal(SockJsFrame frame) throws IOException { if (logger.isTraceEnabled()) { logger.trace("Write " + frame); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingHttpServerSession.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingHttpServerSession.java index accff89ee6e..94cf47cd9de 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingHttpServerSession.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/StreamingHttpServerSession.java @@ -16,8 +16,8 @@ package org.springframework.sockjs.server.transport; import java.io.IOException; -import java.io.OutputStream; +import org.springframework.http.server.ServerHttpResponse; import org.springframework.sockjs.server.SockJsConfiguration; import org.springframework.sockjs.server.SockJsFrame; @@ -31,7 +31,7 @@ public class StreamingHttpServerSession extends AbstractHttpServerSession { super(sessionId, sockJsConfig); } - protected void flush() { + protected void flushCache() throws IOException { cancelHeartbeat(); @@ -64,9 +64,9 @@ public class StreamingHttpServerSession extends AbstractHttpServerSession { } @Override - public void writeFrame(OutputStream outputStream, SockJsFrame frame) throws IOException { - super.writeFrame(outputStream, frame); - outputStream.flush(); + public void writeFrame(ServerHttpResponse response, SockJsFrame frame) throws IOException { + super.writeFrame(response, frame); + response.flush(); } } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketSockJsHandlerAdapter.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketSockJsHandlerAdapter.java index e617ab794ff..82e94f67884 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketSockJsHandlerAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/WebSocketSockJsHandlerAdapter.java @@ -16,6 +16,8 @@ package org.springframework.sockjs.server.transport; +import java.io.IOException; + import org.springframework.sockjs.SockJsHandler; import org.springframework.sockjs.SockJsSessionSupport; import org.springframework.sockjs.server.SockJsConfiguration; @@ -60,7 +62,7 @@ public class WebSocketSockJsHandlerAdapter extends AbstractSockJsWebSocketHandle } @Override - public void sendMessage(String message) throws Exception { + public void sendMessage(String message) throws IOException { this.wsSession.sendText(message); } diff --git a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java index 6a697474979..0f49b75de34 100644 --- a/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/sockjs/server/transport/XhrStreamingTransportHandler.java @@ -56,7 +56,7 @@ public class XhrStreamingTransportHandler extends AbstractStreamingTransportHand response.getBody().write('h'); } response.getBody().write('\n'); - response.getBody().flush(); + response.flush(); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketSession.java index d930dde9d0f..2abe644a5a4 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/WebSocketSession.java @@ -16,6 +16,8 @@ package org.springframework.websocket; +import java.io.IOException; + /** @@ -27,7 +29,7 @@ public interface WebSocketSession { boolean isOpen(); - void sendText(String text) throws Exception; + void sendText(String text) throws IOException; void close(); diff --git a/spring-websocket/src/main/java/org/springframework/websocket/endpoint/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/websocket/endpoint/StandardWebSocketSession.java index 302a0333923..dc4126b168f 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/endpoint/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/endpoint/StandardWebSocketSession.java @@ -16,6 +16,8 @@ package org.springframework.websocket.endpoint; +import java.io.IOException; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.websocket.WebSocketSession; @@ -44,7 +46,7 @@ public class StandardWebSocketSession implements WebSocketSession { } @Override - public void sendText(String text) throws Exception { + public void sendText(String text) throws IOException { logger.trace("Sending text message: " + text); // TODO: check closed this.session.getBasicRemote().sendText(text); diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java index 9dfc509042e..3c33a31017d 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRegistration.java @@ -29,6 +29,8 @@ import javax.websocket.HandshakeResponse; import javax.websocket.server.HandshakeRequest; import javax.websocket.server.ServerEndpointConfig; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; @@ -54,6 +56,8 @@ import org.springframework.websocket.endpoint.WebSocketHandlerEndpoint; */ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAware { + private static Log logger = LogFactory.getLog(EndpointRegistration.class); + private final String path; private final Class endpointClass; @@ -130,8 +134,9 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw if (this.endpointClass != null) { WebApplicationContext wac = ContextLoader.getCurrentWebApplicationContext(); if (wac == null) { - throw new IllegalStateException("Failed to find WebApplicationContext. " - + "Was org.springframework.web.context.ContextLoader used to load the WebApplicationContext?"); + String message = "Failed to find the root WebApplicationContext. Was ContextLoaderListener not used?"; + logger.error(message); + throw new IllegalStateException(); } return wac.getAutowireCapableBeanFactory().createBean(this.endpointClass); } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/EndpointHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/EndpointHandshakeHandler.java index a63033361f6..089a9f8efdb 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/EndpointHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/EndpointHandshakeHandler.java @@ -26,12 +26,11 @@ import org.springframework.websocket.WebSocketHandler; import org.springframework.websocket.endpoint.WebSocketHandlerEndpoint; import org.springframework.websocket.server.AbstractHandshakeHandler; import org.springframework.websocket.server.HandshakeHandler; -import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrategy; /** * A {@link HandshakeHandler} for use with standard Java WebSocket runtimes. A - * container-specific {@link EndpointRequestUpgradeStrategy} is required since standard + * container-specific {@link RequestUpgradeStrategy} is required since standard * Java WebSocket currently does not provide any means of integrating a WebSocket * handshake into an HTTP request processing pipeline. Currently available are * implementations for Tomcat and Glassfish. @@ -41,48 +40,26 @@ import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrat */ public class EndpointHandshakeHandler extends AbstractHandshakeHandler { - private static final boolean tomcatWebSocketPresent = ClassUtils.isPresent( - "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader()); - - private static final boolean glassfishWebSocketPresent = ClassUtils.isPresent( - "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader()); - - private final EndpointRequestUpgradeStrategy upgradeStrategy; + private final RequestUpgradeStrategy upgradeStrategy; public EndpointHandshakeHandler(Endpoint endpoint) { super(endpoint); - this.upgradeStrategy = createUpgradeStrategy(); + this.upgradeStrategy = createRequestUpgradeStrategy(); } public EndpointHandshakeHandler(WebSocketHandler webSocketHandler) { super(webSocketHandler); - this.upgradeStrategy = createUpgradeStrategy(); + this.upgradeStrategy = createRequestUpgradeStrategy(); } public EndpointHandshakeHandler(Class handlerClass) { super(handlerClass); - this.upgradeStrategy = createUpgradeStrategy(); + this.upgradeStrategy = createRequestUpgradeStrategy(); } - private static EndpointRequestUpgradeStrategy createUpgradeStrategy() { - String className; - if (tomcatWebSocketPresent) { - className = "org.springframework.websocket.server.endpoint.support.TomcatRequestUpgradeStrategy"; - } - else if (glassfishWebSocketPresent) { - className = "org.springframework.websocket.server.endpoint.support.GlassfishRequestUpgradeStrategy"; - } - else { - throw new IllegalStateException("No suitable EndpointRequestUpgradeStrategy"); - } - try { - Class clazz = ClassUtils.forName(className, EndpointHandshakeHandler.class.getClassLoader()); - return (EndpointRequestUpgradeStrategy) BeanUtils.instantiateClass(clazz.getConstructor()); - } - catch (Throwable t) { - throw new IllegalStateException("Failed to instantiate " + className, t); - } + protected RequestUpgradeStrategy createRequestUpgradeStrategy() { + return new RequestUpgradeStrategyFactory().create(); } @Override @@ -113,4 +90,37 @@ public class EndpointHandshakeHandler extends AbstractHandshakeHandler { this.upgradeStrategy.upgrade(request, response, protocol, endpoint); } + + private static class RequestUpgradeStrategyFactory { + + private static final String packageName = EndpointHandshakeHandler.class.getPackage().getName(); + + private static final boolean tomcatWebSocketPresent = ClassUtils.isPresent( + "org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader()); + + private static final boolean glassfishWebSocketPresent = ClassUtils.isPresent( + "org.glassfish.tyrus.servlet.TyrusHttpUpgradeHandler", EndpointHandshakeHandler.class.getClassLoader()); + + + private RequestUpgradeStrategy create() { + String className; + if (tomcatWebSocketPresent) { + className = packageName + ".TomcatRequestUpgradeStrategy"; + } + else if (glassfishWebSocketPresent) { + className = packageName + ".GlassfishRequestUpgradeStrategy"; + } + else { + throw new IllegalStateException("No suitable " + RequestUpgradeStrategy.class.getSimpleName()); + } + try { + Class clazz = ClassUtils.forName(className, EndpointHandshakeHandler.class.getClassLoader()); + return (RequestUpgradeStrategy) BeanUtils.instantiateClass(clazz.getConstructor()); + } + catch (Throwable t) { + throw new IllegalStateException("Failed to instantiate " + className, t); + } + } + } + } \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/GlassfishRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/GlassfishRequestUpgradeStrategy.java index 52750dc4d13..6a23883fb5c 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/GlassfishRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/GlassfishRequestUpgradeStrategy.java @@ -46,7 +46,6 @@ import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; import org.springframework.websocket.server.endpoint.EndpointRegistration; -import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrategy; /** * Glassfish support for upgrading an {@link HttpServletRequest} during a WebSocket @@ -55,7 +54,7 @@ import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrat * @author Rossen Stoyanchev * @since 4.0 */ -public class GlassfishRequestUpgradeStrategy implements EndpointRequestUpgradeStrategy { +public class GlassfishRequestUpgradeStrategy implements RequestUpgradeStrategy { private final static Random random = new Random(); @@ -77,7 +76,8 @@ public class GlassfishRequestUpgradeStrategy implements EndpointRequestUpgradeSt servletResponse = new AlreadyUpgradedResponseWrapper(servletResponse); TyrusEndpoint tyrusEndpoint = createTyrusEndpoint(servletRequest, endpoint); - WebSocketEngine.getEngine().register(tyrusEndpoint); + WebSocketEngine engine = WebSocketEngine.getEngine(); + engine.register(tyrusEndpoint); try { if (!performUpgrade(servletRequest, servletResponse, request.getHeaders(), tyrusEndpoint)) { @@ -85,7 +85,7 @@ public class GlassfishRequestUpgradeStrategy implements EndpointRequestUpgradeSt } } finally { - WebSocketEngine.getEngine().unregister(tyrusEndpoint); + engine.unregister(tyrusEndpoint); } } diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/RequestUpgradeStrategy.java similarity index 91% rename from spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRequestUpgradeStrategy.java rename to spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/RequestUpgradeStrategy.java index 20472434293..2dd09023b72 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/EndpointRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/RequestUpgradeStrategy.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.websocket.server.endpoint; +package org.springframework.websocket.server.endpoint.handshake; import javax.websocket.Endpoint; @@ -29,7 +29,7 @@ import org.springframework.http.server.ServerHttpResponse; * @author Rossen Stoyanchev * @since 4.0 */ -public interface EndpointRequestUpgradeStrategy { +public interface RequestUpgradeStrategy { String[] getSupportedVersions(); diff --git a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/TomcatRequestUpgradeStrategy.java index a44468bdab8..b07cd2e0dff 100644 --- a/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/websocket/server/endpoint/handshake/TomcatRequestUpgradeStrategy.java @@ -16,6 +16,7 @@ package org.springframework.websocket.server.endpoint.handshake; +import java.io.IOException; import java.lang.reflect.Method; import java.util.Collections; @@ -29,10 +30,10 @@ import org.apache.tomcat.websocket.server.WsServerContainer; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.sockjs.server.NestedSockJsRuntimeException; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.websocket.server.endpoint.EndpointRegistration; -import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrategy; /** @@ -41,7 +42,7 @@ import org.springframework.websocket.server.endpoint.EndpointRequestUpgradeStrat * @author Rossen Stoyanchev * @since 4.0 */ -public class TomcatRequestUpgradeStrategy implements EndpointRequestUpgradeStrategy { +public class TomcatRequestUpgradeStrategy implements RequestUpgradeStrategy { @Override @@ -51,7 +52,7 @@ public class TomcatRequestUpgradeStrategy implements EndpointRequestUpgradeStrat @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol, - Endpoint endpoint) throws Exception { + Endpoint endpoint) throws IOException { Assert.isTrue(request instanceof ServletServerHttpRequest); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -59,9 +60,14 @@ public class TomcatRequestUpgradeStrategy implements EndpointRequestUpgradeStrat WsHttpUpgradeHandler upgradeHandler = servletRequest.upgrade(WsHttpUpgradeHandler.class); WsHandshakeRequest webSocketRequest = new WsHandshakeRequest(servletRequest); - Method method = ReflectionUtils.findMethod(WsHandshakeRequest.class, "finished"); - ReflectionUtils.makeAccessible(method); - method.invoke(webSocketRequest); + try { + Method method = ReflectionUtils.findMethod(WsHandshakeRequest.class, "finished"); + ReflectionUtils.makeAccessible(method); + method.invoke(webSocketRequest); + } + catch (Exception ex) { + throw new NestedSockJsRuntimeException("Failed to upgrade HttpServletRequest", ex); + } // TODO: use ServletContext attribute when Tomcat is updated WsServerContainer serverContainer = WsServerContainer.getServerContainer();