From ae942ffdb89ae103b6f9e076ec9548594317e2f9 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 10 Apr 2014 23:57:45 -0400 Subject: [PATCH] Make use of enhanced MessageHeaderAccessor support Mutate rather than re-create headers when decoding STOMP messages before a message is sent on a message channel. Use MessageBuilder.createMessage to ensure the fully prepared MessageHeaders is used directly MessageHeaderAccessor instance. Issue: SPR-11468 --- .../core/AbstractMessageSendingTemplate.java | 8 ++ .../HeadersMethodArgumentResolver.java | 21 ++- .../AbstractMethodMessageHandler.java | 6 +- .../simp/SimpMessageTypeMessageCondition.java | 4 +- .../messaging/simp/SimpMessagingTemplate.java | 11 +- .../PrincipalMethodArgumentResolver.java | 3 +- .../SendToMethodReturnValueHandler.java | 25 ++-- .../SimpAnnotationMethodMessageHandler.java | 13 +- .../SubscriptionMethodReturnValueHandler.java | 22 +-- .../broker/AbstractSubscriptionRegistry.java | 42 +++--- .../broker/SimpleBrokerMessageHandler.java | 47 ++++--- .../stomp/StompBrokerRelayMessageHandler.java | 110 +++++++++------ .../user/DefaultUserDestinationResolver.java | 26 ++-- .../user/UserDestinationMessageHandler.java | 12 +- .../converter/MessageConverterTests.java | 4 +- .../MessageBrokerConfigurationTests.java | 1 - ...erRelayMessageHandlerIntegrationTests.java | 17 +-- .../StompBrokerRelayMessageHandlerTests.java | 13 +- .../UserDestinationMessageHandlerTests.java | 12 +- .../messaging/StompSubProtocolHandler.java | 130 +++++++++++------- .../StompSubProtocolHandlerTests.java | 29 ++-- .../StompWebSocketIntegrationTests.java | 2 +- 22 files changed, 322 insertions(+), 236 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/core/AbstractMessageSendingTemplate.java b/spring-messaging/src/main/java/org/springframework/messaging/core/AbstractMessageSendingTemplate.java index 01f004609e..a2c82570b0 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/core/AbstractMessageSendingTemplate.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/core/AbstractMessageSendingTemplate.java @@ -26,6 +26,7 @@ import org.springframework.messaging.MessagingException; import org.springframework.messaging.converter.MessageConversionException; import org.springframework.messaging.converter.MessageConverter; import org.springframework.messaging.converter.SimpleMessageConverter; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; /** @@ -130,6 +131,13 @@ public abstract class AbstractMessageSendingTemplate implements MessageSendin MessagePostProcessor postProcessor) throws MessagingException { headers = processHeadersToSend(headers); + + MessageHeaders messageHeaders; + if (headers != null && headers instanceof MessageHeaders) { + MessageHeaderAccessor.getAccessor() + + } + MessageHeaders messageHeaders = (headers != null) ? new MessageHeaders(headers) : null; Message message = this.converter.toMessage(payload, messageHeaders); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/HeadersMethodArgumentResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/HeadersMethodArgumentResolver.java index 079fa62117..2b2cb64733 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/HeadersMethodArgumentResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/HeadersMethodArgumentResolver.java @@ -26,7 +26,7 @@ import org.springframework.messaging.handler.annotation.Header; import org.springframework.messaging.handler.annotation.Headers; import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; import org.springframework.messaging.support.MessageHeaderAccessor; -import org.springframework.util.ClassUtils; +import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; /** @@ -42,7 +42,6 @@ import org.springframework.util.ReflectionUtils; */ public class HeadersMethodArgumentResolver implements HandlerMethodArgumentResolver { - @Override public boolean supportsParameter(MethodParameter parameter) { Class paramType = parameter.getParameterType(); @@ -60,15 +59,23 @@ public class HeadersMethodArgumentResolver implements HandlerMethodArgumentResol return message.getHeaders(); } else if (MessageHeaderAccessor.class.equals(paramType)) { - return new MessageHeaderAccessor(message); + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); + return (accessor != null ? accessor : new MessageHeaderAccessor(message)); } else if (MessageHeaderAccessor.class.isAssignableFrom(paramType)) { - Method factoryMethod = ClassUtils.getMethod(paramType, "wrap", Message.class); - return ReflectionUtils.invokeMethod(factoryMethod, null, message); + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); + if (accessor != null && paramType.isAssignableFrom(accessor.getClass())) { + return accessor; + } + else { + Method method = ReflectionUtils.findMethod(paramType, "wrap", Message.class); + Assert.notNull(method, "Cannot create accessor of type " + paramType + " for message " + message); + return ReflectionUtils.invokeMethod(method, null, message); + } } else { - throw new IllegalStateException("Unexpected method parameter type " - + paramType + "in method " + parameter.getMethod() + ". " + throw new IllegalStateException( + "Unexpected method parameter type " + paramType + "in method " + parameter.getMethod() + ". " + "@Headers method arguments must be assignable to java.util.Map."); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java index e52be029e2..62912cf7df 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/AbstractMethodMessageHandler.java @@ -42,6 +42,7 @@ import org.springframework.messaging.handler.DestinationPatternsMessageCondition import org.springframework.messaging.handler.HandlerMethod; import org.springframework.messaging.handler.HandlerMethodSelector; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; @@ -339,8 +340,9 @@ public abstract class AbstractMethodMessageHandler logger.debug("Handling message, lookupDestination=" + lookupDestination); } - message = MessageBuilder.fromMessage(message).setHeader( - DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, lookupDestination).build(); + MessageHeaderAccessor headerAccessor = MessageHeaderAccessor.getMutableAccessor(message); + headerAccessor.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, lookupDestination); + message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders()); handleMessageInternal(message, lookupDestination); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageTypeMessageCondition.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageTypeMessageCondition.java index a5af929966..fa698d038a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageTypeMessageCondition.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageTypeMessageCondition.java @@ -73,7 +73,7 @@ public class SimpMessageTypeMessageCondition extends AbstractMessageCondition message) { - Object actualMessageType = message.getHeaders().get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER); + Object actualMessageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders()); if (actualMessageType == null) { return null; } @@ -83,7 +83,7 @@ public class SimpMessageTypeMessageCondition extends AbstractMessageCondition message) { - Object actualMessageType = message.getHeaders().get(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER); + Object actualMessageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders()); if (actualMessageType != null) { if (actualMessageType.equals(this.getMessageType()) && actualMessageType.equals(other.getMessageType())) { return 0; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java index 39da6ca69d..dd8a902856 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessagingTemplate.java @@ -108,8 +108,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate message) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - String destination = headers.getDestination(); + String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); destination = (destination != null) ? destination : getRequiredDefaultDestination(); doSend(destination, message); } @@ -118,10 +117,10 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate message) { Assert.notNull(destination, "Destination must not be null"); - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - headers.setDestination(destination); - headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); - message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); + headerAccessor.setDestination(destination); + headerAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); + message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders()); long timeout = this.sendTimeout; boolean sent = (timeout >= 0) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMethodArgumentResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMethodArgumentResolver.java index dedbaed0c6..6a08d95f7c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMethodArgumentResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMethodArgumentResolver.java @@ -37,8 +37,7 @@ public class PrincipalMethodArgumentResolver implements HandlerMethodArgumentRes @Override public Object resolveArgument(MethodParameter parameter, Message message) throws Exception { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - Principal user = headers.getUser(); + Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders()); if (user == null) { throw new MissingSessionUserException(message); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java index 957505cb26..69c301182c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java @@ -23,6 +23,7 @@ import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.core.MessagePostProcessor; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.annotation.SendTo; @@ -113,29 +114,28 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH } @Override - public void handleReturnValue(Object returnValue, MethodParameter returnType, Message inputMessage) + public void handleReturnValue(Object returnValue, MethodParameter returnType, Message message) throws Exception { if (returnValue == null) { return; } - SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(inputMessage); - - String sessionId = inputHeaders.getSessionId(); + MessageHeaders headers = message.getHeaders(); + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(sessionId); SendToUser sendToUser = returnType.getMethodAnnotation(SendToUser.class); if (sendToUser != null) { - Principal principal = inputHeaders.getUser(); + Principal principal = SimpMessageHeaderAccessor.getUser(headers); if (principal == null) { - throw new MissingSessionUserException(inputMessage); + throw new MissingSessionUserException(message); } String userName = principal.getName(); if (principal instanceof DestinationUserNameProvider) { userName = ((DestinationUserNameProvider) principal).getDestinationUserName(); } - String[] destinations = getTargetDestinations(sendToUser, inputHeaders, this.defaultUserDestinationPrefix); + String[] destinations = getTargetDestinations(sendToUser, message, this.defaultUserDestinationPrefix); for (String destination : destinations) { this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, postProcessor); } @@ -143,15 +143,14 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH } else { SendTo sendTo = returnType.getMethodAnnotation(SendTo.class); - String[] destinations = getTargetDestinations(sendTo, inputHeaders, this.defaultDestinationPrefix); + String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix); for (String destination : destinations) { this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor); } } } - protected String[] getTargetDestinations(Annotation annot, SimpMessageHeaderAccessor inputHeaders, - String defaultPrefix) { + protected String[] getTargetDestinations(Annotation annot, Message message, String defaultPrefix) { if (annot != null) { String[] value = (String[]) AnnotationUtils.getValue(annot); @@ -159,8 +158,8 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH return value; } } - return new String[] { defaultPrefix + - inputHeaders.getHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER) }; + String name = DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER; + return new String[] { defaultPrefix + message.getHeaders().get(name) }; } @@ -176,7 +175,7 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH public Message postProcessMessage(Message message) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); headers.setSessionId(this.sessionId); - return MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + return MessageBuilder.createMessage(message.getPayload(), headers.getMessageHeaders()); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java index 3bd8cfa142..17e82eae17 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SimpAnnotationMethodMessageHandler.java @@ -63,6 +63,7 @@ import org.springframework.stereotype.Controller; import org.springframework.util.AntPathMatcher; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; import org.springframework.util.PathMatcher; import org.springframework.validation.Errors; import org.springframework.validation.Validator; @@ -329,7 +330,7 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan @Override protected String getDestination(Message message) { - return (String) message.getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER); + return (String) SimpMessageHeaderAccessor.getDestination(message.getHeaders()); } @Override @@ -352,13 +353,15 @@ public class SimpAnnotationMethodMessageHandler extends AbstractMethodMessageHan protected void handleMatch(SimpMessageMappingInfo mapping, HandlerMethod handlerMethod, String lookupDestination, Message message) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - String matchedPattern = mapping.getDestinationConditions().getPatterns().iterator().next(); Map vars = getPathMatcher().extractUriTemplateVariables(matchedPattern, lookupDestination); - headers.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars); - message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + if (!CollectionUtils.isEmpty(vars)) { + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + headers.setHeader(DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER, vars); + message = MessageBuilder.createMessage(message.getPayload(), headers.getMessageHeaders()); + + } super.handleMatch(mapping, handlerMethod, lookupDestination, message); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java index 9bca795fd4..4d91326f8e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java @@ -18,6 +18,7 @@ package org.springframework.messaging.simp.annotation.support; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.core.MessagePostProcessor; import org.springframework.messaging.core.MessageSendingOperations; import org.springframework.messaging.handler.annotation.SendTo; @@ -71,13 +72,12 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn return; } - SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(message); - String sessionId = inputHeaders.getSessionId(); - String subscriptionId = inputHeaders.getSubscriptionId(); - String destination = inputHeaders.getDestination(); + MessageHeaders headers = message.getHeaders(); + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); + String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers); + String destination = SimpMessageHeaderAccessor.getDestination(headers); - Assert.state(inputHeaders.getSubscriptionId() != null, - "No subsriptiondId in input message to method " + returnType.getMethod()); + Assert.state(subscriptionId != null, "No subsriptiondId in input message to method " + returnType.getMethod()); MessagePostProcessor postProcessor = new SubscriptionHeaderPostProcessor(sessionId, subscriptionId); this.messagingTemplate.convertAndSend(destination, returnValue, postProcessor); @@ -98,11 +98,11 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn @Override public Message postProcessMessage(Message message) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - headers.setSessionId(this.sessionId); - headers.setSubscriptionId(this.subscriptionId); - headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); - return MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); + headerAccessor.setSessionId(this.sessionId); + headerAccessor.setSubscriptionId(this.subscriptionId); + headerAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); + return MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders()); } } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractSubscriptionRegistry.java index 34fc89799a..1b991f2781 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/AbstractSubscriptionRegistry.java @@ -19,8 +19,10 @@ package org.springframework.messaging.simp.broker; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.MultiValueMap; /** @@ -38,29 +40,31 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist @Override public final void registerSubscription(Message message) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - if (!SimpMessageType.SUBSCRIBE.equals(headers.getMessageType())) { + + MessageHeaders headers = message.getHeaders(); + SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers); + + if (!SimpMessageType.SUBSCRIBE.equals(type)) { logger.error("Expected SUBSCRIBE message: " + message); return; } - String sessionId = headers.getSessionId(); + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); if (sessionId == null) { logger.error("Ignoring subscription. No sessionId in message: " + message); return; } - String subscriptionId = headers.getSubscriptionId(); + String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers); if (subscriptionId == null) { logger.error("Ignoring subscription. No subscriptionId in message: " + message); return; } - String destination = headers.getDestination(); + String destination = SimpMessageHeaderAccessor.getDestination(headers); if (destination == null) { logger.error("Ignoring destination. No destination in message: " + message); return; } if (logger.isDebugEnabled()) { - logger.debug("Adding subscription id=" + headers.getSubscriptionId() - + ", destination=" + headers.getDestination()); + logger.debug("Adding subscription id=" + subscriptionId + ", destination=" + destination); } addSubscriptionInternal(sessionId, subscriptionId, destination, message); } @@ -70,17 +74,20 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist @Override public final void unregisterSubscription(Message message) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - if (!SimpMessageType.UNSUBSCRIBE.equals(headers.getMessageType())) { + + MessageHeaders headers = message.getHeaders(); + SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers); + + if (!SimpMessageType.UNSUBSCRIBE.equals(type)) { logger.error("Expected UNSUBSCRIBE message: " + message); return; } - String sessionId = headers.getSessionId(); + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); if (sessionId == null) { logger.error("Ignoring subscription. No sessionId in message: " + message); return; } - String subscriptionId = headers.getSubscriptionId(); + String subscriptionId = SimpMessageHeaderAccessor.getSubscriptionId(headers); if (subscriptionId == null) { logger.error("Ignoring subscription. No subscriptionId in message: " + message); return; @@ -98,19 +105,22 @@ public abstract class AbstractSubscriptionRegistry implements SubscriptionRegist @Override public final MultiValueMap findSubscriptions(Message message) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - if (!SimpMessageType.MESSAGE.equals(headers.getMessageType())) { - logger.trace("Ignoring message type " + headers.getMessageType()); + + MessageHeaders headers = message.getHeaders(); + SimpMessageType type = SimpMessageHeaderAccessor.getMessageType(headers); + + if (!SimpMessageType.MESSAGE.equals(type)) { + logger.trace("Ignoring message type " + type); return null; } - String destination = headers.getDestination(); + String destination = SimpMessageHeaderAccessor.getDestination(headers); if (destination == null) { logger.trace("Ignoring message, no destination"); return null; } MultiValueMap result = findSubscriptionsInternal(destination, message); if (logger.isTraceEnabled()) { - logger.trace("Found " + result.size() + " subscriptions for destination=" + headers.getDestination()); + logger.trace("Found " + result.size() + " subscriptions for destination=" + destination); } return result; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java index 358f149fb8..8aa18b151e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java @@ -20,6 +20,7 @@ import java.util.Collection; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; @@ -111,9 +112,10 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { @Override protected void handleMessageInternal(Message message) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - SimpMessageType messageType = headers.getMessageType(); - String destination = headers.getDestination(); + MessageHeaders headers = message.getHeaders(); + SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); + String destination = SimpMessageHeaderAccessor.getDestination(headers); + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); if (!checkDestinationPrefix(destination)) { if (logger.isTraceEnabled()) { @@ -122,27 +124,30 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { return; } - if (SimpMessageType.SUBSCRIBE.equals(messageType)) { + if (SimpMessageType.MESSAGE.equals(messageType)) { + sendMessageToSubscribers(destination, message); + } + else if (SimpMessageType.SUBSCRIBE.equals(messageType)) { this.subscriptionRegistry.registerSubscription(message); } else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { this.subscriptionRegistry.unregisterSubscription(message); } - else if (SimpMessageType.MESSAGE.equals(messageType)) { - sendMessageToSubscribers(headers.getDestination(), message); - } else if (SimpMessageType.DISCONNECT.equals(messageType)) { - String sessionId = headers.getSessionId(); - this.subscriptionRegistry.unregisterAllSubscriptions(sessionId); + this.subscriptionRegistry.unregisterAllSubscriptions(sessionId); } else if (SimpMessageType.CONNECT.equals(messageType)) { - SimpMessageHeaderAccessor replyHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); - replyHeaders.setSessionId(headers.getSessionId()); - replyHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message); - - Message connectAck = MessageBuilder.withPayload(EMPTY_PAYLOAD).setHeaders(replyHeaders).build(); + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); + accessor.setSessionId(sessionId); + accessor.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message); + Message connectAck = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); this.clientOutboundChannel.send(connectAck); } + else { + if (logger.isTraceEnabled()) { + logger.trace("Message type not supported. Ignoring: " + message); + } + } } protected void sendMessageToSubscribers(String destination, Message message) { @@ -153,17 +158,17 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { } for (String sessionId : subscriptions.keySet()) { for (String subscriptionId : subscriptions.get(sessionId)) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - headers.setSessionId(sessionId); - headers.setSubscriptionId(subscriptionId); + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + headerAccessor.setSessionId(sessionId); + headerAccessor.setSubscriptionId(subscriptionId); + headerAccessor.copyHeadersIfAbsent(message.getHeaders()); Object payload = message.getPayload(); - Message clientMessage = MessageBuilder.withPayload(payload).setHeaders(headers).build(); + Message reply = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders()); try { - this.clientOutboundChannel.send(clientMessage); + this.clientOutboundChannel.send(reply); } catch (Throwable ex) { - logger.error("Failed to send message to destination=" + destination + - ", sessionId=" + sessionId + ", subscriptionId=" + subscriptionId, ex); + logger.error("Failed to send message=" + message, ex); } } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 0ecbd410d7..20b005af7d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -20,7 +20,6 @@ import java.util.Collection; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; @@ -30,6 +29,7 @@ import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.tcp.FixedIntervalReconnectStrategy; import org.springframework.messaging.tcp.TcpConnection; import org.springframework.messaging.tcp.TcpConnectionHandler; @@ -79,9 +79,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private static final Message HEARTBEAT_MESSAGE; static { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); - HEARTBEAT_MESSAGE = MessageBuilder.withPayload(new byte[] {'\n'}).setHeaders(headers).build(); EMPTY_TASK.run(); + StompHeaderAccessor accessor = StompHeaderAccessor.createForHeartbeat(); + HEARTBEAT_MESSAGE = MessageBuilder.createMessage(StompDecoder.HEARTBEAT_PAYLOAD, accessor.getMessageHeaders()); } @@ -370,37 +370,53 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @Override protected void handleMessageInternal(Message message) { - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - String sessionId = headers.getSessionId(); + String sessionId = SimpMessageHeaderAccessor.getSessionId(message.getHeaders()); if (!isBrokerAvailable()) { - if (sessionId == null || sessionId == SystemStompConnectionHandler.SESSION_ID) { + if (sessionId == null || SystemStompConnectionHandler.SESSION_ID.equals(sessionId)) { throw new MessageDeliveryException("Message broker is not active."); } if (logger.isTraceEnabled()) { - logger.trace("Message broker is not active. Ignoring message id=" + message.getHeaders().getId()); + logger.trace("Message broker is not active. Ignoring: " + message); } return; } - String destination = headers.getDestination(); - StompCommand command = headers.getCommand(); - SimpMessageType messageType = headers.getMessageType(); + StompHeaderAccessor stompAccessor; + StompCommand command; - if (SimpMessageType.MESSAGE.equals(messageType)) { - sessionId = (sessionId == null) ? SystemStompConnectionHandler.SESSION_ID : sessionId; - headers.setSessionId(sessionId); - command = headers.updateStompCommandAsClientMessage(); - message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); + if (accessor == null) { + logger.error("No header accessor, please use SimpMessagingTemplate. Ignoring: " + message); + return; + } + else if (accessor instanceof StompHeaderAccessor) { + stompAccessor = (StompHeaderAccessor) accessor; + command = stompAccessor.getCommand(); + } + else if (accessor instanceof SimpMessageHeaderAccessor) { + stompAccessor = StompHeaderAccessor.wrap(message); + command = stompAccessor.getCommand(); + if (command == null) { + command = stompAccessor.updateStompCommandAsClientMessage(); + } + } + else { + // Should not happen + logger.error("Unexpected header accessor type: " + accessor + ". Ignoring: " + message); + return; } if (sessionId == null) { - if (logger.isWarnEnabled()) { - logger.warn("No sessionId, ignoring message: " + message); + if (!SimpMessageType.MESSAGE.equals(stompAccessor.getMessageType())) { + logger.error("Only STOMP SEND frames supported on \"system\" connection. Ignoring: " + message); + return; } - return; + sessionId = SystemStompConnectionHandler.SESSION_ID; + stompAccessor.setSessionId(sessionId); } + String destination = stompAccessor.getDestination(); if ((command != null) && command.requiresDestination() && !checkDestinationPrefix(destination)) { if (logger.isTraceEnabled()) { logger.trace("Ignoring message to destination=" + destination); @@ -412,20 +428,21 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler logger.trace("Processing message=" + message); } - if (SimpMessageType.CONNECT.equals(messageType)) { + if (StompCommand.CONNECT.equals(command)) { if (logger.isDebugEnabled()) { logger.debug("Processing CONNECT (total connected=" + this.connectionHandlers.size() + ")"); } - headers.setLogin(this.clientLogin); - headers.setPasscode(this.clientPasscode); + stompAccessor = (stompAccessor.isMutable() ? stompAccessor : StompHeaderAccessor.wrap(message)); + stompAccessor.setLogin(this.clientLogin); + stompAccessor.setPasscode(this.clientPasscode); if (getVirtualHost() != null) { - headers.setHost(getVirtualHost()); + stompAccessor.setHost(getVirtualHost()); } - StompConnectionHandler handler = new StompConnectionHandler(sessionId, headers); + StompConnectionHandler handler = new StompConnectionHandler(sessionId, stompAccessor); this.connectionHandlers.put(sessionId, handler); this.tcpClient.connect(handler); } - else if (SimpMessageType.DISCONNECT.equals(messageType)) { + else if (StompCommand.DISCONNECT.equals(command)) { StompConnectionHandler handler = this.connectionHandlers.get(sessionId); if (handler == null) { if (logger.isTraceEnabled()) { @@ -433,7 +450,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } return; } - handler.forward(message); + handler.forward(message, stompAccessor); } else { StompConnectionHandler handler = this.connectionHandlers.get(sessionId); @@ -443,7 +460,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } return; } - handler.forward(message); + handler.forward(message, stompAccessor); } } @@ -486,7 +503,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler logger.debug("Established TCP connection to broker in session '" + this.sessionId + "'"); } this.tcpConnection = connection; - connection.send(MessageBuilder.withPayload(EMPTY_PAYLOAD).setHeaders(this.connectHeaders).build()); + connection.send(MessageBuilder.createMessage(EMPTY_PAYLOAD, this.connectHeaders.getMessageHeaders())); } @Override @@ -522,7 +539,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); headers.setSessionId(this.sessionId); headers.setMessage(errorText); - Message errorMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message errorMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); sendMessageToClient(errorMessage); } } @@ -536,20 +553,23 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @Override public void handleMessage(Message message) { - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) { + StompHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + + if (headerAccessor.isHeartbeat()) { logger.trace("Received broker heartbeat"); } else if (logger.isDebugEnabled()) { logger.debug("Received message from broker in session '" + this.sessionId + "'"); } - if (StompCommand.CONNECTED == headers.getCommand()) { - afterStompConnected(headers); + if (StompCommand.CONNECTED == headerAccessor.getCommand()) { + afterStompConnected(headerAccessor); } - headers.setSessionId(this.sessionId); - message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + headerAccessor.setSessionId(this.sessionId); + headerAccessor.setImmutable(); + sendMessageToClient(message); } @@ -630,9 +650,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler clearConnection(); } catch (Throwable t) { - if (logger.isErrorEnabled()) { - // Ignore - } + // Ignore } } } @@ -661,7 +679,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler * @return a future to wait for the result */ @SuppressWarnings("unchecked") - public ListenableFuture forward(final Message message) { + public ListenableFuture forward(Message message, final StompHeaderAccessor headerAccessor) { TcpConnection conn = this.tcpConnection; @@ -682,8 +700,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } if (logger.isDebugEnabled()) { - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) { + if (headerAccessor.isHeartbeat()) { logger.trace("Forwarding heartbeat to broker"); } else { @@ -691,13 +708,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } } + if (headerAccessor.isMutable() && headerAccessor.isModified()) { + message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders()); + } + ListenableFuture future = conn.send((Message) message); future.addCallback(new ListenableFutureCallback() { @Override public void onSuccess(Void result) { - StompCommand command = StompHeaderAccessor.wrap(message).getCommand(); - if (command == StompCommand.DISCONNECT) { + if (headerAccessor.getCommand() == StompCommand.DISCONNECT) { clearConnection(); } } @@ -707,7 +727,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler // already reset } else { - handleTcpConnectionFailure("Failed to send message " + message, t); + handleTcpConnectionFailure("Failed to send message " + headerAccessor, t); } } }); @@ -777,9 +797,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } @Override - public ListenableFuture forward(Message message) { + public ListenableFuture forward(Message message, StompHeaderAccessor headerAccessor) { try { - ListenableFuture future = super.forward(message); + ListenableFuture future = super.forward(message, headerAccessor); future.get(); return future; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java index ec6a872c48..375b9309ba 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java @@ -19,6 +19,7 @@ package org.springframework.messaging.simp.user; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.util.Assert; @@ -100,34 +101,34 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { @Override public UserDestinationResult resolveDestination(Message message) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); - DestinationInfo info = parseUserDestination(headers); + String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); + DestinationInfo info = parseUserDestination(message); if (info == null) { return null; } Set targetDestinations = new HashSet(); for (String sessionId : info.getSessionIds()) { - targetDestinations.add(getTargetDestination( - headers.getDestination(), info.getDestinationWithoutPrefix(), sessionId, info.getUser())); + targetDestinations.add(getTargetDestination(destination, + info.getDestinationWithoutPrefix(), sessionId, info.getUser())); } - return new UserDestinationResult(headers.getDestination(), + return new UserDestinationResult(destination, targetDestinations, info.getSubscribeDestination(), info.getUser()); } - private DestinationInfo parseUserDestination(SimpMessageHeaderAccessor headers) { + private DestinationInfo parseUserDestination(Message message) { - String destination = headers.getDestination(); + MessageHeaders headers = message.getHeaders(); + SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); + String destination = SimpMessageHeaderAccessor.getDestination(headers); + Principal principal = SimpMessageHeaderAccessor.getUser(headers); String destinationWithoutPrefix; String subscribeDestination; String user; Set sessionIds; - Principal principal = headers.getUser(); - SimpMessageType messageType = headers.getMessageType(); - if (SimpMessageType.SUBSCRIBE.equals(messageType) || SimpMessageType.UNSUBSCRIBE.equals(messageType)) { if (!checkDestination(destination, this.destinationPrefix)) { return null; @@ -136,14 +137,15 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { logger.error("Ignoring message, no principal info available"); return null; } - if (headers.getSessionId() == null) { + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); + if (sessionId == null) { logger.error("Ignoring message, no session id available"); return null; } destinationWithoutPrefix = destination.substring(this.destinationPrefix.length()-1); subscribeDestination = destination; user = principal.getName(); - sessionIds = Collections.singleton(headers.getSessionId()); + sessionIds = Collections.singleton(sessionId); } else if (SimpMessageType.MESSAGE.equals(messageType)) { if (!checkDestination(destination, this.destinationPrefix)) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java index 5d0686789d..e16bfdebdf 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java @@ -152,16 +152,16 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec if (destinations.isEmpty()) { return; } - SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); - if (SimpMessageType.MESSAGE.equals(headerAccessor.getMessageType())) { + if (SimpMessageType.MESSAGE.equals(SimpMessageHeaderAccessor.getMessageType(message.getHeaders()))) { + SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); headerAccessor.setHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION, result.getSubscribeDestination()); - message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headerAccessor).build(); + message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders()); } - for (String targetDestination : destinations) { + for (String destination : destinations) { if (logger.isDebugEnabled()) { - logger.debug("Sending message to resolved destination=" + targetDestination); + logger.debug("Sending message to resolved destination=" + destination); } - this.brokerMessagingTemplate.send(targetDestination, message); + this.brokerMessagingTemplate.send(destination, message); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/converter/MessageConverterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/converter/MessageConverterTests.java index 1f591ae7b5..77e8f974f7 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/converter/MessageConverterTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/converter/MessageConverterTests.java @@ -112,10 +112,12 @@ public class MessageConverterTests { public void toMessageHeadersCopied() { Map map = new HashMap(); map.put("foo", "bar"); - MessageHeaders headers = new MessageHeaders(map ); + MessageHeaders headers = new MessageHeaders(map); Message message = this.converter.toMessage("ABC", headers); assertEquals("bar", message.getHeaders().get("foo")); + assertNotNull(message.getHeaders().getId()); + assertNotNull(message.getHeaders().getTimestamp()); } @Test diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index dd2e973ba8..747e7cf948 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -31,7 +31,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.support.StaticApplicationContext; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.converter.*; import org.springframework.messaging.handler.annotation.MessageMapping; diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index c13fa52a73..e2f2ad35a3 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -38,6 +38,7 @@ import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessagingException; import org.springframework.messaging.StubMessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.broker.BrokerAvailabilityEvent; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.ExecutorSubscribableChannel; @@ -168,7 +169,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { public void messageDeliverExceptionIfSystemSessionForwardFails() throws Exception { stopActiveMqBrokerAndAwait(); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); - this.relay.handleMessage(MessageBuilder.withPayload("test".getBytes()).setHeaders(headers).build()); + this.relay.handleMessage(MessageBuilder.createMessage("test".getBytes(), headers.getMessageHeaders())); } @Test @@ -244,7 +245,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); headers.setSessionId("sess1"); - this.relay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build()); + this.relay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); Thread.sleep(2000); @@ -394,7 +395,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { headers.setSessionId(sessionId); headers.setAcceptVersion("1.1,1.2"); headers.setHeartbeat(0, 0); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); MessageExchangeBuilder builder = new MessageExchangeBuilder(message); builder.expected.add(new StompConnectedFrameMessageMatcher(sessionId)); @@ -405,7 +406,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); headers.setAcceptVersion("1.1,1.2"); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); MessageExchangeBuilder builder = new MessageExchangeBuilder(message); return builder.andExpectError(); } @@ -418,7 +419,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { headers.setSubscriptionId(subscriptionId); headers.setDestination(destination); headers.setReceipt(receiptId); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); MessageExchangeBuilder builder = new MessageExchangeBuilder(message); builder.expected.add(new StompReceiptFrameMessageMatcher(sessionId, receiptId)); @@ -426,14 +427,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } public static MessageExchangeBuilder send(String destination, String payload) { - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); headers.setDestination(destination); - Message message = MessageBuilder.withPayload(payload.getBytes(UTF_8)).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(payload.getBytes(UTF_8), headers.getMessageHeaders()); return new MessageExchangeBuilder(message); } public MessageExchangeBuilder andExpectMessage(String sessionId, String subscriptionId) { - Assert.isTrue(StompCommand.SEND.equals(headers.getCommand()), "MESSAGE can only be expected after SEND"); + Assert.isTrue(SimpMessageType.MESSAGE.equals(headers.getMessageType())); String destination = this.headers.getDestination(); Object payload = this.message.getPayload(); this.expected.add(new StompMessageFrameMessageMatcher(sessionId, subscriptionId, destination, payload)); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java index 2fa50726ae..6308721d3c 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java @@ -27,6 +27,7 @@ import org.springframework.messaging.StubMessageChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.tcp.ReconnectStrategy; import org.springframework.messaging.tcp.TcpConnection; import org.springframework.messaging.tcp.TcpConnectionHandler; @@ -77,17 +78,21 @@ public class StompBrokerRelayMessageHandlerTests { String sessionId = "sess1"; StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); - this.brokerRelay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build()); + this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); List> sent = this.tcpClient.connection.messages; assertEquals(2, sent.size()); StompHeaderAccessor headers1 = StompHeaderAccessor.wrap(sent.get(0)); assertEquals(virtualHost, headers1.getHost()); + assertNotNull("The prepared message does not have an accessor", + MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class)); StompHeaderAccessor headers2 = StompHeaderAccessor.wrap(sent.get(1)); assertEquals(sessionId, headers2.getSessionId()); assertEquals(virtualHost, headers2.getHost()); + assertNotNull("The prepared message does not have an accessor", + MessageHeaderAccessor.getAccessor(sent.get(1), MessageHeaderAccessor.class)); } @Test @@ -104,7 +109,7 @@ public class StompBrokerRelayMessageHandlerTests { String sessionId = "sess1"; StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); - this.brokerRelay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build()); + this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); List> sent = this.tcpClient.connection.messages; assertEquals(2, sent.size()); @@ -126,11 +131,13 @@ public class StompBrokerRelayMessageHandlerTests { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); headers.setSessionId("sess1"); headers.setDestination("/user/daisy/foo"); - this.brokerRelay.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build()); + this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); List> sent = this.tcpClient.connection.messages; assertEquals(1, sent.size()); assertEquals(StompCommand.CONNECT, StompHeaderAccessor.wrap(sent.get(0)).getCommand()); + assertNotNull("The prepared message does not have an accessor", + MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class)); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java index 10ff0c9a8d..120d4e2d2a 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java @@ -66,8 +66,7 @@ public class UserDestinationMessageHandlerTests { ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); Mockito.verify(this.brokerChannel).send(captor.capture()); - assertEquals("/queue/foo-user123", - captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER)); + assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders())); } @Test @@ -79,8 +78,7 @@ public class UserDestinationMessageHandlerTests { ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); Mockito.verify(this.brokerChannel).send(captor.capture()); - assertEquals("/queue/foo-user123", - captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER)); + assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders())); } @Test @@ -93,10 +91,8 @@ public class UserDestinationMessageHandlerTests { ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); Mockito.verify(this.brokerChannel).send(captor.capture()); - assertEquals("/queue/foo-user123", - captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.DESTINATION_HEADER)); - assertEquals("/user/queue/foo", - captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION)); + assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders())); + assertEquals("/user/queue/foo", captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION)); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 352fc14577..24413dfa0b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -41,9 +41,9 @@ import org.springframework.messaging.simp.stomp.StompConversionException; import org.springframework.messaging.simp.stomp.StompEncoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.user.DestinationUserNameProvider; -import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; @@ -79,6 +79,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class); + private static final byte[] EMPTY_PAYLOAD = new byte[0]; + private int messageSizeLimit = 64 * 1024; @@ -172,9 +174,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE for (Message message : messages) { try { - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + + StompHeaderAccessor headerAccessor = + MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + if (logger.isTraceEnabled()) { - if (SimpMessageType.HEARTBEAT.equals(headers.getMessageType())) { + if (headerAccessor.isHeartbeat()) { logger.trace("Received heartbeat from client session=" + session.getId()); } else { @@ -182,13 +187,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } - headers.setSessionId(session.getId()); - headers.setSessionAttributes(session.getAttributes()); - headers.setUser(session.getPrincipal()); + headerAccessor.setSessionId(session.getId()); + headerAccessor.setSessionAttributes(session.getAttributes()); + headerAccessor.setUser(session.getPrincipal()); + headerAccessor.setImmutable(); - message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); - - if (this.eventPublisher != null && StompCommand.CONNECT.equals(headers.getCommand())) { + if (this.eventPublisher != null && StompCommand.CONNECT.equals(headerAccessor.getCommand())) { publishEvent(new SessionConnectEvent(this, message)); } @@ -212,10 +216,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE protected void sendErrorMessage(WebSocketSession session, Throwable error) { - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); - headers.setMessage(error.getMessage()); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); - byte[] bytes = this.stompEncoder.encode(message); + StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR); + headerAccessor.setMessage(error.getMessage()); + byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD); try { session.sendMessage(new TextMessage(bytes)); } @@ -231,46 +234,60 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @Override public void handleMessageToClient(WebSocketSession session, Message message) { - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - - if (headers.getMessageType() == SimpMessageType.CONNECT_ACK) { - StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED); - connectedHeaders.setVersion(getVersion(headers)); - connectedHeaders.setHeartbeat(0, 0); // no heart-beat support with simple broker - headers = connectedHeaders; - } - else if (SimpMessageType.MESSAGE.equals(headers.getMessageType())) { - headers.updateStompCommandAsServerMessage(); - } - - if (headers.getCommand() == StompCommand.CONNECTED) { - afterStompSessionConnected(headers, session); - } - - if (StompCommand.MESSAGE.equals(headers.getCommand())) { - if (headers.getSubscriptionId() == null) { - logger.error("Ignoring message, no subscriptionId header: " + message); - return; - } - String header = SimpMessageHeaderAccessor.ORIGINAL_DESTINATION; - if (message.getHeaders().containsKey(header)) { - headers.setDestination((String) message.getHeaders().get(header)); - } - } - if (!(message.getPayload() instanceof byte[])) { logger.error("Ignoring message, expected byte[] content: " + message); return; } - try { - message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); + if (accessor == null) { + logger.error("No header accessor: " + message); + return; + } - if (this.eventPublisher != null && StompCommand.CONNECTED.equals(headers.getCommand())) { + StompHeaderAccessor stompAccessor; + if (accessor instanceof StompHeaderAccessor) { + stompAccessor = (StompHeaderAccessor) accessor; + } + else if (accessor instanceof SimpMessageHeaderAccessor) { + stompAccessor = StompHeaderAccessor.wrap(message); + if (SimpMessageType.CONNECT_ACK.equals(stompAccessor.getMessageType())) { + StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED); + connectedHeaders.setVersion(getVersion(stompAccessor)); + connectedHeaders.setHeartbeat(0, 0); // no heart-beat support with simple broker + stompAccessor = connectedHeaders; + } + else if (stompAccessor.getCommand() == null || StompCommand.SEND.equals(stompAccessor.getCommand())) { + stompAccessor.updateStompCommandAsServerMessage(); + } + } + else { + // Should not happen + logger.error("Unexpected header accessor type: " + accessor); + return; + } + + StompCommand command = stompAccessor.getCommand(); + if (StompCommand.MESSAGE.equals(command)) { + if (stompAccessor.getSubscriptionId() == null) { + logger.error("Ignoring message, no subscriptionId header: " + message); + return; + } + String header = SimpMessageHeaderAccessor.ORIGINAL_DESTINATION; + if (message.getHeaders().containsKey(header)) { + stompAccessor = toMutableAccessor(stompAccessor, message); + stompAccessor.setDestination((String) message.getHeaders().get(header)); + } + } + else if (StompCommand.CONNECTED.equals(command)) { + stompAccessor = afterStompSessionConnected(message, stompAccessor, session); + if (this.eventPublisher != null && StompCommand.CONNECTED.equals(command)) { publishEvent(new SessionConnectedEvent(this, (Message) message)); } + } - byte[] bytes = this.stompEncoder.encode((Message) message); + try { + byte[] bytes = this.stompEncoder.encode(stompAccessor.getMessageHeaders(), (byte[]) message.getPayload()); TextMessage textMessage = new TextMessage(bytes); session.sendMessage(textMessage); @@ -283,7 +300,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE sendErrorMessage(session, ex); } finally { - if (StompCommand.ERROR.equals(headers.getCommand())) { + if (StompCommand.ERROR.equals(command)) { try { session.close(CloseStatus.PROTOCOL_ERROR); } @@ -294,13 +311,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } + protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message message) { + return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message)); + } + private String getVersion(StompHeaderAccessor connectAckHeaders) { String name = StompHeaderAccessor.CONNECT_MESSAGE_HEADER; Message connectMessage = (Message) connectAckHeaders.getHeader(name); - StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(connectMessage); Assert.notNull(connectMessage, "CONNECT_ACK does not contain original CONNECT " + connectAckHeaders); + StompHeaderAccessor connectHeaders = + MessageHeaderAccessor.getAccessor(connectMessage, StompHeaderAccessor.class); + Set acceptVersions = connectHeaders.getAcceptVersion(); if (acceptVersions.contains("1.2")) { return "1.2"; @@ -316,16 +339,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } - private void afterStompSessionConnected(StompHeaderAccessor headers, WebSocketSession session) { + private StompHeaderAccessor afterStompSessionConnected( + Message message, StompHeaderAccessor headerAccessor, WebSocketSession session) { + Principal principal = session.getPrincipal(); if (principal != null) { - headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); + headerAccessor = toMutableAccessor(headerAccessor, message); + headerAccessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); if (this.userSessionRegistry != null) { String userName = resolveNameForUserSessionRegistry(principal); this.userSessionRegistry.registerSessionId(userName, session.getId()); } } - long[] heartbeat = headers.getHeartbeat(); + long[] heartbeat = headerAccessor.getHeartbeat(); if (heartbeat[1] > 0) { session = WebSocketSessionDecorator.unwrap(session); if (session instanceof SockJsSession) { @@ -333,6 +359,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE ((SockJsSession) session).disableHeartbeat(); } } + return headerAccessor; } private String resolveNameForUserSessionRegistry(Principal principal) { @@ -345,8 +372,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @Override public String resolveSessionId(Message message) { - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - return headers.getSessionId(); + return SimpMessageHeaderAccessor.getSessionId(message.getHeaders()); } @Override @@ -374,7 +400,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); headers.setSessionId(session.getId()); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); if (this.eventPublisher != null) { publishEvent(new SessionDisconnectEvent(this, session.getId(), closeStatus)); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 788c6a6a3c..4a9f62020c 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -16,7 +16,6 @@ package org.springframework.web.socket.messaging; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; @@ -41,7 +40,6 @@ import org.springframework.messaging.simp.stomp.StompEncoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; import org.springframework.messaging.simp.user.DestinationUserNameProvider; -import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.CloseStatus; @@ -61,6 +59,8 @@ import static org.mockito.Mockito.*; */ public class StompSubProtocolHandlerTests { + public static final byte[] EMPTY_PAYLOAD = new byte[0]; + private StompSubProtocolHandler protocolHandler; private TestWebSocketSession session; @@ -89,7 +89,7 @@ public class StompSubProtocolHandlerTests { this.protocolHandler.setUserSessionRegistry(registry); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); this.protocolHandler.handleMessageToClient(this.session, message); assertEquals(1, this.session.getSentMessages().size()); @@ -108,7 +108,7 @@ public class StompSubProtocolHandlerTests { this.protocolHandler.setUserSessionRegistry(registry); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); this.protocolHandler.handleMessageToClient(this.session, message); assertEquals(1, this.session.getSentMessages().size()); @@ -126,7 +126,7 @@ public class StompSubProtocolHandlerTests { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED); headers.setHeartbeat(0,10); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); this.protocolHandler.handleMessageToClient(sockJsSession, message); verify(sockJsSession).disableHeartbeat(); @@ -137,12 +137,12 @@ public class StompSubProtocolHandlerTests { StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT); connectHeaders.setHeartbeat(10000, 10000); - connectHeaders.setNativeHeader(StompHeaderAccessor.STOMP_ACCEPT_VERSION_HEADER, "1.0,1.1"); - Message connectMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectHeaders).build(); + connectHeaders.setAcceptVersion("1.0,1.1"); + Message connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectHeaders.getMessageHeaders()); SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage); - Message connectAckMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(connectAckHeaders).build(); + Message connectAckMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAckHeaders.getMessageHeaders()); this.protocolHandler.handleMessageToClient(this.session, connectAckMessage); @@ -174,12 +174,12 @@ public class StompSubProtocolHandlerTests { this.protocolHandler.afterSessionStarted(this.session, this.channel); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); TextMessage textMessage = new TextMessage(new StompEncoder().encode(message)); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); headers = StompHeaderAccessor.create(StompCommand.CONNECTED); - message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); this.protocolHandler.handleMessageToClient(this.session, message); this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel); @@ -207,7 +207,7 @@ public class StompSubProtocolHandlerTests { this.protocolHandler.afterSessionStarted(this.session, this.channel); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); TextMessage textMessage = new TextMessage(new StompEncoder().encode(message)); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); @@ -218,7 +218,7 @@ public class StompSubProtocolHandlerTests { reset(this.channel); headers = StompHeaderAccessor.create(StompCommand.CONNECTED); - message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); this.protocolHandler.handleMessageToClient(this.session, message); assertEquals(1, this.session.getSentMessages().size()); @@ -241,7 +241,7 @@ public class StompSubProtocolHandlerTests { headers.setSubscriptionId("sub0"); headers.setDestination("/queue/foo-user123"); headers.setHeader(StompHeaderAccessor.ORIGINAL_DESTINATION, "/user/queue/foo"); - Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); this.protocolHandler.handleMessageToClient(this.session, message); assertEquals(1, this.session.getSentMessages().size()); @@ -278,8 +278,9 @@ public class StompSubProtocolHandlerTests { @Test public void handleMessageFromClientInvalidStompCommand() { - TextMessage textMessage = new TextMessage("FOO"); + TextMessage textMessage = new TextMessage("FOO\n\n\0"); + this.protocolHandler.afterSessionStarted(this.session, this.channel); this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); verifyZeroInteractions(this.channel); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java index 10113ca48c..3318e686d6 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java @@ -128,7 +128,7 @@ public class StompWebSocketIntegrationTests extends AbstractWebSocketIntegration assertTrue(clientHandler.latch.await(2, TimeUnit.SECONDS)); String payload = clientHandler.actual.get(0).getPayload(); - assertTrue("Expected STOMP Command=MESSAGE, got " + payload, payload.startsWith("MESSAGE\n")); + assertTrue("Expected STOMP MESSAGE, got " + payload, payload.startsWith("MESSAGE\n")); } finally { session.close();