Add SimpUserRegistry with multi-server support

This change introduces SimpUserRegistry exposing an API to access
information about connected users, their sessions, and subscriptions
with STOMP/WebSocket messaging. Provides are methods to access users
as well as a method to find subscriptions given a Matcher strategy.

The DefaultSimpUserRegistry implementation is also a
SmartApplicationListener which listesn for ApplicationContext events
when users connect, disconnect, subscribe, and unsubscribe to
destinations.

The MultiServerUserRegistry implementation is a composite that
aggregates user information from the local SimpUserRegistry as well
as snapshots of user  on remote application servers.

UserRegistryMessageHandler is used with MultiServerUserRegistry. It
broadcats user registry information through the broker and listens
for similar broadcasts from other servers. This must be enabled
explicitly when configuring the STOMP broker relay.

The existing UserSessionRegistry which was primiarly used internally
to resolve a user name to session id's has been deprecated and is no
longer used. If an application configures a custom UserSessionRegistr
still, it will be adapted accordingly to SimpUserRegistry but the
effect is rather limited (comparable to pre-existing functionality)
and will not work in multi-server scenarios.

Issue: SPR-12029
This commit is contained in:
Rossen Stoyanchev 2015-05-06 18:31:26 -04:00
parent 52153bd454
commit 281588d7bb
46 changed files with 2627 additions and 484 deletions

View File

@ -17,7 +17,6 @@
package org.springframework.messaging.simp.config;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -43,14 +42,16 @@ import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler;
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.user.DefaultUserDestinationResolver;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.MultiServerUserRegistry;
import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserDestinationResolver;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.simp.user.UserRegistryMessageHandler;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.ClassUtils;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.PathMatcher;
@ -88,14 +89,14 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
"com.fasterxml.jackson.databind.ObjectMapper", AbstractMessageBrokerConfiguration.class.getClassLoader());
private ApplicationContext applicationContext;
private ChannelRegistration clientInboundChannelRegistration;
private ChannelRegistration clientOutboundChannelRegistration;
private MessageBrokerRegistry brokerRegistry;
private ApplicationContext applicationContext;
/**
* Protected constructor.
@ -287,12 +288,16 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
if (handler == null) {
return new NoOpBrokerMessageHandler();
}
Map<String, MessageHandler> subscriptions = new HashMap<String, MessageHandler>(1);
String destination = getBrokerRegistry().getUserDestinationBroadcast();
if (destination != null) {
Map<String, MessageHandler> map = new HashMap<String, MessageHandler>(1);
map.put(destination, userDestinationMessageHandler());
handler.setSystemSubscriptions(map);
subscriptions.put(destination, userDestinationMessageHandler());
}
destination = getBrokerRegistry().getUserRegistryBroadcast();
if (destination != null) {
subscriptions.put(destination, userRegistryMessageHandler());
}
handler.setSystemSubscriptions(subscriptions);
return handler;
}
@ -301,10 +306,30 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
UserDestinationMessageHandler handler = new UserDestinationMessageHandler(clientInboundChannel(),
brokerChannel(), userDestinationResolver());
String destination = getBrokerRegistry().getUserDestinationBroadcast();
handler.setUserDestinationBroadcast(destination);
handler.setBroadcastDestination(destination);
return handler;
}
@Bean
public MessageHandler userRegistryMessageHandler() {
if (getBrokerRegistry().getUserRegistryBroadcast() == null) {
return new NoOpMessageHandler();
}
return new UserRegistryMessageHandler(userRegistry(), brokerMessagingTemplate(),
getBrokerRegistry().getUserRegistryBroadcast(), messageBrokerTaskScheduler());
}
// Expose alias for 4.1 compatibility
@Bean(name={"messageBrokerTaskScheduler", "messageBrokerSockJsTaskScheduler"})
public ThreadPoolTaskScheduler messageBrokerTaskScheduler() {
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
scheduler.setThreadNamePrefix("MessageBroker-");
scheduler.setPoolSize(Runtime.getRuntime().availableProcessors());
scheduler.setRemoveOnCancelPolicy(true);
return scheduler;
}
@Bean
public SimpMessagingTemplate brokerMessagingTemplate() {
SimpMessagingTemplate template = new SimpMessagingTemplate(brokerChannel());
@ -350,7 +375,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
@Bean
public UserDestinationResolver userDestinationResolver() {
DefaultUserDestinationResolver resolver = new DefaultUserDestinationResolver(userSessionRegistry());
DefaultUserDestinationResolver resolver = new DefaultUserDestinationResolver(userRegistry());
String prefix = getBrokerRegistry().getUserDestinationPrefix();
if (prefix != null) {
resolver.setUserDestinationPrefix(prefix);
@ -359,8 +384,24 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
}
@Bean
public UserSessionRegistry userSessionRegistry() {
return new DefaultUserSessionRegistry();
@SuppressWarnings("deprecation")
public SimpUserRegistry userRegistry() {
return (getBrokerRegistry().getUserRegistryBroadcast() != null ?
new MultiServerUserRegistry(createLocalUserRegistry()) : createLocalUserRegistry());
}
protected abstract SimpUserRegistry createLocalUserRegistry();
/**
* As of 4.2, UserSessionRegistry is deprecated in favor of SimpUserRegistry
* exposing information about all connected users. The MultiServerUserRegistry
* implementation in combination with UserRegistryMessageHandler can be used
* to share user registries across multiple servers.
*/
@Deprecated
@SuppressWarnings("deprecation")
protected org.springframework.messaging.simp.user.UserSessionRegistry userSessionRegistry() {
return null;
}
/**
@ -417,6 +458,14 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
}
private static class NoOpMessageHandler implements MessageHandler {
@Override
public void handleMessage(Message<?> message) {
}
}
private class NoOpBrokerMessageHandler extends AbstractBrokerMessageHandler {
public NoOpBrokerMessageHandler() {

View File

@ -23,6 +23,7 @@ import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.util.Assert;
import org.springframework.util.PathMatcher;
@ -49,8 +50,6 @@ public class MessageBrokerRegistry {
private String userDestinationPrefix;
private String userDestinationBroadcast;
private PathMatcher pathMatcher;
@ -139,22 +138,14 @@ public class MessageBrokerRegistry {
return this.userDestinationPrefix;
}
/**
* Set a destination to broadcast messages to that remain unresolved because
* the user is not connected. In a multi-application server scenario this
* gives other application servers a chance to try.
* <p><strong>Note:</strong> this option applies only when the
* {@link #enableStompBrokerRelay "broker relay"} is enabled.
* <p>By default this is not set.
* @param destination the destination to forward unresolved
* messages to, e.g. "/topic/unresolved-user-destination".
*/
public void setUserDestinationBroadcast(String destination) {
this.userDestinationBroadcast = destination;
protected String getUserDestinationBroadcast() {
return (this.brokerRelayRegistration != null ?
this.brokerRelayRegistration.getUserDestinationBroadcast() : null);
}
protected String getUserDestinationBroadcast() {
return this.userDestinationBroadcast;
protected String getUserRegistryBroadcast() {
return (this.brokerRelayRegistration != null ?
this.brokerRelayRegistration.getUserRegistryBroadcast() : null);
}
/**

View File

@ -49,6 +49,10 @@ public class StompBrokerRelayRegistration extends AbstractBrokerRegistration {
private boolean autoStartup = true;
private String userDestinationBroadcast;
private String userRegistryBroadcast;
public StompBrokerRelayRegistration(SubscribableChannel clientInboundChannel,
MessageChannel clientOutboundChannel, String[] destinationPrefixes) {
@ -166,10 +170,48 @@ public class StompBrokerRelayRegistration extends AbstractBrokerRegistration {
return this;
}
/**
* Set a destination to broadcast messages to user destinations that remain
* unresolved because the user appears not to be connected. In a
* multi-application server scenario this gives other application servers
* a chance to try.
* <p>By default this is not set.
* @param destination the destination to broadcast unresolved messages to,
* e.g. "/topic/unresolved-user-destination"
*/
public StompBrokerRelayRegistration setUserDestinationBroadcast(String destination) {
this.userDestinationBroadcast = destination;
return this;
}
protected String getUserDestinationBroadcast() {
return this.userDestinationBroadcast;
}
/**
* Set a destination to broadcast the content of the local user registry to
* and to listen for such broadcasts from other servers. In a multi-application
* server scenarios this allows each server's user registry to be aware of
* users connected to other servers.
* <p>By default this is not set.
* @param destination the destination for broadcasting user registry details,
* e.g. "/topic/simp-user-registry".
*/
public StompBrokerRelayRegistration setUserRegistryBroadcast(String destination) {
this.userRegistryBroadcast = destination;
return this;
}
protected String getUserRegistryBroadcast() {
return this.userRegistryBroadcast;
}
protected StompBrokerRelayMessageHandler getMessageHandler(SubscribableChannel brokerChannel) {
StompBrokerRelayMessageHandler handler = new StompBrokerRelayMessageHandler(getClientInboundChannel(),
getClientOutboundChannel(), brokerChannel, getDestinationPrefixes());
StompBrokerRelayMessageHandler handler = new StompBrokerRelayMessageHandler(
getClientInboundChannel(), getClientOutboundChannel(),
brokerChannel, getDestinationPrefixes());
handler.setRelayHost(this.relayHost);
handler.setRelayPort(this.relayPort);

View File

@ -33,8 +33,7 @@ import org.springframework.util.StringUtils;
/**
* A default implementation of {@code UserDestinationResolver} that relies
* on a {@link org.springframework.messaging.simp.user.UserSessionRegistry} to
* find active sessions for a user.
* on a {@link SimpUserRegistry} to find active sessions for a user.
*
* <p>When a user attempts to subscribe, e.g. to "/user/queue/position-updates",
* the "/user" prefix is removed and a unique suffix added based on the session
@ -54,7 +53,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
private static final Log logger = LogFactory.getLog(DefaultUserDestinationResolver.class);
private final UserSessionRegistry sessionRegistry;
private final SimpUserRegistry userRegistry;
private String prefix = "/user/";
@ -62,19 +61,19 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
/**
* Create an instance that will access user session id information through
* the provided registry.
* @param sessionRegistry the registry, never {@code null}
* @param userRegistry the registry, never {@code null}
*/
public DefaultUserDestinationResolver(UserSessionRegistry sessionRegistry) {
Assert.notNull(sessionRegistry, "'sessionRegistry' must not be null");
this.sessionRegistry = sessionRegistry;
public DefaultUserDestinationResolver(SimpUserRegistry userRegistry) {
Assert.notNull(userRegistry, "'userRegistry' must not be null");
this.userRegistry = userRegistry;
}
/**
* Return the configured {@link UserSessionRegistry}.
* Return the configured {@link SimpUserRegistry}.
*/
public UserSessionRegistry getUserSessionRegistry() {
return this.sessionRegistry;
public SimpUserRegistry getSimpUserRegistry() {
return this.userRegistry;
}
/**
@ -141,20 +140,32 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
Assert.isTrue(userEnd > 0, "Expected destination pattern \"/user/{userId}/**\"");
String actualDestination = destination.substring(userEnd);
String subscribeDestination = this.prefix.substring(0, prefixEnd - 1) + actualDestination;
String user = destination.substring(prefixEnd, userEnd);
user = StringUtils.replace(user, "%2F", "/");
String userName = destination.substring(prefixEnd, userEnd);
userName = StringUtils.replace(userName, "%2F", "/");
Set<String> sessionIds;
if (user.equals(sessionId)) {
user = null;
sessionIds = Collections.singleton(sessionId);
}
else if (this.sessionRegistry.getSessionIds(user).contains(sessionId)) {
if (userName.equals(sessionId)) {
userName = null;
sessionIds = Collections.singleton(sessionId);
}
else {
sessionIds = this.sessionRegistry.getSessionIds(user);
SimpUser user = this.userRegistry.getUser(userName);
if (user != null) {
if (user.getSession(sessionId) != null) {
sessionIds = Collections.singleton(sessionId);
}
else {
Set<SimpSession> sessions = user.getSessions();
sessionIds = new HashSet<String>(sessions.size());
for (SimpSession session : sessions) {
sessionIds.add(session.getId());
}
}
}
else {
sessionIds = Collections.<String>emptySet();
}
}
return new ParseResult(actualDestination, subscribeDestination, sessionIds, user);
return new ParseResult(actualDestination, subscribeDestination, sessionIds, userName);
}
else {
return null;
@ -174,6 +185,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
* @param user the target user, possibly {@code null}, e.g if not authenticated.
* @return a target destination, or {@code null} if none
*/
@SuppressWarnings("unused")
protected String getTargetDestination(String sourceDestination, String actualDestination,
String sessionId, String user) {

View File

@ -29,7 +29,11 @@ import org.springframework.util.Assert;
*
* @author Rossen Stoyanchev
* @since 4.0
* @deprecated as of 4.2 this class is no longer used, see deprecation notes
* on {@link UserSessionRegistry} for more details.
*/
@Deprecated
@SuppressWarnings({"deprecation", "unused"})
public class DefaultUserSessionRegistry implements UserSessionRegistry {
// userId -> sessionId
@ -72,4 +76,4 @@ public class DefaultUserSessionRegistry implements UserSessionRegistry {
}
}
}
}

View File

@ -0,0 +1,488 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.core.Ordered;
import org.springframework.messaging.Message;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
/**
* A user registry that is a composite of the "local" user registry as well as
* snapshots of remote user registries. For use with
* {@link UserRegistryMessageHandler} which broadcasts periodically the content
* of the local registry and receives updates from other servers.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
@SuppressWarnings("serial")
public class MultiServerUserRegistry implements SimpUserRegistry, SmartApplicationListener {
private final String id;
private final SimpUserRegistry localRegistry;
private final SmartApplicationListener listener;
private final Map<String, UserRegistryDto> remoteRegistries =
new ConcurrentHashMap<String, UserRegistryDto>();
/**
* Create an instance wrapping the local user registry.
*/
public MultiServerUserRegistry(SimpUserRegistry localRegistry) {
Assert.notNull(localRegistry, "'localRegistry' is required.");
this.localRegistry = localRegistry;
this.listener = (this.localRegistry instanceof SmartApplicationListener ?
(SmartApplicationListener) this.localRegistry : new NoOpSmartApplicationListener());
this.id = generateId();
}
private static String generateId() {
String host;
try {
host = InetAddress.getLocalHost().getHostAddress();
}
catch (UnknownHostException e) {
host = "unknown";
}
return host + "-" + UUID.randomUUID();
}
@Override
public SimpUser getUser(String userName) {
SimpUser user = this.localRegistry.getUser(userName);
if (user != null) {
return user;
}
for (UserRegistryDto registry : this.remoteRegistries.values()) {
user = registry.getUsers().get(userName);
if (user != null) {
return user;
}
}
return null;
}
@Override
public Set<SimpUser> getUsers() {
Set<SimpUser> result = new HashSet<SimpUser>(this.localRegistry.getUsers());
for (UserRegistryDto registry : this.remoteRegistries.values()) {
result.addAll(registry.getUsers().values());
}
return result;
}
@Override
public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) {
Set<SimpSubscription> result = new HashSet<SimpSubscription>(this.localRegistry.findSubscriptions(matcher));
for (UserRegistryDto registry : this.remoteRegistries.values()) {
result.addAll(registry.findSubscriptions(matcher));
}
return result;
}
@Override
public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) {
return this.listener.supportsEventType(eventType);
}
@Override
public boolean supportsSourceType(Class<?> sourceType) {
return this.listener.supportsSourceType(sourceType);
}
@Override
public void onApplicationEvent(ApplicationEvent event) {
this.listener.onApplicationEvent(event);
}
@Override
public int getOrder() {
return this.listener.getOrder();
}
Object getLocalRegistryDto() {
return new UserRegistryDto(this.id, this.localRegistry);
}
void addRemoteRegistryDto(Message<?> message, MessageConverter converter, long expirationPeriod) {
UserRegistryDto registryDto = (UserRegistryDto) converter.fromMessage(message, UserRegistryDto.class);
if (registryDto != null && !registryDto.getId().equals(this.id)) {
long expirationTime = System.currentTimeMillis() + expirationPeriod;
registryDto.setExpirationTime(expirationTime);
registryDto.restoreParentReferences();
this.remoteRegistries.put(registryDto.getId(), registryDto);
}
}
void purgeExpiredRegistries() {
long now = System.currentTimeMillis();
Iterator<Map.Entry<String, UserRegistryDto>> iterator = this.remoteRegistries.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<String, UserRegistryDto> entry = iterator.next();
if (now > entry.getValue().getExpirationTime()) {
iterator.remove();
}
}
}
@Override
public String toString() {
return "local=[" + this.localRegistry + "], remote=" + this.remoteRegistries + "]";
}
@SuppressWarnings("unused")
private static class UserRegistryDto {
private String id;
private Map<String, SimpUserDto> users;
private long expirationTime;
public UserRegistryDto() {
}
public UserRegistryDto(String id, SimpUserRegistry registry) {
this.id = id;
Set<SimpUser> users = registry.getUsers();
this.users = new HashMap<String, SimpUserDto>(users.size());
for (SimpUser user : users) {
this.users.put(user.getName(), new SimpUserDto(user));
}
}
public void setId(String id) {
this.id = id;
}
public String getId() {
return this.id;
}
public void setUsers(Map<String, SimpUserDto> users) {
this.users = users;
}
public Map<String, SimpUserDto> getUsers() {
return this.users;
}
public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) {
Set<SimpSubscription> result = new HashSet<SimpSubscription>();
for (SimpUserDto user : this.users.values()) {
for (SimpSessionDto session : user.sessions) {
for (SimpSubscription subscription : session.subscriptions) {
if (matcher.match(subscription)) {
result.add(subscription);
}
}
}
}
return result;
}
public void setExpirationTime(long expirationTime) {
this.expirationTime = expirationTime;
}
public long getExpirationTime() {
return this.expirationTime;
}
private void restoreParentReferences() {
for (SimpUserDto user : this.users.values()) {
user.restoreParentReferences();
}
}
@Override
public String toString() {
return "id=" + this.id + ", users=" + this.users;
}
}
@SuppressWarnings("unused")
private static class SimpUserDto implements SimpUser {
private String name;
private Set<SimpSessionDto> sessions;
public SimpUserDto() {
this.sessions = new HashSet<SimpSessionDto>(1);
}
public SimpUserDto(SimpUser user) {
this.name = user.getName();
Set<SimpSession> sessions = user.getSessions();
this.sessions = new HashSet<SimpSessionDto>(sessions.size());
for (SimpSession session : sessions) {
this.sessions.add(new SimpSessionDto(session));
}
}
@Override
public String getName() {
return this.name;
}
public void setName(String name) {
this.name = name;
}
@Override
public boolean hasSessions() {
return !this.sessions.isEmpty();
}
@Override
public Set<SimpSession> getSessions() {
return new HashSet<SimpSession>(this.sessions);
}
public void setSessions(Set<SimpSessionDto> sessions) {
this.sessions.addAll(sessions);
}
@Override
public SimpSessionDto getSession(String sessionId) {
for (SimpSessionDto session : this.sessions) {
if (session.getId().equals(sessionId)) {
return session;
}
}
return null;
}
private void restoreParentReferences() {
for (SimpSessionDto session : this.sessions) {
session.setUser(this);
session.restoreParentReferences();
}
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || !(other instanceof SimpUser)) {
return false;
}
return this.name.equals(((SimpUser) other).getName());
}
@Override
public int hashCode() {
return this.name.hashCode();
}
@Override
public String toString() {
return "name=" + this.name + ", sessions=" + this.sessions;
}
}
@SuppressWarnings("unused")
private static class SimpSessionDto implements SimpSession {
private String id;
private SimpUserDto user;
private Set<SimpSubscriptionDto> subscriptions;
public SimpSessionDto() {
this.subscriptions = new HashSet<SimpSubscriptionDto>(4);
}
public SimpSessionDto(SimpSession session) {
this.id = session.getId();
Set<SimpSubscription> subscriptions = session.getSubscriptions();
this.subscriptions = new HashSet<SimpSubscriptionDto>(subscriptions.size());
for (SimpSubscription subscription : subscriptions) {
this.subscriptions.add(new SimpSubscriptionDto(subscription));
}
}
@Override
public String getId() {
return this.id;
}
public void setId(String id) {
this.id = id;
}
@Override
public SimpUserDto getUser() {
return this.user;
}
public void setUser(SimpUserDto user) {
this.user = user;
}
@Override
public Set<SimpSubscription> getSubscriptions() {
return new HashSet<SimpSubscription>(this.subscriptions);
}
public void setSubscriptions(Set<SimpSubscriptionDto> subscriptions) {
this.subscriptions.addAll(subscriptions);
}
private void restoreParentReferences() {
for (SimpSubscriptionDto subscription : this.subscriptions) {
subscription.setSession(this);
}
}
@Override
public int hashCode() {
return this.id.hashCode();
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || !(other instanceof SimpSession)) {
return false;
}
return this.id.equals(((SimpSession) other).getId());
}
@Override
public String toString() {
return "id=" + this.id + ", subscriptions=" + this.subscriptions;
}
}
@SuppressWarnings("unused")
private static class SimpSubscriptionDto implements SimpSubscription {
private String id;
private SimpSessionDto session;
private String destination;
public SimpSubscriptionDto() {
}
public SimpSubscriptionDto(SimpSubscription subscription) {
this.id = subscription.getId();
this.destination = subscription.getDestination();
}
@Override
public String getId() {
return this.id;
}
public void setId(String id) {
this.id = id;
}
@Override
public SimpSessionDto getSession() {
return this.session;
}
public void setSession(SimpSessionDto session) {
this.session = session;
}
@Override
public String getDestination() {
return this.destination;
}
public void setDestination(String destination) {
this.destination = destination;
}
@Override
public int hashCode() {
return 31 * this.id.hashCode() + ObjectUtils.nullSafeHashCode(getSession());
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || !(other instanceof SimpSubscription)) {
return false;
}
SimpSubscription otherSubscription = (SimpSubscription) other;
return (ObjectUtils.nullSafeEquals(getSession(), otherSubscription.getSession()) &&
this.id.equals(otherSubscription.getId()));
}
@Override
public String toString() {
return "destination=" + this.destination;
}
}
private static class NoOpSmartApplicationListener implements SmartApplicationListener {
@Override
public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) {
return false;
}
@Override
public boolean supportsSourceType(Class<?> sourceType) {
return false;
}
@Override
public void onApplicationEvent(ApplicationEvent event) {
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import java.util.Set;
/**
* Represents a session of connected user.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public interface SimpSession {
/**
* Return the session id.
*/
String getId();
/**
* Return the user associated with the session.
*/
SimpUser getUser();
/**
* Return the subscriptions for this session.
*/
Set<SimpSubscription> getSubscriptions();
}

