diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java index fded86db5bf..969924b6b22 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java @@ -175,10 +175,8 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec } if (SimpMessageType.MESSAGE.equals(SimpMessageHeaderAccessor.getMessageType(message.getHeaders()))) { SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message); - if (getHeaderInitializer() != null) { - getHeaderInitializer().initHeaders(headerAccessor); - } - headerAccessor.setHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION, result.getSubscribeDestination()); + initHeaders(headerAccessor); + headerAccessor.setNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION, result.getSubscribeDestination()); message = MessageBuilder.createMessage(message.getPayload(), headerAccessor.getMessageHeaders()); } for (String destination : destinations) { @@ -189,4 +187,10 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec } } + private void initHeaders(SimpMessageHeaderAccessor headerAccessor) { + if (getHeaderInitializer() != null) { + getHeaderInitializer().initHeaders(headerAccessor); + } + } + } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java index 120d4e2d2a3..5712ece96b6 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java @@ -91,8 +91,9 @@ public class UserDestinationMessageHandlerTests { ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); Mockito.verify(this.brokerChannel).send(captor.capture()); - assertEquals("/queue/foo-user123", SimpMessageHeaderAccessor.getDestination(captor.getValue().getHeaders())); - assertEquals("/user/queue/foo", captor.getValue().getHeaders().get(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION)); + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(captor.getValue()); + assertEquals("/queue/foo-user123", accessor.getDestination()); + assertEquals("/user/queue/foo", accessor.getFirstNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION)); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 60a8bf2aaad..2d45fa4188a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -298,10 +298,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE logger.error("Ignoring message, no subscriptionId header: " + message); return; } - String header = SimpMessageHeaderAccessor.ORIGINAL_DESTINATION; - if (message.getHeaders().containsKey(header)) { + String origDestination = stompAccessor.getFirstNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION); + if (origDestination != null) { stompAccessor = toMutableAccessor(stompAccessor, message); - stompAccessor.setDestination((String) message.getHeaders().get(header)); + stompAccessor.setDestination(origDestination); } } else if (StompCommand.CONNECTED.equals(command)) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 032aaae0a08..3f9ffb3d2f4 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -239,7 +239,7 @@ public class StompSubProtocolHandlerTests { headers.setMessageId("mess0"); headers.setSubscriptionId("sub0"); headers.setDestination("/queue/foo-user123"); - headers.setHeader(StompHeaderAccessor.ORIGINAL_DESTINATION, "/user/queue/foo"); + headers.setNativeHeader(StompHeaderAccessor.ORIGINAL_DESTINATION, "/user/queue/foo"); Message message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders()); this.protocolHandler.handleMessageToClient(this.session, message);