Make use of enhanced MessageHeaderAccessor support

Mutate rather than re-create headers when decoding STOMP messages
before a message is sent on a message channel.

Use MessageBuilder.createMessage to ensure the fully prepared
MessageHeaders is used directly MessageHeaderAccessor instance.

Issue: SPR-11468
This commit is contained in:
Rossen Stoyanchev 2014-04-10 23:57:45 -04:00
parent 4867546aec
commit ae942ffdb8
22 changed files with 322 additions and 236 deletions

View File

@ -26,6 +26,7 @@ import org.springframework.messaging.MessagingException;
import org.springframework.messaging.converter.MessageConversionException;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.converter.SimpleMessageConverter;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
/**
@ -130,6 +131,13 @@ public abstract class AbstractMessageSendingTemplate<D> implements MessageSendin
MessagePostProcessor postProcessor) throws MessagingException {
headers = processHeadersToSend(headers);
MessageHeaders messageHeaders;
if (headers != null && headers instanceof MessageHeaders) {
MessageHeaderAccessor.getAccessor()
}
MessageHeaders messageHeaders = (headers != null) ? new MessageHeaders(headers) : null;
Message<?> message = this.converter.toMessage(payload, messageHeaders);

View File

@ -26,7 +26,7 @@ import org.springframework.messaging.handler.annotation.Header;
import org.springframework.messaging.handler.annotation.Headers;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.ClassUtils;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
/**
@ -42,7 +42,6 @@ import org.springframework.util.ReflectionUtils;
*/
public class HeadersMethodArgumentResolver implements HandlerMethodArgumentResolver {
@Override
public boolean supportsParameter(MethodParameter parameter) {
Class<?> paramType = parameter.getParameterType();
@ -60,15 +59,23 @@ public class HeadersMethodArgumentResolver implements HandlerMethodArgumentResol
return message.getHeaders();
}
else if (MessageHeaderAccessor.class.equals(paramType)) {
return new MessageHeaderAccessor(message);
MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
return (accessor != null ? accessor : new MessageHeaderAccessor(message));
}
else if (MessageHeaderAccessor.class.isAssignableFrom(paramType)) {
Method factoryMethod = ClassUtils.getMethod(paramType, "wrap", Message.class);
return ReflectionUtils.invokeMethod(factoryMethod, null, message);
MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
if (accessor != null && paramType.isAssignableFrom(accessor.getClass())) {
return accessor;
}
else {
Method method = ReflectionUtils.findMethod(paramType, "wrap", Message.class);
Assert.notNull(method, "Cannot create accessor of type " + paramType + " for message " + message);
return ReflectionUtils.invokeMethod(method, null, message);
}
}
else {
throw new IllegalStateException("Unexpected method parameter type "
+ paramType + "in method " + parameter.getMethod() + ". "
throw new IllegalStateException(
"Unexpected method parameter type " + paramType + "in method " + parameter.getMethod() + ". "
+ "@Headers method arguments must be assignable to java.util.Map.");
}
}

View File

@ -42,6 +42,7 @@ import org.springframework.messaging.handler.DestinationPatternsMessageCondition
import org.springframework.messaging.handler.HandlerMethod;
import org.springframework.messaging.handler.HandlerMethodSelector;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
@ -339,8 +340,9 @@ public abstract class AbstractMethodMessageHandler<T>
logger.debug("Handling message, lookupDestination=" + lookupDestination);
}
message = MessageBuilder.fromMessage(message).setHeader(
DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, lookupDestination).build();
MessageHeaderAccessor headerAccessor = MessageHeaderAccessor.getMutableAccessor(message);
headerAccessor.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, lookupDestination);
message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
handleMessageInternal(message, lookupDestination);
}

View File

