Make use of enhanced MessageHeaderAccessor support

Mutate rather than re-create headers when decoding STOMP messages
before a message is sent on a message channel.

Use MessageBuilder.createMessage to ensure the fully prepared
MessageHeaders is used directly MessageHeaderAccessor instance.

Issue: SPR-11468
This commit is contained in:
Rossen Stoyanchev 2014-04-10 23:57:45 -04:00
parent 4867546aec
commit ae942ffdb8
22 changed files with 322 additions and 236 deletions

View File

@ -26,6 +26,7 @@ import org.springframework.messaging.MessagingException;
import org.springframework.messaging.converter.MessageConversionException; import org.springframework.messaging.converter.MessageConversionException;
import org.springframework.messaging.converter.MessageConverter; import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.converter.SimpleMessageConverter; import org.springframework.messaging.converter.SimpleMessageConverter;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
@ -130,6 +131,13 @@ public abstract class AbstractMessageSendingTemplate<D> implements MessageSendin
MessagePostProcessor postProcessor) throws MessagingException { MessagePostProcessor postProcessor) throws MessagingException {
headers = processHeadersToSend(headers); headers = processHeadersToSend(headers);
MessageHeaders messageHeaders;
if (headers != null && headers instanceof MessageHeaders) {
MessageHeaderAccessor.getAccessor()
}
MessageHeaders messageHeaders = (headers != null) ? new MessageHeaders(headers) : null; MessageHeaders messageHeaders = (headers != null) ? new MessageHeaders(headers) : null;
Message<?> message = this.converter.toMessage(payload, messageHeaders); Message<?> message = this.converter.toMessage(payload, messageHeaders);

View File

@ -26,7 +26,7 @@ import org.springframework.messaging.handler.annotation.Header;
import org.springframework.messaging.handler.annotation.Headers; import org.springframework.messaging.handler.annotation.Headers;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.ClassUtils; import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
/** /**
@ -42,7 +42,6 @@ import org.springframework.util.ReflectionUtils;
*/ */
public class HeadersMethodArgumentResolver implements HandlerMethodArgumentResolver { public class HeadersMethodArgumentResolver implements HandlerMethodArgumentResolver {
@Override @Override
public boolean supportsParameter(MethodParameter parameter) { public boolean supportsParameter(MethodParameter parameter) {
Class<?> paramType = parameter.getParameterType(); Class<?> paramType = parameter.getParameterType();
@ -60,15 +59,23 @@ public class HeadersMethodArgumentResolver implements HandlerMethodArgumentResol
return message.getHeaders(); return message.getHeaders();
} }
else if (MessageHeaderAccessor.class.equals(paramType)) { else if (MessageHeaderAccessor.class.equals(paramType)) {
return new MessageHeaderAccessor(message); MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
return (accessor != null ? accessor : new MessageHeaderAccessor(message));
} }
else if (MessageHeaderAccessor.class.isAssignableFrom(paramType)) { else if (MessageHeaderAccessor.class.isAssignableFrom(paramType)) {
Method factoryMethod = ClassUtils.getMethod(paramType, "wrap", Message.class); MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
return ReflectionUtils.invokeMethod(factoryMethod, null, message); if (accessor != null && paramType.isAssignableFrom(accessor.getClass())) {
return accessor;
}
else {
Method method = ReflectionUtils.findMethod(paramType, "wrap", Message.class);
Assert.notNull(method, "Cannot create accessor of type " + paramType + " for message " + message);
return ReflectionUtils.invokeMethod(method, null, message);
}
} }
else { else {
throw new IllegalStateException("Unexpected method parameter type " throw new IllegalStateException(
+ paramType + "in method " + parameter.getMethod() + ". " "Unexpected method parameter type " + paramType + "in method " + parameter.getMethod() + ". "
+ "@Headers method arguments must be assignable to java.util.Map."); + "@Headers method arguments must be assignable to java.util.Map.");
} }
} }

View File

@ -42,6 +42,7 @@ import org.springframework.messaging.handler.DestinationPatternsMessageCondition
import org.springframework.messaging.handler.HandlerMethod; import org.springframework.messaging.handler.HandlerMethod;
import org.springframework.messaging.handler.HandlerMethodSelector; import org.springframework.messaging.handler.HandlerMethodSelector;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@ -339,8 +340,9 @@ public abstract class AbstractMethodMessageHandler<T>
logger.debug("Handling message, lookupDestination=" + lookupDestination); logger.debug("Handling message, lookupDestination=" + lookupDestination);
} }
message = MessageBuilder.fromMessage(message).setHeader( MessageHeaderAccessor headerAccessor = MessageHeaderAccessor.getMutableAccessor(message);
DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, lookupDestination).build(); headerAccessor.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, lookupDestination);
message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
handleMessageInternal(message, lookupDestination); handleMessageInternal(message, lookupDestination);
} }

View File

