diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java index 12162ce51d..cde55e78ec 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java @@ -135,6 +135,16 @@ public class MultiServerUserRegistry implements SimpUserRegistry, SmartApplicati return result; } + @Override + public int getUserCount() { + int userCount = 0; + for (UserRegistrySnapshot registry : this.remoteRegistries.values()) { + userCount += registry.getUserMap().size(); + } + userCount += this.localRegistry.getUserCount(); + return userCount; + } + @Override public Set findSubscriptions(SimpSubscriptionMatcher matcher) { Set result = new HashSet<>(); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java index a989f628fa..b76be434cc 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java @@ -40,6 +40,13 @@ public interface SimpUserRegistry { */ Set getUsers(); + /** + * Return the count of all connected users. + * @return the connected user count. + * @since 4.3.5 + */ + int getUserCount(); + /** * Find subscriptions with the given matcher. * @param matcher the matcher to use diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java index 5ce5dd708a..2e93ce2083 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java @@ -61,9 +61,10 @@ public class MultiServerUserRegistryTests { SimpUser user = Mockito.mock(SimpUser.class); Set users = Collections.singleton(user); when(this.localRegistry.getUsers()).thenReturn(users); + when(this.localRegistry.getUserCount()).thenReturn(1); when(this.localRegistry.getUser("joe")).thenReturn(user); - assertEquals(1, this.registry.getUsers().size()); + assertEquals(1, this.registry.getUserCount()); assertSame(user, this.registry.getUser("joe")); } @@ -84,7 +85,7 @@ public class MultiServerUserRegistryTests { this.registry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(1, this.registry.getUsers().size()); + assertEquals(1, this.registry.getUserCount()); SimpUser user = this.registry.getUser("joe"); assertNotNull(user); assertTrue(user.hasSessions()); @@ -125,7 +126,7 @@ public class MultiServerUserRegistryTests { this.registry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(3, this.registry.getUsers().size()); + assertEquals(3, this.registry.getUserCount()); Set matches = this.registry.findSubscriptions(s -> s.getDestination().equals("/match")); assertEquals(2, matches.size()); Iterator iterator = matches.iterator(); @@ -157,7 +158,7 @@ public class MultiServerUserRegistryTests { this.registry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(1, this.registry.getUsers().size()); + assertEquals(1, this.registry.getUserCount()); SimpUser user = this.registry.getUsers().iterator().next(); assertTrue(user.hasSessions()); assertEquals(2, user.getSessions().size()); @@ -187,9 +188,9 @@ public class MultiServerUserRegistryTests { this.registry.addRemoteRegistryDto(message, this.converter, -1); - assertEquals(1, this.registry.getUsers().size()); + assertEquals(1, this.registry.getUserCount()); this.registry.purgeExpiredRegistries(); - assertEquals(0, this.registry.getUsers().size()); + assertEquals(0, this.registry.getUserCount()); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java index 4f441c9e50..790181be25 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java @@ -126,7 +126,7 @@ public class UserRegistryMessageHandlerTests { MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(mock(SimpUserRegistry.class)); remoteRegistry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(2, remoteRegistry.getUsers().size()); + assertEquals(2, remoteRegistry.getUserCount()); assertNotNull(remoteRegistry.getUser("joe")); assertNotNull(remoteRegistry.getUser("jane")); } @@ -142,6 +142,7 @@ public class UserRegistryMessageHandlerTests { HashSet simpUsers = new HashSet<>(Arrays.asList(simpUser1, simpUser2)); SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class); + when(remoteUserRegistry.getUserCount()).thenReturn(2); when(remoteUserRegistry.getUsers()).thenReturn(simpUsers); MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry); @@ -149,7 +150,7 @@ public class UserRegistryMessageHandlerTests { this.handler.handleMessage(message); - assertEquals(2, remoteRegistry.getUsers().size()); + assertEquals(2, remoteRegistry.getUserCount()); assertNotNull(this.multiServerRegistry.getUser("joe")); assertNotNull(this.multiServerRegistry.getUser("jane")); } @@ -159,13 +160,14 @@ public class UserRegistryMessageHandlerTests { TestSimpUser simpUser = new TestSimpUser("joe"); simpUser.addSessions(new TestSimpSession("123")); + when(this.localRegistry.getUserCount()).thenReturn(1); when(this.localRegistry.getUsers()).thenReturn(Collections.singleton(simpUser)); - assertEquals(1, this.multiServerRegistry.getUsers().size()); + assertEquals(1, this.multiServerRegistry.getUserCount()); Message message = this.converter.toMessage(this.multiServerRegistry.getLocalRegistryDto(), null); this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, 20000); - assertEquals(1, this.multiServerRegistry.getUsers().size()); + assertEquals(1, this.multiServerRegistry.getUserCount()); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java index 753e053a0d..6889a6fbe3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java @@ -141,6 +141,11 @@ public class DefaultSimpUserRegistry implements SimpUserRegistry, SmartApplicati return new HashSet<>(this.users.values()); } + @Override + public int getUserCount() { + return this.users.size(); + } + public Set findSubscriptions(SimpSubscriptionMatcher matcher) { Set result = new HashSet<>(); for (LocalSimpSession session : this.sessions.values()) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java index cbccc5a567..bd43f95e13 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java @@ -57,6 +57,7 @@ public class DefaultSimpUserRegistryTests { SimpUser simpUser = registry.getUser("joe"); assertNotNull(simpUser); + assertEquals(1, registry.getUserCount()); assertEquals(1, simpUser.getSessions().size()); assertNotNull(simpUser.getSession("123")); } @@ -82,6 +83,7 @@ public class DefaultSimpUserRegistryTests { SimpUser simpUser = registry.getUser("joe"); assertNotNull(simpUser); + assertEquals(1, registry.getUserCount()); assertEquals(3, simpUser.getSessions().size()); assertNotNull(simpUser.getSession("123")); assertNotNull(simpUser.getSession("456"));