diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java index 443f7e5a859..7e3d7f12b69 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -69,14 +69,17 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { /** - * For internal use. - *

The original destination used by a client when subscribing. Such a - * destination may have been modified (e.g. user destinations) on the server - * side. This header provides a hint so messages sent to clients may have - * a destination matching to their original subscription. + * A header for internal use with "user" destinations where we need to + * restore the destination prior to sending messages to clients. */ public static final String ORIGINAL_DESTINATION = "simpOrigDestination"; + /** + * A header that indicates to the broker that the sender will ignore errors. + * The header is simply checked for presence or absence. + */ + public static final String IGNORE_ERROR = "simpIgnoreError"; + /** * A constructor for creating new message headers. diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java index 2e4faed3185..3964f7339ca 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java @@ -17,7 +17,10 @@ package org.springframework.messaging.simp.config; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.BeanInitializationException; @@ -25,6 +28,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Bean; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; import org.springframework.messaging.converter.ByteArrayMessageConverter; import org.springframework.messaging.converter.CompositeMessageConverter; import org.springframework.messaging.converter.DefaultContentTypeResolver; @@ -37,6 +41,7 @@ import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler; import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; +import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; import org.springframework.messaging.simp.user.DefaultUserDestinationResolver; import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; import org.springframework.messaging.simp.user.UserDestinationMessageHandler; @@ -278,13 +283,26 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC @Bean public AbstractBrokerMessageHandler stompBrokerRelayMessageHandler() { - AbstractBrokerMessageHandler handler = getBrokerRegistry().getStompBrokerRelay(brokerChannel()); - return (handler != null ? handler : new NoOpBrokerMessageHandler()); + StompBrokerRelayMessageHandler handler = getBrokerRegistry().getStompBrokerRelay(brokerChannel()); + if (handler == null) { + return new NoOpBrokerMessageHandler(); + } + String destination = getBrokerRegistry().getUserDestinationBroadcast(); + if (destination != null) { + Map map = new HashMap(1); + map.put(destination, userDestinationMessageHandler()); + handler.setSystemSubscriptions(map); + } + return handler; } @Bean public UserDestinationMessageHandler userDestinationMessageHandler() { - return new UserDestinationMessageHandler(clientInboundChannel(), brokerChannel(), userDestinationResolver()); + UserDestinationMessageHandler handler = new UserDestinationMessageHandler(clientInboundChannel(), + brokerChannel(), userDestinationResolver()); + String destination = getBrokerRegistry().getUserDestinationBroadcast(); + handler.setUserDestinationBroadcast(destination); + return handler; } @Bean diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java index 669a751ba04..6970905322e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java @@ -49,6 +49,8 @@ public class MessageBrokerRegistry { private String userDestinationPrefix; + private String userDestinationBroadcast; + private PathMatcher pathMatcher; @@ -137,6 +139,24 @@ public class MessageBrokerRegistry { return this.userDestinationPrefix; } + /** + * Set a destination to broadcast messages to that remain unresolved because + * the user is not connected. In a multi-application server scenario this + * gives other application servers a chance to try. + *

Note: this option applies only when the + * {@link #enableStompBrokerRelay "broker relay"} is enabled. + *

By default this is not set. + * @param destination the destination to forward unresolved + * messages to, e.g. "/topic/unresolved-user-destination". + */ + public void setUserDestinationBroadcast(String destination) { + this.userDestinationBroadcast = destination; + } + + protected String getUserDestinationBroadcast() { + return this.userDestinationBroadcast; + } + /** * Configure the PathMatcher to use to match the destinations of incoming * messages to {@code @MessageMapping} and {@code @SubscribeMapping} methods. diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index fc74ddc3601..dd30ba706b4 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -17,6 +17,7 @@ package org.springframework.messaging.simp.stomp; import java.util.Collection; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; @@ -26,6 +27,7 @@ import java.util.concurrent.atomic.AtomicInteger; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageDeliveryException; +import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -112,6 +114,8 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private String virtualHost; + private final Map systemSubscriptions = new HashMap(4); + private TcpOperations tcpClient; private MessageHeaderInitializer headerInitializer; @@ -281,6 +285,27 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler return this.systemPasscode; } + /** + * Configure one more destinations to subscribe to on the shared "system" + * connection along with MessageHandler's to handle received messages. + *

This is for internal use in a multi-application server scenario where + * servers forward messages to each other (e.g. unresolved user destinations). + * @param subscriptions the destinations to subscribe to. + */ + public void setSystemSubscriptions(Map subscriptions) { + this.systemSubscriptions.clear(); + if (subscriptions != null) { + this.systemSubscriptions.putAll(subscriptions); + } + } + + /** + * Return the configured map with subscriptions on the "system" connection. + */ + public Map getSystemSubscriptions() { + return this.systemSubscriptions; + } + /** * Set the value of the "host" header to use in STOMP CONNECT frames. When this * property is configured, a "host" header will be added to every STOMP frame sent to @@ -532,6 +557,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler return this.sessionId; } + protected TcpConnection getTcpConnection() { + return this.tcpConnection; + } + @Override public void afterConnected(TcpConnection connection) { if (logger.isDebugEnabled()) { @@ -579,13 +608,14 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler headerAccessor.setUser(this.connectHeaders.getUser()); headerAccessor.setMessage(errorText); Message errorMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders()); - headerAccessor.setImmutable(); - sendMessageToClient(errorMessage); + handleInboundMessage(errorMessage); } } - protected void sendMessageToClient(Message message) { + protected void handleInboundMessage(Message message) { if (this.isRemoteClientSession) { + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + accessor.setImmutable(); StompBrokerRelayMessageHandler.this.getClientOutboundChannel().send(message); } } @@ -610,8 +640,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler logger.trace("Received " + accessor.getDetailedLogMessage(message.getPayload())); } - accessor.setImmutable(); - sendMessageToClient(message); + handleInboundMessage(message); } /** @@ -825,7 +854,6 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } } - private class SystemStompConnectionHandler extends StompConnectionHandler { public SystemStompConnectionHandler(StompHeaderAccessor connectHeaders) { @@ -839,6 +867,63 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } super.afterStompConnected(connectedHeaders); publishBrokerAvailableEvent(); + sendSystemSubscriptions(); + } + + private void sendSystemSubscriptions() { + int i = 0; + for (String destination : getSystemSubscriptions().keySet()) { + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); + accessor.setSubscriptionId(String.valueOf(i++)); + accessor.setDestination(destination); + if (logger.isDebugEnabled()) { + logger.debug("Subscribing to " + destination + " on \"system\" connection."); + } + TcpConnection conn = getTcpConnection(); + if (conn != null) { + MessageHeaders headers = accessor.getMessageHeaders(); + conn.send(MessageBuilder.createMessage(EMPTY_PAYLOAD, headers)).addCallback( + new ListenableFutureCallback() { + public void onSuccess(Void result) { + } + public void onFailure(Throwable ex) { + String error = "Failed to subscribe in \"system\" session."; + handleTcpConnectionFailure(error, ex); + } + }); + } + } + } + + @Override + protected void handleInboundMessage(Message message) { + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + if (StompCommand.MESSAGE.equals(accessor.getCommand())) { + String destination = accessor.getDestination(); + if (destination == null) { + if (logger.isDebugEnabled()) { + logger.debug("Got message on \"system\" connection, with no destination: " + + accessor.getDetailedLogMessage(message.getPayload())); + } + return; + } + if (!getSystemSubscriptions().containsKey(destination)) { + if (logger.isDebugEnabled()) { + logger.debug("Got message on \"system\" connection with no handler: " + + accessor.getDetailedLogMessage(message.getPayload())); + } + return; + } + try { + MessageHandler handler = getSystemSubscriptions().get(destination); + handler.handleMessage(message); + } + catch (Throwable ex) { + if (logger.isDebugEnabled()) { + logger.debug("Error while handling message on \"system\" connection.", ex); + } + } + } } @Override @@ -857,7 +942,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler public ListenableFuture forward(Message message, StompHeaderAccessor accessor) { try { ListenableFuture future = super.forward(message, accessor); - future.get(); + if (message.getHeaders().get(SimpMessageHeaderAccessor.IGNORE_ERROR) == null) { + future.get(); + } return future; } catch (Throwable ex) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java index 2223af3dba0..9cf7730e5ff 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java @@ -16,7 +16,11 @@ package org.springframework.messaging.simp.user; -import java.util.Set; +import static org.springframework.messaging.simp.SimpMessageHeaderAccessor.*; +import static org.springframework.messaging.support.MessageHeaderAccessor.getAccessor; + +import java.util.Arrays; +import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -24,6 +28,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.context.SmartLifecycle; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessagingException; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.core.MessageSendingOperations; @@ -33,6 +38,7 @@ import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageHeaderInitializer; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * {@code MessageHandler} with support for "user" destinations. @@ -53,9 +59,11 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec private final SubscribableChannel brokerChannel; + private final UserDestinationResolver destinationResolver; + private final MessageSendingOperations messagingTemplate; - private final UserDestinationResolver destinationResolver; + private BroadcastHandler broadcastHandler; private MessageHeaderInitializer headerInitializer; @@ -93,6 +101,25 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec return this.destinationResolver; } + /** + * Set a destination to broadcast messages to that remain unresolved because + * the user is not connected. In a multi-application server scenario this + * gives other application servers a chance to try. + *

By default this is not set. + * @param destination the target destination. + */ + public void setUserDestinationBroadcast(String destination) { + this.broadcastHandler = (StringUtils.hasText(destination) ? + new BroadcastHandler(this.messagingTemplate, destination) : null); + } + + /** + * Return the configured destination for unresolved messages. + */ + public String getUserDestinationBroadcast() { + return (this.broadcastHandler != null ? this.broadcastHandler.getBroadcastDestination() : null); + } + /** * Return the messaging template used to send resolved messages to the * broker channel. @@ -164,29 +191,35 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec @Override public void handleMessage(Message message) throws MessagingException { + if (this.broadcastHandler != null) { + message = this.broadcastHandler.preHandle(message); + if (message == null) { + return; + } + } UserDestinationResult result = this.destinationResolver.resolveDestination(message); if (result == null) { return; } - Set destinations = result.getTargetDestinations(); - if (destinations.isEmpty()) { + if (result.getTargetDestinations().isEmpty()) { if (logger.isTraceEnabled()) { - logger.trace("No user destinations found for " + result.getSourceDestination()); + logger.trace("No active sessions for user destination: " + result.getSourceDestination()); + } + if (this.broadcastHandler != null) { + this.broadcastHandler.handleUnresolved(message); } return; } - if (SimpMessageType.MESSAGE.equals(SimpMessageHeaderAccessor.getMessageType(message.getHeaders()))) { - SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(message); - initHeaders(accessor); - String header = SimpMessageHeaderAccessor.ORIGINAL_DESTINATION; - accessor.setNativeHeader(header, result.getSubscribeDestination()); - message = MessageBuilder.createMessage(message.getPayload(), accessor.getMessageHeaders()); + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(message); + initHeaders(accessor); + accessor.setNativeHeader(ORIGINAL_DESTINATION, result.getSubscribeDestination()); + accessor.setLeaveMutable(true); + message = MessageBuilder.createMessage(message.getPayload(), accessor.getMessageHeaders()); + if (logger.isTraceEnabled()) { + logger.trace("Translated " + result.getSourceDestination() + " -> " + result.getTargetDestinations()); } - if (logger.isDebugEnabled()) { - logger.debug("Translated " + result.getSourceDestination() + " -> " + destinations); - } - for (String destination : destinations) { - this.messagingTemplate.send(destination, message); + for (String target : result.getTargetDestinations()) { + this.messagingTemplate.send(target, message); } } @@ -201,4 +234,73 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec return "UserDestinationMessageHandler[" + this.destinationResolver + "]"; } + + /** + * A handler that broadcasts locally unresolved messages to the broker and + * also handles similar broadcasts received from the broker. + */ + private static class BroadcastHandler { + + private static final List NO_COPY_LIST = Arrays.asList("subscription", "message-id"); + + + private final MessageSendingOperations messagingTemplate; + + private final String broadcastDestination; + + + public BroadcastHandler(MessageSendingOperations template, String destination) { + this.messagingTemplate = template; + this.broadcastDestination = destination; + } + + + public String getBroadcastDestination() { + return this.broadcastDestination; + } + + public Message preHandle(Message message) throws MessagingException { + String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); + if (!getBroadcastDestination().equals(destination)) { + return message; + } + SimpMessageHeaderAccessor accessor = getAccessor(message, SimpMessageHeaderAccessor.class); + if (accessor.getSessionId() == null) { + // Our own broadcast + return null; + } + destination = accessor.getFirstNativeHeader(ORIGINAL_DESTINATION); + if (logger.isTraceEnabled()) { + logger.trace("Checking unresolved user destination: " + destination); + } + SimpMessageHeaderAccessor newAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + for (String name : accessor.toNativeHeaderMap().keySet()) { + if (NO_COPY_LIST.contains(name)) { + continue; + } + newAccessor.setNativeHeader(name, accessor.getFirstNativeHeader(name)); + } + newAccessor.setDestination(destination); + newAccessor.setHeader(SimpMessageHeaderAccessor.IGNORE_ERROR, true); // ensure send doesn't block + return MessageBuilder.createMessage(message.getPayload(), newAccessor.getMessageHeaders()); + } + + public void handleUnresolved(Message message) { + MessageHeaders headers = message.getHeaders(); + if (SimpMessageHeaderAccessor.getFirstNativeHeader(ORIGINAL_DESTINATION, headers) != null) { + // Re-broadcast + return; + } + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(message); + String destination = accessor.getDestination(); + accessor.setNativeHeader(ORIGINAL_DESTINATION, destination); + accessor.setLeaveMutable(true); + message = MessageBuilder.createMessage(message.getPayload(), accessor.getMessageHeaders()); + if (logger.isTraceEnabled()) { + logger.trace("Translated " + destination + " -> " + getBroadcastDestination()); + } + this.messagingTemplate.send(getBroadcastDestination(), message); + } + } + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java index 881d38e1c2f..8240760a966 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java @@ -115,7 +115,7 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { */ public boolean containsNativeHeader(String headerName) { Map> map = getNativeHeaders(); - return (map != null ? map.containsKey(headerName) : false); + return (map != null && map.containsKey(headerName)); } /** @@ -207,4 +207,16 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { return nativeHeaders.remove(name); } + @SuppressWarnings("unchecked") + public static String getFirstNativeHeader(String headerName, Map headers) { + Map> map = (Map>) headers.get(NATIVE_HEADERS); + if (map != null) { + List values = map.get(headerName); + if (values != null) { + return values.get(0); + } + } + return null; + } + } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index bd45517c774..2584d851a85 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -79,8 +79,11 @@ import static org.mockito.Mockito.*; public class MessageBrokerConfigurationTests { private ApplicationContext defaultContext = new AnnotationConfigApplicationContext(DefaultConfig.class); + private ApplicationContext simpleBrokerContext = new AnnotationConfigApplicationContext(SimpleBrokerConfig.class); + private ApplicationContext brokerRelayContext = new AnnotationConfigApplicationContext(BrokerRelayConfig.class); + private ApplicationContext customContext = new AnnotationConfigApplicationContext(CustomConfig.class); @@ -401,7 +404,17 @@ public class MessageBrokerConfigurationTests { assertEquals("a.a", handler.getPathMatcher().combine("a", "a")); } + @Test + public void userDestinationBroadcast() throws Exception { + StompBrokerRelayMessageHandler relay = this.brokerRelayContext.getBean(StompBrokerRelayMessageHandler.class); + UserDestinationMessageHandler userHandler = this.brokerRelayContext.getBean(UserDestinationMessageHandler.class); + assertEquals("/topic/unresolved", userHandler.getUserDestinationBroadcast()); + assertNotNull(relay.getSystemSubscriptions()); + assertSame(userHandler, relay.getSystemSubscriptions().get("/topic/unresolved")); + } + + @SuppressWarnings("unused") @Controller static class TestController { @@ -417,7 +430,7 @@ public class MessageBrokerConfigurationTests { } } - + @SuppressWarnings("unused") @Configuration static class SimpleBrokerConfig extends AbstractMessageBrokerConfiguration { @@ -451,6 +464,7 @@ public class MessageBrokerConfigurationTests { @Override public void configureMessageBroker(MessageBrokerRegistry registry) { registry.enableStompBrokerRelay("/topic", "/queue").setAutoStartup(true); + registry.setUserDestinationBroadcast("/topic/unresolved"); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java index 2115303aa70..ed7180608d0 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java @@ -15,15 +15,22 @@ */ package org.springframework.messaging.simp.stomp; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.Callable; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.StubMessageChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; @@ -37,8 +44,6 @@ import org.springframework.messaging.tcp.TcpOperations; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.ListenableFutureTask; -import static org.junit.Assert.*; - /** * Unit tests for StompBrokerRelayMessageHandler. * @@ -74,62 +79,52 @@ public class StompBrokerRelayMessageHandlerTests { @Test - public void testVirtualHostHeader() throws Exception { + public void virtualHost() throws Exception { + + this.brokerRelay.setVirtualHost("ABC"); - String virtualHost = "ABC"; - this.brokerRelay.setVirtualHost(virtualHost); this.brokerRelay.start(); + this.brokerRelay.handleMessage(connectMessage("sess1", "joe")); - String sessionId = "sess1"; - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - headers.setSessionId(sessionId); - this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); + assertEquals(2, this.tcpClient.getSentMessages().size()); - List> sent = this.tcpClient.connection.messages; - assertEquals(2, sent.size()); + StompHeaderAccessor headers1 = this.tcpClient.getSentHeaders(0); + assertEquals(StompCommand.CONNECT, headers1.getCommand()); + assertEquals(StompBrokerRelayMessageHandler.SYSTEM_SESSION_ID, headers1.getSessionId()); + assertEquals("ABC", headers1.getHost()); - StompHeaderAccessor headers1 = StompHeaderAccessor.wrap(sent.get(0)); - assertEquals(virtualHost, headers1.getHost()); - assertNotNull("The prepared message does not have an accessor", - MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class)); - - StompHeaderAccessor headers2 = StompHeaderAccessor.wrap(sent.get(1)); - assertEquals(sessionId, headers2.getSessionId()); - assertEquals(virtualHost, headers2.getHost()); - assertNotNull("The prepared message does not have an accessor", - MessageHeaderAccessor.getAccessor(sent.get(1), MessageHeaderAccessor.class)); + StompHeaderAccessor headers2 = this.tcpClient.getSentHeaders(1); + assertEquals(StompCommand.CONNECT, headers2.getCommand()); + assertEquals("sess1", headers2.getSessionId()); + assertEquals("ABC", headers2.getHost()); } @Test - public void testLoginPasscode() throws Exception { - - this.brokerRelay.setClientLogin("clientlogin"); - this.brokerRelay.setClientPasscode("clientpasscode"); + public void loginAndPasscode() throws Exception { this.brokerRelay.setSystemLogin("syslogin"); this.brokerRelay.setSystemPasscode("syspasscode"); + this.brokerRelay.setClientLogin("clientlogin"); + this.brokerRelay.setClientPasscode("clientpasscode"); this.brokerRelay.start(); + this.brokerRelay.handleMessage(connectMessage("sess1", "joe")); - String sessionId = "sess1"; - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - headers.setSessionId(sessionId); - this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); + assertEquals(2, this.tcpClient.getSentMessages().size()); - List> sent = this.tcpClient.connection.messages; - assertEquals(2, sent.size()); - - StompHeaderAccessor headers1 = StompHeaderAccessor.wrap(sent.get(0)); + StompHeaderAccessor headers1 = this.tcpClient.getSentHeaders(0); + assertEquals(StompCommand.CONNECT, headers1.getCommand()); assertEquals("syslogin", headers1.getLogin()); assertEquals("syspasscode", headers1.getPasscode()); - StompHeaderAccessor headers2 = StompHeaderAccessor.wrap(sent.get(1)); + StompHeaderAccessor headers2 = this.tcpClient.getSentHeaders(1); + assertEquals(StompCommand.CONNECT, headers2.getCommand()); assertEquals("clientlogin", headers2.getLogin()); assertEquals("clientpasscode", headers2.getPasscode()); } @Test - public void testDestinationExcluded() throws Exception { + public void destinationExcluded() throws Exception { this.brokerRelay.start(); @@ -138,89 +133,113 @@ public class StompBrokerRelayMessageHandlerTests { headers.setDestination("/user/daisy/foo"); this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); - List> sent = this.tcpClient.connection.messages; - assertEquals(1, sent.size()); - assertEquals(StompCommand.CONNECT, StompHeaderAccessor.wrap(sent.get(0)).getCommand()); - assertNotNull("The prepared message does not have an accessor", - MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class)); + assertEquals(1, this.tcpClient.getSentMessages().size()); + StompHeaderAccessor headers1 = this.tcpClient.getSentHeaders(0); + assertEquals(StompCommand.CONNECT, headers1.getCommand()); + assertEquals(StompBrokerRelayMessageHandler.SYSTEM_SESSION_ID, headers1.getSessionId()); } @Test - public void testOutboundMessageIsEnriched() throws Exception { + public void messageFromBrokerIsEnriched() throws Exception { this.brokerRelay.start(); + this.brokerRelay.handleMessage(connectMessage("sess1", "joe")); - String sessionId = "sess1"; - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - headers.setSessionId(sessionId); - headers.setUser(new TestPrincipal("joe")); - this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); + assertEquals(2, this.tcpClient.getSentMessages().size()); + assertEquals(StompCommand.CONNECT, this.tcpClient.getSentHeaders(0).getCommand()); + assertEquals(StompCommand.CONNECT, this.tcpClient.getSentHeaders(1).getCommand()); - List> sent = this.tcpClient.connection.messages; - assertEquals(2, sent.size()); + this.tcpClient.handleMessage(message(StompCommand.MESSAGE, null, null, null)); - StompHeaderAccessor responseHeaders = StompHeaderAccessor.create(StompCommand.MESSAGE); - responseHeaders.setLeaveMutable(true); - Message response = MessageBuilder.createMessage(new byte[0], responseHeaders.getMessageHeaders()); - this.tcpClient.connectionHandler.handleMessage(response); - - Message actual = this.outboundChannel.getMessages().get(0); - StompHeaderAccessor actualHeaders = StompHeaderAccessor.getAccessor(actual, StompHeaderAccessor.class); - assertEquals(sessionId, actualHeaders.getSessionId()); - assertEquals("joe", actualHeaders.getUser().getName()); + Message message = this.outboundChannel.getMessages().get(0); + StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertEquals("sess1", accessor.getSessionId()); + assertEquals("joe", accessor.getUser().getName()); } // SPR-12820 @Test - public void testConnectWhenBrokerNotAvailable() throws Exception { + public void connectWhenBrokerNotAvailable() throws Exception { this.brokerRelay.start(); this.brokerRelay.stopInternal(); + this.brokerRelay.handleMessage(connectMessage("sess1", "joe")); - String sessionId = "sess1"; - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - headers.setSessionId(sessionId); - headers.setUser(new TestPrincipal("joe")); - this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); - - Message actual = this.outboundChannel.getMessages().get(0); - StompHeaderAccessor actualHeaders = StompHeaderAccessor.getAccessor(actual, StompHeaderAccessor.class); - assertEquals(StompCommand.ERROR, actualHeaders.getCommand()); - assertEquals(sessionId, actualHeaders.getSessionId()); - assertEquals("joe", actualHeaders.getUser().getName()); - assertEquals("Broker not available.", actualHeaders.getMessage()); + Message message = this.outboundChannel.getMessages().get(0); + StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertEquals(StompCommand.ERROR, accessor.getCommand()); + assertEquals("sess1", accessor.getSessionId()); + assertEquals("joe", accessor.getUser().getName()); + assertEquals("Broker not available.", accessor.getMessage()); } @Test - public void testSendAfterBrokerUnavailable() throws Exception { + public void sendAfterBrokerUnavailable() throws Exception { this.brokerRelay.start(); + assertEquals(1, this.brokerRelay.getConnectionCount()); - String sessionId = "sess1"; - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - headers.setSessionId(sessionId); - headers.setUser(new TestPrincipal("joe")); - this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); - + this.brokerRelay.handleMessage(connectMessage("sess1", "joe")); assertEquals(2, this.brokerRelay.getConnectionCount()); this.brokerRelay.stopInternal(); - - headers = StompHeaderAccessor.create(StompCommand.SEND); - headers.setSessionId(sessionId); - headers.setUser(new TestPrincipal("joe")); - headers.setDestination("/foo"); - this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); - + this.brokerRelay.handleMessage(message(StompCommand.SEND, "sess1", "joe", "/foo")); assertEquals(1, this.brokerRelay.getConnectionCount()); - Message actual = this.outboundChannel.getMessages().get(0); - StompHeaderAccessor actualHeaders = StompHeaderAccessor.getAccessor(actual, StompHeaderAccessor.class); - assertEquals(StompCommand.ERROR, actualHeaders.getCommand()); - assertEquals(sessionId, actualHeaders.getSessionId()); - assertEquals("joe", actualHeaders.getUser().getName()); - assertEquals("Broker not available.", actualHeaders.getMessage()); + Message message = this.outboundChannel.getMessages().get(0); + StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertEquals(StompCommand.ERROR, accessor.getCommand()); + assertEquals("sess1", accessor.getSessionId()); + assertEquals("joe", accessor.getUser().getName()); + assertEquals("Broker not available.", accessor.getMessage()); + } + + @Test + public void systemSubscription() throws Exception { + + MessageHandler handler = mock(MessageHandler.class); + this.brokerRelay.setSystemSubscriptions(Collections.singletonMap("/topic/foo", handler)); + this.brokerRelay.start(); + + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED); + accessor.setLeaveMutable(true); + MessageHeaders headers = accessor.getMessageHeaders(); + this.tcpClient.handleMessage(MessageBuilder.createMessage(new byte[0], headers)); + + assertEquals(2, this.tcpClient.getSentMessages().size()); + assertEquals(StompCommand.CONNECT, this.tcpClient.getSentHeaders(0).getCommand()); + assertEquals(StompCommand.SUBSCRIBE, this.tcpClient.getSentHeaders(1).getCommand()); + assertEquals("/topic/foo", this.tcpClient.getSentHeaders(1).getDestination()); + + Message message = message(StompCommand.MESSAGE, null, null, "/topic/foo"); + this.tcpClient.handleMessage(message); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + verify(handler).handleMessage(captor.capture()); + assertSame(message, captor.getValue()); + } + + private Message connectMessage(String sessionId, String user) { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setSessionId(sessionId); + headers.setUser(new TestPrincipal(user)); + return MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + } + + private Message message(StompCommand command, String sessionId, String user, String destination) { + StompHeaderAccessor accessor = StompHeaderAccessor.create(command); + if (sessionId != null) { + accessor.setSessionId(sessionId); + } + if (user != null) { + accessor.setUser(new TestPrincipal(user)); + } + if (destination != null) { + accessor.setDestination(destination); + } + accessor.setLeaveMutable(true); + return MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); } @@ -254,17 +273,29 @@ public class StompBrokerRelayMessageHandlerTests { private TcpConnectionHandler connectionHandler; + public List> getSentMessages() { + return this.connection.getMessages(); + } + + public StompHeaderAccessor getSentHeaders(int index) { + assertTrue("Size: " + getSentMessages().size(), getSentMessages().size() > index); + Message message = getSentMessages().get(index); + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertNotNull(accessor); + return accessor; + } + @Override - public ListenableFuture connect(TcpConnectionHandler connectionHandler) { - this.connectionHandler = connectionHandler; - connectionHandler.afterConnected(this.connection); + public ListenableFuture connect(TcpConnectionHandler handler) { + this.connectionHandler = handler; + handler.afterConnected(this.connection); return getVoidFuture(); } @Override - public ListenableFuture connect(TcpConnectionHandler connectionHandler, ReconnectStrategy reconnectStrategy) { - this.connectionHandler = connectionHandler; - connectionHandler.afterConnected(this.connection); + public ListenableFuture connect(TcpConnectionHandler handler, ReconnectStrategy strategy) { + this.connectionHandler = handler; + handler.afterConnected(this.connection); return getVoidFuture(); } @@ -272,6 +303,11 @@ public class StompBrokerRelayMessageHandlerTests { public ListenableFuture shutdown() { return getBooleanFuture(); } + + public void handleMessage(Message message) { + this.connectionHandler.handleMessage(message); + } + } @@ -280,6 +316,10 @@ public class StompBrokerRelayMessageHandlerTests { private final List> messages = new ArrayList<>(); + public List> getMessages() { + return this.messages; + } + @Override public ListenableFuture send(Message message) { this.messages.add(message); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java index b112c51f3b5..785aee50fc3 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java @@ -18,7 +18,9 @@ package org.springframework.messaging.simp.user; import static org.junit.Assert.*; import static org.mockito.BDDMockito.*; -import static org.springframework.messaging.simp.SimpMessageHeaderAccessor.ORIGINAL_DESTINATION; +import static org.springframework.messaging.simp.SimpMessageHeaderAccessor.*; + +import java.nio.charset.Charset; import org.junit.Before; import org.junit.Test; @@ -33,6 +35,8 @@ import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.TestPrincipal; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.MessageBuilder; /** @@ -62,7 +66,6 @@ public class UserDestinationMessageHandlerTests { @Test - @SuppressWarnings("rawtypes") public void handleSubscribe() { given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); this.handler.handleMessage(createWith(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo")); @@ -75,7 +78,6 @@ public class UserDestinationMessageHandlerTests { } @Test - @SuppressWarnings("rawtypes") public void handleUnsubscribe() { given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); this.handler.handleMessage(createWith(SimpMessageType.UNSUBSCRIBE, "joe", "123", "/user/queue/foo")); @@ -88,7 +90,6 @@ public class UserDestinationMessageHandlerTests { } @Test - @SuppressWarnings("rawtypes") public void handleMessage() { this.registry.registerSessionId("joe", "123"); given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); @@ -102,6 +103,69 @@ public class UserDestinationMessageHandlerTests { assertEquals("/user/queue/foo", accessor.getFirstNativeHeader(ORIGINAL_DESTINATION)); } + @Test + public void handleMessageWithoutActiveSession() { + this.handler.setUserDestinationBroadcast("/topic/unresolved"); + given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); + this.handler.handleMessage(createWith(SimpMessageType.MESSAGE, "joe", "123", "/user/joe/queue/foo")); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + Mockito.verify(this.brokerChannel).send(captor.capture()); + + Message message = captor.getValue(); + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(message); + assertEquals("/topic/unresolved", accessor.getDestination()); + assertEquals("/user/joe/queue/foo", accessor.getFirstNativeHeader(ORIGINAL_DESTINATION)); + + // Should ignore our own broadcast to brokerChannel + + this.handler.handleMessage(message); + Mockito.verifyNoMoreInteractions(this.brokerChannel); + } + + @Test + public void handleMessageFromBrokerWithActiveSession() { + + this.registry.registerSessionId("joe", "123"); + + this.handler.setUserDestinationBroadcast("/topic/unresolved"); + given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); + + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.MESSAGE); + accessor.setSessionId("system123"); + accessor.setDestination("/topic/unresolved"); + accessor.setNativeHeader(ORIGINAL_DESTINATION, "/user/joe/queue/foo"); + accessor.setNativeHeader("customHeader", "customHeaderValue"); + accessor.setLeaveMutable(true); + byte[] payload = "payload".getBytes(Charset.forName("UTF-8")); + this.handler.handleMessage(MessageBuilder.createMessage(payload, accessor.getMessageHeaders())); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + Mockito.verify(this.brokerChannel).send(captor.capture()); + assertNotNull(captor.getValue()); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(captor.getValue()); + assertEquals("/queue/foo-user123", headers.getDestination()); + assertEquals("/user/queue/foo", headers.getFirstNativeHeader(ORIGINAL_DESTINATION)); + assertEquals("customHeaderValue", headers.getFirstNativeHeader("customHeader")); + assertArrayEquals(payload, (byte[]) captor.getValue().getPayload()); + } + + @Test + public void handleMessageFromBrokerWithoutActiveSession() { + this.handler.setUserDestinationBroadcast("/topic/unresolved"); + given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); + + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.MESSAGE); + accessor.setSessionId("system123"); + accessor.setDestination("/topic/unresolved"); + accessor.setNativeHeader(ORIGINAL_DESTINATION, "/user/joe/queue/foo"); + accessor.setLeaveMutable(true); + byte[] payload = "payload".getBytes(Charset.forName("UTF-8")); + this.handler.handleMessage(MessageBuilder.createMessage(payload, accessor.getMessageHeaders())); + + // No re-broadcast + verifyNoMoreInteractions(this.brokerChannel); + } @Test public void ignoreMessage() { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index 0cb4348b947..1b69b4c75dc 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -159,15 +159,18 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { channelElem = DomUtils.getChildElementByTagName(element, "broker-channel"); RuntimeBeanReference brokerChannel = getMessageChannel("brokerChannel", channelElem, context, source); - RootBeanDefinition broker = registerMessageBroker(element, inChannel, outChannel, brokerChannel, context, source); + + RuntimeBeanReference resolver = registerUserDestResolver(element, sessionRegistry, context, source); + RuntimeBeanReference userDestHandler = registerUserDestHandler(element, inChannel, + brokerChannel, resolver, context, source); + + RootBeanDefinition broker = registerMessageBroker(element, userDestHandler, inChannel, + outChannel, brokerChannel, context, source); RuntimeBeanReference converter = registerMessageConverter(element, context, source); RuntimeBeanReference template = registerMessagingTemplate(element, brokerChannel, converter, context, source); registerAnnotationMethodMessageHandler(element, inChannel, outChannel,converter, template, context, source); - RuntimeBeanReference resolver = registerUserDestinationResolver(element, sessionRegistry, context, source); - registerUserDestinationMessageHandler(inChannel, brokerChannel, resolver, context, source); - Map scopeMap = Collections.singletonMap("websocket", new SimpSessionScope()); RootBeanDefinition scopeConfigurer = new RootBeanDefinition(CustomScopeConfigurer.class); scopeConfigurer.getPropertyValues().add("scopes", scopeMap); @@ -308,11 +311,13 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { return new RuntimeBeanReference(registerBeanDef(beanDef, context, source)); } - private RootBeanDefinition registerMessageBroker(Element messageBrokerElement, RuntimeBeanReference inChannel, - RuntimeBeanReference outChannel, RuntimeBeanReference brokerChannel, ParserContext context, Object source) { + private RootBeanDefinition registerMessageBroker(Element brokerElement, + RuntimeBeanReference userDestHandler, RuntimeBeanReference inChannel, + RuntimeBeanReference outChannel, RuntimeBeanReference brokerChannel, + ParserContext context, Object source) { - Element simpleBrokerElem = DomUtils.getChildElementByTagName(messageBrokerElement, "simple-broker"); - Element brokerRelayElem = DomUtils.getChildElementByTagName(messageBrokerElement, "stomp-broker-relay"); + Element simpleBrokerElem = DomUtils.getChildElementByTagName(brokerElement, "simple-broker"); + Element brokerRelayElem = DomUtils.getChildElementByTagName(brokerElement, "stomp-broker-relay"); ConstructorArgumentValues cavs = new ConstructorArgumentValues(); cavs.addIndexedArgumentValue(0, inChannel); @@ -324,8 +329,8 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { String prefix = simpleBrokerElem.getAttribute("prefix"); cavs.addIndexedArgumentValue(3, Arrays.asList(StringUtils.tokenizeToStringArray(prefix, ","))); brokerDef = new RootBeanDefinition(SimpleBrokerMessageHandler.class, cavs, null); - if (messageBrokerElement.hasAttribute("path-matcher")) { - String pathMatcherRef = messageBrokerElement.getAttribute("path-matcher"); + if (brokerElement.hasAttribute("path-matcher")) { + String pathMatcherRef = brokerElement.getAttribute("path-matcher"); brokerDef.getPropertyValues().add("pathMatcher", new RuntimeBeanReference(pathMatcherRef)); } if (simpleBrokerElem.hasAttribute("scheduler")) { @@ -369,6 +374,13 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { if (brokerRelayElem.hasAttribute("virtual-host")) { values.add("virtualHost", brokerRelayElem.getAttribute("virtual-host")); } + if (brokerElement.hasAttribute("user-destination-broadcast")) { + String destination = brokerElement.getAttribute("user-destination-broadcast"); + ManagedMap map = new ManagedMap(); + map.setSource(source); + map.put(destination, userDestHandler); + values.add("systemSubscriptions", map); + } Class handlerType = StompBrokerRelayMessageHandler.class; brokerDef = new RootBeanDefinition(handlerType, cavs, values); } @@ -471,7 +483,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { return list; } - private RuntimeBeanReference registerUserDestinationResolver(Element brokerElem, + private RuntimeBeanReference registerUserDestResolver(Element brokerElem, RuntimeBeanReference userSessionRegistry, ParserContext context, Object source) { ConstructorArgumentValues cavs = new ConstructorArgumentValues(); @@ -483,15 +495,19 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { return new RuntimeBeanReference(registerBeanDef(beanDef, context, source)); } - private RuntimeBeanReference registerUserDestinationMessageHandler(RuntimeBeanReference inChannel, - RuntimeBeanReference brokerChannel, RuntimeBeanReference userDestinationResolver, - ParserContext context, Object source) { + private RuntimeBeanReference registerUserDestHandler(Element brokerElem, + RuntimeBeanReference inChannel, RuntimeBeanReference brokerChannel, + RuntimeBeanReference userDestinationResolver, ParserContext context, Object source) { ConstructorArgumentValues cavs = new ConstructorArgumentValues(); cavs.addIndexedArgumentValue(0, inChannel); cavs.addIndexedArgumentValue(1, brokerChannel); cavs.addIndexedArgumentValue(2, userDestinationResolver); RootBeanDefinition beanDef = new RootBeanDefinition(UserDestinationMessageHandler.class, cavs, null); + if (brokerElem.hasAttribute("user-destination-broadcast")) { + String destination = brokerElem.getAttribute("user-destination-broadcast"); + beanDef.getPropertyValues().add("userDestinationBroadcast", destination); + } return new RuntimeBeanReference(registerBeanDef(beanDef, context, source)); } diff --git a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd index 8ec7535b3c2..749f1637a1f 100644 --- a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd +++ b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd @@ -853,6 +853,17 @@ The prefix used to identify user destinations. Any destinations that do not start with the given prefix are not be resolved. The default value is "/user/". + ]]> + + + + + diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 20a574aa1b5..1be3a4e9ebf 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -275,6 +275,12 @@ public class MessageBrokerBeanDefinitionParserTests { // expected } + UserDestinationMessageHandler userDestHandler = this.appContext.getBean(UserDestinationMessageHandler.class); + assertEquals("/topic/unresolved", userDestHandler.getUserDestinationBroadcast()); + assertNotNull(messageBroker.getSystemSubscriptions()); + assertSame(userDestHandler, messageBroker.getSystemSubscriptions().get("/topic/unresolved")); + + String name = "webSocketMessageBrokerStats"; WebSocketMessageBrokerStats stats = this.appContext.getBean(name, WebSocketMessageBrokerStats.class); String actual = stats.toString(); diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-relay.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-relay.xml index eb3b9c4d6eb..5462d17935d 100644 --- a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-relay.xml +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-relay.xml @@ -4,7 +4,7 @@ xsi:schemaLocation="http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd http://www.springframework.org/schema/websocket http://www.springframework.org/schema/websocket/spring-websocket.xsd"> - + diff --git a/src/asciidoc/web-websocket.adoc b/src/asciidoc/web-websocket.adoc index 71f1d74aa96..9ecd118bcf5 100644 --- a/src/asciidoc/web-websocket.adoc +++ b/src/asciidoc/web-websocket.adoc @@ -1741,6 +1741,13 @@ http://activemq.apache.org/delete-inactive-destinations.html[configuration optio for purging inactive destinations. ==== +In a multi-application server scenario a user destination may remain unresolved because +the user is connected to a different server. In such cases you can configure a +destination to broadcast unresolved messages to so that other servers have a chance to try. +This can be done through the `userDestinationBroadcast` property of the +`MessageBrokerRegistry` in Java config and the `user-destination-broadcast` attribute +of the `message-broker` element in XML. +