Refactor and polish spring-messaging

Remove base class for STOMP-related message handler classes
(AbstractSimpMessageHandler), polish subclasses and fix issues with
more significant updates to STOMP broker relay.

Introduce base class for SubscribableChannel implementations providing
consistent logging for all channel implementations.
This commit is contained in:
Rossen Stoyanchev 2013-07-13 19:05:32 -04:00
parent f5f3f66b13
commit 2a48ad88fb
18 changed files with 521 additions and 532 deletions

View File

@ -1,164 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.handler;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.CollectionUtils;
import org.springframework.util.PathMatcher;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractSimpMessageHandler implements MessageHandler {
protected final Log logger = LogFactory.getLog(getClass());
private final List<String> allowedDestinations = new ArrayList<String>();
private final List<String> disallowedDestinations = new ArrayList<String>();
private final PathMatcher pathMatcher = new AntPathMatcher();
/**
* Ant-style destination patterns that this service is allowed to process.
*/
public void setAllowedDestinations(String... patterns) {
this.allowedDestinations.clear();
this.allowedDestinations.addAll(Arrays.asList(patterns));
}
/**
* Ant-style destination patterns that this service should skip.
*/
public void setDisallowedDestinations(String... patterns) {
this.disallowedDestinations.clear();
this.disallowedDestinations.addAll(Arrays.asList(patterns));
}
protected abstract Collection<SimpMessageType> getSupportedMessageTypes();
protected boolean canHandle(Message<?> message, SimpMessageType messageType) {
if (!CollectionUtils.isEmpty(getSupportedMessageTypes())) {
if (!getSupportedMessageTypes().contains(messageType)) {
return false;
}
}
return isDestinationAllowed(message);
}
protected boolean isDestinationAllowed(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
String destination = headers.getDestination();
if (destination == null) {
return true;
}
if (!this.disallowedDestinations.isEmpty()) {
for (String pattern : this.disallowedDestinations) {
if (this.pathMatcher.match(pattern, destination)) {
if (logger.isTraceEnabled()) {
logger.trace("Skip message id=" + message.getHeaders().getId());
}
return false;
}
}
}
if (!this.allowedDestinations.isEmpty()) {
for (String pattern : this.allowedDestinations) {
if (this.pathMatcher.match(pattern, destination)) {
return true;
}
}
if (logger.isTraceEnabled()) {
logger.trace("Skip message id=" + message.getHeaders().getId());
}
return false;
}
return true;
}
@Override
public final void handleMessage(Message<?> message) throws MessagingException {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
SimpMessageType messageType = headers.getMessageType();
if (!canHandle(message, messageType)) {
return;
}
if (SimpMessageType.MESSAGE.equals(messageType)) {
handlePublish(message);
}
else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
handleSubscribe(message);
}
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
handleUnsubscribe(message);
}
else if (SimpMessageType.CONNECT.equals(messageType)) {
handleConnect(message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
handleDisconnect(message);
}
else {
handleOther(message);
}
}
protected void handleConnect(Message<?> message) {
}
protected void handlePublish(Message<?> message) {
}
protected void handleSubscribe(Message<?> message) {
}
protected void handleUnsubscribe(Message<?> message) {
}
protected void handleDisconnect(Message<?> message) {
}
protected void handleOther(Message<?> message) {
}
}

View File

