diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistry.java b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistry.java new file mode 100644 index 00000000000..db27c0f561a --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistry.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public interface PubSubChannelRegistry { + + MessageChannel> getClientInputChannel(); + + MessageChannel> getClientOutputChannel(); + + MessageChannel> getMessageBrokerChannel(); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistryAware.java b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistryAware.java new file mode 100644 index 00000000000..f12ab764bf0 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubChannelRegistryAware.java @@ -0,0 +1,28 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public interface PubSubChannelRegistryAware { + + void setPubSubChannelRegistry(PubSubChannelRegistry registry); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java index 873d9ef766f..4e8a0eb7614 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java @@ -24,12 +24,9 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessagingException; -import org.springframework.messaging.SubscribableChannel; import org.springframework.util.AntPathMatcher; -import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.PathMatcher; import org.springframework.web.messaging.MessageType; @@ -44,10 +41,6 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler allowedDestinations = new ArrayList(); private final List disallowedDestinations = new ArrayList(); @@ -55,29 +48,6 @@ public abstract class AbstractPubSubMessageHandler implements MessageHandler> clientChannel; private final Reactor reactor; @@ -52,14 +57,22 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { private Map>> subscriptionsBySession = new ConcurrentHashMap>>(); - public ReactorPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel, - Reactor reactor) { - - super(publishChannel, clientChannel); + /** + * @param clientChannel a channel for broadcasting messages to subscribed clients + * @param reactor + */ + public ReactorPubSubMessageHandler(Reactor reactor) { + Assert.notNull(reactor, "reactor is required"); this.reactor = reactor; this.payloadConverter = new CompositeMessageConverter(null); } + + @Override + public void setPubSubChannelRegistry(PubSubChannelRegistry registry) { + this.clientChannel = registry.getClientOutputChannel(); + } + public void setMessageConverters(List converters) { this.payloadConverter = new CompositeMessageConverter(converters); } @@ -148,7 +161,6 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { this.subscriptionId = subscriptionId; } - @SuppressWarnings("unchecked") @Override public void accept(Event> event) { @@ -160,7 +172,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubSubMessageHandler { Message clientMessage = MessageBuilder.fromPayloadAndHeaders(sentMessage.getPayload(), clientHeaders.toMessageHeaders()).build(); - getClientChannel().send(clientMessage); + clientChannel.send(clientMessage); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java index 7587fe79d24..5517a456a88 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java @@ -32,13 +32,13 @@ import org.springframework.context.ApplicationContextAware; import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.annotation.MessageMapping; import org.springframework.stereotype.Controller; import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils.MethodFilter; import org.springframework.web.messaging.MessageType; +import org.springframework.web.messaging.PubSubChannelRegistry; +import org.springframework.web.messaging.PubSubChannelRegistryAware; import org.springframework.web.messaging.PubSubHeaders; import org.springframework.web.messaging.annotation.SubscribeEvent; import org.springframework.web.messaging.annotation.UnsubscribeEvent; @@ -53,7 +53,7 @@ import org.springframework.web.method.HandlerMethodSelector; * @since 4.0 */ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler - implements ApplicationContextAware, InitializingBean { + implements ApplicationContextAware, InitializingBean, PubSubChannelRegistryAware { private List messageConverters; @@ -70,8 +70,10 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler private ReturnValueHandlerComposite returnValueHandlers = new ReturnValueHandlerComposite(); - public AnnotationPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) { - super(publishChannel, clientChannel); + @Override + public void setPubSubChannelRegistry(PubSubChannelRegistry registry) { + this.argumentResolvers.setPubSubChannelRegistry(registry); + this.returnValueHandlers.setPubSubChannelRegistry(registry); } public void setMessageConverters(List converters) { @@ -93,10 +95,10 @@ public class AnnotationPubSubMessageHandler extends AbstractPubSubMessageHandler initHandlerMethods(); - this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(getPublishChannel())); + this.argumentResolvers.addResolver(new MessageChannelArgumentResolver()); this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverters)); - this.returnValueHandlers.addHandler(new MessageReturnValueHandler(getClientChannel())); + this.returnValueHandlers.addHandler(new MessageReturnValueHandler()); } protected void initHandlerMethods() { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/ArgumentResolverComposite.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/ArgumentResolverComposite.java index cce2a091b18..3e22e9f7b30 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/ArgumentResolverComposite.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/ArgumentResolverComposite.java @@ -27,6 +27,8 @@ import org.apache.commons.logging.LogFactory; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.util.Assert; +import org.springframework.web.messaging.PubSubChannelRegistry; +import org.springframework.web.messaging.PubSubChannelRegistryAware; /** * Resolves method parameters by delegating to a list of registered @@ -81,9 +83,9 @@ public class ArgumentResolverComposite implements ArgumentResolver { private ArgumentResolver getArgumentResolver(MethodParameter parameter) { ArgumentResolver result = this.argumentResolverCache.get(parameter); if (result == null) { - for (ArgumentResolver methodArgumentResolver : this.argumentResolvers) { - if (methodArgumentResolver.supportsParameter(parameter)) { - result = methodArgumentResolver; + for (ArgumentResolver resolver : this.argumentResolvers) { + if (resolver.supportsParameter(parameter)) { + result = resolver; this.argumentResolverCache.put(parameter, result); break; } @@ -103,8 +105,7 @@ public class ArgumentResolverComposite implements ArgumentResolver { /** * Add the given {@link ArgumentResolver}s. */ - public ArgumentResolverComposite addResolvers( - List argumentResolvers) { + public ArgumentResolverComposite addResolvers(List argumentResolvers) { if (argumentResolvers != null) { for (ArgumentResolver resolver : argumentResolvers) { this.argumentResolvers.add(resolver); @@ -113,4 +114,12 @@ public class ArgumentResolverComposite implements ArgumentResolver { return this; } + public void setPubSubChannelRegistry(PubSubChannelRegistry registry) { + for (ArgumentResolver resolver : this.argumentResolvers) { + if (resolver instanceof PubSubChannelRegistryAware) { + ((PubSubChannelRegistryAware) resolver).setPubSubChannelRegistry(registry); + } + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java index 6b3a762e5a5..d4451a8c6e1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java @@ -19,23 +19,25 @@ package org.springframework.web.messaging.service.method; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; +import org.springframework.web.messaging.PubSubChannelRegistry; +import org.springframework.web.messaging.PubSubChannelRegistryAware; import org.springframework.web.messaging.PubSubHeaders; +import org.springframework.web.messaging.support.SessionMessageChannel; /** * @author Rossen Stoyanchev * @since 4.0 */ -public class MessageChannelArgumentResolver implements ArgumentResolver { +public class MessageChannelArgumentResolver implements ArgumentResolver, PubSubChannelRegistryAware { - private final MessageChannel publishChannel; + private MessageChannel> messageBrokerChannel; - public MessageChannelArgumentResolver(MessageChannel publishChannel) { - Assert.notNull(publishChannel, "publishChannel is required"); - this.publishChannel = publishChannel; + @Override + public void setPubSubChannelRegistry(PubSubChannelRegistry registry) { + this.messageBrokerChannel = registry.getMessageBrokerChannel(); } @Override @@ -45,27 +47,9 @@ public class MessageChannelArgumentResolver implements ArgumentResolver { @Override public Object resolveArgument(MethodParameter parameter, Message message) throws Exception { - + Assert.notNull(this.messageBrokerChannel, "messageBrokerChannel is required"); final String sessionId = PubSubHeaders.fromMessageHeaders(message.getHeaders()).getSessionId(); - - return new MessageChannel>() { - - @Override - public boolean send(Message message) { - return send(message, -1); - } - - @SuppressWarnings("unchecked") - @Override - public boolean send(Message message, long timeout) { - PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); - headers.setSessionId(sessionId); - MessageBuilder messageToSend = MessageBuilder.fromPayloadAndHeaders( - message.getPayload(), headers.toMessageHeaders()); - publishChannel.send(messageToSend.build()); - return true; - } - }; + return new SessionMessageChannel(this.messageBrokerChannel, sessionId); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java index 4d53b339044..2e4d9a4738f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java @@ -19,10 +19,10 @@ package org.springframework.web.messaging.service.method; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.MessageFactory; -import org.springframework.messaging.support.GenericMessageFactory; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; +import org.springframework.web.messaging.PubSubChannelRegistry; +import org.springframework.web.messaging.PubSubChannelRegistryAware; import org.springframework.web.messaging.PubSubHeaders; @@ -30,23 +30,16 @@ import org.springframework.web.messaging.PubSubHeaders; * @author Rossen Stoyanchev * @since 4.0 */ -public class MessageReturnValueHandler implements ReturnValueHandler { +public class MessageReturnValueHandler implements ReturnValueHandler, PubSubChannelRegistryAware { - private final MessageChannel clientChannel; - - private MessageFactory messageFactory = new GenericMessageFactory(); + private MessageChannel clientChannel; - public MessageReturnValueHandler(MessageChannel clientChannel) { - Assert.notNull(clientChannel, "clientChannel is required"); - this.clientChannel = clientChannel; + @Override + public void setPubSubChannelRegistry(PubSubChannelRegistry registry) { + this.clientChannel = registry.getClientOutputChannel(); } - public void setMessageFactory(MessageFactory messageFactory) { - this.messageFactory = messageFactory; - } - - @Override public boolean supportsReturnType(MethodParameter returnType) { Class paramType = returnType.getParameterType(); @@ -69,6 +62,8 @@ public class MessageReturnValueHandler implements ReturnValueHandler { public void handleReturnValue(Object returnValue, MethodParameter returnType, Message message) throws Exception { + Assert.notNull(this.clientChannel, "No clientChannel to send messages to"); + Message returnMessage = (Message) returnValue; if (returnMessage == null) { return; @@ -91,7 +86,8 @@ public class MessageReturnValueHandler implements ReturnValueHandler { returnHeaders.setSessionId(sessionId); returnHeaders.setSubscriptionId(subscriptionId); - return MessageBuilder.fromPayloadAndHeaders(returnMessage.getPayload(), returnHeaders.toMessageHeaders()).build(); + Object payload = returnMessage.getPayload(); + return MessageBuilder.fromPayloadAndHeaders(payload, returnHeaders.toMessageHeaders()).build(); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/ReturnValueHandlerComposite.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/ReturnValueHandlerComposite.java index 195e755294c..f66044d45fb 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/ReturnValueHandlerComposite.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/ReturnValueHandlerComposite.java @@ -22,6 +22,8 @@ import java.util.List; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.util.Assert; +import org.springframework.web.messaging.PubSubChannelRegistry; +import org.springframework.web.messaging.PubSubChannelRegistryAware; /** @@ -77,4 +79,12 @@ public class ReturnValueHandlerComposite implements ReturnValueHandler { handler.handleReturnValue(returnValue, returnType, message); } + public void setPubSubChannelRegistry(PubSubChannelRegistry registry) { + for (ReturnValueHandler handler : this.returnValueHandlers) { + if (handler instanceof PubSubChannelRegistryAware) { + ((PubSubChannelRegistryAware) handler).setPubSubChannelRegistry(registry); + } + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java index 7f7c3ae6e3e..a0488b08371 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java @@ -25,10 +25,11 @@ import java.util.concurrent.ConcurrentHashMap; import org.springframework.http.MediaType; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.web.messaging.MessageType; +import org.springframework.web.messaging.PubSubChannelRegistry; +import org.springframework.web.messaging.PubSubChannelRegistryAware; import org.springframework.web.messaging.PubSubHeaders; import org.springframework.web.messaging.converter.CompositeMessageConverter; import org.springframework.web.messaging.converter.MessageConverter; @@ -50,7 +51,10 @@ import reactor.tcp.netty.NettyTcpClient; * @author Rossen Stoyanchev * @since 4.0 */ -public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler { +public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler + implements PubSubChannelRegistryAware { + + private MessageChannel> clientChannel; private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); @@ -61,9 +65,12 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler private final Map> connections = new ConcurrentHashMap>(); - public StompRelayPubSubMessageHandler(SubscribableChannel publishChannel, MessageChannel clientChannel) { - super(publishChannel, clientChannel); + /** + * @param clientChannel a channel for sending messages from the remote message broker + * back to clients + */ + public StompRelayPubSubMessageHandler() { this.tcpClient = new TcpClient.Spec(NettyTcpClient.class) .using(new Environment()) @@ -75,6 +82,11 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler } + @Override + public void setPubSubChannelRegistry(PubSubChannelRegistry registry) { + this.clientChannel = registry.getClientOutputChannel(); + } + public void setMessageConverters(List converters) { this.payloadConverter = new CompositeMessageConverter(converters); } @@ -103,7 +115,6 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler @Override public void accept(TcpConnection connection) { connection.in().consume(new Consumer() { - @SuppressWarnings("unchecked") @Override public void accept(String stompFrame) { if (stompFrame.isEmpty()) { @@ -111,7 +122,7 @@ public class StompRelayPubSubMessageHandler extends AbstractPubSubMessageHandler return; } Message message = stompMessageConverter.toMessage(stompFrame, sessionId); - getClientChannel().send(message); + clientChannel.send(message); } }); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java index 3a03c3a68e2..66d923abfba 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java @@ -28,10 +28,10 @@ import org.springframework.http.MediaType; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; -import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.support.MessageBuilder; -import org.springframework.util.Assert; import org.springframework.web.messaging.MessageType; +import org.springframework.web.messaging.PubSubChannelRegistry; +import org.springframework.web.messaging.PubSubChannelRegistryAware; import org.springframework.web.messaging.converter.CompositeMessageConverter; import org.springframework.web.messaging.converter.MessageConverter; import org.springframework.web.messaging.stomp.StompCommand; @@ -42,21 +42,21 @@ import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; +import reactor.util.Assert; + /** * @author Rossen Stoyanchev * @since 4.0 */ -public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { +public class StompWebSocketHandler extends TextWebSocketHandlerAdapter + implements MessageHandler>, PubSubChannelRegistryAware { - /** - * - */ private static final byte[] EMPTY_PAYLOAD = new byte[0]; private static Log logger = LogFactory.getLog(StompWebSocketHandler.class); - private final MessageChannel publishChannel; + private MessageChannel outputChannel; private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); @@ -65,17 +65,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { private MessageConverter payloadConverter = new CompositeMessageConverter(null); - @SuppressWarnings("unchecked") - public StompWebSocketHandler(MessageChannel publishChannel, SubscribableChannel clientChannel) { - - Assert.notNull(publishChannel, "publishChannel is required"); - Assert.notNull(clientChannel, "clientChannel is required"); - - this.publishChannel = publishChannel; - clientChannel.subscribe(new ClientMessageConsumer()); + @Override + public void setPubSubChannelRegistry(PubSubChannelRegistry registry) { + this.outputChannel = registry.getClientInputChannel(); } - public void setMessageConverters(List converters) { this.payloadConverter = new CompositeMessageConverter(converters); } @@ -91,9 +85,13 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { + Assert.notNull(this.outputChannel, "No output channel for STOMP messages."); this.sessions.put(session.getId(), session); } + /** + * Handle incoming WebSocket messages from clients. + */ @SuppressWarnings("unchecked") @Override protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) { @@ -115,7 +113,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { handleConnect(session, message); } else if (MessageType.MESSAGE.equals(messageType)) { - handleMessage(message); + handlePublish(message); } else if (MessageType.SUBSCRIBE.equals(messageType)) { handleSubscribe(message); @@ -126,7 +124,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { else if (MessageType.DISCONNECT.equals(messageType)) { handleDisconnect(message); } - this.publishChannel.send(message); + this.outputChannel.send(message); } catch (Throwable t) { logger.error("Terminating STOMP session due to failure to send message: ", t); @@ -170,6 +168,9 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } + protected void handlePublish(Message stompMessage) { + } + protected void handleSubscribe(Message message) { // TODO: need a way to communicate back if subscription was successfully created or // not in which case an ERROR should be sent back and close the connection @@ -179,9 +180,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { protected void handleUnsubscribe(Message message) { } - protected void handleMessage(Message stompMessage) { - } - protected void handleDisconnect(Message stompMessage) { } @@ -202,62 +200,62 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter { } } -/* @Override + /* + @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { this.sessions.remove(session.getId()); eventBus.send(AbstractMessageService.CLIENT_CONNECTION_CLOSED_KEY, session.getId()); - }*/ + } + */ + /** + * Handle STOMP messages going back out to WebSocket clients. + */ + @Override + public void handleMessage(Message message) { - private final class ClientMessageConsumer implements MessageHandler> { + StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders()); + headers.setStompCommandIfNotSet(StompCommand.MESSAGE); + if (StompCommand.CONNECTED.equals(headers.getStompCommand())) { + // Ignore for now since we already sent it + return; + } - @Override - public void handleMessage(Message message) { + String sessionId = headers.getSessionId(); + if (sessionId == null) { + logger.error("No \"sessionId\" header in message: " + message); + } + WebSocketSession session = getWebSocketSession(sessionId); + if (session == null) { + logger.error("Session not found: " + message); + } - StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders()); - headers.setStompCommandIfNotSet(StompCommand.MESSAGE); + byte[] payload; + try { + MediaType contentType = headers.getContentType(); + payload = payloadConverter.convertToPayload(message.getPayload(), contentType); + } + catch (Throwable t) { + logger.error("Failed to send " + message, t); + return; + } - if (StompCommand.CONNECTED.equals(headers.getStompCommand())) { - // Ignore for now since we already sent it - return; - } - - String sessionId = headers.getSessionId(); - if (sessionId == null) { - logger.error("No \"sessionId\" header in message: " + message); - } - WebSocketSession session = getWebSocketSession(sessionId); - if (session == null) { - logger.error("Session not found: " + message); - } - - byte[] payload; - try { - MediaType contentType = headers.getContentType(); - payload = payloadConverter.convertToPayload(message.getPayload(), contentType); - } - catch (Throwable t) { - logger.error("Failed to send " + message, t); - return; - } - - try { - Message byteMessage = MessageBuilder.fromPayloadAndHeaders(payload, - headers.toMessageHeaders()).build(); - byte[] bytes = getStompMessageConverter().fromMessage(byteMessage); - session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); - } - catch (Throwable t) { - sendErrorMessage(session, t); - } - finally { - if (StompCommand.ERROR.equals(headers.getStompCommand())) { - try { - session.close(CloseStatus.PROTOCOL_ERROR); - } - catch (IOException e) { - } + try { + Message byteMessage = MessageBuilder.fromPayloadAndHeaders(payload, + headers.toMessageHeaders()).build(); + byte[] bytes = getStompMessageConverter().fromMessage(byteMessage); + session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); + } + catch (Throwable t) { + sendErrorMessage(session, t); + } + finally { + if (StompCommand.ERROR.equals(headers.getStompCommand())) { + try { + session.close(CloseStatus.PROTOCOL_ERROR); + } + catch (IOException e) { } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubChannelRegistryBuilder.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubChannelRegistryBuilder.java new file mode 100644 index 00000000000..58828c64023 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubChannelRegistryBuilder.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.SubscribableChannel; +import org.springframework.util.Assert; +import org.springframework.web.messaging.PubSubChannelRegistry; +import org.springframework.web.messaging.PubSubChannelRegistryAware; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class PubSubChannelRegistryBuilder { + + private SubscribableChannel, MessageHandler>> clientInputChannel; + + private SubscribableChannel, MessageHandler>> clientOutputChannel; + + private SubscribableChannel, MessageHandler>> messageBrokerChannel; + + private Set>> messageHandlers = new HashSet>>(); + + + public PubSubChannelRegistryBuilder( + SubscribableChannel, MessageHandler>> clientOutputChannel, + MessageHandler> clientGateway) { + + Assert.notNull(clientOutputChannel, "clientOutputChannel is required"); + Assert.notNull(clientGateway, "clientGateway is required"); + + this.clientOutputChannel = clientOutputChannel; + this.clientOutputChannel.subscribe(clientGateway); + this.messageHandlers.add(clientGateway); + } + + + public static PubSubChannelRegistryBuilder clientGateway( + SubscribableChannel, MessageHandler>> clientOutputChannel, + MessageHandler> clientGateway) { + + return new PubSubChannelRegistryBuilder(clientOutputChannel, clientGateway); + } + + + public PubSubChannelRegistryBuilder clientMessageHandlers( + SubscribableChannel, MessageHandler>> clientInputChannel, + List>> handlers) { + + Assert.notNull(clientInputChannel, "clientInputChannel is required"); + this.clientInputChannel = clientInputChannel; + + for (MessageHandler> handler : handlers) { + this.clientInputChannel.subscribe(handler); + this.messageHandlers.add(handler); + } + + return this; + } + + public PubSubChannelRegistryBuilder messageBrokerGateway( + SubscribableChannel, MessageHandler>> messageBrokerChannel, + MessageHandler> messageBrokerGateway) { + + Assert.notNull(messageBrokerChannel, "messageBrokerChannel is required"); + Assert.notNull(messageBrokerGateway, "messageBrokerGateway is required"); + + this.messageBrokerChannel = messageBrokerChannel; + this.messageBrokerChannel.subscribe(messageBrokerGateway); + this.messageHandlers.add(messageBrokerGateway); + + return this; + } + + public PubSubChannelRegistry build() { + + PubSubChannelRegistry registry = new PubSubChannelRegistry() { + + @Override + public MessageChannel> getClientInputChannel() { + return clientInputChannel; + } + + @Override + public MessageChannel> getClientOutputChannel() { + return clientOutputChannel; + } + + @Override + public MessageChannel> getMessageBrokerChannel() { + return messageBrokerChannel; + } + }; + + for (MessageHandler> handler : this.messageHandlers) { + if (handler instanceof PubSubChannelRegistryAware) { + ((PubSubChannelRegistryAware) handler).setPubSubChannelRegistry(registry); + } + } + + return registry; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java new file mode 100644 index 00000000000..998a42e3cc9 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.web.messaging.PubSubHeaders; + +import reactor.util.Assert; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class SessionMessageChannel implements MessageChannel> { + + private MessageChannel> delegate; + + private final String sessionId; + + + public SessionMessageChannel(MessageChannel> delegate, String sessionId) { + Assert.notNull(delegate, "delegate is required"); + Assert.notNull(sessionId, "sessionId is required"); + this.sessionId = sessionId; + this.delegate = delegate; + } + + @Override + public boolean send(Message message) { + return send(message, -1); + } + + @Override + public boolean send(Message message, long timeout) { + PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + headers.setSessionId(this.sessionId); + MessageBuilder messageToSend = MessageBuilder.fromPayloadAndHeaders( + message.getPayload(), headers.toMessageHeaders()); + this.delegate.send(messageToSend.build()); + return true; + } +}