From 0a68c9930fcd5b94d85c35329603960e1a898a2e Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 3 Jul 2013 20:40:56 -0400 Subject: [PATCH] Add "simple" broker and SessionSubscriptionRegistry SimpleBrokerWebMessageHandler can be used as an alternative to the StompRelayWebMessageHandler. --- .../SessionSubscriptionRegistration.java | 43 +++++ .../SessionSubscriptionRegistry.java | 36 ++++ .../service/ReactorWebMessageHandler.java | 164 ----------------- .../SimpleBrokerWebMessageHandler.java | 126 +++++++++++++ .../stomp/support/StompWebSocketHandler.java | 75 ++++---- .../CachingSessionSubscriptionRegistry.java | 172 ++++++++++++++++++ ...efaultSessionSubscriptionRegistration.java | 99 ++++++++++ .../DefaultSessionSubscriptionRegistry.java | 66 +++++++ .../web/messaging/support/MessageHolder.java | 2 + .../SimpleBrokerWebMessageHandlerTests.java | 142 +++++++++++++++ ...chingSessionSubscriptionRegistryTests.java | 75 ++++++++ ...tSessionSubscriptionRegistrationTests.java | 82 +++++++++ ...faultSessionSubscriptionRegistryTests.java | 86 +++++++++ 13 files changed, 963 insertions(+), 205 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistration.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistry.java delete mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorWebMessageHandler.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandler.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistration.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistry.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandlerTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistrationTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistryTests.java diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistration.java b/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistration.java new file mode 100644 index 0000000000..0fe7ee3ade --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistration.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging; + +import java.util.Set; + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public interface SessionSubscriptionRegistration { + + + String getSessionId(); + + void addSubscription(String destination, String subscriptionId); + + /** + * @param subscriptionId the subscription to remove + * @return the destination to which the subscriptionId was registered, or {@code null} + * if no matching subscriptionId was found + */ + String removeSubscription(String subscriptionId); + + Set getSubscriptionsByDestination(String destination); + + Set getDestinations(); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistry.java b/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistry.java new file mode 100644 index 0000000000..03ee40b5e3 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/SessionSubscriptionRegistry.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging; + +import java.util.Set; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public interface SessionSubscriptionRegistry { + + SessionSubscriptionRegistration getRegistration(String sessionId); + + SessionSubscriptionRegistration getOrCreateRegistration(String sessionId); + + SessionSubscriptionRegistration removeRegistration(String sessionId); + + Set getSessionSubscriptions(String sessionId, String destination); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorWebMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorWebMessageHandler.java deleted file mode 100644 index c60eaf4c1c..0000000000 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorWebMessageHandler.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.messaging.service; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.support.MessageBuilder; -import org.springframework.util.Assert; -import org.springframework.web.messaging.MessageType; -import org.springframework.web.messaging.converter.CompositeMessageConverter; -import org.springframework.web.messaging.converter.MessageConverter; -import org.springframework.web.messaging.support.WebMessageHeaderAccesssor; - -import reactor.core.Reactor; -import reactor.fn.Consumer; -import reactor.fn.Event; -import reactor.fn.registry.Registration; -import reactor.fn.selector.ObjectSelector; - - -/** - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class ReactorWebMessageHandler extends AbstractWebMessageHandler { - - private MessageChannel clientChannel; - - private final Reactor reactor; - - private MessageConverter payloadConverter; - - private Map>> subscriptionsBySession = new ConcurrentHashMap>>(); - - - /** - * @param clientChannel the channel to which messages for clients should be sent. - */ - public ReactorWebMessageHandler(MessageChannel clientChannel, Reactor reactor) { - Assert.notNull(clientChannel, "clientChannel is required"); - this.clientChannel = clientChannel; - this.reactor = reactor; - this.payloadConverter = new CompositeMessageConverter(null); - } - - public void setMessageConverters(List converters) { - this.payloadConverter = new CompositeMessageConverter(converters); - } - - @Override - protected Collection getSupportedMessageTypes() { - return Arrays.asList(MessageType.MESSAGE, MessageType.SUBSCRIBE, MessageType.UNSUBSCRIBE); - } - - @Override - public void handleSubscribe(Message message) { - - if (logger.isDebugEnabled()) { - logger.debug("Subscribe " + message); - } - - WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message); - String subscriptionId = headers.getSubscriptionId(); - BroadcastingConsumer consumer = new BroadcastingConsumer(subscriptionId); - - String key = getPublishKey(headers.getDestination()); - Registration registration = this.reactor.on(new ObjectSelector(key), consumer); - - String sessionId = headers.getSessionId(); - List> list = this.subscriptionsBySession.get(sessionId); - if (list == null) { - list = new ArrayList>(); - this.subscriptionsBySession.put(sessionId, list); - } - list.add(registration); - } - - private String getPublishKey(String destination) { - return "destination:" + destination; - } - - @Override - public void handlePublish(Message message) { - - if (logger.isDebugEnabled()) { - logger.debug("Message received: " + message); - } - - try { - // Convert to byte[] payload before the fan-out - WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message); - byte[] payload = payloadConverter.convertToPayload(message.getPayload(), headers.getContentType()); - Message m = MessageBuilder.withPayload(payload).copyHeaders(message.getHeaders()).build(); - - this.reactor.notify(getPublishKey(headers.getDestination()), Event.wrap(m)); - } - catch (Exception ex) { - logger.error("Failed to publish " + message, ex); - } - } - - @Override - public void handleDisconnect(Message message) { - WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message); - removeSubscriptions(headers.getSessionId()); - } - - private void removeSubscriptions(String sessionId) { - List> registrations = this.subscriptionsBySession.remove(sessionId); - if (logger.isTraceEnabled()) { - logger.trace("Cancelling " + registrations.size() + " subscriptions for session=" + sessionId); - } - for (Registration registration : registrations) { - registration.cancel(); - } - } - - - private final class BroadcastingConsumer implements Consumer>> { - - private final String subscriptionId; - - - private BroadcastingConsumer(String subscriptionId) { - this.subscriptionId = subscriptionId; - } - - @Override - public void accept(Event> event) { - - Message sentMessage = event.getData(); - - WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(sentMessage); - headers.setSubscriptionId(this.subscriptionId); - - Message clientMessage = MessageBuilder.withPayload( - sentMessage.getPayload()).copyHeaders(headers.toMap()).build(); - - clientChannel.send(clientMessage); - } - } - -} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandler.java new file mode 100644 index 0000000000..505bb0b7fb --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandler.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.service; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Set; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.Assert; +import org.springframework.web.messaging.MessageType; +import org.springframework.web.messaging.SessionSubscriptionRegistration; +import org.springframework.web.messaging.SessionSubscriptionRegistry; +import org.springframework.web.messaging.support.CachingSessionSubscriptionRegistry; +import org.springframework.web.messaging.support.DefaultSessionSubscriptionRegistry; +import org.springframework.web.messaging.support.WebMessageHeaderAccesssor; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler { + + private final MessageChannel clientChannel; + + private CachingSessionSubscriptionRegistry subscriptionRegistry= + new CachingSessionSubscriptionRegistry(new DefaultSessionSubscriptionRegistry()); + + + /** + * @param clientChannel the channel to which messages for clients should be sent + * @param observable an Observable to use to manage subscriptions + */ + public SimpleBrokerWebMessageHandler(MessageChannel clientChannel) { + Assert.notNull(clientChannel, "clientChannel is required"); + this.clientChannel = clientChannel; + } + + + public void setSubscriptionRegistry(SessionSubscriptionRegistry subscriptionRegistry) { + Assert.notNull(subscriptionRegistry, "subscriptionRegistry is required"); + this.subscriptionRegistry = new CachingSessionSubscriptionRegistry(subscriptionRegistry); + } + + @Override + protected Collection getSupportedMessageTypes() { + return Arrays.asList(MessageType.MESSAGE, MessageType.SUBSCRIBE, MessageType.UNSUBSCRIBE); + } + + @Override + public void handleSubscribe(Message message) { + + if (logger.isDebugEnabled()) { + logger.debug("Subscribe " + message); + } + + WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message); + String sessionId = headers.getSessionId(); + String subscriptionId = headers.getSubscriptionId(); + String destination = headers.getDestination(); + + SessionSubscriptionRegistration registration = this.subscriptionRegistry.getOrCreateRegistration(sessionId); + registration.addSubscription(destination, subscriptionId); + } + + @Override + public void handlePublish(Message message) { + + if (logger.isDebugEnabled()) { + logger.debug("Message received: " + message); + } + + String destination = WebMessageHeaderAccesssor.wrap(message).getDestination(); + + Set registrations = + this.subscriptionRegistry.getRegistrationsByDestination(destination); + + if (registrations == null) { + return; + } + + for (SessionSubscriptionRegistration registration : registrations) { + for (String subscriptionId : registration.getSubscriptionsByDestination(destination)) { + + WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message); + headers.setSessionId(registration.getSessionId()); + headers.setSubscriptionId(subscriptionId); + + Message clientMessage = MessageBuilder.withPayload( + message.getPayload()).copyHeaders(headers.toMap()).build(); + + try { + this.clientChannel.send(clientMessage); + } + catch (Throwable ex) { + logger.error("Failed to send message to destination=" + destination + + ", sessionId=" + registration.getSessionId() + ", subscriptionId=" + subscriptionId, ex); + } + } + } + } + + @Override + public void handleDisconnect(Message message) { + String sessionId = WebMessageHeaderAccesssor.wrap(message).getSessionId(); + this.subscriptionRegistry.removeRegistration(sessionId); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java index fc1e58c496..844bb9b599 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java @@ -17,7 +17,6 @@ package org.springframework.web.messaging.stomp.support; import java.io.IOException; import java.nio.charset.Charset; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -30,13 +29,15 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.support.MessageBuilder; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; +import org.springframework.util.CollectionUtils; import org.springframework.web.messaging.MessageType; +import org.springframework.web.messaging.SessionSubscriptionRegistration; +import org.springframework.web.messaging.SessionSubscriptionRegistry; import org.springframework.web.messaging.converter.CompositeMessageConverter; import org.springframework.web.messaging.converter.MessageConverter; import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.StompConversionException; +import org.springframework.web.messaging.support.DefaultSessionSubscriptionRegistry; import org.springframework.web.messaging.support.WebMessageHeaderAccesssor; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; @@ -60,7 +61,9 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); - private final Map sessionInfos = new ConcurrentHashMap(); + private final Map sessions = new ConcurrentHashMap(); + + private SessionSubscriptionRegistry subscriptionRegistry = new DefaultSessionSubscriptionRegistry(); private MessageConverter payloadConverter = new CompositeMessageConverter(null); @@ -74,6 +77,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement this.outputChannel = outputChannel; } + public void setMessageConverters(List converters) { this.payloadConverter = new CompositeMessageConverter(converters); } @@ -82,11 +86,15 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement return this.stompMessageConverter; } + public void setSubscriptionRegistry(SessionSubscriptionRegistry subscriptionRegistry) { + this.subscriptionRegistry = subscriptionRegistry; + } + @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { Assert.notNull(this.outputChannel, "No output channel for STOMP messages."); - this.sessionInfos.put(session.getId(), new SessionInfo(session)); + this.sessions.put(session.getId(), session); } /** @@ -180,17 +188,26 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement String sessionId = headers.getSessionId(); String destination = headers.getDestination(); - SessionInfo sessionInfo = this.sessionInfos.get(sessionId); - sessionInfo.addSubscription(destination, headers.getSubscriptionId()); + SessionSubscriptionRegistration registration = this.subscriptionRegistry.getOrCreateRegistration(sessionId); + registration.addSubscription(destination, headers.getSubscriptionId()); } protected void handleUnsubscribe(Message message) { - // TODO: remove subscription + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + String sessionId = headers.getSessionId(); + String subscriptionId = headers.getSubscriptionId(); + SessionSubscriptionRegistration registration = this.subscriptionRegistry.getRegistration(sessionId); + if (registration == null) { + logger.warn("Subscripton=" + subscriptionId + " for session=" + sessionId + " not found"); + return; + } + registration.removeSubscription(subscriptionId); } - protected void handleDisconnect(Message stompMessage) { + protected void handleDisconnect(Message message) { + } protected void sendErrorMessage(WebSocketSession session, Throwable error) { @@ -211,7 +228,10 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { - this.sessionInfos.remove(session.getId()); + + this.sessions.remove(session.getId()); + this.subscriptionRegistry.removeRegistration(session.getId()); + WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.DISCONNECT); headers.setSessionId(session.getId()); Message message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); @@ -237,17 +257,16 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement logger.error("No \"sessionId\" header in message: " + message); } - SessionInfo sessionInfo = this.sessionInfos.get(sessionId); - WebSocketSession session = sessionInfo.getWebSocketSession(); + WebSocketSession session = this.sessions.get(sessionId); if (session == null) { logger.error("Session not found: " + message); } if (headers.getSubscriptionId() == null) { String destination = headers.getDestination(); - Set subs = sessionInfo.getSubscriptionsForDestination(destination); - if (subs != null) { - // TODO: send to all sub ids + Set subs = this.subscriptionRegistry.getSessionSubscriptions(sessionId, destination); + if (!CollectionUtils.isEmpty(subs)) { + // TODO: send to all subscriptions ids headers.setSubscriptionId(subs.iterator().next()); } else { @@ -285,30 +304,4 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement } } - - private static class SessionInfo { - - private final WebSocketSession session; - - private final MultiValueMap subscriptions = new LinkedMultiValueMap(4); - - - public SessionInfo(WebSocketSession session) { - this.session = session; - } - - public WebSocketSession getWebSocketSession() { - return this.session; - } - - public void addSubscription(String destination, String subscriptionId) { - this.subscriptions.add(destination, subscriptionId); - } - - public Set getSubscriptionsForDestination(String destination) { - List ids = this.subscriptions.get(destination); - return (ids != null) ? new HashSet(ids) : null; - } - } - } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java new file mode 100644 index 0000000000..498a8451b1 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistry.java @@ -0,0 +1,172 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArraySet; + +import org.springframework.util.Assert; +import org.springframework.web.messaging.SessionSubscriptionRegistration; +import org.springframework.web.messaging.SessionSubscriptionRegistry; + + +/** + * A decorator for a {@link SessionSubscriptionRegistry} that intercepts subscriptions + * being added and removed, and maintains a cache that tracks registrations for a + * given destination. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRegistry { + + private final SessionSubscriptionRegistry delegate; + + private final DestinationCache destinationCache = new DestinationCache(); + + + public CachingSessionSubscriptionRegistry(SessionSubscriptionRegistry delegate) { + Assert.notNull(delegate, "delegate SessionSubscriptionRegistry is required"); + this.delegate = delegate; + } + + + @Override + public SessionSubscriptionRegistration getRegistration(String sessionId) { + return new CachingSessionSubscriptionRegistration(this.delegate.getRegistration(sessionId)); + } + + @Override + public SessionSubscriptionRegistration getOrCreateRegistration(String sessionId) { + return new CachingSessionSubscriptionRegistration(this.delegate.getOrCreateRegistration(sessionId)); + } + + @Override + public SessionSubscriptionRegistration removeRegistration(String sessionId) { + SessionSubscriptionRegistration registration = this.delegate.removeRegistration(sessionId); + if (registration != null) { + this.destinationCache.removeRegistration(registration); + } + return registration; + } + + @Override + public Set getSessionSubscriptions(String sessionId, String destination) { + return this.delegate.getSessionSubscriptions(sessionId, destination); + } + + public Set getRegistrationsByDestination(String destination) { + return this.destinationCache.getRegistrations(destination); + } + + + private static class DestinationCache { + + private final Map> cache = + new ConcurrentHashMap>(); + + private final Object monitor = new Object(); + + + public void mapRegistration(String destination, SessionSubscriptionRegistration registration) { + synchronized (monitor) { + Set registrations = this.cache.get(destination); + if (registrations == null) { + registrations = new CopyOnWriteArraySet(); + this.cache.put(destination, registrations); + } + registrations.add(registration); + } + } + + public void unmapRegistration(String destination, SessionSubscriptionRegistration registration) { + synchronized (monitor) { + Set registrations = this.cache.get(destination); + if (registrations != null) { + registrations.remove(registration); + if (registrations.isEmpty()) { + this.cache.remove(destination); + } + } + } + } + + private void removeRegistration(SessionSubscriptionRegistration registration) { + for (String destination : registration.getDestinations()) { + unmapRegistration(destination, registration); + } + } + + public Set getRegistrations(String destination) { + return this.cache.get(destination); + } + + @Override + public String toString() { + return "DestinationCache [cache=" + this.cache + "]"; + } + } + + private class CachingSessionSubscriptionRegistration implements SessionSubscriptionRegistration { + + private final SessionSubscriptionRegistration delegate; + + + public CachingSessionSubscriptionRegistration(SessionSubscriptionRegistration delegate) { + Assert.notNull(delegate, "delegate SessionSubscriptionRegistration is required"); + this.delegate = delegate; + } + + @Override + public String getSessionId() { + return this.delegate.getSessionId(); + } + + @Override + public void addSubscription(String destination, String subscriptionId) { + CachingSessionSubscriptionRegistry.this.destinationCache.mapRegistration(destination, this.delegate); + this.delegate.addSubscription(destination, subscriptionId); + } + + @Override + public String removeSubscription(String subscriptionId) { + String destination = this.delegate.removeSubscription(subscriptionId); + if (destination != null && this.delegate.getSubscriptionsByDestination(destination) == null) { + CachingSessionSubscriptionRegistry.this.destinationCache.unmapRegistration(destination, this); + } + return destination; + } + + @Override + public Set getSubscriptionsByDestination(String destination) { + return this.delegate.getSubscriptionsByDestination(destination); + } + + @Override + public Set getDestinations() { + return this.delegate.getDestinations(); + } + + @Override + public String toString() { + return "CachingSessionSubscriptionRegistration [delegate=" + delegate + "]"; + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistration.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistration.java new file mode 100644 index 0000000000..f2de00f045 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistration.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.springframework.util.Assert; +import org.springframework.web.messaging.SessionSubscriptionRegistration; + + +/** + * A default implementation of SessionSubscriptionRegistration. Uses a map to keep track + * of subscriptions by destination. This implementation assumes that only one thread will + * access and update subscriptions at a time. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class DefaultSessionSubscriptionRegistration implements SessionSubscriptionRegistration { + + private final String sessionId; + + // destination -> subscriptionIds + private final Map> subscriptions = new HashMap>(4); + + + public DefaultSessionSubscriptionRegistration(String sessionId) { + Assert.notNull(sessionId, "sessionId is required"); + this.sessionId = sessionId; + } + + + public String getSessionId() { + return this.sessionId; + } + + @Override + public Set getDestinations() { + return this.subscriptions.keySet(); + } + + @Override + public void addSubscription(String destination, String subscriptionId) { + Assert.hasText(destination, "destination must not be empty"); + Assert.hasText(subscriptionId, "subscriptionId must not be empty"); + Set subs = this.subscriptions.get(destination); + if (subs == null) { + subs = new HashSet(4); + this.subscriptions.put(destination, subs); + } + subs.add(subscriptionId); + } + + @Override + public String removeSubscription(String subscriptionId) { + Assert.hasText(subscriptionId, "subscriptionId must not be empty"); + for (String destination : this.subscriptions.keySet()) { + Set subscriptionIds = this.subscriptions.get(destination); + if (subscriptionIds.remove(subscriptionId)) { + if (subscriptionIds.isEmpty()) { + this.subscriptions.remove(destination); + } + return destination; + } + } + return null; + } + + @Override + public Set getSubscriptionsByDestination(String destination) { + Assert.hasText(destination, "destination must not be empty"); + return this.subscriptions.get(destination); + } + + + @Override + public String toString() { + return "DefaultSessionSubscriptionRegistration [sessionId=" + this.sessionId + + ", subscriptions=" + this.subscriptions + "]"; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistry.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistry.java new file mode 100644 index 0000000000..53b3f38fa4 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistry.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.web.messaging.SessionSubscriptionRegistration; +import org.springframework.web.messaging.SessionSubscriptionRegistry; + + +/** + * A default implementation of SessionSubscriptionRegistry. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class DefaultSessionSubscriptionRegistry implements SessionSubscriptionRegistry { + + // sessionId -> SessionSubscriptionRegistration + private final Map registrations = + new ConcurrentHashMap(); + + + @Override + public SessionSubscriptionRegistration getRegistration(String sessionId) { + return this.registrations.get(sessionId); + } + + @Override + public SessionSubscriptionRegistration getOrCreateRegistration(String sessionId) { + SessionSubscriptionRegistration registration = this.registrations.get(sessionId); + if (registration == null) { + registration = new DefaultSessionSubscriptionRegistration(sessionId); + this.registrations.put(sessionId, registration); + } + return registration; + } + + @Override + public SessionSubscriptionRegistration removeRegistration(String sessionId) { + return this.registrations.remove(sessionId); + } + + @Override + public Set getSessionSubscriptions(String sessionId, String destination) { + SessionSubscriptionRegistration registration = this.registrations.get(sessionId); + return (registration != null) ? registration.getSubscriptionsByDestination(destination) : null; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/MessageHolder.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/MessageHolder.java index 0e2af7f027..5db00abb19 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/MessageHolder.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/MessageHolder.java @@ -20,6 +20,8 @@ import org.springframework.core.NamedThreadLocal; import org.springframework.messaging.Message; +// TODO: remove? + /** * @author Rossen Stoyanchev * @since 4.0 diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandlerTests.java new file mode 100644 index 0000000000..396affa785 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/service/SimpleBrokerWebMessageHandlerTests.java @@ -0,0 +1,142 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.service; + +import java.util.Arrays; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.web.messaging.MessageType; +import org.springframework.web.messaging.support.WebMessageHeaderAccesssor; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + + +/** + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class SimpleBrokerWebMessageHandlerTests { + + private AbstractWebMessageHandler messageHandler; + + @Mock + private MessageChannel clientChannel; + + @Captor + ArgumentCaptor> messageCaptor; + + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + this.messageHandler = new SimpleBrokerWebMessageHandler(this.clientChannel); + } + + + @Test + public void getSupportedMessageTypes() { + assertEquals(Arrays.asList(MessageType.MESSAGE, MessageType.SUBSCRIBE, MessageType.UNSUBSCRIBE), + this.messageHandler.getSupportedMessageTypes()); + } + + @Test + public void subcribePublish() { + + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub1", "/foo")); + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub2", "/foo")); + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub3", "/bar")); + + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub1", "/foo")); + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub2", "/foo")); + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub3", "/bar")); + + this.messageHandler.handlePublish(createMessage("/foo", "message1")); + this.messageHandler.handlePublish(createMessage("/bar", "message2")); + + verify(this.clientChannel, times(6)).send(this.messageCaptor.capture()); + assertCapturedMessage(this.messageCaptor.getAllValues().get(0), "sess1", "sub1", "/foo"); + assertCapturedMessage(this.messageCaptor.getAllValues().get(1), "sess1", "sub2", "/foo"); + assertCapturedMessage(this.messageCaptor.getAllValues().get(2), "sess2", "sub1", "/foo"); + assertCapturedMessage(this.messageCaptor.getAllValues().get(3), "sess2", "sub2", "/foo"); + assertCapturedMessage(this.messageCaptor.getAllValues().get(4), "sess1", "sub3", "/bar"); + assertCapturedMessage(this.messageCaptor.getAllValues().get(5), "sess2", "sub3", "/bar"); + } + + @Test + public void subcribeDisconnectPublish() { + + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub1", "/foo")); + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub2", "/foo")); + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub3", "/bar")); + + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub1", "/foo")); + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub2", "/foo")); + this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub3", "/bar")); + + WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.DISCONNECT); + headers.setSessionId("sess1"); + Message message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build(); + this.messageHandler.handleDisconnect(message); + + this.messageHandler.handlePublish(createMessage("/foo", "message1")); + this.messageHandler.handlePublish(createMessage("/bar", "message2")); + + verify(this.clientChannel, times(3)).send(this.messageCaptor.capture()); + assertCapturedMessage(this.messageCaptor.getAllValues().get(0), "sess2", "sub1", "/foo"); + assertCapturedMessage(this.messageCaptor.getAllValues().get(1), "sess2", "sub2", "/foo"); + assertCapturedMessage(this.messageCaptor.getAllValues().get(2), "sess2", "sub3", "/bar"); + } + + + protected Message createSubscriptionMessage(String sessionId, String subcriptionId, String destination) { + + WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.SUBSCRIBE); + headers.setSubscriptionId(subcriptionId); + headers.setDestination(destination); + headers.setSessionId(sessionId); + + return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build(); + } + + protected Message createMessage(String destination, String payload) { + + WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.MESSAGE); + headers.setDestination(destination); + + return MessageBuilder.withPayload(payload).copyHeaders(headers.toMap()).build(); + } + + protected void assertCapturedMessage(Message message, String sessionId, + String subcriptionId, String destination) { + + WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message); + assertEquals(sessionId, headers.getSessionId()); + assertEquals(subcriptionId, headers.getSubscriptionId()); + assertEquals(destination, headers.getDestination()); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java new file mode 100644 index 0000000000..770a7f40a7 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/support/CachingSessionSubscriptionRegistryTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.web.messaging.SessionSubscriptionRegistration; +import org.springframework.web.messaging.SessionSubscriptionRegistry; + +import static org.junit.Assert.*; + + +/** + * Test fixture for {@link CachingSessionSubscriptionRegistry}. + * + * @author Rossen Stoyanchev + */ +public class CachingSessionSubscriptionRegistryTests { + + private CachingSessionSubscriptionRegistry registry; + + + @Before + public void setup() { + SessionSubscriptionRegistry delegate = new DefaultSessionSubscriptionRegistry(); + this.registry = new CachingSessionSubscriptionRegistry(delegate); + } + + @Test + public void getRegistrationsByDestination() { + + SessionSubscriptionRegistration reg1 = this.registry.getOrCreateRegistration("sess1"); + reg1.addSubscription("/foo", "sub1"); + reg1.addSubscription("/foo", "sub1"); + + SessionSubscriptionRegistration reg2 = this.registry.getOrCreateRegistration("sess2"); + reg2.addSubscription("/foo", "sub1"); + reg2.addSubscription("/foo", "sub1"); + + Set actual = this.registry.getRegistrationsByDestination("/foo"); + assertEquals(2, actual.size()); + assertTrue(actual.contains(reg1)); + assertTrue(actual.contains(reg2)); + + reg1.removeSubscription("sub1"); + reg1.removeSubscription("sub2"); + + actual = this.registry.getRegistrationsByDestination("/foo"); + assertEquals("Invalid set of registrations " + actual, 1, actual.size()); + assertTrue(actual.contains(reg2)); + + reg2.removeSubscription("sub1"); + reg2.removeSubscription("sub2"); + + actual = this.registry.getRegistrationsByDestination("/foo"); + assertNull("Unexpected registrations " + actual, actual); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistrationTests.java new file mode 100644 index 0000000000..2fc7c481a2 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistrationTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + + +/** + * Test fixture for {@link DefaultSessionSubscriptionRegistration}. + * + * @author Rossen Stoyanchev + */ +public class DefaultSessionSubscriptionRegistrationTests { + + private DefaultSessionSubscriptionRegistration registration; + + + @Before + public void setup() { + this.registration = new DefaultSessionSubscriptionRegistration("sess1"); + } + + @Test + public void addSubscriptions() { + this.registration.addSubscription("/foo", "sub1"); + this.registration.addSubscription("/foo", "sub2"); + this.registration.addSubscription("/bar", "sub3"); + this.registration.addSubscription("/bar", "sub4"); + + assertSet(this.registration.getSubscriptionsByDestination("/foo"), 2, "sub1", "sub2"); + assertSet(this.registration.getSubscriptionsByDestination("/bar"), 2, "sub3", "sub4"); + assertSet(this.registration.getDestinations(), 2, "/foo", "/bar"); + } + + @Test + public void removeSubscriptions() { + this.registration.addSubscription("/foo", "sub1"); + this.registration.addSubscription("/foo", "sub2"); + this.registration.addSubscription("/bar", "sub3"); + this.registration.addSubscription("/bar", "sub4"); + + assertEquals("/foo", this.registration.removeSubscription("sub1")); + assertEquals("/foo", this.registration.removeSubscription("sub2")); + + assertNull(this.registration.getSubscriptionsByDestination("/foo")); + assertSet(this.registration.getDestinations(), 1, "/bar"); + + assertEquals("/bar", this.registration.removeSubscription("sub3")); + assertEquals("/bar", this.registration.removeSubscription("sub4")); + + assertNull(this.registration.getSubscriptionsByDestination("/bar")); + assertSet(this.registration.getDestinations(), 0); + } + + + private void assertSet(Set set, int size, String... elements) { + assertEquals("Wrong number of elements in " + set, size, set.size()); + for (String element : elements) { + assertTrue("Set does not contain element " + element, set.contains(element)); + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistryTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistryTests.java new file mode 100644 index 0000000000..46128d631a --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/support/DefaultSessionSubscriptionRegistryTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.messaging.support; + +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.web.messaging.SessionSubscriptionRegistration; + +import static org.junit.Assert.*; + + +/** + * Test fixture for {@link DefaultSessionSubscriptionRegistry}. + * + * @author Rossen Stoyanchev + */ +public class DefaultSessionSubscriptionRegistryTests { + + private DefaultSessionSubscriptionRegistry registry; + + + @Before + public void setup() { + this.registry = new DefaultSessionSubscriptionRegistry(); + } + + @Test + public void getRegistration() { + String sessionId = "sess1"; + assertNull(this.registry.getRegistration(sessionId)); + + this.registry.getOrCreateRegistration(sessionId); + assertNotNull(this.registry.getRegistration(sessionId)); + assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId()); + } + + @Test + public void getOrCreateRegistration() { + String sessionId = "sess1"; + assertNull(this.registry.getRegistration(sessionId)); + + SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId); + assertSame(registration, this.registry.getOrCreateRegistration(sessionId)); + } + + @Test + public void removeRegistration() { + String sessionId = "sess1"; + this.registry.getOrCreateRegistration(sessionId); + assertNotNull(this.registry.getRegistration(sessionId)); + assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId()); + + this.registry.removeRegistration(sessionId); + assertNull(this.registry.getRegistration(sessionId)); + } + + @Test + public void getSessionSubscriptions() { + String sessionId = "sess1"; + SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId); + registration.addSubscription("/foo", "sub1"); + registration.addSubscription("/foo", "sub2"); + + Set subscriptions = this.registry.getSessionSubscriptions(sessionId, "/foo"); + assertEquals("Wrong number of subscriptions " + subscriptions, 2, subscriptions.size()); + assertTrue(subscriptions.contains("sub1")); + assertTrue(subscriptions.contains("sub2")); + } + +}