diff --git a/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java b/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java index f847630361..548baf6806 100644 --- a/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java +++ b/spring-context/src/main/java/org/springframework/messaging/GenericMessage.java @@ -44,7 +44,7 @@ public class GenericMessage implements Message, Serializable { * * @param payload the message payload */ - public GenericMessage(T payload) { + protected GenericMessage(T payload) { this(payload, null); } @@ -56,7 +56,7 @@ public class GenericMessage implements Message, Serializable { * @param headers message headers * @see MessageHeaders */ - public GenericMessage(T payload, Map headers) { + protected GenericMessage(T payload, Map headers) { Assert.notNull(payload, "payload must not be null"); if (headers == null) { headers = new HashMap(); diff --git a/spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java b/spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java new file mode 100644 index 0000000000..076cc5590f --- /dev/null +++ b/spring-context/src/main/java/org/springframework/messaging/GenericMessageFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging; + +import java.util.Map; + + +/** + * A {@link MessageFactory} that creates {@link GenericMessage GenericMessages}. + * + * @author Andy Wilkinson + */ +public class GenericMessageFactory implements MessageFactory> { + + @Override + public

GenericMessage createMessage(P payload, Map headers) { + return new GenericMessage

(payload, headers); + } + +} diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java b/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java index 6bb8128e9c..dd42f246f8 100644 --- a/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java +++ b/spring-context/src/main/java/org/springframework/messaging/MessageChannel.java @@ -23,7 +23,7 @@ package org.springframework.messaging; * @author Mark Fisher * @since 4.0 */ -public interface MessageChannel { +public interface MessageChannel { /** * Send a {@link Message} to this channel. May throw a RuntimeException for @@ -38,7 +38,7 @@ public interface MessageChannel { * * @return whether or not the Message has been sent successfully */ - boolean send(Message message); + boolean send(M message); /** * Send a message, blocking until either the message is accepted or the @@ -51,6 +51,6 @@ public interface MessageChannel { * false if the specified timeout period elapses or * the send is interrupted */ - boolean send(Message message, long timeout); + boolean send(M message, long timeout); } diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java b/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java new file mode 100644 index 0000000000..075cd2f425 --- /dev/null +++ b/spring-context/src/main/java/org/springframework/messaging/MessageFactory.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging; + +import java.util.Map; + + +/** + * A factory for creating messages, allowing for control of the concrete type of the message. + * + * + * + * @author Andy Wilkinson + */ +public interface MessageFactory> { + + /** + * Creates a new message with the given payload and headers + * + * @param payload The message payload + * @param headers The message headers + * @param

The payload's type + * + * @return the message + */ +

M createMessage(P payload, Map headers); +} diff --git a/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java b/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java index 1cbda2cf31..52a171e211 100644 --- a/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java +++ b/spring-context/src/main/java/org/springframework/messaging/MessageHandler.java @@ -24,7 +24,7 @@ package org.springframework.messaging; * @author Iwein Fuld * @since 4.0 */ -public interface MessageHandler { +public interface MessageHandler { /** * TODO: support exceptions? @@ -46,6 +46,6 @@ public interface MessageHandler { * @throws org.springframework.integration.MessageDeliveryException when this handler failed to deliver the * reply related to the handling of the message */ - void handleMessage(Message message) throws MessagingException; + void handleMessage(M message) throws MessagingException; } diff --git a/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java b/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java index e7ff7385c7..6b8961a983 100644 --- a/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java +++ b/spring-context/src/main/java/org/springframework/messaging/SubscribableChannel.java @@ -25,16 +25,16 @@ package org.springframework.messaging; * @author Mark Fisher * @since 4.0 */ -public interface SubscribableChannel extends MessageChannel { +public interface SubscribableChannel> extends MessageChannel { /** * Register a {@link MessageHandler} as a subscriber to this channel. */ - boolean subscribe(MessageHandler handler); + boolean subscribe(H handler); /** * Remove a {@link MessageHandler} from the subscribers of this channel. */ - boolean unsubscribe(MessageHandler handler); + boolean unsubscribe(H handler); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java index 335655ef76..b0da4182ec 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java @@ -40,7 +40,7 @@ import org.springframework.web.messaging.PubSubHeaders; * @author Rossen Stoyanchev * @since 4.0 */ -public abstract class AbstractPubSubMessageHandler implements MessageHandler { +public abstract class AbstractPubSubMessageHandler implements MessageHandler> { protected final Log logger = LogFactory.getLog(getClass()); @@ -54,11 +54,9 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler { private final PathMatcher pathMatcher = new AntPathMatcher(); - /** * @param publishChannel a channel for publishing messages from within the - * application; this constructor will also automatically subscribe the - * current instance to this channel + * application * * @param clientChannel a channel for sending messages to connected clients. */ @@ -67,9 +65,7 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler { Assert.notNull(publishChannel, "publishChannel is required"); Assert.notNull(clientChannel, "clientChannel is required"); - publishChannel.subscribe(this); this.publishChannel = publishChannel; - this.clientChannel = clientChannel; } @@ -146,7 +142,6 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler { return true; } - @Override public final void handleMessage(Message message) throws MessagingException { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java index 3d0813d744..00ac304a02 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java @@ -23,9 +23,10 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.SubscribableChannel; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.PubSubHeaders; @@ -50,6 +51,8 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { private MessageConverter payloadConverter; + private MessageFactory messageFactory; + private Map>> subscriptionsBySession = new ConcurrentHashMap>>(); @@ -59,13 +62,18 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { super(publishChannel, clientChannel); this.reactor = reactor; this.payloadConverter = new CompositeMessageConverter(null); + this.messageFactory = new GenericMessageFactory(); } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } public void setMessageConverters(List converters) { this.payloadConverter = new CompositeMessageConverter(converters); } + @SuppressWarnings("unchecked") @Override public void handlePublish(Message message) { @@ -77,7 +85,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { // Convert to byte[] payload before the fan-out PubSubHeaders inHeaders = PubSubHeaders.fromMessageHeaders(message.getHeaders()); byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType()); - message = new GenericMessage(payload, message.getHeaders()); + message = messageFactory.createMessage(payload, message.getHeaders()); this.reactor.notify(getPublishKey(inHeaders.getDestination()), Event.wrap(message)); } @@ -109,6 +117,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { Selector selector = new ObjectSelector(getPublishKey(headers.getDestination())); Registration registration = this.reactor.on(selector, new Consumer>>() { + @SuppressWarnings("unchecked") @Override public void accept(Event> event) { Message message = event.getData(); @@ -120,8 +129,9 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { } outHeaders.setSubscriptionId(subscriptionId); Object payload = message.getPayload(); - message = new GenericMessage(payload, outHeaders.toMessageHeaders()); - getClientChannel().send(message); + + Message outMessage = messageFactory.createMessage(payload, outHeaders.toMessageHeaders()); + getClientChannel().send(outMessage); } }); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java index 474b309617..d26366b89e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java @@ -31,8 +31,10 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.annotation.MessageMapping; import org.springframework.stereotype.Controller; @@ -69,6 +71,8 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler private ReturnValueHandlerComposite returnValueHandlers = new ReturnValueHandlerComposite(); + private MessageFactory messageFactory = new GenericMessageFactory(); + public AnnotationPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) { @@ -79,6 +83,10 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler this.messageConverters = converters; } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } + @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { this.applicationContext = applicationContext; @@ -92,9 +100,16 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler @Override public void afterPropertiesSet() { initHandlerMethods(); - this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(getPublishChannel())); + + MessageChannelArgumentResolver messageChannelArgumentResolver = new MessageChannelArgumentResolver(getPublishChannel()); + messageChannelArgumentResolver.setMessageFactory(messageFactory); + this.argumentResolvers.addResolver(messageChannelArgumentResolver); + this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverters)); - this.returnValueHandlers.addHandler(new MessageReturnValueHandler(getClientChannel())); + + MessageReturnValueHandler messageReturnValueHandler = new MessageReturnValueHandler(getClientChannel()); + messageReturnValueHandler.setMessageFactory(messageFactory); + this.returnValueHandlers.addHandler(messageReturnValueHandler); } protected void initHandlerMethods() { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java index 33518f3343..3e7ba1c58b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java @@ -17,9 +17,10 @@ package org.springframework.web.messaging.service.method; import org.springframework.core.MethodParameter; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.util.Assert; import org.springframework.web.messaging.PubSubHeaders; @@ -32,10 +33,16 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { private final MessageChannel publishChannel; + private MessageFactory messageFactory; public MessageChannelArgumentResolver(MessageChannel publishChannel) { Assert.notNull(publishChannel, "publishChannel is required"); this.publishChannel = publishChannel; + this.messageFactory = new GenericMessageFactory(); + } + + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; } @Override @@ -48,19 +55,19 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { final String sessionId = PubSubHeaders.fromMessageHeaders(message.getHeaders()).getSessionId(); - return new MessageChannel() { + return new MessageChannel>() { @Override public boolean send(Message message) { return send(message, -1); } + @SuppressWarnings("unchecked") @Override public boolean send(Message message, long timeout) { PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); headers.setSessionId(sessionId); - message = new GenericMessage(message.getPayload(), headers.toMessageHeaders()); - publishChannel.send(message); + publishChannel.send(messageFactory.createMessage(message.getPayload(), headers.toMessageHeaders())); return true; } }; diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java index 6e607040dd..4a1e69c04b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java @@ -17,9 +17,10 @@ package org.springframework.web.messaging.service.method; import org.springframework.core.MethodParameter; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.util.Assert; import org.springframework.web.messaging.PubSubHeaders; @@ -32,12 +33,18 @@ public class MessageReturnValueHandler implements ReturnValueHandler { private final MessageChannel clientChannel; + private MessageFactory messageFactory = new GenericMessageFactory(); + public MessageReturnValueHandler(MessageChannel clientChannel) { Assert.notNull(clientChannel, "clientChannel is required"); this.clientChannel = clientChannel; } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } + @Override public boolean supportsReturnType(MethodParameter returnType) { @@ -56,6 +63,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler { // return Message.class.isAssignableFrom(paramType); } + @SuppressWarnings("unchecked") @Override public void handleReturnValue(Object returnValue, MethodParameter returnType, Message message) throws Exception { @@ -73,7 +81,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler { PubSubHeaders outHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders()); outHeaders.setSessionId(sessionId); outHeaders.setSubscriptionId(subscriptionId); - returnMessage = new GenericMessage(returnMessage.getPayload(), outHeaders.toMessageHeaders()); + returnMessage = messageFactory.createMessage(returnMessage.getPayload(), outHeaders.toMessageHeaders()); this.clientChannel.send(returnMessage); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java index 0c7376ddb4..52086b68c6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java @@ -22,8 +22,8 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; -import org.springframework.messaging.GenericMessage; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.MessageHeaders; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; @@ -48,11 +48,10 @@ public class StompMessageConverter { private static final byte COLON = ':'; - /** * @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String. */ - public Message toMessage(Object stompContent, String sessionId) { + public > M toMessage(Object stompContent, String sessionId, MessageFactory messageFactory) { byte[] byteContent = null; if (stompContent instanceof String) { @@ -103,7 +102,7 @@ public class StompMessageConverter { byte[] payload = new byte[totalLength - payloadIndex]; System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); - return createMessage(command, stompHeaders.toMessageHeaders(), payload); + return createMessage(command, stompHeaders.toMessageHeaders(), payload, messageFactory); } private int findIndexOfPayload(byte[] bytes) { @@ -133,8 +132,8 @@ public class StompMessageConverter { return index; } - protected Message createMessage(StompCommand command, Map headers, byte[] payload) { - return new GenericMessage(payload, headers); + protected > M createMessage(StompCommand command, Map headers, byte[] payload, MessageFactory messageFactory) { + return messageFactory.createMessage(payload, headers); } public byte[] fromMessage(Message message) { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java index 1d2300e653..5783971b72 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java @@ -23,9 +23,10 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.springframework.http.MediaType; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.SubscribableChannel; import org.springframework.util.Assert; import org.springframework.web.messaging.MessageType; @@ -52,17 +53,17 @@ import reactor.tcp.netty.NettyTcpClient; */ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler { - private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); private MessageConverter payloadConverter; + private MessageFactory messageFactory = new GenericMessageFactory(); + private final TcpClient tcpClient; private final Map> connections = new ConcurrentHashMap>(); - public StompRelayPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) { super(publishChannel, clientChannel); @@ -81,6 +82,10 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler this.payloadConverter = new CompositeMessageConverter(converters); } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } + @Override protected Collection getSupportedMessageTypes() { return null; @@ -105,13 +110,14 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler @Override public void accept(TcpConnection connection) { connection.in().consume(new Consumer() { + @SuppressWarnings("unchecked") @Override public void accept(String stompFrame) { if (stompFrame.isEmpty()) { // TODO: why are we getting empty frames? return; } - Message message = stompMessageConverter.toMessage(stompFrame, sessionId); + Message message = stompMessageConverter.toMessage(stompFrame, sessionId, messageFactory); getClientChannel().send(message); } }); @@ -128,6 +134,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler } + @SuppressWarnings("unchecked") private void forwardMessage(Message message, StompCommand command) { StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); @@ -139,7 +146,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler MediaType contentType = stompHeaders.getContentType(); byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType); - Message byteMessage = new GenericMessage(payload, stompHeaders.toMessageHeaders()); + Message byteMessage = messageFactory.createMessage(payload, stompHeaders.toMessageHeaders()); bytesToWrite = this.stompMessageConverter.fromMessage(byteMessage); } catch (Throwable ex) { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java index 53c57fc5ea..e09a905be2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java @@ -25,9 +25,10 @@ import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.MediaType; -import org.springframework.messaging.GenericMessage; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.SubscribableChannel; import org.springframework.util.Assert; @@ -59,7 +60,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { private MessageConverter payloadConverter = new CompositeMessageConverter(null); + private MessageFactory messageFactory = new GenericMessageFactory(); + + @SuppressWarnings("unchecked") public StompWebSocketHandler(MessageChannel publishChannel, SubscribableChannel clientChannel) { Assert.notNull(publishChannel, "publishChannel is required"); @@ -74,6 +78,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { this.payloadConverter = new CompositeMessageConverter(converters); } + public void setMessageFactory(MessageFactory messageFactory) { + this.messageFactory = messageFactory; + } + public StompMessageConverter getStompMessageConverter() { return this.stompMessageConverter; } @@ -88,11 +96,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { this.sessions.put(session.getId(), session); } + @SuppressWarnings("unchecked") @Override protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) { try { String payload = textMessage.getPayload(); - Message message = this.stompMessageConverter.toMessage(payload, session.getId()); + Message message = this.stompMessageConverter.toMessage(payload, session.getId(), messageFactory); // TODO: validate size limits // http://stomp.github.io/stomp-specification-1.2.html#Size_Limits @@ -135,6 +144,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { } } + @SuppressWarnings("unchecked") protected void handleConnect(final WebSocketSession session, Message message) throws IOException { StompHeaders connectStompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); @@ -157,7 +167,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { // TODO: security - Message connectedMessage = new GenericMessage(new byte[0], connectedStompHeaders.toMessageHeaders()); + Message connectedMessage = messageFactory.createMessage(new byte[0], connectedStompHeaders.toMessageHeaders()); byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } @@ -177,12 +187,13 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { protected void handleDisconnect(Message stompMessage) { } + @SuppressWarnings("unchecked") protected void sendErrorMessage(WebSocketSession session, Throwable error) { StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR); stompHeaders.setMessage(error.getMessage()); - Message errorMessage = new GenericMessage(new byte[0], stompHeaders.toMessageHeaders()); + Message errorMessage = messageFactory.createMessage(new byte[0], stompHeaders.toMessageHeaders()); byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage); try { @@ -200,9 +211,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { }*/ - private final class ClientMessageConsumer implements MessageHandler { + private final class ClientMessageConsumer implements MessageHandler> { + @SuppressWarnings("unchecked") @Override public void handleMessage(Message message) { @@ -235,7 +247,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { try { Map messageHeaders = stompHeaders.toMessageHeaders(); - Message byteMessage = new GenericMessage(payload, messageHeaders); + Message byteMessage = messageFactory.createMessage(payload, messageHeaders); byte[] bytes = getStompMessageConverter().fromMessage(byteMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/ReactorMessageChannel.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/ReactorMessageChannel.java index 4e6d25d5e9..c5a99c124c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/ReactorMessageChannel.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/ReactorMessageChannel.java @@ -36,7 +36,7 @@ import reactor.fn.selector.ObjectSelector; * @author Rossen Stoyanchev * @since 4.0 */ -public class ReactorMessageChannel implements SubscribableChannel { +public class ReactorMessageChannel implements SubscribableChannel, MessageHandler>> { private static Log logger = LogFactory.getLog(ReactorMessageChannel.class); @@ -125,6 +125,7 @@ public class ReactorMessageChannel implements SubscribableChannel { this.handler = handler; } + @SuppressWarnings("unchecked") @Override public void accept(Event> event) { Message message = event.getData(); diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java index e2eaed097f..366e9f758f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java @@ -19,7 +19,9 @@ import java.util.Collections; import org.junit.Before; import org.junit.Test; +import org.springframework.messaging.GenericMessageFactory; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageFactory; import org.springframework.messaging.MessageHeaders; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.stomp.StompCommand; @@ -35,19 +37,22 @@ public class StompMessageConverterTests { private StompMessageConverter converter; + private MessageFactory messageFactory = new GenericMessageFactory(); + @Before public void setup() { this.converter = new StompMessageConverter(); } + @SuppressWarnings("unchecked") @Test public void connectFrame() throws Exception { String accept = "accept-version:1.1\n"; String host = "host:github.org\n"; String frame = "\n\n\nCONNECT\n" + accept + host + "\n"; - Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory); assertEquals(0, message.getPayload().length); @@ -71,13 +76,14 @@ public class StompMessageConverterTests { assertTrue(convertedBack.contains(host)); } + @SuppressWarnings("unchecked") @Test public void connectWithEscapes() throws Exception { String accept = "accept-version:1.1\n"; String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String frame = "CONNECT\n" + accept + host + "\n"; - Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123"); + Message message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123", messageFactory); assertEquals(0, message.getPayload().length); @@ -93,13 +99,14 @@ public class StompMessageConverterTests { assertTrue(convertedBack.contains(host)); } + @SuppressWarnings("unchecked") @Test public void connectCR12() throws Exception { String accept = "accept-version:1.2\n"; String host = "host:github.org\n"; String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory); assertEquals(0, message.getPayload().length); @@ -115,13 +122,14 @@ public class StompMessageConverterTests { assertTrue(convertedBack.contains(host)); } + @SuppressWarnings("unchecked") @Test public void connectWithEscapesAndCR12() throws Exception { String accept = "accept-version:1.1\n"; String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123"); + Message message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123", messageFactory); assertEquals(0, message.getPayload().length);