diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java index c37bcda17c0..28004020bc3 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractBrokerMessageHandler.java @@ -27,7 +27,10 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.SmartLifecycle; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.SubscribableChannel; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -44,6 +47,12 @@ public abstract class AbstractBrokerMessageHandler protected final Log logger = LogFactory.getLog(getClass()); + private final SubscribableChannel clientInboundChannel; + + private final MessageChannel clientOutboundChannel; + + private final SubscribableChannel brokerChannel; + private final Collection destinationPrefixes; private ApplicationEventPublisher eventPublisher; @@ -61,16 +70,53 @@ public abstract class AbstractBrokerMessageHandler private final Object lifecycleMonitor = new Object(); - public AbstractBrokerMessageHandler() { - this(Collections.emptyList()); + /** + * Constructor with no destination prefixes (matches all destinations). + * @param inboundChannel the channel for receiving messages from clients (e.g. WebSocket clients) + * @param outboundChannel the channel for sending messages to clients (e.g. WebSocket clients) + * @param brokerChannel the channel for the application to send messages to the broker + */ + public AbstractBrokerMessageHandler(SubscribableChannel inboundChannel, MessageChannel outboundChannel, + SubscribableChannel brokerChannel) { + + this(inboundChannel, outboundChannel, brokerChannel, Collections.emptyList()); } - public AbstractBrokerMessageHandler(Collection destinationPrefixes) { + /** + * Constructor with destination prefixes to match to destinations of messages. + * @param inboundChannel the channel for receiving messages from clients (e.g. WebSocket clients) + * @param outboundChannel the channel for sending messages to clients (e.g. WebSocket clients) + * @param brokerChannel the channel for the application to send messages to the broker + * @param destinationPrefixes prefixes to use to filter out messages + */ + public AbstractBrokerMessageHandler(SubscribableChannel inboundChannel, MessageChannel outboundChannel, + SubscribableChannel brokerChannel, Collection destinationPrefixes) { + + Assert.notNull(inboundChannel, "'inboundChannel' must not be null"); + Assert.notNull(outboundChannel, "'outboundChannel' must not be null"); + Assert.notNull(brokerChannel, "'brokerChannel' must not be null"); + + this.clientInboundChannel = inboundChannel; + this.clientOutboundChannel = outboundChannel; + this.brokerChannel = brokerChannel; + destinationPrefixes = (destinationPrefixes != null) ? destinationPrefixes : Collections.emptyList(); this.destinationPrefixes = Collections.unmodifiableCollection(destinationPrefixes); } + public SubscribableChannel getClientInboundChannel() { + return this.clientInboundChannel; + } + + public MessageChannel getClientOutboundChannel() { + return this.clientOutboundChannel; + } + + public SubscribableChannel getBrokerChannel() { + return this.brokerChannel; + } + public Collection getDestinationPrefixes() { return this.destinationPrefixes; } @@ -117,6 +163,8 @@ public abstract class AbstractBrokerMessageHandler if (logger.isInfoEnabled()) { logger.info("Starting..."); } + this.clientInboundChannel.subscribe(this); + this.brokerChannel.subscribe(this); startInternal(); this.running = true; if (logger.isInfoEnabled()) { @@ -135,6 +183,8 @@ public abstract class AbstractBrokerMessageHandler logger.info("Stopping..."); } stopInternal(); + this.clientInboundChannel.unsubscribe(this); + this.brokerChannel.unsubscribe(this); this.running = false; if (logger.isDebugEnabled()) { logger.info("Stopped."); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java index 269c0058a9f..1ce6bdcf491 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java @@ -43,12 +43,6 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { private static final byte[] EMPTY_PAYLOAD = new byte[0]; - private final SubscribableChannel clientInboundChannel; - - private final MessageChannel clientOutboundChannel; - - private final SubscribableChannel brokerChannel; - private SubscriptionRegistry subscriptionRegistry; private PathMatcher pathMatcher; @@ -59,38 +53,19 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { /** * Create a SimpleBrokerMessageHandler instance with the given message channels * and destination prefixes. - * - * @param inChannel the channel for receiving messages from clients (e.g. WebSocket clients) - * @param outChannel the channel for sending messages to clients (e.g. WebSocket clients) + * @param clientInboundChannel the channel for receiving messages from clients (e.g. WebSocket clients) + * @param clientOutboundChannel the channel for sending messages to clients (e.g. WebSocket clients) * @param brokerChannel the channel for the application to send messages to the broker + * @param destinationPrefixes prefixes to use to filter out messages */ - public SimpleBrokerMessageHandler(SubscribableChannel inChannel, MessageChannel outChannel, + public SimpleBrokerMessageHandler(SubscribableChannel clientInboundChannel, MessageChannel clientOutboundChannel, SubscribableChannel brokerChannel, Collection destinationPrefixes) { - super(destinationPrefixes); - Assert.notNull(inChannel, "'clientInboundChannel' must not be null"); - Assert.notNull(outChannel, "'clientOutboundChannel' must not be null"); - Assert.notNull(brokerChannel, "'brokerChannel' must not be null"); - this.clientInboundChannel = inChannel; - this.clientOutboundChannel = outChannel; - this.brokerChannel = brokerChannel; - DefaultSubscriptionRegistry subscriptionRegistry = new DefaultSubscriptionRegistry(); - this.subscriptionRegistry = subscriptionRegistry; + super(clientInboundChannel, clientOutboundChannel, brokerChannel, destinationPrefixes); + this.subscriptionRegistry = new DefaultSubscriptionRegistry(); } - public SubscribableChannel getClientInboundChannel() { - return this.clientInboundChannel; - } - - public MessageChannel getClientOutboundChannel() { - return this.clientOutboundChannel; - } - - public SubscribableChannel getBrokerChannel() { - return this.brokerChannel; - } - /** * Configure a custom SubscriptionRegistry to use for storing subscriptions. * @@ -147,15 +122,11 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { @Override public void startInternal() { publishBrokerAvailableEvent(); - this.clientInboundChannel.subscribe(this); - this.brokerChannel.subscribe(this); } @Override public void stopInternal() { publishBrokerUnavailableEvent(); - this.clientInboundChannel.unsubscribe(this); - this.brokerChannel.unsubscribe(this); } @Override @@ -191,7 +162,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { connectAck.setSessionId(sessionId); connectAck.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message); Message messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAck.getMessageHeaders()); - this.clientOutboundChannel.send(messageOut); + getClientOutboundChannel().send(messageOut); } else if (SimpMessageType.DISCONNECT.equals(messageType)) { if (logger.isDebugEnabled()) { @@ -234,7 +205,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { Object payload = message.getPayload(); Message reply = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders()); try { - this.clientOutboundChannel.send(reply); + getClientOutboundChannel().send(reply); } catch (Throwable ex) { logger.error("Failed to send " + message, ex); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java index 56c4a1c25d2..59bdd8df18b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java @@ -26,6 +26,8 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Bean; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.converter.*; import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; @@ -239,13 +241,13 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC @Bean public AbstractBrokerMessageHandler simpleBrokerMessageHandler() { SimpleBrokerMessageHandler handler = getBrokerRegistry().getSimpleBroker(brokerChannel()); - return (handler != null) ? handler : noopBroker; + return (handler != null) ? handler : new NoOpBrokerMessageHandler(); } @Bean public AbstractBrokerMessageHandler stompBrokerRelayMessageHandler() { AbstractBrokerMessageHandler handler = getBrokerRegistry().getStompBrokerRelay(brokerChannel()); - return (handler != null) ? handler : noopBroker; + return (handler != null) ? handler : new NoOpBrokerMessageHandler(); } @Bean @@ -373,7 +375,11 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC } - private static final AbstractBrokerMessageHandler noopBroker = new AbstractBrokerMessageHandler() { + private class NoOpBrokerMessageHandler extends AbstractBrokerMessageHandler { + + public NoOpBrokerMessageHandler() { + super(clientInboundChannel(), clientOutboundChannel(), brokerChannel()); + } @Override public void start() { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 6ea47a793b2..e4430dee705 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -93,12 +93,6 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } - private final SubscribableChannel clientInboundChannel; - - private final MessageChannel clientOutboundChannel; - - private final SubscribableChannel brokerChannel; - private String relayHost = "127.0.0.1"; private int relayPort = 61613; @@ -130,22 +124,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler /** * Create a StompBrokerRelayMessageHandler instance with the given message channels * and destination prefixes. - * @param clientInChannel the channel for receiving messages from clients (e.g. WebSocket clients) - * @param clientOutChannel the channel for sending messages to clients (e.g. WebSocket clients) + * @param inboundChannel the channel for receiving messages from clients (e.g. WebSocket clients) + * @param outboundChannel the channel for sending messages to clients (e.g. WebSocket clients) * @param brokerChannel the channel for the application to send messages to the broker * @param destinationPrefixes the broker supported destination prefixes; destinations * that do not match the given prefix are ignored. */ - public StompBrokerRelayMessageHandler(SubscribableChannel clientInChannel, MessageChannel clientOutChannel, + public StompBrokerRelayMessageHandler(SubscribableChannel inboundChannel, MessageChannel outboundChannel, SubscribableChannel brokerChannel, Collection destinationPrefixes) { - super(destinationPrefixes); - Assert.notNull(clientInChannel, "'clientInChannel' must not be null"); - Assert.notNull(clientOutChannel, "'clientOutChannel' must not be null"); - Assert.notNull(brokerChannel, "'brokerChannel' must not be null"); - this.clientInboundChannel = clientInChannel; - this.clientOutboundChannel = clientOutChannel; - this.brokerChannel = brokerChannel; + super(inboundChannel, outboundChannel, brokerChannel, destinationPrefixes); } @@ -362,9 +350,6 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @Override protected void startInternal() { - this.clientInboundChannel.subscribe(this); - this.brokerChannel.subscribe(this); - if (this.tcpClient == null) { StompDecoder decoder = new StompDecoder(); decoder.setHeaderInitializer(getHeaderInitializer()); @@ -397,10 +382,6 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @Override protected void stopInternal() { publishBrokerUnavailableEvent(); - - this.clientInboundChannel.unsubscribe(this); - this.brokerChannel.unsubscribe(this); - try { this.tcpClient.shutdown().get(5000, TimeUnit.MILLISECONDS); } @@ -594,7 +575,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler protected void sendMessageToClient(Message message) { if (this.isRemoteClientSession) { - StompBrokerRelayMessageHandler.this.clientOutboundChannel.send(message); + StompBrokerRelayMessageHandler.this.getClientOutboundChannel().send(message); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/BrokerMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/BrokerMessageHandlerTests.java index 2f4660dfb62..d15d90fc1c2 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/BrokerMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/BrokerMessageHandlerTests.java @@ -22,6 +22,8 @@ import org.mockito.MockitoAnnotations; import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.support.GenericMessage; import java.util.ArrayList; @@ -32,6 +34,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; /** * Unit tests for {@link org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler}. @@ -133,6 +136,7 @@ public class BrokerMessageHandlerTests { private TestBrokerMesageHandler() { + super(mock(SubscribableChannel.class), mock(MessageChannel.class), mock(SubscribableChannel.class)); setApplicationEventPublisher(this); }