Refactor SubscriptionRegistry

The SubscriptionRegistry and implementations are now in a package
together with SimpleBrokerWebMessageHandler and primarily support
with matching subscriptions to messages. Subscriptions can contain
patterns as supported by AntPathMatcher.

StopmWebSocketHandler no longer keeps track of subscriptions and simply
ignores messages without a subscription id, since it has no way of
knowing broker-specific destination semantics for patterns.
This commit is contained in:
Rossen Stoyanchev 2013-07-07 14:18:58 -04:00
parent f25ccac1a1
commit 3a2f5e71b7
21 changed files with 688 additions and 836 deletions

View File

@ -1,43 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging;
import java.util.Set;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface SessionSubscriptionRegistration {
String getSessionId();
void addSubscription(String destination, String subscriptionId);
/**
* @param subscriptionId the subscription to remove
* @return the destination to which the subscriptionId was registered, or {@code null}
* if no matching subscriptionId was found
*/
String removeSubscription(String subscriptionId);
Set<String> getSubscriptionsByDestination(String destination);
Set<String> getDestinations();
}

View File

@ -1,38 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging;
import java.util.Set;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface SessionSubscriptionRegistry {
SessionSubscriptionRegistration getRegistration(String sessionId);
SessionSubscriptionRegistration getOrCreateRegistration(String sessionId);
SessionSubscriptionRegistration removeRegistration(String sessionId);
Set<String> getSessionSubscriptions(String sessionId, String destination);
Set<SessionSubscriptionRegistration> getRegistrationsByDestination(String destination);
}

View File

@ -0,0 +1,108 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.service.broker;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.util.MultiValueMap;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractSubscriptionRegistry implements SubscriptionRegistry {
protected final Log logger = LogFactory.getLog(getClass());
@Override
public void addSubscription(Message<?> message) {
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
if (!MessageType.SUBSCRIBE.equals(headers.getMessageType())) {
logger.error("Expected SUBSCRIBE message: " + message);
return;
}
String sessionId = headers.getSessionId();
if (sessionId == null) {
logger.error("Ignoring subscription. No sessionId in message: " + message);
return;
}
String subscriptionId = headers.getSubscriptionId();
if (subscriptionId == null) {
logger.error("Ignoring subscription. No subscriptionId in message: " + message);
return;
}
String destination = headers.getDestination();
if (destination == null) {
logger.error("Ignoring destination. No destination in message: " + message);
return;
}
addSubscriptionInternal(sessionId, subscriptionId, destination, message);
}
protected abstract void addSubscriptionInternal(String sessionId, String subscriptionId,
String destination, Message<?> message);
@Override
public void removeSubscription(Message<?> message) {
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
if (!MessageType.UNSUBSCRIBE.equals(headers.getMessageType())) {
logger.error("Expected UNSUBSCRIBE message: " + message);
return;
}
String sessionId = headers.getSessionId();
if (sessionId == null) {
logger.error("Ignoring subscription. No sessionId in message: " + message);
return;
}
String subscriptionId = headers.getSubscriptionId();
if (subscriptionId == null) {
logger.error("Ignoring subscription. No subscriptionId in message: " + message);
return;
}
removeSubscriptionInternal(sessionId, subscriptionId, message);
}
protected abstract void removeSubscriptionInternal(String sessionId, String subscriptionId, Message<?> message);
@Override
public void removeSessionSubscriptions(String sessionId) {
}
@Override
public MultiValueMap<String, String> findSubscriptions(Message<?> message) {
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
if (!MessageType.MESSAGE.equals(headers.getMessageType())) {
logger.error("Unexpected message type: " + message);
return null;
}
String destination = headers.getDestination();
if (destination == null) {
logger.error("Ignoring destination. No destination in message: " + message);
return null;
}
return findSubscriptionsInternal(destination, message);
}
protected abstract MultiValueMap<String, String> findSubscriptionsInternal(
String destination, Message<?> message);
}

View File