@ -73,7 +73,7 @@ public class SimpMessageTypeMessageCondition extends AbstractMessageCondition<Si
@Override
public SimpMessageTypeMessageCondition getMatchingCondition(Message<?> message) {
Object actualMessageType = message.getHeaders().get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER);
Object actualMessageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders());
if (actualMessageType == null) {
return null;
}
@ -83,7 +83,7 @@ public class SimpMessageTypeMessageCondition extends AbstractMessageCondition<Si
@Override
public int compareTo(SimpMessageTypeMessageCondition other, Message<?> message) {
Object actualMessageType = message.getHeaders().get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER);
Object actualMessageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders());
if (actualMessageType != null) {
if (actualMessageType.equals(this.getMessageType()) && actualMessageType.equals(other.getMessageType())) {
return 0;

View File

@ -108,8 +108,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
@Override
public void send(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
String destination = headers.getDestination();
String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders());
destination = (destination != null) ? destination : getRequiredDefaultDestination();
doSend(destination, message);
}
@ -118,10 +117,10 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
protected void doSend(String destination, Message<?> message) {
Assert.notNull(destination, "Destination must not be null");
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setDestination(destination);
headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message);
headerAccessor.setDestination(destination);
headerAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
long timeout = this.sendTimeout;
boolean sent = (timeout >= 0)

View File

@ -37,8 +37,7 @@ public class PrincipalMethodArgumentResolver implements HandlerMethodArgumentRes
@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
Principal user = headers.getUser();
Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
if (user == null) {
throw new MissingSessionUserException(message);
}

View File

@ -23,6 +23,7 @@ import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.core.MessagePostProcessor;
import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
import org.springframework.messaging.handler.annotation.SendTo;
@ -113,29 +114,28 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
}
@Override
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> inputMessage)
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> message)
throws Exception {
if (returnValue == null) {
return;
}
SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(inputMessage);
String sessionId = inputHeaders.getSessionId();
MessageHeaders headers = message.getHeaders();
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(sessionId);
SendToUser sendToUser = returnType.getMethodAnnotation(SendToUser.class);
if (sendToUser != null) {
Principal principal = inputHeaders.getUser();
Principal principal = SimpMessageHeaderAccessor.getUser(headers);
if (principal == null) {
throw new MissingSessionUserException(inputMessage);
throw new MissingSessionUserException(message);
}
String userName = principal.getName();
if (principal instanceof DestinationUserNameProvider) {
userName = ((DestinationUserNameProvider) principal).getDestinationUserName();
}
String[] destinations = getTargetDestinations(sendToUser, inputHeaders, this.defaultUserDestinationPrefix);
String[] destinations = getTargetDestinations(sendToUser, message, this.defaultUserDestinationPrefix);
for (String destination : destinations) {
this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, postProcessor);
}
@ -143,15 +143,14 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
}
else {
SendTo sendTo = returnType.getMethodAnnotation(SendTo.class);
String[] destinations = getTargetDestinations(sendTo, inputHeaders, this.defaultDestinationPrefix);
String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix);
for (String destination : destinations) {
this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor);
}
}
}
protected String[] getTargetDestinations(Annotation annot, SimpMessageHeaderAccessor inputHeaders,
String defaultPrefix) {
protected String[] getTargetDestinations(Annotation annot, Message<?> message, String defaultPrefix) {
if (annot != null) {
String[] value = (String[]) AnnotationUtils.getValue(annot);
@ -159,8 +158,8 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
return value;
}
}
return new String[] { defaultPrefix +
inputHeaders.getHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER) };
String name = DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER;
return new String[] { defaultPrefix + message.getHeaders().get(name) };
}
@ -176,7 +175,7 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
public Message<?> postProcessMessage(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setSessionId(this.sessionId);
return MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
return MessageBuilder.createMessage(message.getPayload(), headers.getMessageHeaders());
}
}

View File

@ -63,6 +63,7 @@ import org.springframework.stereotype.Controller;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.PathMatcher;
import org.springframework.validation.Errors;
import org.springframework.validation.Validator;
@ -329,7 +330,7 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan
@Override
protected String getDestination(Message<?> message) {
return (String) message.getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER);
return (String) SimpMessageHeaderAccessor.getDestination(message.getHeaders());
}
@Override
@ -352,13 +353,15 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan
protected void handleMatch(SimpMessageMappingInfo mapping, HandlerMethod handlerMethod,
String lookupDestination, Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
String matchedPattern = mapping.getDestinationConditions().getPatterns().iterator().next();
Map<String, String> vars = getPathMatcher().extractUriTemplateVariables(matchedPattern, lookupDestination);
headers.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars);
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
if (!CollectionUtils.isEmpty(vars)) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars);
message = MessageBuilder.createMessage(message.getPayload(), headers.getMessageHeaders());
}
super.handleMatch(mapping, handlerMethod, lookupDestination, message);
}

View File

@ -18,6 +18,7 @@ package org.springframework.messaging.simp.annotation.support;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.core.MessagePostProcessor;
import org.springframework.messaging.core.MessageSendingOperations;
import org.springframework.messaging.handler.annotation.SendTo;
@ -71,13 +72,12 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
return;
}
SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(message);
String sessionId = inputHeaders.getSessionId();
String subscriptionId = inputHeaders.getSubscriptionId();
String destination = inputHeaders.getDestination();
MessageHeaders headers = message.getHeaders();
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers);
String destination = SimpMessageHeaderAccessor.getDestination(headers);
Assert.state(inputHeaders.getSubscriptionId() != null,
"No subsriptiondId in input message to method " + returnType.getMethod());
Assert.state(subscriptionId != null, "No subsriptiondId in input message to method " + returnType.getMethod());
MessagePostProcessor postProcessor = new SubscriptionHeaderPostProcessor(sessionId, subscriptionId);
this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor);
@ -98,11 +98,11 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
@Override
public Message<?> postProcessMessage(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setSessionId(this.sessionId);
headers.setSubscriptionId(this.subscriptionId);
headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
return MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message);
headerAccessor.setSessionId(this.sessionId);
headerAccessor.setSubscriptionId(this.subscriptionId);
headerAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
return MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
}
}
}

View File

