From 2a48ad88fb1ea72f14477447b9515d0537ec93c8 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Sat, 13 Jul 2013 19:05:32 -0400 Subject: [PATCH] Refactor and polish spring-messaging Remove base class for STOMP-related message handler classes (AbstractSimpMessageHandler), polish subclasses and fix issues with more significant updates to STOMP broker relay. Introduce base class for SubscribableChannel implementations providing consistent logging for all channel implementations. --- .../handler/AbstractSimpMessageHandler.java | 164 ---------- .../handler/AbstractSubscriptionRegistry.java | 18 +- .../handler/AnnotationSimpMessageHandler.java | 37 +-- .../handler/DefaultSubscriptionRegistry.java | 9 +- .../handler/SimpleBrokerMessageHandler.java | 69 ++-- .../simp/handler/SubscriptionRegistry.java | 28 +- ...va => StompBrokerRelayMessageHandler.java} | 294 ++++++++++-------- .../messaging/simp/stomp/StompCommand.java | 30 +- .../simp/stomp/StompHeaderAccessor.java | 33 +- .../simp/stomp/StompWebSocketHandler.java | 8 +- .../support/MessageHeaderAccessor.java | 17 + .../support/NativeMessageHeaderAccessor.java | 1 - .../channel/AbstractSubscribableChannel.java | 104 +++++++ ...l.java => ReactorSubscribableChannel.java} | 76 ++--- ...a => TaskExecutorSubscribableChannel.java} | 55 ++-- .../DefaultSubscriptionRegistryTests.java | 53 ++-- .../SimpleBrokerWebMessageHandlerTests.java | 44 +-- .../channel/PublishSubscibeChannelTests.java | 13 +- 18 files changed, 521 insertions(+), 532 deletions(-) delete mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSimpMessageHandler.java rename spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/{StompRelayMessageHandler.java => StompBrokerRelayMessageHandler.java} (60%) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java rename spring-messaging/src/main/java/org/springframework/messaging/support/channel/{ReactorMessageChannel.java => ReactorSubscribableChannel.java} (52%) rename spring-messaging/src/main/java/org/springframework/messaging/support/channel/{PublishSubscribeChannel.java => TaskExecutorSubscribableChannel.java} (57%) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSimpMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSimpMessageHandler.java deleted file mode 100644 index f3738f74e51..00000000000 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSimpMessageHandler.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.messaging.simp.handler; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.MessagingException; -import org.springframework.messaging.simp.SimpMessageHeaderAccessor; -import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.util.AntPathMatcher; -import org.springframework.util.CollectionUtils; -import org.springframework.util.PathMatcher; - - -/** - * @author Rossen Stoyanchev - * @since 4.0 - */ -public abstract class AbstractSimpMessageHandler implements MessageHandler { - - protected final Log logger = LogFactory.getLog(getClass()); - - private final List allowedDestinations = new ArrayList(); - - private final List disallowedDestinations = new ArrayList(); - - private final PathMatcher pathMatcher = new AntPathMatcher(); - - - /** - * Ant-style destination patterns that this service is allowed to process. - */ - public void setAllowedDestinations(String... patterns) { - this.allowedDestinations.clear(); - this.allowedDestinations.addAll(Arrays.asList(patterns)); - } - - /** - * Ant-style destination patterns that this service should skip. - */ - public void setDisallowedDestinations(String... patterns) { - this.disallowedDestinations.clear(); - this.disallowedDestinations.addAll(Arrays.asList(patterns)); - } - - protected abstract Collection getSupportedMessageTypes(); - - - protected boolean canHandle(Message message, SimpMessageType messageType) { - - if (!CollectionUtils.isEmpty(getSupportedMessageTypes())) { - if (!getSupportedMessageTypes().contains(messageType)) { - return false; - } - } - - return isDestinationAllowed(message); - } - - protected boolean isDestinationAllowed(Message message) { - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - String destination = headers.getDestination(); - - if (destination == null) { - return true; - } - - if (!this.disallowedDestinations.isEmpty()) { - for (String pattern : this.disallowedDestinations) { - if (this.pathMatcher.match(pattern, destination)) { - if (logger.isTraceEnabled()) { - logger.trace("Skip message id=" + message.getHeaders().getId()); - } - return false; - } - } - } - - if (!this.allowedDestinations.isEmpty()) { - for (String pattern : this.allowedDestinations) { - if (this.pathMatcher.match(pattern, destination)) { - return true; - } - } - if (logger.isTraceEnabled()) { - logger.trace("Skip message id=" + message.getHeaders().getId()); - } - return false; - } - - return true; - } - - @Override - public final void handleMessage(Message message) throws MessagingException { - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - SimpMessageType messageType = headers.getMessageType(); - - if (!canHandle(message, messageType)) { - return; - } - - if (SimpMessageType.MESSAGE.equals(messageType)) { - handlePublish(message); - } - else if (SimpMessageType.SUBSCRIBE.equals(messageType)) { - handleSubscribe(message); - } - else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { - handleUnsubscribe(message); - } - else if (SimpMessageType.CONNECT.equals(messageType)) { - handleConnect(message); - } - else if (SimpMessageType.DISCONNECT.equals(messageType)) { - handleDisconnect(message); - } - else { - handleOther(message); - } - } - - protected void handleConnect(Message message) { - } - - protected void handlePublish(Message message) { - } - - protected void handleSubscribe(Message message) { - } - - protected void handleUnsubscribe(Message message) { - } - - protected void handleDisconnect(Message message) { - } - - protected void handleOther(Message message) { - } - -} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSubscriptionRegistry.java index 60b59e3f39a..61cde8e89eb 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSubscriptionRegistry.java @@ -34,7 +34,7 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist @Override - public void addSubscription(Message message) { + public final void registerSubscription(Message message) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); if (!SimpMessageType.SUBSCRIBE.equals(headers.getMessageType())) { logger.error("Expected SUBSCRIBE message: " + message); @@ -55,6 +55,9 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist logger.error("Ignoring destination. No destination in message: " + message); return; } + if (logger.isDebugEnabled()) { + logger.debug("Subscribe request: " + message); + } addSubscriptionInternal(sessionId, subscriptionId, destination, message); } @@ -62,7 +65,7 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist String destination, Message message); @Override - public void removeSubscription(Message message) { + public final void unregisterSubscription(Message message) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); if (!SimpMessageType.UNSUBSCRIBE.equals(headers.getMessageType())) { logger.error("Expected UNSUBSCRIBE message: " + message); @@ -78,17 +81,19 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist logger.error("Ignoring subscription. No subscriptionId in message: " + message); return; } + if (logger.isDebugEnabled()) { + logger.debug("Unubscribe request: " + message); + } removeSubscriptionInternal(sessionId, subscriptionId, message); } protected abstract void removeSubscriptionInternal(String sessionId, String subscriptionId, Message message); @Override - public void removeSessionSubscriptions(String sessionId) { - } + public abstract void unregisterAllSubscriptions(String sessionId); @Override - public MultiValueMap findSubscriptions(Message message) { + public final MultiValueMap findSubscriptions(Message message) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); if (!SimpMessageType.MESSAGE.equals(headers.getMessageType())) { logger.error("Unexpected message type: " + message); @@ -99,6 +104,9 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist logger.error("Ignoring destination. No destination in message: " + message); return null; } + if (logger.isTraceEnabled()) { + logger.trace("Find subscriptions, destination=" + headers.getDestination()); + } return findSubscriptionsInternal(destination, message); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java index 4f1bc2fb9e1..51e2c3e72d3 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java @@ -19,13 +19,14 @@ package org.springframework.messaging.simp.handler; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.util.Arrays; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.beans.BeansException; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.ApplicationContext; @@ -34,6 +35,8 @@ import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.support.MessageBodyArgumentResolver; import org.springframework.messaging.handler.annotation.support.MessageExceptionHandlerMethodResolver; @@ -60,8 +63,9 @@ import org.springframework.web.method.HandlerMethodSelector; * @author Rossen Stoyanchev * @since 4.0 */ -public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler - implements ApplicationContextAware, InitializingBean { +public class AnnotationSimpMessageHandler implements MessageHandler, ApplicationContextAware, InitializingBean { + + private static final Log logger = LogFactory.getLog(AnnotationSimpMessageHandler.class); private final MessageChannel outboundChannel; @@ -104,11 +108,6 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler this.applicationContext = applicationContext; } - @Override - protected Collection getSupportedMessageTypes() { - return Arrays.asList(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE, SimpMessageType.UNSUBSCRIBE); - } - @Override public void afterPropertiesSet() { @@ -183,18 +182,20 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler } @Override - public void handlePublish(Message message) { - handleMessageInternal(message, this.messageMethods); - } + public void handleMessage(Message message) throws MessagingException { - @Override - public void handleSubscribe(Message message) { - handleMessageInternal(message, this.subscribeMethods); - } + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + SimpMessageType messageType = headers.getMessageType(); - @Override - public void handleUnsubscribe(Message message) { - handleMessageInternal(message, this.unsubscribeMethods); + if (SimpMessageType.MESSAGE.equals(messageType)) { + handleMessageInternal(message, this.messageMethods); + } + else if (SimpMessageType.SUBSCRIBE.equals(messageType)) { + handleMessageInternal(message, this.subscribeMethods); + } + else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { + handleMessageInternal(message, this.unsubscribeMethods); + } } private void handleMessageInternal(final Message message, Map handlerMethods) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java index 2cf3bc88f17..a29131e975c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java @@ -74,9 +74,14 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { } @Override - public void removeSessionSubscriptions(String sessionId) { + public void unregisterAllSubscriptions(String sessionId) { SessionSubscriptionInfo info = this.subscriptionRegistry.removeSubscriptions(sessionId); - this.destinationCache.removeSessionSubscriptions(info); + if (info != null) { + if (logger.isDebugEnabled()) { + logger.debug("Unregistering subscriptions for sessionId=" + sessionId); + } + this.destinationCache.removeSessionSubscriptions(info); + } } @Override diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java index 0fcf5eb613c..f0d5519963a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java @@ -16,11 +16,12 @@ package org.springframework.messaging.simp.handler; -import java.util.Arrays; -import java.util.Collection; - +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; @@ -32,7 +33,9 @@ import org.springframework.util.MultiValueMap; * @author Rossen Stoyanchev * @since 4.0 */ -public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler { +public class SimpleBrokerMessageHandler implements MessageHandler { + + private static final Log logger = LogFactory.getLog(SimpleBrokerMessageHandler.class); private final MessageChannel outboundChannel; @@ -54,42 +57,36 @@ public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler { this.subscriptionRegistry = subscriptionRegistry; } - @Override - protected Collection getSupportedMessageTypes() { - return Arrays.asList(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE, - SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT); + public SubscriptionRegistry getSubscriptionRegistry() { + return this.subscriptionRegistry; } @Override - public void handleSubscribe(Message message) { + public void handleMessage(Message message) throws MessagingException { - if (logger.isDebugEnabled()) { - logger.debug("Subscribe " + message); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + SimpMessageType messageType = headers.getMessageType(); + + if (SimpMessageType.SUBSCRIBE.equals(messageType)) { + // TODO: need a way to communicate back if subscription was successfully created or + // not in which case an ERROR should be sent back and close the connection + // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE + this.subscriptionRegistry.registerSubscription(message); } - - this.subscriptionRegistry.addSubscription(message); - - // TODO: need a way to communicate back if subscription was successfully created or - // not in which case an ERROR should be sent back and close the connection - // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE - } - - @Override - protected void handleUnsubscribe(Message message) { - this.subscriptionRegistry.removeSubscription(message); - } - - @Override - public void handlePublish(Message message) { - - if (logger.isTraceEnabled()) { - logger.trace("Message received: " + message); + else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { + this.subscriptionRegistry.unregisterSubscription(message); } + else if (SimpMessageType.MESSAGE.equals(messageType)) { + sendMessageToSubscribers(headers.getDestination(), message); + } + else if (SimpMessageType.DISCONNECT.equals(messageType)) { + String sessionId = SimpMessageHeaderAccessor.wrap(message).getSessionId(); + this.subscriptionRegistry.unregisterAllSubscriptions(sessionId); + } + } - String destination = SimpMessageHeaderAccessor.wrap(message).getDestination(); - + protected void sendMessageToSubscribers(String destination, Message message) { MultiValueMap subscriptions = this.subscriptionRegistry.findSubscriptions(message); - for (String sessionId : subscriptions.keySet()) { for (String subscriptionId : subscriptions.get(sessionId)) { @@ -99,7 +96,6 @@ public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler { Message clientMessage = MessageBuilder.withPayload( message.getPayload()).copyHeaders(headers.toMap()).build(); - try { this.outboundChannel.send(clientMessage); } @@ -110,11 +106,4 @@ public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler { } } } - - @Override - public void handleDisconnect(Message message) { - String sessionId = SimpMessageHeaderAccessor.wrap(message).getSessionId(); - this.subscriptionRegistry.removeSessionSubscriptions(sessionId); - } - } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SubscriptionRegistry.java index b76fe231912..6e42369de49 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SubscriptionRegistry.java @@ -19,19 +19,37 @@ package org.springframework.messaging.simp.handler; import org.springframework.messaging.Message; import org.springframework.util.MultiValueMap; - /** + * A registry of subscription by session that allows looking up subscriptions. + * * @author Rossen Stoyanchev * @since 4.0 */ public interface SubscriptionRegistry { - void addSubscription(Message subscribeMessage); + /** + * Register a subscription represented by the given message. + * @param subscribeMessage the subscription request + */ + void registerSubscription(Message subscribeMessage); - void removeSubscription(Message unsubscribeMessage); + /** + * Unregister a subscription. + * @param unsubscribeMessage the request to unsubscribe + */ + void unregisterSubscription(Message unsubscribeMessage); - void removeSessionSubscriptions(String sessionId); + /** + * Remove all subscriptions associated with the given sessionId. + */ + void unregisterAllSubscriptions(String sessionId); - MultiValueMap findSubscriptions(Message message); + /** + * Find all subscriptions that should receive the given message. + * + * @param message the message + * @return a {@link MultiValueMap} from sessionId to subscriptionId's, possibly empty. + */ + MultiValueMap findSubscriptions(Message message); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java similarity index 60% rename from spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java rename to spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 4fae99140de..9a23bf8f354 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -26,12 +26,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.context.SmartLifecycle; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.MessageHandler; import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -51,12 +52,15 @@ import reactor.tcp.spec.TcpClientSpec; * @author Rossen Stoyanchev * @since 4.0 */ -public class StompRelayMessageHandler extends AbstractSimpMessageHandler implements SmartLifecycle { +public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLifecycle { + + private static final Log logger = LogFactory.getLog(StompBrokerRelayMessageHandler.class); private static final String STOMP_RELAY_SYSTEM_SESSION_ID = "stompRelaySystemSessionId"; + private final MessageChannel outboundChannel; - private MessageChannel outboundChannel; + private final String[] destinationPrefixes; private String relayHost = "127.0.0.1"; @@ -81,13 +85,16 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme /** * @param outboundChannel a channel for messages going out to clients + * @param destinationPrefixes the broker supported destination prefixes; destinations + * that do not match the given prefix are ignored. */ - public StompRelayMessageHandler(MessageChannel outboundChannel) { + public StompBrokerRelayMessageHandler(MessageChannel outboundChannel, Collection destinationPrefixes) { Assert.notNull(outboundChannel, "outboundChannel is required"); + Assert.notNull(destinationPrefixes, "destinationPrefixes is required"); this.outboundChannel = outboundChannel; + this.destinationPrefixes = destinationPrefixes.toArray(new String[destinationPrefixes.size()]); } - /** * Set the STOMP message broker host. */ @@ -148,9 +155,11 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme return this.systemPasscode; } - @Override - protected Collection getSupportedMessageTypes() { - return null; + /** + * @return the configured STOMP broker supported destination prefixes. + */ + public String[] getDestinationPrefixes() { + return destinationPrefixes; } @Override @@ -173,44 +182,66 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme @Override public void start() { synchronized (this.lifecycleMonitor) { - + if (logger.isDebugEnabled()) { + logger.debug("Starting STOMP broker relay"); + } this.environment = new Environment(); this.tcpClient = new TcpClientSpec(NettyTcpClient.class) .env(this.environment) .codec(new DelimitedCodec((byte) 0, true, StandardCodecs.STRING_CODEC)) .connect(this.relayHost, this.relayPort) .get(); - - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - headers.setAcceptVersion("1.1,1.2"); - headers.setLogin(this.systemLogin); - headers.setPasscode(this.systemPasscode); - headers.setHeartbeat(0,0); // TODO - Message message = MessageBuilder.withPayload( - new byte[0]).copyHeaders(headers.toNativeHeaderMap()).build(); - - RelaySession session = new RelaySession(message, headers) { - @Override - protected void sendMessageToClient(Message message) { - // TODO: check for ERROR frame (reconnect?) - } - }; - this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session); - + openSystemSession(); this.running = true; } } + /** + * Open a "system" session for sending messages from parts of the application + * not assoicated with a client STOMP session. + */ + private void openSystemSession() { + + RelaySession session = new RelaySession(STOMP_RELAY_SYSTEM_SESSION_ID) { + @Override + protected void sendMessageToClient(Message message) { + // ignore, only used to send messages + // TODO: ERROR frame/reconnect + } + }; + this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setAcceptVersion("1.1,1.2"); + headers.setLogin(this.systemLogin); + headers.setPasscode(this.systemPasscode); + headers.setHeartbeat(0,0); // TODO + + if (logger.isDebugEnabled()) { + logger.debug("Sending STOMP CONNECT frame to initialize \"system\" TCP connection"); + } + Message message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); + session.open(message); + } + @Override public void stop() { synchronized (this.lifecycleMonitor) { + if (logger.isDebugEnabled()) { + logger.debug("Stopping STOMP broker relay"); + } this.running = false; try { this.tcpClient.close().await(5000, TimeUnit.MILLISECONDS); + } + catch (Throwable t) { + logger.error("Failed to close reactor TCP client", t); + } + try { this.environment.shutdown(); } - catch (InterruptedException e) { - // ignore + catch (Throwable t) { + logger.error("Failed to shut down reactor Environment", t); } } } @@ -224,75 +255,87 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme } @Override - public void handleConnect(Message message) { - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - String sessionId = stompHeaders.getSessionId(); - if (sessionId == null) { - logger.error("No sessionId in message " + message); - return; - } - RelaySession relaySession = new RelaySession(message, stompHeaders); - this.relaySessions.put(sessionId, relaySession); - } - - @Override - public void handlePublish(Message message) { - forwardMessage(message, StompCommand.SEND); - } - - @Override - public void handleSubscribe(Message message) { - forwardMessage(message, StompCommand.SUBSCRIBE); - } - - @Override - public void handleUnsubscribe(Message message) { - forwardMessage(message, StompCommand.UNSUBSCRIBE); - } - - @Override - public void handleDisconnect(Message message) { - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - if (stompHeaders.getStompCommand() != null) { - forwardMessage(message, StompCommand.DISCONNECT); - } - String sessionId = stompHeaders.getSessionId(); - if (sessionId == null) { - logger.error("No sessionId in message " + message); - return; - } - } - - @Override - public void handleOther(Message message) { - StompCommand command = (StompCommand) message.getHeaders().get(SimpMessageHeaderAccessor.PROTOCOL_MESSAGE_TYPE); - Assert.notNull(command, "Expected STOMP command: " + message.getHeaders()); - forwardMessage(message, command); - } - - private void forwardMessage(Message message, StompCommand command) { + public void handleMessage(Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - headers.setStompCommandIfNotSet(command); - String sessionId = headers.getSessionId(); - if (sessionId == null) { - if (StompCommand.SEND.equals(command)) { - sessionId = STOMP_RELAY_SYSTEM_SESSION_ID; - } - else { - logger.error("No sessionId in message " + message); - return; - } - } + String destination = headers.getDestination(); + StompCommand command = headers.getStompCommand(); + SimpMessageType messageType = headers.getMessageType(); - RelaySession session = this.relaySessions.get(sessionId); - if (session == null) { - logger.warn("Session id=" + sessionId + " not found. Message cannot be forwarded: " + message); + if (!this.running) { + if (logger.isTraceEnabled()) { + logger.trace("STOMP broker relay not running. Ignoring message id=" + headers.getId()); + } return; } - session.forward(message, headers); + if (SimpMessageType.MESSAGE.equals(messageType)) { + sessionId = (sessionId == null) ? STOMP_RELAY_SYSTEM_SESSION_ID : sessionId; + headers.setSessionId(sessionId); + command = (command == null) ? StompCommand.SEND : command; + headers.setStompCommandIfNotSet(command); + message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build(); + } + + if (headers.getStompCommand() == null) { + logger.error("Ignoring message, no STOMP command: " + message); + return; + } + if (sessionId == null) { + logger.error("Ignoring message, no sessionId: " + message); + return; + } + if (command.requiresDestination() && (destination == null)) { + logger.error("Ignoring " + command + " message, no destination: " + message); + return; + } + + try { + if ((destination == null) || supportsDestination(destination)) { + if (logger.isTraceEnabled()) { + logger.trace("Processing message: " + message); + } + handleInternal(message, messageType, sessionId); + } + } + catch (Throwable t) { + logger.error("Failed to handle message " + message, t); + } + } + + protected boolean supportsDestination(String destination) { + for (String prefix : this.destinationPrefixes) { + if (destination.startsWith(prefix)) { + return true; + } + } + return false; + } + + protected void handleInternal(Message message, SimpMessageType messageType, String sessionId) { + if (SimpMessageType.CONNECT.equals(messageType)) { + RelaySession session = new RelaySession(sessionId); + this.relaySessions.put(sessionId, session); + session.open(message); + } + else if (SimpMessageType.DISCONNECT.equals(messageType)) { + RelaySession session = this.relaySessions.remove(sessionId); + if (session != null) { + if (logger.isTraceEnabled()) { + logger.trace("Session already removed, sessionId=" + sessionId); + } + session.forward(message); + } + } + else { + RelaySession session = this.relaySessions.get(sessionId); + if (session == null) { + logger.warn("Session id=" + sessionId + " not found. Ignoring message: " + message); + return; + } + session.forward(message); + } } @@ -300,21 +343,23 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme private final String sessionId; - private final Promise> promise; - private final BlockingQueue> messageQueue = new LinkedBlockingQueue>(50); - private final Object monitor = new Object(); + private Promise> promise; private volatile boolean isConnected = false; + private final Object monitor = new Object(); - public RelaySession(final Message message, final StompHeaderAccessor stompHeaders) { + public RelaySession(String sessionId) { + Assert.notNull(sessionId, "sessionId is required"); + this.sessionId = sessionId; + } + + public void open(final Message message) { Assert.notNull(message, "message is required"); - Assert.notNull(stompHeaders, "stompHeaders is required"); - this.sessionId = stompHeaders.getSessionId(); this.promise = tcpClient.open(); this.promise.consume(new Consumer>() { @@ -326,11 +371,9 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme readStompFrame(stompFrame); } }); - stompHeaders.setHeartbeat(0,0); // TODO - forwardInternal(message, stompHeaders, connection); + forwardInternal(message, connection); } }); - this.promise.onError(new Consumer() { @Override public void accept(Throwable ex) { @@ -339,14 +382,12 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme sendError(sessionId, "Failed to connect to message broker " + ex.toString()); } }); - - // TODO: ATM no way to detect closed socket } private void readStompFrame(String stompFrame) { + // heartbeat if (StringUtils.isEmpty(stompFrame)) { - // heartbeat? return; } @@ -359,13 +400,13 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme if (StompCommand.CONNECTED == headers.getStompCommand()) { synchronized(this.monitor) { this.isConnected = true; - flushMessages(promise.get()); + flushMessages(this.promise.get()); } return; } if (StompCommand.ERROR == headers.getStompCommand()) { if (logger.isDebugEnabled()) { - logger.warn("STOMP ERROR: " + headers.getMessage() + ". Removing session: " + this.sessionId); + logger.warn("STOMP ERROR: " + headers.getMessage() + ". Removing session id=" + this.sessionId); } relaySessions.remove(this.sessionId); } @@ -388,14 +429,14 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme sendMessageToClient(errorMessage); } - public void forward(Message message, StompHeaderAccessor headers) { + public void forward(Message message) { if (!this.isConnected) { synchronized(this.monitor) { if (!this.isConnected) { this.messageQueue.add(message); if (logger.isTraceEnabled()) { - logger.trace("Queued message " + message + ", queue size=" + this.messageQueue.size()); + logger.trace("Not connected yet, message queued, queue size=" + this.messageQueue.size()); } return; } @@ -405,7 +446,7 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme TcpConnection connection = this.promise.get(); if (this.messageQueue.isEmpty()) { - forwardInternal(message, headers, connection); + forwardInternal(message, connection); } else { this.messageQueue.add(message); @@ -413,36 +454,37 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme } } - private void flushMessages(TcpConnection connection) { - List> messages = new ArrayList>(); - this.messageQueue.drainTo(messages); - for (Message message : messages) { - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - if (!forwardInternal(message, headers, connection)) { - return; - } - } - } + private boolean forwardInternal(Message message, TcpConnection connection) { - private boolean forwardInternal(Message message, StompHeaderAccessor headers, TcpConnection connection) { try { - headers.setStompCommandIfNotSet(StompCommand.SEND); - if (logger.isTraceEnabled()) { - logger.trace("Forwarding message " + message); + logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId()); } - byte[] bytes = stompMessageConverter.fromMessage(message); connection.send(new String(bytes, Charset.forName("UTF-8"))); } catch (Throwable ex) { - logger.error("Failed to forward message " + message, ex); - connection.close(); + logger.error("Forward failed message id=" + message.getHeaders().getId(), ex); + try { + connection.close(); + } + catch (Throwable t) { + // ignore + } sendError(this.sessionId, "Failed to forward message " + message + ": " + ex.getMessage()); return false; } return true; } - } + private void flushMessages(TcpConnection connection) { + List> messages = new ArrayList>(); + this.messageQueue.drainTo(messages); + for (Message message : messages) { + if (!forwardInternal(message, connection)) { + return; + } + } + } + } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java index e66d09efa4e..a5e3c53e98a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java @@ -16,8 +16,11 @@ package org.springframework.messaging.simp.stomp; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; import org.springframework.messaging.simp.SimpMessageType; @@ -49,21 +52,28 @@ public enum StompCommand { ERROR; - private static Map commandToMessageType = new HashMap(); + private static Map messageTypeLookup = new HashMap(); + + private static Set destinationRequiredLookup = + new HashSet(Arrays.asList(SEND, SUBSCRIBE, MESSAGE)); static { - commandToMessageType.put(StompCommand.CONNECT, SimpMessageType.CONNECT); - commandToMessageType.put(StompCommand.STOMP, SimpMessageType.CONNECT); - commandToMessageType.put(StompCommand.SEND, SimpMessageType.MESSAGE); - commandToMessageType.put(StompCommand.MESSAGE, SimpMessageType.MESSAGE); - commandToMessageType.put(StompCommand.SUBSCRIBE, SimpMessageType.SUBSCRIBE); - commandToMessageType.put(StompCommand.UNSUBSCRIBE, SimpMessageType.UNSUBSCRIBE); - commandToMessageType.put(StompCommand.DISCONNECT, SimpMessageType.DISCONNECT); + messageTypeLookup.put(StompCommand.CONNECT, SimpMessageType.CONNECT); + messageTypeLookup.put(StompCommand.STOMP, SimpMessageType.CONNECT); + messageTypeLookup.put(StompCommand.SEND, SimpMessageType.MESSAGE); + messageTypeLookup.put(StompCommand.MESSAGE, SimpMessageType.MESSAGE); + messageTypeLookup.put(StompCommand.SUBSCRIBE, SimpMessageType.SUBSCRIBE); + messageTypeLookup.put(StompCommand.UNSUBSCRIBE, SimpMessageType.UNSUBSCRIBE); + messageTypeLookup.put(StompCommand.DISCONNECT, SimpMessageType.DISCONNECT); } public SimpMessageType getMessageType() { - SimpMessageType messageType = commandToMessageType.get(this); - return (messageType != null) ? messageType : SimpMessageType.OTHER; + SimpMessageType type = messageTypeLookup.get(this); + return (type != null) ? type : SimpMessageType.OTHER; + } + + public boolean requiresDestination() { + return destinationRequiredLookup.contains(this); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java index 8ad764baa8a..670c2f14dea 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java @@ -27,6 +27,7 @@ import org.springframework.http.MediaType; import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -84,21 +85,31 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { */ private StompHeaderAccessor(StompCommand command, Map> externalSourceHeaders) { super(command.getMessageType(), command, externalSourceHeaders); - initSimpMessageHeaders(); + if (externalSourceHeaders != null) { + setSimpMessageHeaders(externalSourceHeaders); + } } - private void initSimpMessageHeaders() { - String destination = getFirstNativeHeader(DESTINATION); - if (destination != null) { - super.setDestination(destination); + private void setSimpMessageHeaders(Map> extHeaders) { + List values = extHeaders.get(StompHeaderAccessor.DESTINATION); + if (!CollectionUtils.isEmpty(values)) { + super.setDestination(values.get(0)); } - String contentType = getFirstNativeHeader(CONTENT_TYPE); - if (contentType != null) { - super.setContentType(MediaType.parseMediaType(contentType)); + values = extHeaders.get(StompHeaderAccessor.CONTENT_TYPE); + if (!CollectionUtils.isEmpty(values)) { + super.setContentType(MediaType.parseMediaType(values.get(0))); } - if (StompCommand.SUBSCRIBE.equals(getStompCommand()) || StompCommand.UNSUBSCRIBE.equals(getStompCommand())) { - if (getFirstNativeHeader(STOMP_ID) != null) { - super.setSubscriptionId(getFirstNativeHeader(STOMP_ID)); + StompCommand command = getStompCommand(); + if (StompCommand.SUBSCRIBE.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) { + values = extHeaders.get(StompHeaderAccessor.STOMP_ID); + if (!CollectionUtils.isEmpty(values)) { + super.setSubscriptionId(values.get(0)); + } + } + else if (StompCommand.MESSAGE.equals(command)) { + values = extHeaders.get(StompHeaderAccessor.SUBSCRIPTION); + if (!CollectionUtils.isEmpty(values)) { + super.setSubscriptionId(values.get(0)); } } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java index 1ea634f23f4..30edde3f04d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java @@ -26,7 +26,6 @@ import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.CloseStatus; @@ -176,10 +175,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { - this.sessions.remove(session.getId()); + String sessionId = session.getId(); + this.sessions.remove(sessionId); - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); - headers.setSessionId(session.getId()); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.setSessionId(sessionId); Message message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); this.clientInputChannel.send(message); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java index aa62f8c36a0..d8fdf64c992 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.UUID; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -202,10 +203,22 @@ public class MessageHeaderAccessor { } } + public UUID getId() { + return (UUID) getHeader(MessageHeaders.ID); + } + + public Long getTimestamp() { + return (Long) getHeader(MessageHeaders.TIMESTAMP); + } + public void setReplyChannel(MessageChannel replyChannel) { setHeader(MessageHeaders.REPLY_CHANNEL, replyChannel); } + public Object getReplyChannel() { + return getHeader(MessageHeaders.REPLY_CHANNEL); + } + public void setReplyChannelName(String replyChannelName) { setHeader(MessageHeaders.REPLY_CHANNEL, replyChannelName); } @@ -214,6 +227,10 @@ public class MessageHeaderAccessor { setHeader(MessageHeaders.ERROR_CHANNEL, errorChannel); } + public Object getErrorChannel() { + return getHeader(MessageHeaders.ERROR_CHANNEL); + } + public void setErrorChannelName(String errorChannelName) { setHeader(MessageHeaders.ERROR_CHANNEL, errorChannelName); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java index e9dc13a2b89..b5d147ef3ad 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java @@ -50,7 +50,6 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { * A constructor for creating new headers, accepting an optional native header map. */ public NativeMessageHeaderAccessor(Map> nativeHeaders) { - super(); this.originalNativeHeaders = nativeHeaders; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java new file mode 100644 index 00000000000..b7a3800d856 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support.channel; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.beans.factory.BeanNameAware; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.SubscribableChannel; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; + + +/** + * Abstract base class for {@link SubscribableChannel} implementations. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public abstract class AbstractSubscribableChannel implements SubscribableChannel, BeanNameAware { + + protected Log logger = LogFactory.getLog(getClass()); + + private String beanName; + + + public AbstractSubscribableChannel() { + this.beanName = getClass().getSimpleName() + "@" + ObjectUtils.getIdentityHexString(this); + } + + /** + * {@inheritDoc} + *

Used primarily for logging purposes. + */ + @Override + public void setBeanName(String name) { + this.beanName = name; + } + + /** + * @return the name for this channel. + */ + public String getBeanName() { + return this.beanName; + } + + @Override + public final boolean send(Message message) { + return send(message, INDEFINITE_TIMEOUT); + } + + @Override + public final boolean send(Message message, long timeout) { + Assert.notNull(message, "Message must not be null"); + if (logger.isTraceEnabled()) { + logger.trace("[" + this.beanName + "] sending message " + message); + } + return sendInternal(message, timeout); + } + + protected abstract boolean sendInternal(Message message, long timeout); + + @Override + public final boolean subscribe(MessageHandler handler) { + if (hasSubscription(handler)) { + logger.warn("[" + this.beanName + "] handler already subscribed " + handler); + return false; + } + if (logger.isDebugEnabled()) { + logger.debug("[" + this.beanName + "] subscribing " + handler); + } + return subscribeInternal(handler); + } + + protected abstract boolean hasSubscription(MessageHandler handler); + + protected abstract boolean subscribeInternal(MessageHandler handler); + + @Override + public final boolean unsubscribe(MessageHandler handler) { + if (logger.isDebugEnabled()) { + logger.debug("[" + this.beanName + "] unsubscribing " + handler); + } + return unsubscribeInternal(handler); + } + + protected abstract boolean unsubscribeInternal(MessageHandler handler); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorMessageChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorSubscribableChannel.java similarity index 52% rename from spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorMessageChannel.java rename to spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorSubscribableChannel.java index 2fc744049da..d60ccea291f 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorMessageChannel.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorSubscribableChannel.java @@ -19,16 +19,14 @@ package org.springframework.messaging.support.channel; import java.util.HashMap; import java.util.Map; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.SubscribableChannel; import reactor.core.Reactor; import reactor.event.Event; import reactor.event.registry.Registration; import reactor.event.selector.ObjectSelector; +import reactor.event.selector.Selector; import reactor.function.Consumer; @@ -36,88 +34,52 @@ import reactor.function.Consumer; * @author Rossen Stoyanchev * @since 4.0 */ -public class ReactorMessageChannel implements SubscribableChannel { - - private static Log logger = LogFactory.getLog(ReactorMessageChannel.class); +public class ReactorSubscribableChannel extends AbstractSubscribableChannel { private final Reactor reactor; private final Object key = new Object(); - private String name = toString(); // TODO + private final Map> registrations = new HashMap>(); - private final Map> registrations = - new HashMap>(); - - - public ReactorMessageChannel(Reactor reactor) { + public ReactorSubscribableChannel(Reactor reactor) { this.reactor = reactor; } - public void setName(String name) { - this.name = name; - } - public String getName() { - return this.name; + @Override + protected boolean hasSubscription(MessageHandler handler) { + return this.registrations.containsKey(handler); } @Override - public boolean send(Message message) { - return send(message, -1); - } - - @Override - public boolean send(Message message, long timeout) { - if (logger.isTraceEnabled()) { - logger.trace("Channel " + getName() + ", sending message id=" + message.getHeaders().getId()); - } + public boolean sendInternal(Message message, long timeout) { this.reactor.notify(this.key, Event.wrap(message)); return true; } @Override - public boolean subscribe(final MessageHandler handler) { - - if (this.registrations.containsKey(handler)) { - logger.warn("Channel " + getName() + ", handler already subscribed " + handler); - return false; - } - - if (logger.isTraceEnabled()) { - logger.trace("Channel " + getName() + ", subscribing handler " + handler); - } - - Registration>>> registration = this.reactor.on( - ObjectSelector.objectSelector(key), new MessageHandlerConsumer(handler)); - + public boolean subscribeInternal(final MessageHandler handler) { + Selector selector = ObjectSelector.objectSelector(this.key); + MessageHandlerConsumer consumer = new MessageHandlerConsumer(handler); + Registration>>> registration = this.reactor.on(selector, consumer); this.registrations.put(handler, registration); - return true; } @Override - public boolean unsubscribe(MessageHandler handler) { - - if (logger.isTraceEnabled()) { - logger.trace("Channel " + getName() + ", removing subscription for handler " + handler); - } - + public boolean unsubscribeInternal(MessageHandler handler) { Registration registration = this.registrations.remove(handler); - if (registration == null) { - if (logger.isTraceEnabled()) { - logger.trace("Channel " + getName() + ", no subscription for handler " + handler); - } - return false; + if (registration != null) { + registration.cancel(); + return true; } - - registration.cancel(); - return true; + return false; } - private static final class MessageHandlerConsumer implements Consumer>> { + private final class MessageHandlerConsumer implements Consumer>> { private final MessageHandler handler; @@ -132,10 +94,8 @@ public class ReactorMessageChannel implements SubscribableChannel { this.handler.handleMessage(message); } catch (Throwable t) { - // TODO logger.error("Failed to process message " + message, t); } } } - } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/PublishSubscribeChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/TaskExecutorSubscribableChannel.java similarity index 57% rename from spring-messaging/src/main/java/org/springframework/messaging/support/channel/PublishSubscribeChannel.java rename to spring-messaging/src/main/java/org/springframework/messaging/support/channel/TaskExecutorSubscribableChannel.java index 5d48bd00a82..0d58885eee0 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/PublishSubscribeChannel.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/TaskExecutorSubscribableChannel.java @@ -18,80 +18,75 @@ package org.springframework.messaging.support.channel; import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; -import java.util.concurrent.Executor; +import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.SubscribableChannel; -import org.springframework.util.Assert; /** * A {@link SubscribableChannel} that sends messages to each of its subscribers. * * @author Phillip Webb + * @author Rossen Stoyanchev * @since 4.0 */ -public class PublishSubscribeChannel implements SubscribableChannel { +public class TaskExecutorSubscribableChannel extends AbstractSubscribableChannel { - private final Executor executor; + private final TaskExecutor executor; private final Set handlers = new CopyOnWriteArraySet(); /** - * Create a new {@link PublishSubscribeChannel} instance where messages will be sent + * Create a new {@link TaskExecutorSubscribableChannel} instance where messages will be sent * in the callers thread. */ - public PublishSubscribeChannel() { + public TaskExecutorSubscribableChannel() { this(null); } /** - * Create a new {@link PublishSubscribeChannel} instance where messages will be sent + * Create a new {@link TaskExecutorSubscribableChannel} instance where messages will be sent * via the specified executor. * @param executor the executor used to send the message or {@code null} to execute in * the callers thread. */ - public PublishSubscribeChannel(Executor executor) { + public TaskExecutorSubscribableChannel(TaskExecutor executor) { this.executor = executor; } + @Override - public boolean send(Message message) { - return send(message, INDEFINITE_TIMEOUT); + protected boolean hasSubscription(MessageHandler handler) { + return this.handlers.contains(handler); } @Override - public boolean send(Message message, long timeout) { - Assert.notNull(message, "Message must not be null"); - Assert.notNull(message.getPayload(), "Message payload must not be null"); + public boolean sendInternal(final Message message, long timeout) { for (final MessageHandler handler : this.handlers) { - dispatchToHandler(message, handler); + if (this.executor == null) { + handler.handleMessage(message); + } + else { + this.executor.execute(new Runnable() { + @Override + public void run() { + handler.handleMessage(message); + } + }); + } } return true; } - private void dispatchToHandler(final Message message, final MessageHandler handler) { - if (this.executor == null) { - handler.handleMessage(message); - } - else { - this.executor.execute(new Runnable() { - @Override - public void run() { - handler.handleMessage(message); - } - }); - } - } - @Override - public boolean subscribe(MessageHandler handler) { + public boolean subscribeInternal(MessageHandler handler) { return this.handlers.add(handler); } @Override - public boolean unsubscribe(MessageHandler handler) { + public boolean unsubscribeInternal(MessageHandler handler) { return this.handlers.remove(handler); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistryTests.java index bae10f6ac9a..fdc4c6c9480 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistryTests.java @@ -25,7 +25,6 @@ import org.junit.Test; import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.messaging.simp.handler.DefaultSubscriptionRegistry; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.MultiValueMap; @@ -49,30 +48,30 @@ public class DefaultSubscriptionRegistryTests { @Test - public void addSubscriptionInvalidInput() { + public void registerSubscriptionInvalidInput() { String sessId = "sess01"; String subsId = "subs01"; String dest = "/foo"; - this.registry.addSubscription(subscribeMessage(null, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(null, subsId, dest)); assertEquals(0, this.registry.findSubscriptions(message(dest)).size()); - this.registry.addSubscription(subscribeMessage(sessId, null, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, null, dest)); assertEquals(0, this.registry.findSubscriptions(message(dest)).size()); - this.registry.addSubscription(subscribeMessage(sessId, subsId, null)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, null)); assertEquals(0, this.registry.findSubscriptions(message(dest)).size()); } @Test - public void addSubscription() { + public void registerSubscription() { String sessId = "sess01"; String subsId = "subs01"; String dest = "/foo"; - this.registry.addSubscription(subscribeMessage(sessId, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest)); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); assertEquals("Expected one element " + actual, 1, actual.size()); @@ -80,14 +79,14 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void addSubscriptionOneSession() { + public void registerSubscriptionOneSession() { String sessId = "sess01"; List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); String dest = "/foo"; for (String subId : subscriptionIds) { - this.registry.addSubscription(subscribeMessage(sessId, subId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subId, dest)); } MultiValueMap actual = this.registry.findSubscriptions(message(dest)); @@ -97,7 +96,7 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void addSubscriptionMultipleSessions() { + public void registerSubscriptionMultipleSessions() { List sessIds = Arrays.asList("sess01", "sess02", "sess03"); List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); @@ -105,7 +104,7 @@ public class DefaultSubscriptionRegistryTests { for (String sessId : sessIds) { for (String subsId : subscriptionIds) { - this.registry.addSubscription(subscribeMessage(sessId, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest)); } } @@ -118,14 +117,14 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void addSubscriptionWithDestinationPattern() { + public void registerSubscriptionWithDestinationPattern() { String sessId = "sess01"; String subsId = "subs01"; String destPattern = "/topic/PRICE.STOCK.*.IBM"; String dest = "/topic/PRICE.STOCK.NASDAQ.IBM"; - this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, destPattern)); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); assertEquals("Expected one element " + actual, 1, actual.size()); @@ -133,13 +132,13 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void addSubscriptionWithDestinationPatternRegex() { + public void registerSubscriptionWithDestinationPatternRegex() { String sessId = "sess01"; String subsId = "subs01"; String destPattern = "/topic/PRICE.STOCK.*.{ticker:(IBM|MSFT)}"; - this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, destPattern)); Message message = message("/topic/PRICE.STOCK.NASDAQ.IBM"); MultiValueMap actual = this.registry.findSubscriptions(message); @@ -159,7 +158,7 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void removeSubscription() { + public void unregisterSubscription() { List sessIds = Arrays.asList("sess01", "sess02", "sess03"); List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); @@ -167,13 +166,13 @@ public class DefaultSubscriptionRegistryTests { for (String sessId : sessIds) { for (String subsId : subscriptionIds) { - this.registry.addSubscription(subscribeMessage(sessId, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest)); } } - this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(0))); - this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(1))); - this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(2))); + this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(0))); + this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(1))); + this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(2))); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); @@ -183,7 +182,7 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void removeSessionSubscriptions() { + public void unregisterAllSubscriptions() { List sessIds = Arrays.asList("sess01", "sess02", "sess03"); List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); @@ -191,12 +190,12 @@ public class DefaultSubscriptionRegistryTests { for (String sessId : sessIds) { for (String subsId : subscriptionIds) { - this.registry.addSubscription(subscribeMessage(sessId, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest)); } } - this.registry.removeSessionSubscriptions(sessIds.get(0)); - this.registry.removeSessionSubscriptions(sessIds.get(1)); + this.registry.unregisterAllSubscriptions(sessIds.get(0)); + this.registry.unregisterAllSubscriptions(sessIds.get(1)); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); @@ -204,6 +203,12 @@ public class DefaultSubscriptionRegistryTests { assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2)))); } + @Test + public void unregisterAllSubscriptionsNoMatch() { + this.registry.unregisterAllSubscriptions("bogus"); + // no exceptions + } + @Test public void findSubscriptionsNoMatches() { MultiValueMap actual = this.registry.findSubscriptions(message("/foo")); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerWebMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerWebMessageHandlerTests.java index 5044f74e6db..156c59ad7dc 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerWebMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerWebMessageHandlerTests.java @@ -16,8 +16,6 @@ package org.springframework.messaging.simp.handler; -import java.util.Arrays; - import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -30,7 +28,6 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; -import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -57,26 +54,19 @@ public class SimpleBrokerWebMessageHandlerTests { } - @Test - public void getSupportedMessageTypes() { - assertEquals(Arrays.asList(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE, - SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT), - this.messageHandler.getSupportedMessageTypes()); - } - @Test public void subcribePublish() { - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub1", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub2", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub3", "/bar")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub3", "/bar")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub1", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub2", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub3", "/bar")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub3", "/bar")); - this.messageHandler.handlePublish(createMessage("/foo", "message1")); - this.messageHandler.handlePublish(createMessage("/bar", "message2")); + this.messageHandler.handleMessage(createMessage("/foo", "message1")); + this.messageHandler.handleMessage(createMessage("/bar", "message2")); verify(this.clientChannel, times(6)).send(this.messageCaptor.capture()); assertCapturedMessage("sess1", "sub1", "/foo"); @@ -93,21 +83,21 @@ public class SimpleBrokerWebMessageHandlerTests { String sess1 = "sess1"; String sess2 = "sess2"; - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub1", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub2", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub3", "/bar")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub3", "/bar")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub1", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub2", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub3", "/bar")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub3", "/bar")); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); headers.setSessionId(sess1); Message message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); - this.messageHandler.handleDisconnect(message); + this.messageHandler.handleMessage(message); - this.messageHandler.handlePublish(createMessage("/foo", "message1")); - this.messageHandler.handlePublish(createMessage("/bar", "message2")); + this.messageHandler.handleMessage(createMessage("/foo", "message1")); + this.messageHandler.handleMessage(createMessage("/bar", "message2")); verify(this.clientChannel, times(3)).send(this.messageCaptor.capture()); assertCapturedMessage(sess2, "sub1", "/foo"); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java index a4c9d5c3c88..1483e1623d5 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java @@ -16,8 +16,6 @@ package org.springframework.messaging.support.channel; -import java.util.concurrent.Executor; - import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -26,18 +24,19 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessagingException; import org.springframework.messaging.support.MessageBuilder; -import org.springframework.messaging.support.channel.PublishSubscribeChannel; import static org.hamcrest.Matchers.*; import static org.junit.Assert.*; import static org.mockito.BDDMockito.*; +import static org.mockito.Mockito.*; /** - * Tests for {@link PublishSubscribeChannel}. + * Tests for {@link TaskExecutorSubscribableChannel}. * * @author Phillip Webb */ @@ -47,7 +46,7 @@ public class PublishSubscibeChannelTests { public ExpectedException thrown = ExpectedException.none(); - private PublishSubscribeChannel channel = new PublishSubscribeChannel(); + private TaskExecutorSubscribableChannel channel = new TaskExecutorSubscribableChannel(); @Mock private MessageHandler handler; @@ -89,8 +88,8 @@ public class PublishSubscibeChannelTests { @Test public void sendWithExecutor() throws Exception { - Executor executor = mock(Executor.class); - this.channel = new PublishSubscribeChannel(executor); + TaskExecutor executor = mock(TaskExecutor.class); + this.channel = new TaskExecutorSubscribableChannel(executor); this.channel.subscribe(this.handler); this.channel.send(this.message); verify(executor).execute(this.runnableCaptor.capture());