diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index f2fbadeef5f..b1db14696b2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -27,11 +27,10 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.core.NestedCheckedException; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; @@ -106,9 +105,11 @@ public abstract class AbstractSockJsSession implements SockJsSession { private volatile long timeLastActive = this.timeCreated; - private volatile ScheduledFuture heartbeatTask; + private ScheduledFuture heartbeatFuture; + + private HeartbeatTask heartbeatTask; - private final Lock heartbeatLock = new ReentrantLock(); + private final Object heartbeatLock = new Object(); private volatile boolean heartbeatDisabled; @@ -249,15 +250,10 @@ public abstract class AbstractSockJsSession implements SockJsSession { } public void sendHeartbeat() throws SockJsTransportFailureException { - if (isActive()) { - if (heartbeatLock.tryLock()) { - try { - writeFrame(SockJsFrame.heartbeatFrame()); - scheduleHeartbeat(); - } - finally { - heartbeatLock.unlock(); - } + synchronized (this.heartbeatLock) { + if (isActive() && !this.heartbeatDisabled) { + writeFrame(SockJsFrame.heartbeatFrame()); + scheduleHeartbeat(); } } } @@ -266,56 +262,33 @@ public abstract class AbstractSockJsSession implements SockJsSession { if (this.heartbeatDisabled) { return; } - - Assert.state(this.config.getTaskScheduler() != null, "Expected SockJS TaskScheduler"); - cancelHeartbeat(); - if (!isActive()) { - return; - } - - Date time = new Date(System.currentTimeMillis() + this.config.getHeartbeatTime()); - this.heartbeatTask = this.config.getTaskScheduler().schedule(new Runnable() { - public void run() { - try { - sendHeartbeat(); - } - catch (Throwable ex) { - // ignore - } - } - }, time); - if (logger.isTraceEnabled()) { - logger.trace("Scheduled heartbeat in session " + getId()); - } - } - - protected void cancelHeartbeat() { - try { - ScheduledFuture task = this.heartbeatTask; - this.heartbeatTask = null; - if (task == null || task.isCancelled()) { + synchronized (this.heartbeatLock) { + cancelHeartbeat(); + if (!isActive()) { return; } - + Date time = new Date(System.currentTimeMillis() + this.config.getHeartbeatTime()); + this.heartbeatTask = new HeartbeatTask(); + this.heartbeatFuture = this.config.getTaskScheduler().schedule(this.heartbeatTask, time); if (logger.isTraceEnabled()) { - logger.trace("Cancelling heartbeat in session " + getId()); - } - if (task.cancel(false)) { - return; + logger.trace("Scheduled heartbeat in session " + getId()); } + } + } - if (logger.isTraceEnabled()) { - logger.trace("Failed to cancel heartbeat, acquiring heartbeat write lock."); + protected void cancelHeartbeat() { + synchronized (this.heartbeatLock) { + if (this.heartbeatFuture != null) { + if (logger.isTraceEnabled()) { + logger.trace("Cancelling heartbeat in session " + getId()); + } + this.heartbeatFuture.cancel(false); + this.heartbeatFuture = null; } - this.heartbeatLock.lock(); - - if (logger.isTraceEnabled()) { - logger.trace("Releasing heartbeat lock."); + if (this.heartbeatTask != null) { + this.heartbeatTask.cancel(); + this.heartbeatTask = null; } - this.heartbeatLock.unlock(); - } - catch (Throwable ex) { - logger.debug("Failure while cancelling heartbeat in session " + getId(), ex); } } @@ -465,4 +438,28 @@ public abstract class AbstractSockJsSession implements SockJsSession { return getClass().getSimpleName() + "[id=" + getId() + "]"; } + + private class HeartbeatTask implements Runnable { + + private boolean expired; + + @Override + public void run() { + synchronized (heartbeatLock) { + if (!this.expired) { + try { + sendHeartbeat(); + } + finally { + this.expired = true; + } + } + } + } + + void cancel() { + this.expired = true; + } + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java index 8741ef315e4..e09ff565c82 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/SockJsSessionTests.java @@ -270,6 +270,7 @@ public class SockJsSessionTests extends AbstractSockJsSessionTests