diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java index fc3d1b44f3f..305db97bab7 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -96,17 +96,13 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { return Collections.emptySet(); } - Set set = new HashSet(); - for (String sessionId : this.userSessionRegistry.getSessionIds(info.getUser())) { - set.add(getTargetDestination(headers.getDestination(), info.getDestination(), sessionId, info.getUser())); + Set result = new HashSet(); + for (String sessionId : info.getSessionIds()) { + result.add(getTargetDestination( + headers.getDestination(), info.getDestination(), sessionId, info.getUser())); } - return set; - } - protected String getTargetDestination(String originalDestination, String targetDestination, - String sessionId, String user) { - - return targetDestination + "-user" + sessionId; + return result; } private UserDestinationInfo getUserDestinationInfo(SimpMessageHeaderAccessor headers) { @@ -115,6 +111,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { String targetUser; String targetDestination; + Set sessionIds; Principal user = headers.getUser(); SimpMessageType messageType = headers.getMessageType(); @@ -124,11 +121,16 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { return null; } if (user == null) { - logger.warn("Ignoring message, no user information"); + logger.error("Ignoring message, no user info available"); + return null; + } + if (headers.getSessionId() == null) { + logger.error("Ignoring message, no session id available"); return null; } targetUser = user.getName(); targetDestination = destination.substring(this.destinationPrefix.length()-1); + sessionIds = Collections.singleton(headers.getSessionId()); } else if (SimpMessageType.MESSAGE.equals(messageType)) { if (!checkDestination(destination, this.destinationPrefix)) { @@ -139,7 +141,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { Assert.isTrue(endIndex > 0, "Expected destination pattern \"/user/{userId}/**\""); targetUser = destination.substring(startIndex, endIndex); targetDestination = destination.substring(endIndex); - + sessionIds = this.userSessionRegistry.getSessionIds(targetUser); } else { if (logger.isTraceEnabled()) { @@ -148,7 +150,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { return null; } - return new UserDestinationInfo(targetUser, targetDestination); + return new UserDestinationInfo(targetUser, targetDestination, sessionIds); } protected boolean checkDestination(String destination, String requiredPrefix) { @@ -165,6 +167,10 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { return true; } + protected String getTargetDestination(String origDestination, String targetDestination, String sessionId, String user) { + return targetDestination + "-user" + sessionId; + } + private static class UserDestinationInfo { @@ -172,18 +178,25 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { private final String destination; - private UserDestinationInfo(String user, String destination) { + private final Set sessionIds; + + private UserDestinationInfo(String user, String destination, Set sessionIds) { this.user = user; this.destination = destination; + this.sessionIds = sessionIds; } - private String getUser() { + public String getUser() { return this.user; } - private String getDestination() { + public String getDestination() { return this.destination; } + + public Set getSessionIds() { + return this.sessionIds; + } } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java index ed0f04735cc..bf4f4b50003 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,9 +22,6 @@ import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.TestPrincipal; -import org.springframework.messaging.simp.user.DefaultUserDestinationResolver; -import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; -import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.MessageBuilder; import java.util.Set; @@ -36,6 +33,8 @@ import static org.junit.Assert.assertEquals; */ public class DefaultUserDestinationResolverTests { + public static final String SESSION_ID = "123"; + private DefaultUserDestinationResolver resolver; private UserSessionRegistry registry; @@ -44,14 +43,30 @@ public class DefaultUserDestinationResolverTests { @Before public void setup() { this.registry = new DefaultUserSessionRegistry(); + this.registry.registerSessionId("joe", SESSION_ID); + this.resolver = new DefaultUserDestinationResolver(this.registry); } @Test public void handleSubscribe() { - Message message = createMessage(SimpMessageType.SUBSCRIBE, "joe", "/user/queue/foo"); - this.registry.registerSessionId("joe", "123"); + Message message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo"); + Set actual = this.resolver.resolveDestination(message); + + assertEquals(1, actual.size()); + assertEquals("/queue/foo-user123", actual.iterator().next()); + } + + // SPR-11325 + + @Test + public void handleSubscribeOneUserMultipleSessions() { + + this.registry.registerSessionId("joe", "456"); + this.registry.registerSessionId("joe", "789"); + + Message message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo"); Set actual = this.resolver.resolveDestination(message); assertEquals(1, actual.size()); @@ -60,8 +75,7 @@ public class DefaultUserDestinationResolverTests { @Test public void handleUnsubscribe() { - Message message = createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "/user/queue/foo"); - this.registry.registerSessionId("joe", "123"); + Message message = createMessage(SimpMessageType.UNSUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo"); Set actual = this.resolver.resolveDestination(message); assertEquals(1, actual.size()); @@ -70,8 +84,7 @@ public class DefaultUserDestinationResolverTests { @Test public void handleMessage() { - Message message = createMessage(SimpMessageType.MESSAGE, "joe", "/user/joe/queue/foo"); - this.registry.registerSessionId("joe", "123"); + Message message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/user/joe/queue/foo"); Set actual = this.resolver.resolveDestination(message); assertEquals(1, actual.size()); @@ -83,33 +96,33 @@ public class DefaultUserDestinationResolverTests { public void ignoreMessage() { // no destination - Message message = createMessage(SimpMessageType.MESSAGE, "joe", null); + Message message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, null); Set actual = this.resolver.resolveDestination(message); assertEquals(0, actual.size()); // not a user destination - message = createMessage(SimpMessageType.MESSAGE, "joe", "/queue/foo"); + message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/queue/foo"); actual = this.resolver.resolveDestination(message); assertEquals(0, actual.size()); // subscribe + no user - message = createMessage(SimpMessageType.SUBSCRIBE, null, "/user/queue/foo"); + message = createMessage(SimpMessageType.SUBSCRIBE, null, SESSION_ID, "/user/queue/foo"); actual = this.resolver.resolveDestination(message); assertEquals(0, actual.size()); // subscribe + not a user destination - message = createMessage(SimpMessageType.SUBSCRIBE, "joe", "/queue/foo"); + message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/queue/foo"); actual = this.resolver.resolveDestination(message); assertEquals(0, actual.size()); // no match on message type - message = createMessage(SimpMessageType.CONNECT, "joe", "user/joe/queue/foo"); + message = createMessage(SimpMessageType.CONNECT, "joe", SESSION_ID, "user/joe/queue/foo"); actual = this.resolver.resolveDestination(message); assertEquals(0, actual.size()); } - private Message createMessage(SimpMessageType messageType, String user, String destination) { + private Message createMessage(SimpMessageType messageType, String user, String sessionId, String destination) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType); if (destination != null) { headers.setDestination(destination); @@ -117,6 +130,9 @@ public class DefaultUserDestinationResolverTests { if (user != null) { headers.setUser(new TestPrincipal(user)); } + if (sessionId != null) { + headers.setSessionId(sessionId); + } return MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java index 0405a7918c0..bfc0ac224c1 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,10 +28,6 @@ import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.TestPrincipal; -import org.springframework.messaging.simp.user.DefaultUserDestinationResolver; -import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; -import org.springframework.messaging.simp.user.UserDestinationMessageHandler; -import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.MessageBuilder; import static org.junit.Assert.assertEquals; @@ -42,6 +38,7 @@ import static org.mockito.Mockito.*; */ public class UserDestinationMessageHandlerTests { + public static final String SESSION_ID = "123"; private UserDestinationMessageHandler messageHandler; @@ -63,9 +60,8 @@ public class UserDestinationMessageHandlerTests { @Test public void handleSubscribe() { - this.registry.registerSessionId("joe", "123"); when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true); - this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "/user/queue/foo")); + this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo")); ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); Mockito.verify(this.brokerChannel).send(captor.capture()); @@ -76,9 +72,8 @@ public class UserDestinationMessageHandlerTests { @Test public void handleUnsubscribe() { - this.registry.registerSessionId("joe", "123"); when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true); - this.messageHandler.handleMessage(createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "/user/queue/foo")); + this.messageHandler.handleMessage(createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "123", "/user/queue/foo")); ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); Mockito.verify(this.brokerChannel).send(captor.capture()); @@ -91,7 +86,7 @@ public class UserDestinationMessageHandlerTests { public void handleMessage() { this.registry.registerSessionId("joe", "123"); when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true); - this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "/user/joe/queue/foo")); + this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "123", "/user/joe/queue/foo")); ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); Mockito.verify(this.brokerChannel).send(captor.capture()); @@ -105,28 +100,28 @@ public class UserDestinationMessageHandlerTests { public void ignoreMessage() { // no destination - this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", null)); + this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "123", null)); Mockito.verifyZeroInteractions(this.brokerChannel); // not a user destination - this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "/queue/foo")); + this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "123", "/queue/foo")); Mockito.verifyZeroInteractions(this.brokerChannel); // subscribe + no user - this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, null, "/user/queue/foo")); + this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, null, "123", "/user/queue/foo")); Mockito.verifyZeroInteractions(this.brokerChannel); // subscribe + not a user destination - this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "/queue/foo")); + this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "123", "/queue/foo")); Mockito.verifyZeroInteractions(this.brokerChannel); // no match on message type - this.messageHandler.handleMessage(createMessage(SimpMessageType.CONNECT, "joe", "user/joe/queue/foo")); + this.messageHandler.handleMessage(createMessage(SimpMessageType.CONNECT, "joe", "123", "user/joe/queue/foo")); Mockito.verifyZeroInteractions(this.brokerChannel); } - private Message createMessage(SimpMessageType messageType, String user, String destination) { + private Message createMessage(SimpMessageType messageType, String user, String sessionId, String destination) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType); if (destination != null) { headers.setDestination(destination); @@ -134,6 +129,9 @@ public class UserDestinationMessageHandlerTests { if (user != null) { headers.setUser(new TestPrincipal(user)); } + if (sessionId != null) { + headers.setSessionId(sessionId); + } return MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); }