View File

@ -0,0 +1,41 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
/**
* Represents a subscription within a user session.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public interface SimpSubscription {
/**
* Return the id associated of the subscription, never {@code null}.
*/
String getId();
/**
* Return the session of the subscription, never {@code null}.
*/
SimpSession getSession();
/**
* Return the subscription's destination, never {@code null}.
*/
String getDestination();
}

View File

@ -0,0 +1,33 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
/**
* A strategy for matching subscriptions.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public interface SimpSubscriptionMatcher {
/**
* Match the given subscription.
* @param subscription the subscription to match
* @return {@code true} in case of match, {@code false} otherwise.
*/
boolean match(SimpSubscription subscription);
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import java.util.Set;
/**
* Represents a connected user.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public interface SimpUser {
/**
* The unique user name.
*/
String getName();
/**
* Whether the user has any sessions.
*/
boolean hasSessions();
/**
* Look up the session for the given id.
* @param sessionId the session id
* @return the matching session of {@code null}.
*/
SimpSession getSession(String sessionId);
/**
* Return the sessions for the user.
* The returned set is a copy and will never be modified.
* @return a set of session ids, or an empty set.
*/
Set<SimpSession> getSessions();
}

View File

@ -0,0 +1,49 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import java.util.Set;
/**
* A registry of currently connected users.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public interface SimpUserRegistry {
/**
* Get the user for the given name.
* @param userName the name of the user to look up
* @return the user or {@code null} if not connected
*/
SimpUser getUser(String userName);
/**
* Return a snapshot of all connected users. The returned set is a copy and
* will never be modified.
* @return the connected users or an empty set.
*/
Set<SimpUser> getUsers();
/**
* Find subscriptions with the given matcher.
* @param matcher the matcher to use
* @return a set of matching subscriptions or an empty set.
*/
Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher);
}

View File

