diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSimpMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSimpMessageHandler.java deleted file mode 100644 index f3738f74e51..00000000000 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSimpMessageHandler.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.messaging.simp.handler; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.MessagingException; -import org.springframework.messaging.simp.SimpMessageHeaderAccessor; -import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.util.AntPathMatcher; -import org.springframework.util.CollectionUtils; -import org.springframework.util.PathMatcher; - - -/** - * @author Rossen Stoyanchev - * @since 4.0 - */ -public abstract class AbstractSimpMessageHandler implements MessageHandler { - - protected final Log logger = LogFactory.getLog(getClass()); - - private final List allowedDestinations = new ArrayList(); - - private final List disallowedDestinations = new ArrayList(); - - private final PathMatcher pathMatcher = new AntPathMatcher(); - - - /** - * Ant-style destination patterns that this service is allowed to process. - */ - public void setAllowedDestinations(String... patterns) { - this.allowedDestinations.clear(); - this.allowedDestinations.addAll(Arrays.asList(patterns)); - } - - /** - * Ant-style destination patterns that this service should skip. - */ - public void setDisallowedDestinations(String... patterns) { - this.disallowedDestinations.clear(); - this.disallowedDestinations.addAll(Arrays.asList(patterns)); - } - - protected abstract Collection getSupportedMessageTypes(); - - - protected boolean canHandle(Message message, SimpMessageType messageType) { - - if (!CollectionUtils.isEmpty(getSupportedMessageTypes())) { - if (!getSupportedMessageTypes().contains(messageType)) { - return false; - } - } - - return isDestinationAllowed(message); - } - - protected boolean isDestinationAllowed(Message message) { - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - String destination = headers.getDestination(); - - if (destination == null) { - return true; - } - - if (!this.disallowedDestinations.isEmpty()) { - for (String pattern : this.disallowedDestinations) { - if (this.pathMatcher.match(pattern, destination)) { - if (logger.isTraceEnabled()) { - logger.trace("Skip message id=" + message.getHeaders().getId()); - } - return false; - } - } - } - - if (!this.allowedDestinations.isEmpty()) { - for (String pattern : this.allowedDestinations) { - if (this.pathMatcher.match(pattern, destination)) { - return true; - } - } - if (logger.isTraceEnabled()) { - logger.trace("Skip message id=" + message.getHeaders().getId()); - } - return false; - } - - return true; - } - - @Override - public final void handleMessage(Message message) throws MessagingException { - - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - SimpMessageType messageType = headers.getMessageType(); - - if (!canHandle(message, messageType)) { - return; - } - - if (SimpMessageType.MESSAGE.equals(messageType)) { - handlePublish(message); - } - else if (SimpMessageType.SUBSCRIBE.equals(messageType)) { - handleSubscribe(message); - } - else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { - handleUnsubscribe(message); - } - else if (SimpMessageType.CONNECT.equals(messageType)) { - handleConnect(message); - } - else if (SimpMessageType.DISCONNECT.equals(messageType)) { - handleDisconnect(message); - } - else { - handleOther(message); - } - } - - protected void handleConnect(Message message) { - } - - protected void handlePublish(Message message) { - } - - protected void handleSubscribe(Message message) { - } - - protected void handleUnsubscribe(Message message) { - } - - protected void handleDisconnect(Message message) { - } - - protected void handleOther(Message message) { - } - -} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSubscriptionRegistry.java index 60b59e3f39a..61cde8e89eb 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AbstractSubscriptionRegistry.java @@ -34,7 +34,7 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist @Override - public void addSubscription(Message message) { + public final void registerSubscription(Message message) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); if (!SimpMessageType.SUBSCRIBE.equals(headers.getMessageType())) { logger.error("Expected SUBSCRIBE message: " + message); @@ -55,6 +55,9 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist logger.error("Ignoring destination. No destination in message: " + message); return; } + if (logger.isDebugEnabled()) { + logger.debug("Subscribe request: " + message); + } addSubscriptionInternal(sessionId, subscriptionId, destination, message); } @@ -62,7 +65,7 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist String destination, Message message); @Override - public void removeSubscription(Message message) { + public final void unregisterSubscription(Message message) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); if (!SimpMessageType.UNSUBSCRIBE.equals(headers.getMessageType())) { logger.error("Expected UNSUBSCRIBE message: " + message); @@ -78,17 +81,19 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist logger.error("Ignoring subscription. No subscriptionId in message: " + message); return; } + if (logger.isDebugEnabled()) { + logger.debug("Unubscribe request: " + message); + } removeSubscriptionInternal(sessionId, subscriptionId, message); } protected abstract void removeSubscriptionInternal(String sessionId, String subscriptionId, Message message); @Override - public void removeSessionSubscriptions(String sessionId) { - } + public abstract void unregisterAllSubscriptions(String sessionId); @Override - public MultiValueMap findSubscriptions(Message message) { + public final MultiValueMap findSubscriptions(Message message) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); if (!SimpMessageType.MESSAGE.equals(headers.getMessageType())) { logger.error("Unexpected message type: " + message); @@ -99,6 +104,9 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist logger.error("Ignoring destination. No destination in message: " + message); return null; } + if (logger.isTraceEnabled()) { + logger.trace("Find subscriptions, destination=" + headers.getDestination()); + } return findSubscriptionsInternal(destination, message); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java index 4f1bc2fb9e1..51e2c3e72d3 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationSimpMessageHandler.java @@ -19,13 +19,14 @@ package org.springframework.messaging.simp.handler; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.util.Arrays; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.beans.BeansException; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.ApplicationContext; @@ -34,6 +35,8 @@ import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.support.MessageBodyArgumentResolver; import org.springframework.messaging.handler.annotation.support.MessageExceptionHandlerMethodResolver; @@ -60,8 +63,9 @@ import org.springframework.web.method.HandlerMethodSelector; * @author Rossen Stoyanchev * @since 4.0 */ -public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler - implements ApplicationContextAware, InitializingBean { +public class AnnotationSimpMessageHandler implements MessageHandler, ApplicationContextAware, InitializingBean { + + private static final Log logger = LogFactory.getLog(AnnotationSimpMessageHandler.class); private final MessageChannel outboundChannel; @@ -104,11 +108,6 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler this.applicationContext = applicationContext; } - @Override - protected Collection getSupportedMessageTypes() { - return Arrays.asList(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE, SimpMessageType.UNSUBSCRIBE); - } - @Override public void afterPropertiesSet() { @@ -183,18 +182,20 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler } @Override - public void handlePublish(Message message) { - handleMessageInternal(message, this.messageMethods); - } + public void handleMessage(Message message) throws MessagingException { - @Override - public void handleSubscribe(Message message) { - handleMessageInternal(message, this.subscribeMethods); - } + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + SimpMessageType messageType = headers.getMessageType(); - @Override - public void handleUnsubscribe(Message message) { - handleMessageInternal(message, this.unsubscribeMethods); + if (SimpMessageType.MESSAGE.equals(messageType)) { + handleMessageInternal(message, this.messageMethods); + } + else if (SimpMessageType.SUBSCRIBE.equals(messageType)) { + handleMessageInternal(message, this.subscribeMethods); + } + else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { + handleMessageInternal(message, this.unsubscribeMethods); + } } private void handleMessageInternal(final Message message, Map handlerMethods) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java index 2cf3bc88f17..a29131e975c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistry.java @@ -74,9 +74,14 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { } @Override - public void removeSessionSubscriptions(String sessionId) { + public void unregisterAllSubscriptions(String sessionId) { SessionSubscriptionInfo info = this.subscriptionRegistry.removeSubscriptions(sessionId); - this.destinationCache.removeSessionSubscriptions(info); + if (info != null) { + if (logger.isDebugEnabled()) { + logger.debug("Unregistering subscriptions for sessionId=" + sessionId); + } + this.destinationCache.removeSessionSubscriptions(info); + } } @Override diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java index 0fcf5eb613c..f0d5519963a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleBrokerMessageHandler.java @@ -16,11 +16,12 @@ package org.springframework.messaging.simp.handler; -import java.util.Arrays; -import java.util.Collection; - +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; @@ -32,7 +33,9 @@ import org.springframework.util.MultiValueMap; * @author Rossen Stoyanchev * @since 4.0 */ -public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler { +public class SimpleBrokerMessageHandler implements MessageHandler { + + private static final Log logger = LogFactory.getLog(SimpleBrokerMessageHandler.class); private final MessageChannel outboundChannel; @@ -54,42 +57,36 @@ public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler { this.subscriptionRegistry = subscriptionRegistry; } - @Override - protected Collection getSupportedMessageTypes() { - return Arrays.asList(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE, - SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT); + public SubscriptionRegistry getSubscriptionRegistry() { + return this.subscriptionRegistry; } @Override - public void handleSubscribe(Message message) { + public void handleMessage(Message message) throws MessagingException { - if (logger.isDebugEnabled()) { - logger.debug("Subscribe " + message); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + SimpMessageType messageType = headers.getMessageType(); + + if (SimpMessageType.SUBSCRIBE.equals(messageType)) { + // TODO: need a way to communicate back if subscription was successfully created or + // not in which case an ERROR should be sent back and close the connection + // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE + this.subscriptionRegistry.registerSubscription(message); } - - this.subscriptionRegistry.addSubscription(message); - - // TODO: need a way to communicate back if subscription was successfully created or - // not in which case an ERROR should be sent back and close the connection - // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE - } - - @Override - protected void handleUnsubscribe(Message message) { - this.subscriptionRegistry.removeSubscription(message); - } - - @Override - public void handlePublish(Message message) { - - if (logger.isTraceEnabled()) { - logger.trace("Message received: " + message); + else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { + this.subscriptionRegistry.unregisterSubscription(message); } + else if (SimpMessageType.MESSAGE.equals(messageType)) { + sendMessageToSubscribers(headers.getDestination(), message); + } + else if (SimpMessageType.DISCONNECT.equals(messageType)) { + String sessionId = SimpMessageHeaderAccessor.wrap(message).getSessionId(); + this.subscriptionRegistry.unregisterAllSubscriptions(sessionId); + } + } - String destination = SimpMessageHeaderAccessor.wrap(message).getDestination(); - + protected void sendMessageToSubscribers(String destination, Message message) { MultiValueMap subscriptions = this.subscriptionRegistry.findSubscriptions(message); - for (String sessionId : subscriptions.keySet()) { for (String subscriptionId : subscriptions.get(sessionId)) { @@ -99,7 +96,6 @@ public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler { Message clientMessage = MessageBuilder.withPayload( message.getPayload()).copyHeaders(headers.toMap()).build(); - try { this.outboundChannel.send(clientMessage); } @@ -110,11 +106,4 @@ public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler { } } } - - @Override - public void handleDisconnect(Message message) { - String sessionId = SimpMessageHeaderAccessor.wrap(message).getSessionId(); - this.subscriptionRegistry.removeSessionSubscriptions(sessionId); - } - } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SubscriptionRegistry.java index b76fe231912..6e42369de49 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SubscriptionRegistry.java @@ -19,19 +19,37 @@ package org.springframework.messaging.simp.handler; import org.springframework.messaging.Message; import org.springframework.util.MultiValueMap; - /** + * A registry of subscription by session that allows looking up subscriptions. + * * @author Rossen Stoyanchev * @since 4.0 */ public interface SubscriptionRegistry { - void addSubscription(Message subscribeMessage); + /** + * Register a subscription represented by the given message. + * @param subscribeMessage the subscription request + */ + void registerSubscription(Message subscribeMessage); - void removeSubscription(Message unsubscribeMessage); + /** + * Unregister a subscription. + * @param unsubscribeMessage the request to unsubscribe + */ + void unregisterSubscription(Message unsubscribeMessage); - void removeSessionSubscriptions(String sessionId); + /** + * Remove all subscriptions associated with the given sessionId. + */ + void unregisterAllSubscriptions(String sessionId); - MultiValueMap findSubscriptions(Message message); + /** + * Find all subscriptions that should receive the given message. + * + * @param message the message + * @return a {@link MultiValueMap} from sessionId to subscriptionId's, possibly empty. + */ + MultiValueMap findSubscriptions(Message message); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java similarity index 60% rename from spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java rename to spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 4fae99140de..9a23bf8f354 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -26,12 +26,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.context.SmartLifecycle; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.MessageHandler; import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -51,12 +52,15 @@ import reactor.tcp.spec.TcpClientSpec; * @author Rossen Stoyanchev * @since 4.0 */ -public class StompRelayMessageHandler extends AbstractSimpMessageHandler implements SmartLifecycle { +public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLifecycle { + + private static final Log logger = LogFactory.getLog(StompBrokerRelayMessageHandler.class); private static final String STOMP_RELAY_SYSTEM_SESSION_ID = "stompRelaySystemSessionId"; + private final MessageChannel outboundChannel; - private MessageChannel outboundChannel; + private final String[] destinationPrefixes; private String relayHost = "127.0.0.1"; @@ -81,13 +85,16 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme /** * @param outboundChannel a channel for messages going out to clients + * @param destinationPrefixes the broker supported destination prefixes; destinations + * that do not match the given prefix are ignored. */ - public StompRelayMessageHandler(MessageChannel outboundChannel) { + public StompBrokerRelayMessageHandler(MessageChannel outboundChannel, Collection destinationPrefixes) { Assert.notNull(outboundChannel, "outboundChannel is required"); + Assert.notNull(destinationPrefixes, "destinationPrefixes is required"); this.outboundChannel = outboundChannel; + this.destinationPrefixes = destinationPrefixes.toArray(new String[destinationPrefixes.size()]); } - /** * Set the STOMP message broker host. */ @@ -148,9 +155,11 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme return this.systemPasscode; } - @Override - protected Collection getSupportedMessageTypes() { - return null; + /** + * @return the configured STOMP broker supported destination prefixes. + */ + public String[] getDestinationPrefixes() { + return destinationPrefixes; } @Override @@ -173,44 +182,66 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme @Override public void start() { synchronized (this.lifecycleMonitor) { - + if (logger.isDebugEnabled()) { + logger.debug("Starting STOMP broker relay"); + } this.environment = new Environment(); this.tcpClient = new TcpClientSpec(NettyTcpClient.class) .env(this.environment) .codec(new DelimitedCodec((byte) 0, true, StandardCodecs.STRING_CODEC)) .connect(this.relayHost, this.relayPort) .get(); - - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - headers.setAcceptVersion("1.1,1.2"); - headers.setLogin(this.systemLogin); - headers.setPasscode(this.systemPasscode); - headers.setHeartbeat(0,0); // TODO - Message message = MessageBuilder.withPayload( - new byte[0]).copyHeaders(headers.toNativeHeaderMap()).build(); - - RelaySession session = new RelaySession(message, headers) { - @Override - protected void sendMessageToClient(Message message) { - // TODO: check for ERROR frame (reconnect?) - } - }; - this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session); - + openSystemSession(); this.running = true; } } + /** + * Open a "system" session for sending messages from parts of the application + * not assoicated with a client STOMP session. + */ + private void openSystemSession() { + + RelaySession session = new RelaySession(STOMP_RELAY_SYSTEM_SESSION_ID) { + @Override + protected void sendMessageToClient(Message message) { + // ignore, only used to send messages + // TODO: ERROR frame/reconnect + } + }; + this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setAcceptVersion("1.1,1.2"); + headers.setLogin(this.systemLogin); + headers.setPasscode(this.systemPasscode); + headers.setHeartbeat(0,0); // TODO + + if (logger.isDebugEnabled()) { + logger.debug("Sending STOMP CONNECT frame to initialize \"system\" TCP connection"); + } + Message message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); + session.open(message); + } + @Override public void stop() { synchronized (this.lifecycleMonitor) { + if (logger.isDebugEnabled()) { + logger.debug("Stopping STOMP broker relay"); + } this.running = false; try { this.tcpClient.close().await(5000, TimeUnit.MILLISECONDS); + } + catch (Throwable t) { + logger.error("Failed to close reactor TCP client", t); + } + try { this.environment.shutdown(); } - catch (InterruptedException e) { - // ignore + catch (Throwable t) { + logger.error("Failed to shut down reactor Environment", t); } } } @@ -224,75 +255,87 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme } @Override - public void handleConnect(Message message) { - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - String sessionId = stompHeaders.getSessionId(); - if (sessionId == null) { - logger.error("No sessionId in message " + message); - return; - } - RelaySession relaySession = new RelaySession(message, stompHeaders); - this.relaySessions.put(sessionId, relaySession); - } - - @Override - public void handlePublish(Message message) { - forwardMessage(message, StompCommand.SEND); - } - - @Override - public void handleSubscribe(Message message) { - forwardMessage(message, StompCommand.SUBSCRIBE); - } - - @Override - public void handleUnsubscribe(Message message) { - forwardMessage(message, StompCommand.UNSUBSCRIBE); - } - - @Override - public void handleDisconnect(Message message) { - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - if (stompHeaders.getStompCommand() != null) { - forwardMessage(message, StompCommand.DISCONNECT); - } - String sessionId = stompHeaders.getSessionId(); - if (sessionId == null) { - logger.error("No sessionId in message " + message); - return; - } - } - - @Override - public void handleOther(Message message) { - StompCommand command = (StompCommand) message.getHeaders().get(SimpMessageHeaderAccessor.PROTOCOL_MESSAGE_TYPE); - Assert.notNull(command, "Expected STOMP command: " + message.getHeaders()); - forwardMessage(message, command); - } - - private void forwardMessage(Message message, StompCommand command) { + public void handleMessage(Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - headers.setStompCommandIfNotSet(command); - String sessionId = headers.getSessionId(); - if (sessionId == null) { - if (StompCommand.SEND.equals(command)) { - sessionId = STOMP_RELAY_SYSTEM_SESSION_ID; - } - else { - logger.error("No sessionId in message " + message); - return; - } - } + String destination = headers.getDestination(); + StompCommand command = headers.getStompCommand(); + SimpMessageType messageType = headers.getMessageType(); - RelaySession session = this.relaySessions.get(sessionId); - if (session == null) { - logger.warn("Session id=" + sessionId + " not found. Message cannot be forwarded: " + message); + if (!this.running) { + if (logger.isTraceEnabled()) { + logger.trace("STOMP broker relay not running. Ignoring message id=" + headers.getId()); + } return; } - session.forward(message, headers); + if (SimpMessageType.MESSAGE.equals(messageType)) { + sessionId = (sessionId == null) ? STOMP_RELAY_SYSTEM_SESSION_ID : sessionId; + headers.setSessionId(sessionId); + command = (command == null) ? StompCommand.SEND : command; + headers.setStompCommandIfNotSet(command); + message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build(); + } + + if (headers.getStompCommand() == null) { + logger.error("Ignoring message, no STOMP command: " + message); + return; + } + if (sessionId == null) { + logger.error("Ignoring message, no sessionId: " + message); + return; + } + if (command.requiresDestination() && (destination == null)) { + logger.error("Ignoring " + command + " message, no destination: " + message); + return; + } + + try { + if ((destination == null) || supportsDestination(destination)) { + if (logger.isTraceEnabled()) { + logger.trace("Processing message: " + message); + } + handleInternal(message, messageType, sessionId); + } + } + catch (Throwable t) { + logger.error("Failed to handle message " + message, t); + } + } + + protected boolean supportsDestination(String destination) { + for (String prefix : this.destinationPrefixes) { + if (destination.startsWith(prefix)) { + return true; + } + } + return false; + } + + protected void handleInternal(Message message, SimpMessageType messageType, String sessionId) { + if (SimpMessageType.CONNECT.equals(messageType)) { + RelaySession session = new RelaySession(sessionId); + this.relaySessions.put(sessionId, session); + session.open(message); + } + else if (SimpMessageType.DISCONNECT.equals(messageType)) { + RelaySession session = this.relaySessions.remove(sessionId); + if (session != null) { + if (logger.isTraceEnabled()) { + logger.trace("Session already removed, sessionId=" + sessionId); + } + session.forward(message); + } + } + else { + RelaySession session = this.relaySessions.get(sessionId); + if (session == null) { + logger.warn("Session id=" + sessionId + " not found. Ignoring message: " + message); + return; + } + session.forward(message); + } } @@ -300,21 +343,23 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme private final String sessionId; - private final Promise> promise; - private final BlockingQueue> messageQueue = new LinkedBlockingQueue>(50); - private final Object monitor = new Object(); + private Promise> promise; private volatile boolean isConnected = false; + private final Object monitor = new Object(); - public RelaySession(final Message message, final StompHeaderAccessor stompHeaders) { + public RelaySession(String sessionId) { + Assert.notNull(sessionId, "sessionId is required"); + this.sessionId = sessionId; + } + + public void open(final Message message) { Assert.notNull(message, "message is required"); - Assert.notNull(stompHeaders, "stompHeaders is required"); - this.sessionId = stompHeaders.getSessionId(); this.promise = tcpClient.open(); this.promise.consume(new Consumer>() { @@ -326,11 +371,9 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme readStompFrame(stompFrame); } }); - stompHeaders.setHeartbeat(0,0); // TODO - forwardInternal(message, stompHeaders, connection); + forwardInternal(message, connection); } }); - this.promise.onError(new Consumer() { @Override public void accept(Throwable ex) { @@ -339,14 +382,12 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme sendError(sessionId, "Failed to connect to message broker " + ex.toString()); } }); - - // TODO: ATM no way to detect closed socket } private void readStompFrame(String stompFrame) { + // heartbeat if (StringUtils.isEmpty(stompFrame)) { - // heartbeat? return; } @@ -359,13 +400,13 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme if (StompCommand.CONNECTED == headers.getStompCommand()) { synchronized(this.monitor) { this.isConnected = true; - flushMessages(promise.get()); + flushMessages(this.promise.get()); } return; } if (StompCommand.ERROR == headers.getStompCommand()) { if (logger.isDebugEnabled()) { - logger.warn("STOMP ERROR: " + headers.getMessage() + ". Removing session: " + this.sessionId); + logger.warn("STOMP ERROR: " + headers.getMessage() + ". Removing session id=" + this.sessionId); } relaySessions.remove(this.sessionId); } @@ -388,14 +429,14 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme sendMessageToClient(errorMessage); } - public void forward(Message message, StompHeaderAccessor headers) { + public void forward(Message message) { if (!this.isConnected) { synchronized(this.monitor) { if (!this.isConnected) { this.messageQueue.add(message); if (logger.isTraceEnabled()) { - logger.trace("Queued message " + message + ", queue size=" + this.messageQueue.size()); + logger.trace("Not connected yet, message queued, queue size=" + this.messageQueue.size()); } return; } @@ -405,7 +446,7 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme TcpConnection connection = this.promise.get(); if (this.messageQueue.isEmpty()) { - forwardInternal(message, headers, connection); + forwardInternal(message, connection); } else { this.messageQueue.add(message); @@ -413,36 +454,37 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme } } - private void flushMessages(TcpConnection connection) { - List> messages = new ArrayList>(); - this.messageQueue.drainTo(messages); - for (Message message : messages) { - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - if (!forwardInternal(message, headers, connection)) { - return; - } - } - } + private boolean forwardInternal(Message message, TcpConnection connection) { - private boolean forwardInternal(Message message, StompHeaderAccessor headers, TcpConnection connection) { try { - headers.setStompCommandIfNotSet(StompCommand.SEND); - if (logger.isTraceEnabled()) { - logger.trace("Forwarding message " + message); + logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId()); } - byte[] bytes = stompMessageConverter.fromMessage(message); connection.send(new String(bytes, Charset.forName("UTF-8"))); } catch (Throwable ex) { - logger.error("Failed to forward message " + message, ex); - connection.close(); + logger.error("Forward failed message id=" + message.getHeaders().getId(), ex); + try { + connection.close(); + } + catch (Throwable t) { + // ignore + } sendError(this.sessionId, "Failed to forward message " + message + ": " + ex.getMessage()); return false; } return true; } - } + private void flushMessages(TcpConnection connection) { + List> messages = new ArrayList>(); + this.messageQueue.drainTo(messages); + for (Message message : messages) { + if (!forwardInternal(message, connection)) { + return; + } + } + } + } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java index e66d09efa4e..a5e3c53e98a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCommand.java @@ -16,8 +16,11 @@ package org.springframework.messaging.simp.stomp; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; import org.springframework.messaging.simp.SimpMessageType; @@ -49,21 +52,28 @@ public enum StompCommand { ERROR; - private static Map commandToMessageType = new HashMap(); + private static Map messageTypeLookup = new HashMap(); + + private static Set destinationRequiredLookup = + new HashSet(Arrays.asList(SEND, SUBSCRIBE, MESSAGE)); static { - commandToMessageType.put(StompCommand.CONNECT, SimpMessageType.CONNECT); - commandToMessageType.put(StompCommand.STOMP, SimpMessageType.CONNECT); - commandToMessageType.put(StompCommand.SEND, SimpMessageType.MESSAGE); - commandToMessageType.put(StompCommand.MESSAGE, SimpMessageType.MESSAGE); - commandToMessageType.put(StompCommand.SUBSCRIBE, SimpMessageType.SUBSCRIBE); - commandToMessageType.put(StompCommand.UNSUBSCRIBE, SimpMessageType.UNSUBSCRIBE); - commandToMessageType.put(StompCommand.DISCONNECT, SimpMessageType.DISCONNECT); + messageTypeLookup.put(StompCommand.CONNECT, SimpMessageType.CONNECT); + messageTypeLookup.put(StompCommand.STOMP, SimpMessageType.CONNECT); + messageTypeLookup.put(StompCommand.SEND, SimpMessageType.MESSAGE); + messageTypeLookup.put(StompCommand.MESSAGE, SimpMessageType.MESSAGE); + messageTypeLookup.put(StompCommand.SUBSCRIBE, SimpMessageType.SUBSCRIBE); + messageTypeLookup.put(StompCommand.UNSUBSCRIBE, SimpMessageType.UNSUBSCRIBE); + messageTypeLookup.put(StompCommand.DISCONNECT, SimpMessageType.DISCONNECT); } public SimpMessageType getMessageType() { - SimpMessageType messageType = commandToMessageType.get(this); - return (messageType != null) ? messageType : SimpMessageType.OTHER; + SimpMessageType type = messageTypeLookup.get(this); + return (type != null) ? type : SimpMessageType.OTHER; + } + + public boolean requiresDestination() { + return destinationRequiredLookup.contains(this); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java index 8ad764baa8a..670c2f14dea 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java @@ -27,6 +27,7 @@ import org.springframework.http.MediaType; import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -84,21 +85,31 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { */ private StompHeaderAccessor(StompCommand command, Map> externalSourceHeaders) { super(command.getMessageType(), command, externalSourceHeaders); - initSimpMessageHeaders(); + if (externalSourceHeaders != null) { + setSimpMessageHeaders(externalSourceHeaders); + } } - private void initSimpMessageHeaders() { - String destination = getFirstNativeHeader(DESTINATION); - if (destination != null) { - super.setDestination(destination); + private void setSimpMessageHeaders(Map> extHeaders) { + List values = extHeaders.get(StompHeaderAccessor.DESTINATION); + if (!CollectionUtils.isEmpty(values)) { + super.setDestination(values.get(0)); } - String contentType = getFirstNativeHeader(CONTENT_TYPE); - if (contentType != null) { - super.setContentType(MediaType.parseMediaType(contentType)); + values = extHeaders.get(StompHeaderAccessor.CONTENT_TYPE); + if (!CollectionUtils.isEmpty(values)) { + super.setContentType(MediaType.parseMediaType(values.get(0))); } - if (StompCommand.SUBSCRIBE.equals(getStompCommand()) || StompCommand.UNSUBSCRIBE.equals(getStompCommand())) { - if (getFirstNativeHeader(STOMP_ID) != null) { - super.setSubscriptionId(getFirstNativeHeader(STOMP_ID)); + StompCommand command = getStompCommand(); + if (StompCommand.SUBSCRIBE.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) { + values = extHeaders.get(StompHeaderAccessor.STOMP_ID); + if (!CollectionUtils.isEmpty(values)) { + super.setSubscriptionId(values.get(0)); + } + } + else if (StompCommand.MESSAGE.equals(command)) { + values = extHeaders.get(StompHeaderAccessor.SUBSCRIPTION); + if (!CollectionUtils.isEmpty(values)) { + super.setSubscriptionId(values.get(0)); } } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java index 1ea634f23f4..30edde3f04d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompWebSocketHandler.java @@ -26,7 +26,6 @@ import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.CloseStatus; @@ -176,10 +175,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { - this.sessions.remove(session.getId()); + String sessionId = session.getId(); + this.sessions.remove(sessionId); - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); - headers.setSessionId(session.getId()); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.setSessionId(sessionId); Message message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); this.clientInputChannel.send(message); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java index aa62f8c36a0..d8fdf64c992 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.UUID; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -202,10 +203,22 @@ public class MessageHeaderAccessor { } } + public UUID getId() { + return (UUID) getHeader(MessageHeaders.ID); + } + + public Long getTimestamp() { + return (Long) getHeader(MessageHeaders.TIMESTAMP); + } + public void setReplyChannel(MessageChannel replyChannel) { setHeader(MessageHeaders.REPLY_CHANNEL, replyChannel); } + public Object getReplyChannel() { + return getHeader(MessageHeaders.REPLY_CHANNEL); + } + public void setReplyChannelName(String replyChannelName) { setHeader(MessageHeaders.REPLY_CHANNEL, replyChannelName); } @@ -214,6 +227,10 @@ public class MessageHeaderAccessor { setHeader(MessageHeaders.ERROR_CHANNEL, errorChannel); } + public Object getErrorChannel() { + return getHeader(MessageHeaders.ERROR_CHANNEL); + } + public void setErrorChannelName(String errorChannelName) { setHeader(MessageHeaders.ERROR_CHANNEL, errorChannelName); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java index e9dc13a2b89..b5d147ef3ad 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java @@ -50,7 +50,6 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { * A constructor for creating new headers, accepting an optional native header map. */ public NativeMessageHeaderAccessor(Map> nativeHeaders) { - super(); this.originalNativeHeaders = nativeHeaders; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java new file mode 100644 index 00000000000..b7a3800d856 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java @@ -0,0 +1,104 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support.channel; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.beans.factory.BeanNameAware; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.SubscribableChannel; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; + + +/** + * Abstract base class for {@link SubscribableChannel} implementations. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public abstract class AbstractSubscribableChannel implements SubscribableChannel, BeanNameAware { + + protected Log logger = LogFactory.getLog(getClass()); + + private String beanName; + + + public AbstractSubscribableChannel() { + this.beanName = getClass().getSimpleName() + "@" + ObjectUtils.getIdentityHexString(this); + } + + /** + * {@inheritDoc} + *

Used primarily for logging purposes. + */ + @Override + public void setBeanName(String name) { + this.beanName = name; + } + + /** + * @return the name for this channel. + */ + public String getBeanName() { + return this.beanName; + } + + @Override + public final boolean send(Message message) { + return send(message, INDEFINITE_TIMEOUT); + } + + @Override + public final boolean send(Message message, long timeout) { + Assert.notNull(message, "Message must not be null"); + if (logger.isTraceEnabled()) { + logger.trace("[" + this.beanName + "] sending message " + message); + } + return sendInternal(message, timeout); + } + + protected abstract boolean sendInternal(Message message, long timeout); + + @Override + public final boolean subscribe(MessageHandler handler) { + if (hasSubscription(handler)) { + logger.warn("[" + this.beanName + "] handler already subscribed " + handler); + return false; + } + if (logger.isDebugEnabled()) { + logger.debug("[" + this.beanName + "] subscribing " + handler); + } + return subscribeInternal(handler); + } + + protected abstract boolean hasSubscription(MessageHandler handler); + + protected abstract boolean subscribeInternal(MessageHandler handler); + + @Override + public final boolean unsubscribe(MessageHandler handler) { + if (logger.isDebugEnabled()) { + logger.debug("[" + this.beanName + "] unsubscribing " + handler); + } + return unsubscribeInternal(handler); + } + + protected abstract boolean unsubscribeInternal(MessageHandler handler); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorMessageChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorSubscribableChannel.java similarity index 52% rename from spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorMessageChannel.java rename to spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorSubscribableChannel.java index 2fc744049da..d60ccea291f 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorMessageChannel.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ReactorSubscribableChannel.java @@ -19,16 +19,14 @@ package org.springframework.messaging.support.channel; import java.util.HashMap; import java.util.Map; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.SubscribableChannel; import reactor.core.Reactor; import reactor.event.Event; import reactor.event.registry.Registration; import reactor.event.selector.ObjectSelector; +import reactor.event.selector.Selector; import reactor.function.Consumer; @@ -36,88 +34,52 @@ import reactor.function.Consumer; * @author Rossen Stoyanchev * @since 4.0 */ -public class ReactorMessageChannel implements SubscribableChannel { - - private static Log logger = LogFactory.getLog(ReactorMessageChannel.class); +public class ReactorSubscribableChannel extends AbstractSubscribableChannel { private final Reactor reactor; private final Object key = new Object(); - private String name = toString(); // TODO + private final Map> registrations = new HashMap>(); - private final Map> registrations = - new HashMap>(); - - - public ReactorMessageChannel(Reactor reactor) { + public ReactorSubscribableChannel(Reactor reactor) { this.reactor = reactor; } - public void setName(String name) { - this.name = name; - } - public String getName() { - return this.name; + @Override + protected boolean hasSubscription(MessageHandler handler) { + return this.registrations.containsKey(handler); } @Override - public boolean send(Message message) { - return send(message, -1); - } - - @Override - public boolean send(Message message, long timeout) { - if (logger.isTraceEnabled()) { - logger.trace("Channel " + getName() + ", sending message id=" + message.getHeaders().getId()); - } + public boolean sendInternal(Message message, long timeout) { this.reactor.notify(this.key, Event.wrap(message)); return true; } @Override - public boolean subscribe(final MessageHandler handler) { - - if (this.registrations.containsKey(handler)) { - logger.warn("Channel " + getName() + ", handler already subscribed " + handler); - return false; - } - - if (logger.isTraceEnabled()) { - logger.trace("Channel " + getName() + ", subscribing handler " + handler); - } - - Registration>>> registration = this.reactor.on( - ObjectSelector.objectSelector(key), new MessageHandlerConsumer(handler)); - + public boolean subscribeInternal(final MessageHandler handler) { + Selector selector = ObjectSelector.objectSelector(this.key); + MessageHandlerConsumer consumer = new MessageHandlerConsumer(handler); + Registration>>> registration = this.reactor.on(selector, consumer); this.registrations.put(handler, registration); - return true; } @Override - public boolean unsubscribe(MessageHandler handler) { - - if (logger.isTraceEnabled()) { - logger.trace("Channel " + getName() + ", removing subscription for handler " + handler); - } - + public boolean unsubscribeInternal(MessageHandler handler) { Registration registration = this.registrations.remove(handler); - if (registration == null) { - if (logger.isTraceEnabled()) { - logger.trace("Channel " + getName() + ", no subscription for handler " + handler); - } - return false; + if (registration != null) { + registration.cancel(); + return true; } - - registration.cancel(); - return true; + return false; } - private static final class MessageHandlerConsumer implements Consumer>> { + private final class MessageHandlerConsumer implements Consumer>> { private final MessageHandler handler; @@ -132,10 +94,8 @@ public class ReactorMessageChannel implements SubscribableChannel { this.handler.handleMessage(message); } catch (Throwable t) { - // TODO logger.error("Failed to process message " + message, t); } } } - } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/PublishSubscribeChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/TaskExecutorSubscribableChannel.java similarity index 57% rename from spring-messaging/src/main/java/org/springframework/messaging/support/channel/PublishSubscribeChannel.java rename to spring-messaging/src/main/java/org/springframework/messaging/support/channel/TaskExecutorSubscribableChannel.java index 5d48bd00a82..0d58885eee0 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/PublishSubscribeChannel.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/TaskExecutorSubscribableChannel.java @@ -18,80 +18,75 @@ package org.springframework.messaging.support.channel; import java.util.Set; import java.util.concurrent.CopyOnWriteArraySet; -import java.util.concurrent.Executor; +import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.SubscribableChannel; -import org.springframework.util.Assert; /** * A {@link SubscribableChannel} that sends messages to each of its subscribers. * * @author Phillip Webb + * @author Rossen Stoyanchev * @since 4.0 */ -public class PublishSubscribeChannel implements SubscribableChannel { +public class TaskExecutorSubscribableChannel extends AbstractSubscribableChannel { - private final Executor executor; + private final TaskExecutor executor; private final Set handlers = new CopyOnWriteArraySet(); /** - * Create a new {@link PublishSubscribeChannel} instance where messages will be sent + * Create a new {@link TaskExecutorSubscribableChannel} instance where messages will be sent * in the callers thread. */ - public PublishSubscribeChannel() { + public TaskExecutorSubscribableChannel() { this(null); } /** - * Create a new {@link PublishSubscribeChannel} instance where messages will be sent + * Create a new {@link TaskExecutorSubscribableChannel} instance where messages will be sent * via the specified executor. * @param executor the executor used to send the message or {@code null} to execute in * the callers thread. */ - public PublishSubscribeChannel(Executor executor) { + public TaskExecutorSubscribableChannel(TaskExecutor executor) { this.executor = executor; } + @Override - public boolean send(Message message) { - return send(message, INDEFINITE_TIMEOUT); + protected boolean hasSubscription(MessageHandler handler) { + return this.handlers.contains(handler); } @Override - public boolean send(Message message, long timeout) { - Assert.notNull(message, "Message must not be null"); - Assert.notNull(message.getPayload(), "Message payload must not be null"); + public boolean sendInternal(final Message message, long timeout) { for (final MessageHandler handler : this.handlers) { - dispatchToHandler(message, handler); + if (this.executor == null) { + handler.handleMessage(message); + } + else { + this.executor.execute(new Runnable() { + @Override + public void run() { + handler.handleMessage(message); + } + }); + } } return true; } - private void dispatchToHandler(final Message message, final MessageHandler handler) { - if (this.executor == null) { - handler.handleMessage(message); - } - else { - this.executor.execute(new Runnable() { - @Override - public void run() { - handler.handleMessage(message); - } - }); - } - } - @Override - public boolean subscribe(MessageHandler handler) { + public boolean subscribeInternal(MessageHandler handler) { return this.handlers.add(handler); } @Override - public boolean unsubscribe(MessageHandler handler) { + public boolean unsubscribeInternal(MessageHandler handler) { return this.handlers.remove(handler); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistryTests.java index bae10f6ac9a..fdc4c6c9480 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/DefaultSubscriptionRegistryTests.java @@ -25,7 +25,6 @@ import org.junit.Test; import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.messaging.simp.handler.DefaultSubscriptionRegistry; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.MultiValueMap; @@ -49,30 +48,30 @@ public class DefaultSubscriptionRegistryTests { @Test - public void addSubscriptionInvalidInput() { + public void registerSubscriptionInvalidInput() { String sessId = "sess01"; String subsId = "subs01"; String dest = "/foo"; - this.registry.addSubscription(subscribeMessage(null, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(null, subsId, dest)); assertEquals(0, this.registry.findSubscriptions(message(dest)).size()); - this.registry.addSubscription(subscribeMessage(sessId, null, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, null, dest)); assertEquals(0, this.registry.findSubscriptions(message(dest)).size()); - this.registry.addSubscription(subscribeMessage(sessId, subsId, null)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, null)); assertEquals(0, this.registry.findSubscriptions(message(dest)).size()); } @Test - public void addSubscription() { + public void registerSubscription() { String sessId = "sess01"; String subsId = "subs01"; String dest = "/foo"; - this.registry.addSubscription(subscribeMessage(sessId, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest)); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); assertEquals("Expected one element " + actual, 1, actual.size()); @@ -80,14 +79,14 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void addSubscriptionOneSession() { + public void registerSubscriptionOneSession() { String sessId = "sess01"; List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); String dest = "/foo"; for (String subId : subscriptionIds) { - this.registry.addSubscription(subscribeMessage(sessId, subId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subId, dest)); } MultiValueMap actual = this.registry.findSubscriptions(message(dest)); @@ -97,7 +96,7 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void addSubscriptionMultipleSessions() { + public void registerSubscriptionMultipleSessions() { List sessIds = Arrays.asList("sess01", "sess02", "sess03"); List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); @@ -105,7 +104,7 @@ public class DefaultSubscriptionRegistryTests { for (String sessId : sessIds) { for (String subsId : subscriptionIds) { - this.registry.addSubscription(subscribeMessage(sessId, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest)); } } @@ -118,14 +117,14 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void addSubscriptionWithDestinationPattern() { + public void registerSubscriptionWithDestinationPattern() { String sessId = "sess01"; String subsId = "subs01"; String destPattern = "/topic/PRICE.STOCK.*.IBM"; String dest = "/topic/PRICE.STOCK.NASDAQ.IBM"; - this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, destPattern)); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); assertEquals("Expected one element " + actual, 1, actual.size()); @@ -133,13 +132,13 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void addSubscriptionWithDestinationPatternRegex() { + public void registerSubscriptionWithDestinationPatternRegex() { String sessId = "sess01"; String subsId = "subs01"; String destPattern = "/topic/PRICE.STOCK.*.{ticker:(IBM|MSFT)}"; - this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, destPattern)); Message message = message("/topic/PRICE.STOCK.NASDAQ.IBM"); MultiValueMap actual = this.registry.findSubscriptions(message); @@ -159,7 +158,7 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void removeSubscription() { + public void unregisterSubscription() { List sessIds = Arrays.asList("sess01", "sess02", "sess03"); List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); @@ -167,13 +166,13 @@ public class DefaultSubscriptionRegistryTests { for (String sessId : sessIds) { for (String subsId : subscriptionIds) { - this.registry.addSubscription(subscribeMessage(sessId, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest)); } } - this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(0))); - this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(1))); - this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(2))); + this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(0))); + this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(1))); + this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(2))); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); @@ -183,7 +182,7 @@ public class DefaultSubscriptionRegistryTests { } @Test - public void removeSessionSubscriptions() { + public void unregisterAllSubscriptions() { List sessIds = Arrays.asList("sess01", "sess02", "sess03"); List subscriptionIds = Arrays.asList("subs01", "subs02", "subs03"); @@ -191,12 +190,12 @@ public class DefaultSubscriptionRegistryTests { for (String sessId : sessIds) { for (String subsId : subscriptionIds) { - this.registry.addSubscription(subscribeMessage(sessId, subsId, dest)); + this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest)); } } - this.registry.removeSessionSubscriptions(sessIds.get(0)); - this.registry.removeSessionSubscriptions(sessIds.get(1)); + this.registry.unregisterAllSubscriptions(sessIds.get(0)); + this.registry.unregisterAllSubscriptions(sessIds.get(1)); MultiValueMap actual = this.registry.findSubscriptions(message(dest)); @@ -204,6 +203,12 @@ public class DefaultSubscriptionRegistryTests { assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2)))); } + @Test + public void unregisterAllSubscriptionsNoMatch() { + this.registry.unregisterAllSubscriptions("bogus"); + // no exceptions + } + @Test public void findSubscriptionsNoMatches() { MultiValueMap actual = this.registry.findSubscriptions(message("/foo")); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerWebMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerWebMessageHandlerTests.java index 5044f74e6db..156c59ad7dc 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerWebMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpleBrokerWebMessageHandlerTests.java @@ -16,8 +16,6 @@ package org.springframework.messaging.simp.handler; -import java.util.Arrays; - import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -30,7 +28,6 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; -import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -57,26 +54,19 @@ public class SimpleBrokerWebMessageHandlerTests { } - @Test - public void getSupportedMessageTypes() { - assertEquals(Arrays.asList(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE, - SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT), - this.messageHandler.getSupportedMessageTypes()); - } - @Test public void subcribePublish() { - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub1", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub2", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub3", "/bar")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub3", "/bar")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub1", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub2", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub3", "/bar")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub3", "/bar")); - this.messageHandler.handlePublish(createMessage("/foo", "message1")); - this.messageHandler.handlePublish(createMessage("/bar", "message2")); + this.messageHandler.handleMessage(createMessage("/foo", "message1")); + this.messageHandler.handleMessage(createMessage("/bar", "message2")); verify(this.clientChannel, times(6)).send(this.messageCaptor.capture()); assertCapturedMessage("sess1", "sub1", "/foo"); @@ -93,21 +83,21 @@ public class SimpleBrokerWebMessageHandlerTests { String sess1 = "sess1"; String sess2 = "sess2"; - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub1", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub2", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub3", "/bar")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub3", "/bar")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub1", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub2", "/foo")); - this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub3", "/bar")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub3", "/bar")); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); headers.setSessionId(sess1); Message message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); - this.messageHandler.handleDisconnect(message); + this.messageHandler.handleMessage(message); - this.messageHandler.handlePublish(createMessage("/foo", "message1")); - this.messageHandler.handlePublish(createMessage("/bar", "message2")); + this.messageHandler.handleMessage(createMessage("/foo", "message1")); + this.messageHandler.handleMessage(createMessage("/bar", "message2")); verify(this.clientChannel, times(3)).send(this.messageCaptor.capture()); assertCapturedMessage(sess2, "sub1", "/foo"); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java index a4c9d5c3c88..1483e1623d5 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java @@ -16,8 +16,6 @@ package org.springframework.messaging.support.channel; -import java.util.concurrent.Executor; - import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -26,18 +24,19 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessagingException; import org.springframework.messaging.support.MessageBuilder; -import org.springframework.messaging.support.channel.PublishSubscribeChannel; import static org.hamcrest.Matchers.*; import static org.junit.Assert.*; import static org.mockito.BDDMockito.*; +import static org.mockito.Mockito.*; /** - * Tests for {@link PublishSubscribeChannel}. + * Tests for {@link TaskExecutorSubscribableChannel}. * * @author Phillip Webb */ @@ -47,7 +46,7 @@ public class PublishSubscibeChannelTests { public ExpectedException thrown = ExpectedException.none(); - private PublishSubscribeChannel channel = new PublishSubscribeChannel(); + private TaskExecutorSubscribableChannel channel = new TaskExecutorSubscribableChannel(); @Mock private MessageHandler handler; @@ -89,8 +88,8 @@ public class PublishSubscibeChannelTests { @Test public void sendWithExecutor() throws Exception { - Executor executor = mock(Executor.class); - this.channel = new PublishSubscribeChannel(executor); + TaskExecutor executor = mock(TaskExecutor.class); + this.channel = new TaskExecutorSubscribableChannel(executor); this.channel.subscribe(this.handler); this.channel.send(this.message); verify(executor).execute(this.runnableCaptor.capture());