diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractSubscriptionRegistry.java index 4a5497c5704..3a9a04ece1a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractSubscriptionRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; /** @@ -42,66 +41,84 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist @Override public final void registerSubscription(Message message) { MessageHeaders headers = message.getHeaders(); + SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); - Assert.isTrue(SimpMessageType.SUBSCRIBE.equals(messageType), "Expected SUBSCRIBE: " + message); + if (!SimpMessageType.SUBSCRIBE.equals(messageType)) { + throw new IllegalArgumentException("Expected SUBSCRIBE: " + message); + } + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); if (sessionId == null) { logger.error("No sessionId in " + message); return; } + String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers); if (subscriptionId == null) { logger.error("No subscriptionId in " + message); return; } + String destination = SimpMessageHeaderAccessor.getDestination(headers); if (destination == null) { logger.error("No destination in " + message); return; } + addSubscriptionInternal(sessionId, subscriptionId, destination, message); } - protected abstract void addSubscriptionInternal(String sessionId, String subscriptionId, - String destination, Message message); - @Override public final void unregisterSubscription(Message message) { MessageHeaders headers = message.getHeaders(); + SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); - Assert.isTrue(SimpMessageType.UNSUBSCRIBE.equals(messageType), "Expected UNSUBSCRIBE: " + message); + if (!SimpMessageType.UNSUBSCRIBE.equals(messageType)) { + throw new IllegalArgumentException("Expected UNSUBSCRIBE: " + message); + } + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); if (sessionId == null) { logger.error("No sessionId in " + message); return; } + String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers); if (subscriptionId == null) { logger.error("No subscriptionId " + message); return; } + removeSubscriptionInternal(sessionId, subscriptionId, message); } + @Override + public final MultiValueMap findSubscriptions(Message message) { + MessageHeaders headers = message.getHeaders(); + + SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers); + if (!SimpMessageType.MESSAGE.equals(type)) { + throw new IllegalArgumentException("Unexpected message type: " + type); + } + + String destination = SimpMessageHeaderAccessor.getDestination(headers); + if (destination == null) { + logger.error("No destination in " + message); + return null; + } + + return findSubscriptionsInternal(destination, message); + } + + + protected abstract void addSubscriptionInternal(String sessionId, String subscriptionId, + String destination, Message message); + protected abstract void removeSubscriptionInternal(String sessionId, String subscriptionId, Message message); @Override public abstract void unregisterAllSubscriptions(String sessionId); - @Override - public final MultiValueMap findSubscriptions(Message message) { - MessageHeaders headers = message.getHeaders(); - SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers); - Assert.isTrue(SimpMessageType.MESSAGE.equals(type), "Unexpected message type: " + type); - String destination = SimpMessageHeaderAccessor.getDestination(headers); - if (destination == null) { - logger.error("No destination in " + message); - return null; - } - return findSubscriptionsInternal(destination, message); - } - - protected abstract MultiValueMap findSubscriptionsInternal( - String destination, Message message); + protected abstract MultiValueMap findSubscriptionsInternal(String destination, Message message); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java index f71c67d1894..d45159f45b6 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java @@ -25,7 +25,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import org.junit.Before; import org.junit.Test; import org.springframework.messaging.Message; @@ -36,9 +35,7 @@ import org.springframework.util.AntPathMatcher; import org.springframework.util.MultiValueMap; import org.springframework.util.PathMatcher; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** @@ -82,8 +79,8 @@ public class DefaultSubscriptionRegistryTests { String dest = "/foo"; this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest)); - MultiValueMap actual = this.registry.findSubscriptions(createMessage(dest)); + MultiValueMap actual = this.registry.findSubscriptions(createMessage(dest)); assertNotNull(actual); assertEquals("Expected one element " + actual, 1, actual.size()); assertEquals(Collections.singletonList(subsId), actual.get(sessId)); @@ -100,7 +97,6 @@ public class DefaultSubscriptionRegistryTests { } MultiValueMap actual = this.registry.findSubscriptions(createMessage(dest)); - assertNotNull(actual); assertEquals(1, actual.size()); assertEquals(subscriptionIds, sort(actual.get(sessId)));