Fix issue in DefaultUserDestinationResolver

DefaultUserDestinationResolver now uses the session id of
SUBSCRIBE/UNSUBSCRIBE messages rather than looking up all session id's
associated with a user.

Issue: SPR-11325
This commit is contained in:
Rossen Stoyanchev 2014-01-20 15:41:22 -05:00
parent 809a5f59b3
commit b4e48d6749
3 changed files with 75 additions and 48 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -96,17 +96,13 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return Collections.emptySet();
}
Set<String> set = new HashSet<String>();
for (String sessionId : this.userSessionRegistry.getSessionIds(info.getUser())) {
set.add(getTargetDestination(headers.getDestination(), info.getDestination(), sessionId, info.getUser()));
Set<String> result = new HashSet<String>();
for (String sessionId : info.getSessionIds()) {
result.add(getTargetDestination(
headers.getDestination(), info.getDestination(), sessionId, info.getUser()));
}
return set;
}
protected String getTargetDestination(String originalDestination, String targetDestination,
String sessionId, String user) {
return targetDestination + "-user" + sessionId;
return result;
}
private UserDestinationInfo getUserDestinationInfo(SimpMessageHeaderAccessor headers) {
@ -115,6 +111,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
String targetUser;
String targetDestination;
Set<String> sessionIds;
Principal user = headers.getUser();
SimpMessageType messageType = headers.getMessageType();
@ -124,11 +121,16 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return null;
}
if (user == null) {
logger.warn("Ignoring message, no user information");
logger.error("Ignoring message, no user info available");
return null;
}
if (headers.getSessionId() == null) {
logger.error("Ignoring message, no session id available");
return null;
}
targetUser = user.getName();
targetDestination = destination.substring(this.destinationPrefix.length()-1);
sessionIds = Collections.singleton(headers.getSessionId());
}
else if (SimpMessageType.MESSAGE.equals(messageType)) {
if (!checkDestination(destination, this.destinationPrefix)) {
@ -139,7 +141,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
Assert.isTrue(endIndex > 0, "Expected destination pattern \"/user/{userId}/**\"");
targetUser = destination.substring(startIndex, endIndex);
targetDestination = destination.substring(endIndex);
sessionIds = this.userSessionRegistry.getSessionIds(targetUser);
}
else {
if (logger.isTraceEnabled()) {
@ -148,7 +150,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return null;
}
return new UserDestinationInfo(targetUser, targetDestination);
return new UserDestinationInfo(targetUser, targetDestination, sessionIds);
}
protected boolean checkDestination(String destination, String requiredPrefix) {
@ -165,6 +167,10 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return true;
}
protected String getTargetDestination(String origDestination, String targetDestination, String sessionId, String user) {
return targetDestination + "-user" + sessionId;
}
private static class UserDestinationInfo {
@ -172,18 +178,25 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
private final String destination;
private UserDestinationInfo(String user, String destination) {
private final Set<String> sessionIds;
private UserDestinationInfo(String user, String destination, Set<String> sessionIds) {
this.user = user;
this.destination = destination;
this.sessionIds = sessionIds;
}
private String getUser() {
public String getUser() {
return this.user;
}
private String getDestination() {
public String getDestination() {
return this.destination;
}
public Set<String> getSessionIds() {
return this.sessionIds;
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,9 +22,6 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.simp.user.DefaultUserDestinationResolver;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder;
import java.util.Set;
@ -36,6 +33,8 @@ import static org.junit.Assert.assertEquals;
*/
public class DefaultUserDestinationResolverTests {
public static final String SESSION_ID = "123";
private DefaultUserDestinationResolver resolver;
private UserSessionRegistry registry;
@ -44,14 +43,30 @@ public class DefaultUserDestinationResolverTests {
@Before
public void setup() {
this.registry = new DefaultUserSessionRegistry();
this.registry.registerSessionId("joe", SESSION_ID);
this.resolver = new DefaultUserDestinationResolver(this.registry);
}
@Test
public void handleSubscribe() {
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", "/user/queue/foo");
this.registry.registerSessionId("joe", "123");
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size());
assertEquals("/queue/foo-user123", actual.iterator().next());
}
// SPR-11325
@Test
public void handleSubscribeOneUserMultipleSessions() {
this.registry.registerSessionId("joe", "456");
this.registry.registerSessionId("joe", "789");
Message<?> message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size());
@ -60,8 +75,7 @@ public class DefaultUserDestinationResolverTests {
@Test
public void handleUnsubscribe() {
Message<?> message = createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "/user/queue/foo");
this.registry.registerSessionId("joe", "123");
Message<?> message = createMessage(SimpMessageType.UNSUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo");
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size());
@ -70,8 +84,7 @@ public class DefaultUserDestinationResolverTests {
@Test
public void handleMessage() {
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", "/user/joe/queue/foo");
this.registry.registerSessionId("joe", "123");
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/user/joe/queue/foo");
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(1, actual.size());
@ -83,33 +96,33 @@ public class DefaultUserDestinationResolverTests {
public void ignoreMessage() {
// no destination
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", null);
Message<?> message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, null);
Set<String> actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size());
// not a user destination
message = createMessage(SimpMessageType.MESSAGE, "joe", "/queue/foo");
message = createMessage(SimpMessageType.MESSAGE, "joe", SESSION_ID, "/queue/foo");
actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size());
// subscribe + no user
message = createMessage(SimpMessageType.SUBSCRIBE, null, "/user/queue/foo");
message = createMessage(SimpMessageType.SUBSCRIBE, null, SESSION_ID, "/user/queue/foo");
actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size());
// subscribe + not a user destination
message = createMessage(SimpMessageType.SUBSCRIBE, "joe", "/queue/foo");
message = createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/queue/foo");
actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size());
// no match on message type
message = createMessage(SimpMessageType.CONNECT, "joe", "user/joe/queue/foo");
message = createMessage(SimpMessageType.CONNECT, "joe", SESSION_ID, "user/joe/queue/foo");
actual = this.resolver.resolveDestination(message);
assertEquals(0, actual.size());
}
private Message<?> createMessage(SimpMessageType messageType, String user, String destination) {
private Message<?> createMessage(SimpMessageType messageType, String user, String sessionId, String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType);
if (destination != null) {
headers.setDestination(destination);
@ -117,6 +130,9 @@ public class DefaultUserDestinationResolverTests {
if (user != null) {
headers.setUser(new TestPrincipal(user));
}
if (sessionId != null) {
headers.setSessionId(sessionId);
}
return MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,10 +28,6 @@ import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.simp.user.DefaultUserDestinationResolver;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder;
import static org.junit.Assert.assertEquals;
@ -42,6 +38,7 @@ import static org.mockito.Mockito.*;
*/
public class UserDestinationMessageHandlerTests {
public static final String SESSION_ID = "123";
private UserDestinationMessageHandler messageHandler;
@ -63,9 +60,8 @@ public class UserDestinationMessageHandlerTests {
@Test
public void handleSubscribe() {
this.registry.registerSessionId("joe", "123");
when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true);
this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "/user/queue/foo"));
this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", SESSION_ID, "/user/queue/foo"));
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture());
@ -76,9 +72,8 @@ public class UserDestinationMessageHandlerTests {
@Test
public void handleUnsubscribe() {
this.registry.registerSessionId("joe", "123");
when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true);
this.messageHandler.handleMessage(createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "/user/queue/foo"));
this.messageHandler.handleMessage(createMessage(SimpMessageType.UNSUBSCRIBE, "joe", "123", "/user/queue/foo"));
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture());
@ -91,7 +86,7 @@ public class UserDestinationMessageHandlerTests {
public void handleMessage() {
this.registry.registerSessionId("joe", "123");
when(this.brokerChannel.send(Mockito.any(Message.class))).thenReturn(true);
this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "/user/joe/queue/foo"));
this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "123", "/user/joe/queue/foo"));
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture());
@ -105,28 +100,28 @@ public class UserDestinationMessageHandlerTests {
public void ignoreMessage() {
// no destination
this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", null));
this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "123", null));
Mockito.verifyZeroInteractions(this.brokerChannel);
// not a user destination
this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "/queue/foo"));
this.messageHandler.handleMessage(createMessage(SimpMessageType.MESSAGE, "joe", "123", "/queue/foo"));
Mockito.verifyZeroInteractions(this.brokerChannel);
// subscribe + no user
this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, null, "/user/queue/foo"));
this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, null, "123", "/user/queue/foo"));
Mockito.verifyZeroInteractions(this.brokerChannel);
// subscribe + not a user destination
this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "/queue/foo"));
this.messageHandler.handleMessage(createMessage(SimpMessageType.SUBSCRIBE, "joe", "123", "/queue/foo"));
Mockito.verifyZeroInteractions(this.brokerChannel);
// no match on message type
this.messageHandler.handleMessage(createMessage(SimpMessageType.CONNECT, "joe", "user/joe/queue/foo"));
this.messageHandler.handleMessage(createMessage(SimpMessageType.CONNECT, "joe", "123", "user/joe/queue/foo"));
Mockito.verifyZeroInteractions(this.brokerChannel);
}
private Message<?> createMessage(SimpMessageType messageType, String user, String destination) {
private Message<?> createMessage(SimpMessageType messageType, String user, String sessionId, String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(messageType);
if (destination != null) {
headers.setDestination(destination);
@ -134,6 +129,9 @@ public class UserDestinationMessageHandlerTests {
if (user != null) {
headers.setUser(new TestPrincipal(user));
}
if (sessionId != null) {
headers.setSessionId(sessionId);
}
return MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
}