@ -34,7 +34,7 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
@Override
public void addSubscription(Message<?> message) {
public final void registerSubscription(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
if (!SimpMessageType.SUBSCRIBE.equals(headers.getMessageType())) {
logger.error("Expected SUBSCRIBE message: " + message);
@ -55,6 +55,9 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
logger.error("Ignoring destination. No destination in message: " + message);
return;
}
if (logger.isDebugEnabled()) {
logger.debug("Subscribe request: " + message);
}
addSubscriptionInternal(sessionId, subscriptionId, destination, message);
}
@ -62,7 +65,7 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
String destination, Message<?> message);
@Override
public void removeSubscription(Message<?> message) {
public final void unregisterSubscription(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
if (!SimpMessageType.UNSUBSCRIBE.equals(headers.getMessageType())) {
logger.error("Expected UNSUBSCRIBE message: " + message);
@ -78,17 +81,19 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
logger.error("Ignoring subscription. No subscriptionId in message: " + message);
return;
}
if (logger.isDebugEnabled()) {
logger.debug("Unubscribe request: " + message);
}
removeSubscriptionInternal(sessionId, subscriptionId, message);
}
protected abstract void removeSubscriptionInternal(String sessionId, String subscriptionId, Message<?> message);
@Override
public void removeSessionSubscriptions(String sessionId) {
}
public abstract void unregisterAllSubscriptions(String sessionId);
@Override
public MultiValueMap<String, String> findSubscriptions(Message<?> message) {
public final MultiValueMap<String, String> findSubscriptions(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
if (!SimpMessageType.MESSAGE.equals(headers.getMessageType())) {
logger.error("Unexpected message type: " + message);
@ -99,6 +104,9 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
logger.error("Ignoring destination. No destination in message: " + message);
return null;
}
if (logger.isTraceEnabled()) {
logger.trace("Find subscriptions, destination=" + headers.getDestination());
}
return findSubscriptionsInternal(destination, message);
}

View File

@ -19,13 +19,14 @@ package org.springframework.messaging.simp.handler;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationContext;
@ -34,6 +35,8 @@ import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.support.MessageBodyArgumentResolver;
import org.springframework.messaging.handler.annotation.support.MessageExceptionHandlerMethodResolver;
@ -60,8 +63,9 @@ import org.springframework.web.method.HandlerMethodSelector;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler
implements ApplicationContextAware, InitializingBean {
public class AnnotationSimpMessageHandler implements MessageHandler, ApplicationContextAware, InitializingBean {
private static final Log logger = LogFactory.getLog(AnnotationSimpMessageHandler.class);
private final MessageChannel outboundChannel;
@ -104,11 +108,6 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler
this.applicationContext = applicationContext;
}
@Override
protected Collection<SimpMessageType> getSupportedMessageTypes() {
return Arrays.asList(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE, SimpMessageType.UNSUBSCRIBE);
}
@Override
public void afterPropertiesSet() {
@ -183,18 +182,20 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler
}
@Override
public void handlePublish(Message<?> message) {
handleMessageInternal(message, this.messageMethods);
}
public void handleMessage(Message<?> message) throws MessagingException {
@Override
public void handleSubscribe(Message<?> message) {
handleMessageInternal(message, this.subscribeMethods);
}
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
SimpMessageType messageType = headers.getMessageType();
@Override
public void handleUnsubscribe(Message<?> message) {
handleMessageInternal(message, this.unsubscribeMethods);
if (SimpMessageType.MESSAGE.equals(messageType)) {
handleMessageInternal(message, this.messageMethods);
}
else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
handleMessageInternal(message, this.subscribeMethods);
}
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
handleMessageInternal(message, this.unsubscribeMethods);
}
}
private void handleMessageInternal(final Message<?> message, Map<MappingInfo, HandlerMethod> handlerMethods) {

View File

@ -74,9 +74,14 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
}
@Override
public void removeSessionSubscriptions(String sessionId) {
public void unregisterAllSubscriptions(String sessionId) {
SessionSubscriptionInfo info = this.subscriptionRegistry.removeSubscriptions(sessionId);
this.destinationCache.removeSessionSubscriptions(info);
if (info != null) {
if (logger.isDebugEnabled()) {
logger.debug("Unregistering subscriptions for sessionId=" + sessionId);
}
this.destinationCache.removeSessionSubscriptions(info);
}
}
@Override

View File

@ -16,11 +16,12 @@
package org.springframework.messaging.simp.handler;
import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
@ -32,7 +33,9 @@ import org.springframework.util.MultiValueMap;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler {
public class SimpleBrokerMessageHandler implements MessageHandler {
private static final Log logger = LogFactory.getLog(SimpleBrokerMessageHandler.class);
private final MessageChannel outboundChannel;
@ -54,42 +57,36 @@ public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler {
this.subscriptionRegistry = subscriptionRegistry;
}
@Override
protected Collection<SimpMessageType> getSupportedMessageTypes() {
return Arrays.asList(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE,
SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT);
public SubscriptionRegistry getSubscriptionRegistry() {
return this.subscriptionRegistry;
}
@Override
public void handleSubscribe(Message<?> message) {
public void handleMessage(Message<?> message) throws MessagingException {
if (logger.isDebugEnabled()) {
logger.debug("Subscribe " + message);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
SimpMessageType messageType = headers.getMessageType();
if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
// TODO: need a way to communicate back if subscription was successfully created or
// not in which case an ERROR should be sent back and close the connection
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
this.subscriptionRegistry.registerSubscription(message);
}
this.subscriptionRegistry.addSubscription(message);
// TODO: need a way to communicate back if subscription was successfully created or
// not in which case an ERROR should be sent back and close the connection
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
}
@Override
protected void handleUnsubscribe(Message<?> message) {
this.subscriptionRegistry.removeSubscription(message);
}
@Override
public void handlePublish(Message<?> message) {
if (logger.isTraceEnabled()) {
logger.trace("Message received: " + message);
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
this.subscriptionRegistry.unregisterSubscription(message);
}
else if (SimpMessageType.MESSAGE.equals(messageType)) {
sendMessageToSubscribers(headers.getDestination(), message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
String sessionId = SimpMessageHeaderAccessor.wrap(message).getSessionId();
this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
}
}
String destination = SimpMessageHeaderAccessor.wrap(message).getDestination();
protected void sendMessageToSubscribers(String destination, Message<?> message) {
MultiValueMap<String,String> subscriptions = this.subscriptionRegistry.findSubscriptions(message);
for (String sessionId : subscriptions.keySet()) {
for (String subscriptionId : subscriptions.get(sessionId)) {
@ -99,7 +96,6 @@ public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler {
Message<?> clientMessage = MessageBuilder.withPayload(
message.getPayload()).copyHeaders(headers.toMap()).build();
try {
this.outboundChannel.send(clientMessage);
}
@ -110,11 +106,4 @@ public class SimpleBrokerMessageHandler extends AbstractSimpMessageHandler {
}
}
}
@Override
public void handleDisconnect(Message<?> message) {
String sessionId = SimpMessageHeaderAccessor.wrap(message).getSessionId();
this.subscriptionRegistry.removeSessionSubscriptions(sessionId);
}
}

View File

@ -19,19 +19,37 @@ package org.springframework.messaging.simp.handler;
import org.springframework.messaging.Message;
import org.springframework.util.MultiValueMap;
/**
* A registry of subscription by session that allows looking up subscriptions.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface SubscriptionRegistry {
void addSubscription(Message<?> subscribeMessage);
/**
* Register a subscription represented by the given message.
* @param subscribeMessage the subscription request
*/
void registerSubscription(Message<?> subscribeMessage);
void removeSubscription(Message<?> unsubscribeMessage);
/**
* Unregister a subscription.
* @param unsubscribeMessage the request to unsubscribe
*/
void unregisterSubscription(Message<?> unsubscribeMessage);
void removeSessionSubscriptions(String sessionId);
/**
* Remove all subscriptions associated with the given sessionId.
*/
void unregisterAllSubscriptions(String sessionId);
MultiValueMap<String, String> findSubscriptions(Message<?> message);
/**
* Find all subscriptions that should receive the given message.
*
* @param message the message
* @return a {@link MultiValueMap} from sessionId to subscriptionId's, possibly empty.
*/
MultiValueMap<String, String> findSubscriptions(Message<?> message);
}

View File

@ -26,12 +26,13 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@ -51,12 +52,15 @@ import reactor.tcp.spec.TcpClientSpec;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompRelayMessageHandler extends AbstractSimpMessageHandler implements SmartLifecycle {
public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLifecycle {
private static final Log logger = LogFactory.getLog(StompBrokerRelayMessageHandler.class);
private static final String STOMP_RELAY_SYSTEM_SESSION_ID = "stompRelaySystemSessionId";
private final MessageChannel outboundChannel;
private MessageChannel outboundChannel;
private final String[] destinationPrefixes;
private String relayHost = "127.0.0.1";
@ -81,13 +85,16 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
/**
* @param outboundChannel a channel for messages going out to clients
* @param destinationPrefixes the broker supported destination prefixes; destinations
* that do not match the given prefix are ignored.
*/
public StompRelayMessageHandler(MessageChannel outboundChannel) {
public StompBrokerRelayMessageHandler(MessageChannel outboundChannel, Collection<String> destinationPrefixes) {
Assert.notNull(outboundChannel, "outboundChannel is required");
Assert.notNull(destinationPrefixes, "destinationPrefixes is required");
this.outboundChannel = outboundChannel;
this.destinationPrefixes = destinationPrefixes.toArray(new String[destinationPrefixes.size()]);
}
/**
* Set the STOMP message broker host.
*/
@ -148,9 +155,11 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
return this.systemPasscode;
}
@Override
protected Collection<SimpMessageType> getSupportedMessageTypes() {
return null;
/**
* @return the configured STOMP broker supported destination prefixes.
*/
public String[] getDestinationPrefixes() {
return destinationPrefixes;
}
@Override
@ -173,44 +182,66 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
@Override
public void start() {
synchronized (this.lifecycleMonitor) {
if (logger.isDebugEnabled()) {
logger.debug("Starting STOMP broker relay");
}
this.environment = new Environment();
this.tcpClient = new TcpClientSpec<String, String>(NettyTcpClient.class)
.env(this.environment)
.codec(new DelimitedCodec<String, String>((byte) 0, true, StandardCodecs.STRING_CODEC))
.connect(this.relayHost, this.relayPort)
.get();
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setAcceptVersion("1.1,1.2");
headers.setLogin(this.systemLogin);
headers.setPasscode(this.systemPasscode);
headers.setHeartbeat(0,0); // TODO
Message<?> message = MessageBuilder.withPayload(
new byte[0]).copyHeaders(headers.toNativeHeaderMap()).build();
RelaySession session = new RelaySession(message, headers) {
@Override
protected void sendMessageToClient(Message<?> message) {
// TODO: check for ERROR frame (reconnect?)
}
};
this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session);
openSystemSession();
this.running = true;
}
}
/**
* Open a "system" session for sending messages from parts of the application
* not assoicated with a client STOMP session.
*/
private void openSystemSession() {
RelaySession session = new RelaySession(STOMP_RELAY_SYSTEM_SESSION_ID) {
@Override
protected void sendMessageToClient(Message<?> message) {
// ignore, only used to send messages
// TODO: ERROR frame/reconnect
}
};
this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setAcceptVersion("1.1,1.2");
headers.setLogin(this.systemLogin);
headers.setPasscode(this.systemPasscode);
headers.setHeartbeat(0,0); // TODO
if (logger.isDebugEnabled()) {
logger.debug("Sending STOMP CONNECT frame to initialize \"system\" TCP connection");
}
Message<?> message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build();
session.open(message);
}
@Override
public void stop() {
synchronized (this.lifecycleMonitor) {
if (logger.isDebugEnabled()) {
logger.debug("Stopping STOMP broker relay");
}
this.running = false;
try {
this.tcpClient.close().await(5000, TimeUnit.MILLISECONDS);
}
catch (Throwable t) {
logger.error("Failed to close reactor TCP client", t);
}
try {
this.environment.shutdown();
}
catch (InterruptedException e) {
// ignore
catch (Throwable t) {
logger.error("Failed to shut down reactor Environment", t);
}
}
}
@ -224,75 +255,87 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
}
@Override
public void handleConnect(Message<?> message) {
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
String sessionId = stompHeaders.getSessionId();
if (sessionId == null) {
logger.error("No sessionId in message " + message);
return;
}
RelaySession relaySession = new RelaySession(message, stompHeaders);
this.relaySessions.put(sessionId, relaySession);
}
@Override
public void handlePublish(Message<?> message) {
forwardMessage(message, StompCommand.SEND);
}
@Override
public void handleSubscribe(Message<?> message) {
forwardMessage(message, StompCommand.SUBSCRIBE);
}
@Override
public void handleUnsubscribe(Message<?> message) {
forwardMessage(message, StompCommand.UNSUBSCRIBE);
}
@Override
public void handleDisconnect(Message<?> message) {
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
if (stompHeaders.getStompCommand() != null) {
forwardMessage(message, StompCommand.DISCONNECT);
}
String sessionId = stompHeaders.getSessionId();
if (sessionId == null) {
logger.error("No sessionId in message " + message);
return;
}
}
@Override
public void handleOther(Message<?> message) {
StompCommand command = (StompCommand) message.getHeaders().get(SimpMessageHeaderAccessor.PROTOCOL_MESSAGE_TYPE);
Assert.notNull(command, "Expected STOMP command: " + message.getHeaders());
forwardMessage(message, command);
}
private void forwardMessage(Message<?> message, StompCommand command) {
public void handleMessage(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setStompCommandIfNotSet(command);
String sessionId = headers.getSessionId();
if (sessionId == null) {
if (StompCommand.SEND.equals(command)) {
sessionId = STOMP_RELAY_SYSTEM_SESSION_ID;
}
else {
logger.error("No sessionId in message " + message);
return;
}
}
String destination = headers.getDestination();
StompCommand command = headers.getStompCommand();
SimpMessageType messageType = headers.getMessageType();
RelaySession session = this.relaySessions.get(sessionId);
if (session == null) {
logger.warn("Session id=" + sessionId + " not found. Message cannot be forwarded: " + message);
if (!this.running) {
if (logger.isTraceEnabled()) {
logger.trace("STOMP broker relay not running. Ignoring message id=" + headers.getId());
}
return;
}
session.forward(message, headers);
if (SimpMessageType.MESSAGE.equals(messageType)) {
sessionId = (sessionId == null) ? STOMP_RELAY_SYSTEM_SESSION_ID : sessionId;
headers.setSessionId(sessionId);
command = (command == null) ? StompCommand.SEND : command;
headers.setStompCommandIfNotSet(command);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
}
if (headers.getStompCommand() == null) {
logger.error("Ignoring message, no STOMP command: " + message);
return;
}
if (sessionId == null) {
logger.error("Ignoring message, no sessionId: " + message);
return;
}
if (command.requiresDestination() && (destination == null)) {
logger.error("Ignoring " + command + " message, no destination: " + message);
return;
}
try {
if ((destination == null) || supportsDestination(destination)) {
if (logger.isTraceEnabled()) {
logger.trace("Processing message: " + message);
}
handleInternal(message, messageType, sessionId);
}
}
catch (Throwable t) {
logger.error("Failed to handle message " + message, t);
}
}
protected boolean supportsDestination(String destination) {
for (String prefix : this.destinationPrefixes) {
if (destination.startsWith(prefix)) {
return true;
}
}
return false;
}
protected void handleInternal(Message<?> message, SimpMessageType messageType, String sessionId) {
if (SimpMessageType.CONNECT.equals(messageType)) {
RelaySession session = new RelaySession(sessionId);
this.relaySessions.put(sessionId, session);
session.open(message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
RelaySession session = this.relaySessions.remove(sessionId);
if (session != null) {
if (logger.isTraceEnabled()) {
logger.trace("Session already removed, sessionId=" + sessionId);
}
session.forward(message);
}
}
else {
RelaySession session = this.relaySessions.get(sessionId);
if (session == null) {
logger.warn("Session id=" + sessionId + " not found. Ignoring message: " + message);
return;
}
session.forward(message);
}
}
@ -300,21 +343,23 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
private final String sessionId;
private final Promise<TcpConnection<String, String>> promise;
private final BlockingQueue<Message<?>> messageQueue = new LinkedBlockingQueue<Message<?>>(50);
private final Object monitor = new Object();
private Promise<TcpConnection<String, String>> promise;
private volatile boolean isConnected = false;
private final Object monitor = new Object();
public RelaySession(final Message<?> message, final StompHeaderAccessor stompHeaders) {
public RelaySession(String sessionId) {
Assert.notNull(sessionId, "sessionId is required");
this.sessionId = sessionId;
}
public void open(final Message<?> message) {
Assert.notNull(message, "message is required");
Assert.notNull(stompHeaders, "stompHeaders is required");
this.sessionId = stompHeaders.getSessionId();
this.promise = tcpClient.open();
this.promise.consume(new Consumer<TcpConnection<String,String>>() {
@ -326,11 +371,9 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
readStompFrame(stompFrame);
}
});
stompHeaders.setHeartbeat(0,0); // TODO
forwardInternal(message, stompHeaders, connection);
forwardInternal(message, connection);
}
});
this.promise.onError(new Consumer<Throwable>() {
@Override
public void accept(Throwable ex) {
@ -339,14 +382,12 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
sendError(sessionId, "Failed to connect to message broker " + ex.toString());
}
});
// TODO: ATM no way to detect closed socket
}
private void readStompFrame(String stompFrame) {
// heartbeat
if (StringUtils.isEmpty(stompFrame)) {
// heartbeat?
return;
}
@ -359,13 +400,13 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
if (StompCommand.CONNECTED == headers.getStompCommand()) {
synchronized(this.monitor) {
this.isConnected = true;
flushMessages(promise.get());
flushMessages(this.promise.get());
}
return;
}
if (StompCommand.ERROR == headers.getStompCommand()) {
if (logger.isDebugEnabled()) {
logger.warn("STOMP ERROR: " + headers.getMessage() + ". Removing session: " + this.sessionId);
logger.warn("STOMP ERROR: " + headers.getMessage() + ". Removing session id=" + this.sessionId);
}
relaySessions.remove(this.sessionId);
}
@ -388,14 +429,14 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
sendMessageToClient(errorMessage);
}
public void forward(Message<?> message, StompHeaderAccessor headers) {
public void forward(Message<?> message) {
if (!this.isConnected) {
synchronized(this.monitor) {
if (!this.isConnected) {
this.messageQueue.add(message);
if (logger.isTraceEnabled()) {
logger.trace("Queued message " + message + ", queue size=" + this.messageQueue.size());
logger.trace("Not connected yet, message queued, queue size=" + this.messageQueue.size());
}
return;
}
@ -405,7 +446,7 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
TcpConnection<String, String> connection = this.promise.get();
if (this.messageQueue.isEmpty()) {
forwardInternal(message, headers, connection);
forwardInternal(message, connection);
}
else {
this.messageQueue.add(message);
@ -413,36 +454,37 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
}
}
private void flushMessages(TcpConnection<String, String> connection) {
List<Message<?>> messages = new ArrayList<Message<?>>();
this.messageQueue.drainTo(messages);
for (Message<?> message : messages) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (!forwardInternal(message, headers, connection)) {
return;
}
}
}
private boolean forwardInternal(Message<?> message, TcpConnection<String, String> connection) {
private boolean forwardInternal(Message<?> message, StompHeaderAccessor headers, TcpConnection<String, String> connection) {
try {
headers.setStompCommandIfNotSet(StompCommand.SEND);
if (logger.isTraceEnabled()) {
logger.trace("Forwarding message " + message);
logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId());
}
byte[] bytes = stompMessageConverter.fromMessage(message);
connection.send(new String(bytes, Charset.forName("UTF-8")));
}
catch (Throwable ex) {
logger.error("Failed to forward message " + message, ex);
connection.close();
logger.error("Forward failed message id=" + message.getHeaders().getId(), ex);
try {
connection.close();
}
catch (Throwable t) {
// ignore
}
sendError(this.sessionId, "Failed to forward message " + message + ": " + ex.getMessage());
return false;
}
return true;
}
}
private void flushMessages(TcpConnection<String, String> connection) {
List<Message<?>> messages = new ArrayList<Message<?>>();
this.messageQueue.drainTo(messages);
for (Message<?> message : messages) {
if (!forwardInternal(message, connection)) {
return;
}
}
}
}
}

View File

@ -16,8 +16,11 @@
package org.springframework.messaging.simp.stomp;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.messaging.simp.SimpMessageType;
@ -49,21 +52,28 @@ public enum StompCommand {
ERROR;
private static Map<StompCommand, SimpMessageType> commandToMessageType = new HashMap<StompCommand, SimpMessageType>();
private static Map<StompCommand, SimpMessageType> messageTypeLookup = new HashMap<StompCommand, SimpMessageType>();
private static Set<StompCommand> destinationRequiredLookup =
new HashSet<StompCommand>(Arrays.asList(SEND, SUBSCRIBE, MESSAGE));
static {
commandToMessageType.put(StompCommand.CONNECT, SimpMessageType.CONNECT);
commandToMessageType.put(StompCommand.STOMP, SimpMessageType.CONNECT);
commandToMessageType.put(StompCommand.SEND, SimpMessageType.MESSAGE);
commandToMessageType.put(StompCommand.MESSAGE, SimpMessageType.MESSAGE);
commandToMessageType.put(StompCommand.SUBSCRIBE, SimpMessageType.SUBSCRIBE);
commandToMessageType.put(StompCommand.UNSUBSCRIBE, SimpMessageType.UNSUBSCRIBE);
commandToMessageType.put(StompCommand.DISCONNECT, SimpMessageType.DISCONNECT);
messageTypeLookup.put(StompCommand.CONNECT, SimpMessageType.CONNECT);
messageTypeLookup.put(StompCommand.STOMP, SimpMessageType.CONNECT);
messageTypeLookup.put(StompCommand.SEND, SimpMessageType.MESSAGE);
messageTypeLookup.put(StompCommand.MESSAGE, SimpMessageType.MESSAGE);
messageTypeLookup.put(StompCommand.SUBSCRIBE, SimpMessageType.SUBSCRIBE);
messageTypeLookup.put(StompCommand.UNSUBSCRIBE, SimpMessageType.UNSUBSCRIBE);
messageTypeLookup.put(StompCommand.DISCONNECT, SimpMessageType.DISCONNECT);
}
public SimpMessageType getMessageType() {
SimpMessageType messageType = commandToMessageType.get(this);
return (messageType != null) ? messageType : SimpMessageType.OTHER;
SimpMessageType type = messageTypeLookup.get(this);
return (type != null) ? type : SimpMessageType.OTHER;
}
public boolean requiresDestination() {
return destinationRequiredLookup.contains(this);
}
}

View File

@ -27,6 +27,7 @@ import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
@ -84,21 +85,31 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
*/
private StompHeaderAccessor(StompCommand command, Map<String, List<String>> externalSourceHeaders) {
super(command.getMessageType(), command, externalSourceHeaders);
initSimpMessageHeaders();
if (externalSourceHeaders != null) {
setSimpMessageHeaders(externalSourceHeaders);
}
}
private void initSimpMessageHeaders() {
String destination = getFirstNativeHeader(DESTINATION);
if (destination != null) {
super.setDestination(destination);
private void setSimpMessageHeaders(Map<String, List<String>> extHeaders) {
List<String> values = extHeaders.get(StompHeaderAccessor.DESTINATION);
if (!CollectionUtils.isEmpty(values)) {
super.setDestination(values.get(0));
}
String contentType = getFirstNativeHeader(CONTENT_TYPE);
if (contentType != null) {
super.setContentType(MediaType.parseMediaType(contentType));
values = extHeaders.get(StompHeaderAccessor.CONTENT_TYPE);
if (!CollectionUtils.isEmpty(values)) {
super.setContentType(MediaType.parseMediaType(values.get(0)));
}
if (StompCommand.SUBSCRIBE.equals(getStompCommand()) || StompCommand.UNSUBSCRIBE.equals(getStompCommand())) {
if (getFirstNativeHeader(STOMP_ID) != null) {
super.setSubscriptionId(getFirstNativeHeader(STOMP_ID));
StompCommand command = getStompCommand();
if (StompCommand.SUBSCRIBE.equals(command) || StompCommand.UNSUBSCRIBE.equals(command)) {
values = extHeaders.get(StompHeaderAccessor.STOMP_ID);
if (!CollectionUtils.isEmpty(values)) {
super.setSubscriptionId(values.get(0));
}
}
else if (StompCommand.MESSAGE.equals(command)) {
values = extHeaders.get(StompHeaderAccessor.SUBSCRIPTION);
if (!CollectionUtils.isEmpty(values)) {
super.setSubscriptionId(values.get(0));
}
}
}

View File

@ -26,7 +26,6 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.CloseStatus;
@ -176,10 +175,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
this.sessions.remove(session.getId());
String sessionId = session.getId();
this.sessions.remove(sessionId);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
headers.setSessionId(session.getId());
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(sessionId);
Message<?> message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build();
this.clientInputChannel.send(message);
}

View File

@ -21,6 +21,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -202,10 +203,22 @@ public class MessageHeaderAccessor {
}
}
public UUID getId() {
return (UUID) getHeader(MessageHeaders.ID);
}
public Long getTimestamp() {
return (Long) getHeader(MessageHeaders.TIMESTAMP);
}
public void setReplyChannel(MessageChannel replyChannel) {
setHeader(MessageHeaders.REPLY_CHANNEL, replyChannel);
}
public Object getReplyChannel() {
return getHeader(MessageHeaders.REPLY_CHANNEL);
}
public void setReplyChannelName(String replyChannelName) {
setHeader(MessageHeaders.REPLY_CHANNEL, replyChannelName);
}
@ -214,6 +227,10 @@ public class MessageHeaderAccessor {
setHeader(MessageHeaders.ERROR_CHANNEL, errorChannel);
}
public Object getErrorChannel() {
return getHeader(MessageHeaders.ERROR_CHANNEL);
}
public void setErrorChannelName(String errorChannelName) {
setHeader(MessageHeaders.ERROR_CHANNEL, errorChannelName);
}

View File

@ -50,7 +50,6 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor {
* A constructor for creating new headers, accepting an optional native header map.
*/
public NativeMessageHeaderAccessor(Map<String, List<String>> nativeHeaders) {
super();
this.originalNativeHeaders = nativeHeaders;
}

View File

@ -0,0 +1,104 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support.channel;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.BeanNameAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
/**
* Abstract base class for {@link SubscribableChannel} implementations.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractSubscribableChannel implements SubscribableChannel, BeanNameAware {
protected Log logger = LogFactory.getLog(getClass());
private String beanName;
public AbstractSubscribableChannel() {
this.beanName = getClass().getSimpleName() + "@" + ObjectUtils.getIdentityHexString(this);
}
/**
* {@inheritDoc}
* <p>Used primarily for logging purposes.
*/
@Override
public void setBeanName(String name) {
this.beanName = name;
}
/**
* @return the name for this channel.
*/
public String getBeanName() {
return this.beanName;
}
@Override
public final boolean send(Message<?> message) {
return send(message, INDEFINITE_TIMEOUT);
}
@Override
public final boolean send(Message<?> message, long timeout) {
Assert.notNull(message, "Message must not be null");
if (logger.isTraceEnabled()) {
logger.trace("[" + this.beanName + "] sending message " + message);
}
return sendInternal(message, timeout);
}
protected abstract boolean sendInternal(Message<?> message, long timeout);
@Override
public final boolean subscribe(MessageHandler handler) {
if (hasSubscription(handler)) {
logger.warn("[" + this.beanName + "] handler already subscribed " + handler);
return false;
}
if (logger.isDebugEnabled()) {
logger.debug("[" + this.beanName + "] subscribing " + handler);
}
return subscribeInternal(handler);
}
protected abstract boolean hasSubscription(MessageHandler handler);
protected abstract boolean subscribeInternal(MessageHandler handler);
@Override
public final boolean unsubscribe(MessageHandler handler) {
if (logger.isDebugEnabled()) {
logger.debug("[" + this.beanName + "] unsubscribing " + handler);
}
return unsubscribeInternal(handler);
}
protected abstract boolean unsubscribeInternal(MessageHandler handler);
}

View File

@ -19,16 +19,14 @@ package org.springframework.messaging.support.channel;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import reactor.core.Reactor;
import reactor.event.Event;
import reactor.event.registry.Registration;
import reactor.event.selector.ObjectSelector;
import reactor.event.selector.Selector;
import reactor.function.Consumer;
@ -36,88 +34,52 @@ import reactor.function.Consumer;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ReactorMessageChannel implements SubscribableChannel {
private static Log logger = LogFactory.getLog(ReactorMessageChannel.class);
public class ReactorSubscribableChannel extends AbstractSubscribableChannel {
private final Reactor reactor;
private final Object key = new Object();
private String name = toString(); // TODO
private final Map<MessageHandler, Registration<?>> registrations = new HashMap<MessageHandler, Registration<?>>();
private final Map<MessageHandler, Registration<?>> registrations =
new HashMap<MessageHandler, Registration<?>>();
public ReactorMessageChannel(Reactor reactor) {
public ReactorSubscribableChannel(Reactor reactor) {
this.reactor = reactor;
}
public void setName(String name) {
this.name = name;
}
public String getName() {
return this.name;
@Override
protected boolean hasSubscription(MessageHandler handler) {
return this.registrations.containsKey(handler);
}
@Override
public boolean send(Message<?> message) {
return send(message, -1);
}
@Override
public boolean send(Message<?> message, long timeout) {
if (logger.isTraceEnabled()) {
logger.trace("Channel " + getName() + ", sending message id=" + message.getHeaders().getId());
}
public boolean sendInternal(Message<?> message, long timeout) {
this.reactor.notify(this.key, Event.wrap(message));
return true;
}
@Override
public boolean subscribe(final MessageHandler handler) {
if (this.registrations.containsKey(handler)) {
logger.warn("Channel " + getName() + ", handler already subscribed " + handler);
return false;
}
if (logger.isTraceEnabled()) {
logger.trace("Channel " + getName() + ", subscribing handler " + handler);
}
Registration<Consumer<Event<Message<?>>>> registration = this.reactor.on(
ObjectSelector.objectSelector(key), new MessageHandlerConsumer(handler));
public boolean subscribeInternal(final MessageHandler handler) {
Selector selector = ObjectSelector.objectSelector(this.key);
MessageHandlerConsumer consumer = new MessageHandlerConsumer(handler);
Registration<Consumer<Event<Message<?>>>> registration = this.reactor.on(selector, consumer);
this.registrations.put(handler, registration);
return true;
}
@Override
public boolean unsubscribe(MessageHandler handler) {
if (logger.isTraceEnabled()) {
logger.trace("Channel " + getName() + ", removing subscription for handler " + handler);
}
public boolean unsubscribeInternal(MessageHandler handler) {
Registration<?> registration = this.registrations.remove(handler);
if (registration == null) {
if (logger.isTraceEnabled()) {
logger.trace("Channel " + getName() + ", no subscription for handler " + handler);
}
return false;
if (registration != null) {
registration.cancel();
return true;
}
registration.cancel();
return true;
return false;
}
private static final class MessageHandlerConsumer implements Consumer<Event<Message<?>>> {
private final class MessageHandlerConsumer implements Consumer<Event<Message<?>>> {
private final MessageHandler handler;
@ -132,10 +94,8 @@ public class ReactorMessageChannel implements SubscribableChannel {
this.handler.handleMessage(message);
}
catch (Throwable t) {
// TODO
logger.error("Failed to process message " + message, t);
}
}
}
}

View File

@ -18,80 +18,75 @@ package org.springframework.messaging.support.channel;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Executor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
/**
* A {@link SubscribableChannel} that sends messages to each of its subscribers.
*
* @author Phillip Webb
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PublishSubscribeChannel implements SubscribableChannel {
public class TaskExecutorSubscribableChannel extends AbstractSubscribableChannel {
private final Executor executor;
private final TaskExecutor executor;
private final Set<MessageHandler> handlers = new CopyOnWriteArraySet<MessageHandler>();
/**
* Create a new {@link PublishSubscribeChannel} instance where messages will be sent
* Create a new {@link TaskExecutorSubscribableChannel} instance where messages will be sent
* in the callers thread.
*/
public PublishSubscribeChannel() {
public TaskExecutorSubscribableChannel() {
this(null);
}
/**
* Create a new {@link PublishSubscribeChannel} instance where messages will be sent
* Create a new {@link TaskExecutorSubscribableChannel} instance where messages will be sent
* via the specified executor.
* @param executor the executor used to send the message or {@code null} to execute in
* the callers thread.
*/
public PublishSubscribeChannel(Executor executor) {
public TaskExecutorSubscribableChannel(TaskExecutor executor) {
this.executor = executor;
}
@Override
public boolean send(Message<?> message) {
return send(message, INDEFINITE_TIMEOUT);
protected boolean hasSubscription(MessageHandler handler) {
return this.handlers.contains(handler);
}
@Override
public boolean send(Message<?> message, long timeout) {
Assert.notNull(message, "Message must not be null");
Assert.notNull(message.getPayload(), "Message payload must not be null");
public boolean sendInternal(final Message<?> message, long timeout) {
for (final MessageHandler handler : this.handlers) {
dispatchToHandler(message, handler);
if (this.executor == null) {
handler.handleMessage(message);
}
else {
this.executor.execute(new Runnable() {
@Override
public void run() {
handler.handleMessage(message);
}
});
}
}
return true;
}
private void dispatchToHandler(final Message<?> message, final MessageHandler handler) {
if (this.executor == null) {
handler.handleMessage(message);
}
else {
this.executor.execute(new Runnable() {
@Override
public void run() {
handler.handleMessage(message);
}
});
}
}
@Override
public boolean subscribe(MessageHandler handler) {
public boolean subscribeInternal(MessageHandler handler) {
return this.handlers.add(handler);
}
@Override
public boolean unsubscribe(MessageHandler handler) {
public boolean unsubscribeInternal(MessageHandler handler) {
return this.handlers.remove(handler);
}

View File

@ -25,7 +25,6 @@ import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.DefaultSubscriptionRegistry;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.MultiValueMap;
@ -49,30 +48,30 @@ public class DefaultSubscriptionRegistryTests {
@Test
public void addSubscriptionInvalidInput() {
public void registerSubscriptionInvalidInput() {
String sessId = "sess01";
String subsId = "subs01";
String dest = "/foo";
this.registry.addSubscription(subscribeMessage(null, subsId, dest));
this.registry.registerSubscription(subscribeMessage(null, subsId, dest));
assertEquals(0, this.registry.findSubscriptions(message(dest)).size());
this.registry.addSubscription(subscribeMessage(sessId, null, dest));
this.registry.registerSubscription(subscribeMessage(sessId, null, dest));
assertEquals(0, this.registry.findSubscriptions(message(dest)).size());
this.registry.addSubscription(subscribeMessage(sessId, subsId, null));
this.registry.registerSubscription(subscribeMessage(sessId, subsId, null));
assertEquals(0, this.registry.findSubscriptions(message(dest)).size());
}
@Test
public void addSubscription() {
public void registerSubscription() {
String sessId = "sess01";
String subsId = "subs01";
String dest = "/foo";
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest));
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
assertEquals("Expected one element " + actual, 1, actual.size());
@ -80,14 +79,14 @@ public class DefaultSubscriptionRegistryTests {
}
@Test
public void addSubscriptionOneSession() {
public void registerSubscriptionOneSession() {
String sessId = "sess01";
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
String dest = "/foo";
for (String subId : subscriptionIds) {
this.registry.addSubscription(subscribeMessage(sessId, subId, dest));
this.registry.registerSubscription(subscribeMessage(sessId, subId, dest));
}
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
@ -97,7 +96,7 @@ public class DefaultSubscriptionRegistryTests {
}
@Test
public void addSubscriptionMultipleSessions() {
public void registerSubscriptionMultipleSessions() {
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
@ -105,7 +104,7 @@ public class DefaultSubscriptionRegistryTests {
for (String sessId : sessIds) {
for (String subsId : subscriptionIds) {
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest));
}
}
@ -118,14 +117,14 @@ public class DefaultSubscriptionRegistryTests {
}
@Test
public void addSubscriptionWithDestinationPattern() {
public void registerSubscriptionWithDestinationPattern() {
String sessId = "sess01";
String subsId = "subs01";
String destPattern = "/topic/PRICE.STOCK.*.IBM";
String dest = "/topic/PRICE.STOCK.NASDAQ.IBM";
this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern));
this.registry.registerSubscription(subscribeMessage(sessId, subsId, destPattern));
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
assertEquals("Expected one element " + actual, 1, actual.size());
@ -133,13 +132,13 @@ public class DefaultSubscriptionRegistryTests {
}
@Test
public void addSubscriptionWithDestinationPatternRegex() {
public void registerSubscriptionWithDestinationPatternRegex() {
String sessId = "sess01";
String subsId = "subs01";
String destPattern = "/topic/PRICE.STOCK.*.{ticker:(IBM|MSFT)}";
this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern));
this.registry.registerSubscription(subscribeMessage(sessId, subsId, destPattern));
Message<?> message = message("/topic/PRICE.STOCK.NASDAQ.IBM");
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message);
@ -159,7 +158,7 @@ public class DefaultSubscriptionRegistryTests {
}
@Test
public void removeSubscription() {
public void unregisterSubscription() {
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
@ -167,13 +166,13 @@ public class DefaultSubscriptionRegistryTests {
for (String sessId : sessIds) {
for (String subsId : subscriptionIds) {
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest));
}
}
this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(0)));
this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(1)));
this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(2)));
this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(0)));
this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(1)));
this.registry.unregisterSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(2)));
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
@ -183,7 +182,7 @@ public class DefaultSubscriptionRegistryTests {
}
@Test
public void removeSessionSubscriptions() {
public void unregisterAllSubscriptions() {
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
@ -191,12 +190,12 @@ public class DefaultSubscriptionRegistryTests {
for (String sessId : sessIds) {
for (String subsId : subscriptionIds) {
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
this.registry.registerSubscription(subscribeMessage(sessId, subsId, dest));
}
}
this.registry.removeSessionSubscriptions(sessIds.get(0));
this.registry.removeSessionSubscriptions(sessIds.get(1));
this.registry.unregisterAllSubscriptions(sessIds.get(0));
this.registry.unregisterAllSubscriptions(sessIds.get(1));
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
@ -204,6 +203,12 @@ public class DefaultSubscriptionRegistryTests {
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2))));
}
@Test
public void unregisterAllSubscriptionsNoMatch() {
this.registry.unregisterAllSubscriptions("bogus");
// no exceptions
}
@Test
public void findSubscriptionsNoMatches() {
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message("/foo"));

View File

@ -16,8 +16,6 @@
package org.springframework.messaging.simp.handler;
import java.util.Arrays;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
@ -30,7 +28,6 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
@ -57,26 +54,19 @@ public class SimpleBrokerWebMessageHandlerTests {
}
@Test
public void getSupportedMessageTypes() {
assertEquals(Arrays.asList(SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE,
SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT),
this.messageHandler.getSupportedMessageTypes());
}
@Test
public void subcribePublish() {
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub1", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub2", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub3", "/bar"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub1", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub2", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub3", "/bar"));
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub1", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub2", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub3", "/bar"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub1", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub2", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub3", "/bar"));
this.messageHandler.handlePublish(createMessage("/foo", "message1"));
this.messageHandler.handlePublish(createMessage("/bar", "message2"));
this.messageHandler.handleMessage(createMessage("/foo", "message1"));
this.messageHandler.handleMessage(createMessage("/bar", "message2"));
verify(this.clientChannel, times(6)).send(this.messageCaptor.capture());
assertCapturedMessage("sess1", "sub1", "/foo");
@ -93,21 +83,21 @@ public class SimpleBrokerWebMessageHandlerTests {
String sess1 = "sess1";
String sess2 = "sess2";
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub1", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub2", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub3", "/bar"));
this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub1", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub2", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage(sess1, "sub3", "/bar"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub1", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub2", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub3", "/bar"));
this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub1", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub2", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage(sess2, "sub3", "/bar"));
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
headers.setSessionId(sess1);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build();
this.messageHandler.handleDisconnect(message);
this.messageHandler.handleMessage(message);
this.messageHandler.handlePublish(createMessage("/foo", "message1"));
this.messageHandler.handlePublish(createMessage("/bar", "message2"));
this.messageHandler.handleMessage(createMessage("/foo", "message1"));
this.messageHandler.handleMessage(createMessage("/bar", "message2"));
verify(this.clientChannel, times(3)).send(this.messageCaptor.capture());
assertCapturedMessage(sess2, "sub1", "/foo");

View File

@ -16,8 +16,6 @@
package org.springframework.messaging.support.channel;
import java.util.concurrent.Executor;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@ -26,18 +24,19 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.channel.PublishSubscribeChannel;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
import static org.mockito.Mockito.*;
/**
* Tests for {@link PublishSubscribeChannel}.
* Tests for {@link TaskExecutorSubscribableChannel}.
*
* @author Phillip Webb
*/
@ -47,7 +46,7 @@ public class PublishSubscibeChannelTests {
public ExpectedException thrown = ExpectedException.none();
private PublishSubscribeChannel channel = new PublishSubscribeChannel();
private TaskExecutorSubscribableChannel channel = new TaskExecutorSubscribableChannel();
@Mock
private MessageHandler handler;
@ -89,8 +88,8 @@ public class PublishSubscibeChannelTests {
@Test
public void sendWithExecutor() throws Exception {
Executor executor = mock(Executor.class);
this.channel = new PublishSubscribeChannel(executor);
TaskExecutor executor = mock(TaskExecutor.class);
this.channel = new TaskExecutorSubscribableChannel(executor);
this.channel.subscribe(this.handler);
this.channel.send(this.message);
verify(executor).execute(this.runnableCaptor.capture());