From 3022f5e34fed329998589b719f3fd545c271a9db Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Fri, 14 Jun 2013 12:34:12 +0100 Subject: [PATCH] Make Message type pluggable To improve compatibility between Spring's messaging classes and Spring Integration, the type of Message that is created has been made pluggable through the introduction of a factory abstraction; MessageFactory. By default a MessageFactory is provided that will create org.springframework.messaging.GenericMessage instances, however this can be replaced with an alternative implementation. For example, Spring Integration can provide an implementation that creates org.springframework.integration.message.GenericMessage instances. This control over the type of Message that's created allows messages to flow from Spring messaging code into Spring Integration code without any need for conversion. In further support of this goal, MessageChannel, MessageHandler, and SubscribableChannel have been genericized to make the Message type that they deal with more flexible. --- .../messaging/GenericMessage.java | 4 +- .../messaging/GenericMessageFactory.java | 34 +++++++++++++++ .../messaging/MessageChannel.java | 6 +-- .../messaging/MessageFactory.java | 41 +++++++++++++++++++ .../messaging/MessageHandler.java | 4 +- .../messaging/SubscribableChannel.java | 6 +-- .../service/AbstractPubSubMessageHandler.java | 9 +--- .../service/ReactorPubSubMessageHandler.java | 18 ++++++-- .../AnnotationPubSubMessageHandler.java | 19 ++++++++- .../MessageChannelArgumentResolver.java | 15 +++++-- .../method/MessageReturnValueHandler.java | 12 +++++- .../stomp/support/StompMessageConverter.java | 11 +++-- .../StompRelayPubSubMessageHandler.java | 17 +++++--- .../stomp/support/StompWebSocketHandler.java | 24 ++++++++--- .../support/ReactorMessageChannel.java | 3 +- .../support/StompMessageConverterTests.java | 16 ++++++-- 16 files changed, 188 insertions(+), 51 deletions(-) create mode 100644 spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java create mode 100644 spring-context/src/main/java/org/springframework/messaging/MessageFactory.java diff --git a/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java b/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java index f847630361..548baf6806 100644 --- a/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java +++ b/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java @@ -44,7 +44,7 @@ public class GenericMessage implements Message, Serializable { * * @param payload the message payload */ - public GenericMessage(T payload) { + protected GenericMessage(T payload) { this(payload, null); } @@ -56,7 +56,7 @@ public class GenericMessage implements Message, Serializable { * @param headers message headers * @see MessageHeaders */ - public GenericMessage(T payload, Map headers) { + protected GenericMessage(T payload, Map headers) { Assert.notNull(payload, "payload must not be null"); if (headers == null) { headers = new HashMap(); diff --git a/spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java b/spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java new file mode 100644 index 0000000000..076cc5590f --- /dev/null +++ b/spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging; + +import java.util.Map; + + +/** + * A {@link MessageFactory} that creates {@link GenericMessage GenericMessages}. + * + * @author Andy Wilkinson + */ +public class GenericMessageFactory implements MessageFactory> { + + @Override + public

GenericMessage createMessage(P payload, Map headers) { + return new GenericMessage

(payload, headers); + } + +} diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java b/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java index 6bb8128e9c..dd42f246f8 100644 --- a/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java +++ b/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java @@ -23,7 +23,7 @@ package org.springframework.messaging; * @author Mark Fisher * @since 4.0 */ -public interface MessageChannel { +public interface MessageChannel { /** * Send a {@link Message} to this channel. May throw a RuntimeException for @@ -38,7 +38,7 @@ public interface MessageChannel { * * @return whether or not the Message has been sent successfully */ - boolean send(Message message); + boolean send(M message); /** * Send a message, blocking until either the message is accepted or the @@ -51,6 +51,6 @@ public interface MessageChannel { * false if the specified timeout period elapses or * the send is interrupted */ - boolean send(Message message, long timeout); + boolean send(M message, long timeout); } diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java b/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java new file mode 100644 index 0000000000..075cd2f425 --- /dev/null +++ b/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging; + +import java.util.Map; + + +/** + * A factory for creating messages, allowing for control of the concrete type of the message. + * + * + * + * @author Andy Wilkinson + */ +public interface MessageFactory> { + + /** + * Creates a new message with the given payload and headers + * + * @param payload The message payload + * @param headers The message headers + * @param

The payload's type + * + * @return the message + */ +

M createMessage(P payload, Map headers); +} diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java b/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java index 1cbda2cf31..52a171e211 100644 --- a/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java +++ b/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java @@ -24,7 +24,7 @@ package org.springframework.messaging; * @author Iwein Fuld * @since 4.0 */ -public interface MessageHandler { +public interface MessageHandler { /** * TODO: support exceptions? @@ -46,6 +46,6 @@ public interface MessageHandler { * @throws org.springframework.integration.MessageDeliveryException when this handler failed to deliver the * reply related to the handling of the message */ - void handleMessage(Message message) throws MessagingException; + void handleMessage(M message) throws MessagingException; } diff --git a/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java b/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java index e7ff7385c7..6b8961a983 100644 --- a/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java +++ b/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java @@ -25,16 +25,16 @@ package org.springframework.messaging; * @author Mark Fisher * @since 4.0 */ -public interface SubscribableChannel extends MessageChannel { +public interface SubscribableChannel> extends MessageChannel { /** * Register a {@link MessageHandler} as a subscriber to this channel. */ - boolean subscribe(MessageHandler handler); + boolean subscribe(H handler); /** * Remove a {@link MessageHandler} from the subscribers of this channel. */ - boolean unsubscribe(MessageHandler handler); + boolean unsubscribe(H handler); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java index 335655ef76..b0da4182ec 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java @@ -40,7 +40,7 @@ import org.springframework.web.messaging.PubSubHeaders; * @author Rossen Stoyanchev * @since 4.0 */ -public abstract class AbstractPubSubMessageHandler implements MessageHandler { +public abstract class AbstractPubSubMessageHandler implements MessageHandler> { protected final Log logger = LogFactory.getLog(getClass()); @@ -54,11 +54,9 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler { private final PathMatcher pathMatcher = new AntPathMatcher(); - /** * @param publishChannel a channel for publishing messages from within the - * application; this constructor will also automatically subscribe the - * current instance to this channel + * application * * @param clientChannel a channel for sending messages to connected clients. */ @@ -67,9 +65,7 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler { Assert.notNull(publishChannel, "publishChannel is required"); Assert.notNull(clientChannel, "clientChannel is required"); - publishChannel.subscribe(this); this.publishChannel = publishChannel; - this.clientChannel = clientChannel; } @@ -146,7 +142,6 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler { return true; } - @Override public final void handleMessage(Message message) throws MessagingException { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java index 3d0813d744..00ac304a02 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java @@ -23,9 +23,10 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.SubscribableChannel; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.PubSubHeaders; @@ -50,6 +51,8 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { private MessageConverter payloadConverter; + private MessageFactory messageFactory; + private Map>> subscriptionsBySession = new ConcurrentHashMap>>(); @@ -59,13 +62,18 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { super(publishChannel, clientChannel); this.reactor = reactor; this.payloadConverter = new CompositeMessageConverter(null); + this.messageFactory = new GenericMessageFactory(); } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } public void setMessageConverters(List converters) { this.payloadConverter = new CompositeMessageConverter(converters); } + @SuppressWarnings("unchecked") @Override public void handlePublish(Message message) { @@ -77,7 +85,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { // Convert to byte[] payload before the fan-out PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders()); byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType()); - message = new GenericMessage(payload, message.getHeaders()); + message = messageFactory.createMessage(payload, message.getHeaders()); this.reactor.notify(getPublishKey(inHeaders.getDestination()), Event.wrap(message)); } @@ -109,6 +117,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { Selector selector = new ObjectSelector(getPublishKey(headers.getDestination())); Registration registration = this.reactor.on(selector, new Consumer>>() { + @SuppressWarnings("unchecked") @Override public void accept(Event> event) { Message message = event.getData(); @@ -120,8 +129,9 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { } outHeaders.setSubscriptionId(subscriptionId); Object payload = message.getPayload(); - message = new GenericMessage(payload, outHeaders.toMessageHeaders()); - getClientChannel().send(message); + + Message outMessage = messageFactory.createMessage(payload, outHeaders.toMessageHeaders()); + getClientChannel().send(outMessage); } }); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java index 474b309617..d26366b89e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java @@ -31,8 +31,10 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.annotation.MessageMapping; import org.springframework.stereotype.Controller; @@ -69,6 +71,8 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler private ReturnValueHandlerComposite returnValueHandlers = new ReturnValueHandlerComposite(); + private MessageFactory messageFactory = new GenericMessageFactory(); + public AnnotationPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) { @@ -79,6 +83,10 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler this.messageConverters = converters; } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } + @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { this.applicationContext = applicationContext; @@ -92,9 +100,16 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler @Override public void afterPropertiesSet() { initHandlerMethods(); - this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(getPublishChannel())); + + MessageChannelArgumentResolver messageChannelArgumentResolver = new MessageChannelArgumentResolver(getPublishChannel()); + messageChannelArgumentResolver.setMessageFactory(messageFactory); + this.argumentResolvers.addResolver(messageChannelArgumentResolver); + this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverters)); - this.returnValueHandlers.addHandler(new MessageReturnValueHandler(getClientChannel())); + + MessageReturnValueHandler messageReturnValueHandler = new MessageReturnValueHandler(getClientChannel()); + messageReturnValueHandler.setMessageFactory(messageFactory); + this.returnValueHandlers.addHandler(messageReturnValueHandler); } protected void initHandlerMethods() { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java index 33518f3343..3e7ba1c58b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java @@ -17,9 +17,10 @@ package org.springframework.web.messaging.service.method; import org.springframework.core.MethodParameter; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.util.Assert; import org.springframework.web.messaging.PubSubHeaders; @@ -32,10 +33,16 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { private final MessageChannel publishChannel; + private MessageFactory messageFactory; public MessageChannelArgumentResolver(MessageChannel publishChannel) { Assert.notNull(publishChannel, "publishChannel is required"); this.publishChannel = publishChannel; + this.messageFactory = new GenericMessageFactory(); + } + + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; } @Override @@ -48,19 +55,19 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { final String sessionId = PubSubHeaders.fromMessageHeaders(message.getHeaders()).getSessionId(); - return new MessageChannel() { + return new MessageChannel>() { @Override public boolean send(Message message) { return send(message, -1); } + @SuppressWarnings("unchecked") @Override public boolean send(Message message, long timeout) { PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); headers.setSessionId(sessionId); - message = new GenericMessage(message.getPayload(), headers.toMessageHeaders()); - publishChannel.send(message); + publishChannel.send(messageFactory.createMessage(message.getPayload(), headers.toMessageHeaders())); return true; } }; diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java index 6e607040dd..4a1e69c04b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java @@ -17,9 +17,10 @@ package org.springframework.web.messaging.service.method; import org.springframework.core.MethodParameter; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.util.Assert; import org.springframework.web.messaging.PubSubHeaders; @@ -32,12 +33,18 @@ public class MessageReturnValueHandler implements ReturnValueHandler { private final MessageChannel clientChannel; + private MessageFactory messageFactory = new GenericMessageFactory(); + public MessageReturnValueHandler(MessageChannel clientChannel) { Assert.notNull(clientChannel, "clientChannel is required"); this.clientChannel = clientChannel; } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } + @Override public boolean supportsReturnType(MethodParameter returnType) { @@ -56,6 +63,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler { // return Message.class.isAssignableFrom(paramType); } + @SuppressWarnings("unchecked") @Override public void handleReturnValue(Object returnValue, MethodParameter returnType, Message message) throws Exception { @@ -73,7 +81,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler { PubSubHeaders outHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders()); outHeaders.setSessionId(sessionId); outHeaders.setSubscriptionId(subscriptionId); - returnMessage = new GenericMessage(returnMessage.getPayload(), outHeaders.toMessageHeaders()); + returnMessage = messageFactory.createMessage(returnMessage.getPayload(), outHeaders.toMessageHeaders()); this.clientChannel.send(returnMessage); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java index 0c7376ddb4..52086b68c6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java @@ -22,8 +22,8 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; -import org.springframework.messaging.GenericMessage; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.MessageHeaders; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; @@ -48,11 +48,10 @@ public class StompMessageConverter { private static final byte COLON = ':'; - /** * @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String. */ - public Message toMessage(Object stompContent, String sessionId) { + public > M toMessage(Object stompContent, String sessionId, MessageFactory messageFactory) { byte[] byteContent = null; if (stompContent instanceof String) { @@ -103,7 +102,7 @@ public class StompMessageConverter { byte[] payload = new byte[totalLength - payloadIndex]; System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); - return createMessage(command, stompHeaders.toMessageHeaders(), payload); + return createMessage(command, stompHeaders.toMessageHeaders(), payload, messageFactory); } private int findIndexOfPayload(byte[] bytes) { @@ -133,8 +132,8 @@ public class StompMessageConverter { return index; } - protected Message createMessage(StompCommand command, Map headers, byte[] payload) { - return new GenericMessage(payload, headers); + protected > M createMessage(StompCommand command, Map headers, byte[] payload, MessageFactory messageFactory) { + return messageFactory.createMessage(payload, headers); } public byte[] fromMessage(Message message) { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java index 1d2300e653..5783971b72 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java @@ -23,9 +23,10 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.springframework.http.MediaType; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.SubscribableChannel; import org.springframework.util.Assert; import org.springframework.web.messaging.MessageType; @@ -52,17 +53,17 @@ import reactor.tcp.netty.NettyTcpClient; */ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler { - private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); private MessageConverter payloadConverter; + private MessageFactory messageFactory = new GenericMessageFactory(); + private final TcpClient tcpClient; private final Map> connections = new ConcurrentHashMap>(); - public StompRelayPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) { super(publishChannel, clientChannel); @@ -81,6 +82,10 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler this.payloadConverter = new CompositeMessageConverter(converters); } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } + @Override protected Collection getSupportedMessageTypes() { return null; @@ -105,13 +110,14 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler @Override public void accept(TcpConnection connection) { connection.in().consume(new Consumer() { + @SuppressWarnings("unchecked") @Override public void accept(String stompFrame) { if (stompFrame.isEmpty()) { // TODO: why are we getting empty frames? return; } - Message message = stompMessageConverter.toMessage(stompFrame, sessionId); + Message message = stompMessageConverter.toMessage(stompFrame, sessionId, messageFactory); getClientChannel().send(message); } }); @@ -128,6 +134,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler } + @SuppressWarnings("unchecked") private void forwardMessage(Message message, StompCommand command) { StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); @@ -139,7 +146,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler MediaType contentType = stompHeaders.getContentType(); byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType); - Message byteMessage = new GenericMessage(payload, stompHeaders.toMessageHeaders()); + Message byteMessage = messageFactory.createMessage(payload, stompHeaders.toMessageHeaders()); bytesToWrite = this.stompMessageConverter.fromMessage(byteMessage); } catch (Throwable ex) { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java index 53c57fc5ea..e09a905be2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java @@ -25,9 +25,10 @@ import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.MediaType; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.SubscribableChannel; import org.springframework.util.Assert; @@ -59,7 +60,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { private MessageConverter payloadConverter = new CompositeMessageConverter(null); + private MessageFactory messageFactory = new GenericMessageFactory(); + + @SuppressWarnings("unchecked") public StompWebSocketHandler(MessageChannel publishChannel, SubscribableChannel clientChannel) { Assert.notNull(publishChannel, "publishChannel is required"); @@ -74,6 +78,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { this.payloadConverter = new CompositeMessageConverter(converters); } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } + public StompMessageConverter getStompMessageConverter() { return this.stompMessageConverter; } @@ -88,11 +96,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { this.sessions.put(session.getId(), session); } + @SuppressWarnings("unchecked") @Override protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) { try { String payload = textMessage.getPayload(); - Message message = this.stompMessageConverter.toMessage(payload, session.getId()); + Message message = this.stompMessageConverter.toMessage(payload, session.getId(), messageFactory); // TODO: validate size limits // http://stomp.github.io/stomp-specification-1.2.html#Size_Limits @@ -135,6 +144,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { } } + @SuppressWarnings("unchecked") protected void handleConnect(final WebSocketSession session, Message message) throws IOException { StompHeaders connectStompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); @@ -157,7 +167,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { // TODO: security - Message connectedMessage = new GenericMessage(new byte[0], connectedStompHeaders.toMessageHeaders()); + Message connectedMessage = messageFactory.createMessage(new byte[0], connectedStompHeaders.toMessageHeaders()); byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } @@ -177,12 +187,13 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { protected void handleDisconnect(Message stompMessage) { } + @SuppressWarnings("unchecked") protected void sendErrorMessage(WebSocketSession session, Throwable error) { StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR); stompHeaders.setMessage(error.getMessage()); - Message errorMessage = new GenericMessage(new byte[0], stompHeaders.toMessageHeaders()); + Message errorMessage = messageFactory.createMessage(new byte[0], stompHeaders.toMessageHeaders()); byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage); try { @@ -200,9 +211,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { }*/ - private final class ClientMessageConsumer implements MessageHandler { + private final class ClientMessageConsumer implements MessageHandler> { + @SuppressWarnings("unchecked") @Override public void handleMessage(Message message) { @@ -235,7 +247,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { try { Map messageHeaders = stompHeaders.toMessageHeaders(); - Message byteMessage = new GenericMessage(payload, messageHeaders); + Message byteMessage = messageFactory.createMessage(payload, messageHeaders); byte[] bytes = getStompMessageConverter().fromMessage(byteMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/ReactorMessageChannel.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/ReactorMessageChannel.java index 4e6d25d5e9..c5a99c124c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/ReactorMessageChannel.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/ReactorMessageChannel.java @@ -36,7 +36,7 @@ import reactor.fn.selector.ObjectSelector; * @author Rossen Stoyanchev * @since 4.0 */ -public class ReactorMessageChannel implements SubscribableChannel { +public class ReactorMessageChannel implements SubscribableChannel, MessageHandler>> { private static Log logger = LogFactory.getLog(ReactorMessageChannel.class); @@ -125,6 +125,7 @@ public class ReactorMessageChannel implements SubscribableChannel { this.handler = handler; } + @SuppressWarnings("unchecked") @Override public void accept(Event> event) { Message message = event.getData(); diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java index e2eaed097f..366e9f758f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java @@ -19,7 +19,9 @@ import java.util.Collections; import org.junit.Before; import org.junit.Test; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.MessageHeaders; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.stomp.StompCommand; @@ -35,19 +37,22 @@ public class StompMessageConverterTests { private StompMessageConverter converter; + private MessageFactory messageFactory = new GenericMessageFactory(); + @Before public void setup() { this.converter = new StompMessageConverter(); } + @SuppressWarnings("unchecked") @Test public void connectFrame() throws Exception { String accept = "accept-version:1.1\n"; String host = "host:github.org\n"; String frame = "\n\n\nCONNECT\n" + accept + host + "\n"; - Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory); assertEquals(0, message.getPayload().length); @@ -71,13 +76,14 @@ public class StompMessageConverterTests { assertTrue(convertedBack.contains(host)); } + @SuppressWarnings("unchecked") @Test public void connectWithEscapes() throws Exception { String accept = "accept-version:1.1\n"; String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String frame = "CONNECT\n" + accept + host + "\n"; - Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory); assertEquals(0, message.getPayload().length); @@ -93,13 +99,14 @@ public class StompMessageConverterTests { assertTrue(convertedBack.contains(host)); } + @SuppressWarnings("unchecked") @Test public void connectCR12() throws Exception { String accept = "accept-version:1.2\n"; String host = "host:github.org\n"; String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory); assertEquals(0, message.getPayload().length); @@ -115,13 +122,14 @@ public class StompMessageConverterTests { assertTrue(convertedBack.contains(host)); } + @SuppressWarnings("unchecked") @Test public void connectWithEscapesAndCR12() throws Exception { String accept = "accept-version:1.1\n"; String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory); assertEquals(0, message.getPayload().length);