@ -0,0 +1,240 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.service.broker;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import org.springframework.messaging.Message;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
private final DestinationCache destinationCache = new DestinationCache();
private final SessionSubscriptionRegistry subscriptionRegistry = new SessionSubscriptionRegistry();
private AntPathMatcher pathMatcher = new AntPathMatcher();
/**
* @param pathMatcher the pathMatcher to set
*/
public void setPathMatcher(AntPathMatcher pathMatcher) {
this.pathMatcher = pathMatcher;
}
public AntPathMatcher getPathMatcher() {
return this.pathMatcher;
}
@Override
protected void addSubscriptionInternal(String sessionId, String subsId, String destination, Message<?> message) {
SessionSubscriptionInfo info = this.subscriptionRegistry.addSubscription(sessionId, subsId, destination);
if (!this.pathMatcher.isPattern(destination)) {
this.destinationCache.mapToDestination(destination, info);
}
}
@Override
protected void removeSubscriptionInternal(String sessionId, String subscriptionId, Message<?> message) {
SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId);
if (info != null) {
String destination = info.removeSubscription(subscriptionId);
if (info.getSubscriptions(destination) == null) {
this.destinationCache.unmapFromDestination(destination, info);
}
}
}
@Override
public void removeSessionSubscriptions(String sessionId) {
SessionSubscriptionInfo info = this.subscriptionRegistry.removeSubscriptions(sessionId);
this.destinationCache.removeSessionSubscriptions(info);
}
@Override
protected MultiValueMap<String, String> findSubscriptionsInternal(String destination, Message<?> message) {
MultiValueMap<String,String> result = this.destinationCache.getSubscriptions(destination);
if (result.isEmpty()) {
result = new LinkedMultiValueMap<String, String>();
for (SessionSubscriptionInfo info : this.subscriptionRegistry.getAllSubscriptions()) {
for (String destinationPattern : info.getDestinations()) {
if (this.pathMatcher.match(destinationPattern, destination)) {
for (String subscriptionId : info.getSubscriptions(destinationPattern)) {
result.add(info.sessionId, subscriptionId);
}
}
}
}
}
return result;
}
/**
* Provide direct lookup of session subscriptions by destination (for non-pattern destinations).
*/
private static class DestinationCache {
// destination -> ..
private final Map<String, Set<SessionSubscriptionInfo>> subscriptionsByDestination =
new ConcurrentHashMap<String, Set<SessionSubscriptionInfo>>();
private final Object monitor = new Object();
public void mapToDestination(String destination, SessionSubscriptionInfo info) {
synchronized (monitor) {
Set<SessionSubscriptionInfo> registrations = this.subscriptionsByDestination.get(destination);
if (registrations == null) {
registrations = new CopyOnWriteArraySet<SessionSubscriptionInfo>();
this.subscriptionsByDestination.put(destination, registrations);
}
registrations.add(info);
}
}
public void unmapFromDestination(String destination, SessionSubscriptionInfo info) {
synchronized (monitor) {
Set<SessionSubscriptionInfo> infos = this.subscriptionsByDestination.get(destination);
if (infos != null) {
infos.remove(info);
if (infos.isEmpty()) {
this.subscriptionsByDestination.remove(destination);
}
}
}
}
public void removeSessionSubscriptions(SessionSubscriptionInfo info) {
for (String destination : info.getDestinations()) {
unmapFromDestination(destination, info);
}
}
public MultiValueMap<String, String> getSubscriptions(String destination) {
MultiValueMap<String, String> result = new LinkedMultiValueMap<String, String>();
Set<SessionSubscriptionInfo> infos = this.subscriptionsByDestination.get(destination);
if (infos != null) {
for (SessionSubscriptionInfo info : infos) {
Set<String> subscriptions = info.getSubscriptions(destination);
if (subscriptions != null) {
for (String subscription : subscriptions) {
result.add(info.getSessionId(), subscription);
}
}
}
}
return result;
}
}
/**
* Provide access to session subscriptions by sessionId.
*/
private static class SessionSubscriptionRegistry {
private final Map<String, SessionSubscriptionInfo> sessions =
new ConcurrentHashMap<String, SessionSubscriptionInfo>();
public SessionSubscriptionInfo getSubscriptions(String sessionId) {
return this.sessions.get(sessionId);
}
public Collection<SessionSubscriptionInfo> getAllSubscriptions() {
return this.sessions.values();
}
public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, String destination) {
SessionSubscriptionInfo info = this.sessions.get(sessionId);
if (info == null) {
info = new SessionSubscriptionInfo(sessionId);
this.sessions.put(sessionId, info);
}
info.addSubscription(subscriptionId, destination);
return info;
}
public SessionSubscriptionInfo removeSubscriptions(String sessionId) {
return this.sessions.remove(sessionId);
}
}
/**
* Hold subscriptions for a session.
*/
private static class SessionSubscriptionInfo {
private final String sessionId;
private final Map<String, Set<String>> subscriptions = new HashMap<String, Set<String>>(4);
public SessionSubscriptionInfo(String sessionId) {
this.sessionId = sessionId;
}
public String getSessionId() {
return this.sessionId;
}
public Set<String> getDestinations() {
return this.subscriptions.keySet();
}
public Set<String> getSubscriptions(String destination) {
return this.subscriptions.get(destination);
}
public void addSubscription(String subscriptionId, String destination) {
Set<String> subs = this.subscriptions.get(destination);
if (subs == null) {
subs = new HashSet<String>(4);
this.subscriptions.put(destination, subs);
}
subs.add(subscriptionId);
}
public String removeSubscription(String subscriptionId) {
for (String destination : this.subscriptions.keySet()) {
Set<String> subscriptionIds = this.subscriptions.get(destination);
if (subscriptionIds.remove(subscriptionId)) {
if (subscriptionIds.isEmpty()) {
this.subscriptions.remove(destination);
}
return destination;
}
}
return null;
}
}
}

View File