@ -108,7 +108,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
* <p>By default this is not set.
* @param destination the target destination.
*/
public void setUserDestinationBroadcast(String destination) {
public void setBroadcastDestination(String destination) {
this.broadcastHandler = (StringUtils.hasText(destination) ?
new BroadcastHandler(this.messagingTemplate, destination) : null);
}
@ -116,7 +116,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
/**
* Return the configured destination for unresolved messages.
*/
public String getUserDestinationBroadcast() {
public String getBroadcastDestination() {
return (this.broadcastHandler != null ? this.broadcastHandler.getBroadcastDestination() : null);
}

View File

@ -87,7 +87,7 @@ public class UserDestinationResult {
* @return the user name or {@code null} if we have a session id only such as
* when the user is not authenticated; in such cases it is possible to use
* sessionId in place of a user name thus removing the need for a user-to-session
* lookup via {@link org.springframework.messaging.simp.user.UserSessionRegistry}.
* lookup via {@link SimpUserRegistry}.
*/
public String getUser() {
return this.user;

View File

@ -0,0 +1,136 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import java.util.concurrent.ScheduledFuture;
import org.springframework.context.ApplicationListener;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.broker.BrokerAvailabilityEvent;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
/**
* A MessageHandler that is subscribed to listen to broadcasts of user registry
* information from other application servers as well as to periodically
* broadcast the content of the local user registry. The aggregated information
* is maintained in a {@link MultiServerUserRegistry}.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public class UserRegistryMessageHandler implements MessageHandler, ApplicationListener<BrokerAvailabilityEvent> {
private final MultiServerUserRegistry userRegistry;
private final SimpMessagingTemplate brokerTemplate;
private final String broadcastDestination;
private final TaskScheduler scheduler;
private final UserRegistryTask schedulerTask = new UserRegistryTask();
private volatile ScheduledFuture<?> scheduledFuture;
private long registryExpirationPeriod = 20 * 1000;
public UserRegistryMessageHandler(SimpUserRegistry userRegistry, SimpMessagingTemplate brokerTemplate,
String broadcastDestination, TaskScheduler scheduler) {
Assert.notNull(userRegistry, "'userRegistry' is required");
Assert.isInstanceOf(MultiServerUserRegistry.class, userRegistry);
Assert.notNull(brokerTemplate, "'brokerTemplate' is required");
Assert.hasText(broadcastDestination, "'broadcastDestination' is required");
Assert.notNull(scheduler, "'scheduler' is required");
this.userRegistry = (MultiServerUserRegistry) userRegistry;
this.brokerTemplate = brokerTemplate;
this.broadcastDestination = broadcastDestination;
this.scheduler = scheduler;
}
/**
* Return the destination for broadcasting user registry information to.
*/
public String getBroadcastDestination() {
return this.broadcastDestination;
}
/**
* Configure how long before a remote registry snapshot expires.
* <p>By default this is set to 20000 (20 seconds).
* @param expirationPeriod the expiration period in milliseconds
*/
@SuppressWarnings("unused")
public void setRegistryExpirationPeriod(long expirationPeriod) {
this.registryExpirationPeriod = expirationPeriod;
}
/**
* Return the configured registry expiration period.
*/
public long getRegistryExpirationPeriod() {
return this.registryExpirationPeriod;
}
@Override
public void onApplicationEvent(BrokerAvailabilityEvent event) {
if (event.isBrokerAvailable()) {
long delay = getRegistryExpirationPeriod() / 2;
this.scheduledFuture = this.scheduler.scheduleWithFixedDelay(this.schedulerTask, delay);
}
else if (this.scheduledFuture != null ){
this.scheduledFuture.cancel(true);
this.scheduledFuture = null;
}
}
@Override
public void handleMessage(Message<?> message) throws MessagingException {
MessageConverter converter = this.brokerTemplate.getMessageConverter();
this.userRegistry.addRemoteRegistryDto(message, converter, getRegistryExpirationPeriod());
}
private class UserRegistryTask implements Runnable {
@Override
public void run() {
try {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
accessor.setHeader(SimpMessageHeaderAccessor.IGNORE_ERROR, true);
accessor.setLeaveMutable(true);
Object payload = userRegistry.getLocalRegistryDto();
brokerTemplate.convertAndSend(getBroadcastDestination(), payload, accessor.getMessageHeaders());
}
finally {
userRegistry.purgeExpiredRegistries();
}
}
}
}

View File

@ -19,34 +19,41 @@ package org.springframework.messaging.simp.user;
import java.util.Set;
/**
* A registry for looking up active user sessions. For use when resolving user
* destinations.
* A contract for adding and removing user sessions.
*
* <p>As of 4.2 this interface extends {@link SimpUserRegistry}.
* exposing methods to return all registered users as well as to provide more
* extensive information for each user.
*
* @author Rossen Stoyanchev
* @since 4.0
* @see DefaultUserDestinationResolver
* @deprecated in favor of {@link SimpUserRegistry} in combination with
* {@link org.springframework.context.ApplicationListener} listening for
* {@link org.springframework.web.socket.messaging.AbstractSubProtocolEvent} events.
*/
@Deprecated
public interface UserSessionRegistry {
/**
* Return the active session id's for the user.
* @param user the user
* @return a set with 0 or more session id's, never {@code null}.
* Return the active session ids for the user.
* The returned set is a snapshot that will never be modified.
* @param userName the user to look up
* @return a set with 0 or more session ids, never {@code null}.
*/
Set<String> getSessionIds(String user);
Set<String> getSessionIds(String userName);
/**
* Register an active session id for a user.
* @param user the user
* @param userName the user name
* @param sessionId the session id
*/
void registerSessionId(String user, String sessionId);
void registerSessionId(String userName, String sessionId);
/**
* Unregister an active session id for a user.
* @param user the user
* @param userName the user name
* @param sessionId the session id
*/
void unregisterSessionId(String user, String sessionId);
void unregisterSessionId(String userName, String sessionId);
}

View File

@ -0,0 +1,121 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.util.CollectionUtils;
/**
* A temporary adapter to allow use of deprecated {@link UserSessionRegistry}.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
@SuppressWarnings("deprecation")
public class UserSessionRegistryAdapter implements SimpUserRegistry {
private final UserSessionRegistry delegate;
public UserSessionRegistryAdapter(UserSessionRegistry delegate) {
this.delegate = delegate;
}
@Override
public SimpUser getUser(String userName) {
Set<String> sessionIds = this.delegate.getSessionIds(userName);
return (!CollectionUtils.isEmpty(sessionIds) ? new SimpleSimpUser(userName, sessionIds) : null);
}
@Override
public Set<SimpUser> getUsers() {
throw new UnsupportedOperationException("UserSessionRegistry does not expose a listing of users.");
}
@Override
public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) {
throw new UnsupportedOperationException("UserSessionRegistry does not support operations across users.");
}
private static class SimpleSimpUser implements SimpUser {
private final String name;
private final Map<String, SimpSession> sessions;
public SimpleSimpUser(String name, Set<String> sessionIds) {
this.name = name;
this.sessions = new HashMap<String, SimpSession>(sessionIds.size());
for (String sessionId : sessionIds) {
this.sessions.put(sessionId, new SimpleSimpSession(sessionId));
}
}
@Override
public String getName() {
return this.name;
}
@Override
public boolean hasSessions() {
return !this.sessions.isEmpty();
}
@Override
public SimpSession getSession(String sessionId) {
return this.sessions.get(sessionId);
}
@Override
public Set<SimpSession> getSessions() {
return new HashSet<SimpSession>(this.sessions.values());
}
}
private static class SimpleSimpSession implements SimpSession {
private final String id;
public SimpleSimpSession(String id) {
this.id = id;
}
@Override
public String getId() {
return this.id;
}
@Override
public SimpUser getUser() {
return null;
}
@Override
public Set<SimpSubscription> getSubscriptions() {
return Collections.<SimpSubscription>emptySet();
}
}
}

View File

@ -3,7 +3,7 @@
* unique to a user's sessions), primarily translating the destinations and then
* forwarding the updated message to the broker.
*
* <p>Also included is {@link org.springframework.messaging.simp.user.UserSessionRegistry}
* <p>Also included is {@link org.springframework.messaging.simp.user.SimpUserRegistry}
* for keeping track of connected user sessions.
*/
package org.springframework.messaging.simp.user;

View File

@ -16,6 +16,9 @@
package org.springframework.messaging.simp.config;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
@ -51,8 +54,10 @@ import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.MultiServerUserRegistry;
import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.simp.user.UserRegistryMessageHandler;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ChannelInterceptorAdapter;
@ -66,9 +71,6 @@ import org.springframework.validation.Errors;
import org.springframework.validation.Validator;
import org.springframework.validation.beanvalidation.OptionalValidatorFactoryBean;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link AbstractMessageBrokerConfiguration}.
*
@ -235,26 +237,6 @@ public class MessageBrokerConfigurationTests {
assertEquals("bar", new String((byte[]) message.getPayload()));
}
@Test
public void brokerChannelUsedByUserDestinationMessageHandler() {
TestChannel channel = this.simpleBrokerContext.getBean("brokerChannel", TestChannel.class);
UserDestinationMessageHandler messageHandler = this.simpleBrokerContext.getBean(UserDestinationMessageHandler.class);
this.simpleBrokerContext.getBean(UserSessionRegistry.class).registerSessionId("joe", "s1");
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setDestination("/user/joe/foo");
Message<?> message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
messageHandler.handleMessage(message);
message = channel.messages.get(0);
headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
assertEquals("/foo-users1", headers.getDestination());
}
@Test
public void brokerChannelCustomized() {
AbstractSubscribableChannel channel = this.customContext.getBean(
@ -272,7 +254,7 @@ public class MessageBrokerConfigurationTests {
@Test
public void configureMessageConvertersDefault() {
AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {};
AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig();
CompositeMessageConverter compositeConverter = config.brokerMessageConverter();
List<MessageConverter> converters = compositeConverter.getConverters();
@ -305,7 +287,7 @@ public class MessageBrokerConfigurationTests {
@Test
public void configureMessageConvertersCustom() {
final MessageConverter testConverter = mock(MessageConverter.class);
AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {
AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() {
@Override
protected boolean configureMessageConverters(List<MessageConverter> messageConverters) {
messageConverters.add(testConverter);
@ -323,7 +305,7 @@ public class MessageBrokerConfigurationTests {
public void configureMessageConvertersCustomAndDefault() {
final MessageConverter testConverter = mock(MessageConverter.class);
AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {
AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() {
@Override
protected boolean configureMessageConverters(List<MessageConverter> messageConverters) {
messageConverters.add(testConverter);
@ -355,7 +337,7 @@ public class MessageBrokerConfigurationTests {
@Test
public void simpValidatorDefault() {
AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {};
AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() {};
config.setApplicationContext(new StaticApplicationContext());
assertThat(config.simpValidator(), Matchers.notNullValue());
@ -365,7 +347,7 @@ public class MessageBrokerConfigurationTests {
@Test
public void simpValidatorCustom() {
final Validator validator = mock(Validator.class);
AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {
AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() {
@Override
public Validator getValidator() {
return validator;
@ -379,7 +361,7 @@ public class MessageBrokerConfigurationTests {
public void simpValidatorMvc() {
StaticApplicationContext appCxt = new StaticApplicationContext();
appCxt.registerSingleton("mvcValidator", TestValidator.class);
AbstractMessageBrokerConfiguration config = new AbstractMessageBrokerConfiguration() {};
AbstractMessageBrokerConfiguration config = new BaseTestMessageBrokerConfig() {};
config.setApplicationContext(appCxt);
assertThat(config.simpValidator(), Matchers.notNullValue());
@ -405,12 +387,35 @@ public class MessageBrokerConfigurationTests {
}
@Test
public void userDestinationBroadcast() throws Exception {
public void userBroadcasts() throws Exception {
SimpUserRegistry userRegistry = this.brokerRelayContext.getBean(SimpUserRegistry.class);
assertEquals(MultiServerUserRegistry.class, userRegistry.getClass());
UserDestinationMessageHandler handler1 = this.brokerRelayContext.getBean(UserDestinationMessageHandler.class);
assertEquals("/topic/unresolved-user-destination", handler1.getBroadcastDestination());
UserRegistryMessageHandler handler2 = this.brokerRelayContext.getBean(UserRegistryMessageHandler.class);
assertEquals("/topic/simp-user-registry", handler2.getBroadcastDestination());
StompBrokerRelayMessageHandler relay = this.brokerRelayContext.getBean(StompBrokerRelayMessageHandler.class);
UserDestinationMessageHandler userHandler = this.brokerRelayContext.getBean(UserDestinationMessageHandler.class);
assertEquals("/topic/unresolved", userHandler.getUserDestinationBroadcast());
assertNotNull(relay.getSystemSubscriptions());
assertSame(userHandler, relay.getSystemSubscriptions().get("/topic/unresolved"));
assertEquals(2, relay.getSystemSubscriptions().size());
assertSame(handler1, relay.getSystemSubscriptions().get("/topic/unresolved-user-destination"));
assertSame(handler2, relay.getSystemSubscriptions().get("/topic/simp-user-registry"));
}
@Test
public void userBroadcastsDisabledWithSimpleBroker() throws Exception {
SimpUserRegistry registry = this.simpleBrokerContext.getBean(SimpUserRegistry.class);
assertNotNull(registry);
assertNotEquals(MultiServerUserRegistry.class, registry.getClass());
UserDestinationMessageHandler handler = this.simpleBrokerContext.getBean(UserDestinationMessageHandler.class);
assertNull(handler.getBroadcastDestination());
String name = "userRegistryMessageHandler";
MessageHandler messageHandler = this.simpleBrokerContext.getBean(name, MessageHandler.class);
assertNotEquals(UserRegistryMessageHandler.class, messageHandler.getClass());
}
@ -430,9 +435,17 @@ public class MessageBrokerConfigurationTests {
}
}
static class BaseTestMessageBrokerConfig extends AbstractMessageBrokerConfiguration {
@Override
protected SimpUserRegistry createLocalUserRegistry() {
return mock(SimpUserRegistry.class);
}
}
@SuppressWarnings("unused")
@Configuration
static class SimpleBrokerConfig extends AbstractMessageBrokerConfiguration {
static class SimpleBrokerConfig extends BaseTestMessageBrokerConfig {
@Bean
public TestController subscriptionController() {
@ -463,17 +476,18 @@ public class MessageBrokerConfigurationTests {
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableStompBrokerRelay("/topic", "/queue").setAutoStartup(true);
registry.setUserDestinationBroadcast("/topic/unresolved");
registry.enableStompBrokerRelay("/topic", "/queue").setAutoStartup(true)
.setUserDestinationBroadcast("/topic/unresolved-user-destination")
.setUserRegistryBroadcast("/topic/simp-user-registry");
}
}
@Configuration
static class DefaultConfig extends AbstractMessageBrokerConfiguration {
static class DefaultConfig extends BaseTestMessageBrokerConfig {
}
@Configuration
static class CustomConfig extends AbstractMessageBrokerConfiguration {
static class CustomConfig extends BaseTestMessageBrokerConfig {
private ChannelInterceptor interceptor = new ChannelInterceptorAdapter() {};

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,7 +29,8 @@ import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import static org.junit.Assert.*;
/**
* Unit tests for {@link org.springframework.messaging.simp.config.StompBrokerRelayRegistration}.
* Unit tests for
* {@link org.springframework.messaging.simp.config.StompBrokerRelayRegistration}.
*
* @author Rossen Stoyanchev
*/
@ -39,15 +40,11 @@ public class StompBrokerRelayRegistrationTests {
@Test
public void test() {
SubscribableChannel clientInboundChannel = new StubMessageChannel();
MessageChannel clientOutboundChannel = new StubMessageChannel();
SubscribableChannel brokerChannel = new StubMessageChannel();
String[] destinationPrefixes = new String[] { "/foo", "/bar" };
StompBrokerRelayRegistration registration = new StompBrokerRelayRegistration(
clientInboundChannel, clientOutboundChannel, destinationPrefixes);
SubscribableChannel inChannel = new StubMessageChannel();
MessageChannel outChannel = new StubMessageChannel();
String[] prefixes = new String[] { "/foo", "/bar" };
StompBrokerRelayRegistration registration = new StompBrokerRelayRegistration(inChannel, outChannel, prefixes);
registration.setClientLogin("clientlogin");
registration.setClientPasscode("clientpasscode");
registration.setSystemLogin("syslogin");
@ -56,18 +53,16 @@ public class StompBrokerRelayRegistrationTests {
registration.setSystemHeartbeatSendInterval(456);
registration.setVirtualHost("example.org");
StompBrokerRelayMessageHandler relayMessageHandler = registration.getMessageHandler(brokerChannel);
StompBrokerRelayMessageHandler handler = registration.getMessageHandler(new StubMessageChannel());
assertEquals(Arrays.asList(destinationPrefixes),
new ArrayList<String>(relayMessageHandler.getDestinationPrefixes()));
assertEquals("clientlogin", relayMessageHandler.getClientLogin());
assertEquals("clientpasscode", relayMessageHandler.getClientPasscode());
assertEquals("syslogin", relayMessageHandler.getSystemLogin());
assertEquals("syspasscode", relayMessageHandler.getSystemPasscode());
assertEquals(123, relayMessageHandler.getSystemHeartbeatReceiveInterval());
assertEquals(456, relayMessageHandler.getSystemHeartbeatSendInterval());
assertEquals("example.org", relayMessageHandler.getVirtualHost());
assertArrayEquals(prefixes, handler.getDestinationPrefixes().toArray(new String[2]));
assertEquals("clientlogin", handler.getClientLogin());
assertEquals("clientpasscode", handler.getClientPasscode());
assertEquals("syslogin", handler.getSystemLogin());
assertEquals("syspasscode", handler.getSystemPasscode());
assertEquals(123, handler.getSystemHeartbeatReceiveInterval());
assertEquals(456, handler.getSystemHeartbeatSendInterval());
assertEquals("example.org", handler.getVirtualHost());
}
}

View File

@ -17,6 +17,9 @@
package org.springframework.messaging.simp.user;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import java.security.Principal;
import org.junit.Before;
import org.junit.Test;
@ -36,35 +39,36 @@ import org.springframework.util.StringUtils;
*/
public class DefaultUserDestinationResolverTests {
public static final String SESSION_ID = "123";
private DefaultUserDestinationResolver resolver;
private UserSessionRegistry registry;
private TestPrincipal user;
private SimpUserRegistry registry;
@Before
public void setup() {
this.user = new TestPrincipal("joe");
this.registry = new DefaultUserSessionRegistry();
this.registry.registerSessionId(this.user.getName(), SESSION_ID);
TestSimpUser simpUser = new TestSimpUser("joe");
simpUser.addSessions(new TestSimpSession("123"));
this.registry = mock(SimpUserRegistry.class);
when(this.registry.getUser("joe")).thenReturn(simpUser);
this.resolver = new DefaultUserDestinationResolver(this.registry);
}
@Test
public void handleSubscribe() {
TestPrincipal user = new TestPrincipal("joe");
String sourceDestination = "/user/queue/foo";
Message<?> message = createWith(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, sourceDestination);
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, user, "123", sourceDestination);
UserDestinationResult actual = this.resolver.resolveDestination(message);
assertEquals(sourceDestination, actual.getSourceDestination());
assertEquals(1, actual.getTargetDestinations().size());
assertEquals("/queue/foo-user123", actual.getTargetDestinations().iterator().next());
assertEquals(sourceDestination, actual.getSubscribeDestination());
assertEquals(this.user.getName(), actual.getUser());
assertEquals(user.getName(), actual.getUser());
}
// SPR-11325
@ -72,32 +76,35 @@ public class DefaultUserDestinationResolverTests {
@Test
public void handleSubscribeOneUserMultipleSessions() {
this.registry.registerSessionId("joe", "456");
this.registry.registerSessionId("joe", "789");
TestSimpUser simpUser = new TestSimpUser("joe");
simpUser.addSessions(new TestSimpSession("123"), new TestSimpSession("456"));
when(this.registry.getUser("joe")).thenReturn(simpUser);
Message<?> message = createWith(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo");
TestPrincipal user = new TestPrincipal("joe");
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, user, "456", "/user/queue/foo");
UserDestinationResult actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.getTargetDestinations().size());
assertEquals("/queue/foo-user123", actual.getTargetDestinations().iterator().next());
assertEquals("/queue/foo-user456", actual.getTargetDestinations().iterator().next());
}
@Test
public void handleSubscribeNoUser() {
String sourceDestination = "/user/queue/foo";
Message<?> message = createWith(SimpMessageType.SUBSCRIBE, null, SESSION_ID, sourceDestination);
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, null, "123", sourceDestination);
UserDestinationResult actual = this.resolver.resolveDestination(message);
assertEquals(sourceDestination, actual.getSourceDestination());
assertEquals(1, actual.getTargetDestinations().size());
assertEquals("/queue/foo-user" + SESSION_ID, actual.getTargetDestinations().iterator().next());
assertEquals("/queue/foo-user" + "123", actual.getTargetDestinations().iterator().next());
assertEquals(sourceDestination, actual.getSubscribeDestination());
assertNull(actual.getUser());
}
@Test
public void handleUnsubscribe() {
Message<?> message = createWith(SimpMessageType.UNSUBSCRIBE, this.user, SESSION_ID, "/user/queue/foo");
TestPrincipal user = new TestPrincipal("joe");
Message<?> message = createMessage(SimpMessageType.UNSUBSCRIBE, user, "123", "/user/queue/foo");
UserDestinationResult actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.getTargetDestinations().size());
@ -106,32 +113,37 @@ public class DefaultUserDestinationResolverTests {
@Test
public void handleMessage() {
TestPrincipal user = new TestPrincipal("joe");
String sourceDestination = "/user/joe/queue/foo";
Message<?> message = createWith(SimpMessageType.MESSAGE, this.user, SESSION_ID, sourceDestination);
Message<?> message = createMessage(SimpMessageType.MESSAGE, user, "123", sourceDestination);
UserDestinationResult actual = this.resolver.resolveDestination(message);
assertEquals(sourceDestination, actual.getSourceDestination());
assertEquals(1, actual.getTargetDestinations().size());
assertEquals("/queue/foo-user123", actual.getTargetDestinations().iterator().next());
assertEquals("/user/queue/foo", actual.getSubscribeDestination());
assertEquals(this.user.getName(), actual.getUser());
assertEquals(user.getName(), actual.getUser());
}
// SPR-12444
@Test
public void handleMessageToOtherUser() {
final String OTHER_SESSION_ID = "456";
final String OTHER_USER_NAME = "anna";
TestSimpUser otherSimpUser = new TestSimpUser("anna");
otherSimpUser.addSessions(new TestSimpSession("456"));
when(this.registry.getUser("anna")).thenReturn(otherSimpUser);
TestPrincipal user = new TestPrincipal("joe");
TestPrincipal otherUser = new TestPrincipal("anna");
String sourceDestination = "/user/anna/queue/foo";
Message<?> message = createMessage(SimpMessageType.MESSAGE, user, "456", sourceDestination);
String sourceDestination = "/user/"+OTHER_USER_NAME+"/queue/foo";
TestPrincipal otherUser = new TestPrincipal(OTHER_USER_NAME);
this.registry.registerSessionId(otherUser.getName(), OTHER_SESSION_ID);
Message<?> message = createWith(SimpMessageType.MESSAGE, this.user, SESSION_ID, sourceDestination);
UserDestinationResult actual = this.resolver.resolveDestination(message);
assertEquals(sourceDestination, actual.getSourceDestination());
assertEquals(1, actual.getTargetDestinations().size());
assertEquals("/queue/foo-user" + OTHER_SESSION_ID, actual.getTargetDestinations().iterator().next());
assertEquals("/queue/foo-user456", actual.getTargetDestinations().iterator().next());
assertEquals("/user/queue/foo", actual.getSubscribeDestination());
assertEquals(otherUser.getName(), actual.getUser());
}
@ -140,9 +152,14 @@ public class DefaultUserDestinationResolverTests {
public void handleMessageEncodedUserName() {
String userName = "http://joe.openid.example.org/";
this.registry.registerSessionId(userName, "openid123");
TestSimpUser simpUser = new TestSimpUser(userName);
simpUser.addSessions(new TestSimpSession("openid123"));
when(this.registry.getUser(userName)).thenReturn(simpUser);
String destination = "/user/" + StringUtils.replace(userName, "/", "%2F") + "/queue/foo";
Message<?> message = createWith(SimpMessageType.MESSAGE, this.user, null, destination);
Message<?> message = createMessage(SimpMessageType.MESSAGE, new TestPrincipal("joe"), null, destination);
UserDestinationResult actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.getTargetDestinations().size());
@ -151,8 +168,8 @@ public class DefaultUserDestinationResolverTests {
@Test
public void handleMessageWithNoUser() {
String sourceDestination = "/user/" + SESSION_ID + "/queue/foo";
Message<?> message = createWith(SimpMessageType.MESSAGE, null, SESSION_ID, sourceDestination);
String sourceDestination = "/user/" + "123" + "/queue/foo";
Message<?> message = createMessage(SimpMessageType.MESSAGE, null, "123", sourceDestination);
UserDestinationResult actual = this.resolver.resolveDestination(message);
assertEquals(sourceDestination, actual.getSourceDestination());
@ -166,28 +183,28 @@ public class DefaultUserDestinationResolverTests {
public void ignoreMessage() {
// no destination
Message<?> message = createWith(SimpMessageType.MESSAGE, this.user, SESSION_ID, null);
TestPrincipal user = new TestPrincipal("joe");
Message<?> message = createMessage(SimpMessageType.MESSAGE, user, "123", null);
UserDestinationResult actual = this.resolver.resolveDestination(message);
assertNull(actual);
// not a user destination
message = createWith(SimpMessageType.MESSAGE, this.user, SESSION_ID, "/queue/foo");
message = createMessage(SimpMessageType.MESSAGE, user, "123", "/queue/foo");
actual = this.resolver.resolveDestination(message);
assertNull(actual);
// subscribe + not a user destination
message = createWith(SimpMessageType.SUBSCRIBE, this.user, SESSION_ID, "/queue/foo");
message = createMessage(SimpMessageType.SUBSCRIBE, user, "123", "/queue/foo");
actual = this.resolver.resolveDestination(message);
assertNull(actual);
// no match on message type
message = createWith(SimpMessageType.CONNECT, this.user, SESSION_ID, "user/joe/queue/foo");
message = createMessage(SimpMessageType.CONNECT, user, "123", "user/joe/queue/foo");
actual = this.resolver.resolveDestination(message);
assertNull(actual);
}
private Message<?> createWith(SimpMessageType type, TestPrincipal user, String sessionId, String destination) {
private Message<?> createMessage(SimpMessageType type, Principal user, String sessionId, String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type);
if (destination != null) {
headers.setDestination(destination);

View File

@ -1,82 +0,0 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import static org.junit.Assert.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import org.junit.Test;
/**
* Test fixture for
* {@link org.springframework.messaging.simp.user.DefaultUserSessionRegistry}
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class DefaultUserSessionRegistryTests {
private static final String user = "joe";
private static final List<String> sessionIds = Arrays.asList("sess01", "sess02", "sess03");
@Test
public void addOneSessionId() {
DefaultUserSessionRegistry resolver = new DefaultUserSessionRegistry();
resolver.registerSessionId(user, sessionIds.get(0));
assertEquals(Collections.singleton(sessionIds.get(0)), resolver.getSessionIds(user));
assertSame(Collections.emptySet(), resolver.getSessionIds("jane"));
}
@Test
public void addMultipleSessionIds() {
DefaultUserSessionRegistry resolver = new DefaultUserSessionRegistry();
for (String sessionId : sessionIds) {
resolver.registerSessionId(user, sessionId);
}
assertEquals(new LinkedHashSet<>(sessionIds), resolver.getSessionIds(user));
assertEquals(Collections.<String>emptySet(), resolver.getSessionIds("jane"));
}
@Test
public void removeSessionIds() {
DefaultUserSessionRegistry resolver = new DefaultUserSessionRegistry();
for (String sessionId : sessionIds) {
resolver.registerSessionId(user, sessionId);
}
assertEquals(new LinkedHashSet<>(sessionIds), resolver.getSessionIds(user));
resolver.unregisterSessionId(user, sessionIds.get(1));
resolver.unregisterSessionId(user, sessionIds.get(2));
assertEquals(Collections.singleton(sessionIds.get(0)), resolver.getSessionIds(user));
resolver.unregisterSessionId(user, sessionIds.get(0));
assertSame(Collections.emptySet(), resolver.getSessionIds(user));
}
}

View File

@ -0,0 +1,167 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.messaging.Message;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.converter.MessageConverter;
/**
* Unit tests for {@link MultiServerUserRegistry}.
*
* @author Rossen Stoyanchev
*/
public class MultiServerUserRegistryTests {
private SimpUserRegistry localRegistry;
private MultiServerUserRegistry multiServerRegistry;
private MessageConverter converter;
@Before
public void setUp() throws Exception {
this.localRegistry = Mockito.mock(SimpUserRegistry.class);
this.multiServerRegistry = new MultiServerUserRegistry(this.localRegistry);
this.converter = new MappingJackson2MessageConverter();
}
@Test
public void getUserFromLocalRegistry() throws Exception {
SimpUser user = Mockito.mock(SimpUser.class);
Set<SimpUser> users = Collections.singleton(user);
when(this.localRegistry.getUsers()).thenReturn(users);
when(this.localRegistry.getUser("joe")).thenReturn(user);
assertEquals(1, this.multiServerRegistry.getUsers().size());
assertSame(user, this.multiServerRegistry.getUser("joe"));
}
@Test
public void getUserFromRemoteRegistry() throws Exception {
TestSimpSession remoteSession = new TestSimpSession("remote-sess");
remoteSession.addSubscriptions(new TestSimpSubscription("remote-sub", "/remote-dest"));
TestSimpUser remoteUser = new TestSimpUser("joe");
remoteUser.addSessions(remoteSession);
SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class);
when(remoteUserRegistry.getUsers()).thenReturn(Collections.singleton(remoteUser));
MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry);
Message<?> message = this.converter.toMessage(remoteRegistry.getLocalRegistryDto(), null);
this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, 20000);
assertEquals(1, this.multiServerRegistry.getUsers().size());
SimpUser user = this.multiServerRegistry.getUser("joe");
assertNotNull(user);
assertEquals(1, user.getSessions().size());
SimpSession session = user.getSession("remote-sess");
assertNotNull(session);
assertEquals("remote-sess", session.getId());
assertSame(user, session.getUser());
assertEquals(1, session.getSubscriptions().size());
SimpSubscription subscription = session.getSubscriptions().iterator().next();
assertEquals("remote-sub", subscription.getId());
assertSame(session, subscription.getSession());
assertEquals("/remote-dest", subscription.getDestination());
}
@Test
public void findUserFromRemoteRegistry() throws Exception {
TestSimpSubscription subscription1 = new TestSimpSubscription("sub1", "/match");
TestSimpSession session1 = new TestSimpSession("sess1");
session1.addSubscriptions(subscription1);
TestSimpUser user1 = new TestSimpUser("joe");
user1.addSessions(session1);
TestSimpSubscription subscription2 = new TestSimpSubscription("sub1", "/match");
TestSimpSession session2 = new TestSimpSession("sess2");
session2.addSubscriptions(subscription2);
TestSimpUser user2 = new TestSimpUser("jane");
user2.addSessions(session2);
TestSimpSubscription subscription3 = new TestSimpSubscription("sub1", "/not-a-match");
TestSimpSession session3 = new TestSimpSession("sess3");
session3.addSubscriptions(subscription3);
TestSimpUser user3 = new TestSimpUser("jack");
user3.addSessions(session3);
SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class);
when(remoteUserRegistry.getUsers()).thenReturn(new HashSet<SimpUser>(Arrays.asList(user1, user2, user3)));
MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry);
Message<?> message = this.converter.toMessage(remoteRegistry.getLocalRegistryDto(), null);
this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, 20000);
assertEquals(3, this.multiServerRegistry.getUsers().size());
Set<SimpSubscription> matches = this.multiServerRegistry.findSubscriptions(new SimpSubscriptionMatcher() {
@Override
public boolean match(SimpSubscription subscription) {
return subscription.getDestination().equals("/match");
}
});
assertEquals(2, matches.size());
Iterator<SimpSubscription> iterator = matches.iterator();
Set<String> sessionIds = new HashSet<>(2);
sessionIds.add(iterator.next().getSession().getId());
sessionIds.add(iterator.next().getSession().getId());
assertEquals(new HashSet<>(Arrays.asList("sess1", "sess2")), sessionIds);
}
@Test
public void purgeExpiredRegistries() throws Exception {
TestSimpUser remoteUser = new TestSimpUser("joe");
remoteUser.addSessions(new TestSimpSession("remote-sub"));
SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class);
when(remoteUserRegistry.getUsers()).thenReturn(Collections.singleton(remoteUser));
MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry);
Message<?> message = this.converter.toMessage(remoteRegistry.getLocalRegistryDto(), null);
long expirationPeriod = -1;
this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, expirationPeriod);
assertEquals(1, this.multiServerRegistry.getUsers().size());
this.multiServerRegistry.purgeExpiredRegistries();
assertEquals(0, this.multiServerRegistry.getUsers().size());
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
public class TestSimpSession implements SimpSession {
private String id;
private TestSimpUser user;
private Set<SimpSubscription> subscriptions = new HashSet<>();
public TestSimpSession(String id) {
this.id = id;
}
@Override
public String getId() {
return id;
}
@Override
public TestSimpUser getUser() {
return user;
}
public void setUser(TestSimpUser user) {
this.user = user;
}
@Override
public Set<SimpSubscription> getSubscriptions() {
return subscriptions;
}
public void addSubscriptions(TestSimpSubscription... subscriptions) {
for (TestSimpSubscription subscription : subscriptions) {
subscription.setSession(this);
this.subscriptions.add(subscription);
}
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
public class TestSimpSubscription implements SimpSubscription {
private String id;
private TestSimpSession session;
private String destination;
public TestSimpSubscription(String id, String destination) {
this.destination = destination;
this.id = id;
}
@Override
public String getId() {
return id;
}
@Override
public TestSimpSession getSession() {
return this.session;
}
public void setSession(TestSimpSession session) {
this.session = session;
}
@Override
public String getDestination() {
return destination;
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
public class TestSimpUser implements SimpUser {
private String name;
private Map<String, SimpSession> sessions = new HashMap<>();
public TestSimpUser(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
@Override
public Set<SimpSession> getSessions() {
return new HashSet<>(this.sessions.values());
}
@Override
public boolean hasSessions() {
return !this.sessions.isEmpty();
}
@Override
public SimpSession getSession(String sessionId) {
return this.sessions.get(sessionId);
}
public void addSessions(TestSimpSession... sessions) {
for (TestSimpSession session : sessions) {
session.setUser(this);
this.sessions.put(session.getId(), session);
}
}
}

View File

@ -25,9 +25,7 @@ import java.nio.charset.Charset;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.springframework.messaging.Message;
import org.springframework.messaging.StubMessageChannel;
@ -50,16 +48,15 @@ public class UserDestinationMessageHandlerTests {
private UserDestinationMessageHandler handler;
private UserSessionRegistry registry;
private SimpUserRegistry registry;
@Mock
private SubscribableChannel brokerChannel;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
this.registry = new DefaultUserSessionRegistry();
this.registry = mock(SimpUserRegistry.class);
this.brokerChannel = mock(SubscribableChannel.class);
UserDestinationResolver resolver = new DefaultUserDestinationResolver(this.registry);
this.handler = new UserDestinationMessageHandler(new StubMessageChannel(), this.brokerChannel, resolver);
}
@ -91,7 +88,9 @@ public class UserDestinationMessageHandlerTests {
@Test
public void handleMessage() {
this.registry.registerSessionId("joe", "123");
TestSimpUser simpUser = new TestSimpUser("joe");
simpUser.addSessions(new TestSimpSession("123"));
when(this.registry.getUser("joe")).thenReturn(simpUser);
given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true);
this.handler.handleMessage(createWith(SimpMessageType.MESSAGE, "joe", "123", "/user/joe/queue/foo"));
@ -105,7 +104,7 @@ public class UserDestinationMessageHandlerTests {
@Test
public void handleMessageWithoutActiveSession() {
this.handler.setUserDestinationBroadcast("/topic/unresolved");
this.handler.setBroadcastDestination("/topic/unresolved");
given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true);
this.handler.handleMessage(createWith(SimpMessageType.MESSAGE, "joe", "123", "/user/joe/queue/foo"));
@ -126,9 +125,11 @@ public class UserDestinationMessageHandlerTests {
@Test
public void handleMessageFromBrokerWithActiveSession() {
this.registry.registerSessionId("joe", "123");
TestSimpUser simpUser = new TestSimpUser("joe");
simpUser.addSessions(new TestSimpSession("123"));
when(this.registry.getUser("joe")).thenReturn(simpUser);
this.handler.setUserDestinationBroadcast("/topic/unresolved");
this.handler.setBroadcastDestination("/topic/unresolved");
given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true);
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.MESSAGE);
@ -152,7 +153,7 @@ public class UserDestinationMessageHandlerTests {
@Test
public void handleMessageFromBrokerWithoutActiveSession() {
this.handler.setUserDestinationBroadcast("/topic/unresolved");
this.handler.setBroadcastDestination("/topic/unresolved");
given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true);
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.MESSAGE);

View File

@ -0,0 +1,183 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import static org.junit.Assert.*;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.concurrent.ScheduledFuture;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.broker.BrokerAvailabilityEvent;
import org.springframework.scheduling.TaskScheduler;
/**
* User tests for {@link UserRegistryMessageHandler}.
* @author Rossen Stoyanchev
*/
public class UserRegistryMessageHandlerTests {
private UserRegistryMessageHandler handler;
private SimpUserRegistry localRegistry;
private MultiServerUserRegistry multiServerRegistry;
private MessageConverter converter;
@Mock
private MessageChannel brokerChannel;
@Mock
private TaskScheduler taskScheduler;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
when(this.brokerChannel.send(any())).thenReturn(true);
this.converter = new MappingJackson2MessageConverter();
SimpMessagingTemplate brokerTemplate = new SimpMessagingTemplate(this.brokerChannel);
brokerTemplate.setMessageConverter(this.converter);
this.localRegistry = mock(SimpUserRegistry.class);
this.multiServerRegistry = new MultiServerUserRegistry(this.localRegistry);
this.handler = new UserRegistryMessageHandler(this.multiServerRegistry, brokerTemplate,
"/topic/simp-user-registry", this.taskScheduler);
}
@Test
public void brokerAvailableEvent() throws Exception {
Runnable runnable = getUserRegistryTask();
assertNotNull(runnable);
}
@SuppressWarnings("unchecked")
@Test
public void brokerUnavailableEvent() throws Exception {
ScheduledFuture future = Mockito.mock(ScheduledFuture.class);
when(this.taskScheduler.scheduleWithFixedDelay(any(Runnable.class), any(Long.class))).thenReturn(future);
BrokerAvailabilityEvent event = new BrokerAvailabilityEvent(true, this);
this.handler.onApplicationEvent(event);
verifyNoMoreInteractions(future);
event = new BrokerAvailabilityEvent(false, this);
this.handler.onApplicationEvent(event);
verify(future).cancel(true);
}
@Test
public void broadcastRegistry() throws Exception {
TestSimpUser simpUser1 = new TestSimpUser("joe");
TestSimpUser simpUser2 = new TestSimpUser("jane");
simpUser1.addSessions(new TestSimpSession("123"));
simpUser1.addSessions(new TestSimpSession("456"));
HashSet<SimpUser> simpUsers = new HashSet<>(Arrays.asList(simpUser1, simpUser2));
when(this.localRegistry.getUsers()).thenReturn(simpUsers);
getUserRegistryTask().run();
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
verify(this.brokerChannel).send(captor.capture());
Message<?> message = captor.getValue();
assertNotNull(message);
MessageHeaders headers = message.getHeaders();
assertEquals("/topic/simp-user-registry", SimpMessageHeaderAccessor.getDestination(headers));
MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(mock(SimpUserRegistry.class));
remoteRegistry.addRemoteRegistryDto(message, this.converter, 20000);
assertEquals(2, remoteRegistry.getUsers().size());
assertNotNull(remoteRegistry.getUser("joe"));
assertNotNull(remoteRegistry.getUser("jane"));
}
@Test
public void handleMessage() throws Exception {
TestSimpUser simpUser1 = new TestSimpUser("joe");
TestSimpUser simpUser2 = new TestSimpUser("jane");
simpUser1.addSessions(new TestSimpSession("123"));
simpUser2.addSessions(new TestSimpSession("456"));
HashSet<SimpUser> simpUsers = new HashSet<>(Arrays.asList(simpUser1, simpUser2));
SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class);
when(remoteUserRegistry.getUsers()).thenReturn(simpUsers);
MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry);
Message<?> message = this.converter.toMessage(remoteRegistry.getLocalRegistryDto(), null);
this.handler.handleMessage(message);
assertEquals(2, remoteRegistry.getUsers().size());
assertNotNull(this.multiServerRegistry.getUser("joe"));
assertNotNull(this.multiServerRegistry.getUser("jane"));
}
@Test
public void handleMessageFromOwnBroadcast() throws Exception {
TestSimpUser simpUser = new TestSimpUser("joe");
simpUser.addSessions(new TestSimpSession("123"));
when(this.localRegistry.getUsers()).thenReturn(Collections.singleton(simpUser));
assertEquals(1, this.multiServerRegistry.getUsers().size());
Message<?> message = this.converter.toMessage(this.multiServerRegistry.getLocalRegistryDto(), null);
this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, 20000);
assertEquals(1, this.multiServerRegistry.getUsers().size());
}
private Runnable getUserRegistryTask() {
BrokerAvailabilityEvent event = new BrokerAvailabilityEvent(true, this);
this.handler.onApplicationEvent(event);
ArgumentCaptor<? extends Runnable> captor = ArgumentCaptor.forClass(Runnable.class);
verify(this.taskScheduler).scheduleWithFixedDelay(captor.capture(), eq(10000L));
return captor.getValue();
}
}

View File

@ -48,8 +48,9 @@ import org.springframework.messaging.simp.SimpSessionScope;
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.user.DefaultUserDestinationResolver;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.MultiServerUserRegistry;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserRegistryMessageHandler;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@ -61,6 +62,7 @@ import org.springframework.util.xml.DomUtils;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.springframework.web.socket.messaging.DefaultSimpUserRegistry;
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.messaging.WebSocketAnnotationMethodMessageHandler;
@ -98,6 +100,8 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
public static final String WEB_SOCKET_HANDLER_BEAN_NAME = "subProtocolWebSocketHandler";
public static final String SCHEDULER_BEAN_NAME = "messageBrokerScheduler";
public static final String SOCKJS_SCHEDULER_BEAN_NAME = "messageBrokerSockJsScheduler";
private static final int DEFAULT_MAPPING_ORDER = 1;
@ -108,10 +112,82 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
@Override
public BeanDefinition parse(Element element, ParserContext context) {
Object source = context.extractSource(element);
CompositeComponentDefinition compDefinition = new CompositeComponentDefinition(element.getTagName(), source);
context.pushContainingComponent(compDefinition);
Element channelElem = DomUtils.getChildElementByTagName(element, "client-inbound-channel");
RuntimeBeanReference inChannel = getMessageChannel("clientInboundChannel", channelElem, context, source);
channelElem = DomUtils.getChildElementByTagName(element, "client-outbound-channel");
RuntimeBeanReference outChannel = getMessageChannel("clientOutboundChannel", channelElem, context, source);
channelElem = DomUtils.getChildElementByTagName(element, "broker-channel");
RuntimeBeanReference brokerChannel = getMessageChannel("brokerChannel", channelElem, context, source);
RuntimeBeanReference userRegistry = registerUserRegistry(element, context, source);
Object userDestHandler = registerUserDestHandler(element, userRegistry, inChannel, brokerChannel, context, source);
RuntimeBeanReference converter = registerMessageConverter(element, context, source);
RuntimeBeanReference template = registerMessagingTemplate(element, brokerChannel, converter, context, source);
registerAnnotationMethodMessageHandler(element, inChannel, outChannel,converter, template, context, source);
RootBeanDefinition broker = registerMessageBroker(element, inChannel, outChannel, brokerChannel,
userDestHandler, template, userRegistry, context, source);
// WebSocket and sub-protocol handling
ManagedMap<String, Object> urlMap = registerHandlerMapping(element, context, source);
RuntimeBeanReference stompHandler = registerStompHandler(element, inChannel, outChannel, context, source);
for (Element endpointElem : DomUtils.getChildElementsByTagName(element, "stomp-endpoint")) {
RuntimeBeanReference requestHandler = registerRequestHandler(endpointElem, stompHandler, context, source);
String pathAttribute = endpointElem.getAttribute("path");
Assert.state(StringUtils.hasText(pathAttribute), "Invalid <stomp-endpoint> (no path mapping)");
List<String> paths = Arrays.asList(StringUtils.tokenizeToStringArray(pathAttribute, ","));
for (String path : paths) {
path = path.trim();
Assert.state(StringUtils.hasText(path), "Invalid <stomp-endpoint> path attribute: " + pathAttribute);
if (DomUtils.getChildElementByTagName(endpointElem, "sockjs") != null) {
path = path.endsWith("/") ? path + "**" : path + "/**";
}
urlMap.put(path, requestHandler);
}
}
Map<String, Object> scopeMap = Collections.<String, Object>singletonMap("websocket", new SimpSessionScope());
RootBeanDefinition scopeConfigurer = new RootBeanDefinition(CustomScopeConfigurer.class);
scopeConfigurer.getPropertyValues().add("scopes", scopeMap);
registerBeanDefByName("webSocketScopeConfigurer", scopeConfigurer, context, source);
registerWebSocketMessageBrokerStats(broker, inChannel, outChannel, context, source);
context.popAndRegisterContainingComponent();
return null;
}
private RuntimeBeanReference registerUserRegistry(Element element, ParserContext context, Object source) {
Element relayElement = DomUtils.getChildElementByTagName(element, "stomp-broker-relay");
boolean multiServer = (relayElement != null && relayElement.hasAttribute("user-registry-broadcast"));
if (multiServer) {
RootBeanDefinition localRegistryBeanDef = new RootBeanDefinition(DefaultSimpUserRegistry.class);
RootBeanDefinition beanDef = new RootBeanDefinition(MultiServerUserRegistry.class);
beanDef.getConstructorArgumentValues().addIndexedArgumentValue(0, localRegistryBeanDef);
String beanName = registerBeanDef(beanDef, context, source);
return new RuntimeBeanReference(beanName);
}
else {
RootBeanDefinition beanDef = new RootBeanDefinition(DefaultSimpUserRegistry.class);
String beanName = registerBeanDef(beanDef, context, source);
return new RuntimeBeanReference(beanName);
}
}
private ManagedMap<String, Object> registerHandlerMapping(Element element,
ParserContext context, Object source) {
RootBeanDefinition handlerMappingDef = new RootBeanDefinition(SimpleUrlHandlerMapping.class);
String orderAttribute = element.getAttribute("order");
@ -128,58 +204,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
handlerMappingDef.getPropertyValues().add("urlMap", urlMap);
registerBeanDef(handlerMappingDef, context, source);
Element channelElem = DomUtils.getChildElementByTagName(element, "client-inbound-channel");
RuntimeBeanReference inChannel = getMessageChannel("clientInboundChannel", channelElem, context, source);
channelElem = DomUtils.getChildElementByTagName(element, "client-outbound-channel");
RuntimeBeanReference outChannel = getMessageChannel("clientOutboundChannel", channelElem, context, source);
RootBeanDefinition registryBeanDef = new RootBeanDefinition(DefaultUserSessionRegistry.class);
String registryBeanName = registerBeanDef(registryBeanDef, context, source);
RuntimeBeanReference sessionRegistry = new RuntimeBeanReference(registryBeanName);
RuntimeBeanReference subProtoHandler = registerSubProtoHandler(element, inChannel, outChannel,
sessionRegistry, context, source);
for (Element endpointElem : DomUtils.getChildElementsByTagName(element, "stomp-endpoint")) {
RuntimeBeanReference requestHandler = registerRequestHandler(endpointElem, subProtoHandler, context, source);
String pathAttribute = endpointElem.getAttribute("path");
Assert.state(StringUtils.hasText(pathAttribute), "Invalid <stomp-endpoint> (no path mapping)");
List<String> paths = Arrays.asList(StringUtils.tokenizeToStringArray(pathAttribute, ","));
for (String path : paths) {
path = path.trim();
Assert.state(StringUtils.hasText(path), "Invalid <stomp-endpoint> path attribute: " + pathAttribute);
if (DomUtils.getChildElementByTagName(endpointElem, "sockjs") != null) {
path = path.endsWith("/") ? path + "**" : path + "/**";
}
urlMap.put(path, requestHandler);
}
}
channelElem = DomUtils.getChildElementByTagName(element, "broker-channel");
RuntimeBeanReference brokerChannel = getMessageChannel("brokerChannel", channelElem, context, source);
RuntimeBeanReference resolver = registerUserDestResolver(element, sessionRegistry, context, source);
RuntimeBeanReference userDestHandler = registerUserDestHandler(element, inChannel,
brokerChannel, resolver, context, source);
RootBeanDefinition broker = registerMessageBroker(element, userDestHandler, inChannel,
outChannel, brokerChannel, context, source);
RuntimeBeanReference converter = registerMessageConverter(element, context, source);
RuntimeBeanReference template = registerMessagingTemplate(element, brokerChannel, converter, context, source);
registerAnnotationMethodMessageHandler(element, inChannel, outChannel,converter, template, context, source);
Map<String, Object> scopeMap = Collections.<String, Object>singletonMap("websocket", new SimpSessionScope());
RootBeanDefinition scopeConfigurer = new RootBeanDefinition(CustomScopeConfigurer.class);
scopeConfigurer.getPropertyValues().add("scopes", scopeMap);
registerBeanDefByName("webSocketScopeConfigurer", scopeConfigurer, context, source);
registerWebSocketMessageBrokerStats(broker, inChannel, outChannel, context, source);
context.popAndRegisterContainingComponent();
return null;
return urlMap;
}
private RuntimeBeanReference getMessageChannel(String name, Element element, ParserContext context, Object source) {
@ -240,11 +265,10 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
return executorDef;
}
private RuntimeBeanReference registerSubProtoHandler(Element element, RuntimeBeanReference inChannel,
RuntimeBeanReference outChannel, RuntimeBeanReference registry, ParserContext context, Object source) {
private RuntimeBeanReference registerStompHandler(Element element, RuntimeBeanReference inChannel,
RuntimeBeanReference outChannel, ParserContext context, Object source) {
RootBeanDefinition stompHandlerDef = new RootBeanDefinition(StompSubProtocolHandler.class);
stompHandlerDef.getPropertyValues().add("userSessionRegistry", registry);
registerBeanDef(stompHandlerDef, context, source);
ConstructorArgumentValues cavs = new ConstructorArgumentValues();
@ -285,13 +309,16 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
RootBeanDefinition beanDef;
RuntimeBeanReference sockJsService = WebSocketNamespaceUtils.registerSockJsService(
element, SOCKJS_SCHEDULER_BEAN_NAME, context, source);
element, SCHEDULER_BEAN_NAME, context, source);
if (sockJsService != null) {
ConstructorArgumentValues cavs = new ConstructorArgumentValues();
cavs.addIndexedArgumentValue(0, sockJsService);
cavs.addIndexedArgumentValue(1, subProtoHandler);
beanDef = new RootBeanDefinition(SockJsHttpRequestHandler.class, cavs, null);
// Register alias for backwards compatibility with 4.1
context.getRegistry().registerAlias(SCHEDULER_BEAN_NAME, SOCKJS_SCHEDULER_BEAN_NAME);
}
else {
RuntimeBeanReference handshakeHandler = WebSocketNamespaceUtils.registerHandshakeHandler(element, context, source);
@ -312,9 +339,9 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
}
private RootBeanDefinition registerMessageBroker(Element brokerElement,
RuntimeBeanReference userDestHandler, RuntimeBeanReference inChannel,
RuntimeBeanReference outChannel, RuntimeBeanReference brokerChannel,
ParserContext context, Object source) {
RuntimeBeanReference inChannel, RuntimeBeanReference outChannel, RuntimeBeanReference brokerChannel,
Object userDestHandler, RuntimeBeanReference brokerTemplate,
RuntimeBeanReference userRegistry, ParserContext context, Object source) {
Element simpleBrokerElem = DomUtils.getChildElementByTagName(brokerElement, "simple-broker");
Element brokerRelayElem = DomUtils.getChildElementByTagName(brokerElement, "stomp-broker-relay");
@ -374,11 +401,18 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
if (brokerRelayElem.hasAttribute("virtual-host")) {
values.add("virtualHost", brokerRelayElem.getAttribute("virtual-host"));
}
if (brokerElement.hasAttribute("user-destination-broadcast")) {
String destination = brokerElement.getAttribute("user-destination-broadcast");
ManagedMap<String, Object> map = new ManagedMap<String, Object>();
map.setSource(source);
ManagedMap<String, Object> map = new ManagedMap<String, Object>();
map.setSource(source);
if (brokerRelayElem.hasAttribute("user-destination-broadcast")) {
String destination = brokerRelayElem.getAttribute("user-destination-broadcast");
map.put(destination, userDestHandler);
}
if (brokerRelayElem.hasAttribute("user-registry-broadcast")) {
String destination = brokerRelayElem.getAttribute("user-registry-broadcast");
map.put(destination, registerUserRegistryMessageHandler(userRegistry,
brokerTemplate, destination, context, source));
}
if (!map.isEmpty()) {
values.add("systemSubscriptions", map);
}
Class<?> handlerType = StompBrokerRelayMessageHandler.class;
@ -392,6 +426,22 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
return brokerDef;
}
private RuntimeBeanReference registerUserRegistryMessageHandler(
RuntimeBeanReference userRegistry, RuntimeBeanReference brokerTemplate,
String destination, ParserContext context, Object source) {
Object scheduler = WebSocketNamespaceUtils.registerScheduler(SCHEDULER_BEAN_NAME, context, source);
RootBeanDefinition beanDef = new RootBeanDefinition(UserRegistryMessageHandler.class);
beanDef.getConstructorArgumentValues().addIndexedArgumentValue(0, userRegistry);
beanDef.getConstructorArgumentValues().addIndexedArgumentValue(1, brokerTemplate);
beanDef.getConstructorArgumentValues().addIndexedArgumentValue(2, destination);
beanDef.getConstructorArgumentValues().addIndexedArgumentValue(3, scheduler);
String beanName = registerBeanDef(beanDef, context, source);
return new RuntimeBeanReference(beanName);
}
private RuntimeBeanReference registerMessageConverter(Element element, ParserContext context, Object source) {
Element convertersElement = DomUtils.getChildElementByTagName(element, "message-converters");
ManagedList<? super Object> converters = new ManagedList<Object>();
@ -484,11 +534,10 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
}
private RuntimeBeanReference registerUserDestResolver(Element brokerElem,
RuntimeBeanReference userSessionRegistry, ParserContext context, Object source) {
RuntimeBeanReference userRegistry, ParserContext context, Object source) {
ConstructorArgumentValues cavs = new ConstructorArgumentValues();
cavs.addIndexedArgumentValue(0, userSessionRegistry);
RootBeanDefinition beanDef = new RootBeanDefinition(DefaultUserDestinationResolver.class, cavs, null);
RootBeanDefinition beanDef = new RootBeanDefinition(DefaultUserDestinationResolver.class);
beanDef.getConstructorArgumentValues().addIndexedArgumentValue(0, userRegistry);
if (brokerElem.hasAttribute("user-destination-prefix")) {
beanDef.getPropertyValues().add("userDestinationPrefix", brokerElem.getAttribute("user-destination-prefix"));
}
@ -496,19 +545,24 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
}
private RuntimeBeanReference registerUserDestHandler(Element brokerElem,
RuntimeBeanReference inChannel, RuntimeBeanReference brokerChannel,
RuntimeBeanReference userDestinationResolver, ParserContext context, Object source) {
RuntimeBeanReference userRegistry, RuntimeBeanReference inChannel,
RuntimeBeanReference brokerChannel, ParserContext context, Object source) {
ConstructorArgumentValues cavs = new ConstructorArgumentValues();
cavs.addIndexedArgumentValue(0, inChannel);
cavs.addIndexedArgumentValue(1, brokerChannel);
cavs.addIndexedArgumentValue(2, userDestinationResolver);
RootBeanDefinition beanDef = new RootBeanDefinition(UserDestinationMessageHandler.class, cavs, null);
if (brokerElem.hasAttribute("user-destination-broadcast")) {
String destination = brokerElem.getAttribute("user-destination-broadcast");
beanDef.getPropertyValues().add("userDestinationBroadcast", destination);
Object userDestResolver = registerUserDestResolver(brokerElem, userRegistry, context, source);
RootBeanDefinition beanDef = new RootBeanDefinition(UserDestinationMessageHandler.class);
beanDef.getConstructorArgumentValues().addIndexedArgumentValue(0, inChannel);
beanDef.getConstructorArgumentValues().addIndexedArgumentValue(1, brokerChannel);
beanDef.getConstructorArgumentValues().addIndexedArgumentValue(2, userDestResolver);
Element relayElement = DomUtils.getChildElementByTagName(brokerElem, "stomp-broker-relay");
if (relayElement != null && relayElement.hasAttribute("user-destination-broadcast")) {
String destination = relayElement.getAttribute("user-destination-broadcast");
beanDef.getPropertyValues().add("broadcastDestination", destination);
}
return new RuntimeBeanReference(registerBeanDef(beanDef, context, source));
String beanName = registerBeanDef(beanDef, context, source);
return new RuntimeBeanReference(beanName);
}
private void registerWebSocketMessageBrokerStats(RootBeanDefinition broker, RuntimeBeanReference inChannel,
@ -530,7 +584,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
if (context.getRegistry().containsBeanDefinition(name)) {
beanDef.getPropertyValues().add("outboundChannelExecutor", context.getRegistry().getBeanDefinition(name));
}
name = SOCKJS_SCHEDULER_BEAN_NAME;
name = SCHEDULER_BEAN_NAME;
if (context.getRegistry().containsBeanDefinition(name)) {
beanDef.getPropertyValues().add("sockJsTaskScheduler", context.getRegistry().getBeanDefinition(name));
}

View File

@ -62,7 +62,7 @@ class WebSocketNamespaceUtils {
return handlerRef;
}
public static RuntimeBeanReference registerSockJsService(Element element, String sockJsSchedulerName,
public static RuntimeBeanReference registerSockJsService(Element element, String schedulerName,
ParserContext context, Object source) {
Element sockJsElement = DomUtils.getChildElementByTagName(element, "sockjs");
@ -79,7 +79,7 @@ class WebSocketNamespaceUtils {
scheduler = new RuntimeBeanReference(customTaskSchedulerName);
}
else {
scheduler = registerSockJsScheduler(sockJsSchedulerName, context, source);
scheduler = registerScheduler(schedulerName, context, source);
}
sockJsServiceDef.getConstructorArgumentValues().addIndexedArgumentValue(0, scheduler);
@ -156,7 +156,7 @@ class WebSocketNamespaceUtils {
return null;
}
private static RuntimeBeanReference registerSockJsScheduler(String schedulerName, ParserContext context, Object source) {
public static RuntimeBeanReference registerScheduler(String schedulerName, ParserContext context, Object source) {
if (!context.getRegistry().containsBeanDefinition(schedulerName)) {
RootBeanDefinition taskSchedulerDef = new RootBeanDefinition(ThreadPoolTaskScheduler.class);
taskSchedulerDef.setSource(source);

View File

@ -22,7 +22,6 @@ import java.util.List;
import java.util.Map;
import org.springframework.context.ApplicationContext;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
@ -63,11 +62,10 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry {
public WebMvcStompEndpointRegistry(WebSocketHandler webSocketHandler,
WebSocketTransportRegistration transportRegistration,
UserSessionRegistry userSessionRegistry, TaskScheduler defaultSockJsTaskScheduler) {
TaskScheduler defaultSockJsTaskScheduler) {
Assert.notNull(webSocketHandler, "'webSocketHandler' is required ");
Assert.notNull(transportRegistration, "'transportRegistration' is required");
Assert.notNull(userSessionRegistry, "'userSessionRegistry' is required");
this.webSocketHandler = webSocketHandler;
this.subProtocolWebSocketHandler = unwrapSubProtocolWebSocketHandler(webSocketHandler);
@ -80,7 +78,6 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry {
}
this.stompHandler = new StompSubProtocolHandler();
this.stompHandler.setUserSessionRegistry(userSessionRegistry);
if (transportRegistration.getMessageSizeLimit() != null) {
this.stompHandler.setMessageSizeLimit(transportRegistration.getMessageSizeLimit());

View File

@ -25,11 +25,14 @@ import org.springframework.messaging.simp.annotation.support.SimpAnnotationMetho
import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler;
import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.messaging.simp.user.UserSessionRegistryAdapter;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.WebSocketMessageBrokerStats;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.springframework.web.socket.messaging.DefaultSimpUserRegistry;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.messaging.WebSocketAnnotationMethodMessageHandler;
@ -58,10 +61,10 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
@Bean
public HandlerMapping stompWebSocketHandlerMapping() {
WebSocketHandler handler = subProtocolWebSocketHandler();
handler = decorateWebSocketHandler(handler);
WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry(handler,
getTransportRegistration(), userSessionRegistry(), messageBrokerSockJsTaskScheduler());
WebSocketHandler handler = decorateWebSocketHandler(subProtocolWebSocketHandler());
WebSocketTransportRegistration transport = getTransportRegistration();
ThreadPoolTaskScheduler scheduler = messageBrokerTaskScheduler();
WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry(handler, transport, scheduler);
registry.setApplicationContext(getApplicationContext());
registerStompEndpoints(registry);
return registry.getHandlerMapping();
@ -90,33 +93,21 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
protected void configureWebSocketTransport(WebSocketTransportRegistration registry) {
}
protected abstract void registerStompEndpoints(StompEndpointRegistry registry);
/**
* The default TaskScheduler to use if none is configured via
* {@link SockJsServiceRegistration#setTaskScheduler(org.springframework.scheduling.TaskScheduler)}, i.e.
* <pre class="code">
* &#064;Configuration
* &#064;EnableWebSocketMessageBroker
* public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
*
* public void registerStompEndpoints(StompEndpointRegistry registry) {
* registry.addEndpoint("/stomp").withSockJS().setTaskScheduler(myScheduler());
* }
*
* // ...
* }
* </pre>
*/
@Bean
public ThreadPoolTaskScheduler messageBrokerSockJsTaskScheduler() {
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
scheduler.setThreadNamePrefix("MessageBrokerSockJS-");
scheduler.setPoolSize(Runtime.getRuntime().availableProcessors());
scheduler.setRemoveOnCancelPolicy(true);
return scheduler;
@Override
@SuppressWarnings("deprecation")
protected SimpUserRegistry createLocalUserRegistry() {
org.springframework.messaging.simp.user.UserSessionRegistry sessionRegistry = userSessionRegistry();
if (sessionRegistry == null) {
return new DefaultSimpUserRegistry();
}
else {
return (userSessionRegistry() instanceof SimpUserRegistry ?
(SimpUserRegistry) userSessionRegistry() : new UserSessionRegistryAdapter(sessionRegistry));
}
}
protected abstract void registerStompEndpoints(StompEndpointRegistry registry);
@Bean
public static CustomScopeConfigurer webSocketScopeConfigurer() {
CustomScopeConfigurer configurer = new CustomScopeConfigurer();
@ -138,7 +129,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
stats.setStompBrokerRelay(brokerRelay);
stats.setInboundChannelExecutor(clientInboundChannelExecutor());
stats.setOutboundChannelExecutor(clientOutboundChannelExecutor());
stats.setSockJsTaskScheduler(messageBrokerSockJsTaskScheduler());
stats.setSockJsTaskScheduler(messageBrokerTaskScheduler());
return stats;
}

View File

@ -16,6 +16,8 @@
package org.springframework.web.socket.messaging;
import java.security.Principal;
import org.springframework.context.ApplicationEvent;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
@ -32,6 +34,8 @@ public abstract class AbstractSubProtocolEvent extends ApplicationEvent {
private final Message<byte[]> message;
private final Principal user;
/**
* Create a new AbstractSubProtocolEvent.
@ -42,6 +46,19 @@ public abstract class AbstractSubProtocolEvent extends ApplicationEvent {
super(source);
Assert.notNull(message, "Message must not be null");
this.message = message;
this.user = null;
}
/**
* Create a new AbstractSubProtocolEvent.
* @param source the component that published the event (never {@code null})
* @param message the incoming message
*/
protected AbstractSubProtocolEvent(Object source, Message<byte[]> message, Principal user) {
super(source);
Assert.notNull(message, "Message must not be null");
this.message = message;
this.user = user;
}
@ -60,6 +77,13 @@ public abstract class AbstractSubProtocolEvent extends ApplicationEvent {
return this.message;
}
/**
* Return the user for the session associated with the event.
*/
public Principal getUser() {
return this.user;
}
@Override
public String toString() {
return getClass().getSimpleName() + "[" + this.message + "]";

View File

@ -0,0 +1,336 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import java.security.Principal;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.core.Ordered;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.SimpSession;
import org.springframework.messaging.simp.user.SimpSubscription;
import org.springframework.messaging.simp.user.SimpSubscriptionMatcher;
import org.springframework.messaging.simp.user.SimpUser;
import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
/**
* Default, mutable, thread-safe implementation of {@link SimpUserRegistry}.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicationListener {
private final Map<String, DefaultSimpUser> users = new ConcurrentHashMap<String, DefaultSimpUser>();
private final Map<String, DefaultSimpSession> sessions = new ConcurrentHashMap<String, DefaultSimpSession>();
@Override
public SimpUser getUser(String userName) {
return this.users.get(userName);
}
@Override
public Set<SimpUser> getUsers() {
return new HashSet<SimpUser>(this.users.values());
}
public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) {
Set<SimpSubscription> result = new HashSet<SimpSubscription>();
for (DefaultSimpSession session : this.sessions.values()) {
for (SimpSubscription subscription : session.subscriptions.values()) {
if (matcher.match(subscription)) {
result.add(subscription);
}
}
}
return result;
}
@Override
public boolean supportsEventType(Class<? extends ApplicationEvent> eventType) {
return AbstractSubProtocolEvent.class.isAssignableFrom(eventType);
}
@Override
public boolean supportsSourceType(Class<?> sourceType) {
return true;
}
@Override
public void onApplicationEvent(ApplicationEvent event) {
AbstractSubProtocolEvent subProtocolEvent = (AbstractSubProtocolEvent) event;
Message<?> message = subProtocolEvent.getMessage();
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, SimpMessageHeaderAccessor.class);
String sessionId = accessor.getSessionId();
if (event instanceof SessionSubscribeEvent) {
DefaultSimpSession session = this.sessions.get(sessionId);
if (session != null) {
String id = accessor.getSubscriptionId();
String destination = accessor.getDestination();
session.addSubscription(id, destination);
}
}
else if (event instanceof SessionConnectedEvent) {
Principal user = subProtocolEvent.getUser();
if (user == null) {
return;
}
String name = user.getName();
if (user instanceof DestinationUserNameProvider) {
name = ((DestinationUserNameProvider) user).getDestinationUserName();
}
synchronized (this) {
DefaultSimpUser simpUser = this.users.get(name);
if (simpUser == null) {
simpUser = new DefaultSimpUser(name, sessionId);
this.users.put(name, simpUser);
}
else {
simpUser.addSession(sessionId);
}
this.sessions.put(sessionId, (DefaultSimpSession) simpUser.getSession(sessionId));
}
}
else if (event instanceof SessionDisconnectEvent) {
synchronized (this) {
DefaultSimpSession session = this.sessions.remove(sessionId);
if (session != null) {
DefaultSimpUser user = session.getUser();
user.removeSession(sessionId);
if (!user.hasSessions()) {
this.users.remove(user.getName());
}
}
}
}
else if (event instanceof SessionUnsubscribeEvent) {
DefaultSimpSession session = this.sessions.get(sessionId);
if (session != null) {
String subscriptionId = accessor.getSubscriptionId();
session.removeSubscription(subscriptionId);
}
}
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
@Override
public String toString() {
return "users=" + this.users;
}
private static class DefaultSimpUser implements SimpUser {
private final String name;
private final Map<String, SimpSession> sessions =
new ConcurrentHashMap<String, SimpSession>(1);
public DefaultSimpUser(String userName, String sessionId) {
Assert.notNull(userName);
Assert.notNull(sessionId);
this.name = userName;
this.sessions.put(sessionId, new DefaultSimpSession(sessionId, this));
}
@Override
public String getName() {
return this.name;
}
@Override
public boolean hasSessions() {
return !this.sessions.isEmpty();
}
@Override
public SimpSession getSession(String sessionId) {
return this.sessions.get(sessionId);
}
@Override
public Set<SimpSession> getSessions() {
return new HashSet<SimpSession>(this.sessions.values());
}
void addSession(String sessionId) {
DefaultSimpSession session = new DefaultSimpSession(sessionId, this);
this.sessions.put(sessionId, session);
}
void removeSession(String sessionId) {
this.sessions.remove(sessionId);
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || !(other instanceof SimpUser)) {
return false;
}
return this.name.equals(((SimpUser) other).getName());
}
@Override
public int hashCode() {
return this.name.hashCode();
}
@Override
public String toString() {
return "name=" + this.name + ", sessions=" + this.sessions;
}
}
private static class DefaultSimpSession implements SimpSession {
private final String id;
private final DefaultSimpUser user;
private final Map<String, SimpSubscription> subscriptions = new ConcurrentHashMap<String, SimpSubscription>(4);
public DefaultSimpSession(String id, DefaultSimpUser user) {
Assert.notNull(id);
Assert.notNull(user);
this.id = id;
this.user = user;
}
@Override
public String getId() {
return this.id;
}
@Override
public DefaultSimpUser getUser() {
return this.user;
}
@Override
public Set<SimpSubscription> getSubscriptions() {
return new HashSet<SimpSubscription>(this.subscriptions.values());
}
void addSubscription(String id, String destination) {
this.subscriptions.put(id, new DefaultSimpSubscription(id, destination, this));
}
void removeSubscription(String id) {
this.subscriptions.remove(id);
}
@Override
public int hashCode() {
return this.id.hashCode();
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || !(other instanceof SimpSubscription)) {
return false;
}
return this.id.equals(((SimpSubscription) other).getId());
}
@Override
public String toString() {
return "id=" + this.id + ", subscriptions=" + this.subscriptions;
}
}
private static class DefaultSimpSubscription implements SimpSubscription {
private final String id;
private final DefaultSimpSession session;
private final String destination;
public DefaultSimpSubscription(String id, String destination, DefaultSimpSession session) {
Assert.notNull(id);
Assert.hasText(destination);
Assert.notNull(session);
this.id = id;
this.destination = destination;
this.session = session;
}
@Override
public String getId() {
return this.id;
}
@Override
public DefaultSimpSession getSession() {
return this.session;
}
@Override
public String getDestination() {
return this.destination;
}
@Override
public int hashCode() {
return 31 * this.id.hashCode() + getSession().hashCode();
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || !(other instanceof SimpSubscription)) {
return false;
}
SimpSubscription otherSubscription = (SimpSubscription) other;
return (getSession().getId().equals(otherSubscription.getSession().getId()) &&
this.id.equals(otherSubscription.getId()));
}
@Override
public String toString() {
return "destination=" + this.destination;
}
}
}

View File

@ -16,6 +16,8 @@
package org.springframework.web.socket.messaging;
import java.security.Principal;
import org.springframework.messaging.Message;
/**
@ -41,4 +43,8 @@ public class SessionConnectEvent extends AbstractSubProtocolEvent {
super(source, message);
}
public SessionConnectEvent(Object source, Message<byte[]> message, Principal user) {
super(source, message, user);
}
}

View File

@ -16,6 +16,8 @@
package org.springframework.web.socket.messaging;
import java.security.Principal;
import org.springframework.messaging.Message;
/**
@ -37,4 +39,8 @@ public class SessionConnectedEvent extends AbstractSubProtocolEvent {
super(source, message);
}
public SessionConnectedEvent(Object source, Message<byte[]> message, Principal user) {
super(source, message, user);
}
}

View File

@ -16,6 +16,8 @@
package org.springframework.web.socket.messaging;
import java.security.Principal;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
@ -45,14 +47,21 @@ public class SessionDisconnectEvent extends AbstractSubProtocolEvent {
* @param sessionId the disconnect message
* @param closeStatus the status object
*/
public SessionDisconnectEvent(Object source, Message<byte[]> message, String sessionId, CloseStatus closeStatus) {
public SessionDisconnectEvent(Object source, Message<byte[]> message, String sessionId,
CloseStatus closeStatus) {
this(source, message, sessionId, closeStatus, null);
}
public SessionDisconnectEvent(Object source, Message<byte[]> message, String sessionId,
CloseStatus closeStatus, Principal user) {
super(source, message);
Assert.notNull(sessionId, "'sessionId' must not be null");
this.sessionId = sessionId;
this.status = closeStatus;
}
/**
* Return the session id.
*/

View File

@ -17,6 +17,8 @@
package org.springframework.web.socket.messaging;
import java.security.Principal;
import org.springframework.messaging.Message;
/**
@ -34,4 +36,8 @@ public class SessionSubscribeEvent extends AbstractSubProtocolEvent {
super(source, message);
}
public SessionSubscribeEvent(Object source, Message<byte[]> message, Principal user) {
super(source, message, user);
}
}

View File

@ -17,6 +17,8 @@
package org.springframework.web.socket.messaging;
import java.security.Principal;
import org.springframework.messaging.Message;
/**
@ -34,4 +36,8 @@ public class SessionUnsubscribeEvent extends AbstractSubProtocolEvent {
super(source, message);
}
public SessionUnsubscribeEvent(Object source, Message<byte[]> message, Principal user) {
super(source, message, user);
}
}

View File

@ -34,7 +34,6 @@ import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpAttributes;
import org.springframework.messaging.simp.SimpAttributesContextHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
@ -44,8 +43,6 @@ import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.AbstractMessageChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
@ -94,8 +91,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
private int messageSizeLimit = 64 * 1024;
private UserSessionRegistry userSessionRegistry;
private final StompEncoder stompEncoder = new StompEncoder();
private final StompDecoder stompDecoder = new StompDecoder();
@ -134,21 +129,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
return this.messageSizeLimit;
}
/**
* Provide a registry with which to register active user session ids.
* @see org.springframework.messaging.simp.user.UserDestinationMessageHandler
*/
public void setUserSessionRegistry(UserSessionRegistry registry) {
this.userSessionRegistry = registry;
}
/**
* @return the configured UserSessionRegistry.
*/
public UserSessionRegistry getUserSessionRegistry() {
return this.userSessionRegistry;
}
/**
* Configure a {@link MessageHeaderInitializer} to apply to the headers of all
* messages created from decoded STOMP frames and other messages sent to the
@ -234,9 +214,11 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Principal user = session.getPrincipal();
headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(session.getPrincipal());
headerAccessor.setUser(user);
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(outputChannel)) {
headerAccessor.setImmutable();
@ -257,13 +239,13 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
SimpAttributesContextHolder.setAttributesFromMessage(message);
if (this.eventPublisher != null) {
if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
publishEvent(new SessionConnectEvent(this, message));
publishEvent(new SessionConnectEvent(this, message, user));
}
else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) {
publishEvent(new SessionSubscribeEvent(this, message));
publishEvent(new SessionSubscribeEvent(this, message, user));
}
else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
publishEvent(new SessionUnsubscribeEvent(this, message));
publishEvent(new SessionUnsubscribeEvent(this, message, user));
}
}
outputChannel.send(message);
@ -349,7 +331,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
try {
SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes());
SimpAttributesContextHolder.setAttributes(simpAttributes);
publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
Principal user = session.getPrincipal();
publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message, user));
}
finally {
SimpAttributesContextHolder.resetAttributes();
@ -466,10 +449,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
if (principal != null) {
accessor = toMutableAccessor(accessor, message);
accessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
if (this.userSessionRegistry != null) {
String userName = getSessionRegistryUserName(principal);
this.userSessionRegistry.registerSessionId(userName, session.getId());
}
}
long[] heartbeat = accessor.getHeartbeat();
if (heartbeat[1] > 0) {
@ -481,14 +460,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
return accessor;
}
private String getSessionRegistryUserName(Principal principal) {
String userName = principal.getName();
if (principal instanceof DestinationUserNameProvider) {
userName = ((DestinationUserNameProvider) principal).getDestinationUserName();
}
return userName;
}
@Override
public String resolveSessionId(Message<?> message) {
return SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
@ -505,17 +476,13 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Override
public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
this.decoders.remove(session.getId());
Principal principal = session.getPrincipal();
if (principal != null && this.userSessionRegistry != null) {
String userName = getSessionRegistryUserName(principal);
this.userSessionRegistry.unregisterSessionId(userName, session.getId());
}
Message<byte[]> message = createDisconnectMessage(session);
SimpAttributes simpAttributes = SimpAttributes.fromMessage(message);
try {
SimpAttributesContextHolder.setAttributes(simpAttributes);
if (this.eventPublisher != null) {
publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus));
Principal user = session.getPrincipal();
publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus, user));
}
outputChannel.send(message);
}

View File

@ -344,6 +344,27 @@
]]></xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="user-destination-broadcast" type="xsd:string">
<xsd:annotation>
<xsd:documentation><![CDATA[
Set a destination to broadcast messages to that remain unresolved because
the user is not connected. In a multi-application server scenario this
gives other application servers a chance to try.
By default this is not set.
]]></xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="user-registry-broadcast" type="xsd:string">
<xsd:annotation>
<xsd:documentation><![CDATA[
Set a destination to broadcast the content of the local user registry to
and to listen for such broadcasts from other servers. In a multi-application
server scenarios this allows each server's user registry to be aware of
users connected to other servers.
By default this is not set.
]]></xsd:documentation>
</xsd:annotation>
</xsd:attribute>
</xsd:complexType>
<xsd:complexType name="simple-broker">
@ -853,17 +874,6 @@
The prefix used to identify user destinations.
Any destinations that do not start with the given prefix are not be resolved.
The default value is "/user/".
]]></xsd:documentation>
</xsd:annotation>
</xsd:attribute>
<xsd:attribute name="user-destination-broadcast" type="xsd:string">
<xsd:annotation>
<xsd:documentation><![CDATA[
Set a destination to broadcast messages to that remain unresolved because
the user is not connected. In a multi-application server scenario this
gives other application servers a chance to try.
Note: this option applies only when the stomp-broker-relay is enabled.
By default this is not set.
]]></xsd:documentation>
</xsd:annotation>
</xsd:attribute>

View File

@ -16,14 +16,20 @@
package org.springframework.web.socket.config;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.hamcrest.Matchers;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.CustomScopeConfigurer;
@ -46,9 +52,11 @@ import org.springframework.messaging.simp.annotation.support.SimpAnnotationMetho
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.user.DefaultUserDestinationResolver;
import org.springframework.messaging.simp.user.MultiServerUserRegistry;
import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserDestinationResolver;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.simp.user.UserRegistryMessageHandler;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
@ -64,7 +72,9 @@ import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TestWebSocketSession;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
import org.springframework.web.socket.messaging.DefaultSimpUserRegistry;
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
@ -75,9 +85,6 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
/**
* Test fixture for MessageBrokerBeanDefinitionParser.
* See test configuration files websocket-config-broker-*.xml.
@ -133,7 +140,8 @@ public class MessageBrokerBeanDefinitionParserTests {
assertEquals(25 * 1000, subProtocolWsHandler.getSendTimeLimit());
assertEquals(1024 * 1024, subProtocolWsHandler.getSendBufferSizeLimit());
StompSubProtocolHandler stompHandler = (StompSubProtocolHandler) subProtocolWsHandler.getProtocolHandlerMap().get("v12.stomp");
Map<String, SubProtocolHandler> handlerMap = subProtocolWsHandler.getProtocolHandlerMap();
StompSubProtocolHandler stompHandler = (StompSubProtocolHandler) handlerMap.get("v12.stomp");
assertNotNull(stompHandler);
assertEquals(128 * 1024, stompHandler.getMessageSizeLimit());
@ -166,15 +174,15 @@ public class MessageBrokerBeanDefinitionParserTests {
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
assertEquals(Arrays.asList("http://mydomain3.com", "http://mydomain4.com"), defaultSockJsService.getAllowedOrigins());
UserSessionRegistry userSessionRegistry = this.appContext.getBean(UserSessionRegistry.class);
assertNotNull(userSessionRegistry);
SimpUserRegistry userRegistry = this.appContext.getBean(SimpUserRegistry.class);
assertNotNull(userRegistry);
assertEquals(DefaultSimpUserRegistry.class, userRegistry.getClass());
UserDestinationResolver userDestResolver = this.appContext.getBean(UserDestinationResolver.class);
assertNotNull(userDestResolver);
assertThat(userDestResolver, Matchers.instanceOf(DefaultUserDestinationResolver.class));
DefaultUserDestinationResolver defaultUserDestResolver = (DefaultUserDestinationResolver) userDestResolver;
assertEquals("/personal/", defaultUserDestResolver.getDestinationPrefix());
assertSame(stompHandler.getUserSessionRegistry(), defaultUserDestResolver.getUserSessionRegistry());
UserDestinationMessageHandler userDestHandler = this.appContext.getBean(UserDestinationMessageHandler.class);
assertNotNull(userDestHandler);
@ -192,11 +200,12 @@ public class MessageBrokerBeanDefinitionParserTests {
testChannel("clientInboundChannel", subscriberTypes, 2);
testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(SubProtocolWebSocketHandler.class);
subscriberTypes = Collections.singletonList(SubProtocolWebSocketHandler.class);
testChannel("clientOutboundChannel", subscriberTypes, 1);
testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(
SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class);
testChannel("brokerChannel", subscriberTypes, 1);
try {
this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class);
@ -260,7 +269,7 @@ public class MessageBrokerBeanDefinitionParserTests {
testChannel("clientInboundChannel", subscriberTypes, 2);
testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(SubProtocolWebSocketHandler.class);
subscriberTypes = Collections.singletonList(SubProtocolWebSocketHandler.class);
testChannel("clientOutboundChannel", subscriberTypes, 1);
testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60);
@ -275,11 +284,20 @@ public class MessageBrokerBeanDefinitionParserTests {
// expected
}
String destination = "/topic/unresolved-user-destination";
UserDestinationMessageHandler userDestHandler = this.appContext.getBean(UserDestinationMessageHandler.class);
assertEquals("/topic/unresolved", userDestHandler.getUserDestinationBroadcast());
assertEquals(destination, userDestHandler.getBroadcastDestination());
assertNotNull(messageBroker.getSystemSubscriptions());
assertSame(userDestHandler, messageBroker.getSystemSubscriptions().get("/topic/unresolved"));
assertSame(userDestHandler, messageBroker.getSystemSubscriptions().get(destination));
destination = "/topic/simp-user-registry";
UserRegistryMessageHandler userRegistryHandler = this.appContext.getBean(UserRegistryMessageHandler.class);
assertEquals(destination, userRegistryHandler.getBroadcastDestination());
assertNotNull(messageBroker.getSystemSubscriptions());
assertSame(userRegistryHandler, messageBroker.getSystemSubscriptions().get(destination));
SimpUserRegistry userRegistry = this.appContext.getBean(SimpUserRegistry.class);
assertEquals(MultiServerUserRegistry.class, userRegistry.getClass());
String name = "webSocketMessageBrokerStats";
WebSocketMessageBrokerStats stats = this.appContext.getBean(name, WebSocketMessageBrokerStats.class);
@ -339,7 +357,7 @@ public class MessageBrokerBeanDefinitionParserTests {
testChannel("clientInboundChannel", subscriberTypes, 3);
testExecutor("clientInboundChannel", 100, 200, 600);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(SubProtocolWebSocketHandler.class);
subscriberTypes = Collections.singletonList(SubProtocolWebSocketHandler.class);
testChannel("clientOutboundChannel", subscriberTypes, 3);
testExecutor("clientOutboundChannel", 101, 201, 601);

View File

@ -16,6 +16,8 @@
package org.springframework.web.socket.config.annotation;
import static org.junit.Assert.*;
import java.util.Map;
import org.junit.Before;
@ -23,17 +25,12 @@ import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.util.UrlPathHelper;
import static org.junit.Assert.*;
/**
* Test fixture for
* {@link org.springframework.web.socket.config.annotation.WebMvcStompEndpointRegistry}.
@ -46,17 +43,16 @@ public class WebMvcStompEndpointRegistryTests {
private SubProtocolWebSocketHandler webSocketHandler;
private UserSessionRegistry userSessionRegistry;
@Before
public void setup() {
SubscribableChannel inChannel = Mockito.mock(SubscribableChannel.class);
SubscribableChannel outChannel = Mockito.mock(SubscribableChannel.class);
this.webSocketHandler = new SubProtocolWebSocketHandler(inChannel, outChannel);
this.userSessionRegistry = new DefaultUserSessionRegistry();
this.endpointRegistry = new WebMvcStompEndpointRegistry(this.webSocketHandler,
new WebSocketTransportRegistration(), this.userSessionRegistry, Mockito.mock(TaskScheduler.class));
WebSocketTransportRegistration transport = new WebSocketTransportRegistration();
TaskScheduler scheduler = Mockito.mock(TaskScheduler.class);
this.endpointRegistry = new WebMvcStompEndpointRegistry(this.webSocketHandler, transport, scheduler);
}
@ -69,9 +65,6 @@ public class WebMvcStompEndpointRegistryTests {
assertNotNull(protocolHandlers.get("v10.stomp"));
assertNotNull(protocolHandlers.get("v11.stomp"));
assertNotNull(protocolHandlers.get("v12.stomp"));
StompSubProtocolHandler stompHandler = (StompSubProtocolHandler) protocolHandlers.get("v10.stomp");
assertSame(this.userSessionRegistry, stompHandler.getUserSessionRegistry());
}
@Test

View File

@ -136,19 +136,16 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
}
@Test
public void webSocketTransportOptions() {
public void webSocketHandler() {
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
SubProtocolWebSocketHandler subProtocolWebSocketHandler =
config.getBean("subProtocolWebSocketHandler", SubProtocolWebSocketHandler.class);
SubProtocolWebSocketHandler subWsHandler = config.getBean(SubProtocolWebSocketHandler.class);
assertEquals(1024 * 1024, subProtocolWebSocketHandler.getSendBufferSizeLimit());
assertEquals(25 * 1000, subProtocolWebSocketHandler.getSendTimeLimit());
assertEquals(1024 * 1024, subWsHandler.getSendBufferSizeLimit());
assertEquals(25 * 1000, subWsHandler.getSendTimeLimit());
List<SubProtocolHandler> protocolHandlers = subProtocolWebSocketHandler.getProtocolHandlers();
for(SubProtocolHandler protocolHandler : protocolHandlers) {
assertTrue(protocolHandler instanceof StompSubProtocolHandler);
assertEquals(128 * 1024, ((StompSubProtocolHandler) protocolHandler).getMessageSizeLimit());
}
Map<String, SubProtocolHandler> handlerMap = subWsHandler.getProtocolHandlerMap();
StompSubProtocolHandler protocolHandler = (StompSubProtocolHandler) handlerMap.get("v12.stomp");
assertEquals(128 * 1024, protocolHandler.getMessageSizeLimit());
}
@Test

View File

@ -0,0 +1,199 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.messaging;
import static org.junit.Assert.*;
import java.security.Principal;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.user.SimpSubscription;
import org.springframework.messaging.simp.user.SimpSubscriptionMatcher;
import org.springframework.messaging.simp.user.SimpUser;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.CloseStatus;
/**
* Test fixture for
* {@link DefaultSimpUserRegistry}
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class DefaultSimpUserRegistryTests {
@Test
public void addOneSessionId() {
TestPrincipal user = new TestPrincipal("joe");
Message<byte[]> message = createMessage(SimpMessageType.CONNECT_ACK, "123");
SessionConnectedEvent event = new SessionConnectedEvent(this, message, user);
DefaultSimpUserRegistry registry = new DefaultSimpUserRegistry();
registry.onApplicationEvent(event);
SimpUser simpUser = registry.getUser("joe");
assertNotNull(simpUser);
assertEquals(1, simpUser.getSessions().size());
assertNotNull(simpUser.getSession("123"));
}
@Test
public void addMultipleSessionIds() {
DefaultSimpUserRegistry registry = new DefaultSimpUserRegistry();
TestPrincipal user = new TestPrincipal("joe");
Message<byte[]> message = createMessage(SimpMessageType.CONNECT_ACK, "123");
SessionConnectedEvent event = new SessionConnectedEvent(this, message, user);
registry.onApplicationEvent(event);
message = createMessage(SimpMessageType.CONNECT_ACK, "456");
event = new SessionConnectedEvent(this, message, user);
registry.onApplicationEvent(event);
message = createMessage(SimpMessageType.CONNECT_ACK, "789");
event = new SessionConnectedEvent(this, message, user);
registry.onApplicationEvent(event);
SimpUser simpUser = registry.getUser("joe");
assertNotNull(simpUser);
assertEquals(3, simpUser.getSessions().size());
assertNotNull(simpUser.getSession("123"));
assertNotNull(simpUser.getSession("456"));
assertNotNull(simpUser.getSession("789"));
}
@Test
public void removeSessionIds() {
DefaultSimpUserRegistry registry = new DefaultSimpUserRegistry();
TestPrincipal user = new TestPrincipal("joe");
Message<byte[]> message = createMessage(SimpMessageType.CONNECT_ACK, "123");
SessionConnectedEvent connectedEvent = new SessionConnectedEvent(this, message, user);
registry.onApplicationEvent(connectedEvent);
message = createMessage(SimpMessageType.CONNECT_ACK, "456");
connectedEvent = new SessionConnectedEvent(this, message, user);
registry.onApplicationEvent(connectedEvent);
message = createMessage(SimpMessageType.CONNECT_ACK, "789");
connectedEvent = new SessionConnectedEvent(this, message, user);
registry.onApplicationEvent(connectedEvent);
SimpUser simpUser = registry.getUser("joe");
assertNotNull(simpUser);
assertEquals(3, simpUser.getSessions().size());
CloseStatus status = CloseStatus.GOING_AWAY;
message = createMessage(SimpMessageType.DISCONNECT, "456");
SessionDisconnectEvent disconnectEvent = new SessionDisconnectEvent(this, message, "456", status, user);
registry.onApplicationEvent(disconnectEvent);
message = createMessage(SimpMessageType.DISCONNECT, "789");
disconnectEvent = new SessionDisconnectEvent(this, message, "789", status, user);
registry.onApplicationEvent(disconnectEvent);
assertEquals(1, simpUser.getSessions().size());
assertNotNull(simpUser.getSession("123"));
}
@Test
public void findSubscriptions() throws Exception {
DefaultSimpUserRegistry registry = new DefaultSimpUserRegistry();
TestPrincipal user = new TestPrincipal("joe");
Message<byte[]> message = createMessage(SimpMessageType.CONNECT_ACK, "123");
SessionConnectedEvent event = new SessionConnectedEvent(this, message, user);
registry.onApplicationEvent(event);
message = createMessage(SimpMessageType.SUBSCRIBE, "123", "sub1", "/match");
SessionSubscribeEvent subscribeEvent = new SessionSubscribeEvent(this, message, user);
registry.onApplicationEvent(subscribeEvent);
message = createMessage(SimpMessageType.SUBSCRIBE, "123", "sub2", "/match");
subscribeEvent = new SessionSubscribeEvent(this, message, user);
registry.onApplicationEvent(subscribeEvent);
message = createMessage(SimpMessageType.SUBSCRIBE, "123", "sub3", "/not-a-match");
subscribeEvent = new SessionSubscribeEvent(this, message, user);
registry.onApplicationEvent(subscribeEvent);
Set<SimpSubscription> matches = registry.findSubscriptions(new SimpSubscriptionMatcher() {
@Override
public boolean match(SimpSubscription subscription) {
return subscription.getDestination().equals("/match");
}
});
assertEquals(2, matches.size());
Iterator<SimpSubscription> iterator = matches.iterator();
Set<String> sessionIds = new HashSet<>(2);
sessionIds.add(iterator.next().getId());
sessionIds.add(iterator.next().getId());
assertEquals(new HashSet<>(Arrays.asList("sub1", "sub2")), sessionIds);
}
private Message<byte[]> createMessage(SimpMessageType type, String sessionId) {
return createMessage(type, sessionId, null, null);
}
private Message<byte[]> createMessage(SimpMessageType type, String sessionId, String subscriptionId,
String destination) {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(type);
accessor.setSessionId(sessionId);
if (destination != null) {
accessor.setDestination(destination);
}
if (subscriptionId != null) {
accessor.setSubscriptionId(subscriptionId);
}
return MessageBuilder.createMessage(new byte[0], accessor.getMessageHeaders());
}
private static class TestPrincipal implements Principal {
private String name;
public TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return this.name;
}
}
}

View File

@ -47,9 +47,7 @@ import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.ChannelInterceptorAdapter;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
@ -96,9 +94,6 @@ public class StompSubProtocolHandlerTests {
@Test
public void handleMessageToClientWithConnectedFrame() {
UserSessionRegistry registry = new DefaultUserSessionRegistry();
this.protocolHandler.setUserSessionRegistry(registry);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message);
@ -106,8 +101,6 @@ public class StompSubProtocolHandlerTests {
assertEquals(1, this.session.getSentMessages().size());
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());
assertEquals(Collections.singleton("s1"), registry.getSessionIds("joe"));
}
@Test
@ -115,9 +108,6 @@ public class StompSubProtocolHandlerTests {
this.session.setPrincipal(new UniqueUser("joe"));
UserSessionRegistry registry = new DefaultUserSessionRegistry();
this.protocolHandler.setUserSessionRegistry(registry);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message);
@ -125,9 +115,6 @@ public class StompSubProtocolHandlerTests {
assertEquals(1, this.session.getSentMessages().size());
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());
assertEquals(Collections.<String>emptySet(), registry.getSessionIds("joe"));
assertEquals(Collections.singleton("s1"), registry.getSessionIds("Me myself and I"));
}
@Test
@ -348,8 +335,6 @@ public class StompSubProtocolHandlerTests {
TestPublisher publisher = new TestPublisher();
UserSessionRegistry registry = new DefaultUserSessionRegistry();
this.protocolHandler.setUserSessionRegistry(registry);
this.protocolHandler.setApplicationEventPublisher(publisher);
this.protocolHandler.afterSessionStarted(this.session, this.channel);
@ -387,8 +372,6 @@ public class StompSubProtocolHandlerTests {
ApplicationEventPublisher publisher = mock(ApplicationEventPublisher.class);
UserSessionRegistry registry = new DefaultUserSessionRegistry();
this.protocolHandler.setUserSessionRegistry(registry);
this.protocolHandler.setApplicationEventPublisher(publisher);
this.protocolHandler.afterSessionStarted(this.session, this.channel);

View File

@ -4,7 +4,7 @@
xsi:schemaLocation="http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd
http://www.springframework.org/schema/websocket http://www.springframework.org/schema/websocket/spring-websocket.xsd">
<websocket:message-broker order="2" user-destination-broadcast="/topic/unresolved">
<websocket:message-broker order="2">
<websocket:stomp-endpoint path="/foo">
<websocket:sockjs/>
</websocket:stomp-endpoint>
@ -12,7 +12,9 @@
client-login="clientlogin" client-passcode="clientpass"
system-login="syslogin" system-passcode="syspass"
heartbeat-send-interval="5000" heartbeat-receive-interval="5000"
virtual-host="spring.io"/>
virtual-host="spring.io"
user-destination-broadcast="/topic/unresolved-user-destination"
user-registry-broadcast="/topic/simp-user-registry"/>
</websocket:message-broker>
<bean id="myHandler" class="org.springframework.web.socket.config.TestWebSocketHandler"/>