@ -19,8 +19,10 @@ package org.springframework.messaging.simp.broker;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.MultiValueMap;
/**
@ -38,29 +40,31 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
@Override
public final void registerSubscription(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
if (!SimpMessageType.SUBSCRIBE.equals(headers.getMessageType())) {
MessageHeaders headers = message.getHeaders();
SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers);
if (!SimpMessageType.SUBSCRIBE.equals(type)) {
logger.error("Expected SUBSCRIBE message: " + message);
return;
}
String sessionId = headers.getSessionId();
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (sessionId == null) {
logger.error("Ignoring subscription. No sessionId in message: " + message);
return;
}
String subscriptionId = headers.getSubscriptionId();
String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers);
if (subscriptionId == null) {
logger.error("Ignoring subscription. No subscriptionId in message: " + message);
return;
}
String destination = headers.getDestination();
String destination = SimpMessageHeaderAccessor.getDestination(headers);
if (destination == null) {
logger.error("Ignoring destination. No destination in message: " + message);
return;
}
if (logger.isDebugEnabled()) {
logger.debug("Adding subscription id=" + headers.getSubscriptionId()
+ ", destination=" + headers.getDestination());
logger.debug("Adding subscription id=" + subscriptionId + ", destination=" + destination);
}
addSubscriptionInternal(sessionId, subscriptionId, destination, message);
}
@ -70,17 +74,20 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
@Override
public final void unregisterSubscription(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
if (!SimpMessageType.UNSUBSCRIBE.equals(headers.getMessageType())) {
MessageHeaders headers = message.getHeaders();
SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers);
if (!SimpMessageType.UNSUBSCRIBE.equals(type)) {
logger.error("Expected UNSUBSCRIBE message: " + message);
return;
}
String sessionId = headers.getSessionId();
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (sessionId == null) {
logger.error("Ignoring subscription. No sessionId in message: " + message);
return;
}
String subscriptionId = headers.getSubscriptionId();
String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers);
if (subscriptionId == null) {
logger.error("Ignoring subscription. No subscriptionId in message: " + message);
return;
@ -98,19 +105,22 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist
@Override
public final MultiValueMap<String, String> findSubscriptions(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
if (!SimpMessageType.MESSAGE.equals(headers.getMessageType())) {
logger.trace("Ignoring message type " + headers.getMessageType());
MessageHeaders headers = message.getHeaders();
SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers);
if (!SimpMessageType.MESSAGE.equals(type)) {
logger.trace("Ignoring message type " + type);
return null;
}
String destination = headers.getDestination();
String destination = SimpMessageHeaderAccessor.getDestination(headers);
if (destination == null) {
logger.trace("Ignoring message, no destination");
return null;
}
MultiValueMap<String, String> result = findSubscriptionsInternal(destination, message);
if (logger.isTraceEnabled()) {
logger.trace("Found " + result.size() + " subscriptions for destination=" + headers.getDestination());
logger.trace("Found " + result.size() + " subscriptions for destination=" + destination);
}
return result;
}

View File

@ -20,6 +20,7 @@ import java.util.Collection;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
@ -111,9 +112,10 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
@Override
protected void handleMessageInternal(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
SimpMessageType messageType = headers.getMessageType();
String destination = headers.getDestination();
MessageHeaders headers = message.getHeaders();
SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
String destination = SimpMessageHeaderAccessor.getDestination(headers);
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (!checkDestinationPrefix(destination)) {
if (logger.isTraceEnabled()) {
@ -122,27 +124,30 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
return;
}
if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
if (SimpMessageType.MESSAGE.equals(messageType)) {
sendMessageToSubscribers(destination, message);
}
else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
this.subscriptionRegistry.registerSubscription(message);
}
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
this.subscriptionRegistry.unregisterSubscription(message);
}
else if (SimpMessageType.MESSAGE.equals(messageType)) {
sendMessageToSubscribers(headers.getDestination(), message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
String sessionId = headers.getSessionId();
this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
}
else if (SimpMessageType.CONNECT.equals(messageType)) {
SimpMessageHeaderAccessor replyHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
replyHeaders.setSessionId(headers.getSessionId());
replyHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
Message<byte[]> connectAck = MessageBuilder.withPayload(EMPTY_PAYLOAD).setHeaders(replyHeaders).build();
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
accessor.setSessionId(sessionId);
accessor.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
Message<byte[]> connectAck = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
this.clientOutboundChannel.send(connectAck);
}
else {
if (logger.isTraceEnabled()) {
logger.trace("Message type not supported. Ignoring: " + message);
}
}
}
protected void sendMessageToSubscribers(String destination, Message<?> message) {
@ -153,17 +158,17 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
}
for (String sessionId : subscriptions.keySet()) {
for (String subscriptionId : subscriptions.get(sessionId)) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setSessionId(sessionId);
headers.setSubscriptionId(subscriptionId);
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
headerAccessor.setSessionId(sessionId);
headerAccessor.setSubscriptionId(subscriptionId);
headerAccessor.copyHeadersIfAbsent(message.getHeaders());
Object payload = message.getPayload();
Message<?> clientMessage = MessageBuilder.withPayload(payload).setHeaders(headers).build();
Message<?> reply = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
try {
this.clientOutboundChannel.send(clientMessage);
this.clientOutboundChannel.send(reply);
}
catch (Throwable ex) {
logger.error("Failed to send message to destination=" + destination +
", sessionId=" + sessionId + ", subscriptionId=" + subscriptionId, ex);
logger.error("Failed to send message=" + message, ex);
}
}
}

View File

@ -20,7 +20,6 @@ import java.util.Collection;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
@ -30,6 +29,7 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.FixedIntervalReconnectStrategy;
import org.springframework.messaging.tcp.TcpConnection;
import org.springframework.messaging.tcp.TcpConnectionHandler;
@ -79,9 +79,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
private static final Message<byte[]> HEARTBEAT_MESSAGE;
static {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
HEARTBEAT_MESSAGE = MessageBuilder.withPayload(new byte[] {'\n'}).setHeaders(headers).build();
EMPTY_TASK.run();
StompHeaderAccessor accessor = StompHeaderAccessor.createForHeartbeat();
HEARTBEAT_MESSAGE = MessageBuilder.createMessage(StompDecoder.HEARTBEAT_PAYLOAD, accessor.getMessageHeaders());
}
@ -370,37 +370,53 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
@Override
protected void handleMessageInternal(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
String sessionId = headers.getSessionId();
String sessionId = SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
if (!isBrokerAvailable()) {
if (sessionId == null || sessionId == SystemStompConnectionHandler.SESSION_ID) {
if (sessionId == null || SystemStompConnectionHandler.SESSION_ID.equals(sessionId)) {
throw new MessageDeliveryException("Message broker is not active.");
}
if (logger.isTraceEnabled()) {
logger.trace("Message broker is not active. Ignoring message id=" + message.getHeaders().getId());
logger.trace("Message broker is not active. Ignoring: " + message);
}
return;
}
String destination = headers.getDestination();
StompCommand command = headers.getCommand();
SimpMessageType messageType = headers.getMessageType();
StompHeaderAccessor stompAccessor;
StompCommand command;
if (SimpMessageType.MESSAGE.equals(messageType)) {
sessionId = (sessionId == null) ? SystemStompConnectionHandler.SESSION_ID : sessionId;
headers.setSessionId(sessionId);
command = headers.updateStompCommandAsClientMessage();
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
if (accessor == null) {
logger.error("No header accessor, please use SimpMessagingTemplate. Ignoring: " + message);
return;
}
else if (accessor instanceof StompHeaderAccessor) {
stompAccessor = (StompHeaderAccessor) accessor;
command = stompAccessor.getCommand();
}
else if (accessor instanceof SimpMessageHeaderAccessor) {
stompAccessor = StompHeaderAccessor.wrap(message);
command = stompAccessor.getCommand();
if (command == null) {
command = stompAccessor.updateStompCommandAsClientMessage();
}
}
else {
// Should not happen
logger.error("Unexpected header accessor type: " + accessor + ". Ignoring: " + message);
return;
}
if (sessionId == null) {
if (logger.isWarnEnabled()) {
logger.warn("No sessionId, ignoring message: " + message);
if (!SimpMessageType.MESSAGE.equals(stompAccessor.getMessageType())) {
logger.error("Only STOMP SEND frames supported on \"system\" connection. Ignoring: " + message);
return;
}
return;
sessionId = SystemStompConnectionHandler.SESSION_ID;
stompAccessor.setSessionId(sessionId);
}
String destination = stompAccessor.getDestination();
if ((command != null) && command.requiresDestination() && !checkDestinationPrefix(destination)) {
if (logger.isTraceEnabled()) {
logger.trace("Ignoring message to destination=" + destination);
@ -412,20 +428,21 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
logger.trace("Processing message=" + message);
}
if (SimpMessageType.CONNECT.equals(messageType)) {
if (StompCommand.CONNECT.equals(command)) {
if (logger.isDebugEnabled()) {
logger.debug("Processing CONNECT (total connected=" + this.connectionHandlers.size() + ")");
}
headers.setLogin(this.clientLogin);
headers.setPasscode(this.clientPasscode);
stompAccessor = (stompAccessor.isMutable() ? stompAccessor : StompHeaderAccessor.wrap(message));
stompAccessor.setLogin(this.clientLogin);
stompAccessor.setPasscode(this.clientPasscode);
if (getVirtualHost() != null) {
headers.setHost(getVirtualHost());
stompAccessor.setHost(getVirtualHost());
}
StompConnectionHandler handler = new StompConnectionHandler(sessionId, headers);
StompConnectionHandler handler = new StompConnectionHandler(sessionId, stompAccessor);
this.connectionHandlers.put(sessionId, handler);
this.tcpClient.connect(handler);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
else if (StompCommand.DISCONNECT.equals(command)) {
StompConnectionHandler handler = this.connectionHandlers.get(sessionId);
if (handler == null) {
if (logger.isTraceEnabled()) {
@ -433,7 +450,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
return;
}
handler.forward(message);
handler.forward(message, stompAccessor);
}
else {
StompConnectionHandler handler = this.connectionHandlers.get(sessionId);
@ -443,7 +460,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
return;
}
handler.forward(message);
handler.forward(message, stompAccessor);
}
}
@ -486,7 +503,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
logger.debug("Established TCP connection to broker in session '" + this.sessionId + "'");
}
this.tcpConnection = connection;
connection.send(MessageBuilder.withPayload(EMPTY_PAYLOAD).setHeaders(this.connectHeaders).build());
connection.send(MessageBuilder.createMessage(EMPTY_PAYLOAD, this.connectHeaders.getMessageHeaders()));
}
@Override
@ -522,7 +539,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setSessionId(this.sessionId);
headers.setMessage(errorText);
Message<?> errorMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<?> errorMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
sendMessageToClient(errorMessage);
}
}
@ -536,20 +553,23 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
@Override
public void handleMessage(Message<byte[]> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) {
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (headerAccessor.isHeartbeat()) {
logger.trace("Received broker heartbeat");
}
else if (logger.isDebugEnabled()) {
logger.debug("Received message from broker in session '" + this.sessionId + "'");
}
if (StompCommand.CONNECTED == headers.getCommand()) {
afterStompConnected(headers);
if (StompCommand.CONNECTED == headerAccessor.getCommand()) {
afterStompConnected(headerAccessor);
}
headers.setSessionId(this.sessionId);
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
headerAccessor.setSessionId(this.sessionId);
headerAccessor.setImmutable();
sendMessageToClient(message);
}
@ -630,9 +650,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
clearConnection();
}
catch (Throwable t) {
if (logger.isErrorEnabled()) {
// Ignore
}
// Ignore
}
}
}
@ -661,7 +679,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
* @return a future to wait for the result
*/
@SuppressWarnings("unchecked")
public ListenableFuture<Void> forward(final Message<?> message) {
public ListenableFuture<Void> forward(Message<?> message, final StompHeaderAccessor headerAccessor) {
TcpConnection<byte[]> conn = this.tcpConnection;
@ -682,8 +700,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
if (logger.isDebugEnabled()) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) {
if (headerAccessor.isHeartbeat()) {
logger.trace("Forwarding heartbeat to broker");
}
else {
@ -691,13 +708,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
}
if (headerAccessor.isMutable() && headerAccessor.isModified()) {
message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
}
ListenableFuture<Void> future = conn.send((Message<byte[]>) message);
future.addCallback(new ListenableFutureCallback<Void>() {
@Override
public void onSuccess(Void result) {
StompCommand command = StompHeaderAccessor.wrap(message).getCommand();
if (command == StompCommand.DISCONNECT) {
if (headerAccessor.getCommand() == StompCommand.DISCONNECT) {
clearConnection();
}
}
@ -707,7 +727,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
// already reset
}
else {
handleTcpConnectionFailure("Failed to send message " + message, t);
handleTcpConnectionFailure("Failed to send message " + headerAccessor, t);
}
}
});
@ -777,9 +797,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
@Override
public ListenableFuture<Void> forward(Message<?> message) {
public ListenableFuture<Void> forward(Message<?> message, StompHeaderAccessor headerAccessor) {
try {
ListenableFuture<Void> future = super.forward(message);
ListenableFuture<Void> future = super.forward(message, headerAccessor);
future.get();
return future;
}

View File

@ -19,6 +19,7 @@ package org.springframework.messaging.simp.user;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.util.Assert;
@ -100,34 +101,34 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
@Override
public UserDestinationResult resolveDestination(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
DestinationInfo info = parseUserDestination(headers);
String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders());
DestinationInfo info = parseUserDestination(message);
if (info == null) {
return null;
}
Set<String> targetDestinations = new HashSet<String>();
for (String sessionId : info.getSessionIds()) {
targetDestinations.add(getTargetDestination(
headers.getDestination(), info.getDestinationWithoutPrefix(), sessionId, info.getUser()));
targetDestinations.add(getTargetDestination(destination,
info.getDestinationWithoutPrefix(), sessionId, info.getUser()));
}
return new UserDestinationResult(headers.getDestination(),
return new UserDestinationResult(destination,
targetDestinations, info.getSubscribeDestination(), info.getUser());
}
private DestinationInfo parseUserDestination(SimpMessageHeaderAccessor headers) {
private DestinationInfo parseUserDestination(Message<?> message) {
String destination = headers.getDestination();
MessageHeaders headers = message.getHeaders();
SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
String destination = SimpMessageHeaderAccessor.getDestination(headers);
Principal principal = SimpMessageHeaderAccessor.getUser(headers);
String destinationWithoutPrefix;
String subscribeDestination;
String user;
Set<String> sessionIds;
Principal principal = headers.getUser();
SimpMessageType messageType = headers.getMessageType();
if (SimpMessageType.SUBSCRIBE.equals(messageType) || SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
if (!checkDestination(destination, this.destinationPrefix)) {
return null;
@ -136,14 +137,15 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
logger.error("Ignoring message, no principal info available");
return null;
}
if (headers.getSessionId() == null) {
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (sessionId == null) {
logger.error("Ignoring message, no session id available");
return null;
}
destinationWithoutPrefix = destination.substring(this.destinationPrefix.length()-1);
subscribeDestination = destination;
user = principal.getName();
sessionIds = Collections.singleton(headers.getSessionId());
sessionIds = Collections.singleton(sessionId);
}
else if (SimpMessageType.MESSAGE.equals(messageType)) {
if (!checkDestination(destination, this.destinationPrefix)) {

View File

@ -152,16 +152,16 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
if (destinations.isEmpty()) {
return;
}
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message);
if (SimpMessageType.MESSAGE.equals(headerAccessor.getMessageType())) {
if (SimpMessageType.MESSAGE.equals(SimpMessageHeaderAccessor.getMessageType(message.getHeaders()))) {
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message);
headerAccessor.setHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION, result.getSubscribeDestination());
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headerAccessor).build();
message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders());
}
for (String targetDestination : destinations) {
for (String destination : destinations) {
if (logger.isDebugEnabled()) {
logger.debug("Sending message to resolved destination=" + targetDestination);
logger.debug("Sending message to resolved destination=" + destination);
}
this.brokerMessagingTemplate.send(targetDestination, message);
this.brokerMessagingTemplate.send(destination, message);
}
}

View File

@ -112,10 +112,12 @@ public class MessageConverterTests {
public void toMessageHeadersCopied() {
Map<String, Object> map = new HashMap<String, Object>();
map.put("foo", "bar");
MessageHeaders headers = new MessageHeaders(map );
MessageHeaders headers = new MessageHeaders(map);
Message<?> message = this.converter.toMessage("ABC", headers);
assertEquals("bar", message.getHeaders().get("foo"));
assertNotNull(message.getHeaders().getId());
assertNotNull(message.getHeaders().getTimestamp());
}
@Test

View File

@ -31,7 +31,6 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.support.StaticApplicationContext;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.converter.*;
import org.springframework.messaging.handler.annotation.MessageMapping;

View File

@ -38,6 +38,7 @@ import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.StubMessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.broker.BrokerAvailabilityEvent;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
@ -168,7 +169,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
public void messageDeliverExceptionIfSystemSessionForwardFails() throws Exception {
stopActiveMqBrokerAndAwait();
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
this.relay.handleMessage(MessageBuilder.withPayload("test".getBytes()).setHeaders(headers).build());
this.relay.handleMessage(MessageBuilder.createMessage("test".getBytes(), headers.getMessageHeaders()));
}
@Test
@ -244,7 +245,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId("sess1");
this.relay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());
this.relay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
Thread.sleep(2000);
@ -394,7 +395,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
headers.setSessionId(sessionId);
headers.setAcceptVersion("1.1,1.2");
headers.setHeartbeat(0, 0);
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<?> message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
MessageExchangeBuilder builder = new MessageExchangeBuilder(message);
builder.expected.add(new StompConnectedFrameMessageMatcher(sessionId));
@ -405,7 +406,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId);
headers.setAcceptVersion("1.1,1.2");
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<?> message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
MessageExchangeBuilder builder = new MessageExchangeBuilder(message);
return builder.andExpectError();
}
@ -418,7 +419,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
headers.setSubscriptionId(subscriptionId);
headers.setDestination(destination);
headers.setReceipt(receiptId);
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<?> message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
MessageExchangeBuilder builder = new MessageExchangeBuilder(message);
builder.expected.add(new StompReceiptFrameMessageMatcher(sessionId, receiptId));
@ -426,14 +427,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
}
public static MessageExchangeBuilder send(String destination, String payload) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
headers.setDestination(destination);
Message<?> message = MessageBuilder.withPayload(payload.getBytes(UTF_8)).setHeaders(headers).build();
Message<?> message = MessageBuilder.createMessage(payload.getBytes(UTF_8), headers.getMessageHeaders());
return new MessageExchangeBuilder(message);
}
public MessageExchangeBuilder andExpectMessage(String sessionId, String subscriptionId) {
Assert.isTrue(StompCommand.SEND.equals(headers.getCommand()), "MESSAGE can only be expected after SEND");
Assert.isTrue(SimpMessageType.MESSAGE.equals(headers.getMessageType()));
String destination = this.headers.getDestination();
Object payload = this.message.getPayload();
this.expected.add(new StompMessageFrameMessageMatcher(sessionId, subscriptionId, destination, payload));

View File

@ -27,6 +27,7 @@ import org.springframework.messaging.StubMessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.ReconnectStrategy;
import org.springframework.messaging.tcp.TcpConnection;
import org.springframework.messaging.tcp.TcpConnectionHandler;
@ -77,17 +78,21 @@ public class StompBrokerRelayMessageHandlerTests {
String sessionId = "sess1";
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId);
this.brokerRelay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());
this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
List<Message<byte[]>> sent = this.tcpClient.connection.messages;
assertEquals(2, sent.size());
StompHeaderAccessor headers1 = StompHeaderAccessor.wrap(sent.get(0));
assertEquals(virtualHost, headers1.getHost());
assertNotNull("The prepared message does not have an accessor",
MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class));
StompHeaderAccessor headers2 = StompHeaderAccessor.wrap(sent.get(1));
assertEquals(sessionId, headers2.getSessionId());
assertEquals(virtualHost, headers2.getHost());
assertNotNull("The prepared message does not have an accessor",
MessageHeaderAccessor.getAccessor(sent.get(1), MessageHeaderAccessor.class));
}
@Test
@ -104,7 +109,7 @@ public class StompBrokerRelayMessageHandlerTests {
String sessionId = "sess1";
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId);
this.brokerRelay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());
this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
List<Message<byte[]>> sent = this.tcpClient.connection.messages;
assertEquals(2, sent.size());
@ -126,11 +131,13 @@ public class StompBrokerRelayMessageHandlerTests {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
headers.setSessionId("sess1");
headers.setDestination("/user/daisy/foo");
this.brokerRelay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());
this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
List<Message<byte[]>> sent = this.tcpClient.connection.messages;
assertEquals(1, sent.size());
assertEquals(StompCommand.CONNECT, StompHeaderAccessor.wrap(sent.get(0)).getCommand());
assertNotNull("The prepared message does not have an accessor",
MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class));
}

