Add SubscribableChannel and ReactorMessageChannel

This commit is contained in:
Rossen Stoyanchev 2013-06-13 01:13:37 -04:00
parent a1cfa3832e
commit 3e0aac08dc
10 changed files with 463 additions and 156 deletions

View File

@ -0,0 +1,51 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging;
/**
* Base interface for any component that handles Messages.
*
* @author Mark Fisher
* @author Iwein Fuld
* @since 4.0
*/
public interface MessageHandler {
/**
* TODO: support exceptions?
*
* Handles the message if possible. If the handler cannot deal with the
* message this will result in a <code>MessageRejectedException</code> e.g.
* in case of a Selective Consumer. When a consumer tries to handle a
* message, but fails to do so, a <code>MessageHandlingException</code> is
* thrown. In the last case it is recommended to treat the message as tainted
* and go into an error scenario.
* <p>
* When the handling results in a failure of another message being sent
* (e.g. a "reply" message), that failure will trigger a
* <code>MessageDeliveryException</code>.
*
* @param message the message to be handled
* @throws org.springframework.integration.MessageRejectedException if the handler doesn't accept the message
* @throws org.springframework.integration.MessageHandlingException when something fails during the handling
* @throws org.springframework.integration.MessageDeliveryException when this handler failed to deliver the
* reply related to the handling of the message
*/
void handleMessage(Message<?> message) throws MessagingException;
}

View File

@ -0,0 +1,40 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging;
/**
* Interface for any MessageChannel implementation that accepts subscribers.
* The subscribers must implement the {@link MessageHandler} interface and
* will be invoked when a Message is available.
*
* @author Mark Fisher
* @since 4.0
*/
public interface SubscribableChannel extends MessageChannel {
/**
* Register a {@link MessageHandler} as a subscriber to this channel.
*/
boolean subscribe(MessageHandler handler);
/**
* Remove a {@link MessageHandler} from the subscribers of this channel.
*/
boolean unsubscribe(MessageHandler handler);
}

View File