@ -14,21 +14,18 @@
* limitations under the License.
*/
package org.springframework.web.messaging.service;
package org.springframework.web.messaging.service.broker;
import java.util.Arrays;
import java.util.Collection;
import java.util.Set;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.SessionSubscriptionRegistration;
import org.springframework.web.messaging.SessionSubscriptionRegistry;
import org.springframework.web.messaging.support.CachingSessionSubscriptionRegistry;
import org.springframework.web.messaging.support.DefaultSessionSubscriptionRegistry;
import org.springframework.web.messaging.service.AbstractWebMessageHandler;
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
@ -38,23 +35,22 @@ import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
*/
public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler {
private final MessageChannel clientChannel;
private final MessageChannel outboundChannel;
private SessionSubscriptionRegistry subscriptionRegistry=
new CachingSessionSubscriptionRegistry(new DefaultSessionSubscriptionRegistry());
private SubscriptionRegistry subscriptionRegistry = new DefaultSubscriptionRegistry();
/**
* @param clientChannel the channel to which messages for clients should be sent
* @param outboundChannel the channel to which messages for clients should be sent
* @param observable an Observable to use to manage subscriptions
*/
public SimpleBrokerWebMessageHandler(MessageChannel clientChannel) {
Assert.notNull(clientChannel, "clientChannel is required");
this.clientChannel = clientChannel;
public SimpleBrokerWebMessageHandler(MessageChannel outboundChannel) {
Assert.notNull(outboundChannel, "outboundChannel is required");
this.outboundChannel = outboundChannel;
}
public void setSubscriptionRegistry(SessionSubscriptionRegistry subscriptionRegistry) {
public void setSubscriptionRegistry(SubscriptionRegistry subscriptionRegistry) {
Assert.notNull(subscriptionRegistry, "subscriptionRegistry is required");
this.subscriptionRegistry = subscriptionRegistry;
}
@ -71,13 +67,16 @@ public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler {
logger.debug("Subscribe " + message);
}
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
String sessionId = headers.getSessionId();
String subscriptionId = headers.getSubscriptionId();
String destination = headers.getDestination();
this.subscriptionRegistry.addSubscription(message);
SessionSubscriptionRegistration registration = this.subscriptionRegistry.getOrCreateRegistration(sessionId);
registration.addSubscription(destination, subscriptionId);
// TODO: need a way to communicate back if subscription was successfully created or
// not in which case an ERROR should be sent back and close the connection
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
}
@Override
protected void handleUnsubscribe(Message<?> message) {
this.subscriptionRegistry.removeSubscription(message);
}
@Override
@ -89,29 +88,24 @@ public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler {
String destination = WebMessageHeaderAccesssor.wrap(message).getDestination();
Set<SessionSubscriptionRegistration> registrations =
this.subscriptionRegistry.getRegistrationsByDestination(destination);
MultiValueMap<String,String> subscriptions = this.subscriptionRegistry.findSubscriptions(message);
if (registrations == null) {
return;
}
for (SessionSubscriptionRegistration registration : registrations) {
for (String subscriptionId : registration.getSubscriptionsByDestination(destination)) {
for (String sessionId : subscriptions.keySet()) {
for (String subscriptionId : subscriptions.get(sessionId)) {
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.wrap(message);
headers.setSessionId(registration.getSessionId());
headers.setSessionId(sessionId);
headers.setSubscriptionId(subscriptionId);
Message<?> clientMessage = MessageBuilder.withPayload(
message.getPayload()).copyHeaders(headers.toMap()).build();
try {
this.clientChannel.send(clientMessage);
this.outboundChannel.send(clientMessage);
}
catch (Throwable ex) {
logger.error("Failed to send message to destination=" + destination +
", sessionId=" + registration.getSessionId() + ", subscriptionId=" + subscriptionId, ex);
", sessionId=" + sessionId + ", subscriptionId=" + subscriptionId, ex);
}
}
}
@ -120,7 +114,7 @@ public class SimpleBrokerWebMessageHandler extends AbstractWebMessageHandler {
@Override
public void handleDisconnect(Message<?> message) {
String sessionId = WebMessageHeaderAccesssor.wrap(message).getSessionId();
this.subscriptionRegistry.removeRegistration(sessionId);
this.subscriptionRegistry.removeSessionSubscriptions(sessionId);
}
}

View File

@ -14,22 +14,24 @@
* limitations under the License.
*/
package org.springframework.web.messaging.support;
package org.springframework.web.messaging.service.broker;
import org.springframework.web.messaging.SessionSubscriptionRegistry;
import org.springframework.messaging.Message;
import org.springframework.util.MultiValueMap;
/**
* Test fixture for {@link DefaultSessionSubscriptionRegistry}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class DefaultSessionSubscriptionRegistryTests extends AbstractSessionSubscriptionRegistryTests {
public interface SubscriptionRegistry {
void addSubscription(Message<?> subscribeMessage);
@Override
protected SessionSubscriptionRegistry createSessionSubscriptionRegistry() {
return new DefaultSessionSubscriptionRegistry();
}
void removeSubscription(Message<?> unsubscribeMessage);
void removeSessionSubscriptions(String sessionId);
MultiValueMap<String, String> findSubscriptions(Message<?> message);
}

View File

@ -57,9 +57,9 @@ import org.springframework.web.method.HandlerMethodSelector;
public class AnnotationWebMessageHandler extends AbstractWebMessageHandler
implements ApplicationContextAware, InitializingBean {
private final MessageChannel clientChannel;
private final MessageChannel inboundChannel;
private final MessageChannel brokerChannel;
private final MessageChannel outboundChannel;
private List<MessageConverter> messageConverters;
@ -79,11 +79,15 @@ public class AnnotationWebMessageHandler extends AbstractWebMessageHandler
private ReturnValueHandlerComposite returnValueHandlers = new ReturnValueHandlerComposite();
public AnnotationWebMessageHandler(MessageChannel clientChannel, MessageChannel brokerChannel) {
Assert.notNull(clientChannel, "clientChannel is required");
Assert.notNull(brokerChannel, "brokerChannel is required");
this.clientChannel = clientChannel;
this.brokerChannel = brokerChannel;
/**
* @param inboundChannel a channel for processing incoming messages from clients
* @param outboundChannel a channel for messages going out to clients
*/
public AnnotationWebMessageHandler(MessageChannel inboundChannel, MessageChannel outboundChannel) {
Assert.notNull(inboundChannel, "inboundChannel is required");
Assert.notNull(outboundChannel, "outboundChannel is required");
this.inboundChannel = inboundChannel;
this.outboundChannel = outboundChannel;
}
public void setMessageConverters(List<MessageConverter> converters) {
@ -105,11 +109,11 @@ public class AnnotationWebMessageHandler extends AbstractWebMessageHandler
initHandlerMethods();
this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(this.brokerChannel));
this.argumentResolvers.addResolver(new MessageChannelArgumentResolver(this.inboundChannel));
this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverters));
this.returnValueHandlers.addHandler(new MessageReturnValueHandler(this.clientChannel));
this.returnValueHandlers.addHandler(new PayloadReturnValueHandler(this.clientChannel));
this.returnValueHandlers.addHandler(new MessageReturnValueHandler(this.outboundChannel));
this.returnValueHandlers.addHandler(new PayloadReturnValueHandler(this.outboundChannel));
}
protected void initHandlerMethods() {

View File

@ -28,12 +28,12 @@ import org.springframework.util.Assert;
*/
public class MessageChannelArgumentResolver implements ArgumentResolver {
private MessageChannel messageBrokerChannel;
private MessageChannel inboundChannel;
public MessageChannelArgumentResolver(MessageChannel messageBrokerChannel) {
Assert.notNull(messageBrokerChannel, "messageBrokerChannel is required");
this.messageBrokerChannel = messageBrokerChannel;
public MessageChannelArgumentResolver(MessageChannel inboundChannel) {
Assert.notNull(inboundChannel, "inboundChannel is required");
this.inboundChannel = inboundChannel;
}
@Override
@ -43,7 +43,7 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
return this.messageBrokerChannel;
return this.inboundChannel;
}
}

View File

@ -30,12 +30,12 @@ import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
*/
public class MessageReturnValueHandler implements ReturnValueHandler {
private MessageChannel clientChannel;
private MessageChannel outboundChannel;
public MessageReturnValueHandler(MessageChannel clientChannel) {
Assert.notNull(clientChannel, "clientChannel is required");
this.clientChannel = clientChannel;
public MessageReturnValueHandler(MessageChannel outboundChannel) {
Assert.notNull(outboundChannel, "outboundChannel is required");
this.outboundChannel = outboundChannel;
}
@Override
@ -49,7 +49,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> message)
throws Exception {
Assert.notNull(this.clientChannel, "No clientChannel to send messages to");
Assert.notNull(this.outboundChannel, "No clientChannel to send messages to");
Message<?> returnMessage = (Message<?>) returnValue;
if (message == null) {
@ -70,7 +70,7 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
returnMessage = MessageBuilder.withPayload(
returnMessage.getPayload()).copyHeaders(returnHeaders.toMap()).build();
this.clientChannel.send(returnMessage);
this.outboundChannel.send(returnMessage);
}
}

View File

@ -30,12 +30,12 @@ import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
*/
public class PayloadReturnValueHandler implements ReturnValueHandler {
private MessageChannel clientChannel;
private MessageChannel outboundChannel;
public PayloadReturnValueHandler(MessageChannel clientChannel) {
Assert.notNull(clientChannel, "clientChannel is required");
this.clientChannel = clientChannel;
public PayloadReturnValueHandler(MessageChannel outboundChannel) {
Assert.notNull(outboundChannel, "outboundChannel is required");
this.outboundChannel = outboundChannel;
}
@Override
@ -47,7 +47,7 @@ public class PayloadReturnValueHandler implements ReturnValueHandler {
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> message)
throws Exception {
Assert.notNull(this.clientChannel, "No clientChannel to send messages to");
Assert.notNull(this.outboundChannel, "No outboundChannel to send messages to");
if (returnValue == null) {
return;
@ -63,7 +63,7 @@ public class PayloadReturnValueHandler implements ReturnValueHandler {
Message<?> returnMessage = MessageBuilder.withPayload(
returnValue).copyHeaders(returnHeaders.toMap()).build();
this.clientChannel.send(returnMessage);
this.outboundChannel.send(returnMessage);
}
}

View File

@ -97,7 +97,7 @@ public class StompHeaderAccessor extends WebMessageHeaderAccesssor {
if (contentType != null) {
super.setContentType(MediaType.parseMediaType(contentType));
}
if (StompCommand.SUBSCRIBE.equals(getStompCommand())) {
if (StompCommand.SUBSCRIBE.equals(getStompCommand()) || StompCommand.UNSUBSCRIBE.equals(getStompCommand())) {
if (getFirstNativeHeader(STOMP_ID) != null) {
super.setSubscriptionId(getFirstNativeHeader(STOMP_ID));
}

View File

@ -59,7 +59,7 @@ public class StompRelayWebMessageHandler extends AbstractWebMessageHandler imple
private static final String STOMP_RELAY_SYSTEM_SESSION_ID = "stompRelaySystemSessionId";
private MessageChannel clientChannel;
private MessageChannel outboundChannel;
private String relayHost = "127.0.0.1";
@ -85,11 +85,11 @@ public class StompRelayWebMessageHandler extends AbstractWebMessageHandler imple
/**
* @param clientChannel the channel to which messages for clients should be sent.
* @param outboundChannel a channel for messages going out to clients
*/
public StompRelayWebMessageHandler(MessageChannel clientChannel) {
Assert.notNull(clientChannel, "clientChannel is required");
this.clientChannel = clientChannel;
public StompRelayWebMessageHandler(MessageChannel outboundChannel) {
Assert.notNull(outboundChannel, "outboundChannel is required");
this.outboundChannel = outboundChannel;
this.payloadConverter = new CompositeMessageConverter(null);
}
@ -387,7 +387,7 @@ public class StompRelayWebMessageHandler extends AbstractWebMessageHandler imple
}
protected void sendMessageToClient(Message<?> message) {
clientChannel.send(message);
outboundChannel.send(message);
}
private void sendError(String sessionId, String errorText) {

View File

@ -29,15 +29,11 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.CollectionUtils;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.SessionSubscriptionRegistration;
import org.springframework.web.messaging.SessionSubscriptionRegistry;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.support.DefaultSessionSubscriptionRegistry;
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
@ -63,8 +59,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
private SessionSubscriptionRegistry subscriptionRegistry = new DefaultSessionSubscriptionRegistry();
private MessageConverter payloadConverter = new CompositeMessageConverter(null);
@ -86,10 +80,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
return this.stompMessageConverter;
}
public void setSubscriptionRegistry(SessionSubscriptionRegistry subscriptionRegistry) {
this.subscriptionRegistry = subscriptionRegistry;
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
@ -179,35 +169,12 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
}
protected void handleSubscribe(Message<?> message) {
// TODO: need a way to communicate back if subscription was successfully created or
// not in which case an ERROR should be sent back and close the connection
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
String sessionId = headers.getSessionId();
String destination = headers.getDestination();
SessionSubscriptionRegistration registration = this.subscriptionRegistry.getOrCreateRegistration(sessionId);
registration.addSubscription(destination, headers.getSubscriptionId());
}
protected void handleUnsubscribe(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
String sessionId = headers.getSessionId();
String subscriptionId = headers.getSubscriptionId();
SessionSubscriptionRegistration registration = this.subscriptionRegistry.getRegistration(sessionId);
if (registration == null) {
logger.warn("Subscripton=" + subscriptionId + " for session=" + sessionId + " not found");
return;
}
registration.removeSubscription(subscriptionId);
}
protected void handleDisconnect(Message<?> message) {
}
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
@ -230,7 +197,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
this.sessions.remove(session.getId());
this.subscriptionRegistry.removeRegistration(session.getId());
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.DISCONNECT);
headers.setSessionId(session.getId());
@ -254,25 +220,22 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
String sessionId = headers.getSessionId();
if (sessionId == null) {
// TODO: failed message delivery mechanism
logger.error("No \"sessionId\" header in message: " + message);
return;
}
WebSocketSession session = this.sessions.get(sessionId);
if (session == null) {
logger.error("Session not found: " + message);
// TODO: failed message delivery mechanism
logger.error("WebSocketSession not found for sessionId=" + sessionId);
return;
}
if (headers.getSubscriptionId() == null) {
String destination = headers.getDestination();
Set<String> subs = this.subscriptionRegistry.getSessionSubscriptions(sessionId, destination);
if (!CollectionUtils.isEmpty(subs)) {
// TODO: send to all subscriptions ids
headers.setSubscriptionId(subs.iterator().next());
}
else {
logger.error("No subscription id: " + message);
return;
}
// TODO: failed message delivery mechanism
logger.error("No subscription id: " + message);
return;
}
byte[] payload;

View File

@ -1,190 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import org.springframework.util.Assert;
import org.springframework.web.messaging.SessionSubscriptionRegistration;
import org.springframework.web.messaging.SessionSubscriptionRegistry;
/**
* A decorator for a {@link SessionSubscriptionRegistry} that intercepts subscriptions
* being added and removed and maintains a lookup cache of registrations by destination.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class CachingSessionSubscriptionRegistry implements SessionSubscriptionRegistry {
private final SessionSubscriptionRegistry delegate;
private final DestinationCache destinationCache = new DestinationCache();
public CachingSessionSubscriptionRegistry(SessionSubscriptionRegistry delegate) {
Assert.notNull(delegate, "delegate SessionSubscriptionRegistry is required");
this.delegate = delegate;
}
@Override
public SessionSubscriptionRegistration getRegistration(String sessionId) {
SessionSubscriptionRegistration reg = this.delegate.getRegistration(sessionId);
return (reg != null) ? new CachingSessionSubscriptionRegistration(reg) : null;
}
@Override
public SessionSubscriptionRegistration getOrCreateRegistration(String sessionId) {
return new CachingSessionSubscriptionRegistration(this.delegate.getOrCreateRegistration(sessionId));
}
@Override
public SessionSubscriptionRegistration removeRegistration(String sessionId) {
SessionSubscriptionRegistration registration = this.delegate.removeRegistration(sessionId);
if (registration != null) {
this.destinationCache.removeRegistration(registration);
}
return registration;
}
@Override
public Set<String> getSessionSubscriptions(String sessionId, String destination) {
return this.delegate.getSessionSubscriptions(sessionId, destination);
}
@Override
public Set<SessionSubscriptionRegistration> getRegistrationsByDestination(String destination) {
return this.destinationCache.getRegistrations(destination);
}
private static class DestinationCache {
private final Map<String, Set<SessionSubscriptionRegistration>> cache =
new ConcurrentHashMap<String, Set<SessionSubscriptionRegistration>>();
private final Object monitor = new Object();
public void mapRegistration(String destination, SessionSubscriptionRegistration registration) {
synchronized (monitor) {
Set<SessionSubscriptionRegistration> registrations = this.cache.get(destination);
if (registrations == null) {
registrations = new CopyOnWriteArraySet<SessionSubscriptionRegistration>();
this.cache.put(destination, registrations);
}
registrations.add(registration);
}
}
public void unmapRegistration(String destination, SessionSubscriptionRegistration registration) {
synchronized (monitor) {
Set<SessionSubscriptionRegistration> registrations = this.cache.get(destination);
if (registrations != null) {
registrations.remove(registration);
if (registrations.isEmpty()) {
this.cache.remove(destination);
}
}
}
}
private void removeRegistration(SessionSubscriptionRegistration registration) {
for (String destination : registration.getDestinations()) {
unmapRegistration(destination, registration);
}
}
public Set<SessionSubscriptionRegistration> getRegistrations(String destination) {
return this.cache.get(destination);
}
@Override
public String toString() {
return "DestinationCache [cache=" + this.cache + "]";
}
}
private class CachingSessionSubscriptionRegistration implements SessionSubscriptionRegistration {
private final SessionSubscriptionRegistration delegate;
public CachingSessionSubscriptionRegistration(SessionSubscriptionRegistration delegate) {
Assert.notNull(delegate, "delegate SessionSubscriptionRegistration is required");
this.delegate = delegate;
}
@Override
public String getSessionId() {
return this.delegate.getSessionId();
}
@Override
public void addSubscription(String destination, String subscriptionId) {
destinationCache.mapRegistration(destination, this);
this.delegate.addSubscription(destination, subscriptionId);
}
@Override
public String removeSubscription(String subscriptionId) {
String destination = this.delegate.removeSubscription(subscriptionId);
if (destination != null && this.delegate.getSubscriptionsByDestination(destination) == null) {
destinationCache.unmapRegistration(destination, this);
}
return destination;
}
@Override
public Set<String> getSubscriptionsByDestination(String destination) {
return this.delegate.getSubscriptionsByDestination(destination);
}
@Override
public Set<String> getDestinations() {
return this.delegate.getDestinations();
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof CachingSessionSubscriptionRegistration)) {
return false;
}
CachingSessionSubscriptionRegistration otherType = (CachingSessionSubscriptionRegistration) other;
return this.delegate.equals(otherType.delegate);
}
@Override
public int hashCode() {
return this.delegate.hashCode();
}
@Override
public String toString() {
return "CachingSessionSubscriptionRegistration [delegate=" + delegate + "]";
}
}
}

View File

@ -1,115 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.util.Assert;
import org.springframework.web.messaging.SessionSubscriptionRegistration;
/**
* A default implementation of SessionSubscriptionRegistration. Uses a map to keep track
* of subscriptions by destination. This implementation assumes that only one thread will
* access and update subscriptions at a time.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class DefaultSessionSubscriptionRegistration implements SessionSubscriptionRegistration {
private final String sessionId;
// destination -> subscriptionIds
private final Map<String, Set<String>> subscriptions = new HashMap<String, Set<String>>(4);
public DefaultSessionSubscriptionRegistration(String sessionId) {
Assert.notNull(sessionId, "sessionId is required");
this.sessionId = sessionId;
}
public String getSessionId() {
return this.sessionId;
}
@Override
public Set<String> getDestinations() {
return this.subscriptions.keySet();
}
@Override
public void addSubscription(String destination, String subscriptionId) {
Assert.hasText(destination, "destination must not be empty");
Assert.hasText(subscriptionId, "subscriptionId must not be empty");
Set<String> subs = this.subscriptions.get(destination);
if (subs == null) {
subs = new HashSet<String>(4);
this.subscriptions.put(destination, subs);
}
subs.add(subscriptionId);
}
@Override
public String removeSubscription(String subscriptionId) {
Assert.hasText(subscriptionId, "subscriptionId must not be empty");
for (String destination : this.subscriptions.keySet()) {
Set<String> subscriptionIds = this.subscriptions.get(destination);
if (subscriptionIds.remove(subscriptionId)) {
if (subscriptionIds.isEmpty()) {
this.subscriptions.remove(destination);
}
return destination;
}
}
return null;
}
@Override
public Set<String> getSubscriptionsByDestination(String destination) {
Assert.hasText(destination, "destination must not be empty");
return this.subscriptions.get(destination);
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof DefaultSessionSubscriptionRegistration)) {
return false;
}
DefaultSessionSubscriptionRegistration otherType = (DefaultSessionSubscriptionRegistration) other;
return this.sessionId.equals(otherType.sessionId);
}
@Override
public int hashCode() {
return 31 + this.sessionId.hashCode();
}
@Override
public String toString() {
return "DefaultSessionSubscriptionRegistration [sessionId=" + this.sessionId
+ ", subscriptions=" + this.subscriptions + "]";
}
}

View File

@ -1,83 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.web.messaging.SessionSubscriptionRegistration;
import org.springframework.web.messaging.SessionSubscriptionRegistry;
/**
* A default implementation of SessionSubscriptionRegistry.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class DefaultSessionSubscriptionRegistry implements SessionSubscriptionRegistry {
// sessionId -> SessionSubscriptionRegistration
private final Map<String, SessionSubscriptionRegistration> registrations =
new ConcurrentHashMap<String, SessionSubscriptionRegistration>();
@Override
public SessionSubscriptionRegistration getRegistration(String sessionId) {
return this.registrations.get(sessionId);
}
@Override
public SessionSubscriptionRegistration getOrCreateRegistration(String sessionId) {
SessionSubscriptionRegistration registration = this.registrations.get(sessionId);
if (registration == null) {
registration = new DefaultSessionSubscriptionRegistration(sessionId);
this.registrations.put(sessionId, registration);
}
return registration;
}
@Override
public SessionSubscriptionRegistration removeRegistration(String sessionId) {
return this.registrations.remove(sessionId);
}
@Override
public Set<String> getSessionSubscriptions(String sessionId, String destination) {
SessionSubscriptionRegistration registration = this.registrations.get(sessionId);
return (registration != null) ? registration.getSubscriptionsByDestination(destination) : null;
}
/**
* The default implementation performs a lookup by destination on each registration.
* For a more efficient algorithm consider decorating an instance of this class with
* {@link CachingSessionSubscriptionRegistry}.
*/
@Override
public Set<SessionSubscriptionRegistration> getRegistrationsByDestination(String destination) {
Set<SessionSubscriptionRegistration> result = new HashSet<SessionSubscriptionRegistration>();
for (SessionSubscriptionRegistration r : this.registrations.values()) {
if (r.getSubscriptionsByDestination(destination) != null) {
result.add(r);
}
}
return result.isEmpty() ? null : result;
}
}

View File

@ -28,6 +28,7 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.service.broker.SimpleBrokerWebMessageHandler;
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
import static org.junit.Assert.*;
@ -89,29 +90,29 @@ public class SimpleBrokerWebMessageHandlerTests {
@Test
public void subcribeDisconnectPublish() {
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub1", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub2", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess1", "sub3", "/bar"));
String sess1 = "sess1";
String sess2 = "sess2";
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub1", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub2", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage("sess2", "sub3", "/bar"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub1", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub2", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess1, "sub3", "/bar"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub1", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub2", "/foo"));
this.messageHandler.handleSubscribe(createSubscriptionMessage(sess2, "sub3", "/bar"));
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.DISCONNECT);
headers.setSessionId("sess1");
headers.setSessionId(sess1);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toMap()).build();
this.messageHandler.handleDisconnect(message);
this.messageHandler.handlePublish(createMessage("/foo", "message1"));
this.messageHandler.handlePublish(createMessage("/bar", "message2"));
verify(this.clientChannel, times(6)).send(this.messageCaptor.capture());
assertCapturedMessage("sess1", "sub1", "/foo");
assertCapturedMessage("sess1", "sub2", "/foo");
assertCapturedMessage("sess2", "sub1", "/foo");
assertCapturedMessage("sess2", "sub2", "/foo");
assertCapturedMessage("sess1", "sub3", "/bar");
assertCapturedMessage("sess2", "sub3", "/bar");
verify(this.clientChannel, times(3)).send(this.messageCaptor.capture());
assertCapturedMessage(sess2, "sub1", "/foo");
assertCapturedMessage(sess2, "sub2", "/foo");
assertCapturedMessage(sess2, "sub3", "/bar");
}

View File

@ -0,0 +1,242 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.service.broker;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.MultiValueMap;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.support.WebMessageHeaderAccesssor;
import static org.junit.Assert.*;
/**
* Test fixture for {@link DefaultSubscriptionRegistry}.
*
* @author Rossen Stoyanchev
*/
public class DefaultSubscriptionRegistryTests {
private DefaultSubscriptionRegistry registry;
@Before
public void setup() {
this.registry = new DefaultSubscriptionRegistry();
}
@Test
public void addSubscriptionInvalidInput() {
String sessId = "sess01";
String subsId = "subs01";
String dest = "/foo";
this.registry.addSubscription(subscribeMessage(null, subsId, dest));
assertEquals(0, this.registry.findSubscriptions(message(dest)).size());
this.registry.addSubscription(subscribeMessage(sessId, null, dest));
assertEquals(0, this.registry.findSubscriptions(message(dest)).size());
this.registry.addSubscription(subscribeMessage(sessId, subsId, null));
assertEquals(0, this.registry.findSubscriptions(message(dest)).size());
}
@Test
public void addSubscription() {
String sessId = "sess01";
String subsId = "subs01";
String dest = "/foo";
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
assertEquals("Expected one element " + actual, 1, actual.size());
assertEquals(Arrays.asList(subsId), actual.get(sessId));
}
@Test
public void addSubscriptionOneSession() {
String sessId = "sess01";
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
String dest = "/foo";
for (String subId : subscriptionIds) {
this.registry.addSubscription(subscribeMessage(sessId, subId, dest));
}
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
assertEquals("Expected one element " + actual, 1, actual.size());
assertEquals(subscriptionIds, sort(actual.get(sessId)));
}
@Test
public void addSubscriptionMultipleSessions() {
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
String dest = "/foo";
for (String sessId : sessIds) {
for (String subsId : subscriptionIds) {
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
}
}
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
assertEquals("Expected three elements " + actual, 3, actual.size());
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(0))));
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(1))));
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2))));
}
@Test
public void addSubscriptionWithDestinationPattern() {
String sessId = "sess01";
String subsId = "subs01";
String destPattern = "/topic/PRICE.STOCK.*.IBM";
String dest = "/topic/PRICE.STOCK.NASDAQ.IBM";
this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern));
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
assertEquals("Expected one element " + actual, 1, actual.size());
assertEquals(Arrays.asList(subsId), actual.get(sessId));
}
@Test
public void addSubscriptionWithDestinationPatternRegex() {
String sessId = "sess01";
String subsId = "subs01";
String destPattern = "/topic/PRICE.STOCK.*.{ticker:(IBM|MSFT)}";
this.registry.addSubscription(subscribeMessage(sessId, subsId, destPattern));
Message<?> message = message("/topic/PRICE.STOCK.NASDAQ.IBM");
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message);
assertEquals("Expected one element " + actual, 1, actual.size());
assertEquals(Arrays.asList(subsId), actual.get(sessId));
message = message("/topic/PRICE.STOCK.NASDAQ.MSFT");
actual = this.registry.findSubscriptions(message);
assertEquals("Expected one element " + actual, 1, actual.size());
assertEquals(Arrays.asList(subsId), actual.get(sessId));
message = message("/topic/PRICE.STOCK.NASDAQ.VMW");
actual = this.registry.findSubscriptions(message);
assertEquals("Expected no elements " + actual, 0, actual.size());
}
@Test
public void removeSubscription() {
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
String dest = "/foo";
for (String sessId : sessIds) {
for (String subsId : subscriptionIds) {
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
}
}
this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(0)));
this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(1)));
this.registry.removeSubscription(unsubscribeMessage(sessIds.get(0), subscriptionIds.get(2)));
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
assertEquals("Expected three elements " + actual, 2, actual.size());
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(1))));
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2))));
}
@Test
public void removeSessionSubscriptions() {
List<String> sessIds = Arrays.asList("sess01", "sess02", "sess03");
List<String> subscriptionIds = Arrays.asList("subs01", "subs02", "subs03");
String dest = "/foo";
for (String sessId : sessIds) {
for (String subsId : subscriptionIds) {
this.registry.addSubscription(subscribeMessage(sessId, subsId, dest));
}
}
this.registry.removeSessionSubscriptions(sessIds.get(0));
this.registry.removeSessionSubscriptions(sessIds.get(1));
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message(dest));
assertEquals("Expected three elements " + actual, 1, actual.size());
assertEquals(subscriptionIds, sort(actual.get(sessIds.get(2))));
}
@Test
public void findSubscriptionsNoMatches() {
MultiValueMap<String, String> actual = this.registry.findSubscriptions(message("/foo"));
assertEquals("Expected no elements " + actual, 0, actual.size());
}
private Message<?> subscribeMessage(String sessionId, String subscriptionId, String destination) {
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.SUBSCRIBE);
headers.setSessionId(sessionId);
headers.setSubscriptionId(subscriptionId);
if (destination != null) {
headers.setDestination(destination);
}
return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build();
}
private Message<?> unsubscribeMessage(String sessionId, String subscriptionId) {
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create(MessageType.UNSUBSCRIBE);
headers.setSessionId(sessionId);
headers.setSubscriptionId(subscriptionId);
return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build();
}
private Message<?> message(String destination) {
WebMessageHeaderAccesssor headers = WebMessageHeaderAccesssor.create();
headers.setDestination(destination);
return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build();
}
private List<String> sort(List<String> list) {
Collections.sort(list);
return list;
}
}