@ -73,7 +73,7 @@ public class SimpMessageTypeMessageCondition extends AbstractMessageCondition<Si
@Override @Override
public SimpMessageTypeMessageCondition getMatchingCondition(Message<?> message) { public SimpMessageTypeMessageCondition getMatchingCondition(Message<?> message) {
Object actualMessageType = message.getHeaders().get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER); Object actualMessageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders());
if (actualMessageType == null) { if (actualMessageType == null) {
return null; return null;
} }
@ -83,7 +83,7 @@ public class SimpMessageTypeMessageCondition extends AbstractMessageCondition<Si
@Override @Override
public int compareTo(SimpMessageTypeMessageCondition other, Message<?> message) { public int compareTo(SimpMessageTypeMessageCondition other, Message<?> message) {
Object actualMessageType = message.getHeaders().get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER); Object actualMessageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders());
if (actualMessageType != null) { if (actualMessageType != null) {
if (actualMessageType.equals(this.getMessageType()) && actualMessageType.equals(other.getMessageType())) { if (actualMessageType.equals(this.getMessageType()) && actualMessageType.equals(other.getMessageType())) {
return 0; return 0;

View File

@ -108,8 +108,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
@Override @Override
public void send(Message<?> message) { public void send(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders());
String destination = headers.getDestination();
destination = (destination != null) ? destination : getRequiredDefaultDestination(); destination = (destination != null) ? destination : getRequiredDefaultDestination();
doSend(destination, message); doSend(destination, message);
} }
@ -118,10 +117,10 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
protected void doSend(String destination, Message<?> message) { protected void doSend(String destination, Message<?> message) {
Assert.notNull(destination, "Destination must not be null"); Assert.notNull(destination, "Destination must not be null");
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message);
headers.setDestination(destination); headerAccessor.setDestination(destination);
headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); headerAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
long timeout = this.sendTimeout; long timeout = this.sendTimeout;
boolean sent = (timeout >= 0) boolean sent = (timeout >= 0)

View File

@ -37,8 +37,7 @@ public class PrincipalMethodArgumentResolver implements HandlerMethodArgumentRes
@Override @Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception { public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
Principal user = headers.getUser();
if (user == null) { if (user == null) {
throw new MissingSessionUserException(message); throw new MissingSessionUserException(message);
} }

View File

@ -23,6 +23,7 @@ import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.core.MessagePostProcessor; import org.springframework.messaging.core.MessagePostProcessor;
import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.annotation.SendTo;
@ -113,29 +114,28 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
} }
@Override @Override
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> inputMessage) public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> message)
throws Exception { throws Exception {
if (returnValue == null) { if (returnValue == null) {
return; return;
} }
SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(inputMessage); MessageHeaders headers = message.getHeaders();
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
String sessionId = inputHeaders.getSessionId();
MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(sessionId); MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(sessionId);
SendToUser sendToUser = returnType.getMethodAnnotation(SendToUser.class); SendToUser sendToUser = returnType.getMethodAnnotation(SendToUser.class);
if (sendToUser != null) { if (sendToUser != null) {
Principal principal = inputHeaders.getUser(); Principal principal = SimpMessageHeaderAccessor.getUser(headers);
if (principal == null) { if (principal == null) {
throw new MissingSessionUserException(inputMessage); throw new MissingSessionUserException(message);
} }
String userName = principal.getName(); String userName = principal.getName();
if (principal instanceof DestinationUserNameProvider) { if (principal instanceof DestinationUserNameProvider) {
userName = ((DestinationUserNameProvider) principal).getDestinationUserName(); userName = ((DestinationUserNameProvider) principal).getDestinationUserName();
} }
String[] destinations = getTargetDestinations(sendToUser, inputHeaders, this.defaultUserDestinationPrefix); String[] destinations = getTargetDestinations(sendToUser, message, this.defaultUserDestinationPrefix);
for (String destination : destinations) { for (String destination : destinations) {
this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, postProcessor); this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, postProcessor);
} }
@ -143,15 +143,14 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
} }
else { else {
SendTo sendTo = returnType.getMethodAnnotation(SendTo.class); SendTo sendTo = returnType.getMethodAnnotation(SendTo.class);
String[] destinations = getTargetDestinations(sendTo, inputHeaders, this.defaultDestinationPrefix); String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix);
for (String destination : destinations) { for (String destination : destinations) {
this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor); this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor);
} }
} }
} }
protected String[] getTargetDestinations(Annotation annot, SimpMessageHeaderAccessor inputHeaders, protected String[] getTargetDestinations(Annotation annot, Message<?> message, String defaultPrefix) {
String defaultPrefix) {
if (annot != null) { if (annot != null) {
String[] value = (String[]) AnnotationUtils.getValue(annot); String[] value = (String[]) AnnotationUtils.getValue(annot);
@ -159,8 +158,8 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
return value; return value;
} }
} }
return new String[] { defaultPrefix + String name = DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER;
inputHeaders.getHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER) }; return new String[] { defaultPrefix + message.getHeaders().get(name) };
} }
@ -176,7 +175,7 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
public Message<?> postProcessMessage(Message<?> message) { public Message<?> postProcessMessage(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setSessionId(this.sessionId); headers.setSessionId(this.sessionId);
return MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); return MessageBuilder.createMessage(message.getPayload(), headers.getMessageHeaders());
} }
} }

View File

@ -63,6 +63,7 @@ import org.springframework.stereotype.Controller;
import org.springframework.util.AntPathMatcher; import org.springframework.util.AntPathMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.PathMatcher; import org.springframework.util.PathMatcher;
import org.springframework.validation.Errors; import org.springframework.validation.Errors;
import org.springframework.validation.Validator; import org.springframework.validation.Validator;
@ -329,7 +330,7 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan
@Override @Override
protected String getDestination(Message<?> message) { protected String getDestination(Message<?> message) {
return (String) message.getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER); return (String) SimpMessageHeaderAccessor.getDestination(message.getHeaders());
} }
@Override @Override
@ -352,13 +353,15 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan
protected void handleMatch(SimpMessageMappingInfo mapping, HandlerMethod handlerMethod, protected void handleMatch(SimpMessageMappingInfo mapping, HandlerMethod handlerMethod,
String lookupDestination, Message<?> message) { String lookupDestination, Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
String matchedPattern = mapping.getDestinationConditions().getPatterns().iterator().next(); String matchedPattern = mapping.getDestinationConditions().getPatterns().iterator().next();
Map<String, String> vars = getPathMatcher().extractUriTemplateVariables(matchedPattern, lookupDestination); Map<String, String> vars = getPathMatcher().extractUriTemplateVariables(matchedPattern, lookupDestination);
headers.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars); if (!CollectionUtils.isEmpty(vars)) {
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars);
message = MessageBuilder.createMessage(message.getPayload(), headers.getMessageHeaders());
}
super.handleMatch(mapping, handlerMethod, lookupDestination, message); super.handleMatch(mapping, handlerMethod, lookupDestination, message);
} }

View File

@ -18,6 +18,7 @@ package org.springframework.messaging.simp.annotation.support;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.core.MessagePostProcessor; import org.springframework.messaging.core.MessagePostProcessor;
import org.springframework.messaging.core.MessageSendingOperations; import org.springframework.messaging.core.MessageSendingOperations;
import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.annotation.SendTo;
@ -71,13 +72,12 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
return; return;
} }
SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(message); MessageHeaders headers = message.getHeaders();
String sessionId = inputHeaders.getSessionId(); String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
String subscriptionId = inputHeaders.getSubscriptionId(); String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers);
String destination = inputHeaders.getDestination(); String destination = SimpMessageHeaderAccessor.getDestination(headers);
Assert.state(inputHeaders.getSubscriptionId() != null, Assert.state(subscriptionId != null, "No subsriptiondId in input message to method " + returnType.getMethod());
"No subsriptiondId in input message to method " + returnType.getMethod());
MessagePostProcessor postProcessor = new SubscriptionHeaderPostProcessor(sessionId, subscriptionId); MessagePostProcessor postProcessor = new SubscriptionHeaderPostProcessor(sessionId, subscriptionId);
this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor); this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor);
@ -98,11 +98,11 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
@Override @Override
public Message<?> postProcessMessage(Message<?> message) { public Message<?> postProcessMessage(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message);
headers.setSessionId(this.sessionId); headerAccessor.setSessionId(this.sessionId);
headers.setSubscriptionId(this.subscriptionId); headerAccessor.setSubscriptionId(this.subscriptionId);
headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); headerAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
return MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); return MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
} }
} }
} }

View File

