diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java index a7193e8e861..4e15a242fae 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java @@ -121,12 +121,12 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { @Override public UserDestinationResult resolveDestination(Message message) { - String sourceDestination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); ParseResult parseResult = parse(message); if (parseResult == null) { return null; } String user = parseResult.getUser(); + String sourceDestination = parseResult.getSourceDestination(); Set targetSet = new HashSet<>(); for (String sessionId : parseResult.getSessionIds()) { String actualDestination = parseResult.getActualDestination(); @@ -142,65 +142,81 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { @Nullable private ParseResult parse(Message message) { MessageHeaders headers = message.getHeaders(); - String destination = SimpMessageHeaderAccessor.getDestination(headers); - if (destination == null || !checkDestination(destination, this.prefix)) { + String sourceDestination = SimpMessageHeaderAccessor.getDestination(headers); + if (sourceDestination == null || !checkDestination(sourceDestination, this.prefix)) { return null; } SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); - Principal principal = SimpMessageHeaderAccessor.getUser(headers); - String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); - if (SimpMessageType.SUBSCRIBE.equals(messageType) || SimpMessageType.UNSUBSCRIBE.equals(messageType)) { - if (sessionId == null) { - logger.error("No session id. Ignoring " + message); + switch (messageType) { + case SUBSCRIBE: + case UNSUBSCRIBE: + return parseSubscriptionMessage(message, headers, sourceDestination); + case MESSAGE: + return parseMessage(headers, sourceDestination); + default: return null; - } - int prefixEnd = this.prefix.length() - 1; - String actualDestination = destination.substring(prefixEnd); - if (!this.keepLeadingSlash) { - actualDestination = actualDestination.substring(1); - } - String user = (principal != null ? principal.getName() : null); - return new ParseResult(actualDestination, destination, Collections.singleton(sessionId), user); } - else if (SimpMessageType.MESSAGE.equals(messageType)) { - int prefixEnd = this.prefix.length(); - int userEnd = destination.indexOf('/', prefixEnd); - Assert.isTrue(userEnd > 0, "Expected destination pattern \"/user/{userId}/**\""); - String actualDestination = destination.substring(userEnd); - String subscribeDestination = this.prefix.substring(0, prefixEnd - 1) + actualDestination; - String userName = destination.substring(prefixEnd, userEnd); - userName = StringUtils.replace(userName, "%2F", "/"); - Set sessionIds; - if (userName.equals(sessionId)) { - userName = null; + } + + private ParseResult parseSubscriptionMessage(Message message, MessageHeaders headers, String sourceDestination) { + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); + if (sessionId == null) { + logger.error("No session id. Ignoring " + message); + return null; + } + int prefixEnd = this.prefix.length() - 1; + String actualDestination = sourceDestination.substring(prefixEnd); + if (!this.keepLeadingSlash) { + actualDestination = actualDestination.substring(1); + } + Principal principal = SimpMessageHeaderAccessor.getUser(headers); + String user = (principal != null ? principal.getName() : null); + Set sessionIds = Collections.singleton(sessionId); + return new ParseResult(sourceDestination, actualDestination, sourceDestination, sessionIds, user); + } + + private ParseResult parseMessage(MessageHeaders headers, String sourceDestination) { + int prefixEnd = this.prefix.length(); + int userEnd = sourceDestination.indexOf('/', prefixEnd); + Assert.isTrue(userEnd > 0, "Expected destination pattern \"/user/{userId}/**\""); + String actualDestination = sourceDestination.substring(userEnd); + String subscribeDestination = this.prefix.substring(0, prefixEnd - 1) + actualDestination; + String userName = sourceDestination.substring(prefixEnd, userEnd); + userName = StringUtils.replace(userName, "%2F", "/"); + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); + Set sessionIds; + if (userName.equals(sessionId)) { + userName = null; + sessionIds = Collections.singleton(sessionId); + } + else { + sessionIds = getSessionIdsByUser(userName, sessionId); + } + if (!this.keepLeadingSlash) { + actualDestination = actualDestination.substring(1); + } + return new ParseResult(sourceDestination, actualDestination, subscribeDestination, sessionIds, userName); + } + + private Set getSessionIdsByUser(String userName, String sessionId) { + Set sessionIds; + SimpUser user = this.userRegistry.getUser(userName); + if (user != null) { + if (user.getSession(sessionId) != null) { sessionIds = Collections.singleton(sessionId); } else { - SimpUser user = this.userRegistry.getUser(userName); - if (user != null) { - if (user.getSession(sessionId) != null) { - sessionIds = Collections.singleton(sessionId); - } - else { - Set sessions = user.getSessions(); - sessionIds = new HashSet<>(sessions.size()); - for (SimpSession session : sessions) { - sessionIds.add(session.getId()); - } - } - } - else { - sessionIds = Collections.emptySet(); + Set sessions = user.getSessions(); + sessionIds = new HashSet<>(sessions.size()); + for (SimpSession session : sessions) { + sessionIds.add(session.getId()); } } - if (!this.keepLeadingSlash) { - actualDestination = actualDestination.substring(1); - } - return new ParseResult(actualDestination, subscribeDestination, sessionIds, userName); } else { - return null; + sessionIds = Collections.emptySet(); } + return sessionIds; } protected boolean checkDestination(String destination, String requiredPrefix) { @@ -243,8 +259,11 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { private final String user; + private final String sourceDestination; - public ParseResult(String actualDest, String subscribeDest, Set sessionIds, String user) { + public ParseResult(String sourceDest, String actualDest, String subscribeDest, + Set sessionIds, String user) { + this.sourceDestination = sourceDest; this.actualDestination = actualDest; this.subscribeDestination = subscribeDest; this.sessionIds = sessionIds; @@ -267,6 +286,10 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { public String getUser() { return this.user; } + + public String getSourceDestination() { + return this.sourceDestination; + } } }