Make Message type pluggable

To improve compatibility between Spring's messaging classes and
Spring Integration, the type of Message that is created has been made
pluggable through the introduction of a factory abstraction;
MessageFactory.

By default a MessageFactory is provided that will create
org.springframework.messaging.GenericMessage instances, however this
can be replaced with an alternative implementation. For example,
Spring Integration can provide an implementation that creates
org.springframework.integration.message.GenericMessage instances.

This control over the type of Message that's created allows messages
to flow from Spring messaging code into Spring Integration code without
any need for conversion. In further support of this goal,
MessageChannel, MessageHandler, and SubscribableChannel have been
genericized to make the Message type that they deal with more
flexible.
This commit is contained in:
Andy Wilkinson 2013-06-14 12:34:12 +01:00
parent 641aaf4b6a
commit 3022f5e34f
16 changed files with 188 additions and 51 deletions

View File

@ -44,7 +44,7 @@ public class GenericMessage<T> implements Message<T>, Serializable {
*
* @param payload the message payload
*/
public GenericMessage(T payload) {
protected GenericMessage(T payload) {
this(payload, null);
}
@ -56,7 +56,7 @@ public class GenericMessage<T> implements Message<T>, Serializable {
* @param headers message headers
* @see MessageHeaders
*/
public GenericMessage(T payload, Map<String, Object> headers) {
protected GenericMessage(T payload, Map<String, Object> headers) {
Assert.notNull(payload, "payload must not be null");
if (headers == null) {
headers = new HashMap<String, Object>();

View File

@ -0,0 +1,34 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging;
import java.util.Map;
/**
* A {@link MessageFactory} that creates {@link GenericMessage GenericMessages}.
*
* @author Andy Wilkinson
*/
public class GenericMessageFactory implements MessageFactory<GenericMessage<?>> {
@Override
public <P> GenericMessage<?> createMessage(P payload, Map<String, Object> headers) {
return new GenericMessage<P>(payload, headers);
}
}

View File

@ -23,7 +23,7 @@ package org.springframework.messaging;
* @author Mark Fisher
* @since 4.0
*/
public interface MessageChannel {
public interface MessageChannel<M extends Message> {
/**
* Send a {@link Message} to this channel. May throw a RuntimeException for
@ -38,7 +38,7 @@ public interface MessageChannel {
*
* @return whether or not the Message has been sent successfully
*/
boolean send(Message<?> message);
boolean send(M message);
/**
* Send a message, blocking until either the message is accepted or the
@ -51,6 +51,6 @@ public interface MessageChannel {
* <code>false</code> if the specified timeout period elapses or
* the send is interrupted
*/
boolean send(Message<?> message, long timeout);
boolean send(M message, long timeout);
}

View File

@ -0,0 +1,41 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging;
import java.util.Map;
/**
* A factory for creating messages, allowing for control of the concrete type of the message.
*
*
*
* @author Andy Wilkinson
*/
public interface MessageFactory<M extends Message<?>> {
/**
* Creates a new message with the given payload and headers
*
* @param payload The message payload
* @param headers The message headers
* @param <P> The payload's type
*
* @return the message
*/
<P> M createMessage(P payload, Map<String, Object> headers);
}

View File

@ -24,7 +24,7 @@ package org.springframework.messaging;
* @author Iwein Fuld
* @since 4.0
*/
public interface MessageHandler {
public interface MessageHandler<M extends Message> {
/**
* TODO: support exceptions?
@ -46,6 +46,6 @@ public interface MessageHandler {
* @throws org.springframework.integration.MessageDeliveryException when this handler failed to deliver the
* reply related to the handling of the message
*/
void handleMessage(Message<?> message) throws MessagingException;
void handleMessage(M message) throws MessagingException;
}

View File

@ -25,16 +25,16 @@ package org.springframework.messaging;
* @author Mark Fisher
* @since 4.0
*/
public interface SubscribableChannel extends MessageChannel {
public interface SubscribableChannel<M extends Message, H extends MessageHandler<M>> extends MessageChannel<M> {
/**
* Register a {@link MessageHandler} as a subscriber to this channel.
*/
boolean subscribe(MessageHandler handler);
boolean subscribe(H handler);
/**
* Remove a {@link MessageHandler} from the subscribers of this channel.
*/
boolean unsubscribe(MessageHandler handler);
boolean unsubscribe(H handler);
}

View File

@ -40,7 +40,7 @@ import org.springframework.web.messaging.PubSubHeaders;
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractPubSubMessageHandler implements MessageHandler {
public abstract class AbstractPubSubMessageHandler implements MessageHandler<Message<?>> {
protected final Log logger = LogFactory.getLog(getClass());
@ -54,11 +54,9 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler {
private final PathMatcher pathMatcher = new AntPathMatcher();
/**
* @param publishChannel a channel for publishing messages from within the
* application; this constructor will also automatically subscribe the
* current instance to this channel
* application
*
* @param clientChannel a channel for sending messages to connected clients.
*/
@ -67,9 +65,7 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler {
Assert.notNull(publishChannel, "publishChannel is required");
Assert.notNull(clientChannel, "clientChannel is required");
publishChannel.subscribe(this);
this.publishChannel = publishChannel;
this.clientChannel = clientChannel;
}
@ -146,7 +142,6 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler {
return true;
}
@Override
public final void handleMessage(Message<?> message) throws MessagingException {

View File

@ -23,9 +23,10 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubHeaders;
@ -50,6 +51,8 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
private MessageConverter payloadConverter;
private MessageFactory messageFactory;
private Map<String, List<Registration<?>>> subscriptionsBySession = new ConcurrentHashMap<String, List<Registration<?>>>();
@ -59,13 +62,18 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
super(publishChannel, clientChannel);
this.reactor = reactor;
this.payloadConverter = new CompositeMessageConverter(null);
this.messageFactory = new GenericMessageFactory();
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
public void setMessageConverters(List<MessageConverter> converters) {
this.payloadConverter = new CompositeMessageConverter(converters);
}
@SuppressWarnings("unchecked")
@Override
public void handlePublish(Message<?> message) {
@ -77,7 +85,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
// Convert to byte[] payload before the fan-out
PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders());
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType());
message = new GenericMessage<byte[]>(payload, message.getHeaders());
message = messageFactory.createMessage(payload, message.getHeaders());
this.reactor.notify(getPublishKey(inHeaders.getDestination()), Event.wrap(message));
}
@ -109,6 +117,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
Selector selector = new ObjectSelector<String>(getPublishKey(headers.getDestination()));
Registration<?> registration = this.reactor.on(selector,
new Consumer<Event<Message<?>>>() {
@SuppressWarnings("unchecked")
@Override
public void accept(Event<Message<?>> event) {
Message<?> message = event.getData();
@ -120,8 +129,9 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler {
}
outHeaders.setSubscriptionId(subscriptionId);
Object payload = message.getPayload();
message = new GenericMessage<Object>(payload, outHeaders.toMessageHeaders());
getClientChannel().send(message);
Message outMessage = messageFactory.createMessage(payload, outHeaders.toMessageHeaders());
getClientChannel().send(outMessage);
}
});

View File

@ -31,8 +31,10 @@ import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.annotation.MessageMapping;
import org.springframework.stereotype.Controller;
@ -69,6 +71,8 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler
private ReturnValueHandlerComposite returnValueHandlers = new ReturnValueHandlerComposite();
private MessageFactory messageFactory = new GenericMessageFactory();
public AnnotationPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) {
@ -79,6 +83,10 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler
this.messageConverters = converters;
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
@ -92,9 +100,16 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler
@Override
public void afterPropertiesSet() {
initHandlerMethods();
this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(getPublishChannel()));
MessageChannelArgumentResolver messageChannelArgumentResolver = new MessageChannelArgumentResolver(getPublishChannel());
messageChannelArgumentResolver.setMessageFactory(messageFactory);
this.argumentResolvers.addResolver(messageChannelArgumentResolver);
this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverters));
this.returnValueHandlers.addHandler(new MessageReturnValueHandler(getClientChannel()));
MessageReturnValueHandler messageReturnValueHandler = new MessageReturnValueHandler(getClientChannel());
messageReturnValueHandler.setMessageFactory(messageFactory);
this.returnValueHandlers.addHandler(messageReturnValueHandler);
}
protected void initHandlerMethods() {

View File

@ -17,9 +17,10 @@
package org.springframework.web.messaging.service.method;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.util.Assert;
import org.springframework.web.messaging.PubSubHeaders;
@ -32,10 +33,16 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
private final MessageChannel publishChannel;
private MessageFactory messageFactory;
public MessageChannelArgumentResolver(MessageChannel publishChannel) {
Assert.notNull(publishChannel, "publishChannel is required");
this.publishChannel = publishChannel;
this.messageFactory = new GenericMessageFactory();
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
@Override
@ -48,19 +55,19 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
final String sessionId = PubSubHeaders.fromMessageHeaders(message.getHeaders()).getSessionId();
return new MessageChannel() {
return new MessageChannel<Message<?>>() {
@Override
public boolean send(Message<?> message) {
return send(message, -1);
}
@SuppressWarnings("unchecked")
@Override
public boolean send(Message<?> message, long timeout) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
headers.setSessionId(sessionId);
message = new GenericMessage<Object>(message.getPayload(), headers.toMessageHeaders());
publishChannel.send(message);
publishChannel.send(messageFactory.createMessage(message.getPayload(), headers.toMessageHeaders()));
return true;
}
};

View File

@ -17,9 +17,10 @@
package org.springframework.web.messaging.service.method;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.util.Assert;
import org.springframework.web.messaging.PubSubHeaders;
@ -32,12 +33,18 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
private final MessageChannel clientChannel;
private MessageFactory messageFactory = new GenericMessageFactory();
public MessageReturnValueHandler(MessageChannel clientChannel) {
Assert.notNull(clientChannel, "clientChannel is required");
this.clientChannel = clientChannel;
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
@Override
public boolean supportsReturnType(MethodParameter returnType) {
@ -56,6 +63,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
// return Message.class.isAssignableFrom(paramType);
}
@SuppressWarnings("unchecked")
@Override
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> message)
throws Exception {
@ -73,7 +81,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
PubSubHeaders outHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders());
outHeaders.setSessionId(sessionId);
outHeaders.setSubscriptionId(subscriptionId);
returnMessage = new GenericMessage<Object>(returnMessage.getPayload(), outHeaders.toMessageHeaders());
returnMessage = messageFactory.createMessage(returnMessage.getPayload(), outHeaders.toMessageHeaders());
this.clientChannel.send(returnMessage);
}

View File

@ -22,8 +22,8 @@ import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
@ -48,11 +48,10 @@ public class StompMessageConverter {
private static final byte COLON = ':';
/**
* @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String.
*/
public Message<byte[]> toMessage(Object stompContent, String sessionId) {
public <M extends Message<?>> M toMessage(Object stompContent, String sessionId, MessageFactory<M> messageFactory) {
byte[] byteContent = null;
if (stompContent instanceof String) {
@ -103,7 +102,7 @@ public class StompMessageConverter {
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
return createMessage(command, stompHeaders.toMessageHeaders(), payload);
return createMessage(command, stompHeaders.toMessageHeaders(), payload, messageFactory);
}
private int findIndexOfPayload(byte[] bytes) {
@ -133,8 +132,8 @@ public class StompMessageConverter {
return index;
}
protected Message<byte[]> createMessage(StompCommand command, Map<String, Object> headers, byte[] payload) {
return new GenericMessage<byte[]>(payload, headers);
protected <M extends Message<?>> M createMessage(StompCommand command, Map<String, Object> headers, byte[] payload, MessageFactory<M> messageFactory) {
return messageFactory.createMessage(payload, headers);
}
public byte[] fromMessage(Message<byte[]> message) {

View File

@ -23,9 +23,10 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
import org.springframework.web.messaging.MessageType;
@ -52,17 +53,17 @@ import reactor.tcp.netty.NettyTcpClient;
*/
public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler {
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private MessageConverter payloadConverter;
private MessageFactory messageFactory = new GenericMessageFactory();
private final TcpClient<String, String> tcpClient;
private final Map<String, TcpConnection<String, String>> connections =
new ConcurrentHashMap<String, TcpConnection<String, String>>();
public StompRelayPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) {
super(publishChannel, clientChannel);
@ -81,6 +82,10 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
this.payloadConverter = new CompositeMessageConverter(converters);
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
@Override
protected Collection<MessageType> getSupportedMessageTypes() {
return null;
@ -105,13 +110,14 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
@Override
public void accept(TcpConnection<String, String> connection) {
connection.in().consume(new Consumer<String>() {
@SuppressWarnings("unchecked")
@Override
public void accept(String stompFrame) {
if (stompFrame.isEmpty()) {
// TODO: why are we getting empty frames?
return;
}
Message<byte[]> message = stompMessageConverter.toMessage(stompFrame, sessionId);
Message<byte[]> message = stompMessageConverter.toMessage(stompFrame, sessionId, messageFactory);
getClientChannel().send(message);
}
});
@ -128,6 +134,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
}
@SuppressWarnings("unchecked")
private void forwardMessage(Message<?> message, StompCommand command) {
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
@ -139,7 +146,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler
MediaType contentType = stompHeaders.getContentType();
byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType);
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, stompHeaders.toMessageHeaders());
Message<byte[]> byteMessage = messageFactory.createMessage(payload, stompHeaders.toMessageHeaders());
bytesToWrite = this.stompMessageConverter.fromMessage(byteMessage);
}
catch (Throwable ex) {

View File

@ -25,9 +25,10 @@ import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
@ -59,7 +60,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
private MessageConverter payloadConverter = new CompositeMessageConverter(null);
private MessageFactory messageFactory = new GenericMessageFactory();
@SuppressWarnings("unchecked")
public StompWebSocketHandler(MessageChannel publishChannel, SubscribableChannel clientChannel) {
Assert.notNull(publishChannel, "publishChannel is required");
@ -74,6 +78,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
this.payloadConverter = new CompositeMessageConverter(converters);
}
public void setMessageFactory(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
public StompMessageConverter getStompMessageConverter() {
return this.stompMessageConverter;
}
@ -88,11 +96,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
this.sessions.put(session.getId(), session);
}
@SuppressWarnings("unchecked")
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) {
try {
String payload = textMessage.getPayload();
Message<byte[]> message = this.stompMessageConverter.toMessage(payload, session.getId());
Message<byte[]> message = this.stompMessageConverter.toMessage(payload, session.getId(), messageFactory);
// TODO: validate size limits
// http://stomp.github.io/stomp-specification-1.2.html#Size_Limits
@ -135,6 +144,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
}
}
@SuppressWarnings("unchecked")
protected void handleConnect(final WebSocketSession session, Message<byte[]> message) throws IOException {
StompHeaders connectStompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
@ -157,7 +167,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
// TODO: security
Message<byte[]> connectedMessage = new GenericMessage<byte[]>(new byte[0], connectedStompHeaders.toMessageHeaders());
Message<byte[]> connectedMessage = messageFactory.createMessage(new byte[0], connectedStompHeaders.toMessageHeaders());
byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
@ -177,12 +187,13 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
protected void handleDisconnect(Message<byte[]> stompMessage) {
}
@SuppressWarnings("unchecked")
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
stompHeaders.setMessage(error.getMessage());
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.toMessageHeaders());
Message<byte[]> errorMessage = messageFactory.createMessage(new byte[0], stompHeaders.toMessageHeaders());
byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage);
try {
@ -200,9 +211,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
}*/
private final class ClientMessageConsumer implements MessageHandler {
private final class ClientMessageConsumer implements MessageHandler<Message<?>> {
@SuppressWarnings("unchecked")
@Override
public void handleMessage(Message<?> message) {
@ -235,7 +247,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter {
try {
Map<String, Object> messageHeaders = stompHeaders.toMessageHeaders();
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, messageHeaders);
Message<byte[]> byteMessage = messageFactory.createMessage(payload, messageHeaders);
byte[] bytes = getStompMessageConverter().fromMessage(byteMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}

View File

@ -36,7 +36,7 @@ import reactor.fn.selector.ObjectSelector;
* @author Rossen Stoyanchev
* @since 4.0
*/
public class ReactorMessageChannel implements SubscribableChannel {
public class ReactorMessageChannel implements SubscribableChannel<Message<?>, MessageHandler<Message<?>>> {
private static Log logger = LogFactory.getLog(ReactorMessageChannel.class);
@ -125,6 +125,7 @@ public class ReactorMessageChannel implements SubscribableChannel {
this.handler = handler;
}
@SuppressWarnings("unchecked")
@Override
public void accept(Event<Message<?>> event) {
Message<?> message = event.getData();

View File

@ -19,7 +19,9 @@ import java.util.Collections;
import org.junit.Before;
import org.junit.Test;
import org.springframework.messaging.GenericMessageFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageFactory;
import org.springframework.messaging.MessageHeaders;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.stomp.StompCommand;
@ -35,19 +37,22 @@ public class StompMessageConverterTests {
private StompMessageConverter converter;
private MessageFactory messageFactory = new GenericMessageFactory();
@Before
public void setup() {
this.converter = new StompMessageConverter();
}
@SuppressWarnings("unchecked")
@Test
public void connectFrame() throws Exception {
String accept = "accept-version:1.1\n";
String host = "host:github.org\n";
String frame = "\n\n\nCONNECT\n" + accept + host + "\n";
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory);
assertEquals(0, message.getPayload().length);
@ -71,13 +76,14 @@ public class StompMessageConverterTests {
assertTrue(convertedBack.contains(host));
}
@SuppressWarnings("unchecked")
@Test
public void connectWithEscapes() throws Exception {
String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String frame = "CONNECT\n" + accept + host + "\n";
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory);
assertEquals(0, message.getPayload().length);
@ -93,13 +99,14 @@ public class StompMessageConverterTests {
assertTrue(convertedBack.contains(host));
}
@SuppressWarnings("unchecked")
@Test
public void connectCR12() throws Exception {
String accept = "accept-version:1.2\n";
String host = "host:github.org\n";
String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory);
assertEquals(0, message.getPayload().length);
@ -115,13 +122,14 @@ public class StompMessageConverterTests {
assertTrue(convertedBack.contains(host));
}
@SuppressWarnings("unchecked")
@Test
public void connectWithEscapesAndCR12() throws Exception {
String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory);
assertEquals(0, message.getPayload().length);