View File

@ -66,8 +66,7 @@ public class UserDestinationMessageHandlerTests {
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture());
assertEquals("/queue/foo-user123",
captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER));
assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders()));
}
@Test
@ -79,8 +78,7 @@ public class UserDestinationMessageHandlerTests {
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture());
assertEquals("/queue/foo-user123",
captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER));
assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders()));
}
@Test
@ -93,10 +91,8 @@ public class UserDestinationMessageHandlerTests {
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
Mockito.verify(this.brokerChannel).send(captor.capture());
assertEquals("/queue/foo-user123",
captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER));
assertEquals("/user/queue/foo",
captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION));
assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders()));
assertEquals("/user/queue/foo", captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION));
}

View File

@ -41,9 +41,9 @@ import org.springframework.messaging.simp.stomp.StompConversionException;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
@ -79,6 +79,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);
private static final byte[] EMPTY_PAYLOAD = new byte[0];
private int messageSizeLimit = 64 * 1024;
@ -172,9 +174,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
for (Message<byte[]> message : messages) {
try {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (logger.isTraceEnabled()) {
if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) {
if (headerAccessor.isHeartbeat()) {
logger.trace("Received heartbeat from client session=" + session.getId());
}
else {
@ -182,13 +187,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
}
headers.setSessionId(session.getId());
headers.setSessionAttributes(session.getAttributes());
headers.setUser(session.getPrincipal());
headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(session.getPrincipal());
headerAccessor.setImmutable();
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
if (this.eventPublisher != null && StompCommand.CONNECT.equals(headers.getCommand())) {
if (this.eventPublisher != null && StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
publishEvent(new SessionConnectEvent(this, message));
}
@ -212,10 +216,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setMessage(error.getMessage());
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
byte[] bytes = this.stompEncoder.encode(message);
StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
headerAccessor.setMessage(error.getMessage());
byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD);
try {
session.sendMessage(new TextMessage(bytes));
}
@ -231,46 +234,60 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Override
public void handleMessageToClient(WebSocketSession session, Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (headers.getMessageType() == SimpMessageType.CONNECT_ACK) {
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
connectedHeaders.setVersion(getVersion(headers));
connectedHeaders.setHeartbeat(0, 0); // no heart-beat support with simple broker
headers = connectedHeaders;
}
else if (SimpMessageType.MESSAGE.equals(headers.getMessageType())) {
headers.updateStompCommandAsServerMessage();
}
if (headers.getCommand() == StompCommand.CONNECTED) {
afterStompSessionConnected(headers, session);
}
if (StompCommand.MESSAGE.equals(headers.getCommand())) {
if (headers.getSubscriptionId() == null) {
logger.error("Ignoring message, no subscriptionId header: " + message);
return;
}
String header = SimpMessageHeaderAccessor.ORIGINAL_DESTINATION;
if (message.getHeaders().containsKey(header)) {
headers.setDestination((String) message.getHeaders().get(header));
}
}
if (!(message.getPayload() instanceof byte[])) {
logger.error("Ignoring message, expected byte[] content: " + message);
return;
}
try {
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
if (accessor == null) {
logger.error("No header accessor: " + message);
return;
}
if (this.eventPublisher != null && StompCommand.CONNECTED.equals(headers.getCommand())) {
StompHeaderAccessor stompAccessor;
if (accessor instanceof StompHeaderAccessor) {
stompAccessor = (StompHeaderAccessor) accessor;
}
else if (accessor instanceof SimpMessageHeaderAccessor) {
stompAccessor = StompHeaderAccessor.wrap(message);
if (SimpMessageType.CONNECT_ACK.equals(stompAccessor.getMessageType())) {
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
connectedHeaders.setVersion(getVersion(stompAccessor));
connectedHeaders.setHeartbeat(0, 0); // no heart-beat support with simple broker
stompAccessor = connectedHeaders;
}
else if (stompAccessor.getCommand() == null || StompCommand.SEND.equals(stompAccessor.getCommand())) {
stompAccessor.updateStompCommandAsServerMessage();
}
}
else {
// Should not happen
logger.error("Unexpected header accessor type: " + accessor);
return;
}
StompCommand command = stompAccessor.getCommand();
if (StompCommand.MESSAGE.equals(command)) {
if (stompAccessor.getSubscriptionId() == null) {
logger.error("Ignoring message, no subscriptionId header: " + message);
return;
}
String header = SimpMessageHeaderAccessor.ORIGINAL_DESTINATION;
if (message.getHeaders().containsKey(header)) {
stompAccessor = toMutableAccessor(stompAccessor, message);
stompAccessor.setDestination((String) message.getHeaders().get(header));
}
}
else if (StompCommand.CONNECTED.equals(command)) {
stompAccessor = afterStompSessionConnected(message, stompAccessor, session);
if (this.eventPublisher != null && StompCommand.CONNECTED.equals(command)) {
publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
}
}
byte[] bytes = this.stompEncoder.encode((Message<byte[]>) message);
try {
byte[] bytes = this.stompEncoder.encode(stompAccessor.getMessageHeaders(), (byte[]) message.getPayload());
TextMessage textMessage = new TextMessage(bytes);
session.sendMessage(textMessage);
@ -283,7 +300,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
sendErrorMessage(session, ex);
}
finally {
if (StompCommand.ERROR.equals(headers.getCommand())) {
if (StompCommand.ERROR.equals(command)) {
try {
session.close(CloseStatus.PROTOCOL_ERROR);
}
@ -294,13 +311,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
}
protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message<?> message) {
return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message));
}
private String getVersion(StompHeaderAccessor connectAckHeaders) {
String name = StompHeaderAccessor.CONNECT_MESSAGE_HEADER;
Message<?> connectMessage = (Message<?>) connectAckHeaders.getHeader(name);
StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(connectMessage);
Assert.notNull(connectMessage, "CONNECT_ACK does not contain original CONNECT " + connectAckHeaders);
StompHeaderAccessor connectHeaders =
MessageHeaderAccessor.getAccessor(connectMessage, StompHeaderAccessor.class);
Set<String> acceptVersions = connectHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) {
return "1.2";
@ -316,16 +339,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
}
private void afterStompSessionConnected(StompHeaderAccessor headers, WebSocketSession session) {
private StompHeaderAccessor afterStompSessionConnected(
Message<?> message, StompHeaderAccessor headerAccessor, WebSocketSession session) {
Principal principal = session.getPrincipal();
if (principal != null) {
headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
headerAccessor = toMutableAccessor(headerAccessor, message);
headerAccessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
if (this.userSessionRegistry != null) {
String userName = resolveNameForUserSessionRegistry(principal);
this.userSessionRegistry.registerSessionId(userName, session.getId());
}
}
long[] heartbeat = headers.getHeartbeat();
long[] heartbeat = headerAccessor.getHeartbeat();
if (heartbeat[1] > 0) {
session = WebSocketSessionDecorator.unwrap(session);
if (session instanceof SockJsSession) {
@ -333,6 +359,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
((SockJsSession) session).disableHeartbeat();
}
}
return headerAccessor;
}
private String resolveNameForUserSessionRegistry(Principal principal) {
@ -345,8 +372,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Override
public String resolveSessionId(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
return headers.getSessionId();
return SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
}
@Override
@ -374,7 +400,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(session.getId());
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<?> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
if (this.eventPublisher != null) {
publishEvent(new SessionDisconnectEvent(this, session.getId(), closeStatus));

View File

@ -16,7 +16,6 @@
package org.springframework.web.socket.messaging;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
@ -41,7 +40,6 @@ import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.CloseStatus;
@ -61,6 +59,8 @@ import static org.mockito.Mockito.*;
*/
public class StompSubProtocolHandlerTests {
public static final byte[] EMPTY_PAYLOAD = new byte[0];
private StompSubProtocolHandler protocolHandler;
private TestWebSocketSession session;
@ -89,7 +89,7 @@ public class StompSubProtocolHandlerTests {
this.protocolHandler.setUserSessionRegistry(registry);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size());
@ -108,7 +108,7 @@ public class StompSubProtocolHandlerTests {
this.protocolHandler.setUserSessionRegistry(registry);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size());
@ -126,7 +126,7 @@ public class StompSubProtocolHandlerTests {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
headers.setHeartbeat(0,10);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(sockJsSession, message);
verify(sockJsSession).disableHeartbeat();
@ -137,12 +137,12 @@ public class StompSubProtocolHandlerTests {
StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT);
connectHeaders.setHeartbeat(10000, 10000);
connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1");
Message<?> connectMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build();
connectHeaders.setAcceptVersion("1.0,1.1");
Message<?> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectHeaders.getMessageHeaders());
SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage);
Message<byte[]> connectAckMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build();
Message<byte[]> connectAckMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAckHeaders.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, connectAckMessage);
@ -174,12 +174,12 @@ public class StompSubProtocolHandlerTests {
this.protocolHandler.afterSessionStarted(this.session, this.channel);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
TextMessage textMessage = new TextMessage(new StompEncoder().encode(message));
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message);
this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel);
@ -207,7 +207,7 @@ public class StompSubProtocolHandlerTests {
this.protocolHandler.afterSessionStarted(this.session, this.channel);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
TextMessage textMessage = new TextMessage(new StompEncoder().encode(message));
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
@ -218,7 +218,7 @@ public class StompSubProtocolHandlerTests {
reset(this.channel);
headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size());
@ -241,7 +241,7 @@ public class StompSubProtocolHandlerTests {
headers.setSubscriptionId("sub0");
headers.setDestination("/queue/foo-user123");
headers.setHeader(StompHeaderAccessor.ORIGINAL_DESTINATION, "/user/queue/foo");
Message<byte[]> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
this.protocolHandler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size());
@ -278,8 +278,9 @@ public class StompSubProtocolHandlerTests {
@Test
public void handleMessageFromClientInvalidStompCommand() {
TextMessage textMessage = new TextMessage("FOO");
TextMessage textMessage = new TextMessage("FOO\n\n\0");
this.protocolHandler.afterSessionStarted(this.session, this.channel);
this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
verifyZeroInteractions(this.channel);

View File

@ -128,7 +128,7 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration
assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS));
String payload = clientHandler.actual.get(0).getPayload();
assertTrue("Expected STOMP Command=MESSAGE, got " + payload, payload.startsWith("MESSAGE\n"));
assertTrue("Expected STOMP MESSAGE, got " + payload, payload.startsWith("MESSAGE\n"));
}
finally {
session.close();