View File

@ -1,116 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import java.util.Set;
import org.junit.Before;
import org.junit.Test;
import org.springframework.web.messaging.SessionSubscriptionRegistration;
import org.springframework.web.messaging.SessionSubscriptionRegistry;
import static org.junit.Assert.*;
/**
* A test fixture for {@link AbstractSessionSubscriptionRegistry}.
*
* @author Rossen Stoyanchev
*/
public abstract class AbstractSessionSubscriptionRegistryTests {
protected SessionSubscriptionRegistry registry;
@Before
public void setup() {
this.registry = createSessionSubscriptionRegistry();
}
protected abstract SessionSubscriptionRegistry createSessionSubscriptionRegistry();
@Test
public void getRegistration() {
String sessionId = "sess1";
assertNull(this.registry.getRegistration(sessionId));
this.registry.getOrCreateRegistration(sessionId);
assertNotNull(this.registry.getRegistration(sessionId));
assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId());
}
@Test
public void getOrCreateRegistration() {
String sessionId = "sess1";
assertNull(this.registry.getRegistration(sessionId));
SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId);
assertEquals(registration, this.registry.getOrCreateRegistration(sessionId));
}
@Test
public void removeRegistration() {
String sessionId = "sess1";
this.registry.getOrCreateRegistration(sessionId);
assertNotNull(this.registry.getRegistration(sessionId));
assertEquals(sessionId, this.registry.getRegistration(sessionId).getSessionId());
this.registry.removeRegistration(sessionId);
assertNull(this.registry.getRegistration(sessionId));
}
@Test
public void getSessionSubscriptions() {
String sessionId = "sess1";
SessionSubscriptionRegistration registration = this.registry.getOrCreateRegistration(sessionId);
registration.addSubscription("/foo", "sub1");
registration.addSubscription("/foo", "sub2");
Set<String> subscriptions = this.registry.getSessionSubscriptions(sessionId, "/foo");
assertEquals("Wrong number of subscriptions " + subscriptions, 2, subscriptions.size());
assertTrue(subscriptions.contains("sub1"));
assertTrue(subscriptions.contains("sub2"));
}
@Test
public void getRegistrationsByDestination() {
SessionSubscriptionRegistration reg1 = this.registry.getOrCreateRegistration("sess1");
reg1.addSubscription("/foo", "sub1");
SessionSubscriptionRegistration reg2 = this.registry.getOrCreateRegistration("sess2");
reg2.addSubscription("/foo", "sub1");
Set<SessionSubscriptionRegistration> actual = this.registry.getRegistrationsByDestination("/foo");
assertEquals(2, actual.size());
assertTrue(actual.contains(reg1));
assertTrue(actual.contains(reg2));
reg1.removeSubscription("sub1");
actual = this.registry.getRegistrationsByDestination("/foo");
assertEquals("Invalid set of registrations " + actual, 1, actual.size());
assertTrue(actual.contains(reg2));
reg2.removeSubscription("sub1");
actual = this.registry.getRegistrationsByDestination("/foo");
assertNull("Unexpected registrations " + actual, actual);
}
}

