diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java index 498a8451b15..66326986928 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java @@ -140,7 +140,7 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe @Override public void addSubscription(String destination, String subscriptionId) { - CachingSessionSubscriptionRegistry.this.destinationCache.mapRegistration(destination, this.delegate); + destinationCache.mapRegistration(destination, this); this.delegate.addSubscription(destination, subscriptionId); } @@ -148,7 +148,7 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe public String removeSubscription(String subscriptionId) { String destination = this.delegate.removeSubscription(subscriptionId); if (destination != null && this.delegate.getSubscriptionsByDestination(destination) == null) { - CachingSessionSubscriptionRegistry.this.destinationCache.unmapRegistration(destination, this); + destinationCache.unmapRegistration(destination, this); } return destination; } @@ -163,6 +163,23 @@ public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRe return this.delegate.getDestinations(); } + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof CachingSessionSubscriptionRegistration)) { + return false; + } + CachingSessionSubscriptionRegistration otherType = (CachingSessionSubscriptionRegistration) other; + return this.delegate.equals(otherType.delegate); + } + + @Override + public int hashCode() { + return this.delegate.hashCode(); + } + @Override public String toString() { return "CachingSessionSubscriptionRegistration [delegate=" + delegate + "]"; diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistration.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistration.java index f2de00f0459..b0602835b2f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistration.java @@ -89,6 +89,22 @@ public class DefaultSessionSubscriptionRegistration implements SessionSubscripti return this.subscriptions.get(destination); } + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof DefaultSessionSubscriptionRegistration)) { + return false; + } + DefaultSessionSubscriptionRegistration otherType = (DefaultSessionSubscriptionRegistration) other; + return this.sessionId.equals(otherType.sessionId); + } + + @Override + public int hashCode() { + return 31 + this.sessionId.hashCode(); + } @Override public String toString() { diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandlerTests.java index 396affa7853..03532906847 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandlerTests.java @@ -78,12 +78,12 @@ public class SimpleBrokerWebMessageHandlerTests { this.messageHandler.handlePublish(createMessage("/bar", "message2")); verify(this.clientChannel, times(6)).send(this.messageCaptor.capture()); - assertCapturedMessage(this.messageCaptor.getAllValues().get(0), "sess1", "sub1", "/foo"); - assertCapturedMessage(this.messageCaptor.getAllValues().get(1), "sess1", "sub2", "/foo"); - assertCapturedMessage(this.messageCaptor.getAllValues().get(2), "sess2", "sub1", "/foo"); - assertCapturedMessage(this.messageCaptor.getAllValues().get(3), "sess2", "sub2", "/foo"); - assertCapturedMessage(this.messageCaptor.getAllValues().get(4), "sess1", "sub3", "/bar"); - assertCapturedMessage(this.messageCaptor.getAllValues().get(5), "sess2", "sub3", "/bar"); + assertCapturedMessage("sess1", "sub1", "/foo"); + assertCapturedMessage("sess1", "sub2", "/foo"); + assertCapturedMessage("sess2", "sub1", "/foo"); + assertCapturedMessage("sess2", "sub2", "/foo"); + assertCapturedMessage("sess1", "sub3", "/bar"); + assertCapturedMessage("sess2", "sub3", "/bar"); } @Test @@ -105,10 +105,13 @@ public class SimpleBrokerWebMessageHandlerTests { this.messageHandler.handlePublish(createMessage("/foo", "message1")); this.messageHandler.handlePublish(createMessage("/bar", "message2")); - verify(this.clientChannel, times(3)).send(this.messageCaptor.capture()); - assertCapturedMessage(this.messageCaptor.getAllValues().get(0), "sess2", "sub1", "/foo"); - assertCapturedMessage(this.messageCaptor.getAllValues().get(1), "sess2", "sub2", "/foo"); - assertCapturedMessage(this.messageCaptor.getAllValues().get(2), "sess2", "sub3", "/bar"); + verify(this.clientChannel, times(6)).send(this.messageCaptor.capture()); + assertCapturedMessage("sess1", "sub1", "/foo"); + assertCapturedMessage("sess1", "sub2", "/foo"); + assertCapturedMessage("sess2", "sub1", "/foo"); + assertCapturedMessage("sess2", "sub2", "/foo"); + assertCapturedMessage("sess1", "sub3", "/bar"); + assertCapturedMessage("sess2", "sub3", "/bar"); } @@ -130,13 +133,18 @@ public class SimpleBrokerWebMessageHandlerTests { return MessageBuilder.withPayload(payload).copyHeaders(headers.toMap()).build(); } - protected void assertCapturedMessage(Message message, String sessionId, - String subcriptionId, String destination) { - - WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message); - assertEquals(sessionId, headers.getSessionId()); - assertEquals(subcriptionId, headers.getSubscriptionId()); - assertEquals(destination, headers.getDestination()); + protected boolean assertCapturedMessage(String sessionId, String subcriptionId, String destination) { + for (Message message : this.messageCaptor.getAllValues()) { + WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message); + if (sessionId.equals(headers.getSessionId())) { + if (subcriptionId.equals(headers.getSubscriptionId())) { + if (destination.equals(headers.getDestination())) { + return true; + } + } + } + } + return false; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java index 0fb729e08b0..e7e878ffe10 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java @@ -16,6 +16,7 @@ package org.springframework.web.messaging.stomp.support; import java.util.Collections; +import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -23,6 +24,7 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.stomp.StompCommand; +import org.springframework.web.messaging.support.WebMessageHeaderAccesssor; import static org.junit.Assert.*; @@ -53,7 +55,14 @@ public class StompMessageConverterTests { MessageHeaders headers = message.getHeaders(); StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - assertEquals(7, stompHeaders.toMap().size()); + Map map = stompHeaders.toMap(); + assertEquals(6, map.size()); + assertNotNull(map.get(MessageHeaders.ID)); + assertNotNull(map.get(MessageHeaders.TIMESTAMP)); + assertNotNull(map.get(WebMessageHeaderAccesssor.SESSION_ID)); + assertNotNull(map.get(WebMessageHeaderAccesssor.NATIVE_HEADERS)); + assertNotNull(map.get(WebMessageHeaderAccesssor.MESSAGE_TYPE)); + assertNotNull(map.get(WebMessageHeaderAccesssor.PROTOCOL_MESSAGE_TYPE)); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); assertEquals("github.org", stompHeaders.getHost()); diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java index 770a7f40a7a..aed8431b6af 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java @@ -47,11 +47,9 @@ public class CachingSessionSubscriptionRegistryTests { SessionSubscriptionRegistration reg1 = this.registry.getOrCreateRegistration("sess1"); reg1.addSubscription("/foo", "sub1"); - reg1.addSubscription("/foo", "sub1"); SessionSubscriptionRegistration reg2 = this.registry.getOrCreateRegistration("sess2"); reg2.addSubscription("/foo", "sub1"); - reg2.addSubscription("/foo", "sub1"); Set actual = this.registry.getRegistrationsByDestination("/foo"); assertEquals(2, actual.size()); @@ -59,14 +57,12 @@ public class CachingSessionSubscriptionRegistryTests { assertTrue(actual.contains(reg2)); reg1.removeSubscription("sub1"); - reg1.removeSubscription("sub2"); actual = this.registry.getRegistrationsByDestination("/foo"); assertEquals("Invalid set of registrations " + actual, 1, actual.size()); assertTrue(actual.contains(reg2)); reg2.removeSubscription("sub1"); - reg2.removeSubscription("sub2"); actual = this.registry.getRegistrationsByDestination("/foo"); assertNull("Unexpected registrations " + actual, actual);