diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java index c9aaeaffe7a..e51a0f9fb8d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java @@ -160,6 +160,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { SimpMessageHeaderAccessor connectAck = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); initHeaders(connectAck); connectAck.setSessionId(sessionId); + connectAck.setUser(SimpMessageHeaderAccessor.getUser(headers)); connectAck.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message); Message messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAck.getMessageHeaders()); getClientOutboundChannel().send(messageOut); @@ -172,6 +173,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler { SimpMessageHeaderAccessor disconnectAck = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); initHeaders(disconnectAck); disconnectAck.setSessionId(sessionId); + disconnectAck.setUser(SimpMessageHeaderAccessor.getUser(headers)); Message messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, disconnectAck.getMessageHeaders()); getClientOutboundChannel().send(messageOut); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index a5bd08a50d7..a9a1f496df2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -566,8 +566,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler getHeaderInitializer().initHeaders(headerAccessor); } headerAccessor.setSessionId(this.sessionId); + headerAccessor.setUser(this.connectHeaders.getUser()); headerAccessor.setMessage(errorText); Message errorMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders()); + headerAccessor.setImmutable(); sendMessageToClient(errorMessage); } } @@ -582,6 +584,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler public void handleMessage(Message message) { StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); accessor.setSessionId(this.sessionId); + accessor.setUser(this.connectHeaders.getUser()); StompCommand command = accessor.getCommand(); if (StompCommand.CONNECTED.equals(command)) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java index f36fa893824..4504e5c4690 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java @@ -33,6 +33,7 @@ import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.support.MessageBuilder; /** @@ -109,6 +110,7 @@ public class SimpleBrokerMessageHandlerTests { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); headers.setSessionId(sess1); + headers.setUser(new TestPrincipal("joe")); Message message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); this.messageHandler.handleMessage(message); @@ -120,6 +122,7 @@ public class SimpleBrokerMessageHandlerTests { Message captured = this.messageCaptor.getAllValues().get(0); assertEquals(SimpMessageType.DISCONNECT_ACK, SimpMessageHeaderAccessor.getMessageType(captured.getHeaders())); assertEquals(sess1, SimpMessageHeaderAccessor.getSessionId(captured.getHeaders())); + assertEquals("joe", SimpMessageHeaderAccessor.getUser(captured.getHeaders()).getName()); assertCapturedMessage(sess2, "sub1", "/foo"); assertCapturedMessage(sess2, "sub2", "/foo"); @@ -142,6 +145,7 @@ public class SimpleBrokerMessageHandlerTests { SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.wrap(connectAckMessage); assertEquals(connectMessage, connectAckHeaders.getHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER)); assertEquals(sess1, connectAckHeaders.getSessionId()); + assertEquals("joe", connectAckHeaders.getUser().getName()); } @@ -156,6 +160,7 @@ public class SimpleBrokerMessageHandlerTests { protected Message createConnectMessage(String sessionId) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); headers.setSessionId(sessionId); + headers.setUser(new TestPrincipal("joe")); return MessageBuilder.createMessage("", headers.getMessageHeaders()); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java index 1dc1b6e0c15..6a4f2d7b6e6 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import org.springframework.messaging.Message; import org.springframework.messaging.StubMessageChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.tcp.ReconnectStrategy; @@ -47,16 +48,18 @@ public class StompBrokerRelayMessageHandlerTests { private StompBrokerRelayMessageHandler brokerRelay; + private StubMessageChannel outboundChannel; + private StubTcpOperations tcpClient; @Before public void setup() { - this.tcpClient = new StubTcpOperations(); + this.outboundChannel = new StubMessageChannel(); this.brokerRelay = new StompBrokerRelayMessageHandler(new StubMessageChannel(), - new StubMessageChannel(), new StubMessageChannel(), Arrays.asList("/topic")) { + this.outboundChannel, new StubMessageChannel(), Arrays.asList("/topic")) { @Override protected void startInternal() { @@ -65,6 +68,7 @@ public class StompBrokerRelayMessageHandlerTests { } }; + this.tcpClient = new StubTcpOperations(); this.brokerRelay.setTcpClient(this.tcpClient); } @@ -141,6 +145,31 @@ public class StompBrokerRelayMessageHandlerTests { MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class)); } + @Test + public void testOutboundMessage() throws Exception { + + this.brokerRelay.start(); + + String sessionId = "sess1"; + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setSessionId(sessionId); + headers.setUser(new TestPrincipal("joe")); + this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders())); + + List> sent = this.tcpClient.connection.messages; + assertEquals(2, sent.size()); + + StompHeaderAccessor responseHeaders = StompHeaderAccessor.create(StompCommand.MESSAGE); + responseHeaders.setLeaveMutable(true); + Message response = MessageBuilder.createMessage(new byte[0], responseHeaders.getMessageHeaders()); + this.tcpClient.connectionHandler.handleMessage(response); + + Message actual = this.outboundChannel.getMessages().get(0); + StompHeaderAccessor actualHeaders = StompHeaderAccessor.getAccessor(actual, StompHeaderAccessor.class); + assertEquals(sessionId, actualHeaders.getSessionId()); + assertEquals("joe", actualHeaders.getUser().getName()); + } + private static ListenableFutureTask getVoidFuture() { ListenableFutureTask futureTask = new ListenableFutureTask<>(new Callable() { @@ -169,15 +198,19 @@ public class StompBrokerRelayMessageHandlerTests { private StubTcpConnection connection = new StubTcpConnection(); + private TcpConnectionHandler connectionHandler; + @Override public ListenableFuture connect(TcpConnectionHandler connectionHandler) { + this.connectionHandler = connectionHandler; connectionHandler.afterConnected(this.connection); return getVoidFuture(); } @Override public ListenableFuture connect(TcpConnectionHandler connectionHandler, ReconnectStrategy reconnectStrategy) { + this.connectionHandler = connectionHandler; connectionHandler.afterConnected(this.connection); return getVoidFuture(); }