@ -19,8 +19,10 @@ package org.springframework.messaging.simp.broker;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
/** /**
@ -38,29 +40,31 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
@Override @Override
public final void registerSubscription(Message<?> message) { public final void registerSubscription(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
if (!SimpMessageType.SUBSCRIBE.equals(headers.getMessageType())) { MessageHeaders headers = message.getHeaders();
SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers);
if (!SimpMessageType.SUBSCRIBE.equals(type)) {
logger.error("Expected SUBSCRIBE message: " + message); logger.error("Expected SUBSCRIBE message: " + message);
return; return;
} }
String sessionId = headers.getSessionId(); String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (sessionId == null) { if (sessionId == null) {
logger.error("Ignoring subscription. No sessionId in message: " + message); logger.error("Ignoring subscription. No sessionId in message: " + message);
return; return;
} }
String subscriptionId = headers.getSubscriptionId(); String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers);
if (subscriptionId == null) { if (subscriptionId == null) {
logger.error("Ignoring subscription. No subscriptionId in message: " + message); logger.error("Ignoring subscription. No subscriptionId in message: " + message);
return; return;
} }
String destination = headers.getDestination(); String destination = SimpMessageHeaderAccessor.getDestination(headers);
if (destination == null) { if (destination == null) {
logger.error("Ignoring destination. No destination in message: " + message); logger.error("Ignoring destination. No destination in message: " + message);
return; return;
} }
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Adding subscription id=" + headers.getSubscriptionId() logger.debug("Adding subscription id=" + subscriptionId + ", destination=" + destination);
+ ", destination=" + headers.getDestination());
} }
addSubscriptionInternal(sessionId, subscriptionId, destination, message); addSubscriptionInternal(sessionId, subscriptionId, destination, message);
} }
@ -70,17 +74,20 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
@Override @Override
public final void unregisterSubscription(Message<?> message) { public final void unregisterSubscription(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
if (!SimpMessageType.UNSUBSCRIBE.equals(headers.getMessageType())) { MessageHeaders headers = message.getHeaders();
SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers);
if (!SimpMessageType.UNSUBSCRIBE.equals(type)) {
logger.error("Expected UNSUBSCRIBE message: " + message); logger.error("Expected UNSUBSCRIBE message: " + message);
return; return;
} }
String sessionId = headers.getSessionId(); String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (sessionId == null) { if (sessionId == null) {
logger.error("Ignoring subscription. No sessionId in message: " + message); logger.error("Ignoring subscription. No sessionId in message: " + message);
return; return;
} }
String subscriptionId = headers.getSubscriptionId(); String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers);
if (subscriptionId == null) { if (subscriptionId == null) {
logger.error("Ignoring subscription. No subscriptionId in message: " + message); logger.error("Ignoring subscription. No subscriptionId in message: " + message);
return; return;
@ -98,19 +105,22 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
@Override @Override
public final MultiValueMap<String, String> findSubscriptions(Message<?> message) { public final MultiValueMap<String, String> findSubscriptions(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
if (!SimpMessageType.MESSAGE.equals(headers.getMessageType())) { MessageHeaders headers = message.getHeaders();
logger.trace("Ignoring message type " + headers.getMessageType()); SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers);
if (!SimpMessageType.MESSAGE.equals(type)) {
logger.trace("Ignoring message type " + type);
return null; return null;
} }
String destination = headers.getDestination(); String destination = SimpMessageHeaderAccessor.getDestination(headers);
if (destination == null) { if (destination == null) {
logger.trace("Ignoring message, no destination"); logger.trace("Ignoring message, no destination");
return null; return null;
} }
MultiValueMap<String, String> result = findSubscriptionsInternal(destination, message); MultiValueMap<String, String> result = findSubscriptionsInternal(destination, message);
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Found " + result.size() + " subscriptions for destination=" + headers.getDestination()); logger.trace("Found " + result.size() + " subscriptions for destination=" + destination);
} }
return result; return result;
} }

View File

@ -20,6 +20,7 @@ import java.util.Collection;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
@ -111,9 +112,10 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
@Override @Override
protected void handleMessageInternal(Message<?> message) { protected void handleMessageInternal(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); MessageHeaders headers = message.getHeaders();
SimpMessageType messageType = headers.getMessageType(); SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
String destination = headers.getDestination(); String destination = SimpMessageHeaderAccessor.getDestination(headers);
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (!checkDestinationPrefix(destination)) { if (!checkDestinationPrefix(destination)) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
@ -122,27 +124,30 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
return; return;
} }
if (SimpMessageType.SUBSCRIBE.equals(messageType)) { if (SimpMessageType.MESSAGE.equals(messageType)) {
sendMessageToSubscribers(destination, message);
}
else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
this.subscriptionRegistry.registerSubscription(message); this.subscriptionRegistry.registerSubscription(message);
} }
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
this.subscriptionRegistry.unregisterSubscription(message); this.subscriptionRegistry.unregisterSubscription(message);
} }
else if (SimpMessageType.MESSAGE.equals(messageType)) {
sendMessageToSubscribers(headers.getDestination(), message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) { else if (SimpMessageType.DISCONNECT.equals(messageType)) {
String sessionId = headers.getSessionId(); this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
} }
else if (SimpMessageType.CONNECT.equals(messageType)) { else if (SimpMessageType.CONNECT.equals(messageType)) {
SimpMessageHeaderAccessor replyHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
replyHeaders.setSessionId(headers.getSessionId()); accessor.setSessionId(sessionId);
replyHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message); accessor.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
Message<byte[]> connectAck = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
Message<byte[]> connectAck = MessageBuilder.withPayload(EMPTY_PAYLOAD).setHeaders(replyHeaders).build();
this.clientOutboundChannel.send(connectAck); this.clientOutboundChannel.send(connectAck);
} }
else {
if (logger.isTraceEnabled()) {
logger.trace("Message type not supported. Ignoring: " + message);
}
}
} }
protected void sendMessageToSubscribers(String destination, Message<?> message) { protected void sendMessageToSubscribers(String destination, Message<?> message) {
@ -153,17 +158,17 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
} }
for (String sessionId : subscriptions.keySet()) { for (String sessionId : subscriptions.keySet()) {
for (String subscriptionId : subscriptions.get(sessionId)) { for (String subscriptionId : subscriptions.get(sessionId)) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
headers.setSessionId(sessionId); headerAccessor.setSessionId(sessionId);
headers.setSubscriptionId(subscriptionId); headerAccessor.setSubscriptionId(subscriptionId);
headerAccessor.copyHeadersIfAbsent(message.getHeaders());
Object payload = message.getPayload(); Object payload = message.getPayload();
Message<?> clientMessage = MessageBuilder.withPayload(payload).setHeaders(headers).build(); Message<?> reply = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
try { try {
this.clientOutboundChannel.send(clientMessage); this.clientOutboundChannel.send(reply);
} }
catch (Throwable ex) { catch (Throwable ex) {
logger.error("Failed to send message to destination=" + destination + logger.error("Failed to send message=" + message, ex);
", sessionId=" + sessionId + ", subscriptionId=" + subscriptionId, ex);
} }
} }
} }

View File

@ -20,7 +20,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
@ -30,6 +29,7 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler; import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.FixedIntervalReconnectStrategy; import org.springframework.messaging.tcp.FixedIntervalReconnectStrategy;
import org.springframework.messaging.tcp.TcpConnection; import org.springframework.messaging.tcp.TcpConnection;
import org.springframework.messaging.tcp.TcpConnectionHandler; import org.springframework.messaging.tcp.TcpConnectionHandler;
@ -79,9 +79,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
private static final Message<byte[]> HEARTBEAT_MESSAGE; private static final Message<byte[]> HEARTBEAT_MESSAGE;
static { static {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
HEARTBEAT_MESSAGE = MessageBuilder.withPayload(new byte[] {'\n'}).setHeaders(headers).build();
EMPTY_TASK.run(); EMPTY_TASK.run();
StompHeaderAccessor accessor = StompHeaderAccessor.createForHeartbeat();
HEARTBEAT_MESSAGE = MessageBuilder.createMessage(StompDecoder.HEARTBEAT_PAYLOAD, accessor.getMessageHeaders());
} }
@ -370,37 +370,53 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
@Override @Override
protected void handleMessageInternal(Message<?> message) { protected void handleMessageInternal(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); String sessionId = SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
String sessionId = headers.getSessionId();
if (!isBrokerAvailable()) { if (!isBrokerAvailable()) {
if (sessionId == null || sessionId == SystemStompConnectionHandler.SESSION_ID) { if (sessionId == null || SystemStompConnectionHandler.SESSION_ID.equals(sessionId)) {
throw new MessageDeliveryException("Message broker is not active."); throw new MessageDeliveryException("Message broker is not active.");
} }
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Message broker is not active. Ignoring message id=" + message.getHeaders().getId()); logger.trace("Message broker is not active. Ignoring: " + message);
} }
return; return;
} }
String destination = headers.getDestination(); StompHeaderAccessor stompAccessor;
StompCommand command = headers.getCommand(); StompCommand command;
SimpMessageType messageType = headers.getMessageType();
if (SimpMessageType.MESSAGE.equals(messageType)) { MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
sessionId = (sessionId == null) ? SystemStompConnectionHandler.SESSION_ID : sessionId; if (accessor == null) {
headers.setSessionId(sessionId); logger.error("No header accessor, please use SimpMessagingTemplate. Ignoring: " + message);
command = headers.updateStompCommandAsClientMessage(); return;
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); }
else if (accessor instanceof StompHeaderAccessor) {
stompAccessor = (StompHeaderAccessor) accessor;
command = stompAccessor.getCommand();
}
else if (accessor instanceof SimpMessageHeaderAccessor) {
stompAccessor = StompHeaderAccessor.wrap(message);
command = stompAccessor.getCommand();
if (command == null) {
command = stompAccessor.updateStompCommandAsClientMessage();
}
}
else {
// Should not happen
logger.error("Unexpected header accessor type: " + accessor + ". Ignoring: " + message);
return;
} }
if (sessionId == null) { if (sessionId == null) {
if (logger.isWarnEnabled()) { if (!SimpMessageType.MESSAGE.equals(stompAccessor.getMessageType())) {
logger.warn("No sessionId, ignoring message: " + message); logger.error("Only STOMP SEND frames supported on \"system\" connection. Ignoring: " + message);
return;
} }
return; sessionId = SystemStompConnectionHandler.SESSION_ID;
stompAccessor.setSessionId(sessionId);
} }
String destination = stompAccessor.getDestination();
if ((command != null) && command.requiresDestination() && !checkDestinationPrefix(destination)) { if ((command != null) && command.requiresDestination() && !checkDestinationPrefix(destination)) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Ignoring message to destination=" + destination); logger.trace("Ignoring message to destination=" + destination);
@ -412,20 +428,21 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
logger.trace("Processing message=" + message); logger.trace("Processing message=" + message);
} }
if (SimpMessageType.CONNECT.equals(messageType)) { if (StompCommand.CONNECT.equals(command)) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Processing CONNECT (total connected=" + this.connectionHandlers.size() + ")"); logger.debug("Processing CONNECT (total connected=" + this.connectionHandlers.size() + ")");
} }
headers.setLogin(this.clientLogin); stompAccessor = (stompAccessor.isMutable() ? stompAccessor : StompHeaderAccessor.wrap(message));
headers.setPasscode(this.clientPasscode); stompAccessor.setLogin(this.clientLogin);
stompAccessor.setPasscode(this.clientPasscode);
if (getVirtualHost() != null) { if (getVirtualHost() != null) {
headers.setHost(getVirtualHost()); stompAccessor.setHost(getVirtualHost());
} }
StompConnectionHandler handler = new StompConnectionHandler(sessionId, headers); StompConnectionHandler handler = new StompConnectionHandler(sessionId, stompAccessor);
this.connectionHandlers.put(sessionId, handler); this.connectionHandlers.put(sessionId, handler);
this.tcpClient.connect(handler); this.tcpClient.connect(handler);
} }
else if (SimpMessageType.DISCONNECT.equals(messageType)) { else if (StompCommand.DISCONNECT.equals(command)) {
StompConnectionHandler handler = this.connectionHandlers.get(sessionId); StompConnectionHandler handler = this.connectionHandlers.get(sessionId);
if (handler == null) { if (handler == null) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
@ -433,7 +450,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
} }
return; return;
} }
handler.forward(message); handler.forward(message, stompAccessor);
} }
else { else {
StompConnectionHandler handler = this.connectionHandlers.get(sessionId); StompConnectionHandler handler = this.connectionHandlers.get(sessionId);
@ -443,7 +460,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
} }
return; return;
} }
handler.forward(message); handler.forward(message, stompAccessor);
} }
} }
@ -486,7 +503,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
logger.debug("Established TCP connection to broker in session '" + this.sessionId + "'"); logger.debug("Established TCP connection to broker in session '" + this.sessionId + "'");
} }
this.tcpConnection = connection; this.tcpConnection = connection;
connection.send(MessageBuilder.withPayload(EMPTY_PAYLOAD).setHeaders(this.connectHeaders).build()); connection.send(MessageBuilder.createMessage(EMPTY_PAYLOAD, this.connectHeaders.getMessageHeaders()));
} }
@Override @Override
@ -522,7 +539,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setSessionId(this.sessionId); headers.setSessionId(this.sessionId);
headers.setMessage(errorText); headers.setMessage(errorText);
Message<?> errorMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<?> errorMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
sendMessageToClient(errorMessage); sendMessageToClient(errorMessage);
} }
} }
@ -536,20 +553,23 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
@Override @Override
public void handleMessage(Message<byte[]> message) { public void handleMessage(Message<byte[]> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); StompHeaderAccessor headerAccessor =
if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) { MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (headerAccessor.isHeartbeat()) {
logger.trace("Received broker heartbeat"); logger.trace("Received broker heartbeat");
} }
else if (logger.isDebugEnabled()) { else if (logger.isDebugEnabled()) {
logger.debug("Received message from broker in session '" + this.sessionId + "'"); logger.debug("Received message from broker in session '" + this.sessionId + "'");
} }
if (StompCommand.CONNECTED == headers.getCommand()) { if (StompCommand.CONNECTED == headerAccessor.getCommand()) {
afterStompConnected(headers); afterStompConnected(headerAccessor);
} }
headers.setSessionId(this.sessionId); headerAccessor.setSessionId(this.sessionId);
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); headerAccessor.setImmutable();
sendMessageToClient(message); sendMessageToClient(message);
} }
@ -630,9 +650,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
clearConnection(); clearConnection();
} }
catch (Throwable t) { catch (Throwable t) {
if (logger.isErrorEnabled()) { // Ignore
// Ignore
}
} }
} }
} }
@ -661,7 +679,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
* @return a future to wait for the result * @return a future to wait for the result
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public ListenableFuture<Void> forward(final Message<?> message) { public ListenableFuture<Void> forward(Message<?> message, final StompHeaderAccessor headerAccessor) {
TcpConnection<byte[]> conn = this.tcpConnection; TcpConnection<byte[]> conn = this.tcpConnection;
@ -682,8 +700,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
} }
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (headerAccessor.isHeartbeat()) {
if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) {
logger.trace("Forwarding heartbeat to broker"); logger.trace("Forwarding heartbeat to broker");
} }
else { else {
@ -691,13 +708,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
} }
} }
if (headerAccessor.isMutable() && headerAccessor.isModified()) {
message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
}
ListenableFuture<Void> future = conn.send((Message<byte[]>) message); ListenableFuture<Void> future = conn.send((Message<byte[]>) message);
future.addCallback(new ListenableFutureCallback<Void>() { future.addCallback(new ListenableFutureCallback<Void>() {
@Override @Override
public void onSuccess(Void result) { public void onSuccess(Void result) {
StompCommand command = StompHeaderAccessor.wrap(message).getCommand(); if (headerAccessor.getCommand() == StompCommand.DISCONNECT) {
if (command == StompCommand.DISCONNECT) {
clearConnection(); clearConnection();
} }
} }
@ -707,7 +727,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
// already reset // already reset
} }
else { else {
handleTcpConnectionFailure("Failed to send message " + message, t); handleTcpConnectionFailure("Failed to send message " + headerAccessor, t);
} }
} }
}); });
@ -777,9 +797,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
} }
@Override @Override
public ListenableFuture<Void> forward(Message<?> message) { public ListenableFuture<Void> forward(Message<?> message, StompHeaderAccessor headerAccessor) {
try { try {
ListenableFuture<Void> future = super.forward(message); ListenableFuture<Void> future = super.forward(message, headerAccessor);
future.get(); future.get();
return future; return future;
} }

View File

@ -19,6 +19,7 @@ package org.springframework.messaging.simp.user;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -100,34 +101,34 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
@Override @Override
public UserDestinationResult resolveDestination(Message<?> message) { public UserDestinationResult resolveDestination(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders());
DestinationInfo info = parseUserDestination(headers); DestinationInfo info = parseUserDestination(message);
if (info == null) { if (info == null) {
return null; return null;
} }
Set<String> targetDestinations = new HashSet<String>(); Set<String> targetDestinations = new HashSet<String>();
for (String sessionId : info.getSessionIds()) { for (String sessionId : info.getSessionIds()) {
targetDestinations.add(getTargetDestination( targetDestinations.add(getTargetDestination(destination,
headers.getDestination(), info.getDestinationWithoutPrefix(), sessionId, info.getUser())); info.getDestinationWithoutPrefix(), sessionId, info.getUser()));
} }
return new UserDestinationResult(headers.getDestination(), return new UserDestinationResult(destination,
targetDestinations, info.getSubscribeDestination(), info.getUser()); targetDestinations, info.getSubscribeDestination(), info.getUser());
} }
private DestinationInfo parseUserDestination(SimpMessageHeaderAccessor headers) { private DestinationInfo parseUserDestination(Message<?> message) {
String destination = headers.getDestination(); MessageHeaders headers = message.getHeaders();
SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
String destination = SimpMessageHeaderAccessor.getDestination(headers);
Principal principal = SimpMessageHeaderAccessor.getUser(headers);
String destinationWithoutPrefix; String destinationWithoutPrefix;
String subscribeDestination; String subscribeDestination;
String user; String user;
Set<String> sessionIds; Set<String> sessionIds;
Principal principal = headers.getUser();
SimpMessageType messageType = headers.getMessageType();
if (SimpMessageType.SUBSCRIBE.equals(messageType) || SimpMessageType.UNSUBSCRIBE.equals(messageType)) { if (SimpMessageType.SUBSCRIBE.equals(messageType) || SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
if (!checkDestination(destination, this.destinationPrefix)) { if (!checkDestination(destination, this.destinationPrefix)) {
return null; return null;
@ -136,14 +137,15 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
logger.error("Ignoring message, no principal info available"); logger.error("Ignoring message, no principal info available");
return null; return null;
} }
if (headers.getSessionId() == null) { String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (sessionId == null) {
logger.error("Ignoring message, no session id available"); logger.error("Ignoring message, no session id available");
return null; return null;
} }
destinationWithoutPrefix = destination.substring(this.destinationPrefix.length()-1); destinationWithoutPrefix = destination.substring(this.destinationPrefix.length()-1);
subscribeDestination = destination; subscribeDestination = destination;
user = principal.getName(); user = principal.getName();
sessionIds = Collections.singleton(headers.getSessionId()); sessionIds = Collections.singleton(sessionId);
} }
else if (SimpMessageType.MESSAGE.equals(messageType)) { else if (SimpMessageType.MESSAGE.equals(messageType)) {
if (!checkDestination(destination, this.destinationPrefix)) { if (!checkDestination(destination, this.destinationPrefix)) {

View File

@ -152,16 +152,16 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
if (destinations.isEmpty()) { if (destinations.isEmpty()) {
return; return;
} }
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); if (SimpMessageType.MESSAGE.equals(SimpMessageHeaderAccessor.getMessageType(message.getHeaders()))) {
if (SimpMessageType.MESSAGE.equals(headerAccessor.getMessageType())) { SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message);
headerAccessor.setHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION, result.getSubscribeDestination()); headerAccessor.setHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION, result.getSubscribeDestination());
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headerAccessor).build(); message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
} }
for (String targetDestination : destinations) { for (String destination : destinations) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Sending message to resolved destination=" + targetDestination); logger.debug("Sending message to resolved destination=" + destination);
} }
this.brokerMessagingTemplate.send(targetDestination, message); this.brokerMessagingTemplate.send(destination, message);
} }
} }

View File

@ -112,10 +112,12 @@ public class MessageConverterTests {
public void toMessageHeadersCopied() { public void toMessageHeadersCopied() {
Map<String, Object> map = new HashMap<String, Object>(); Map<String, Object> map = new HashMap<String, Object>();
map.put("foo", "bar"); map.put("foo", "bar");
MessageHeaders headers = new MessageHeaders(map ); MessageHeaders headers = new MessageHeaders(map);
Message<?> message = this.converter.toMessage("ABC", headers); Message<?> message = this.converter.toMessage("ABC", headers);
assertEquals("bar", message.getHeaders().get("foo")); assertEquals("bar", message.getHeaders().get("foo"));
assertNotNull(message.getHeaders().getId());
assertNotNull(message.getHeaders().getTimestamp());
} }
@Test @Test

View File

@ -31,7 +31,6 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.support.StaticApplicationContext; import org.springframework.context.support.StaticApplicationContext;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.converter.*; import org.springframework.messaging.converter.*;
import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.MessageMapping;

View File

@ -38,6 +38,7 @@ import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException; import org.springframework.messaging.MessagingException;
import org.springframework.messaging.StubMessageChannel; import org.springframework.messaging.StubMessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.broker.BrokerAvailabilityEvent; import org.springframework.messaging.simp.broker.BrokerAvailabilityEvent;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.ExecutorSubscribableChannel; import org.springframework.messaging.support.ExecutorSubscribableChannel;
@ -168,7 +169,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
public void messageDeliverExceptionIfSystemSessionForwardFails() throws Exception { public void messageDeliverExceptionIfSystemSessionForwardFails() throws Exception {
stopActiveMqBrokerAndAwait(); stopActiveMqBrokerAndAwait();
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
this.relay.handleMessage(MessageBuilder.withPayload("test".getBytes()).setHeaders(headers).build()); this.relay.handleMessage(MessageBuilder.createMessage("test".getBytes(), headers.getMessageHeaders()));
} }
@Test @Test
@ -244,7 +245,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId("sess1"); headers.setSessionId("sess1");
this.relay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build()); this.relay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
Thread.sleep(2000); Thread.sleep(2000);
@ -394,7 +395,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
headers.setAcceptVersion("1.1,1.2"); headers.setAcceptVersion("1.1,1.2");
headers.setHeartbeat(0, 0); headers.setHeartbeat(0, 0);
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<?> message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
MessageExchangeBuilder builder = new MessageExchangeBuilder(message); MessageExchangeBuilder builder = new MessageExchangeBuilder(message);
builder.expected.add(new StompConnectedFrameMessageMatcher(sessionId)); builder.expected.add(new StompConnectedFrameMessageMatcher(sessionId));
@ -405,7 +406,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
headers.setAcceptVersion("1.1,1.2"); headers.setAcceptVersion("1.1,1.2");
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<?> message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
MessageExchangeBuilder builder = new MessageExchangeBuilder(message); MessageExchangeBuilder builder = new MessageExchangeBuilder(message);
return builder.andExpectError(); return builder.andExpectError();
} }
@ -418,7 +419,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
headers.setSubscriptionId(subscriptionId); headers.setSubscriptionId(subscriptionId);
headers.setDestination(destination); headers.setDestination(destination);
headers.setReceipt(receiptId); headers.setReceipt(receiptId);
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<?> message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
MessageExchangeBuilder builder = new MessageExchangeBuilder(message); MessageExchangeBuilder builder = new MessageExchangeBuilder(message);
builder.expected.add(new StompReceiptFrameMessageMatcher(sessionId, receiptId)); builder.expected.add(new StompReceiptFrameMessageMatcher(sessionId, receiptId));
@ -426,14 +427,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
} }
public static MessageExchangeBuilder send(String destination, String payload) { public static MessageExchangeBuilder send(String destination, String payload) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
headers.setDestination(destination); headers.setDestination(destination);
Message<?> message = MessageBuilder.withPayload(payload.getBytes(UTF_8)).setHeaders(headers).build(); Message<?> message = MessageBuilder.createMessage(payload.getBytes(UTF_8), headers.getMessageHeaders());
return new MessageExchangeBuilder(message); return new MessageExchangeBuilder(message);
} }
public MessageExchangeBuilder andExpectMessage(String sessionId, String subscriptionId) { public MessageExchangeBuilder andExpectMessage(String sessionId, String subscriptionId) {
Assert.isTrue(StompCommand.SEND.equals(headers.getCommand()), "MESSAGE can only be expected after SEND"); Assert.isTrue(SimpMessageType.MESSAGE.equals(headers.getMessageType()));
String destination = this.headers.getDestination(); String destination = this.headers.getDestination();
Object payload = this.message.getPayload(); Object payload = this.message.getPayload();
this.expected.add(new StompMessageFrameMessageMatcher(sessionId, subscriptionId, destination, payload)); this.expected.add(new StompMessageFrameMessageMatcher(sessionId, subscriptionId, destination, payload));

View File

@ -27,6 +27,7 @@ import org.springframework.messaging.StubMessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.ReconnectStrategy; import org.springframework.messaging.tcp.ReconnectStrategy;
import org.springframework.messaging.tcp.TcpConnection; import org.springframework.messaging.tcp.TcpConnection;
import org.springframework.messaging.tcp.TcpConnectionHandler; import org.springframework.messaging.tcp.TcpConnectionHandler;
@ -77,17 +78,21 @@ public class StompBrokerRelayMessageHandlerTests {
String sessionId = "sess1"; String sessionId = "sess1";
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
this.brokerRelay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build()); this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
List<Message<byte[]>> sent = this.tcpClient.connection.messages; List<Message<byte[]>> sent = this.tcpClient.connection.messages;
assertEquals(2, sent.size()); assertEquals(2, sent.size());
StompHeaderAccessor headers1 = StompHeaderAccessor.wrap(sent.get(0)); StompHeaderAccessor headers1 = StompHeaderAccessor.wrap(sent.get(0));
assertEquals(virtualHost, headers1.getHost()); assertEquals(virtualHost, headers1.getHost());
assertNotNull("The prepared message does not have an accessor",
MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class));
StompHeaderAccessor headers2 = StompHeaderAccessor.wrap(sent.get(1)); StompHeaderAccessor headers2 = StompHeaderAccessor.wrap(sent.get(1));
assertEquals(sessionId, headers2.getSessionId()); assertEquals(sessionId, headers2.getSessionId());
assertEquals(virtualHost, headers2.getHost()); assertEquals(virtualHost, headers2.getHost());
assertNotNull("The prepared message does not have an accessor",
MessageHeaderAccessor.getAccessor(sent.get(1), MessageHeaderAccessor.class));
} }
@Test @Test
@ -104,7 +109,7 @@ public class StompBrokerRelayMessageHandlerTests {
String sessionId = "sess1"; String sessionId = "sess1";
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
this.brokerRelay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build()); this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
List<Message<byte[]>> sent = this.tcpClient.connection.messages; List<Message<byte[]>> sent = this.tcpClient.connection.messages;
assertEquals(2, sent.size()); assertEquals(2, sent.size());
@ -126,11 +131,13 @@ public class StompBrokerRelayMessageHandlerTests {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
headers.setSessionId("sess1"); headers.setSessionId("sess1");
headers.setDestination("/user/daisy/foo"); headers.setDestination("/user/daisy/foo");
this.brokerRelay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build()); this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
List<Message<byte[]>> sent = this.tcpClient.connection.messages; List<Message<byte[]>> sent = this.tcpClient.connection.messages;
assertEquals(1, sent.size()); assertEquals(1, sent.size());
assertEquals(StompCommand.CONNECT, StompHeaderAccessor.wrap(sent.get(0)).getCommand()); assertEquals(StompCommand.CONNECT, StompHeaderAccessor.wrap(sent.get(0)).getCommand());
assertNotNull("The prepared message does not have an accessor",
MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class));
} }

View File

@ -66,8 +66,7 @@ public class UserDestinationMessageHandlerTests {
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture()); Mockito.verify(this.brokerChannel).send(captor.capture());
assertEquals("/queue/foo-user123", assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders()));
captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER));
} }
@Test @Test
@ -79,8 +78,7 @@ public class UserDestinationMessageHandlerTests {
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture()); Mockito.verify(this.brokerChannel).send(captor.capture());
assertEquals("/queue/foo-user123", assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders()));
captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER));
} }
@Test @Test
@ -93,10 +91,8 @@ public class UserDestinationMessageHandlerTests {
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class); ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture()); Mockito.verify(this.brokerChannel).send(captor.capture());
assertEquals("/queue/foo-user123", assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders()));
captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER)); assertEquals("/user/queue/foo", captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION));
assertEquals("/user/queue/foo",
captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION));
} }

View File

@ -41,9 +41,9 @@ import org.springframework.messaging.simp.stomp.StompConversionException;
import org.springframework.messaging.simp.stomp.StompEncoder; import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
@ -79,6 +79,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class); private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);
private static final byte[] EMPTY_PAYLOAD = new byte[0];
private int messageSizeLimit = 64 * 1024; private int messageSizeLimit = 64 * 1024;
@ -172,9 +174,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
for (Message<byte[]> message : messages) { for (Message<byte[]> message : messages) {
try { try {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) { if (headerAccessor.isHeartbeat()) {
logger.trace("Received heartbeat from client session=" + session.getId()); logger.trace("Received heartbeat from client session=" + session.getId());
} }
else { else {
@ -182,13 +187,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
} }
} }
headers.setSessionId(session.getId()); headerAccessor.setSessionId(session.getId());
headers.setSessionAttributes(session.getAttributes()); headerAccessor.setSessionAttributes(session.getAttributes());
headers.setUser(session.getPrincipal()); headerAccessor.setUser(session.getPrincipal());
headerAccessor.setImmutable();
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); if (this.eventPublisher != null && StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
if (this.eventPublisher != null && StompCommand.CONNECT.equals(headers.getCommand())) {
publishEvent(new SessionConnectEvent(this, message)); publishEvent(new SessionConnectEvent(this, message));
} }
@ -212,10 +216,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
protected void sendErrorMessage(WebSocketSession session, Throwable error) { protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setMessage(error.getMessage()); headerAccessor.setMessage(error.getMessage());
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD);
byte[] bytes = this.stompEncoder.encode(message);
try { try {
session.sendMessage(new TextMessage(bytes)); session.sendMessage(new TextMessage(bytes));
} }
@ -231,46 +234,60 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Override @Override
public void handleMessageToClient(WebSocketSession session, Message<?> message) { public void handleMessageToClient(WebSocketSession session, Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (headers.getMessageType() == SimpMessageType.CONNECT_ACK) {
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
connectedHeaders.setVersion(getVersion(headers));
connectedHeaders.setHeartbeat(0, 0); // no heart-beat support with simple broker
headers = connectedHeaders;
}
else if (SimpMessageType.MESSAGE.equals(headers.getMessageType())) {
headers.updateStompCommandAsServerMessage();
}
if (headers.getCommand() == StompCommand.CONNECTED) {
afterStompSessionConnected(headers, session);
}
if (StompCommand.MESSAGE.equals(headers.getCommand())) {
if (headers.getSubscriptionId() == null) {
logger.error("Ignoring message, no subscriptionId header: " + message);
return;
}
String header = SimpMessageHeaderAccessor.ORIGINAL_DESTINATION;
if (message.getHeaders().containsKey(header)) {
headers.setDestination((String) message.getHeaders().get(header));
}
}
if (!(message.getPayload() instanceof byte[])) { if (!(message.getPayload() instanceof byte[])) {
logger.error("Ignoring message, expected byte[] content: " + message); logger.error("Ignoring message, expected byte[] content: " + message);
return; return;
} }
try { MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); if (accessor == null) {
logger.error("No header accessor: " + message);
return;
}
if (this.eventPublisher != null && StompCommand.CONNECTED.equals(headers.getCommand())) { StompHeaderAccessor stompAccessor;
if (accessor instanceof StompHeaderAccessor) {
stompAccessor = (StompHeaderAccessor) accessor;
}
else if (accessor instanceof SimpMessageHeaderAccessor) {
stompAccessor = StompHeaderAccessor.wrap(message);
if (SimpMessageType.CONNECT_ACK.equals(stompAccessor.getMessageType())) {
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
connectedHeaders.setVersion(getVersion(stompAccessor));
connectedHeaders.setHeartbeat(0, 0); // no heart-beat support with simple broker
stompAccessor = connectedHeaders;
}
else if (stompAccessor.getCommand() == null || StompCommand.SEND.equals(stompAccessor.getCommand())) {
stompAccessor.updateStompCommandAsServerMessage();
}
}
else {
// Should not happen
logger.error("Unexpected header accessor type: " + accessor);
return;
}
StompCommand command = stompAccessor.getCommand();
if (StompCommand.MESSAGE.equals(command)) {
if (stompAccessor.getSubscriptionId() == null) {
logger.error("Ignoring message, no subscriptionId header: " + message);
return;
}
String header = SimpMessageHeaderAccessor.ORIGINAL_DESTINATION;
if (message.getHeaders().containsKey(header)) {
stompAccessor = toMutableAccessor(stompAccessor, message);
stompAccessor.setDestination((String) message.getHeaders().get(header));
}
}
else if (StompCommand.CONNECTED.equals(command)) {
stompAccessor = afterStompSessionConnected(message, stompAccessor, session);
if (this.eventPublisher != null && StompCommand.CONNECTED.equals(command)) {
publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message)); publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
} }
}
byte[] bytes = this.stompEncoder.encode((Message<byte[]>) message); try {
byte[] bytes = this.stompEncoder.encode(stompAccessor.getMessageHeaders(), (byte[]) message.getPayload());
TextMessage textMessage = new TextMessage(bytes); TextMessage textMessage = new TextMessage(bytes);
session.sendMessage(textMessage); session.sendMessage(textMessage);
@ -283,7 +300,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
sendErrorMessage(session, ex); sendErrorMessage(session, ex);
} }
finally { finally {
if (StompCommand.ERROR.equals(headers.getCommand())) { if (StompCommand.ERROR.equals(command)) {
try { try {
session.close(CloseStatus.PROTOCOL_ERROR); session.close(CloseStatus.PROTOCOL_ERROR);
} }
@ -294,13 +311,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
} }
} }
protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message<?> message) {
return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message));
}
private String getVersion(StompHeaderAccessor connectAckHeaders) { private String getVersion(StompHeaderAccessor connectAckHeaders) {
String name = StompHeaderAccessor.CONNECT_MESSAGE_HEADER; String name = StompHeaderAccessor.CONNECT_MESSAGE_HEADER;
Message<?> connectMessage = (Message<?>) connectAckHeaders.getHeader(name); Message<?> connectMessage = (Message<?>) connectAckHeaders.getHeader(name);
StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(connectMessage);
Assert.notNull(connectMessage, "CONNECT_ACK does not contain original CONNECT " + connectAckHeaders); Assert.notNull(connectMessage, "CONNECT_ACK does not contain original CONNECT " + connectAckHeaders);
StompHeaderAccessor connectHeaders =
MessageHeaderAccessor.getAccessor(connectMessage, StompHeaderAccessor.class);
Set<String> acceptVersions = connectHeaders.getAcceptVersion(); Set<String> acceptVersions = connectHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) { if (acceptVersions.contains("1.2")) {
return "1.2"; return "1.2";
@ -316,16 +339,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
} }
} }
private void afterStompSessionConnected(StompHeaderAccessor headers, WebSocketSession session) { private StompHeaderAccessor afterStompSessionConnected(
Message<?> message, StompHeaderAccessor headerAccessor, WebSocketSession session) {
Principal principal = session.getPrincipal(); Principal principal = session.getPrincipal();
if (principal != null) { if (principal != null) {
headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); headerAccessor = toMutableAccessor(headerAccessor, message);
headerAccessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
if (this.userSessionRegistry != null) { if (this.userSessionRegistry != null) {
String userName = resolveNameForUserSessionRegistry(principal); String userName = resolveNameForUserSessionRegistry(principal);
this.userSessionRegistry.registerSessionId(userName, session.getId()); this.userSessionRegistry.registerSessionId(userName, session.getId());
} }
} }
long[] heartbeat = headers.getHeartbeat(); long[] heartbeat = headerAccessor.getHeartbeat();
if (heartbeat[1] > 0) { if (heartbeat[1] > 0) {
session = WebSocketSessionDecorator.unwrap(session); session = WebSocketSessionDecorator.unwrap(session);
if (session instanceof SockJsSession) { if (session instanceof SockJsSession) {
@ -333,6 +359,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
((SockJsSession) session).disableHeartbeat(); ((SockJsSession) session).disableHeartbeat();
} }
} }
return headerAccessor;
} }
private String resolveNameForUserSessionRegistry(Principal principal) { private String resolveNameForUserSessionRegistry(Principal principal) {
@ -345,8 +372,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Override @Override
public String resolveSessionId(Message<?> message) { public String resolveSessionId(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); return SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
return headers.getSessionId();
} }
@Override @Override
@ -374,7 +400,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(session.getId()); headers.setSessionId(session.getId());
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<?> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
if (this.eventPublisher != null) { if (this.eventPublisher != null) {
publishEvent(new SessionDisconnectEvent(this, session.getId(), closeStatus)); publishEvent(new SessionDisconnectEvent(this, session.getId(), closeStatus));

View File

@ -16,7 +16,6 @@
package org.springframework.web.socket.messaging; package org.springframework.web.socket.messaging;
import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -41,7 +40,6 @@ import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.CloseStatus;
@ -61,6 +59,8 @@ import static org.mockito.Mockito.*;
*/ */
public class StompSubProtocolHandlerTests { public class StompSubProtocolHandlerTests {
public static final byte[] EMPTY_PAYLOAD = new byte[0];
private StompSubProtocolHandler protocolHandler; private StompSubProtocolHandler protocolHandler;
private TestWebSocketSession session; private TestWebSocketSession session;
@ -89,7 +89,7 @@ public class StompSubProtocolHandlerTests {
this.protocolHandler.setUserSessionRegistry(registry); this.protocolHandler.setUserSessionRegistry(registry);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message); this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size()); assertEquals(1, this.session.getSentMessages().size());
@ -108,7 +108,7 @@ public class StompSubProtocolHandlerTests {
this.protocolHandler.setUserSessionRegistry(registry); this.protocolHandler.setUserSessionRegistry(registry);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message); this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size()); assertEquals(1, this.session.getSentMessages().size());
@ -126,7 +126,7 @@ public class StompSubProtocolHandlerTests {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
headers.setHeartbeat(0,10); headers.setHeartbeat(0,10);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(sockJsSession, message); this.protocolHandler.handleMessageToClient(sockJsSession, message);
verify(sockJsSession).disableHeartbeat(); verify(sockJsSession).disableHeartbeat();
@ -137,12 +137,12 @@ public class StompSubProtocolHandlerTests {
StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT); StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT);
connectHeaders.setHeartbeat(10000, 10000); connectHeaders.setHeartbeat(10000, 10000);
connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1"); connectHeaders.setAcceptVersion("1.0,1.1");
Message<?> connectMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build(); Message<?> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectHeaders.getMessageHeaders());
SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage); connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage);
Message<byte[]> connectAckMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build(); Message<byte[]> connectAckMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAckHeaders.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, connectAckMessage); this.protocolHandler.handleMessageToClient(this.session, connectAckMessage);
@ -174,12 +174,12 @@ public class StompSubProtocolHandlerTests {
this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.afterSessionStarted(this.session, this.channel);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
TextMessage textMessage = new TextMessage(new StompEncoder().encode(message)); TextMessage textMessage = new TextMessage(new StompEncoder().encode(message));
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
headers = StompHeaderAccessor.create(StompCommand.CONNECTED); headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message); this.protocolHandler.handleMessageToClient(this.session, message);
this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel); this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel);
@ -207,7 +207,7 @@ public class StompSubProtocolHandlerTests {
this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.afterSessionStarted(this.session, this.channel);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
TextMessage textMessage = new TextMessage(new StompEncoder().encode(message)); TextMessage textMessage = new TextMessage(new StompEncoder().encode(message));
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
@ -218,7 +218,7 @@ public class StompSubProtocolHandlerTests {
reset(this.channel); reset(this.channel);
headers = StompHeaderAccessor.create(StompCommand.CONNECTED); headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message); this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size()); assertEquals(1, this.session.getSentMessages().size());
@ -241,7 +241,7 @@ public class StompSubProtocolHandlerTests {
headers.setSubscriptionId("sub0"); headers.setSubscriptionId("sub0");
headers.setDestination("/queue/foo-user123"); headers.setDestination("/queue/foo-user123");
headers.setHeader(StompHeaderAccessor.ORIGINAL_DESTINATION, "/user/queue/foo"); headers.setHeader(StompHeaderAccessor.ORIGINAL_DESTINATION, "/user/queue/foo");
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message); this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size()); assertEquals(1, this.session.getSentMessages().size());
@ -278,8 +278,9 @@ public class StompSubProtocolHandlerTests {
@Test @Test
public void handleMessageFromClientInvalidStompCommand() { public void handleMessageFromClientInvalidStompCommand() {
TextMessage textMessage = new TextMessage("FOO"); TextMessage textMessage = new TextMessage("FOO\n\n\0");
this.protocolHandler.afterSessionStarted(this.session, this.channel);
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
verifyZeroInteractions(this.channel); verifyZeroInteractions(this.channel);

View File

@ -128,7 +128,7 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration
assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS)); assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));
String payload = clientHandler.actual.get(0).getPayload(); String payload = clientHandler.actual.get(0).getPayload();
assertTrue("Expected STOMP Command=MESSAGE, got " + payload, payload.startsWith("MESSAGE\n")); assertTrue("Expected STOMP MESSAGE, got " + payload, payload.startsWith("MESSAGE\n"));
} }
finally { finally {
session.close(); session.close();