@ -20,6 +20,9 @@ import java.util.Arrays;
@@ -20,6 +20,9 @@ import java.util.Arrays;
import java.util.Collections ;
import java.util.Iterator ;
import java.util.List ;
import java.util.concurrent.CountDownLatch ;
import java.util.concurrent.TimeUnit ;
import java.util.concurrent.atomic.AtomicReference ;
import org.junit.Before ;
import org.junit.Test ;
@ -28,10 +31,13 @@ import org.springframework.messaging.Message;
@@ -28,10 +31,13 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor ;
import org.springframework.messaging.simp.SimpMessageType ;
import org.springframework.messaging.support.MessageBuilder ;
import org.springframework.util.AntPathMatcher ;
import org.springframework.util.MultiValueMap ;
import org.springframework.util.PathMatcher ;
import static org.junit.Assert.assertEquals ;
import static org.junit.Assert.assertNotNull ;
import static org.junit.Assert.assertTrue ;
/ * *
@ -417,6 +423,38 @@ public class DefaultSubscriptionRegistryTests {
@@ -417,6 +423,38 @@ public class DefaultSubscriptionRegistryTests {
// no ConcurrentModificationException
}
@Test
public void findSubscriptionsWithConcurrentUnregisterAllSubscriptions ( ) throws Exception {
final CountDownLatch iterationPausedLatch = new CountDownLatch ( 1 ) ;
final CountDownLatch iterationResumeLatch = new CountDownLatch ( 1 ) ;
final CountDownLatch iterationDoneLatch = new CountDownLatch ( 1 ) ;
PathMatcher pathMatcher = new PausingPathMatcher ( iterationPausedLatch , iterationResumeLatch ) ;
this . registry . setPathMatcher ( pathMatcher ) ;
this . registry . registerSubscription ( subscribeMessage ( "sess1" , "1" , "/foo" ) ) ;
this . registry . registerSubscription ( subscribeMessage ( "sess2" , "1" , "/foo" ) ) ;
AtomicReference < MultiValueMap < String , String > > subscriptions = new AtomicReference < > ( ) ;
new Thread ( ( ) - > {
subscriptions . set ( registry . findSubscriptions ( createMessage ( "/foo" ) ) ) ;
iterationDoneLatch . countDown ( ) ;
} ) . start ( ) ;
assertTrue ( iterationPausedLatch . await ( 10 , TimeUnit . SECONDS ) ) ;
this . registry . unregisterAllSubscriptions ( "sess1" ) ;
this . registry . unregisterAllSubscriptions ( "sess2" ) ;
iterationResumeLatch . countDown ( ) ;
assertTrue ( iterationDoneLatch . await ( 10 , TimeUnit . SECONDS ) ) ;
MultiValueMap < String , String > result = subscriptions . get ( ) ;
assertNotNull ( result ) ;
assertEquals ( 0 , result . size ( ) ) ;
}
private Message < ? > createMessage ( String destination ) {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor . create ( ) ;
accessor . setDestination ( destination ) ;
@ -452,4 +490,34 @@ public class DefaultSubscriptionRegistryTests {
@@ -452,4 +490,34 @@ public class DefaultSubscriptionRegistryTests {
return list ;
}
/ * *
* An extension of AntPathMatcher with a pair of CountDownLatch ' s to pause
* while matching , allowing another thread to something , and resume when the
* other thread signals it ' s okay to do so .
* /
private static class PausingPathMatcher extends AntPathMatcher {
private final CountDownLatch iterationPausedLatch ;
private final CountDownLatch iterationResumeLatch ;
public PausingPathMatcher ( CountDownLatch iterationPausedLatch , CountDownLatch iterationResumeLatch ) {
this . iterationPausedLatch = iterationPausedLatch ;
this . iterationResumeLatch = iterationResumeLatch ;
}
@Override
public boolean match ( String pattern , String path ) {
try {
this . iterationPausedLatch . countDown ( ) ;
assertTrue ( this . iterationResumeLatch . await ( 10 , TimeUnit . SECONDS ) ) ;
return super . match ( pattern , path ) ;
}
catch ( InterruptedException ex ) {
ex . printStackTrace ( ) ;
return false ;
}
}
}
}