diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java index d067fc9ec1..f1f7b830d5 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java @@ -16,8 +16,6 @@ package org.springframework.messaging.simp.broker; -import static org.springframework.messaging.support.MessageHeaderAccessor.getAccessor; - import java.util.Collection; import java.util.HashSet; import java.util.LinkedHashMap; @@ -46,6 +44,8 @@ import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.PathMatcher; +import static org.springframework.messaging.support.MessageHeaderAccessor.getAccessor; + /** * Implementation of {@link SubscriptionRegistry} that stores subscriptions @@ -209,7 +209,13 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { for (String sessionId : allMatches.keySet()) { for (String subId : allMatches.get(sessionId)) { SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId); + if (info == null) { + continue; + } Subscription sub = info.getSubscription(subId); + if (sub == null) { + continue; + } Expression expression = sub.getSelectorExpression(); if (expression == null) { result.add(sessionId, subId); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java index 341d3074ff..c44fe3d85b 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java @@ -20,6 +20,9 @@ import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; import org.junit.Test; @@ -28,10 +31,13 @@ import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.AntPathMatcher; import org.springframework.util.MultiValueMap; +import org.springframework.util.PathMatcher; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** @@ -417,6 +423,38 @@ public class DefaultSubscriptionRegistryTests { // no ConcurrentModificationException } + @Test + public void findSubscriptionsWithConcurrentUnregisterAllSubscriptions() throws Exception { + + final CountDownLatch iterationPausedLatch = new CountDownLatch(1); + final CountDownLatch iterationResumeLatch = new CountDownLatch(1); + final CountDownLatch iterationDoneLatch = new CountDownLatch(1); + + PathMatcher pathMatcher = new PausingPathMatcher(iterationPausedLatch, iterationResumeLatch); + this.registry.setPathMatcher(pathMatcher); + this.registry.registerSubscription(subscribeMessage("sess1", "1", "/foo")); + this.registry.registerSubscription(subscribeMessage("sess2", "1", "/foo")); + + AtomicReference> subscriptions = new AtomicReference<>(); + new Thread(() -> { + subscriptions.set(registry.findSubscriptions(createMessage("/foo"))); + iterationDoneLatch.countDown(); + }).start(); + + assertTrue(iterationPausedLatch.await(10, TimeUnit.SECONDS)); + + this.registry.unregisterAllSubscriptions("sess1"); + this.registry.unregisterAllSubscriptions("sess2"); + + iterationResumeLatch.countDown(); + assertTrue(iterationDoneLatch.await(10, TimeUnit.SECONDS)); + + MultiValueMap result = subscriptions.get(); + assertNotNull(result); + assertEquals(0, result.size()); + } + + private Message createMessage(String destination) { SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(); accessor.setDestination(destination); @@ -452,4 +490,34 @@ public class DefaultSubscriptionRegistryTests { return list; } + /** + * An extension of AntPathMatcher with a pair of CountDownLatch's to pause + * while matching, allowing another thread to something, and resume when the + * other thread signals it's okay to do so. + */ + private static class PausingPathMatcher extends AntPathMatcher { + + private final CountDownLatch iterationPausedLatch; + + private final CountDownLatch iterationResumeLatch; + + + public PausingPathMatcher(CountDownLatch iterationPausedLatch, CountDownLatch iterationResumeLatch) { + this.iterationPausedLatch = iterationPausedLatch; + this.iterationResumeLatch = iterationResumeLatch; + } + + @Override + public boolean match(String pattern, String path) { + try { + this.iterationPausedLatch.countDown(); + assertTrue(this.iterationResumeLatch.await(10, TimeUnit.SECONDS)); + return super.match(pattern, path); + } + catch (InterruptedException ex) { + ex.printStackTrace(); + return false; + } + } + } }