diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index dec1b4846b4..d9ca92960eb 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -27,8 +27,6 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -106,9 +104,11 @@ public abstract class AbstractSockJsSession implements SockJsSession { private volatile long timeLastActive = this.timeCreated; - private volatile ScheduledFuture heartbeatTask; + private ScheduledFuture heartbeatFuture; - private final Lock heartbeatLock = new ReentrantLock(); + private HeartbeatTask heartbeatTask; + + private final Object heartbeatLock = new Object(); private volatile boolean heartbeatDisabled; @@ -249,19 +249,10 @@ public abstract class AbstractSockJsSession implements SockJsSession { } public void sendHeartbeat() throws SockJsTransportFailureException { - if (isActive()) { - if (heartbeatLock.tryLock()) { - try { - if (this.heartbeatTask == null) { - // Cancelled while waiting to acquire the lock - return; - } - writeFrame(SockJsFrame.heartbeatFrame()); - scheduleHeartbeat(); - } - finally { - heartbeatLock.unlock(); - } + synchronized (this.heartbeatLock) { + if (isActive() && !this.heartbeatDisabled) { + writeFrame(SockJsFrame.heartbeatFrame()); + scheduleHeartbeat(); } } } @@ -270,56 +261,33 @@ public abstract class AbstractSockJsSession implements SockJsSession { if (this.heartbeatDisabled) { return; } - - Assert.state(this.config.getTaskScheduler() != null, "Expected SockJS TaskScheduler"); - cancelHeartbeat(); - if (!isActive()) { - return; - } - - Date time = new Date(System.currentTimeMillis() + this.config.getHeartbeatTime()); - this.heartbeatTask = this.config.getTaskScheduler().schedule(new Runnable() { - public void run() { - try { - sendHeartbeat(); - } - catch (Throwable ex) { - // ignore - } - } - }, time); - if (logger.isTraceEnabled()) { - logger.trace("Scheduled heartbeat in session " + getId()); - } - } - - protected void cancelHeartbeat() { - try { - ScheduledFuture task = this.heartbeatTask; - this.heartbeatTask = null; - if (task == null || task.isCancelled()) { + synchronized (this.heartbeatLock) { + cancelHeartbeat(); + if (!isActive()) { return; } - + Date time = new Date(System.currentTimeMillis() + this.config.getHeartbeatTime()); + this.heartbeatTask = new HeartbeatTask(); + this.heartbeatFuture = this.config.getTaskScheduler().schedule(this.heartbeatTask, time); if (logger.isTraceEnabled()) { - logger.trace("Cancelling heartbeat in session " + getId()); - } - if (task.cancel(false)) { - return; + logger.trace("Scheduled heartbeat in session " + getId()); } + } + } - if (logger.isTraceEnabled()) { - logger.trace("Failed to cancel heartbeat, acquiring heartbeat write lock."); + protected void cancelHeartbeat() { + synchronized (this.heartbeatLock) { + if (this.heartbeatFuture != null) { + if (logger.isTraceEnabled()) { + logger.trace("Cancelling heartbeat in session " + getId()); + } + this.heartbeatFuture.cancel(false); + this.heartbeatFuture = null; } - this.heartbeatLock.lock(); - - if (logger.isTraceEnabled()) { - logger.trace("Releasing heartbeat lock."); + if (this.heartbeatTask != null) { + this.heartbeatTask.cancel(); + this.heartbeatTask = null; } - this.heartbeatLock.unlock(); - } - catch (Throwable ex) { - logger.debug("Failure while cancelling heartbeat in session " + getId(), ex); } } @@ -469,4 +437,28 @@ public abstract class AbstractSockJsSession implements SockJsSession { return getClass().getSimpleName() + "[id=" + getId() + "]"; } + + private class HeartbeatTask implements Runnable { + + private boolean expired; + + @Override + public void run() { + synchronized (heartbeatLock) { + if (!this.expired) { + try { + sendHeartbeat(); + } + finally { + this.expired = true; + } + } + } + } + + void cancel() { + this.expired = true; + } + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java index d21599adc2e..81356230272 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java @@ -288,6 +288,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests