From 281588d7bb9b4c15e72d743f5d88cc5ff39ff2de Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 6 May 2015 18:31:26 -0400 Subject: [PATCH] Add SimpUserRegistry with multi-server support This change introduces SimpUserRegistry exposing an API to access information about connected users, their sessions, and subscriptions with STOMP/WebSocket messaging. Provides are methods to access users as well as a method to find subscriptions given a Matcher strategy. The DefaultSimpUserRegistry implementation is also a SmartApplicationListener which listesn for ApplicationContext events when users connect, disconnect, subscribe, and unsubscribe to destinations. The MultiServerUserRegistry implementation is a composite that aggregates user information from the local SimpUserRegistry as well as snapshots of user on remote application servers. UserRegistryMessageHandler is used with MultiServerUserRegistry. It broadcats user registry information through the broker and listens for similar broadcasts from other servers. This must be enabled explicitly when configuring the STOMP broker relay. The existing UserSessionRegistry which was primiarly used internally to resolve a user name to session id's has been deprecated and is no longer used. If an application configures a custom UserSessionRegistr still, it will be adapted accordingly to SimpUserRegistry but the effect is rather limited (comparable to pre-existing functionality) and will not work in multi-server scenarios. Issue: SPR-12029 --- .../AbstractMessageBrokerConfiguration.java | 73 ++- .../simp/config/MessageBrokerRegistry.java | 23 +- .../config/StompBrokerRelayRegistration.java | 46 +- .../user/DefaultUserDestinationResolver.java | 50 +- .../simp/user/DefaultUserSessionRegistry.java | 6 +- .../simp/user/MultiServerUserRegistry.java | 488 ++++++++++++++++++ .../messaging/simp/user/SimpSession.java | 43 ++ .../messaging/simp/user/SimpSubscription.java | 41 ++ .../simp/user/SimpSubscriptionMatcher.java | 33 ++ .../messaging/simp/user/SimpUser.java | 52 ++ .../messaging/simp/user/SimpUserRegistry.java | 49 ++ .../user/UserDestinationMessageHandler.java | 4 +- .../simp/user/UserDestinationResult.java | 2 +- .../simp/user/UserRegistryMessageHandler.java | 136 +++++ .../simp/user/UserSessionRegistry.java | 29 +- .../simp/user/UserSessionRegistryAdapter.java | 121 +++++ .../messaging/simp/user/package-info.java | 2 +- .../MessageBrokerConfigurationTests.java | 92 ++-- .../StompBrokerRelayRegistrationTests.java | 37 +- .../DefaultUserDestinationResolverTests.java | 91 ++-- .../user/DefaultUserSessionRegistryTests.java | 82 --- .../user/MultiServerUserRegistryTests.java | 167 ++++++ .../messaging/simp/user/TestSimpSession.java | 62 +++ .../simp/user/TestSimpSubscription.java | 52 ++ .../messaging/simp/user/TestSimpUser.java | 62 +++ .../UserDestinationMessageHandlerTests.java | 23 +- .../user/UserRegistryMessageHandlerTests.java | 183 +++++++ .../MessageBrokerBeanDefinitionParser.java | 214 +++++--- .../config/WebSocketNamespaceUtils.java | 6 +- .../WebMvcStompEndpointRegistry.java | 5 +- ...cketMessageBrokerConfigurationSupport.java | 51 +- .../messaging/AbstractSubProtocolEvent.java | 24 + .../messaging/DefaultSimpUserRegistry.java | 336 ++++++++++++ .../socket/messaging/SessionConnectEvent.java | 6 + .../messaging/SessionConnectedEvent.java | 6 + .../messaging/SessionDisconnectEvent.java | 13 +- .../messaging/SessionSubscribeEvent.java | 6 + .../messaging/SessionUnsubscribeEvent.java | 6 + .../messaging/StompSubProtocolHandler.java | 53 +- .../socket/config/spring-websocket-4.2.xsd | 32 +- ...essageBrokerBeanDefinitionParserTests.java | 46 +- .../WebMvcStompEndpointRegistryTests.java | 19 +- ...essageBrokerConfigurationSupportTests.java | 17 +- .../DefaultSimpUserRegistryTests.java | 199 +++++++ .../StompSubProtocolHandlerTests.java | 17 - .../config/websocket-config-broker-relay.xml | 6 +- 46 files changed, 2627 insertions(+), 484 deletions(-) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSession.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSubscription.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSubscriptionMatcher.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUser.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserRegistryMessageHandler.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistryAdapter.java delete mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserSessionRegistryTests.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpSession.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpSubscription.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpUser.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java index 3964f7339ca..e22e76d0f84 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java @@ -17,7 +17,6 @@ package org.springframework.messaging.simp.config; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -43,14 +42,16 @@ import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler; import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; import org.springframework.messaging.simp.user.DefaultUserDestinationResolver; -import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; +import org.springframework.messaging.simp.user.MultiServerUserRegistry; +import org.springframework.messaging.simp.user.SimpUserRegistry; import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.simp.user.UserDestinationResolver; -import org.springframework.messaging.simp.user.UserSessionRegistry; +import org.springframework.messaging.simp.user.UserRegistryMessageHandler; import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.ClassUtils; import org.springframework.util.MimeTypeUtils; import org.springframework.util.PathMatcher; @@ -88,14 +89,14 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC "com.fasterxml.jackson.databind.ObjectMapper", AbstractMessageBrokerConfiguration.class.getClassLoader()); + private ApplicationContext applicationContext; + private ChannelRegistration clientInboundChannelRegistration; private ChannelRegistration clientOutboundChannelRegistration; private MessageBrokerRegistry brokerRegistry; - private ApplicationContext applicationContext; - /** * Protected constructor. @@ -287,12 +288,16 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC if (handler == null) { return new NoOpBrokerMessageHandler(); } + Map subscriptions = new HashMap(1); String destination = getBrokerRegistry().getUserDestinationBroadcast(); if (destination != null) { - Map map = new HashMap(1); - map.put(destination, userDestinationMessageHandler()); - handler.setSystemSubscriptions(map); + subscriptions.put(destination, userDestinationMessageHandler()); } + destination = getBrokerRegistry().getUserRegistryBroadcast(); + if (destination != null) { + subscriptions.put(destination, userRegistryMessageHandler()); + } + handler.setSystemSubscriptions(subscriptions); return handler; } @@ -301,10 +306,30 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC UserDestinationMessageHandler handler = new UserDestinationMessageHandler(clientInboundChannel(), brokerChannel(), userDestinationResolver()); String destination = getBrokerRegistry().getUserDestinationBroadcast(); - handler.setUserDestinationBroadcast(destination); + handler.setBroadcastDestination(destination); return handler; } + @Bean + public MessageHandler userRegistryMessageHandler() { + if (getBrokerRegistry().getUserRegistryBroadcast() == null) { + return new NoOpMessageHandler(); + } + return new UserRegistryMessageHandler(userRegistry(), brokerMessagingTemplate(), + getBrokerRegistry().getUserRegistryBroadcast(), messageBrokerTaskScheduler()); + } + + // Expose alias for 4.1 compatibility + + @Bean(name={"messageBrokerTaskScheduler", "messageBrokerSockJsTaskScheduler"}) + public ThreadPoolTaskScheduler messageBrokerTaskScheduler() { + ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); + scheduler.setThreadNamePrefix("MessageBroker-"); + scheduler.setPoolSize(Runtime.getRuntime().availableProcessors()); + scheduler.setRemoveOnCancelPolicy(true); + return scheduler; + } + @Bean public SimpMessagingTemplate brokerMessagingTemplate() { SimpMessagingTemplate template = new SimpMessagingTemplate(brokerChannel()); @@ -350,7 +375,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC @Bean public UserDestinationResolver userDestinationResolver() { - DefaultUserDestinationResolver resolver = new DefaultUserDestinationResolver(userSessionRegistry()); + DefaultUserDestinationResolver resolver = new DefaultUserDestinationResolver(userRegistry()); String prefix = getBrokerRegistry().getUserDestinationPrefix(); if (prefix != null) { resolver.setUserDestinationPrefix(prefix); @@ -359,8 +384,24 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC } @Bean - public UserSessionRegistry userSessionRegistry() { - return new DefaultUserSessionRegistry(); + @SuppressWarnings("deprecation") + public SimpUserRegistry userRegistry() { + return (getBrokerRegistry().getUserRegistryBroadcast() != null ? + new MultiServerUserRegistry(createLocalUserRegistry()) : createLocalUserRegistry()); + } + + protected abstract SimpUserRegistry createLocalUserRegistry(); + + /** + * As of 4.2, UserSessionRegistry is deprecated in favor of SimpUserRegistry + * exposing information about all connected users. The MultiServerUserRegistry + * implementation in combination with UserRegistryMessageHandler can be used + * to share user registries across multiple servers. + */ + @Deprecated + @SuppressWarnings("deprecation") + protected org.springframework.messaging.simp.user.UserSessionRegistry userSessionRegistry() { + return null; } /** @@ -417,6 +458,14 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC } + private static class NoOpMessageHandler implements MessageHandler { + + @Override + public void handleMessage(Message message) { + } + + } + private class NoOpBrokerMessageHandler extends AbstractBrokerMessageHandler { public NoOpBrokerMessageHandler() { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java index 6970905322e..0fb15519984 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/MessageBrokerRegistry.java @@ -23,6 +23,7 @@ import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; +import org.springframework.messaging.simp.user.SimpUserRegistry; import org.springframework.util.Assert; import org.springframework.util.PathMatcher; @@ -49,8 +50,6 @@ public class MessageBrokerRegistry { private String userDestinationPrefix; - private String userDestinationBroadcast; - private PathMatcher pathMatcher; @@ -139,22 +138,14 @@ public class MessageBrokerRegistry { return this.userDestinationPrefix; } - /** - * Set a destination to broadcast messages to that remain unresolved because - * the user is not connected. In a multi-application server scenario this - * gives other application servers a chance to try. - *

Note: this option applies only when the - * {@link #enableStompBrokerRelay "broker relay"} is enabled. - *

By default this is not set. - * @param destination the destination to forward unresolved - * messages to, e.g. "/topic/unresolved-user-destination". - */ - public void setUserDestinationBroadcast(String destination) { - this.userDestinationBroadcast = destination; + protected String getUserDestinationBroadcast() { + return (this.brokerRelayRegistration != null ? + this.brokerRelayRegistration.getUserDestinationBroadcast() : null); } - protected String getUserDestinationBroadcast() { - return this.userDestinationBroadcast; + protected String getUserRegistryBroadcast() { + return (this.brokerRelayRegistration != null ? + this.brokerRelayRegistration.getUserRegistryBroadcast() : null); } /** diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistration.java index f8534381da7..b6fbfef02e8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistration.java @@ -49,6 +49,10 @@ public class StompBrokerRelayRegistration extends AbstractBrokerRegistration { private boolean autoStartup = true; + private String userDestinationBroadcast; + + private String userRegistryBroadcast; + public StompBrokerRelayRegistration(SubscribableChannel clientInboundChannel, MessageChannel clientOutboundChannel, String[] destinationPrefixes) { @@ -166,10 +170,48 @@ public class StompBrokerRelayRegistration extends AbstractBrokerRegistration { return this; } + /** + * Set a destination to broadcast messages to user destinations that remain + * unresolved because the user appears not to be connected. In a + * multi-application server scenario this gives other application servers + * a chance to try. + *

By default this is not set. + * @param destination the destination to broadcast unresolved messages to, + * e.g. "/topic/unresolved-user-destination" + */ + public StompBrokerRelayRegistration setUserDestinationBroadcast(String destination) { + this.userDestinationBroadcast = destination; + return this; + } + + protected String getUserDestinationBroadcast() { + return this.userDestinationBroadcast; + } + + /** + * Set a destination to broadcast the content of the local user registry to + * and to listen for such broadcasts from other servers. In a multi-application + * server scenarios this allows each server's user registry to be aware of + * users connected to other servers. + *

By default this is not set. + * @param destination the destination for broadcasting user registry details, + * e.g. "/topic/simp-user-registry". + */ + public StompBrokerRelayRegistration setUserRegistryBroadcast(String destination) { + this.userRegistryBroadcast = destination; + return this; + } + + protected String getUserRegistryBroadcast() { + return this.userRegistryBroadcast; + } + protected StompBrokerRelayMessageHandler getMessageHandler(SubscribableChannel brokerChannel) { - StompBrokerRelayMessageHandler handler = new StompBrokerRelayMessageHandler(getClientInboundChannel(), - getClientOutboundChannel(), brokerChannel, getDestinationPrefixes()); + + StompBrokerRelayMessageHandler handler = new StompBrokerRelayMessageHandler( + getClientInboundChannel(), getClientOutboundChannel(), + brokerChannel, getDestinationPrefixes()); handler.setRelayHost(this.relayHost); handler.setRelayPort(this.relayPort); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java index 64363a1fe34..d8593c71a61 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java @@ -33,8 +33,7 @@ import org.springframework.util.StringUtils; /** * A default implementation of {@code UserDestinationResolver} that relies - * on a {@link org.springframework.messaging.simp.user.UserSessionRegistry} to - * find active sessions for a user. + * on a {@link SimpUserRegistry} to find active sessions for a user. * *

When a user attempts to subscribe, e.g. to "/user/queue/position-updates", * the "/user" prefix is removed and a unique suffix added based on the session @@ -54,7 +53,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { private static final Log logger = LogFactory.getLog(DefaultUserDestinationResolver.class); - private final UserSessionRegistry sessionRegistry; + private final SimpUserRegistry userRegistry; private String prefix = "/user/"; @@ -62,19 +61,19 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { /** * Create an instance that will access user session id information through * the provided registry. - * @param sessionRegistry the registry, never {@code null} + * @param userRegistry the registry, never {@code null} */ - public DefaultUserDestinationResolver(UserSessionRegistry sessionRegistry) { - Assert.notNull(sessionRegistry, "'sessionRegistry' must not be null"); - this.sessionRegistry = sessionRegistry; + public DefaultUserDestinationResolver(SimpUserRegistry userRegistry) { + Assert.notNull(userRegistry, "'userRegistry' must not be null"); + this.userRegistry = userRegistry; } /** - * Return the configured {@link UserSessionRegistry}. + * Return the configured {@link SimpUserRegistry}. */ - public UserSessionRegistry getUserSessionRegistry() { - return this.sessionRegistry; + public SimpUserRegistry getSimpUserRegistry() { + return this.userRegistry; } /** @@ -141,20 +140,32 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { Assert.isTrue(userEnd > 0, "Expected destination pattern \"/user/{userId}/**\""); String actualDestination = destination.substring(userEnd); String subscribeDestination = this.prefix.substring(0, prefixEnd - 1) + actualDestination; - String user = destination.substring(prefixEnd, userEnd); - user = StringUtils.replace(user, "%2F", "/"); + String userName = destination.substring(prefixEnd, userEnd); + userName = StringUtils.replace(userName, "%2F", "/"); Set sessionIds; - if (user.equals(sessionId)) { - user = null; - sessionIds = Collections.singleton(sessionId); - } - else if (this.sessionRegistry.getSessionIds(user).contains(sessionId)) { + if (userName.equals(sessionId)) { + userName = null; sessionIds = Collections.singleton(sessionId); } else { - sessionIds = this.sessionRegistry.getSessionIds(user); + SimpUser user = this.userRegistry.getUser(userName); + if (user != null) { + if (user.getSession(sessionId) != null) { + sessionIds = Collections.singleton(sessionId); + } + else { + Set sessions = user.getSessions(); + sessionIds = new HashSet(sessions.size()); + for (SimpSession session : sessions) { + sessionIds.add(session.getId()); + } + } + } + else { + sessionIds = Collections.emptySet(); + } } - return new ParseResult(actualDestination, subscribeDestination, sessionIds, user); + return new ParseResult(actualDestination, subscribeDestination, sessionIds, userName); } else { return null; @@ -174,6 +185,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { * @param user the target user, possibly {@code null}, e.g if not authenticated. * @return a target destination, or {@code null} if none */ + @SuppressWarnings("unused") protected String getTargetDestination(String sourceDestination, String actualDestination, String sessionId, String user) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserSessionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserSessionRegistry.java index 313cbecefff..4903dd11ab2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserSessionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserSessionRegistry.java @@ -29,7 +29,11 @@ import org.springframework.util.Assert; * * @author Rossen Stoyanchev * @since 4.0 + * @deprecated as of 4.2 this class is no longer used, see deprecation notes + * on {@link UserSessionRegistry} for more details. */ +@Deprecated +@SuppressWarnings({"deprecation", "unused"}) public class DefaultUserSessionRegistry implements UserSessionRegistry { // userId -> sessionId @@ -72,4 +76,4 @@ public class DefaultUserSessionRegistry implements UserSessionRegistry { } } -} +} \ No newline at end of file diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java new file mode 100644 index 00000000000..e43ab15c761 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java @@ -0,0 +1,488 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.context.ApplicationEvent; +import org.springframework.context.event.SmartApplicationListener; +import org.springframework.core.Ordered; +import org.springframework.messaging.Message; +import org.springframework.messaging.converter.MessageConverter; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; + +/** + * A user registry that is a composite of the "local" user registry as well as + * snapshots of remote user registries. For use with + * {@link UserRegistryMessageHandler} which broadcasts periodically the content + * of the local registry and receives updates from other servers. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +@SuppressWarnings("serial") +public class MultiServerUserRegistry implements SimpUserRegistry, SmartApplicationListener { + + private final String id; + + private final SimpUserRegistry localRegistry; + + private final SmartApplicationListener listener; + + private final Map remoteRegistries = + new ConcurrentHashMap(); + + + /** + * Create an instance wrapping the local user registry. + */ + public MultiServerUserRegistry(SimpUserRegistry localRegistry) { + Assert.notNull(localRegistry, "'localRegistry' is required."); + this.localRegistry = localRegistry; + this.listener = (this.localRegistry instanceof SmartApplicationListener ? + (SmartApplicationListener) this.localRegistry : new NoOpSmartApplicationListener()); + this.id = generateId(); + } + + private static String generateId() { + String host; + try { + host = InetAddress.getLocalHost().getHostAddress(); + } + catch (UnknownHostException e) { + host = "unknown"; + } + return host + "-" + UUID.randomUUID(); + } + + + @Override + public SimpUser getUser(String userName) { + SimpUser user = this.localRegistry.getUser(userName); + if (user != null) { + return user; + } + for (UserRegistryDto registry : this.remoteRegistries.values()) { + user = registry.getUsers().get(userName); + if (user != null) { + return user; + } + } + return null; + } + + @Override + public Set getUsers() { + Set result = new HashSet(this.localRegistry.getUsers()); + for (UserRegistryDto registry : this.remoteRegistries.values()) { + result.addAll(registry.getUsers().values()); + } + return result; + } + + @Override + public Set findSubscriptions(SimpSubscriptionMatcher matcher) { + Set result = new HashSet(this.localRegistry.findSubscriptions(matcher)); + for (UserRegistryDto registry : this.remoteRegistries.values()) { + result.addAll(registry.findSubscriptions(matcher)); + } + return result; + } + + @Override + public boolean supportsEventType(Class eventType) { + return this.listener.supportsEventType(eventType); + } + + @Override + public boolean supportsSourceType(Class sourceType) { + return this.listener.supportsSourceType(sourceType); + } + + @Override + public void onApplicationEvent(ApplicationEvent event) { + this.listener.onApplicationEvent(event); + } + + @Override + public int getOrder() { + return this.listener.getOrder(); + } + + Object getLocalRegistryDto() { + return new UserRegistryDto(this.id, this.localRegistry); + } + + void addRemoteRegistryDto(Message message, MessageConverter converter, long expirationPeriod) { + UserRegistryDto registryDto = (UserRegistryDto) converter.fromMessage(message, UserRegistryDto.class); + if (registryDto != null && !registryDto.getId().equals(this.id)) { + long expirationTime = System.currentTimeMillis() + expirationPeriod; + registryDto.setExpirationTime(expirationTime); + registryDto.restoreParentReferences(); + this.remoteRegistries.put(registryDto.getId(), registryDto); + } + } + + void purgeExpiredRegistries() { + long now = System.currentTimeMillis(); + Iterator> iterator = this.remoteRegistries.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (now > entry.getValue().getExpirationTime()) { + iterator.remove(); + } + } + } + + @Override + public String toString() { + return "local=[" + this.localRegistry + "], remote=" + this.remoteRegistries + "]"; + } + + + @SuppressWarnings("unused") + private static class UserRegistryDto { + + private String id; + + private Map users; + + private long expirationTime; + + + public UserRegistryDto() { + } + + public UserRegistryDto(String id, SimpUserRegistry registry) { + this.id = id; + Set users = registry.getUsers(); + this.users = new HashMap(users.size()); + for (SimpUser user : users) { + this.users.put(user.getName(), new SimpUserDto(user)); + } + } + + public void setId(String id) { + this.id = id; + } + + public String getId() { + return this.id; + } + + public void setUsers(Map users) { + this.users = users; + } + + public Map getUsers() { + return this.users; + } + + public Set findSubscriptions(SimpSubscriptionMatcher matcher) { + Set result = new HashSet(); + for (SimpUserDto user : this.users.values()) { + for (SimpSessionDto session : user.sessions) { + for (SimpSubscription subscription : session.subscriptions) { + if (matcher.match(subscription)) { + result.add(subscription); + } + } + } + } + return result; + } + + public void setExpirationTime(long expirationTime) { + this.expirationTime = expirationTime; + } + + public long getExpirationTime() { + return this.expirationTime; + } + + private void restoreParentReferences() { + for (SimpUserDto user : this.users.values()) { + user.restoreParentReferences(); + } + } + @Override + public String toString() { + return "id=" + this.id + ", users=" + this.users; + } + } + + @SuppressWarnings("unused") + private static class SimpUserDto implements SimpUser { + + private String name; + + private Set sessions; + + + public SimpUserDto() { + this.sessions = new HashSet(1); + } + + public SimpUserDto(SimpUser user) { + this.name = user.getName(); + Set sessions = user.getSessions(); + this.sessions = new HashSet(sessions.size()); + for (SimpSession session : sessions) { + this.sessions.add(new SimpSessionDto(session)); + } + } + + @Override + public String getName() { + return this.name; + } + + public void setName(String name) { + this.name = name; + } + + @Override + public boolean hasSessions() { + return !this.sessions.isEmpty(); + } + + @Override + public Set getSessions() { + return new HashSet(this.sessions); + } + + public void setSessions(Set sessions) { + this.sessions.addAll(sessions); + } + + @Override + public SimpSessionDto getSession(String sessionId) { + for (SimpSessionDto session : this.sessions) { + if (session.getId().equals(sessionId)) { + return session; + } + } + return null; + } + + private void restoreParentReferences() { + for (SimpSessionDto session : this.sessions) { + session.setUser(this); + session.restoreParentReferences(); + } + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || !(other instanceof SimpUser)) { + return false; + } + return this.name.equals(((SimpUser) other).getName()); + } + + @Override + public int hashCode() { + return this.name.hashCode(); + } + + @Override + public String toString() { + return "name=" + this.name + ", sessions=" + this.sessions; + } + } + + @SuppressWarnings("unused") + private static class SimpSessionDto implements SimpSession { + + private String id; + + private SimpUserDto user; + + private Set subscriptions; + + + public SimpSessionDto() { + this.subscriptions = new HashSet(4); + } + + public SimpSessionDto(SimpSession session) { + this.id = session.getId(); + Set subscriptions = session.getSubscriptions(); + this.subscriptions = new HashSet(subscriptions.size()); + for (SimpSubscription subscription : subscriptions) { + this.subscriptions.add(new SimpSubscriptionDto(subscription)); + } + } + + @Override + public String getId() { + return this.id; + } + + public void setId(String id) { + this.id = id; + } + + @Override + public SimpUserDto getUser() { + return this.user; + } + + public void setUser(SimpUserDto user) { + this.user = user; + } + + @Override + public Set getSubscriptions() { + return new HashSet(this.subscriptions); + } + + public void setSubscriptions(Set subscriptions) { + this.subscriptions.addAll(subscriptions); + } + + private void restoreParentReferences() { + for (SimpSubscriptionDto subscription : this.subscriptions) { + subscription.setSession(this); + } + } + + @Override + public int hashCode() { + return this.id.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || !(other instanceof SimpSession)) { + return false; + } + return this.id.equals(((SimpSession) other).getId()); + } + + @Override + public String toString() { + return "id=" + this.id + ", subscriptions=" + this.subscriptions; + } + } + + @SuppressWarnings("unused") + private static class SimpSubscriptionDto implements SimpSubscription { + + private String id; + + private SimpSessionDto session; + + private String destination; + + + public SimpSubscriptionDto() { + } + + public SimpSubscriptionDto(SimpSubscription subscription) { + this.id = subscription.getId(); + this.destination = subscription.getDestination(); + } + + @Override + public String getId() { + return this.id; + } + + public void setId(String id) { + this.id = id; + } + + @Override + public SimpSessionDto getSession() { + return this.session; + } + + public void setSession(SimpSessionDto session) { + this.session = session; + } + + @Override + public String getDestination() { + return this.destination; + } + + public void setDestination(String destination) { + this.destination = destination; + } + + @Override + public int hashCode() { + return 31 * this.id.hashCode() + ObjectUtils.nullSafeHashCode(getSession()); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || !(other instanceof SimpSubscription)) { + return false; + } + SimpSubscription otherSubscription = (SimpSubscription) other; + return (ObjectUtils.nullSafeEquals(getSession(), otherSubscription.getSession()) && + this.id.equals(otherSubscription.getId())); + } + + @Override + public String toString() { + return "destination=" + this.destination; + } + } + + private static class NoOpSmartApplicationListener implements SmartApplicationListener { + + @Override + public boolean supportsEventType(Class eventType) { + return false; + } + + @Override + public boolean supportsSourceType(Class sourceType) { + return false; + } + + @Override + public void onApplicationEvent(ApplicationEvent event) { + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSession.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSession.java new file mode 100644 index 00000000000..691783fbf66 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSession.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import java.util.Set; + +/** + * Represents a session of connected user. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +public interface SimpSession { + + /** + * Return the session id. + */ + String getId(); + + /** + * Return the user associated with the session. + */ + SimpUser getUser(); + + /** + * Return the subscriptions for this session. + */ + Set getSubscriptions(); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSubscription.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSubscription.java new file mode 100644 index 00000000000..9819e3a1a8d --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSubscription.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +/** + * Represents a subscription within a user session. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +public interface SimpSubscription { + + /** + * Return the id associated of the subscription, never {@code null}. + */ + String getId(); + + /** + * Return the session of the subscription, never {@code null}. + */ + SimpSession getSession(); + + /** + * Return the subscription's destination, never {@code null}. + */ + String getDestination(); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSubscriptionMatcher.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSubscriptionMatcher.java new file mode 100644 index 00000000000..5c71bc97aad --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpSubscriptionMatcher.java @@ -0,0 +1,33 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +/** + * A strategy for matching subscriptions. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +public interface SimpSubscriptionMatcher { + + /** + * Match the given subscription. + * @param subscription the subscription to match + * @return {@code true} in case of match, {@code false} otherwise. + */ + boolean match(SimpSubscription subscription); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUser.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUser.java new file mode 100644 index 00000000000..3b7846bf534 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUser.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import java.util.Set; + +/** + * Represents a connected user. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +public interface SimpUser { + + /** + * The unique user name. + */ + String getName(); + + /** + * Whether the user has any sessions. + */ + boolean hasSessions(); + + /** + * Look up the session for the given id. + * @param sessionId the session id + * @return the matching session of {@code null}. + */ + SimpSession getSession(String sessionId); + + /** + * Return the sessions for the user. + * The returned set is a copy and will never be modified. + * @return a set of session ids, or an empty set. + */ + Set getSessions(); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java new file mode 100644 index 00000000000..c93a3eebead --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import java.util.Set; + +/** + * A registry of currently connected users. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +public interface SimpUserRegistry { + + /** + * Get the user for the given name. + * @param userName the name of the user to look up + * @return the user or {@code null} if not connected + */ + SimpUser getUser(String userName); + + /** + * Return a snapshot of all connected users. The returned set is a copy and + * will never be modified. + * @return the connected users or an empty set. + */ + Set getUsers(); + + /** + * Find subscriptions with the given matcher. + * @param matcher the matcher to use + * @return a set of matching subscriptions or an empty set. + */ + Set findSubscriptions(SimpSubscriptionMatcher matcher); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java index 9cf7730e5ff..6d3ecd41b14 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java @@ -108,7 +108,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec *

By default this is not set. * @param destination the target destination. */ - public void setUserDestinationBroadcast(String destination) { + public void setBroadcastDestination(String destination) { this.broadcastHandler = (StringUtils.hasText(destination) ? new BroadcastHandler(this.messagingTemplate, destination) : null); } @@ -116,7 +116,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec /** * Return the configured destination for unresolved messages. */ - public String getUserDestinationBroadcast() { + public String getBroadcastDestination() { return (this.broadcastHandler != null ? this.broadcastHandler.getBroadcastDestination() : null); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java index 4d7765a945a..7e2d8503e63 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java @@ -87,7 +87,7 @@ public class UserDestinationResult { * @return the user name or {@code null} if we have a session id only such as * when the user is not authenticated; in such cases it is possible to use * sessionId in place of a user name thus removing the need for a user-to-session - * lookup via {@link org.springframework.messaging.simp.user.UserSessionRegistry}. + * lookup via {@link SimpUserRegistry}. */ public String getUser() { return this.user; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserRegistryMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserRegistryMessageHandler.java new file mode 100644 index 00000000000..66769af4a69 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserRegistryMessageHandler.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import java.util.concurrent.ScheduledFuture; + +import org.springframework.context.ApplicationListener; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; +import org.springframework.messaging.converter.MessageConverter; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.broker.BrokerAvailabilityEvent; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.Assert; + +/** + * A MessageHandler that is subscribed to listen to broadcasts of user registry + * information from other application servers as well as to periodically + * broadcast the content of the local user registry. The aggregated information + * is maintained in a {@link MultiServerUserRegistry}. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +public class UserRegistryMessageHandler implements MessageHandler, ApplicationListener { + + private final MultiServerUserRegistry userRegistry; + + private final SimpMessagingTemplate brokerTemplate; + + private final String broadcastDestination; + + private final TaskScheduler scheduler; + + private final UserRegistryTask schedulerTask = new UserRegistryTask(); + + private volatile ScheduledFuture scheduledFuture; + + private long registryExpirationPeriod = 20 * 1000; + + + public UserRegistryMessageHandler(SimpUserRegistry userRegistry, SimpMessagingTemplate brokerTemplate, + String broadcastDestination, TaskScheduler scheduler) { + + Assert.notNull(userRegistry, "'userRegistry' is required"); + Assert.isInstanceOf(MultiServerUserRegistry.class, userRegistry); + Assert.notNull(brokerTemplate, "'brokerTemplate' is required"); + Assert.hasText(broadcastDestination, "'broadcastDestination' is required"); + Assert.notNull(scheduler, "'scheduler' is required"); + + + this.userRegistry = (MultiServerUserRegistry) userRegistry; + this.brokerTemplate = brokerTemplate; + this.broadcastDestination = broadcastDestination; + this.scheduler = scheduler; + } + + + /** + * Return the destination for broadcasting user registry information to. + */ + public String getBroadcastDestination() { + return this.broadcastDestination; + } + + /** + * Configure how long before a remote registry snapshot expires. + *

By default this is set to 20000 (20 seconds). + * @param expirationPeriod the expiration period in milliseconds + */ + @SuppressWarnings("unused") + public void setRegistryExpirationPeriod(long expirationPeriod) { + this.registryExpirationPeriod = expirationPeriod; + } + + /** + * Return the configured registry expiration period. + */ + public long getRegistryExpirationPeriod() { + return this.registryExpirationPeriod; + } + + + @Override + public void onApplicationEvent(BrokerAvailabilityEvent event) { + if (event.isBrokerAvailable()) { + long delay = getRegistryExpirationPeriod() / 2; + this.scheduledFuture = this.scheduler.scheduleWithFixedDelay(this.schedulerTask, delay); + } + else if (this.scheduledFuture != null ){ + this.scheduledFuture.cancel(true); + this.scheduledFuture = null; + } + } + + @Override + public void handleMessage(Message message) throws MessagingException { + MessageConverter converter = this.brokerTemplate.getMessageConverter(); + this.userRegistry.addRemoteRegistryDto(message, converter, getRegistryExpirationPeriod()); + } + + + private class UserRegistryTask implements Runnable { + + @Override + public void run() { + try { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + accessor.setHeader(SimpMessageHeaderAccessor.IGNORE_ERROR, true); + accessor.setLeaveMutable(true); + Object payload = userRegistry.getLocalRegistryDto(); + brokerTemplate.convertAndSend(getBroadcastDestination(), payload, accessor.getMessageHeaders()); + } + finally { + userRegistry.purgeExpiredRegistries(); + } + } + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistry.java index 05b780b816c..fbccc0c3b0b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistry.java @@ -19,34 +19,41 @@ package org.springframework.messaging.simp.user; import java.util.Set; /** - * A registry for looking up active user sessions. For use when resolving user - * destinations. + * A contract for adding and removing user sessions. + * + *

As of 4.2 this interface extends {@link SimpUserRegistry}. + * exposing methods to return all registered users as well as to provide more + * extensive information for each user. * * @author Rossen Stoyanchev * @since 4.0 - * @see DefaultUserDestinationResolver + * @deprecated in favor of {@link SimpUserRegistry} in combination with + * {@link org.springframework.context.ApplicationListener} listening for + * {@link org.springframework.web.socket.messaging.AbstractSubProtocolEvent} events. */ +@Deprecated public interface UserSessionRegistry { /** - * Return the active session id's for the user. - * @param user the user - * @return a set with 0 or more session id's, never {@code null}. + * Return the active session ids for the user. + * The returned set is a snapshot that will never be modified. + * @param userName the user to look up + * @return a set with 0 or more session ids, never {@code null}. */ - Set getSessionIds(String user); + Set getSessionIds(String userName); /** * Register an active session id for a user. - * @param user the user + * @param userName the user name * @param sessionId the session id */ - void registerSessionId(String user, String sessionId); + void registerSessionId(String userName, String sessionId); /** * Unregister an active session id for a user. - * @param user the user + * @param userName the user name * @param sessionId the session id */ - void unregisterSessionId(String user, String sessionId); + void unregisterSessionId(String userName, String sessionId); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistryAdapter.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistryAdapter.java new file mode 100644 index 00000000000..acd1e3bd158 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserSessionRegistryAdapter.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.springframework.util.CollectionUtils; + +/** + * A temporary adapter to allow use of deprecated {@link UserSessionRegistry}. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +@SuppressWarnings("deprecation") +public class UserSessionRegistryAdapter implements SimpUserRegistry { + + private final UserSessionRegistry delegate; + + + public UserSessionRegistryAdapter(UserSessionRegistry delegate) { + this.delegate = delegate; + } + + + @Override + public SimpUser getUser(String userName) { + Set sessionIds = this.delegate.getSessionIds(userName); + return (!CollectionUtils.isEmpty(sessionIds) ? new SimpleSimpUser(userName, sessionIds) : null); + } + + @Override + public Set getUsers() { + throw new UnsupportedOperationException("UserSessionRegistry does not expose a listing of users."); + } + + @Override + public Set findSubscriptions(SimpSubscriptionMatcher matcher) { + throw new UnsupportedOperationException("UserSessionRegistry does not support operations across users."); + } + + + private static class SimpleSimpUser implements SimpUser { + + private final String name; + + private final Map sessions; + + + public SimpleSimpUser(String name, Set sessionIds) { + this.name = name; + this.sessions = new HashMap(sessionIds.size()); + for (String sessionId : sessionIds) { + this.sessions.put(sessionId, new SimpleSimpSession(sessionId)); + } + } + + @Override + public String getName() { + return this.name; + } + + @Override + public boolean hasSessions() { + return !this.sessions.isEmpty(); + } + + @Override + public SimpSession getSession(String sessionId) { + return this.sessions.get(sessionId); + } + + @Override + public Set getSessions() { + return new HashSet(this.sessions.values()); + } + } + + private static class SimpleSimpSession implements SimpSession { + + private final String id; + + + public SimpleSimpSession(String id) { + this.id = id; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public SimpUser getUser() { + return null; + } + + @Override + public Set getSubscriptions() { + return Collections.emptySet(); + } + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/package-info.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/package-info.java index 2a2a7ff00fe..9cf68e14592 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/package-info.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/package-info.java @@ -3,7 +3,7 @@ * unique to a user's sessions), primarily translating the destinations and then * forwarding the updated message to the broker. * - *

Also included is {@link org.springframework.messaging.simp.user.UserSessionRegistry} + *

Also included is {@link org.springframework.messaging.simp.user.SimpUserRegistry} * for keeping track of connected user sessions. */ package org.springframework.messaging.simp.user; diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index 2584d851a85..e72a77cea2e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -16,6 +16,9 @@ package org.springframework.messaging.simp.config; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -51,8 +54,10 @@ import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.simp.user.MultiServerUserRegistry; +import org.springframework.messaging.simp.user.SimpUserRegistry; import org.springframework.messaging.simp.user.UserDestinationMessageHandler; -import org.springframework.messaging.simp.user.UserSessionRegistry; +import org.springframework.messaging.simp.user.UserRegistryMessageHandler; import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ChannelInterceptorAdapter; @@ -66,9 +71,6 @@ import org.springframework.validation.Errors; import org.springframework.validation.Validator; import org.springframework.validation.beanvalidation.OptionalValidatorFactoryBean; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - /** * Test fixture for {@link AbstractMessageBrokerConfiguration}. * @@ -235,26 +237,6 @@ public class MessageBrokerConfigurationTests { assertEquals("bar", new String((byte[]) message.getPayload())); } - @Test - public void brokerChannelUsedByUserDestinationMessageHandler() { - TestChannel channel = this.simpleBrokerContext.getBean("brokerChannel", TestChannel.class); - UserDestinationMessageHandler messageHandler = this.simpleBrokerContext.getBean(UserDestinationMessageHandler.class); - - this.simpleBrokerContext.getBean(UserSessionRegistry.class).registerSessionId("joe", "s1"); - - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); - headers.setDestination("/user/joe/foo"); - Message message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); - - messageHandler.handleMessage(message); - - message = channel.messages.get(0); - headers = StompHeaderAccessor.wrap(message); - - assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); - assertEquals("/foo-users1", headers.getDestination()); - } - @Test public void brokerChannelCustomized() { AbstractSubscribableChannel channel = this.customContext.getBean( @@ -272,7 +254,7 @@ public class MessageBrokerConfigurationTests { @Test public void configureMessageConvertersDefault() { - AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {}; + AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig(); CompositeMessageConverter compositeConverter = config.brokerMessageConverter(); List converters = compositeConverter.getConverters(); @@ -305,7 +287,7 @@ public class MessageBrokerConfigurationTests { @Test public void configureMessageConvertersCustom() { final MessageConverter testConverter = mock(MessageConverter.class); - AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() { + AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() { @Override protected boolean configureMessageConverters(List messageConverters) { messageConverters.add(testConverter); @@ -323,7 +305,7 @@ public class MessageBrokerConfigurationTests { public void configureMessageConvertersCustomAndDefault() { final MessageConverter testConverter = mock(MessageConverter.class); - AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() { + AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() { @Override protected boolean configureMessageConverters(List messageConverters) { messageConverters.add(testConverter); @@ -355,7 +337,7 @@ public class MessageBrokerConfigurationTests { @Test public void simpValidatorDefault() { - AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {}; + AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() {}; config.setApplicationContext(new StaticApplicationContext()); assertThat(config.simpValidator(), Matchers.notNullValue()); @@ -365,7 +347,7 @@ public class MessageBrokerConfigurationTests { @Test public void simpValidatorCustom() { final Validator validator = mock(Validator.class); - AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() { + AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() { @Override public Validator getValidator() { return validator; @@ -379,7 +361,7 @@ public class MessageBrokerConfigurationTests { public void simpValidatorMvc() { StaticApplicationContext appCxt = new StaticApplicationContext(); appCxt.registerSingleton("mvcValidator", TestValidator.class); - AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {}; + AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() {}; config.setApplicationContext(appCxt); assertThat(config.simpValidator(), Matchers.notNullValue()); @@ -405,12 +387,35 @@ public class MessageBrokerConfigurationTests { } @Test - public void userDestinationBroadcast() throws Exception { + public void userBroadcasts() throws Exception { + SimpUserRegistry userRegistry = this.brokerRelayContext.getBean(SimpUserRegistry.class); + assertEquals(MultiServerUserRegistry.class, userRegistry.getClass()); + + UserDestinationMessageHandler handler1 = this.brokerRelayContext.getBean(UserDestinationMessageHandler.class); + assertEquals("/topic/unresolved-user-destination", handler1.getBroadcastDestination()); + + UserRegistryMessageHandler handler2 = this.brokerRelayContext.getBean(UserRegistryMessageHandler.class); + assertEquals("/topic/simp-user-registry", handler2.getBroadcastDestination()); + StompBrokerRelayMessageHandler relay = this.brokerRelayContext.getBean(StompBrokerRelayMessageHandler.class); - UserDestinationMessageHandler userHandler = this.brokerRelayContext.getBean(UserDestinationMessageHandler.class); - assertEquals("/topic/unresolved", userHandler.getUserDestinationBroadcast()); assertNotNull(relay.getSystemSubscriptions()); - assertSame(userHandler, relay.getSystemSubscriptions().get("/topic/unresolved")); + assertEquals(2, relay.getSystemSubscriptions().size()); + assertSame(handler1, relay.getSystemSubscriptions().get("/topic/unresolved-user-destination")); + assertSame(handler2, relay.getSystemSubscriptions().get("/topic/simp-user-registry")); + } + + @Test + public void userBroadcastsDisabledWithSimpleBroker() throws Exception { + SimpUserRegistry registry = this.simpleBrokerContext.getBean(SimpUserRegistry.class); + assertNotNull(registry); + assertNotEquals(MultiServerUserRegistry.class, registry.getClass()); + + UserDestinationMessageHandler handler = this.simpleBrokerContext.getBean(UserDestinationMessageHandler.class); + assertNull(handler.getBroadcastDestination()); + + String name = "userRegistryMessageHandler"; + MessageHandler messageHandler = this.simpleBrokerContext.getBean(name, MessageHandler.class); + assertNotEquals(UserRegistryMessageHandler.class, messageHandler.getClass()); } @@ -430,9 +435,17 @@ public class MessageBrokerConfigurationTests { } } + static class BaseTestMessageBrokerConfig extends AbstractMessageBrokerConfiguration { + + @Override + protected SimpUserRegistry createLocalUserRegistry() { + return mock(SimpUserRegistry.class); + } + } + @SuppressWarnings("unused") @Configuration - static class SimpleBrokerConfig extends AbstractMessageBrokerConfiguration { + static class SimpleBrokerConfig extends BaseTestMessageBrokerConfig { @Bean public TestController subscriptionController() { @@ -463,17 +476,18 @@ public class MessageBrokerConfigurationTests { @Override public void configureMessageBroker(MessageBrokerRegistry registry) { - registry.enableStompBrokerRelay("/topic", "/queue").setAutoStartup(true); - registry.setUserDestinationBroadcast("/topic/unresolved"); + registry.enableStompBrokerRelay("/topic", "/queue").setAutoStartup(true) + .setUserDestinationBroadcast("/topic/unresolved-user-destination") + .setUserRegistryBroadcast("/topic/simp-user-registry"); } } @Configuration - static class DefaultConfig extends AbstractMessageBrokerConfiguration { + static class DefaultConfig extends BaseTestMessageBrokerConfig { } @Configuration - static class CustomConfig extends AbstractMessageBrokerConfiguration { + static class CustomConfig extends BaseTestMessageBrokerConfig { private ChannelInterceptor interceptor = new ChannelInterceptorAdapter() {}; diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistrationTests.java index 87dc99d3a95..b9273f2921d 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/StompBrokerRelayRegistrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,8 @@ import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; import static org.junit.Assert.*; /** - * Unit tests for {@link org.springframework.messaging.simp.config.StompBrokerRelayRegistration}. + * Unit tests for + * {@link org.springframework.messaging.simp.config.StompBrokerRelayRegistration}. * * @author Rossen Stoyanchev */ @@ -39,15 +40,11 @@ public class StompBrokerRelayRegistrationTests { @Test public void test() { - SubscribableChannel clientInboundChannel = new StubMessageChannel(); - MessageChannel clientOutboundChannel = new StubMessageChannel(); - SubscribableChannel brokerChannel = new StubMessageChannel(); - - String[] destinationPrefixes = new String[] { "/foo", "/bar" }; - - StompBrokerRelayRegistration registration = new StompBrokerRelayRegistration( - clientInboundChannel, clientOutboundChannel, destinationPrefixes); + SubscribableChannel inChannel = new StubMessageChannel(); + MessageChannel outChannel = new StubMessageChannel(); + String[] prefixes = new String[] { "/foo", "/bar" }; + StompBrokerRelayRegistration registration = new StompBrokerRelayRegistration(inChannel, outChannel, prefixes); registration.setClientLogin("clientlogin"); registration.setClientPasscode("clientpasscode"); registration.setSystemLogin("syslogin"); @@ -56,18 +53,16 @@ public class StompBrokerRelayRegistrationTests { registration.setSystemHeartbeatSendInterval(456); registration.setVirtualHost("example.org"); - StompBrokerRelayMessageHandler relayMessageHandler = registration.getMessageHandler(brokerChannel); + StompBrokerRelayMessageHandler handler = registration.getMessageHandler(new StubMessageChannel()); - assertEquals(Arrays.asList(destinationPrefixes), - new ArrayList(relayMessageHandler.getDestinationPrefixes())); - - assertEquals("clientlogin", relayMessageHandler.getClientLogin()); - assertEquals("clientpasscode", relayMessageHandler.getClientPasscode()); - assertEquals("syslogin", relayMessageHandler.getSystemLogin()); - assertEquals("syspasscode", relayMessageHandler.getSystemPasscode()); - assertEquals(123, relayMessageHandler.getSystemHeartbeatReceiveInterval()); - assertEquals(456, relayMessageHandler.getSystemHeartbeatSendInterval()); - assertEquals("example.org", relayMessageHandler.getVirtualHost()); + assertArrayEquals(prefixes, handler.getDestinationPrefixes().toArray(new String[2])); + assertEquals("clientlogin", handler.getClientLogin()); + assertEquals("clientpasscode", handler.getClientPasscode()); + assertEquals("syslogin", handler.getSystemLogin()); + assertEquals("syspasscode", handler.getSystemPasscode()); + assertEquals(123, handler.getSystemHeartbeatReceiveInterval()); + assertEquals(456, handler.getSystemHeartbeatSendInterval()); + assertEquals("example.org", handler.getVirtualHost()); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java index 455bfd4d06b..5a062f9e4f6 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java @@ -17,6 +17,9 @@ package org.springframework.messaging.simp.user; import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import java.security.Principal; import org.junit.Before; import org.junit.Test; @@ -36,35 +39,36 @@ import org.springframework.util.StringUtils; */ public class DefaultUserDestinationResolverTests { - public static final String SESSION_ID = "123"; - private DefaultUserDestinationResolver resolver; - private UserSessionRegistry registry; - - private TestPrincipal user; + private SimpUserRegistry registry; @Before public void setup() { - this.user = new TestPrincipal("joe"); - this.registry = new DefaultUserSessionRegistry(); - this.registry.registerSessionId(this.user.getName(), SESSION_ID); + + TestSimpUser simpUser = new TestSimpUser("joe"); + simpUser.addSessions(new TestSimpSession("123")); + + this.registry = mock(SimpUserRegistry.class); + when(this.registry.getUser("joe")).thenReturn(simpUser); + this.resolver = new DefaultUserDestinationResolver(this.registry); } - @Test public void handleSubscribe() { + TestPrincipal user = new TestPrincipal("joe"); String sourceDestination = "/user/queue/foo"; - Message message = createWith(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, sourceDestination); + + Message message = createMessage(SimpMessageType.SUBSCRIBE, user, "123", sourceDestination); UserDestinationResult actual = this.resolver.resolveDestination(message); assertEquals(sourceDestination, actual.getSourceDestination()); assertEquals(1, actual.getTargetDestinations().size()); assertEquals("/queue/foo-user123", actual.getTargetDestinations().iterator().next()); assertEquals(sourceDestination, actual.getSubscribeDestination()); - assertEquals(this.user.getName(), actual.getUser()); + assertEquals(user.getName(), actual.getUser()); } // SPR-11325 @@ -72,32 +76,35 @@ public class DefaultUserDestinationResolverTests { @Test public void handleSubscribeOneUserMultipleSessions() { - this.registry.registerSessionId("joe", "456"); - this.registry.registerSessionId("joe", "789"); + TestSimpUser simpUser = new TestSimpUser("joe"); + simpUser.addSessions(new TestSimpSession("123"), new TestSimpSession("456")); + when(this.registry.getUser("joe")).thenReturn(simpUser); - Message message = createWith(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo"); + TestPrincipal user = new TestPrincipal("joe"); + Message message = createMessage(SimpMessageType.SUBSCRIBE, user, "456", "/user/queue/foo"); UserDestinationResult actual = this.resolver.resolveDestination(message); assertEquals(1, actual.getTargetDestinations().size()); - assertEquals("/queue/foo-user123", actual.getTargetDestinations().iterator().next()); + assertEquals("/queue/foo-user456", actual.getTargetDestinations().iterator().next()); } @Test public void handleSubscribeNoUser() { String sourceDestination = "/user/queue/foo"; - Message message = createWith(SimpMessageType.SUBSCRIBE, null, SESSION_ID, sourceDestination); + Message message = createMessage(SimpMessageType.SUBSCRIBE, null, "123", sourceDestination); UserDestinationResult actual = this.resolver.resolveDestination(message); assertEquals(sourceDestination, actual.getSourceDestination()); assertEquals(1, actual.getTargetDestinations().size()); - assertEquals("/queue/foo-user" + SESSION_ID, actual.getTargetDestinations().iterator().next()); + assertEquals("/queue/foo-user" + "123", actual.getTargetDestinations().iterator().next()); assertEquals(sourceDestination, actual.getSubscribeDestination()); assertNull(actual.getUser()); } @Test public void handleUnsubscribe() { - Message message = createWith(SimpMessageType.UNSUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo"); + TestPrincipal user = new TestPrincipal("joe"); + Message message = createMessage(SimpMessageType.UNSUBSCRIBE, user, "123", "/user/queue/foo"); UserDestinationResult actual = this.resolver.resolveDestination(message); assertEquals(1, actual.getTargetDestinations().size()); @@ -106,32 +113,37 @@ public class DefaultUserDestinationResolverTests { @Test public void handleMessage() { + TestPrincipal user = new TestPrincipal("joe"); String sourceDestination = "/user/joe/queue/foo"; - Message message = createWith(SimpMessageType.MESSAGE, this.user, SESSION_ID, sourceDestination); + Message message = createMessage(SimpMessageType.MESSAGE, user, "123", sourceDestination); UserDestinationResult actual = this.resolver.resolveDestination(message); assertEquals(sourceDestination, actual.getSourceDestination()); assertEquals(1, actual.getTargetDestinations().size()); assertEquals("/queue/foo-user123", actual.getTargetDestinations().iterator().next()); assertEquals("/user/queue/foo", actual.getSubscribeDestination()); - assertEquals(this.user.getName(), actual.getUser()); + assertEquals(user.getName(), actual.getUser()); } // SPR-12444 + @Test public void handleMessageToOtherUser() { - final String OTHER_SESSION_ID = "456"; - final String OTHER_USER_NAME = "anna"; + + TestSimpUser otherSimpUser = new TestSimpUser("anna"); + otherSimpUser.addSessions(new TestSimpSession("456")); + when(this.registry.getUser("anna")).thenReturn(otherSimpUser); + + TestPrincipal user = new TestPrincipal("joe"); + TestPrincipal otherUser = new TestPrincipal("anna"); + String sourceDestination = "/user/anna/queue/foo"; + Message message = createMessage(SimpMessageType.MESSAGE, user, "456", sourceDestination); - String sourceDestination = "/user/"+OTHER_USER_NAME+"/queue/foo"; - TestPrincipal otherUser = new TestPrincipal(OTHER_USER_NAME); - this.registry.registerSessionId(otherUser.getName(), OTHER_SESSION_ID); - Message message = createWith(SimpMessageType.MESSAGE, this.user, SESSION_ID, sourceDestination); UserDestinationResult actual = this.resolver.resolveDestination(message); assertEquals(sourceDestination, actual.getSourceDestination()); assertEquals(1, actual.getTargetDestinations().size()); - assertEquals("/queue/foo-user" + OTHER_SESSION_ID, actual.getTargetDestinations().iterator().next()); + assertEquals("/queue/foo-user456", actual.getTargetDestinations().iterator().next()); assertEquals("/user/queue/foo", actual.getSubscribeDestination()); assertEquals(otherUser.getName(), actual.getUser()); } @@ -140,9 +152,14 @@ public class DefaultUserDestinationResolverTests { public void handleMessageEncodedUserName() { String userName = "http://joe.openid.example.org/"; - this.registry.registerSessionId(userName, "openid123"); + + TestSimpUser simpUser = new TestSimpUser(userName); + simpUser.addSessions(new TestSimpSession("openid123")); + when(this.registry.getUser(userName)).thenReturn(simpUser); + String destination = "/user/" + StringUtils.replace(userName, "/", "%2F") + "/queue/foo"; - Message message = createWith(SimpMessageType.MESSAGE, this.user, null, destination); + + Message message = createMessage(SimpMessageType.MESSAGE, new TestPrincipal("joe"), null, destination); UserDestinationResult actual = this.resolver.resolveDestination(message); assertEquals(1, actual.getTargetDestinations().size()); @@ -151,8 +168,8 @@ public class DefaultUserDestinationResolverTests { @Test public void handleMessageWithNoUser() { - String sourceDestination = "/user/" + SESSION_ID + "/queue/foo"; - Message message = createWith(SimpMessageType.MESSAGE, null, SESSION_ID, sourceDestination); + String sourceDestination = "/user/" + "123" + "/queue/foo"; + Message message = createMessage(SimpMessageType.MESSAGE, null, "123", sourceDestination); UserDestinationResult actual = this.resolver.resolveDestination(message); assertEquals(sourceDestination, actual.getSourceDestination()); @@ -166,28 +183,28 @@ public class DefaultUserDestinationResolverTests { public void ignoreMessage() { // no destination - Message message = createWith(SimpMessageType.MESSAGE, this.user, SESSION_ID, null); + TestPrincipal user = new TestPrincipal("joe"); + Message message = createMessage(SimpMessageType.MESSAGE, user, "123", null); UserDestinationResult actual = this.resolver.resolveDestination(message); assertNull(actual); // not a user destination - message = createWith(SimpMessageType.MESSAGE, this.user, SESSION_ID, "/queue/foo"); + message = createMessage(SimpMessageType.MESSAGE, user, "123", "/queue/foo"); actual = this.resolver.resolveDestination(message); assertNull(actual); // subscribe + not a user destination - message = createWith(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/queue/foo"); + message = createMessage(SimpMessageType.SUBSCRIBE, user, "123", "/queue/foo"); actual = this.resolver.resolveDestination(message); assertNull(actual); // no match on message type - message = createWith(SimpMessageType.CONNECT, this.user, SESSION_ID, "user/joe/queue/foo"); + message = createMessage(SimpMessageType.CONNECT, user, "123", "user/joe/queue/foo"); actual = this.resolver.resolveDestination(message); assertNull(actual); } - - private Message createWith(SimpMessageType type, TestPrincipal user, String sessionId, String destination) { + private Message createMessage(SimpMessageType type, Principal user, String sessionId, String destination) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type); if (destination != null) { headers.setDestination(destination); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserSessionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserSessionRegistryTests.java deleted file mode 100644 index 8d93deec6ab..00000000000 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserSessionRegistryTests.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright 2002-2015 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.messaging.simp.user; - -import static org.junit.Assert.*; - -import java.util.Arrays; -import java.util.Collections; -import java.util.LinkedHashSet; -import java.util.List; - -import org.junit.Test; - -/** - * Test fixture for - * {@link org.springframework.messaging.simp.user.DefaultUserSessionRegistry} - * - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class DefaultUserSessionRegistryTests { - - private static final String user = "joe"; - - private static final List sessionIds = Arrays.asList("sess01", "sess02", "sess03"); - - - @Test - public void addOneSessionId() { - - DefaultUserSessionRegistry resolver = new DefaultUserSessionRegistry(); - resolver.registerSessionId(user, sessionIds.get(0)); - - assertEquals(Collections.singleton(sessionIds.get(0)), resolver.getSessionIds(user)); - assertSame(Collections.emptySet(), resolver.getSessionIds("jane")); - } - - @Test - public void addMultipleSessionIds() { - - DefaultUserSessionRegistry resolver = new DefaultUserSessionRegistry(); - for (String sessionId : sessionIds) { - resolver.registerSessionId(user, sessionId); - } - - assertEquals(new LinkedHashSet<>(sessionIds), resolver.getSessionIds(user)); - assertEquals(Collections.emptySet(), resolver.getSessionIds("jane")); - } - - @Test - public void removeSessionIds() { - - DefaultUserSessionRegistry resolver = new DefaultUserSessionRegistry(); - for (String sessionId : sessionIds) { - resolver.registerSessionId(user, sessionId); - } - - assertEquals(new LinkedHashSet<>(sessionIds), resolver.getSessionIds(user)); - - resolver.unregisterSessionId(user, sessionIds.get(1)); - resolver.unregisterSessionId(user, sessionIds.get(2)); - assertEquals(Collections.singleton(sessionIds.get(0)), resolver.getSessionIds(user)); - - resolver.unregisterSessionId(user, sessionIds.get(0)); - assertSame(Collections.emptySet(), resolver.getSessionIds(user)); - } - -} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java new file mode 100644 index 00000000000..c01efafd1be --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java @@ -0,0 +1,167 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import org.springframework.messaging.Message; +import org.springframework.messaging.converter.MappingJackson2MessageConverter; +import org.springframework.messaging.converter.MessageConverter; + +/** + * Unit tests for {@link MultiServerUserRegistry}. + * + * @author Rossen Stoyanchev + */ +public class MultiServerUserRegistryTests { + + private SimpUserRegistry localRegistry; + + private MultiServerUserRegistry multiServerRegistry; + + private MessageConverter converter; + + + @Before + public void setUp() throws Exception { + this.localRegistry = Mockito.mock(SimpUserRegistry.class); + this.multiServerRegistry = new MultiServerUserRegistry(this.localRegistry); + this.converter = new MappingJackson2MessageConverter(); + } + + @Test + public void getUserFromLocalRegistry() throws Exception { + + SimpUser user = Mockito.mock(SimpUser.class); + Set users = Collections.singleton(user); + when(this.localRegistry.getUsers()).thenReturn(users); + when(this.localRegistry.getUser("joe")).thenReturn(user); + + assertEquals(1, this.multiServerRegistry.getUsers().size()); + assertSame(user, this.multiServerRegistry.getUser("joe")); + } + + @Test + public void getUserFromRemoteRegistry() throws Exception { + + TestSimpSession remoteSession = new TestSimpSession("remote-sess"); + remoteSession.addSubscriptions(new TestSimpSubscription("remote-sub", "/remote-dest")); + TestSimpUser remoteUser = new TestSimpUser("joe"); + remoteUser.addSessions(remoteSession); + SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class); + when(remoteUserRegistry.getUsers()).thenReturn(Collections.singleton(remoteUser)); + + MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry); + Message message = this.converter.toMessage(remoteRegistry.getLocalRegistryDto(), null); + + this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, 20000); + assertEquals(1, this.multiServerRegistry.getUsers().size()); + + SimpUser user = this.multiServerRegistry.getUser("joe"); + assertNotNull(user); + assertEquals(1, user.getSessions().size()); + + SimpSession session = user.getSession("remote-sess"); + assertNotNull(session); + assertEquals("remote-sess", session.getId()); + assertSame(user, session.getUser()); + assertEquals(1, session.getSubscriptions().size()); + + SimpSubscription subscription = session.getSubscriptions().iterator().next(); + assertEquals("remote-sub", subscription.getId()); + assertSame(session, subscription.getSession()); + assertEquals("/remote-dest", subscription.getDestination()); + } + + @Test + public void findUserFromRemoteRegistry() throws Exception { + + TestSimpSubscription subscription1 = new TestSimpSubscription("sub1", "/match"); + TestSimpSession session1 = new TestSimpSession("sess1"); + session1.addSubscriptions(subscription1); + TestSimpUser user1 = new TestSimpUser("joe"); + user1.addSessions(session1); + + TestSimpSubscription subscription2 = new TestSimpSubscription("sub1", "/match"); + TestSimpSession session2 = new TestSimpSession("sess2"); + session2.addSubscriptions(subscription2); + TestSimpUser user2 = new TestSimpUser("jane"); + user2.addSessions(session2); + + TestSimpSubscription subscription3 = new TestSimpSubscription("sub1", "/not-a-match"); + TestSimpSession session3 = new TestSimpSession("sess3"); + session3.addSubscriptions(subscription3); + TestSimpUser user3 = new TestSimpUser("jack"); + user3.addSessions(session3); + + SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class); + when(remoteUserRegistry.getUsers()).thenReturn(new HashSet(Arrays.asList(user1, user2, user3))); + + MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry); + Message message = this.converter.toMessage(remoteRegistry.getLocalRegistryDto(), null); + + this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, 20000); + assertEquals(3, this.multiServerRegistry.getUsers().size()); + + Set matches = this.multiServerRegistry.findSubscriptions(new SimpSubscriptionMatcher() { + @Override + public boolean match(SimpSubscription subscription) { + return subscription.getDestination().equals("/match"); + } + }); + + assertEquals(2, matches.size()); + + Iterator iterator = matches.iterator(); + Set sessionIds = new HashSet<>(2); + sessionIds.add(iterator.next().getSession().getId()); + sessionIds.add(iterator.next().getSession().getId()); + assertEquals(new HashSet<>(Arrays.asList("sess1", "sess2")), sessionIds); + } + + @Test + public void purgeExpiredRegistries() throws Exception { + + TestSimpUser remoteUser = new TestSimpUser("joe"); + remoteUser.addSessions(new TestSimpSession("remote-sub")); + SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class); + when(remoteUserRegistry.getUsers()).thenReturn(Collections.singleton(remoteUser)); + + MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry); + Message message = this.converter.toMessage(remoteRegistry.getLocalRegistryDto(), null); + + long expirationPeriod = -1; + this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, expirationPeriod); + assertEquals(1, this.multiServerRegistry.getUsers().size()); + + this.multiServerRegistry.purgeExpiredRegistries(); + assertEquals(0, this.multiServerRegistry.getUsers().size()); + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpSession.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpSession.java new file mode 100644 index 00000000000..d1be4a19543 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpSession.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + + +public class TestSimpSession implements SimpSession { + + private String id; + + private TestSimpUser user; + + private Set subscriptions = new HashSet<>(); + + + public TestSimpSession(String id) { + this.id = id; + } + + @Override + public String getId() { + return id; + } + + @Override + public TestSimpUser getUser() { + return user; + } + + public void setUser(TestSimpUser user) { + this.user = user; + } + + @Override + public Set getSubscriptions() { + return subscriptions; + } + + public void addSubscriptions(TestSimpSubscription... subscriptions) { + for (TestSimpSubscription subscription : subscriptions) { + subscription.setSession(this); + this.subscriptions.add(subscription); + } + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpSubscription.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpSubscription.java new file mode 100644 index 00000000000..28f92823ce7 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpSubscription.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + + +public class TestSimpSubscription implements SimpSubscription { + + private String id; + + private TestSimpSession session; + + private String destination; + + + public TestSimpSubscription(String id, String destination) { + this.destination = destination; + this.id = id; + } + + @Override + public String getId() { + return id; + } + + @Override + public TestSimpSession getSession() { + return this.session; + } + + public void setSession(TestSimpSession session) { + this.session = session; + } + + @Override + public String getDestination() { + return destination; + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpUser.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpUser.java new file mode 100644 index 00000000000..19d2099b2d9 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/TestSimpUser.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + + +public class TestSimpUser implements SimpUser { + + private String name; + + private Map sessions = new HashMap<>(); + + + public TestSimpUser(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + + @Override + public Set getSessions() { + return new HashSet<>(this.sessions.values()); + } + + @Override + public boolean hasSessions() { + return !this.sessions.isEmpty(); + } + + @Override + public SimpSession getSession(String sessionId) { + return this.sessions.get(sessionId); + } + + public void addSessions(TestSimpSession... sessions) { + for (TestSimpSession session : sessions) { + session.setUser(this); + this.sessions.put(session.getId(), session); + } + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java index 785aee50fc3..20a3e975f60 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java @@ -25,9 +25,7 @@ import java.nio.charset.Charset; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; -import org.mockito.Mock; import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; import org.springframework.messaging.Message; import org.springframework.messaging.StubMessageChannel; @@ -50,16 +48,15 @@ public class UserDestinationMessageHandlerTests { private UserDestinationMessageHandler handler; - private UserSessionRegistry registry; + private SimpUserRegistry registry; - @Mock private SubscribableChannel brokerChannel; @Before public void setup() { - MockitoAnnotations.initMocks(this); - this.registry = new DefaultUserSessionRegistry(); + this.registry = mock(SimpUserRegistry.class); + this.brokerChannel = mock(SubscribableChannel.class); UserDestinationResolver resolver = new DefaultUserDestinationResolver(this.registry); this.handler = new UserDestinationMessageHandler(new StubMessageChannel(), this.brokerChannel, resolver); } @@ -91,7 +88,9 @@ public class UserDestinationMessageHandlerTests { @Test public void handleMessage() { - this.registry.registerSessionId("joe", "123"); + TestSimpUser simpUser = new TestSimpUser("joe"); + simpUser.addSessions(new TestSimpSession("123")); + when(this.registry.getUser("joe")).thenReturn(simpUser); given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); this.handler.handleMessage(createWith(SimpMessageType.MESSAGE, "joe", "123", "/user/joe/queue/foo")); @@ -105,7 +104,7 @@ public class UserDestinationMessageHandlerTests { @Test public void handleMessageWithoutActiveSession() { - this.handler.setUserDestinationBroadcast("/topic/unresolved"); + this.handler.setBroadcastDestination("/topic/unresolved"); given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); this.handler.handleMessage(createWith(SimpMessageType.MESSAGE, "joe", "123", "/user/joe/queue/foo")); @@ -126,9 +125,11 @@ public class UserDestinationMessageHandlerTests { @Test public void handleMessageFromBrokerWithActiveSession() { - this.registry.registerSessionId("joe", "123"); + TestSimpUser simpUser = new TestSimpUser("joe"); + simpUser.addSessions(new TestSimpSession("123")); + when(this.registry.getUser("joe")).thenReturn(simpUser); - this.handler.setUserDestinationBroadcast("/topic/unresolved"); + this.handler.setBroadcastDestination("/topic/unresolved"); given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.MESSAGE); @@ -152,7 +153,7 @@ public class UserDestinationMessageHandlerTests { @Test public void handleMessageFromBrokerWithoutActiveSession() { - this.handler.setUserDestinationBroadcast("/topic/unresolved"); + this.handler.setBroadcastDestination("/topic/unresolved"); given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true); StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.MESSAGE); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java new file mode 100644 index 00000000000..223abf0908e --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java @@ -0,0 +1,183 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.simp.user; + +import static org.junit.Assert.*; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.*; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.concurrent.ScheduledFuture; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.converter.MappingJackson2MessageConverter; +import org.springframework.messaging.converter.MessageConverter; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.broker.BrokerAvailabilityEvent; +import org.springframework.scheduling.TaskScheduler; + +/** + * User tests for {@link UserRegistryMessageHandler}. + * @author Rossen Stoyanchev + */ +public class UserRegistryMessageHandlerTests { + + private UserRegistryMessageHandler handler; + + private SimpUserRegistry localRegistry; + + private MultiServerUserRegistry multiServerRegistry; + + private MessageConverter converter; + + @Mock + private MessageChannel brokerChannel; + + @Mock + private TaskScheduler taskScheduler; + + + @Before + public void setUp() throws Exception { + + MockitoAnnotations.initMocks(this); + + when(this.brokerChannel.send(any())).thenReturn(true); + this.converter = new MappingJackson2MessageConverter(); + + SimpMessagingTemplate brokerTemplate = new SimpMessagingTemplate(this.brokerChannel); + brokerTemplate.setMessageConverter(this.converter); + + this.localRegistry = mock(SimpUserRegistry.class); + this.multiServerRegistry = new MultiServerUserRegistry(this.localRegistry); + + this.handler = new UserRegistryMessageHandler(this.multiServerRegistry, brokerTemplate, + "/topic/simp-user-registry", this.taskScheduler); + } + + @Test + public void brokerAvailableEvent() throws Exception { + Runnable runnable = getUserRegistryTask(); + assertNotNull(runnable); + } + + @SuppressWarnings("unchecked") + @Test + public void brokerUnavailableEvent() throws Exception { + + ScheduledFuture future = Mockito.mock(ScheduledFuture.class); + when(this.taskScheduler.scheduleWithFixedDelay(any(Runnable.class), any(Long.class))).thenReturn(future); + + BrokerAvailabilityEvent event = new BrokerAvailabilityEvent(true, this); + this.handler.onApplicationEvent(event); + verifyNoMoreInteractions(future); + + event = new BrokerAvailabilityEvent(false, this); + this.handler.onApplicationEvent(event); + verify(future).cancel(true); + } + + @Test + public void broadcastRegistry() throws Exception { + + TestSimpUser simpUser1 = new TestSimpUser("joe"); + TestSimpUser simpUser2 = new TestSimpUser("jane"); + + simpUser1.addSessions(new TestSimpSession("123")); + simpUser1.addSessions(new TestSimpSession("456")); + + HashSet simpUsers = new HashSet<>(Arrays.asList(simpUser1, simpUser2)); + when(this.localRegistry.getUsers()).thenReturn(simpUsers); + + getUserRegistryTask().run(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + verify(this.brokerChannel).send(captor.capture()); + + Message message = captor.getValue(); + assertNotNull(message); + MessageHeaders headers = message.getHeaders(); + assertEquals("/topic/simp-user-registry", SimpMessageHeaderAccessor.getDestination(headers)); + + MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(mock(SimpUserRegistry.class)); + remoteRegistry.addRemoteRegistryDto(message, this.converter, 20000); + assertEquals(2, remoteRegistry.getUsers().size()); + assertNotNull(remoteRegistry.getUser("joe")); + assertNotNull(remoteRegistry.getUser("jane")); + } + + @Test + public void handleMessage() throws Exception { + + TestSimpUser simpUser1 = new TestSimpUser("joe"); + TestSimpUser simpUser2 = new TestSimpUser("jane"); + + simpUser1.addSessions(new TestSimpSession("123")); + simpUser2.addSessions(new TestSimpSession("456")); + + HashSet simpUsers = new HashSet<>(Arrays.asList(simpUser1, simpUser2)); + SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class); + when(remoteUserRegistry.getUsers()).thenReturn(simpUsers); + + MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry); + Message message = this.converter.toMessage(remoteRegistry.getLocalRegistryDto(), null); + + this.handler.handleMessage(message); + + assertEquals(2, remoteRegistry.getUsers().size()); + assertNotNull(this.multiServerRegistry.getUser("joe")); + assertNotNull(this.multiServerRegistry.getUser("jane")); + } + + @Test + public void handleMessageFromOwnBroadcast() throws Exception { + + TestSimpUser simpUser = new TestSimpUser("joe"); + simpUser.addSessions(new TestSimpSession("123")); + when(this.localRegistry.getUsers()).thenReturn(Collections.singleton(simpUser)); + + assertEquals(1, this.multiServerRegistry.getUsers().size()); + + Message message = this.converter.toMessage(this.multiServerRegistry.getLocalRegistryDto(), null); + this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, 20000); + assertEquals(1, this.multiServerRegistry.getUsers().size()); + } + + + private Runnable getUserRegistryTask() { + BrokerAvailabilityEvent event = new BrokerAvailabilityEvent(true, this); + this.handler.onApplicationEvent(event); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(this.taskScheduler).scheduleWithFixedDelay(captor.capture(), eq(10000L)); + + return captor.getValue(); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index 1b69b4c75dc..1982e8b85b1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -48,8 +48,9 @@ import org.springframework.messaging.simp.SimpSessionScope; import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; import org.springframework.messaging.simp.user.DefaultUserDestinationResolver; -import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; +import org.springframework.messaging.simp.user.MultiServerUserRegistry; import org.springframework.messaging.simp.user.UserDestinationMessageHandler; +import org.springframework.messaging.simp.user.UserRegistryMessageHandler; import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; @@ -61,6 +62,7 @@ import org.springframework.util.xml.DomUtils; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory; +import org.springframework.web.socket.messaging.DefaultSimpUserRegistry; import org.springframework.web.socket.messaging.StompSubProtocolHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; import org.springframework.web.socket.messaging.WebSocketAnnotationMethodMessageHandler; @@ -98,6 +100,8 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { public static final String WEB_SOCKET_HANDLER_BEAN_NAME = "subProtocolWebSocketHandler"; + public static final String SCHEDULER_BEAN_NAME = "messageBrokerScheduler"; + public static final String SOCKJS_SCHEDULER_BEAN_NAME = "messageBrokerSockJsScheduler"; private static final int DEFAULT_MAPPING_ORDER = 1; @@ -108,10 +112,82 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { @Override public BeanDefinition parse(Element element, ParserContext context) { + Object source = context.extractSource(element); CompositeComponentDefinition compDefinition = new CompositeComponentDefinition(element.getTagName(), source); context.pushContainingComponent(compDefinition); + Element channelElem = DomUtils.getChildElementByTagName(element, "client-inbound-channel"); + RuntimeBeanReference inChannel = getMessageChannel("clientInboundChannel", channelElem, context, source); + + channelElem = DomUtils.getChildElementByTagName(element, "client-outbound-channel"); + RuntimeBeanReference outChannel = getMessageChannel("clientOutboundChannel", channelElem, context, source); + + channelElem = DomUtils.getChildElementByTagName(element, "broker-channel"); + RuntimeBeanReference brokerChannel = getMessageChannel("brokerChannel", channelElem, context, source); + + RuntimeBeanReference userRegistry = registerUserRegistry(element, context, source); + Object userDestHandler = registerUserDestHandler(element, userRegistry, inChannel, brokerChannel, context, source); + + RuntimeBeanReference converter = registerMessageConverter(element, context, source); + RuntimeBeanReference template = registerMessagingTemplate(element, brokerChannel, converter, context, source); + registerAnnotationMethodMessageHandler(element, inChannel, outChannel,converter, template, context, source); + + RootBeanDefinition broker = registerMessageBroker(element, inChannel, outChannel, brokerChannel, + userDestHandler, template, userRegistry, context, source); + + // WebSocket and sub-protocol handling + + ManagedMap urlMap = registerHandlerMapping(element, context, source); + RuntimeBeanReference stompHandler = registerStompHandler(element, inChannel, outChannel, context, source); + for (Element endpointElem : DomUtils.getChildElementsByTagName(element, "stomp-endpoint")) { + RuntimeBeanReference requestHandler = registerRequestHandler(endpointElem, stompHandler, context, source); + String pathAttribute = endpointElem.getAttribute("path"); + Assert.state(StringUtils.hasText(pathAttribute), "Invalid (no path mapping)"); + List paths = Arrays.asList(StringUtils.tokenizeToStringArray(pathAttribute, ",")); + for (String path : paths) { + path = path.trim(); + Assert.state(StringUtils.hasText(path), "Invalid path attribute: " + pathAttribute); + if (DomUtils.getChildElementByTagName(endpointElem, "sockjs") != null) { + path = path.endsWith("/") ? path + "**" : path + "/**"; + } + urlMap.put(path, requestHandler); + } + } + + Map scopeMap = Collections.singletonMap("websocket", new SimpSessionScope()); + RootBeanDefinition scopeConfigurer = new RootBeanDefinition(CustomScopeConfigurer.class); + scopeConfigurer.getPropertyValues().add("scopes", scopeMap); + registerBeanDefByName("webSocketScopeConfigurer", scopeConfigurer, context, source); + + registerWebSocketMessageBrokerStats(broker, inChannel, outChannel, context, source); + + context.popAndRegisterContainingComponent(); + return null; + } + + private RuntimeBeanReference registerUserRegistry(Element element, ParserContext context, Object source) { + + Element relayElement = DomUtils.getChildElementByTagName(element, "stomp-broker-relay"); + boolean multiServer = (relayElement != null && relayElement.hasAttribute("user-registry-broadcast")); + + if (multiServer) { + RootBeanDefinition localRegistryBeanDef = new RootBeanDefinition(DefaultSimpUserRegistry.class); + RootBeanDefinition beanDef = new RootBeanDefinition(MultiServerUserRegistry.class); + beanDef.getConstructorArgumentValues().addIndexedArgumentValue(0, localRegistryBeanDef); + String beanName = registerBeanDef(beanDef, context, source); + return new RuntimeBeanReference(beanName); + } + else { + RootBeanDefinition beanDef = new RootBeanDefinition(DefaultSimpUserRegistry.class); + String beanName = registerBeanDef(beanDef, context, source); + return new RuntimeBeanReference(beanName); + } + } + + private ManagedMap registerHandlerMapping(Element element, + ParserContext context, Object source) { + RootBeanDefinition handlerMappingDef = new RootBeanDefinition(SimpleUrlHandlerMapping.class); String orderAttribute = element.getAttribute("order"); @@ -128,58 +204,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { handlerMappingDef.getPropertyValues().add("urlMap", urlMap); registerBeanDef(handlerMappingDef, context, source); - - Element channelElem = DomUtils.getChildElementByTagName(element, "client-inbound-channel"); - RuntimeBeanReference inChannel = getMessageChannel("clientInboundChannel", channelElem, context, source); - - channelElem = DomUtils.getChildElementByTagName(element, "client-outbound-channel"); - RuntimeBeanReference outChannel = getMessageChannel("clientOutboundChannel", channelElem, context, source); - - RootBeanDefinition registryBeanDef = new RootBeanDefinition(DefaultUserSessionRegistry.class); - String registryBeanName = registerBeanDef(registryBeanDef, context, source); - RuntimeBeanReference sessionRegistry = new RuntimeBeanReference(registryBeanName); - - RuntimeBeanReference subProtoHandler = registerSubProtoHandler(element, inChannel, outChannel, - sessionRegistry, context, source); - - for (Element endpointElem : DomUtils.getChildElementsByTagName(element, "stomp-endpoint")) { - RuntimeBeanReference requestHandler = registerRequestHandler(endpointElem, subProtoHandler, context, source); - String pathAttribute = endpointElem.getAttribute("path"); - Assert.state(StringUtils.hasText(pathAttribute), "Invalid (no path mapping)"); - List paths = Arrays.asList(StringUtils.tokenizeToStringArray(pathAttribute, ",")); - for (String path : paths) { - path = path.trim(); - Assert.state(StringUtils.hasText(path), "Invalid path attribute: " + pathAttribute); - if (DomUtils.getChildElementByTagName(endpointElem, "sockjs") != null) { - path = path.endsWith("/") ? path + "**" : path + "/**"; - } - urlMap.put(path, requestHandler); - } - } - - channelElem = DomUtils.getChildElementByTagName(element, "broker-channel"); - RuntimeBeanReference brokerChannel = getMessageChannel("brokerChannel", channelElem, context, source); - - RuntimeBeanReference resolver = registerUserDestResolver(element, sessionRegistry, context, source); - RuntimeBeanReference userDestHandler = registerUserDestHandler(element, inChannel, - brokerChannel, resolver, context, source); - - RootBeanDefinition broker = registerMessageBroker(element, userDestHandler, inChannel, - outChannel, brokerChannel, context, source); - - RuntimeBeanReference converter = registerMessageConverter(element, context, source); - RuntimeBeanReference template = registerMessagingTemplate(element, brokerChannel, converter, context, source); - registerAnnotationMethodMessageHandler(element, inChannel, outChannel,converter, template, context, source); - - Map scopeMap = Collections.singletonMap("websocket", new SimpSessionScope()); - RootBeanDefinition scopeConfigurer = new RootBeanDefinition(CustomScopeConfigurer.class); - scopeConfigurer.getPropertyValues().add("scopes", scopeMap); - registerBeanDefByName("webSocketScopeConfigurer", scopeConfigurer, context, source); - - registerWebSocketMessageBrokerStats(broker, inChannel, outChannel, context, source); - - context.popAndRegisterContainingComponent(); - return null; + return urlMap; } private RuntimeBeanReference getMessageChannel(String name, Element element, ParserContext context, Object source) { @@ -240,11 +265,10 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { return executorDef; } - private RuntimeBeanReference registerSubProtoHandler(Element element, RuntimeBeanReference inChannel, - RuntimeBeanReference outChannel, RuntimeBeanReference registry, ParserContext context, Object source) { + private RuntimeBeanReference registerStompHandler(Element element, RuntimeBeanReference inChannel, + RuntimeBeanReference outChannel, ParserContext context, Object source) { RootBeanDefinition stompHandlerDef = new RootBeanDefinition(StompSubProtocolHandler.class); - stompHandlerDef.getPropertyValues().add("userSessionRegistry", registry); registerBeanDef(stompHandlerDef, context, source); ConstructorArgumentValues cavs = new ConstructorArgumentValues(); @@ -285,13 +309,16 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { RootBeanDefinition beanDef; RuntimeBeanReference sockJsService = WebSocketNamespaceUtils.registerSockJsService( - element, SOCKJS_SCHEDULER_BEAN_NAME, context, source); + element, SCHEDULER_BEAN_NAME, context, source); if (sockJsService != null) { ConstructorArgumentValues cavs = new ConstructorArgumentValues(); cavs.addIndexedArgumentValue(0, sockJsService); cavs.addIndexedArgumentValue(1, subProtoHandler); beanDef = new RootBeanDefinition(SockJsHttpRequestHandler.class, cavs, null); + + // Register alias for backwards compatibility with 4.1 + context.getRegistry().registerAlias(SCHEDULER_BEAN_NAME, SOCKJS_SCHEDULER_BEAN_NAME); } else { RuntimeBeanReference handshakeHandler = WebSocketNamespaceUtils.registerHandshakeHandler(element, context, source); @@ -312,9 +339,9 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { } private RootBeanDefinition registerMessageBroker(Element brokerElement, - RuntimeBeanReference userDestHandler, RuntimeBeanReference inChannel, - RuntimeBeanReference outChannel, RuntimeBeanReference brokerChannel, - ParserContext context, Object source) { + RuntimeBeanReference inChannel, RuntimeBeanReference outChannel, RuntimeBeanReference brokerChannel, + Object userDestHandler, RuntimeBeanReference brokerTemplate, + RuntimeBeanReference userRegistry, ParserContext context, Object source) { Element simpleBrokerElem = DomUtils.getChildElementByTagName(brokerElement, "simple-broker"); Element brokerRelayElem = DomUtils.getChildElementByTagName(brokerElement, "stomp-broker-relay"); @@ -374,11 +401,18 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { if (brokerRelayElem.hasAttribute("virtual-host")) { values.add("virtualHost", brokerRelayElem.getAttribute("virtual-host")); } - if (brokerElement.hasAttribute("user-destination-broadcast")) { - String destination = brokerElement.getAttribute("user-destination-broadcast"); - ManagedMap map = new ManagedMap(); - map.setSource(source); + ManagedMap map = new ManagedMap(); + map.setSource(source); + if (brokerRelayElem.hasAttribute("user-destination-broadcast")) { + String destination = brokerRelayElem.getAttribute("user-destination-broadcast"); map.put(destination, userDestHandler); + } + if (brokerRelayElem.hasAttribute("user-registry-broadcast")) { + String destination = brokerRelayElem.getAttribute("user-registry-broadcast"); + map.put(destination, registerUserRegistryMessageHandler(userRegistry, + brokerTemplate, destination, context, source)); + } + if (!map.isEmpty()) { values.add("systemSubscriptions", map); } Class handlerType = StompBrokerRelayMessageHandler.class; @@ -392,6 +426,22 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { return brokerDef; } + private RuntimeBeanReference registerUserRegistryMessageHandler( + RuntimeBeanReference userRegistry, RuntimeBeanReference brokerTemplate, + String destination, ParserContext context, Object source) { + + Object scheduler = WebSocketNamespaceUtils.registerScheduler(SCHEDULER_BEAN_NAME, context, source); + + RootBeanDefinition beanDef = new RootBeanDefinition(UserRegistryMessageHandler.class); + beanDef.getConstructorArgumentValues().addIndexedArgumentValue(0, userRegistry); + beanDef.getConstructorArgumentValues().addIndexedArgumentValue(1, brokerTemplate); + beanDef.getConstructorArgumentValues().addIndexedArgumentValue(2, destination); + beanDef.getConstructorArgumentValues().addIndexedArgumentValue(3, scheduler); + + String beanName = registerBeanDef(beanDef, context, source); + return new RuntimeBeanReference(beanName); + } + private RuntimeBeanReference registerMessageConverter(Element element, ParserContext context, Object source) { Element convertersElement = DomUtils.getChildElementByTagName(element, "message-converters"); ManagedList converters = new ManagedList(); @@ -484,11 +534,10 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { } private RuntimeBeanReference registerUserDestResolver(Element brokerElem, - RuntimeBeanReference userSessionRegistry, ParserContext context, Object source) { + RuntimeBeanReference userRegistry, ParserContext context, Object source) { - ConstructorArgumentValues cavs = new ConstructorArgumentValues(); - cavs.addIndexedArgumentValue(0, userSessionRegistry); - RootBeanDefinition beanDef = new RootBeanDefinition(DefaultUserDestinationResolver.class, cavs, null); + RootBeanDefinition beanDef = new RootBeanDefinition(DefaultUserDestinationResolver.class); + beanDef.getConstructorArgumentValues().addIndexedArgumentValue(0, userRegistry); if (brokerElem.hasAttribute("user-destination-prefix")) { beanDef.getPropertyValues().add("userDestinationPrefix", brokerElem.getAttribute("user-destination-prefix")); } @@ -496,19 +545,24 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { } private RuntimeBeanReference registerUserDestHandler(Element brokerElem, - RuntimeBeanReference inChannel, RuntimeBeanReference brokerChannel, - RuntimeBeanReference userDestinationResolver, ParserContext context, Object source) { + RuntimeBeanReference userRegistry, RuntimeBeanReference inChannel, + RuntimeBeanReference brokerChannel, ParserContext context, Object source) { - ConstructorArgumentValues cavs = new ConstructorArgumentValues(); - cavs.addIndexedArgumentValue(0, inChannel); - cavs.addIndexedArgumentValue(1, brokerChannel); - cavs.addIndexedArgumentValue(2, userDestinationResolver); - RootBeanDefinition beanDef = new RootBeanDefinition(UserDestinationMessageHandler.class, cavs, null); - if (brokerElem.hasAttribute("user-destination-broadcast")) { - String destination = brokerElem.getAttribute("user-destination-broadcast"); - beanDef.getPropertyValues().add("userDestinationBroadcast", destination); + Object userDestResolver = registerUserDestResolver(brokerElem, userRegistry, context, source); + + RootBeanDefinition beanDef = new RootBeanDefinition(UserDestinationMessageHandler.class); + beanDef.getConstructorArgumentValues().addIndexedArgumentValue(0, inChannel); + beanDef.getConstructorArgumentValues().addIndexedArgumentValue(1, brokerChannel); + beanDef.getConstructorArgumentValues().addIndexedArgumentValue(2, userDestResolver); + + Element relayElement = DomUtils.getChildElementByTagName(brokerElem, "stomp-broker-relay"); + if (relayElement != null && relayElement.hasAttribute("user-destination-broadcast")) { + String destination = relayElement.getAttribute("user-destination-broadcast"); + beanDef.getPropertyValues().add("broadcastDestination", destination); } - return new RuntimeBeanReference(registerBeanDef(beanDef, context, source)); + + String beanName = registerBeanDef(beanDef, context, source); + return new RuntimeBeanReference(beanName); } private void registerWebSocketMessageBrokerStats(RootBeanDefinition broker, RuntimeBeanReference inChannel, @@ -530,7 +584,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { if (context.getRegistry().containsBeanDefinition(name)) { beanDef.getPropertyValues().add("outboundChannelExecutor", context.getRegistry().getBeanDefinition(name)); } - name = SOCKJS_SCHEDULER_BEAN_NAME; + name = SCHEDULER_BEAN_NAME; if (context.getRegistry().containsBeanDefinition(name)) { beanDef.getPropertyValues().add("sockJsTaskScheduler", context.getRegistry().getBeanDefinition(name)); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java index ed93ccaa2af..8f931cf1546 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java @@ -62,7 +62,7 @@ class WebSocketNamespaceUtils { return handlerRef; } - public static RuntimeBeanReference registerSockJsService(Element element, String sockJsSchedulerName, + public static RuntimeBeanReference registerSockJsService(Element element, String schedulerName, ParserContext context, Object source) { Element sockJsElement = DomUtils.getChildElementByTagName(element, "sockjs"); @@ -79,7 +79,7 @@ class WebSocketNamespaceUtils { scheduler = new RuntimeBeanReference(customTaskSchedulerName); } else { - scheduler = registerSockJsScheduler(sockJsSchedulerName, context, source); + scheduler = registerScheduler(schedulerName, context, source); } sockJsServiceDef.getConstructorArgumentValues().addIndexedArgumentValue(0, scheduler); @@ -156,7 +156,7 @@ class WebSocketNamespaceUtils { return null; } - private static RuntimeBeanReference registerSockJsScheduler(String schedulerName, ParserContext context, Object source) { + public static RuntimeBeanReference registerScheduler(String schedulerName, ParserContext context, Object source) { if (!context.getRegistry().containsBeanDefinition(schedulerName)) { RootBeanDefinition taskSchedulerDef = new RootBeanDefinition(ThreadPoolTaskScheduler.class); taskSchedulerDef.setSource(source); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java index 22c9f724949..74038ae1373 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java @@ -22,7 +22,6 @@ import java.util.List; import java.util.Map; import org.springframework.context.ApplicationContext; -import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.scheduling.TaskScheduler; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; @@ -63,11 +62,10 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry { public WebMvcStompEndpointRegistry(WebSocketHandler webSocketHandler, WebSocketTransportRegistration transportRegistration, - UserSessionRegistry userSessionRegistry, TaskScheduler defaultSockJsTaskScheduler) { + TaskScheduler defaultSockJsTaskScheduler) { Assert.notNull(webSocketHandler, "'webSocketHandler' is required "); Assert.notNull(transportRegistration, "'transportRegistration' is required"); - Assert.notNull(userSessionRegistry, "'userSessionRegistry' is required"); this.webSocketHandler = webSocketHandler; this.subProtocolWebSocketHandler = unwrapSubProtocolWebSocketHandler(webSocketHandler); @@ -80,7 +78,6 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry { } this.stompHandler = new StompSubProtocolHandler(); - this.stompHandler.setUserSessionRegistry(userSessionRegistry); if (transportRegistration.getMessageSizeLimit() != null) { this.stompHandler.setMessageSizeLimit(transportRegistration.getMessageSizeLimit()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java index 0afefe9b504..2421d5f5609 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java @@ -25,11 +25,14 @@ import org.springframework.messaging.simp.annotation.support.SimpAnnotationMetho import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler; import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration; import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; +import org.springframework.messaging.simp.user.SimpUserRegistry; +import org.springframework.messaging.simp.user.UserSessionRegistryAdapter; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.config.WebSocketMessageBrokerStats; import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory; +import org.springframework.web.socket.messaging.DefaultSimpUserRegistry; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; import org.springframework.web.socket.messaging.WebSocketAnnotationMethodMessageHandler; @@ -58,10 +61,10 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac @Bean public HandlerMapping stompWebSocketHandlerMapping() { - WebSocketHandler handler = subProtocolWebSocketHandler(); - handler = decorateWebSocketHandler(handler); - WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry(handler, - getTransportRegistration(), userSessionRegistry(), messageBrokerSockJsTaskScheduler()); + WebSocketHandler handler = decorateWebSocketHandler(subProtocolWebSocketHandler()); + WebSocketTransportRegistration transport = getTransportRegistration(); + ThreadPoolTaskScheduler scheduler = messageBrokerTaskScheduler(); + WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry(handler, transport, scheduler); registry.setApplicationContext(getApplicationContext()); registerStompEndpoints(registry); return registry.getHandlerMapping(); @@ -90,33 +93,21 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac protected void configureWebSocketTransport(WebSocketTransportRegistration registry) { } - protected abstract void registerStompEndpoints(StompEndpointRegistry registry); - - /** - * The default TaskScheduler to use if none is configured via - * {@link SockJsServiceRegistration#setTaskScheduler(org.springframework.scheduling.TaskScheduler)}, i.e. - *
-	 * @Configuration
-	 * @EnableWebSocketMessageBroker
-	 * public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
-	 *
-	 *   public void registerStompEndpoints(StompEndpointRegistry registry) {
-	 *     registry.addEndpoint("/stomp").withSockJS().setTaskScheduler(myScheduler());
-	 *   }
-	 *
-	 *   // ...
-	 * }
-	 * 
- */ - @Bean - public ThreadPoolTaskScheduler messageBrokerSockJsTaskScheduler() { - ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); - scheduler.setThreadNamePrefix("MessageBrokerSockJS-"); - scheduler.setPoolSize(Runtime.getRuntime().availableProcessors()); - scheduler.setRemoveOnCancelPolicy(true); - return scheduler; + @Override + @SuppressWarnings("deprecation") + protected SimpUserRegistry createLocalUserRegistry() { + org.springframework.messaging.simp.user.UserSessionRegistry sessionRegistry = userSessionRegistry(); + if (sessionRegistry == null) { + return new DefaultSimpUserRegistry(); + } + else { + return (userSessionRegistry() instanceof SimpUserRegistry ? + (SimpUserRegistry) userSessionRegistry() : new UserSessionRegistryAdapter(sessionRegistry)); + } } + protected abstract void registerStompEndpoints(StompEndpointRegistry registry); + @Bean public static CustomScopeConfigurer webSocketScopeConfigurer() { CustomScopeConfigurer configurer = new CustomScopeConfigurer(); @@ -138,7 +129,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac stats.setStompBrokerRelay(brokerRelay); stats.setInboundChannelExecutor(clientInboundChannelExecutor()); stats.setOutboundChannelExecutor(clientOutboundChannelExecutor()); - stats.setSockJsTaskScheduler(messageBrokerSockJsTaskScheduler()); + stats.setSockJsTaskScheduler(messageBrokerTaskScheduler()); return stats; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/AbstractSubProtocolEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/AbstractSubProtocolEvent.java index 1b52111586a..0827275df73 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/AbstractSubProtocolEvent.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/AbstractSubProtocolEvent.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.messaging; +import java.security.Principal; + import org.springframework.context.ApplicationEvent; import org.springframework.messaging.Message; import org.springframework.util.Assert; @@ -32,6 +34,8 @@ public abstract class AbstractSubProtocolEvent extends ApplicationEvent { private final Message message; + private final Principal user; + /** * Create a new AbstractSubProtocolEvent. @@ -42,6 +46,19 @@ public abstract class AbstractSubProtocolEvent extends ApplicationEvent { super(source); Assert.notNull(message, "Message must not be null"); this.message = message; + this.user = null; + } + + /** + * Create a new AbstractSubProtocolEvent. + * @param source the component that published the event (never {@code null}) + * @param message the incoming message + */ + protected AbstractSubProtocolEvent(Object source, Message message, Principal user) { + super(source); + Assert.notNull(message, "Message must not be null"); + this.message = message; + this.user = user; } @@ -60,6 +77,13 @@ public abstract class AbstractSubProtocolEvent extends ApplicationEvent { return this.message; } + /** + * Return the user for the session associated with the event. + */ + public Principal getUser() { + return this.user; + } + @Override public String toString() { return getClass().getSimpleName() + "[" + this.message + "]"; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java new file mode 100644 index 00000000000..2921979a795 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java @@ -0,0 +1,336 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.messaging; + +import java.security.Principal; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.context.ApplicationEvent; +import org.springframework.context.event.SmartApplicationListener; +import org.springframework.core.Ordered; +import org.springframework.messaging.Message; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.user.DestinationUserNameProvider; +import org.springframework.messaging.simp.user.SimpSession; +import org.springframework.messaging.simp.user.SimpSubscription; +import org.springframework.messaging.simp.user.SimpSubscriptionMatcher; +import org.springframework.messaging.simp.user.SimpUser; +import org.springframework.messaging.simp.user.SimpUserRegistry; +import org.springframework.messaging.support.MessageHeaderAccessor; +import org.springframework.util.Assert; + +/** + * Default, mutable, thread-safe implementation of {@link SimpUserRegistry}. + * + * @author Rossen Stoyanchev + * @since 4.2 + */ +public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicationListener { + + private final Map users = new ConcurrentHashMap(); + + private final Map sessions = new ConcurrentHashMap(); + + + @Override + public SimpUser getUser(String userName) { + return this.users.get(userName); + } + + @Override + public Set getUsers() { + return new HashSet(this.users.values()); + } + + public Set findSubscriptions(SimpSubscriptionMatcher matcher) { + Set result = new HashSet(); + for (DefaultSimpSession session : this.sessions.values()) { + for (SimpSubscription subscription : session.subscriptions.values()) { + if (matcher.match(subscription)) { + result.add(subscription); + } + } + } + return result; + } + + @Override + public boolean supportsEventType(Class eventType) { + return AbstractSubProtocolEvent.class.isAssignableFrom(eventType); + } + + @Override + public boolean supportsSourceType(Class sourceType) { + return true; + } + + @Override + public void onApplicationEvent(ApplicationEvent event) { + + AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event; + Message message = subProtocolEvent.getMessage(); + SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class); + String sessionId = accessor.getSessionId(); + + if (event instanceof SessionSubscribeEvent) { + DefaultSimpSession session = this.sessions.get(sessionId); + if (session != null) { + String id = accessor.getSubscriptionId(); + String destination = accessor.getDestination(); + session.addSubscription(id, destination); + } + } + else if (event instanceof SessionConnectedEvent) { + Principal user = subProtocolEvent.getUser(); + if (user == null) { + return; + } + String name = user.getName(); + if (user instanceof DestinationUserNameProvider) { + name = ((DestinationUserNameProvider) user).getDestinationUserName(); + } + synchronized (this) { + DefaultSimpUser simpUser = this.users.get(name); + if (simpUser == null) { + simpUser = new DefaultSimpUser(name, sessionId); + this.users.put(name, simpUser); + } + else { + simpUser.addSession(sessionId); + } + this.sessions.put(sessionId, (DefaultSimpSession) simpUser.getSession(sessionId)); + } + } + else if (event instanceof SessionDisconnectEvent) { + synchronized (this) { + DefaultSimpSession session = this.sessions.remove(sessionId); + if (session != null) { + DefaultSimpUser user = session.getUser(); + user.removeSession(sessionId); + if (!user.hasSessions()) { + this.users.remove(user.getName()); + } + } + } + } + else if (event instanceof SessionUnsubscribeEvent) { + DefaultSimpSession session = this.sessions.get(sessionId); + if (session != null) { + String subscriptionId = accessor.getSubscriptionId(); + session.removeSubscription(subscriptionId); + } + } + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + + @Override + public String toString() { + return "users=" + this.users; + } + + private static class DefaultSimpUser implements SimpUser { + + private final String name; + + private final Map sessions = + new ConcurrentHashMap(1); + + + public DefaultSimpUser(String userName, String sessionId) { + Assert.notNull(userName); + Assert.notNull(sessionId); + this.name = userName; + this.sessions.put(sessionId, new DefaultSimpSession(sessionId, this)); + } + + @Override + public String getName() { + return this.name; + } + + @Override + public boolean hasSessions() { + return !this.sessions.isEmpty(); + } + + @Override + public SimpSession getSession(String sessionId) { + return this.sessions.get(sessionId); + } + + @Override + public Set getSessions() { + return new HashSet(this.sessions.values()); + } + + void addSession(String sessionId) { + DefaultSimpSession session = new DefaultSimpSession(sessionId, this); + this.sessions.put(sessionId, session); + } + + void removeSession(String sessionId) { + this.sessions.remove(sessionId); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || !(other instanceof SimpUser)) { + return false; + } + return this.name.equals(((SimpUser) other).getName()); + } + + @Override + public int hashCode() { + return this.name.hashCode(); + } + + @Override + public String toString() { + return "name=" + this.name + ", sessions=" + this.sessions; + } + } + + private static class DefaultSimpSession implements SimpSession { + + private final String id; + + private final DefaultSimpUser user; + + private final Map subscriptions = new ConcurrentHashMap(4); + + + public DefaultSimpSession(String id, DefaultSimpUser user) { + Assert.notNull(id); + Assert.notNull(user); + this.id = id; + this.user = user; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public DefaultSimpUser getUser() { + return this.user; + } + + @Override + public Set getSubscriptions() { + return new HashSet(this.subscriptions.values()); + } + + void addSubscription(String id, String destination) { + this.subscriptions.put(id, new DefaultSimpSubscription(id, destination, this)); + } + + void removeSubscription(String id) { + this.subscriptions.remove(id); + } + + @Override + public int hashCode() { + return this.id.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || !(other instanceof SimpSubscription)) { + return false; + } + return this.id.equals(((SimpSubscription) other).getId()); + } + + @Override + public String toString() { + return "id=" + this.id + ", subscriptions=" + this.subscriptions; + } + } + + private static class DefaultSimpSubscription implements SimpSubscription { + + private final String id; + + private final DefaultSimpSession session; + + private final String destination; + + + public DefaultSimpSubscription(String id, String destination, DefaultSimpSession session) { + Assert.notNull(id); + Assert.hasText(destination); + Assert.notNull(session); + this.id = id; + this.destination = destination; + this.session = session; + } + + @Override + public String getId() { + return this.id; + } + + @Override + public DefaultSimpSession getSession() { + return this.session; + } + + @Override + public String getDestination() { + return this.destination; + } + + @Override + public int hashCode() { + return 31 * this.id.hashCode() + getSession().hashCode(); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || !(other instanceof SimpSubscription)) { + return false; + } + SimpSubscription otherSubscription = (SimpSubscription) other; + return (getSession().getId().equals(otherSubscription.getSession().getId()) && + this.id.equals(otherSubscription.getId())); + } + + @Override + public String toString() { + return "destination=" + this.destination; + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectEvent.java index 8d4cdcb9018..41006feae83 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectEvent.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectEvent.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.messaging; +import java.security.Principal; + import org.springframework.messaging.Message; /** @@ -41,4 +43,8 @@ public class SessionConnectEvent extends AbstractSubProtocolEvent { super(source, message); } + public SessionConnectEvent(Object source, Message message, Principal user) { + super(source, message, user); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectedEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectedEvent.java index 1a5c9279ace..bee9d56433b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectedEvent.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectedEvent.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.messaging; +import java.security.Principal; + import org.springframework.messaging.Message; /** @@ -37,4 +39,8 @@ public class SessionConnectedEvent extends AbstractSubProtocolEvent { super(source, message); } + public SessionConnectedEvent(Object source, Message message, Principal user) { + super(source, message, user); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java index 733349c6e7e..a4167377c40 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.messaging; +import java.security.Principal; + import org.springframework.messaging.Message; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; @@ -45,14 +47,21 @@ public class SessionDisconnectEvent extends AbstractSubProtocolEvent { * @param sessionId the disconnect message * @param closeStatus the status object */ - public SessionDisconnectEvent(Object source, Message message, String sessionId, CloseStatus closeStatus) { + public SessionDisconnectEvent(Object source, Message message, String sessionId, + CloseStatus closeStatus) { + + this(source, message, sessionId, closeStatus, null); + } + + public SessionDisconnectEvent(Object source, Message message, String sessionId, + CloseStatus closeStatus, Principal user) { + super(source, message); Assert.notNull(sessionId, "'sessionId' must not be null"); this.sessionId = sessionId; this.status = closeStatus; } - /** * Return the session id. */ diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionSubscribeEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionSubscribeEvent.java index 8d18c50abbc..5a3498f7bd2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionSubscribeEvent.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionSubscribeEvent.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.messaging; +import java.security.Principal; + import org.springframework.messaging.Message; /** @@ -34,4 +36,8 @@ public class SessionSubscribeEvent extends AbstractSubProtocolEvent { super(source, message); } + public SessionSubscribeEvent(Object source, Message message, Principal user) { + super(source, message, user); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionUnsubscribeEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionUnsubscribeEvent.java index 62e0540698a..cd3352716fa 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionUnsubscribeEvent.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionUnsubscribeEvent.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.messaging; +import java.security.Principal; + import org.springframework.messaging.Message; /** @@ -34,4 +36,8 @@ public class SessionUnsubscribeEvent extends AbstractSubProtocolEvent { super(source, message); } + public SessionUnsubscribeEvent(Object source, Message message, Principal user) { + super(source, message, user); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 6bdea8f1e5e..9b4ad470bd9 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -34,7 +34,6 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.simp.SimpAttributes; import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -44,8 +43,6 @@ import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompDecoder; import org.springframework.messaging.simp.stomp.StompEncoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; -import org.springframework.messaging.simp.user.DestinationUserNameProvider; -import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.AbstractMessageChannel; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; @@ -94,8 +91,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE private int messageSizeLimit = 64 * 1024; - private UserSessionRegistry userSessionRegistry; - private final StompEncoder stompEncoder = new StompEncoder(); private final StompDecoder stompDecoder = new StompDecoder(); @@ -134,21 +129,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE return this.messageSizeLimit; } - /** - * Provide a registry with which to register active user session ids. - * @see org.springframework.messaging.simp.user.UserDestinationMessageHandler - */ - public void setUserSessionRegistry(UserSessionRegistry registry) { - this.userSessionRegistry = registry; - } - - /** - * @return the configured UserSessionRegistry. - */ - public UserSessionRegistry getUserSessionRegistry() { - return this.userSessionRegistry; - } - /** * Configure a {@link MessageHeaderInitializer} to apply to the headers of all * messages created from decoded STOMP frames and other messages sent to the @@ -234,9 +214,11 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + Principal user = session.getPrincipal(); + headerAccessor.setSessionId(session.getId()); headerAccessor.setSessionAttributes(session.getAttributes()); - headerAccessor.setUser(session.getPrincipal()); + headerAccessor.setUser(user); headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat()); if (!detectImmutableMessageInterceptor(outputChannel)) { headerAccessor.setImmutable(); @@ -257,13 +239,13 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE SimpAttributesContextHolder.setAttributesFromMessage(message); if (this.eventPublisher != null) { if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) { - publishEvent(new SessionConnectEvent(this, message)); + publishEvent(new SessionConnectEvent(this, message, user)); } else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) { - publishEvent(new SessionSubscribeEvent(this, message)); + publishEvent(new SessionSubscribeEvent(this, message, user)); } else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) { - publishEvent(new SessionUnsubscribeEvent(this, message)); + publishEvent(new SessionUnsubscribeEvent(this, message, user)); } } outputChannel.send(message); @@ -349,7 +331,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE try { SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes()); SimpAttributesContextHolder.setAttributes(simpAttributes); - publishEvent(new SessionConnectedEvent(this, (Message) message)); + Principal user = session.getPrincipal(); + publishEvent(new SessionConnectedEvent(this, (Message) message, user)); } finally { SimpAttributesContextHolder.resetAttributes(); @@ -466,10 +449,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE if (principal != null) { accessor = toMutableAccessor(accessor, message); accessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); - if (this.userSessionRegistry != null) { - String userName = getSessionRegistryUserName(principal); - this.userSessionRegistry.registerSessionId(userName, session.getId()); - } } long[] heartbeat = accessor.getHeartbeat(); if (heartbeat[1] > 0) { @@ -481,14 +460,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE return accessor; } - private String getSessionRegistryUserName(Principal principal) { - String userName = principal.getName(); - if (principal instanceof DestinationUserNameProvider) { - userName = ((DestinationUserNameProvider) principal).getDestinationUserName(); - } - return userName; - } - @Override public String resolveSessionId(Message message) { return SimpMessageHeaderAccessor.getSessionId(message.getHeaders()); @@ -505,17 +476,13 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @Override public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) { this.decoders.remove(session.getId()); - Principal principal = session.getPrincipal(); - if (principal != null && this.userSessionRegistry != null) { - String userName = getSessionRegistryUserName(principal); - this.userSessionRegistry.unregisterSessionId(userName, session.getId()); - } Message message = createDisconnectMessage(session); SimpAttributes simpAttributes = SimpAttributes.fromMessage(message); try { SimpAttributesContextHolder.setAttributes(simpAttributes); if (this.eventPublisher != null) { - publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus)); + Principal user = session.getPrincipal(); + publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus, user)); } outputChannel.send(message); } diff --git a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd index 749f1637a1f..90a535c5b68 100644 --- a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd +++ b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.2.xsd @@ -344,6 +344,27 @@ ]]> + + + + + + + + + + @@ -853,17 +874,6 @@ The prefix used to identify user destinations. Any destinations that do not start with the given prefix are not be resolved. The default value is "/user/". - ]]> - - - - - diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 1be3a4e9ebf..862e0ade48a 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -16,14 +16,20 @@ package org.springframework.web.socket.config; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.List; +import java.util.Map; import org.hamcrest.Matchers; import org.junit.Before; import org.junit.Test; + import org.springframework.beans.DirectFieldAccessor; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.config.CustomScopeConfigurer; @@ -46,9 +52,11 @@ import org.springframework.messaging.simp.annotation.support.SimpAnnotationMetho import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; import org.springframework.messaging.simp.user.DefaultUserDestinationResolver; +import org.springframework.messaging.simp.user.MultiServerUserRegistry; +import org.springframework.messaging.simp.user.SimpUserRegistry; import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.simp.user.UserDestinationResolver; -import org.springframework.messaging.simp.user.UserSessionRegistry; +import org.springframework.messaging.simp.user.UserRegistryMessageHandler; import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; @@ -64,7 +72,9 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TestWebSocketSession; import org.springframework.web.socket.handler.WebSocketHandlerDecorator; import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory; +import org.springframework.web.socket.messaging.DefaultSimpUserRegistry; import org.springframework.web.socket.messaging.StompSubProtocolHandler; +import org.springframework.web.socket.messaging.SubProtocolHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeInterceptor; @@ -75,9 +85,6 @@ import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; -import static org.hamcrest.Matchers.*; -import static org.junit.Assert.*; - /** * Test fixture for MessageBrokerBeanDefinitionParser. * See test configuration files websocket-config-broker-*.xml. @@ -133,7 +140,8 @@ public class MessageBrokerBeanDefinitionParserTests { assertEquals(25 * 1000, subProtocolWsHandler.getSendTimeLimit()); assertEquals(1024 * 1024, subProtocolWsHandler.getSendBufferSizeLimit()); - StompSubProtocolHandler stompHandler = (StompSubProtocolHandler) subProtocolWsHandler.getProtocolHandlerMap().get("v12.stomp"); + Map handlerMap = subProtocolWsHandler.getProtocolHandlerMap(); + StompSubProtocolHandler stompHandler = (StompSubProtocolHandler) handlerMap.get("v12.stomp"); assertNotNull(stompHandler); assertEquals(128 * 1024, stompHandler.getMessageSizeLimit()); @@ -166,15 +174,15 @@ public class MessageBrokerBeanDefinitionParserTests { instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class))); assertEquals(Arrays.asList("http://mydomain3.com", "http://mydomain4.com"), defaultSockJsService.getAllowedOrigins()); - UserSessionRegistry userSessionRegistry = this.appContext.getBean(UserSessionRegistry.class); - assertNotNull(userSessionRegistry); + SimpUserRegistry userRegistry = this.appContext.getBean(SimpUserRegistry.class); + assertNotNull(userRegistry); + assertEquals(DefaultSimpUserRegistry.class, userRegistry.getClass()); UserDestinationResolver userDestResolver = this.appContext.getBean(UserDestinationResolver.class); assertNotNull(userDestResolver); assertThat(userDestResolver, Matchers.instanceOf(DefaultUserDestinationResolver.class)); DefaultUserDestinationResolver defaultUserDestResolver = (DefaultUserDestinationResolver) userDestResolver; assertEquals("/personal/", defaultUserDestResolver.getDestinationPrefix()); - assertSame(stompHandler.getUserSessionRegistry(), defaultUserDestResolver.getUserSessionRegistry()); UserDestinationMessageHandler userDestHandler = this.appContext.getBean(UserDestinationMessageHandler.class); assertNotNull(userDestHandler); @@ -192,11 +200,12 @@ public class MessageBrokerBeanDefinitionParserTests { testChannel("clientInboundChannel", subscriberTypes, 2); testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); - subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); + subscriberTypes = Collections.singletonList(SubProtocolWebSocketHandler.class); testChannel("clientOutboundChannel", subscriberTypes, 1); testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); - subscriberTypes = Arrays.>asList(SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class); + subscriberTypes = Arrays.>asList( + SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class); testChannel("brokerChannel", subscriberTypes, 1); try { this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class); @@ -260,7 +269,7 @@ public class MessageBrokerBeanDefinitionParserTests { testChannel("clientInboundChannel", subscriberTypes, 2); testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); - subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); + subscriberTypes = Collections.singletonList(SubProtocolWebSocketHandler.class); testChannel("clientOutboundChannel", subscriberTypes, 1); testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); @@ -275,11 +284,20 @@ public class MessageBrokerBeanDefinitionParserTests { // expected } + String destination = "/topic/unresolved-user-destination"; UserDestinationMessageHandler userDestHandler = this.appContext.getBean(UserDestinationMessageHandler.class); - assertEquals("/topic/unresolved", userDestHandler.getUserDestinationBroadcast()); + assertEquals(destination, userDestHandler.getBroadcastDestination()); assertNotNull(messageBroker.getSystemSubscriptions()); - assertSame(userDestHandler, messageBroker.getSystemSubscriptions().get("/topic/unresolved")); + assertSame(userDestHandler, messageBroker.getSystemSubscriptions().get(destination)); + destination = "/topic/simp-user-registry"; + UserRegistryMessageHandler userRegistryHandler = this.appContext.getBean(UserRegistryMessageHandler.class); + assertEquals(destination, userRegistryHandler.getBroadcastDestination()); + assertNotNull(messageBroker.getSystemSubscriptions()); + assertSame(userRegistryHandler, messageBroker.getSystemSubscriptions().get(destination)); + + SimpUserRegistry userRegistry = this.appContext.getBean(SimpUserRegistry.class); + assertEquals(MultiServerUserRegistry.class, userRegistry.getClass()); String name = "webSocketMessageBrokerStats"; WebSocketMessageBrokerStats stats = this.appContext.getBean(name, WebSocketMessageBrokerStats.class); @@ -339,7 +357,7 @@ public class MessageBrokerBeanDefinitionParserTests { testChannel("clientInboundChannel", subscriberTypes, 3); testExecutor("clientInboundChannel", 100, 200, 600); - subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); + subscriberTypes = Collections.singletonList(SubProtocolWebSocketHandler.class); testChannel("clientOutboundChannel", subscriberTypes, 3); testExecutor("clientOutboundChannel", 101, 201, 601); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistryTests.java index 604303df736..8651638405e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistryTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistryTests.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.config.annotation; +import static org.junit.Assert.*; + import java.util.Map; import org.junit.Before; @@ -23,17 +25,12 @@ import org.junit.Test; import org.mockito.Mockito; import org.springframework.messaging.SubscribableChannel; -import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; -import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.scheduling.TaskScheduler; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; -import org.springframework.web.socket.messaging.StompSubProtocolHandler; import org.springframework.web.socket.messaging.SubProtocolHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; import org.springframework.web.util.UrlPathHelper; -import static org.junit.Assert.*; - /** * Test fixture for * {@link org.springframework.web.socket.config.annotation.WebMvcStompEndpointRegistry}. @@ -46,17 +43,16 @@ public class WebMvcStompEndpointRegistryTests { private SubProtocolWebSocketHandler webSocketHandler; - private UserSessionRegistry userSessionRegistry; - @Before public void setup() { SubscribableChannel inChannel = Mockito.mock(SubscribableChannel.class); SubscribableChannel outChannel = Mockito.mock(SubscribableChannel.class); this.webSocketHandler = new SubProtocolWebSocketHandler(inChannel, outChannel); - this.userSessionRegistry = new DefaultUserSessionRegistry(); - this.endpointRegistry = new WebMvcStompEndpointRegistry(this.webSocketHandler, - new WebSocketTransportRegistration(), this.userSessionRegistry, Mockito.mock(TaskScheduler.class)); + + WebSocketTransportRegistration transport = new WebSocketTransportRegistration(); + TaskScheduler scheduler = Mockito.mock(TaskScheduler.class); + this.endpointRegistry = new WebMvcStompEndpointRegistry(this.webSocketHandler, transport, scheduler); } @@ -69,9 +65,6 @@ public class WebMvcStompEndpointRegistryTests { assertNotNull(protocolHandlers.get("v10.stomp")); assertNotNull(protocolHandlers.get("v11.stomp")); assertNotNull(protocolHandlers.get("v12.stomp")); - - StompSubProtocolHandler stompHandler = (StompSubProtocolHandler) protocolHandlers.get("v10.stomp"); - assertSame(this.userSessionRegistry, stompHandler.getUserSessionRegistry()); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java index ac219f4fb31..5a8da5016b2 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java @@ -136,19 +136,16 @@ public class WebSocketMessageBrokerConfigurationSupportTests { } @Test - public void webSocketTransportOptions() { + public void webSocketHandler() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); - SubProtocolWebSocketHandler subProtocolWebSocketHandler = - config.getBean("subProtocolWebSocketHandler", SubProtocolWebSocketHandler.class); + SubProtocolWebSocketHandler subWsHandler = config.getBean(SubProtocolWebSocketHandler.class); - assertEquals(1024 * 1024, subProtocolWebSocketHandler.getSendBufferSizeLimit()); - assertEquals(25 * 1000, subProtocolWebSocketHandler.getSendTimeLimit()); + assertEquals(1024 * 1024, subWsHandler.getSendBufferSizeLimit()); + assertEquals(25 * 1000, subWsHandler.getSendTimeLimit()); - List protocolHandlers = subProtocolWebSocketHandler.getProtocolHandlers(); - for(SubProtocolHandler protocolHandler : protocolHandlers) { - assertTrue(protocolHandler instanceof StompSubProtocolHandler); - assertEquals(128 * 1024, ((StompSubProtocolHandler) protocolHandler).getMessageSizeLimit()); - } + Map handlerMap = subWsHandler.getProtocolHandlerMap(); + StompSubProtocolHandler protocolHandler = (StompSubProtocolHandler) handlerMap.get("v12.stomp"); + assertEquals(128 * 1024, protocolHandler.getMessageSizeLimit()); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java new file mode 100644 index 00000000000..a176e40a026 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java @@ -0,0 +1,199 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.messaging; + +import static org.junit.Assert.*; + +import java.security.Principal; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; + +import org.junit.Test; + +import org.springframework.messaging.Message; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.user.SimpSubscription; +import org.springframework.messaging.simp.user.SimpSubscriptionMatcher; +import org.springframework.messaging.simp.user.SimpUser; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.web.socket.CloseStatus; + +/** + * Test fixture for + * {@link DefaultSimpUserRegistry} + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class DefaultSimpUserRegistryTests { + + @Test + public void addOneSessionId() { + + TestPrincipal user = new TestPrincipal("joe"); + Message message = createMessage(SimpMessageType.CONNECT_ACK, "123"); + SessionConnectedEvent event = new SessionConnectedEvent(this, message, user); + + DefaultSimpUserRegistry registry = new DefaultSimpUserRegistry(); + registry.onApplicationEvent(event); + + SimpUser simpUser = registry.getUser("joe"); + assertNotNull(simpUser); + + assertEquals(1, simpUser.getSessions().size()); + assertNotNull(simpUser.getSession("123")); + } + + @Test + public void addMultipleSessionIds() { + + DefaultSimpUserRegistry registry = new DefaultSimpUserRegistry(); + + TestPrincipal user = new TestPrincipal("joe"); + Message message = createMessage(SimpMessageType.CONNECT_ACK, "123"); + SessionConnectedEvent event = new SessionConnectedEvent(this, message, user); + registry.onApplicationEvent(event); + + message = createMessage(SimpMessageType.CONNECT_ACK, "456"); + event = new SessionConnectedEvent(this, message, user); + registry.onApplicationEvent(event); + + message = createMessage(SimpMessageType.CONNECT_ACK, "789"); + event = new SessionConnectedEvent(this, message, user); + registry.onApplicationEvent(event); + + SimpUser simpUser = registry.getUser("joe"); + assertNotNull(simpUser); + + assertEquals(3, simpUser.getSessions().size()); + assertNotNull(simpUser.getSession("123")); + assertNotNull(simpUser.getSession("456")); + assertNotNull(simpUser.getSession("789")); + } + + @Test + public void removeSessionIds() { + + DefaultSimpUserRegistry registry = new DefaultSimpUserRegistry(); + + TestPrincipal user = new TestPrincipal("joe"); + Message message = createMessage(SimpMessageType.CONNECT_ACK, "123"); + SessionConnectedEvent connectedEvent = new SessionConnectedEvent(this, message, user); + registry.onApplicationEvent(connectedEvent); + + message = createMessage(SimpMessageType.CONNECT_ACK, "456"); + connectedEvent = new SessionConnectedEvent(this, message, user); + registry.onApplicationEvent(connectedEvent); + + message = createMessage(SimpMessageType.CONNECT_ACK, "789"); + connectedEvent = new SessionConnectedEvent(this, message, user); + registry.onApplicationEvent(connectedEvent); + + SimpUser simpUser = registry.getUser("joe"); + assertNotNull(simpUser); + assertEquals(3, simpUser.getSessions().size()); + + + CloseStatus status = CloseStatus.GOING_AWAY; + message = createMessage(SimpMessageType.DISCONNECT, "456"); + SessionDisconnectEvent disconnectEvent = new SessionDisconnectEvent(this, message, "456", status, user); + registry.onApplicationEvent(disconnectEvent); + + message = createMessage(SimpMessageType.DISCONNECT, "789"); + disconnectEvent = new SessionDisconnectEvent(this, message, "789", status, user); + registry.onApplicationEvent(disconnectEvent); + + assertEquals(1, simpUser.getSessions().size()); + assertNotNull(simpUser.getSession("123")); + } + + @Test + public void findSubscriptions() throws Exception { + + DefaultSimpUserRegistry registry = new DefaultSimpUserRegistry(); + + TestPrincipal user = new TestPrincipal("joe"); + Message message = createMessage(SimpMessageType.CONNECT_ACK, "123"); + SessionConnectedEvent event = new SessionConnectedEvent(this, message, user); + registry.onApplicationEvent(event); + + message = createMessage(SimpMessageType.SUBSCRIBE, "123", "sub1", "/match"); + SessionSubscribeEvent subscribeEvent = new SessionSubscribeEvent(this, message, user); + registry.onApplicationEvent(subscribeEvent); + + message = createMessage(SimpMessageType.SUBSCRIBE, "123", "sub2", "/match"); + subscribeEvent = new SessionSubscribeEvent(this, message, user); + registry.onApplicationEvent(subscribeEvent); + + message = createMessage(SimpMessageType.SUBSCRIBE, "123", "sub3", "/not-a-match"); + subscribeEvent = new SessionSubscribeEvent(this, message, user); + registry.onApplicationEvent(subscribeEvent); + + Set matches = registry.findSubscriptions(new SimpSubscriptionMatcher() { + @Override + public boolean match(SimpSubscription subscription) { + return subscription.getDestination().equals("/match"); + } + }); + + assertEquals(2, matches.size()); + + Iterator iterator = matches.iterator(); + Set sessionIds = new HashSet<>(2); + sessionIds.add(iterator.next().getId()); + sessionIds.add(iterator.next().getId()); + assertEquals(new HashSet<>(Arrays.asList("sub1", "sub2")), sessionIds); + } + + private Message createMessage(SimpMessageType type, String sessionId) { + return createMessage(type, sessionId, null, null); + } + + private Message createMessage(SimpMessageType type, String sessionId, String subscriptionId, + String destination) { + + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(type); + accessor.setSessionId(sessionId); + if (destination != null) { + accessor.setDestination(destination); + } + if (subscriptionId != null) { + accessor.setSubscriptionId(subscriptionId); + } + return MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders()); + } + + + private static class TestPrincipal implements Principal { + + private String name; + + public TestPrincipal(String name) { + this.name = name; + } + + @Override + public String getName() { + return this.name; + } + + } + +} \ No newline at end of file diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 052e4c8d3b6..002e6e16f38 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -47,9 +47,7 @@ import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompEncoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; -import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; import org.springframework.messaging.simp.user.DestinationUserNameProvider; -import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.ChannelInterceptorAdapter; import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; @@ -96,9 +94,6 @@ public class StompSubProtocolHandlerTests { @Test public void handleMessageToClientWithConnectedFrame() { - UserSessionRegistry registry = new DefaultUserSessionRegistry(); - this.protocolHandler.setUserSessionRegistry(registry); - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); this.protocolHandler.handleMessageToClient(this.session, message); @@ -106,8 +101,6 @@ public class StompSubProtocolHandlerTests { assertEquals(1, this.session.getSentMessages().size()); WebSocketMessage textMessage = this.session.getSentMessages().get(0); assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload()); - - assertEquals(Collections.singleton("s1"), registry.getSessionIds("joe")); } @Test @@ -115,9 +108,6 @@ public class StompSubProtocolHandlerTests { this.session.setPrincipal(new UniqueUser("joe")); - UserSessionRegistry registry = new DefaultUserSessionRegistry(); - this.protocolHandler.setUserSessionRegistry(registry); - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); this.protocolHandler.handleMessageToClient(this.session, message); @@ -125,9 +115,6 @@ public class StompSubProtocolHandlerTests { assertEquals(1, this.session.getSentMessages().size()); WebSocketMessage textMessage = this.session.getSentMessages().get(0); assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload()); - - assertEquals(Collections.emptySet(), registry.getSessionIds("joe")); - assertEquals(Collections.singleton("s1"), registry.getSessionIds("Me myself and I")); } @Test @@ -348,8 +335,6 @@ public class StompSubProtocolHandlerTests { TestPublisher publisher = new TestPublisher(); - UserSessionRegistry registry = new DefaultUserSessionRegistry(); - this.protocolHandler.setUserSessionRegistry(registry); this.protocolHandler.setApplicationEventPublisher(publisher); this.protocolHandler.afterSessionStarted(this.session, this.channel); @@ -387,8 +372,6 @@ public class StompSubProtocolHandlerTests { ApplicationEventPublisher publisher = mock(ApplicationEventPublisher.class); - UserSessionRegistry registry = new DefaultUserSessionRegistry(); - this.protocolHandler.setUserSessionRegistry(registry); this.protocolHandler.setApplicationEventPublisher(publisher); this.protocolHandler.afterSessionStarted(this.session, this.channel); diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-relay.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-relay.xml index 5462d17935d..1fdc4171a8e 100644 --- a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-relay.xml +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-relay.xml @@ -4,7 +4,7 @@ xsi:schemaLocation="http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd http://www.springframework.org/schema/websocket http://www.springframework.org/schema/websocket/spring-websocket.xsd"> - + @@ -12,7 +12,9 @@ client-login="clientlogin" client-passcode="clientpass" system-login="syslogin" system-passcode="syspass" heartbeat-send-interval="5000" heartbeat-receive-interval="5000" - virtual-host="spring.io"/> + virtual-host="spring.io" + user-destination-broadcast="/topic/unresolved-user-destination" + user-registry-broadcast="/topic/simp-user-registry"/>