@ -18,37 +18,35 @@ package org.springframework.web.messaging.service;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.PathMatcher;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.event.EventConsumer;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractMessageService {
public abstract class AbstractPubSubMessageHandler implements MessageHandler {
protected final Log logger = LogFactory.getLog(getClass());
private final MessageChannel publishChannel;
public static final String CLIENT_TO_SERVER_MESSAGE_KEY = "clientToServerMessageKey";
public static final String CLIENT_CONNECTION_CLOSED_KEY = "clientConnectionClosed";
public static final String SERVER_TO_CLIENT_MESSAGE_KEY = "serverToClientMessageKey";
private final EventBus eventBus;
private final MessageChannel clientChannel;
private final List<String> allowedDestinations = new ArrayList<String>();
@ -57,56 +55,31 @@ public abstract class AbstractMessageService {
private final PathMatcher pathMatcher = new AntPathMatcher();
public AbstractMessageService(EventBus reactor) {
/**
* @param publishChannel a channel for publishing messages from within the
* application; this constructor will also automatically subscribe the
* current instance to this channel
*
* @param clientChannel a channel for sending messages to connected clients.
*/
public AbstractPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) {
Assert.notNull(reactor, "reactor is required");
this.eventBus = reactor;
Assert.notNull(publishChannel, "publishChannel is required");
Assert.notNull(clientChannel, "clientChannel is required");
this.eventBus.registerConsumer(CLIENT_TO_SERVER_MESSAGE_KEY, new EventConsumer<Message<?>>() {
publishChannel.subscribe(this);
this.publishChannel = publishChannel;
@Override
public void accept(Message<?> message) {
if (!isAllowedDestination(message)) {
return;
}
if (logger.isTraceEnabled()) {
logger.trace("Processing message id=" + message.getHeaders().getId());
}
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
MessageType messageType = headers.getMessageType();
if (messageType == null || messageType.equals(MessageType.OTHER)) {
processOther(message);
}
else if (MessageType.CONNECT.equals(messageType)) {
processConnect(message);
}
else if (MessageType.MESSAGE.equals(messageType)) {
processMessage(message);
}
else if (MessageType.SUBSCRIBE.equals(messageType)) {
processSubscribe(message);
}
else if (MessageType.UNSUBSCRIBE.equals(messageType)) {
processUnsubscribe(message);
}
else if (MessageType.DISCONNECT.equals(messageType)) {
processDisconnect(message);
}
}
});
this.eventBus.registerConsumer(CLIENT_CONNECTION_CLOSED_KEY, new EventConsumer<String>() {
@Override
public void accept(String sessionId) {
processClientConnectionClosed(sessionId);
}
});
this.clientChannel = clientChannel;
}
public MessageChannel getPublishChannel() {
return this.publishChannel;
}
public MessageChannel getClientChannel() {
return this.clientChannel;
}
/**
* Ant-style destination patterns that this service is allowed to process.
@ -124,16 +97,29 @@ public abstract class AbstractMessageService {
this.disallowedDestinations.addAll(Arrays.asList(patterns));
}
public EventBus getEventBus() {
return this.eventBus;
protected abstract Collection<MessageType> getSupportedMessageTypes();
protected boolean canHandle(Message<?> message, MessageType messageType) {
if (!CollectionUtils.isEmpty(getSupportedMessageTypes())) {
if (!getSupportedMessageTypes().contains(messageType)) {
return false;
}
}
return isDestinationAllowed(message);
}
private boolean isAllowedDestination(Message<?> message) {
protected boolean isDestinationAllowed(Message<?> message) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String destination = headers.getDestination();
if (destination == null) {
return true;
}
if (!this.disallowedDestinations.isEmpty()) {
for (String pattern : this.disallowedDestinations) {
if (this.pathMatcher.match(pattern, destination)) {
@ -144,6 +130,7 @@ public abstract class AbstractMessageService {
}
}
}
if (!this.allowedDestinations.isEmpty()) {
for (String pattern : this.allowedDestinations) {
if (this.pathMatcher.match(pattern, destination)) {
@ -155,28 +142,61 @@ public abstract class AbstractMessageService {
}
return false;
}
return true;
}
protected void processConnect(Message<?> message) {
@Override
public final void handleMessage(Message<?> message) throws MessagingException {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
MessageType messageType = headers.getMessageType();
if (!canHandle(message, messageType)) {
return;
}
if (logger.isTraceEnabled()) {
logger.trace("Handling message id=" + message.getHeaders().getId());
}
if (MessageType.MESSAGE.equals(messageType)) {
handlePublish(message);
}
else if (MessageType.SUBSCRIBE.equals(messageType)) {
handleSubscribe(message);
}
else if (MessageType.UNSUBSCRIBE.equals(messageType)) {
handleUnsubscribe(message);
}
else if (MessageType.CONNECT.equals(messageType)) {
handleConnect(message);
}
else if (MessageType.DISCONNECT.equals(messageType)) {
handleDisconnect(message);
}
else {
handleOther(message);
}
}
protected void processMessage(Message<?> message) {
protected void handleConnect(Message<?> message) {
}
protected void processSubscribe(Message<?> message) {
protected void handlePublish(Message<?> message) {
}
protected void processUnsubscribe(Message<?> message) {
protected void handleSubscribe(Message<?> message) {
}
protected void processDisconnect(Message<?> message) {
protected void handleUnsubscribe(Message<?> message) {
}
protected void processOther(Message<?> message) {
protected void handleDisconnect(Message<?> message) {
}
protected void processClientConnectionClosed(String sessionId) {
protected void handleOther(Message<?> message) {
}
}

View File

@ -17,34 +17,47 @@
package org.springframework.web.messaging.service;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.event.EventConsumer;
import org.springframework.web.messaging.event.EventRegistration;
import reactor.core.Reactor;
import reactor.fn.Consumer;
import reactor.fn.Event;
import reactor.fn.registry.Registration;
import reactor.fn.selector.ObjectSelector;
import reactor.fn.selector.Selector;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PubSubMessageService extends AbstractMessageService {
public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
private final Reactor reactor;
private MessageConverter payloadConverter;
private Map<String, List<EventRegistration>> subscriptionsBySession =
new ConcurrentHashMap<String, List<EventRegistration>>();
private Map<String, List<Registration<?>>> subscriptionsBySession = new ConcurrentHashMap<String, List<Registration<?>>>();
public PubSubMessageService(EventBus reactor) {
super(reactor);
public ReactorPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel,
Reactor reactor) {
super(publishChannel, clientChannel);
this.reactor = reactor;
this.payloadConverter = new CompositeMessageConverter(null);
}
@ -54,7 +67,7 @@ public class PubSubMessageService extends AbstractMessageService {
}
@Override
protected void processMessage(Message<?> message) {
public void handlePublish(Message<?> message) {
if (logger.isDebugEnabled()) {
logger.debug("Message received: " + message);
@ -66,7 +79,7 @@ public class PubSubMessageService extends AbstractMessageService {
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType());
message = new GenericMessage<byte[]>(payload, message.getHeaders());
getEventBus().send(getPublishKey(inHeaders.getDestination()), message);
this.reactor.notify(getPublishKey(inHeaders.getDestination()), Event.wrap(message));
}
catch (Exception ex) {
logger.error("Failed to publish " + message, ex);
@ -78,56 +91,69 @@ public class PubSubMessageService extends AbstractMessageService {
}
@Override
protected void processSubscribe(Message<?> message) {
protected Collection<MessageType> getSupportedMessageTypes() {
return Arrays.asList(MessageType.MESSAGE, MessageType.SUBSCRIBE, MessageType.UNSUBSCRIBE);
}
@Override
public void handleSubscribe(Message<?> message) {
if (logger.isDebugEnabled()) {
logger.debug("Subscribe " + message);
}
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
final String subscriptionId = headers.getSubscriptionId();
EventRegistration registration = getEventBus().registerConsumer(getPublishKey(headers.getDestination()),
new EventConsumer<Message<?>>() {
Selector selector = new ObjectSelector<String>(getPublishKey(headers.getDestination()));
Registration<?> registration = this.reactor.on(selector,
new Consumer<Event<Message<?>>>() {
@Override
public void accept(Message<?> message) {
public void accept(Event<Message<?>> event) {
Message<?> message = event.getData();
PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaders outHeaders = PubSubHeaders.create();
outHeaders.setDestinations(inHeaders.getDestinations());
outHeaders.setContentType(inHeaders.getContentType());
if (inHeaders.getContentType() != null) {
outHeaders.setContentType(inHeaders.getContentType());
}
outHeaders.setSubscriptionId(subscriptionId);
Object payload = message.getPayload();
message = new GenericMessage<Object>(payload, outHeaders.toMessageHeaders());
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message);
getClientChannel().send(message);
}
});
addSubscription((String) message.getHeaders().get("sessionId"), registration);
addSubscription(headers.getSessionId(), registration);
}
private void addSubscription(String sessionId, EventRegistration registration) {
List<EventRegistration> list = this.subscriptionsBySession.get(sessionId);
private void addSubscription(String sessionId, Registration<?> registration) {
List<Registration<?>> list = this.subscriptionsBySession.get(sessionId);
if (list == null) {
list = new ArrayList<EventRegistration>();
list = new ArrayList<Registration<?>>();
this.subscriptionsBySession.put(sessionId, list);
}
list.add(registration);
}
@Override
public void processDisconnect(Message<?> message) {
String sessionId = (String) message.getHeaders().get("sessionId");
removeSubscriptions(sessionId);
public void handleDisconnect(Message<?> message) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
removeSubscriptions(headers.getSessionId());
}
@Override
protected void processClientConnectionClosed(String sessionId) {
/* @Override
public void handleClientConnectionClosed(String sessionId) {
removeSubscriptions(sessionId);
}
*/
private void removeSubscriptions(String sessionId) {
List<EventRegistration> registrations = this.subscriptionsBySession.remove(sessionId);
List<Registration<?>> registrations = this.subscriptionsBySession.remove(sessionId);
if (logger.isTraceEnabled()) {
logger.trace("Cancelling " + registrations.size() + " subscriptions for session=" + sessionId);
}
for (EventRegistration registration : registrations) {
for (Registration<?> registration : registrations) {
registration.cancel();
}
}

View File

@ -19,6 +19,7 @@ package org.springframework.web.messaging.service.method;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -31,16 +32,18 @@ import org.springframework.context.ApplicationContextAware;
import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.annotation.MessageMapping;
import org.springframework.stereotype.Controller;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils.MethodFilter;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.annotation.SubscribeEvent;
import org.springframework.web.messaging.annotation.UnsubscribeEvent;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.service.AbstractMessageService;
import org.springframework.web.messaging.service.AbstractPubSubMessageHandler;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.HandlerMethodSelector;
@ -49,7 +52,8 @@ import org.springframework.web.method.HandlerMethodSelector;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class AnnotationMessageService extends AbstractMessageService implements ApplicationContextAware, InitializingBean {
public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler
implements ApplicationContextAware, InitializingBean {
private List<MessageConverter> messageConverters;
@ -66,10 +70,10 @@ public class AnnotationMessageService extends AbstractMessageService implements
private ReturnValueHandlerComposite returnValueHandlers = new ReturnValueHandlerComposite();
public AnnotationMessageService(EventBus eventBus) {
super(eventBus);
}
public AnnotationPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) {
super(publishChannel, clientChannel);
}
public void setMessageConverters(List<MessageConverter> converters) {
this.messageConverters = converters;
@ -80,12 +84,17 @@ public class AnnotationMessageService extends AbstractMessageService implements
this.applicationContext = applicationContext;
}
@Override
protected Collection<MessageType> getSupportedMessageTypes() {
return Arrays.asList(MessageType.MESSAGE, MessageType.SUBSCRIBE, MessageType.UNSUBSCRIBE);
}
@Override
public void afterPropertiesSet() {
initHandlerMethods();
this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(getEventBus()));
this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(getPublishChannel()));
this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverters));
this.returnValueHandlers.addHandler(new MessageReturnValueHandler(getEventBus()));
this.returnValueHandlers.addHandler(new MessageReturnValueHandler(getClientChannel()));
}
protected void initHandlerMethods() {
@ -151,21 +160,21 @@ public class AnnotationMessageService extends AbstractMessageService implements
}
@Override
protected void processMessage(Message<?> message) {
handleMessage(message, this.messageMethods);
public void handlePublish(Message<?> message) {
handleMessageInternal(message, this.messageMethods);
}
@Override
protected void processSubscribe(Message<?> message) {
handleMessage(message, this.subscribeMethods);
public void handleSubscribe(Message<?> message) {
handleMessageInternal(message, this.subscribeMethods);
}
@Override
protected void processUnsubscribe(Message<?> message) {
handleMessage(message, this.unsubscribeMethods);
public void handleUnsubscribe(Message<?> message) {
handleMessageInternal(message, this.unsubscribeMethods);
}
private void handleMessage(final Message<?> message, Map<MappingInfo, HandlerMethod> handlerMethods) {
private void handleMessageInternal(final Message<?> message, Map<MappingInfo, HandlerMethod> handlerMethods) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
String destination = headers.getDestination();

View File

@ -21,8 +21,6 @@ import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.service.AbstractMessageService;
import reactor.util.Assert;
@ -33,12 +31,12 @@ import reactor.util.Assert;
*/
public class MessageChannelArgumentResolver implements ArgumentResolver {
private final EventBus eventBus;
private final MessageChannel publishChannel;
public MessageChannelArgumentResolver(EventBus eventBus) {
Assert.notNull(eventBus, "reactor is required");
this.eventBus = eventBus;
public MessageChannelArgumentResolver(MessageChannel publishChannel) {
Assert.notNull(publishChannel, "publishChannel is required");
this.publishChannel = publishChannel;
}
@Override
@ -55,13 +53,15 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
@Override
public boolean send(Message<?> message) {
return send(message, -1);
}
@Override
public boolean send(Message<?> message, long timeout) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
headers.setSessionId(sessionId);
message = new GenericMessage<Object>(message.getPayload(), headers.toMessageHeaders());
eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message);
publishChannel.send(message);
return true;
}
};

View File

@ -19,9 +19,8 @@ package org.springframework.web.messaging.service.method;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.service.AbstractMessageService;
import reactor.util.Assert;
@ -32,11 +31,12 @@ import reactor.util.Assert;
*/
public class MessageReturnValueHandler implements ReturnValueHandler {
private final EventBus eventBus;
private final MessageChannel clientChannel;
public MessageReturnValueHandler(EventBus eventBus) {
this.eventBus = eventBus;
public MessageReturnValueHandler(MessageChannel clientChannel) {
Assert.notNull(clientChannel, "clientChannel is required");
this.clientChannel = clientChannel;
}
@ -76,7 +76,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
outHeaders.setSubscriptionId(subscriptionId);
returnMessage = new GenericMessage<Object>(returnMessage.getPayload(), outHeaders.toMessageHeaders());
this.eventBus.send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, returnMessage);
this.clientChannel.send(returnMessage);
}
}

View File

@ -23,6 +23,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@ -33,10 +34,12 @@ import org.springframework.core.task.TaskExecutor;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.service.AbstractMessageService;
import org.springframework.web.messaging.service.AbstractPubSubMessageHandler;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
@ -47,7 +50,10 @@ import reactor.util.Assert;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class RelayStompService extends AbstractMessageService {
public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler {
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private MessageConverter payloadConverter;
@ -55,11 +61,14 @@ public class RelayStompService extends AbstractMessageService {
private Map<String, RelaySession> relaySessions = new ConcurrentHashMap<String, RelaySession>();
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
/**
* @param executor
*/
public StompRelayPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel,
TaskExecutor executor) {
public RelayStompService(EventBus eventBus, TaskExecutor executor) {
super(eventBus);
super(publishChannel, clientChannel);
this.taskExecutor = executor; // For now, a naive way to manage socket reading
this.payloadConverter = new CompositeMessageConverter(null);
}
@ -69,7 +78,13 @@ public class RelayStompService extends AbstractMessageService {
this.payloadConverter = new CompositeMessageConverter(converters);
}
protected void processConnect(Message<?> message) {
@Override
protected Collection<MessageType> getSupportedMessageTypes() {
return null;
}
@Override
public void handleConnect(Message<?> message) {
String sessionId = (String) message.getHeaders().get("sessionId");
@ -95,7 +110,7 @@ public class RelayStompService extends AbstractMessageService {
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
String sessionId = stompHeaders.getSessionId();
RelaySession session = RelayStompService.this.relaySessions.get(sessionId);
RelaySession session = StompRelayPubSubMessageHandler.this.relaySessions.get(sessionId);
Assert.notNull(session, "RelaySession not found");
try {
@ -133,40 +148,40 @@ public class RelayStompService extends AbstractMessageService {
}
@Override
protected void processMessage(Message<?> message) {
public void handlePublish(Message<?> message) {
forwardMessage(message, StompCommand.SEND);
}
@Override
protected void processSubscribe(Message<?> message) {
public void handleSubscribe(Message<?> message) {
forwardMessage(message, StompCommand.SUBSCRIBE);
}
@Override
protected void processUnsubscribe(Message<?> message) {
public void handleUnsubscribe(Message<?> message) {
forwardMessage(message, StompCommand.UNSUBSCRIBE);
}
@Override
protected void processDisconnect(Message<?> message) {
public void handleDisconnect(Message<?> message) {
forwardMessage(message, StompCommand.DISCONNECT);
}
@Override
protected void processOther(Message<?> message) {
public void handleOther(Message<?> message) {
StompCommand command = (StompCommand) message.getHeaders().get("stompCommand");
Assert.notNull(command, "Expected STOMP command: " + message.getHeaders());
forwardMessage(message, command);
}
@Override
protected void processClientConnectionClosed(String sessionId) {
/* @Override
public void handleClientConnectionClosed(String sessionId) {
if (logger.isDebugEnabled()) {
logger.debug("Client connection closed for STOMP session=" + sessionId + ". Clearing relay session.");
}
clearRelaySession(sessionId);
}
*/
private final static class RelaySession {
@ -219,7 +234,7 @@ public class RelayStompService extends AbstractMessageService {
else if (b == 0x00) {
byte[] bytes = out.toByteArray();
Message<byte[]> message = stompMessageConverter.toMessage(bytes, sessionId);
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message);
getClientChannel().send(message);
out.reset();
}
else {
@ -241,7 +256,7 @@ public class RelayStompService extends AbstractMessageService {
stompHeaders.setMessage(message);
stompHeaders.setSessionId(this.sessionId);
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.toMessageHeaders());
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, errorMessage);
getClientChannel().send(errorMessage);
}
}

View File

@ -27,12 +27,13 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.event.EventConsumer;
import org.springframework.web.messaging.service.AbstractMessageService;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.stomp.StompHeaders;
@ -50,19 +51,22 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
private static Log logger = LogFactory.getLog(StompWebSocketHandler.class);
private final MessageChannel publishChannel;
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
private final EventBus eventBus;
private MessageConverter payloadConverter = new CompositeMessageConverter(null);
public StompWebSocketHandler(EventBus eventBus) {
this.eventBus = eventBus;
this.eventBus.registerConsumer(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY,
new ClientMessageConsumer());
public StompWebSocketHandler(MessageChannel publishChannel, SubscribableChannel clientChannel) {
Assert.notNull(publishChannel, "publishChannel is required");
Assert.notNull(clientChannel, "clientChannel is required");
this.publishChannel = publishChannel;
clientChannel.subscribe(new ClientMessageConsumer());
}
@ -115,7 +119,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
else if (MessageType.DISCONNECT.equals(messageType)) {
handleDisconnect(message);
}
this.eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message);
this.publishChannel.send(message);
}
catch (Throwable t) {
logger.error("Terminating STOMP session due to failure to send message: ", t);
@ -189,17 +193,18 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
}
}
@Override
/* @Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
this.sessions.remove(session.getId());
eventBus.send(AbstractMessageService.CLIENT_CONNECTION_CLOSED_KEY, session.getId());
}
}*/
private final class ClientMessageConsumer implements EventConsumer<Message<?>> {
private final class ClientMessageConsumer implements MessageHandler {
@Override
public void accept(Message<?> message) {
public void handleMessage(Message<?> message) {
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
stompHeaders.setStompCommandIfNotSet(StompCommand.MESSAGE);

View File

@ -0,0 +1,141 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import reactor.core.Reactor;
import reactor.fn.Consumer;
import reactor.fn.Event;
import reactor.fn.registry.Registration;
import reactor.fn.selector.ObjectSelector;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ReactorMessageChannel implements SubscribableChannel {
private static Log logger = LogFactory.getLog(ReactorMessageChannel.class);
private final Reactor reactor;
private final Object key = new Object();
private String name = toString(); // TODO
private final Map<MessageHandler, Registration<?>> registrations =
new HashMap<MessageHandler, Registration<?>>();
public ReactorMessageChannel(Reactor reactor) {
this.reactor = reactor;
}
public void setName(String name) {
this.name = name;
}
public String getName() {
return this.name;
}
@Override
public boolean send(Message<?> message) {
return send(message, -1);
}
@Override
public boolean send(Message<?> message, long timeout) {
if (logger.isTraceEnabled()) {
logger.trace("Channel " + getName() + ", sending message id=" + message.getHeaders().getId());
}
this.reactor.notify(this.key, Event.wrap(message));
return true;
}
@Override
public boolean subscribe(final MessageHandler handler) {
if (this.registrations.containsKey(handler)) {
logger.warn("Channel " + getName() + ", handler already subscribed " + handler);
return false;
}
if (logger.isTraceEnabled()) {
logger.trace("Channel " + getName() + ", subscribing handler " + handler);
}
Registration<Consumer<Event<Message<?>>>> registration = this.reactor.on(
ObjectSelector.objectSelector(key), new MessageHandlerConsumer(handler));
this.registrations.put(handler, registration);
return true;
}
@Override
public boolean unsubscribe(MessageHandler handler) {
if (logger.isTraceEnabled()) {
logger.trace("Channel " + getName() + ", removing subscription for handler " + handler);
}
Registration<?> registration = this.registrations.get(handler);
if (registration == null) {
if (logger.isTraceEnabled()) {
logger.trace("Channel " + getName() + ", no subscription for handler " + handler);
}
return false;
}
registration.cancel();
return true;
}
private static final class MessageHandlerConsumer implements Consumer<Event<Message<?>>> {
private final MessageHandler handler;
private MessageHandlerConsumer(MessageHandler handler) {
this.handler = handler;
}
@Override
public void accept(Event<Message<?>> event) {
Message<?> message = event.getData();
try {
this.handler.handleMessage(message);
}
catch (Throwable t) {
// TODO
logger.error("Failed to process message " + message, t);
}
}
}
}