View File

@ -1,35 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import org.springframework.web.messaging.SessionSubscriptionRegistry;
/**
* Test fixture for {@link CachingSessionSubscriptionRegistry}.
*
* @author Rossen Stoyanchev
*/
public class CachingSessionSubscriptionRegistryTests extends AbstractSessionSubscriptionRegistryTests {
@Override
protected SessionSubscriptionRegistry createSessionSubscriptionRegistry() {
return new CachingSessionSubscriptionRegistry(new DefaultSessionSubscriptionRegistry());
}
}

View File

@ -1,82 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.support;
import java.util.Set;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Test fixture for {@link DefaultSessionSubscriptionRegistration}.
*
* @author Rossen Stoyanchev
*/
public class DefaultSessionSubscriptionRegistrationTests {
private DefaultSessionSubscriptionRegistration registration;
@Before
public void setup() {
this.registration = new DefaultSessionSubscriptionRegistration("sess1");
}
@Test
public void addSubscriptions() {
this.registration.addSubscription("/foo", "sub1");
this.registration.addSubscription("/foo", "sub2");
this.registration.addSubscription("/bar", "sub3");
this.registration.addSubscription("/bar", "sub4");
assertSet(this.registration.getSubscriptionsByDestination("/foo"), 2, "sub1", "sub2");
assertSet(this.registration.getSubscriptionsByDestination("/bar"), 2, "sub3", "sub4");
assertSet(this.registration.getDestinations(), 2, "/foo", "/bar");
}
@Test
public void removeSubscriptions() {
this.registration.addSubscription("/foo", "sub1");
this.registration.addSubscription("/foo", "sub2");
this.registration.addSubscription("/bar", "sub3");
this.registration.addSubscription("/bar", "sub4");
assertEquals("/foo", this.registration.removeSubscription("sub1"));
assertEquals("/foo", this.registration.removeSubscription("sub2"));
assertNull(this.registration.getSubscriptionsByDestination("/foo"));
assertSet(this.registration.getDestinations(), 1, "/bar");
assertEquals("/bar", this.registration.removeSubscription("sub3"));
assertEquals("/bar", this.registration.removeSubscription("sub4"));
assertNull(this.registration.getSubscriptionsByDestination("/bar"));
assertSet(this.registration.getDestinations(), 0);
}
private void assertSet(Set<String> set, int size, String... elements) {
assertEquals("Wrong number of elements in " + set, size, set.size());
for (String element : elements) {
assertTrue("Set does not contain element " + element, set.contains(element));
}
}
}