Refactor SubscriptionRegistry
The SubscriptionRegistry and implementations are now in a package together with SimpleBrokerWebMessageHandler and primarily support with matching subscriptions to messages. Subscriptions can contain patterns as supported by AntPathMatcher. StopmWebSocketHandler no longer keeps track of subscriptions and simply ignores messages without a subscription id, since it has no way of knowing broker-specific destination semantics for patterns.
This commit is contained in:
parent
f25ccac1a1
commit
3a2f5e71b7
|
|
@ -1,43 +0,0 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
*/
|
||||
public interface SessionSubscriptionRegistration {
|
||||
|
||||
|
||||
String getSessionId();
|
||||
|
||||
void addSubscription(String destination, String subscriptionId);
|
||||
|
||||
/**
|
||||
* @param subscriptionId the subscription to remove
|
||||
* @return the destination to which the subscriptionId was registered, or {@code null}
|
||||
* if no matching subscriptionId was found
|
||||
*/
|
||||
String removeSubscription(String subscriptionId);
|
||||
|
||||
Set<String> getSubscriptionsByDestination(String destination);
|
||||
|
||||
Set<String> getDestinations();
|
||||
|
||||
}
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
/**
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
*/
|
||||
public interface SessionSubscriptionRegistry {
|
||||
|
||||
SessionSubscriptionRegistration getRegistration(String sessionId);
|
||||
|
||||
SessionSubscriptionRegistration getOrCreateRegistration(String sessionId);
|
||||
|
||||
SessionSubscriptionRegistration removeRegistration(String sessionId);
|
||||
|
||||
Set<String> getSessionSubscriptions(String sessionId, String destination);
|
||||
|
||||
Set<SessionSubscriptionRegistration> getRegistrationsByDestination(String destination);
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.service.broker;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.messaging.MessageType;
|
||||
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
|
||||
|
||||
|
||||
/**
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
*/
|
||||
public abstract class AbstractSubscriptionRegistry implements SubscriptionRegistry {
|
||||
|
||||
protected final Log logger = LogFactory.getLog(getClass());
|
||||
|
||||
|
||||
@Override
|
||||
public void addSubscription(Message<?> message) {
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
|
||||
if (!MessageType.SUBSCRIBE.equals(headers.getMessageType())) {
|
||||
logger.error("Expected SUBSCRIBE message: " + message);
|
||||
return;
|
||||
}
|
||||
String sessionId = headers.getSessionId();
|
||||
if (sessionId == null) {
|
||||
logger.error("Ignoring subscription. No sessionId in message: " + message);
|
||||
return;
|
||||
}
|
||||
String subscriptionId = headers.getSubscriptionId();
|
||||
if (subscriptionId == null) {
|
||||
logger.error("Ignoring subscription. No subscriptionId in message: " + message);
|
||||
return;
|
||||
}
|
||||
String destination = headers.getDestination();
|
||||
if (destination == null) {
|
||||
logger.error("Ignoring destination. No destination in message: " + message);
|
||||
return;
|
||||
}
|
||||
addSubscriptionInternal(sessionId, subscriptionId, destination, message);
|
||||
}
|
||||
|
||||
protected abstract void addSubscriptionInternal(String sessionId, String subscriptionId,
|
||||
String destination, Message<?> message);
|
||||
|
||||
@Override
|
||||
public void removeSubscription(Message<?> message) {
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
|
||||
if (!MessageType.UNSUBSCRIBE.equals(headers.getMessageType())) {
|
||||
logger.error("Expected UNSUBSCRIBE message: " + message);
|
||||
return;
|
||||
}
|
||||
String sessionId = headers.getSessionId();
|
||||
if (sessionId == null) {
|
||||
logger.error("Ignoring subscription. No sessionId in message: " + message);
|
||||
return;
|
||||
}
|
||||
String subscriptionId = headers.getSubscriptionId();
|
||||
if (subscriptionId == null) {
|
||||
logger.error("Ignoring subscription. No subscriptionId in message: " + message);
|
||||
return;
|
||||
}
|
||||
removeSubscriptionInternal(sessionId, subscriptionId, message);
|
||||
}
|
||||
|
||||
protected abstract void removeSubscriptionInternal(String sessionId, String subscriptionId, Message<?> message);
|
||||
|
||||
@Override
|
||||
public void removeSessionSubscriptions(String sessionId) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiValueMap<String, String> findSubscriptions(Message<?> message) {
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
|
||||
if (!MessageType.MESSAGE.equals(headers.getMessageType())) {
|
||||
logger.error("Unexpected message type: " + message);
|
||||
return null;
|
||||
}
|
||||
String destination = headers.getDestination();
|
||||
if (destination == null) {
|
||||
logger.error("Ignoring destination. No destination in message: " + message);
|
||||
return null;
|
||||
}
|
||||
return findSubscriptionsInternal(destination, message);
|
||||
}
|
||||
|
||||
protected abstract MultiValueMap<String, String> findSubscriptionsInternal(
|
||||
String destination, Message<?> message);
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.service.broker;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArraySet;
|
||||
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.util.AntPathMatcher;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
|
||||
|
||||
/**
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
*/
|
||||
public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
|
||||
|
||||
private final DestinationCache destinationCache = new DestinationCache();
|
||||
|
||||
private final SessionSubscriptionRegistry subscriptionRegistry = new SessionSubscriptionRegistry();
|
||||
|
||||
private AntPathMatcher pathMatcher = new AntPathMatcher();
|
||||
|
||||
|
||||
/**
|
||||
* @param pathMatcher the pathMatcher to set
|
||||
*/
|
||||
public void setPathMatcher(AntPathMatcher pathMatcher) {
|
||||
this.pathMatcher = pathMatcher;
|
||||
}
|
||||
|
||||
public AntPathMatcher getPathMatcher() {
|
||||
return this.pathMatcher;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void addSubscriptionInternal(String sessionId, String subsId, String destination, Message<?> message) {
|
||||
SessionSubscriptionInfo info = this.subscriptionRegistry.addSubscription(sessionId, subsId, destination);
|
||||
if (!this.pathMatcher.isPattern(destination)) {
|
||||
this.destinationCache.mapToDestination(destination, info);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void removeSubscriptionInternal(String sessionId, String subscriptionId, Message<?> message) {
|
||||
SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId);
|
||||
if (info != null) {
|
||||
String destination = info.removeSubscription(subscriptionId);
|
||||
if (info.getSubscriptions(destination) == null) {
|
||||
this.destinationCache.unmapFromDestination(destination, info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeSessionSubscriptions(String sessionId) {
|
||||
SessionSubscriptionInfo info = this.subscriptionRegistry.removeSubscriptions(sessionId);
|
||||
this.destinationCache.removeSessionSubscriptions(info);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MultiValueMap<String, String> findSubscriptionsInternal(String destination, Message<?> message) {
|
||||
MultiValueMap<String,String> result = this.destinationCache.getSubscriptions(destination);
|
||||
if (result.isEmpty()) {
|
||||
result = new LinkedMultiValueMap<String, String>();
|
||||
for (SessionSubscriptionInfo info : this.subscriptionRegistry.getAllSubscriptions()) {
|
||||
for (String destinationPattern : info.getDestinations()) {
|
||||
if (this.pathMatcher.match(destinationPattern, destination)) {
|
||||
for (String subscriptionId : info.getSubscriptions(destinationPattern)) {
|
||||
result.add(info.sessionId, subscriptionId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Provide direct lookup of session subscriptions by destination (for non-pattern destinations).
|
||||
*/
|
||||
private static class DestinationCache {
|
||||
|
||||
// destination -> ..
|
||||
private final Map<String, Set<SessionSubscriptionInfo>> subscriptionsByDestination =
|
||||
new ConcurrentHashMap<String, Set<SessionSubscriptionInfo>>();
|
||||
|
||||
private final Object monitor = new Object();
|
||||
|
||||
|
||||
public void mapToDestination(String destination, SessionSubscriptionInfo info) {
|
||||
synchronized (monitor) {
|
||||
Set<SessionSubscriptionInfo> registrations = this.subscriptionsByDestination.get(destination);
|
||||
if (registrations == null) {
|
||||
registrations = new CopyOnWriteArraySet<SessionSubscriptionInfo>();
|
||||
this.subscriptionsByDestination.put(destination, registrations);
|
||||
}
|
||||
registrations.add(info);
|
||||
}
|
||||
}
|
||||
|
||||
public void unmapFromDestination(String destination, SessionSubscriptionInfo info) {
|
||||
synchronized (monitor) {
|
||||
Set<SessionSubscriptionInfo> infos = this.subscriptionsByDestination.get(destination);
|
||||
if (infos != null) {
|
||||
infos.remove(info);
|
||||
if (infos.isEmpty()) {
|
||||
this.subscriptionsByDestination.remove(destination);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void removeSessionSubscriptions(SessionSubscriptionInfo info) {
|
||||
for (String destination : info.getDestinations()) {
|
||||
unmapFromDestination(destination, info);
|
||||
}
|
||||
}
|
||||
|
||||
public MultiValueMap<String, String> getSubscriptions(String destination) {
|
||||
MultiValueMap<String, String> result = new LinkedMultiValueMap<String, String>();
|
||||
Set<SessionSubscriptionInfo> infos = this.subscriptionsByDestination.get(destination);
|
||||
if (infos != null) {
|
||||
for (SessionSubscriptionInfo info : infos) {
|
||||
Set<String> subscriptions = info.getSubscriptions(destination);
|
||||
if (subscriptions != null) {
|
||||
for (String subscription : subscriptions) {
|
||||
result.add(info.getSessionId(), subscription);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide access to session subscriptions by sessionId.
|
||||
*/
|
||||
private static class SessionSubscriptionRegistry {
|
||||
|
||||
private final Map<String, SessionSubscriptionInfo> sessions =
|
||||
new ConcurrentHashMap<String, SessionSubscriptionInfo>();
|
||||
|
||||
|
||||
public SessionSubscriptionInfo getSubscriptions(String sessionId) {
|
||||
return this.sessions.get(sessionId);
|
||||
}
|
||||
|
||||
public Collection<SessionSubscriptionInfo> getAllSubscriptions() {
|
||||
return this.sessions.values();
|
||||
}
|
||||
|
||||
public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, String destination) {
|
||||
SessionSubscriptionInfo info = this.sessions.get(sessionId);
|
||||
if (info == null) {
|
||||
info = new SessionSubscriptionInfo(sessionId);
|
||||
this.sessions.put(sessionId, info);
|
||||
}
|
||||
info.addSubscription(subscriptionId, destination);
|
||||
return info;
|
||||
}
|
||||
|
||||
public SessionSubscriptionInfo removeSubscriptions(String sessionId) {
|
||||
return this.sessions.remove(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hold subscriptions for a session.
|
||||
*/
|
||||
private static class SessionSubscriptionInfo {
|
||||
|
||||
private final String sessionId;
|
||||
|
||||
private final Map<String, Set<String>> subscriptions = new HashMap<String, Set<String>>(4);
|
||||
|
||||
|
||||
public SessionSubscriptionInfo(String sessionId) {
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
public String getSessionId() {
|
||||
return this.sessionId;
|
||||
}
|
||||
|
||||
public Set<String> getDestinations() {
|
||||
return this.subscriptions.keySet();
|
||||
}
|
||||
|
||||
public Set<String> getSubscriptions(String destination) {
|
||||
return this.subscriptions.get(destination);
|
||||
}
|
||||
|
||||
public void addSubscription(String subscriptionId, String destination) {
|
||||
Set<String> subs = this.subscriptions.get(destination);
|
||||
if (subs == null) {
|
||||
subs = new HashSet<String>(4);
|
||||
this.subscriptions.put(destination, subs);
|
||||
}
|
||||
subs.add(subscriptionId);
|
||||
}
|
||||
|
||||
public String removeSubscription(String subscriptionId) {
|
||||
for (String destination : this.subscriptions.keySet()) {
|
||||
Set<String> subscriptionIds = this.subscriptions.get(destination);
|
||||
if (subscriptionIds.remove(subscriptionId)) {
|
||||
if (subscriptionIds.isEmpty()) {
|
||||
this.subscriptions.remove(destination);
|
||||
}
|
||||
return destination;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -14,21 +14,18 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.service;
|
||||
package org.springframework.web.messaging.service.broker;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Set;
|
||||
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.messaging.MessageType;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistration;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistry;
|
||||
import org.springframework.web.messaging.support.CachingSessionSubscriptionRegistry;
|
||||
import org.springframework.web.messaging.support.DefaultSessionSubscriptionRegistry;
|
||||
import org.springframework.web.messaging.service.AbstractWebMessageHandler;
|
||||
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
|
||||
|
||||
|
||||
|
|
@ -38,23 +35,22 @@ import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
|
|||
*/
|
||||
public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler {
|
||||
|
||||
private final MessageChannel clientChannel;
|
||||
private final MessageChannel outboundChannel;
|
||||
|
||||
private SessionSubscriptionRegistry subscriptionRegistry=
|
||||
new CachingSessionSubscriptionRegistry(new DefaultSessionSubscriptionRegistry());
|
||||
private SubscriptionRegistry subscriptionRegistry = new DefaultSubscriptionRegistry();
|
||||
|
||||
|
||||
/**
|
||||
* @param clientChannel the channel to which messages for clients should be sent
|
||||
* @param outboundChannel the channel to which messages for clients should be sent
|
||||
* @param observable an Observable to use to manage subscriptions
|
||||
*/
|
||||
public SimpleBrokerWebMessageHandler(MessageChannel clientChannel) {
|
||||
Assert.notNull(clientChannel, "clientChannel is required");
|
||||
this.clientChannel = clientChannel;
|
||||
public SimpleBrokerWebMessageHandler(MessageChannel outboundChannel) {
|
||||
Assert.notNull(outboundChannel, "outboundChannel is required");
|
||||
this.outboundChannel = outboundChannel;
|
||||
}
|
||||
|
||||
|
||||
public void setSubscriptionRegistry(SessionSubscriptionRegistry subscriptionRegistry) {
|
||||
public void setSubscriptionRegistry(SubscriptionRegistry subscriptionRegistry) {
|
||||
Assert.notNull(subscriptionRegistry, "subscriptionRegistry is required");
|
||||
this.subscriptionRegistry = subscriptionRegistry;
|
||||
}
|
||||
|
|
@ -71,13 +67,16 @@ public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler {
|
|||
logger.debug("Subscribe " + message);
|
||||
}
|
||||
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
|
||||
String sessionId = headers.getSessionId();
|
||||
String subscriptionId = headers.getSubscriptionId();
|
||||
String destination = headers.getDestination();
|
||||
this.subscriptionRegistry.addSubscription(message);
|
||||
|
||||
SessionSubscriptionRegistration registration = this.subscriptionRegistry.getOrCreateRegistration(sessionId);
|
||||
registration.addSubscription(destination, subscriptionId);
|
||||
// TODO: need a way to communicate back if subscription was successfully created or
|
||||
// not in which case an ERROR should be sent back and close the connection
|
||||
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void handleUnsubscribe(Message<?> message) {
|
||||
this.subscriptionRegistry.removeSubscription(message);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -89,29 +88,24 @@ public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler {
|
|||
|
||||
String destination = WebMessageHeaderAccesssor.wrap(message).getDestination();
|
||||
|
||||
Set<SessionSubscriptionRegistration> registrations =
|
||||
this.subscriptionRegistry.getRegistrationsByDestination(destination);
|
||||
MultiValueMap<String,String> subscriptions = this.subscriptionRegistry.findSubscriptions(message);
|
||||
|
||||
if (registrations == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (SessionSubscriptionRegistration registration : registrations) {
|
||||
for (String subscriptionId : registration.getSubscriptionsByDestination(destination)) {
|
||||
for (String sessionId : subscriptions.keySet()) {
|
||||
for (String subscriptionId : subscriptions.get(sessionId)) {
|
||||
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
|
||||
headers.setSessionId(registration.getSessionId());
|
||||
headers.setSessionId(sessionId);
|
||||
headers.setSubscriptionId(subscriptionId);
|
||||
|
||||
Message<?> clientMessage = MessageBuilder.withPayload(
|
||||
message.getPayload()).copyHeaders(headers.toMap()).build();
|
||||
|
||||
try {
|
||||
this.clientChannel.send(clientMessage);
|
||||
this.outboundChannel.send(clientMessage);
|
||||
}
|
||||
catch (Throwable ex) {
|
||||
logger.error("Failed to send message to destination=" + destination +
|
||||
", sessionId=" + registration.getSessionId() + ", subscriptionId=" + subscriptionId, ex);
|
||||
", sessionId=" + sessionId + ", subscriptionId=" + subscriptionId, ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -120,7 +114,7 @@ public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler {
|
|||
@Override
|
||||
public void handleDisconnect(Message<?> message) {
|
||||
String sessionId = WebMessageHeaderAccesssor.wrap(message).getSessionId();
|
||||
this.subscriptionRegistry.removeRegistration(sessionId);
|
||||
this.subscriptionRegistry.removeSessionSubscriptions(sessionId);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -14,22 +14,24 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.support;
|
||||
package org.springframework.web.messaging.service.broker;
|
||||
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistry;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
|
||||
|
||||
/**
|
||||
* Test fixture for {@link DefaultSessionSubscriptionRegistry}.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
*/
|
||||
public class DefaultSessionSubscriptionRegistryTests extends AbstractSessionSubscriptionRegistryTests {
|
||||
public interface SubscriptionRegistry {
|
||||
|
||||
void addSubscription(Message<?> subscribeMessage);
|
||||
|
||||
@Override
|
||||
protected SessionSubscriptionRegistry createSessionSubscriptionRegistry() {
|
||||
return new DefaultSessionSubscriptionRegistry();
|
||||
}
|
||||
void removeSubscription(Message<?> unsubscribeMessage);
|
||||
|
||||
void removeSessionSubscriptions(String sessionId);
|
||||
|
||||
MultiValueMap<String, String> findSubscriptions(Message<?> message);
|
||||
|
||||
}
|
||||
|
|
@ -57,9 +57,9 @@ import org.springframework.web.method.HandlerMethodSelector;
|
|||
public class AnnotationWebMessageHandler extends AbstractWebMessageHandler
|
||||
implements ApplicationContextAware, InitializingBean {
|
||||
|
||||
private final MessageChannel clientChannel;
|
||||
private final MessageChannel inboundChannel;
|
||||
|
||||
private final MessageChannel brokerChannel;
|
||||
private final MessageChannel outboundChannel;
|
||||
|
||||
private List<MessageConverter> messageConverters;
|
||||
|
||||
|
|
@ -79,11 +79,15 @@ public class AnnotationWebMessageHandler extends AbstractWebMessageHandler
|
|||
private ReturnValueHandlerComposite returnValueHandlers = new ReturnValueHandlerComposite();
|
||||
|
||||
|
||||
public AnnotationWebMessageHandler(MessageChannel clientChannel, MessageChannel brokerChannel) {
|
||||
Assert.notNull(clientChannel, "clientChannel is required");
|
||||
Assert.notNull(brokerChannel, "brokerChannel is required");
|
||||
this.clientChannel = clientChannel;
|
||||
this.brokerChannel = brokerChannel;
|
||||
/**
|
||||
* @param inboundChannel a channel for processing incoming messages from clients
|
||||
* @param outboundChannel a channel for messages going out to clients
|
||||
*/
|
||||
public AnnotationWebMessageHandler(MessageChannel inboundChannel, MessageChannel outboundChannel) {
|
||||
Assert.notNull(inboundChannel, "inboundChannel is required");
|
||||
Assert.notNull(outboundChannel, "outboundChannel is required");
|
||||
this.inboundChannel = inboundChannel;
|
||||
this.outboundChannel = outboundChannel;
|
||||
}
|
||||
|
||||
public void setMessageConverters(List<MessageConverter> converters) {
|
||||
|
|
@ -105,11 +109,11 @@ public class AnnotationWebMessageHandler extends AbstractWebMessageHandler
|
|||
|
||||
initHandlerMethods();
|
||||
|
||||
this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(this.brokerChannel));
|
||||
this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(this.inboundChannel));
|
||||
this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverters));
|
||||
|
||||
this.returnValueHandlers.addHandler(new MessageReturnValueHandler(this.clientChannel));
|
||||
this.returnValueHandlers.addHandler(new PayloadReturnValueHandler(this.clientChannel));
|
||||
this.returnValueHandlers.addHandler(new MessageReturnValueHandler(this.outboundChannel));
|
||||
this.returnValueHandlers.addHandler(new PayloadReturnValueHandler(this.outboundChannel));
|
||||
}
|
||||
|
||||
protected void initHandlerMethods() {
|
||||
|
|
|
|||
|
|
@ -28,12 +28,12 @@ import org.springframework.util.Assert;
|
|||
*/
|
||||
public class MessageChannelArgumentResolver implements ArgumentResolver {
|
||||
|
||||
private MessageChannel messageBrokerChannel;
|
||||
private MessageChannel inboundChannel;
|
||||
|
||||
|
||||
public MessageChannelArgumentResolver(MessageChannel messageBrokerChannel) {
|
||||
Assert.notNull(messageBrokerChannel, "messageBrokerChannel is required");
|
||||
this.messageBrokerChannel = messageBrokerChannel;
|
||||
public MessageChannelArgumentResolver(MessageChannel inboundChannel) {
|
||||
Assert.notNull(inboundChannel, "inboundChannel is required");
|
||||
this.inboundChannel = inboundChannel;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -43,7 +43,7 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
|
|||
|
||||
@Override
|
||||
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
|
||||
return this.messageBrokerChannel;
|
||||
return this.inboundChannel;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,12 +30,12 @@ import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
|
|||
*/
|
||||
public class MessageReturnValueHandler implements ReturnValueHandler {
|
||||
|
||||
private MessageChannel clientChannel;
|
||||
private MessageChannel outboundChannel;
|
||||
|
||||
|
||||
public MessageReturnValueHandler(MessageChannel clientChannel) {
|
||||
Assert.notNull(clientChannel, "clientChannel is required");
|
||||
this.clientChannel = clientChannel;
|
||||
public MessageReturnValueHandler(MessageChannel outboundChannel) {
|
||||
Assert.notNull(outboundChannel, "outboundChannel is required");
|
||||
this.outboundChannel = outboundChannel;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -49,7 +49,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
|
|||
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> message)
|
||||
throws Exception {
|
||||
|
||||
Assert.notNull(this.clientChannel, "No clientChannel to send messages to");
|
||||
Assert.notNull(this.outboundChannel, "No clientChannel to send messages to");
|
||||
|
||||
Message<?> returnMessage = (Message<?>) returnValue;
|
||||
if (message == null) {
|
||||
|
|
@ -70,7 +70,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
|
|||
returnMessage = MessageBuilder.withPayload(
|
||||
returnMessage.getPayload()).copyHeaders(returnHeaders.toMap()).build();
|
||||
|
||||
this.clientChannel.send(returnMessage);
|
||||
this.outboundChannel.send(returnMessage);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,12 +30,12 @@ import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
|
|||
*/
|
||||
public class PayloadReturnValueHandler implements ReturnValueHandler {
|
||||
|
||||
private MessageChannel clientChannel;
|
||||
private MessageChannel outboundChannel;
|
||||
|
||||
|
||||
public PayloadReturnValueHandler(MessageChannel clientChannel) {
|
||||
Assert.notNull(clientChannel, "clientChannel is required");
|
||||
this.clientChannel = clientChannel;
|
||||
public PayloadReturnValueHandler(MessageChannel outboundChannel) {
|
||||
Assert.notNull(outboundChannel, "outboundChannel is required");
|
||||
this.outboundChannel = outboundChannel;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -47,7 +47,7 @@ public class PayloadReturnValueHandler implements ReturnValueHandler {
|
|||
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> message)
|
||||
throws Exception {
|
||||
|
||||
Assert.notNull(this.clientChannel, "No clientChannel to send messages to");
|
||||
Assert.notNull(this.outboundChannel, "No outboundChannel to send messages to");
|
||||
|
||||
if (returnValue == null) {
|
||||
return;
|
||||
|
|
@ -63,7 +63,7 @@ public class PayloadReturnValueHandler implements ReturnValueHandler {
|
|||
Message<?> returnMessage = MessageBuilder.withPayload(
|
||||
returnValue).copyHeaders(returnHeaders.toMap()).build();
|
||||
|
||||
this.clientChannel.send(returnMessage);
|
||||
this.outboundChannel.send(returnMessage);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ public class StompHeaderAccessor extends WebMessageHeaderAccesssor {
|
|||
if (contentType != null) {
|
||||
super.setContentType(MediaType.parseMediaType(contentType));
|
||||
}
|
||||
if (StompCommand.SUBSCRIBE.equals(getStompCommand())) {
|
||||
if (StompCommand.SUBSCRIBE.equals(getStompCommand()) || StompCommand.UNSUBSCRIBE.equals(getStompCommand())) {
|
||||
if (getFirstNativeHeader(STOMP_ID) != null) {
|
||||
super.setSubscriptionId(getFirstNativeHeader(STOMP_ID));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ public class StompRelayWebMessageHandler extends AbstractWebMessageHandler imple
|
|||
private static final String STOMP_RELAY_SYSTEM_SESSION_ID = "stompRelaySystemSessionId";
|
||||
|
||||
|
||||
private MessageChannel clientChannel;
|
||||
private MessageChannel outboundChannel;
|
||||
|
||||
private String relayHost = "127.0.0.1";
|
||||
|
||||
|
|
@ -85,11 +85,11 @@ public class StompRelayWebMessageHandler extends AbstractWebMessageHandler imple
|
|||
|
||||
|
||||
/**
|
||||
* @param clientChannel the channel to which messages for clients should be sent.
|
||||
* @param outboundChannel a channel for messages going out to clients
|
||||
*/
|
||||
public StompRelayWebMessageHandler(MessageChannel clientChannel) {
|
||||
Assert.notNull(clientChannel, "clientChannel is required");
|
||||
this.clientChannel = clientChannel;
|
||||
public StompRelayWebMessageHandler(MessageChannel outboundChannel) {
|
||||
Assert.notNull(outboundChannel, "outboundChannel is required");
|
||||
this.outboundChannel = outboundChannel;
|
||||
this.payloadConverter = new CompositeMessageConverter(null);
|
||||
}
|
||||
|
||||
|
|
@ -387,7 +387,7 @@ public class StompRelayWebMessageHandler extends AbstractWebMessageHandler imple
|
|||
}
|
||||
|
||||
protected void sendMessageToClient(Message<?> message) {
|
||||
clientChannel.send(message);
|
||||
outboundChannel.send(message);
|
||||
}
|
||||
|
||||
private void sendError(String sessionId, String errorText) {
|
||||
|
|
|
|||
|
|
@ -29,15 +29,11 @@ import org.springframework.messaging.Message;
|
|||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.MessageHandler;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.messaging.MessageType;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistration;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistry;
|
||||
import org.springframework.web.messaging.converter.CompositeMessageConverter;
|
||||
import org.springframework.web.messaging.converter.MessageConverter;
|
||||
import org.springframework.web.messaging.stomp.StompCommand;
|
||||
import org.springframework.web.messaging.stomp.StompConversionException;
|
||||
import org.springframework.web.messaging.support.DefaultSessionSubscriptionRegistry;
|
||||
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
|
||||
import org.springframework.web.socket.CloseStatus;
|
||||
import org.springframework.web.socket.TextMessage;
|
||||
|
|
@ -63,8 +59,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
|
|||
|
||||
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
|
||||
|
||||
private SessionSubscriptionRegistry subscriptionRegistry = new DefaultSessionSubscriptionRegistry();
|
||||
|
||||
private MessageConverter payloadConverter = new CompositeMessageConverter(null);
|
||||
|
||||
|
||||
|
|
@ -86,10 +80,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
|
|||
return this.stompMessageConverter;
|
||||
}
|
||||
|
||||
public void setSubscriptionRegistry(SessionSubscriptionRegistry subscriptionRegistry) {
|
||||
this.subscriptionRegistry = subscriptionRegistry;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
||||
|
|
@ -179,35 +169,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
|
|||
}
|
||||
|
||||
protected void handleSubscribe(Message<?> message) {
|
||||
|
||||
// TODO: need a way to communicate back if subscription was successfully created or
|
||||
// not in which case an ERROR should be sent back and close the connection
|
||||
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
|
||||
|
||||
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
|
||||
String sessionId = headers.getSessionId();
|
||||
String destination = headers.getDestination();
|
||||
|
||||
SessionSubscriptionRegistration registration = this.subscriptionRegistry.getOrCreateRegistration(sessionId);
|
||||
registration.addSubscription(destination, headers.getSubscriptionId());
|
||||
}
|
||||
|
||||
protected void handleUnsubscribe(Message<?> message) {
|
||||
|
||||
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
|
||||
String sessionId = headers.getSessionId();
|
||||
String subscriptionId = headers.getSubscriptionId();
|
||||
|
||||
SessionSubscriptionRegistration registration = this.subscriptionRegistry.getRegistration(sessionId);
|
||||
if (registration == null) {
|
||||
logger.warn("Subscripton=" + subscriptionId + " for session=" + sessionId + " not found");
|
||||
return;
|
||||
}
|
||||
registration.removeSubscription(subscriptionId);
|
||||
}
|
||||
|
||||
protected void handleDisconnect(Message<?> message) {
|
||||
|
||||
}
|
||||
|
||||
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
|
||||
|
|
@ -230,7 +197,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
|
|||
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
|
||||
|
||||
this.sessions.remove(session.getId());
|
||||
this.subscriptionRegistry.removeRegistration(session.getId());
|
||||
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.DISCONNECT);
|
||||
headers.setSessionId(session.getId());
|
||||
|
|
@ -254,25 +220,22 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
|
|||
|
||||
String sessionId = headers.getSessionId();
|
||||
if (sessionId == null) {
|
||||
// TODO: failed message delivery mechanism
|
||||
logger.error("No \"sessionId\" header in message: " + message);
|
||||
return;
|
||||
}
|
||||
|
||||
WebSocketSession session = this.sessions.get(sessionId);
|
||||
if (session == null) {
|
||||
logger.error("Session not found: " + message);
|
||||
// TODO: failed message delivery mechanism
|
||||
logger.error("WebSocketSession not found for sessionId=" + sessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
if (headers.getSubscriptionId() == null) {
|
||||
String destination = headers.getDestination();
|
||||
Set<String> subs = this.subscriptionRegistry.getSessionSubscriptions(sessionId, destination);
|
||||
if (!CollectionUtils.isEmpty(subs)) {
|
||||
// TODO: send to all subscriptions ids
|
||||
headers.setSubscriptionId(subs.iterator().next());
|
||||
}
|
||||
else {
|
||||
logger.error("No subscription id: " + message);
|
||||
return;
|
||||
}
|
||||
// TODO: failed message delivery mechanism
|
||||
logger.error("No subscription id: " + message);
|
||||
return;
|
||||
}
|
||||
|
||||
byte[] payload;
|
||||
|
|
|
|||
|
|
@ -1,190 +0,0 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.support;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArraySet;
|
||||
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistration;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistry;
|
||||
|
||||
|
||||
/**
|
||||
* A decorator for a {@link SessionSubscriptionRegistry} that intercepts subscriptions
|
||||
* being added and removed and maintains a lookup cache of registrations by destination.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
*/
|
||||
public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRegistry {
|
||||
|
||||
private final SessionSubscriptionRegistry delegate;
|
||||
|
||||
private final DestinationCache destinationCache = new DestinationCache();
|
||||
|
||||
|
||||
public CachingSessionSubscriptionRegistry(SessionSubscriptionRegistry delegate) {
|
||||
Assert.notNull(delegate, "delegate SessionSubscriptionRegistry is required");
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public SessionSubscriptionRegistration getRegistration(String sessionId) {
|
||||
SessionSubscriptionRegistration reg = this.delegate.getRegistration(sessionId);
|
||||
return (reg != null) ? new CachingSessionSubscriptionRegistration(reg) : null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SessionSubscriptionRegistration getOrCreateRegistration(String sessionId) {
|
||||
return new CachingSessionSubscriptionRegistration(this.delegate.getOrCreateRegistration(sessionId));
|
||||
}
|
||||
|
||||
@Override
|
||||
public SessionSubscriptionRegistration removeRegistration(String sessionId) {
|
||||
SessionSubscriptionRegistration registration = this.delegate.removeRegistration(sessionId);
|
||||
if (registration != null) {
|
||||
this.destinationCache.removeRegistration(registration);
|
||||
}
|
||||
return registration;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getSessionSubscriptions(String sessionId, String destination) {
|
||||
return this.delegate.getSessionSubscriptions(sessionId, destination);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<SessionSubscriptionRegistration> getRegistrationsByDestination(String destination) {
|
||||
return this.destinationCache.getRegistrations(destination);
|
||||
}
|
||||
|
||||
|
||||
private static class DestinationCache {
|
||||
|
||||
private final Map<String, Set<SessionSubscriptionRegistration>> cache =
|
||||
new ConcurrentHashMap<String, Set<SessionSubscriptionRegistration>>();
|
||||
|
||||
private final Object monitor = new Object();
|
||||
|
||||
|
||||
public void mapRegistration(String destination, SessionSubscriptionRegistration registration) {
|
||||
synchronized (monitor) {
|
||||
Set<SessionSubscriptionRegistration> registrations = this.cache.get(destination);
|
||||
if (registrations == null) {
|
||||
registrations = new CopyOnWriteArraySet<SessionSubscriptionRegistration>();
|
||||
this.cache.put(destination, registrations);
|
||||
}
|
||||
registrations.add(registration);
|
||||
}
|
||||
}
|
||||
|
||||
public void unmapRegistration(String destination, SessionSubscriptionRegistration registration) {
|
||||
synchronized (monitor) {
|
||||
Set<SessionSubscriptionRegistration> registrations = this.cache.get(destination);
|
||||
if (registrations != null) {
|
||||
registrations.remove(registration);
|
||||
if (registrations.isEmpty()) {
|
||||
this.cache.remove(destination);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void removeRegistration(SessionSubscriptionRegistration registration) {
|
||||
for (String destination : registration.getDestinations()) {
|
||||
unmapRegistration(destination, registration);
|
||||
}
|
||||
}
|
||||
|
||||
public Set<SessionSubscriptionRegistration> getRegistrations(String destination) {
|
||||
return this.cache.get(destination);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "DestinationCache [cache=" + this.cache + "]";
|
||||
}
|
||||
}
|
||||
|
||||
private class CachingSessionSubscriptionRegistration implements SessionSubscriptionRegistration {
|
||||
|
||||
private final SessionSubscriptionRegistration delegate;
|
||||
|
||||
|
||||
public CachingSessionSubscriptionRegistration(SessionSubscriptionRegistration delegate) {
|
||||
Assert.notNull(delegate, "delegate SessionSubscriptionRegistration is required");
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getSessionId() {
|
||||
return this.delegate.getSessionId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addSubscription(String destination, String subscriptionId) {
|
||||
destinationCache.mapRegistration(destination, this);
|
||||
this.delegate.addSubscription(destination, subscriptionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String removeSubscription(String subscriptionId) {
|
||||
String destination = this.delegate.removeSubscription(subscriptionId);
|
||||
if (destination != null && this.delegate.getSubscriptionsByDestination(destination) == null) {
|
||||
destinationCache.unmapRegistration(destination, this);
|
||||
}
|
||||
return destination;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getSubscriptionsByDestination(String destination) {
|
||||
return this.delegate.getSubscriptionsByDestination(destination);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getDestinations() {
|
||||
return this.delegate.getDestinations();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
if (this == other) {
|
||||
return true;
|
||||
}
|
||||
if (!(other instanceof CachingSessionSubscriptionRegistration)) {
|
||||
return false;
|
||||
}
|
||||
CachingSessionSubscriptionRegistration otherType = (CachingSessionSubscriptionRegistration) other;
|
||||
return this.delegate.equals(otherType.delegate);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return this.delegate.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "CachingSessionSubscriptionRegistration [delegate=" + delegate + "]";
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,115 +0,0 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.support;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistration;
|
||||
|
||||
|
||||
/**
|
||||
* A default implementation of SessionSubscriptionRegistration. Uses a map to keep track
|
||||
* of subscriptions by destination. This implementation assumes that only one thread will
|
||||
* access and update subscriptions at a time.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
*/
|
||||
public class DefaultSessionSubscriptionRegistration implements SessionSubscriptionRegistration {
|
||||
|
||||
private final String sessionId;
|
||||
|
||||
// destination -> subscriptionIds
|
||||
private final Map<String, Set<String>> subscriptions = new HashMap<String, Set<String>>(4);
|
||||
|
||||
|
||||
public DefaultSessionSubscriptionRegistration(String sessionId) {
|
||||
Assert.notNull(sessionId, "sessionId is required");
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
|
||||
public String getSessionId() {
|
||||
return this.sessionId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getDestinations() {
|
||||
return this.subscriptions.keySet();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addSubscription(String destination, String subscriptionId) {
|
||||
Assert.hasText(destination, "destination must not be empty");
|
||||
Assert.hasText(subscriptionId, "subscriptionId must not be empty");
|
||||
Set<String> subs = this.subscriptions.get(destination);
|
||||
if (subs == null) {
|
||||
subs = new HashSet<String>(4);
|
||||
this.subscriptions.put(destination, subs);
|
||||
}
|
||||
subs.add(subscriptionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String removeSubscription(String subscriptionId) {
|
||||
Assert.hasText(subscriptionId, "subscriptionId must not be empty");
|
||||
for (String destination : this.subscriptions.keySet()) {
|
||||
Set<String> subscriptionIds = this.subscriptions.get(destination);
|
||||
if (subscriptionIds.remove(subscriptionId)) {
|
||||
if (subscriptionIds.isEmpty()) {
|
||||
this.subscriptions.remove(destination);
|
||||
}
|
||||
return destination;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getSubscriptionsByDestination(String destination) {
|
||||
Assert.hasText(destination, "destination must not be empty");
|
||||
return this.subscriptions.get(destination);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
if (this == other) {
|
||||
return true;
|
||||
}
|
||||
if (!(other instanceof DefaultSessionSubscriptionRegistration)) {
|
||||
return false;
|
||||
}
|
||||
DefaultSessionSubscriptionRegistration otherType = (DefaultSessionSubscriptionRegistration) other;
|
||||
return this.sessionId.equals(otherType.sessionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 + this.sessionId.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "DefaultSessionSubscriptionRegistration [sessionId=" + this.sessionId
|
||||
+ ", subscriptions=" + this.subscriptions + "]";
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.support;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistration;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistry;
|
||||
|
||||
|
||||
/**
|
||||
* A default implementation of SessionSubscriptionRegistry.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 4.0
|
||||
*/
|
||||
public class DefaultSessionSubscriptionRegistry implements SessionSubscriptionRegistry {
|
||||
|
||||
// sessionId -> SessionSubscriptionRegistration
|
||||
private final Map<String, SessionSubscriptionRegistration> registrations =
|
||||
new ConcurrentHashMap<String, SessionSubscriptionRegistration>();
|
||||
|
||||
|
||||
@Override
|
||||
public SessionSubscriptionRegistration getRegistration(String sessionId) {
|
||||
return this.registrations.get(sessionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SessionSubscriptionRegistration getOrCreateRegistration(String sessionId) {
|
||||
SessionSubscriptionRegistration registration = this.registrations.get(sessionId);
|
||||
if (registration == null) {
|
||||
registration = new DefaultSessionSubscriptionRegistration(sessionId);
|
||||
this.registrations.put(sessionId, registration);
|
||||
}
|
||||
return registration;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SessionSubscriptionRegistration removeRegistration(String sessionId) {
|
||||
return this.registrations.remove(sessionId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getSessionSubscriptions(String sessionId, String destination) {
|
||||
SessionSubscriptionRegistration registration = this.registrations.get(sessionId);
|
||||
return (registration != null) ? registration.getSubscriptionsByDestination(destination) : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* The default implementation performs a lookup by destination on each registration.
|
||||
* For a more efficient algorithm consider decorating an instance of this class with
|
||||
* {@link CachingSessionSubscriptionRegistry}.
|
||||
*/
|
||||
@Override
|
||||
public Set<SessionSubscriptionRegistration> getRegistrationsByDestination(String destination) {
|
||||
Set<SessionSubscriptionRegistration> result = new HashSet<SessionSubscriptionRegistration>();
|
||||
for (SessionSubscriptionRegistration r : this.registrations.values()) {
|
||||
if (r.getSubscriptionsByDestination(destination) != null) {
|
||||
result.add(r);
|
||||
}
|
||||
}
|
||||
return result.isEmpty() ? null : result;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -28,6 +28,7 @@ import org.springframework.messaging.Message;
|
|||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.web.messaging.MessageType;
|
||||
import org.springframework.web.messaging.service.broker.SimpleBrokerWebMessageHandler;
|
||||
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
|
@ -89,29 +90,29 @@ public class SimpleBrokerWebMessageHandlerTests {
|
|||
@Test
|
||||
public void subcribeDisconnectPublish() {
|
||||
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub1", "/foo"));
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub2", "/foo"));
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub3", "/bar"));
|
||||
String sess1 = "sess1";
|
||||
String sess2 = "sess2";
|
||||
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub1", "/foo"));
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub2", "/foo"));
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub3", "/bar"));
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub1", "/foo"));
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub2", "/foo"));
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub3", "/bar"));
|
||||
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub1", "/foo"));
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub2", "/foo"));
|
||||
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub3", "/bar"));
|
||||
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.DISCONNECT);
|
||||
headers.setSessionId("sess1");
|
||||
headers.setSessionId(sess1);
|
||||
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build();
|
||||
this.messageHandler.handleDisconnect(message);
|
||||
|
||||
this.messageHandler.handlePublish(createMessage("/foo", "message1"));
|
||||
this.messageHandler.handlePublish(createMessage("/bar", "message2"));
|
||||
|
||||
verify(this.clientChannel, times(6)).send(this.messageCaptor.capture());
|
||||
assertCapturedMessage("sess1", "sub1", "/foo");
|
||||
assertCapturedMessage("sess1", "sub2", "/foo");
|
||||
assertCapturedMessage("sess2", "sub1", "/foo");
|
||||
assertCapturedMessage("sess2", "sub2", "/foo");
|
||||
assertCapturedMessage("sess1", "sub3", "/bar");
|
||||
assertCapturedMessage("sess2", "sub3", "/bar");
|
||||
verify(this.clientChannel, times(3)).send(this.messageCaptor.capture());
|
||||
assertCapturedMessage(sess2, "sub1", "/foo");
|
||||
assertCapturedMessage(sess2, "sub2", "/foo");
|
||||
assertCapturedMessage(sess2, "sub3", "/bar");
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,242 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.service.broker;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.messaging.MessageType;
|
||||
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
|
||||
/**
|
||||
* Test fixture for {@link DefaultSubscriptionRegistry}.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
*/
|
||||
public class DefaultSubscriptionRegistryTests {
|
||||
|
||||
|
||||
private DefaultSubscriptionRegistry registry;
|
||||
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
this.registry = new DefaultSubscriptionRegistry();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void addSubscriptionInvalidInput() {
|
||||
|
||||
String sessId = "sess01";
|
||||
String subsId = "subs01";
|
||||
String dest = "/foo";
|
||||
|
||||
this.registry.addSubscription(subscribeMessage(null, subsId, dest));
|
||||
assertEquals(0, this.registry.findSubscriptions(message(dest)).size());
|
||||
|
||||
this.registry.addSubscription(subscribeMessage(sessId, null, dest));
|
||||
assertEquals(0, this.registry.findSubscriptions(message(dest)).size());
|
||||
|
||||
this.registry.addSubscription(subscribeMessage(sessId, subsId, null));
|
||||
assertEquals(0, this.registry.findSubscriptions(message(dest)).size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addSubscription() {
|
||||
|
||||
String sessId = "sess01";
|
||||
String subsId = "subs01";
|
||||
String dest = "/foo";
|
||||
|
||||
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
|
||||
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
|
||||
|
||||
assertEquals("Expected one element " + actual, 1, actual.size());
|
||||
assertEquals(Arrays.asList(subsId), actual.get(sessId));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addSubscriptionOneSession() {
|
||||
|
||||
String sessId = "sess01";
|
||||
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
|
||||
String dest = "/foo";
|
||||
|
||||
for (String subId : subscriptionIds) {
|
||||
this.registry.addSubscription(subscribeMessage(sessId, subId, dest));
|
||||
}
|
||||
|
||||
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
|
||||
|
||||
assertEquals("Expected one element " + actual, 1, actual.size());
|
||||
assertEquals(subscriptionIds, sort(actual.get(sessId)));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addSubscriptionMultipleSessions() {
|
||||
|
||||
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
|
||||
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
|
||||
String dest = "/foo";
|
||||
|
||||
for (String sessId : sessIds) {
|
||||
for (String subsId : subscriptionIds) {
|
||||
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
|
||||
}
|
||||
}
|
||||
|
||||
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
|
||||
|
||||
assertEquals("Expected three elements " + actual, 3, actual.size());
|
||||
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(0))));
|
||||
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(1))));
|
||||
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2))));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addSubscriptionWithDestinationPattern() {
|
||||
|
||||
String sessId = "sess01";
|
||||
String subsId = "subs01";
|
||||
String destPattern = "/topic/PRICE.STOCK.*.IBM";
|
||||
String dest = "/topic/PRICE.STOCK.NASDAQ.IBM";
|
||||
|
||||
this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern));
|
||||
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
|
||||
|
||||
assertEquals("Expected one element " + actual, 1, actual.size());
|
||||
assertEquals(Arrays.asList(subsId), actual.get(sessId));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addSubscriptionWithDestinationPatternRegex() {
|
||||
|
||||
String sessId = "sess01";
|
||||
String subsId = "subs01";
|
||||
String destPattern = "/topic/PRICE.STOCK.*.{ticker:(IBM|MSFT)}";
|
||||
|
||||
this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern));
|
||||
Message<?> message = message("/topic/PRICE.STOCK.NASDAQ.IBM");
|
||||
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message);
|
||||
|
||||
assertEquals("Expected one element " + actual, 1, actual.size());
|
||||
assertEquals(Arrays.asList(subsId), actual.get(sessId));
|
||||
|
||||
message = message("/topic/PRICE.STOCK.NASDAQ.MSFT");
|
||||
actual = this.registry.findSubscriptions(message);
|
||||
|
||||
assertEquals("Expected one element " + actual, 1, actual.size());
|
||||
assertEquals(Arrays.asList(subsId), actual.get(sessId));
|
||||
|
||||
message = message("/topic/PRICE.STOCK.NASDAQ.VMW");
|
||||
actual = this.registry.findSubscriptions(message);
|
||||
|
||||
assertEquals("Expected no elements " + actual, 0, actual.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void removeSubscription() {
|
||||
|
||||
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
|
||||
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
|
||||
String dest = "/foo";
|
||||
|
||||
for (String sessId : sessIds) {
|
||||
for (String subsId : subscriptionIds) {
|
||||
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
|
||||
}
|
||||
}
|
||||
|
||||
this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(0)));
|
||||
this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(1)));
|
||||
this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(2)));
|
||||
|
||||
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
|
||||
|
||||
assertEquals("Expected three elements " + actual, 2, actual.size());
|
||||
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(1))));
|
||||
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2))));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void removeSessionSubscriptions() {
|
||||
|
||||
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
|
||||
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
|
||||
String dest = "/foo";
|
||||
|
||||
for (String sessId : sessIds) {
|
||||
for (String subsId : subscriptionIds) {
|
||||
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
|
||||
}
|
||||
}
|
||||
|
||||
this.registry.removeSessionSubscriptions(sessIds.get(0));
|
||||
this.registry.removeSessionSubscriptions(sessIds.get(1));
|
||||
|
||||
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
|
||||
|
||||
assertEquals("Expected three elements " + actual, 1, actual.size());
|
||||
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2))));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void findSubscriptionsNoMatches() {
|
||||
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message("/foo"));
|
||||
assertEquals("Expected no elements " + actual, 0, actual.size());
|
||||
}
|
||||
|
||||
|
||||
private Message<?> subscribeMessage(String sessionId, String subscriptionId, String destination) {
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.SUBSCRIBE);
|
||||
headers.setSessionId(sessionId);
|
||||
headers.setSubscriptionId(subscriptionId);
|
||||
if (destination != null) {
|
||||
headers.setDestination(destination);
|
||||
}
|
||||
return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build();
|
||||
}
|
||||
|
||||
private Message<?> unsubscribeMessage(String sessionId, String subscriptionId) {
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.UNSUBSCRIBE);
|
||||
headers.setSessionId(sessionId);
|
||||
headers.setSubscriptionId(subscriptionId);
|
||||
return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build();
|
||||
}
|
||||
|
||||
private Message<?> message(String destination) {
|
||||
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create();
|
||||
headers.setDestination(destination);
|
||||
return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build();
|
||||
}
|
||||
|
||||
private List<String> sort(List<String> list) {
|
||||
Collections.sort(list);
|
||||
return list;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.support;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistration;
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistry;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
|
||||
/**
|
||||
* A test fixture for {@link AbstractSessionSubscriptionRegistry}.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
*/
|
||||
public abstract class AbstractSessionSubscriptionRegistryTests {
|
||||
|
||||
protected SessionSubscriptionRegistry registry;
|
||||
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
this.registry = createSessionSubscriptionRegistry();
|
||||
}
|
||||
|
||||
protected abstract SessionSubscriptionRegistry createSessionSubscriptionRegistry();
|
||||
|
||||
|
||||
@Test
|
||||
public void getRegistration() {
|
||||
String sessionId = "sess1";
|
||||
assertNull(this.registry.getRegistration(sessionId));
|
||||
|
||||
this.registry.getOrCreateRegistration(sessionId);
|
||||
assertNotNull(this.registry.getRegistration(sessionId));
|
||||
assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getOrCreateRegistration() {
|
||||
String sessionId = "sess1";
|
||||
assertNull(this.registry.getRegistration(sessionId));
|
||||
|
||||
SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId);
|
||||
assertEquals(registration, this.registry.getOrCreateRegistration(sessionId));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void removeRegistration() {
|
||||
String sessionId = "sess1";
|
||||
this.registry.getOrCreateRegistration(sessionId);
|
||||
assertNotNull(this.registry.getRegistration(sessionId));
|
||||
assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId());
|
||||
|
||||
this.registry.removeRegistration(sessionId);
|
||||
assertNull(this.registry.getRegistration(sessionId));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getSessionSubscriptions() {
|
||||
String sessionId = "sess1";
|
||||
SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId);
|
||||
registration.addSubscription("/foo", "sub1");
|
||||
registration.addSubscription("/foo", "sub2");
|
||||
|
||||
Set<String> subscriptions = this.registry.getSessionSubscriptions(sessionId, "/foo");
|
||||
assertEquals("Wrong number of subscriptions " + subscriptions, 2, subscriptions.size());
|
||||
assertTrue(subscriptions.contains("sub1"));
|
||||
assertTrue(subscriptions.contains("sub2"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getRegistrationsByDestination() {
|
||||
|
||||
SessionSubscriptionRegistration reg1 = this.registry.getOrCreateRegistration("sess1");
|
||||
reg1.addSubscription("/foo", "sub1");
|
||||
|
||||
SessionSubscriptionRegistration reg2 = this.registry.getOrCreateRegistration("sess2");
|
||||
reg2.addSubscription("/foo", "sub1");
|
||||
|
||||
Set<SessionSubscriptionRegistration> actual = this.registry.getRegistrationsByDestination("/foo");
|
||||
assertEquals(2, actual.size());
|
||||
assertTrue(actual.contains(reg1));
|
||||
assertTrue(actual.contains(reg2));
|
||||
|
||||
reg1.removeSubscription("sub1");
|
||||
|
||||
actual = this.registry.getRegistrationsByDestination("/foo");
|
||||
assertEquals("Invalid set of registrations " + actual, 1, actual.size());
|
||||
assertTrue(actual.contains(reg2));
|
||||
|
||||
reg2.removeSubscription("sub1");
|
||||
|
||||
actual = this.registry.getRegistrationsByDestination("/foo");
|
||||
assertNull("Unexpected registrations " + actual, actual);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.support;
|
||||
|
||||
import org.springframework.web.messaging.SessionSubscriptionRegistry;
|
||||
|
||||
|
||||
/**
|
||||
* Test fixture for {@link CachingSessionSubscriptionRegistry}.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
*/
|
||||
public class CachingSessionSubscriptionRegistryTests extends AbstractSessionSubscriptionRegistryTests {
|
||||
|
||||
|
||||
@Override
|
||||
protected SessionSubscriptionRegistry createSessionSubscriptionRegistry() {
|
||||
return new CachingSessionSubscriptionRegistry(new DefaultSessionSubscriptionRegistry());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
/*
|
||||
* Copyright 2002-2013 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.messaging.support;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
|
||||
/**
|
||||
* Test fixture for {@link DefaultSessionSubscriptionRegistration}.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
*/
|
||||
public class DefaultSessionSubscriptionRegistrationTests {
|
||||
|
||||
private DefaultSessionSubscriptionRegistration registration;
|
||||
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
this.registration = new DefaultSessionSubscriptionRegistration("sess1");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addSubscriptions() {
|
||||
this.registration.addSubscription("/foo", "sub1");
|
||||
this.registration.addSubscription("/foo", "sub2");
|
||||
this.registration.addSubscription("/bar", "sub3");
|
||||
this.registration.addSubscription("/bar", "sub4");
|
||||
|
||||
assertSet(this.registration.getSubscriptionsByDestination("/foo"), 2, "sub1", "sub2");
|
||||
assertSet(this.registration.getSubscriptionsByDestination("/bar"), 2, "sub3", "sub4");
|
||||
assertSet(this.registration.getDestinations(), 2, "/foo", "/bar");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void removeSubscriptions() {
|
||||
this.registration.addSubscription("/foo", "sub1");
|
||||
this.registration.addSubscription("/foo", "sub2");
|
||||
this.registration.addSubscription("/bar", "sub3");
|
||||
this.registration.addSubscription("/bar", "sub4");
|
||||
|
||||
assertEquals("/foo", this.registration.removeSubscription("sub1"));
|
||||
assertEquals("/foo", this.registration.removeSubscription("sub2"));
|
||||
|
||||
assertNull(this.registration.getSubscriptionsByDestination("/foo"));
|
||||
assertSet(this.registration.getDestinations(), 1, "/bar");
|
||||
|
||||
assertEquals("/bar", this.registration.removeSubscription("sub3"));
|
||||
assertEquals("/bar", this.registration.removeSubscription("sub4"));
|
||||
|
||||
assertNull(this.registration.getSubscriptionsByDestination("/bar"));
|
||||
assertSet(this.registration.getDestinations(), 0);
|
||||
}
|
||||
|
||||
|
||||
private void assertSet(Set<String> set, int size, String... elements) {
|
||||
assertEquals("Wrong number of elements in " + set, size, set.size());
|
||||
for (String element : elements) {
|
||||
assertTrue("Set does not contain element " + element, set.contains(element));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Reference in New Issue