Enrich CONNECTED frames with Principal

Issue: SPR-12479
This commit is contained in:
Rossen Stoyanchev 2014-12-02 14:30:36 -05:00
parent dc5b5ca8ee
commit fa89ae244f
4 changed files with 46 additions and 3 deletions

View File

@ -160,6 +160,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
SimpMessageHeaderAccessor connectAck = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
initHeaders(connectAck);
connectAck.setSessionId(sessionId);
connectAck.setUser(SimpMessageHeaderAccessor.getUser(headers));
connectAck.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
Message<byte[]> messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAck.getMessageHeaders());
getClientOutboundChannel().send(messageOut);
@ -172,6 +173,7 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
SimpMessageHeaderAccessor disconnectAck = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
initHeaders(disconnectAck);
disconnectAck.setSessionId(sessionId);
disconnectAck.setUser(SimpMessageHeaderAccessor.getUser(headers));
Message<byte[]> messageOut = MessageBuilder.createMessage(EMPTY_PAYLOAD, disconnectAck.getMessageHeaders());
getClientOutboundChannel().send(messageOut);
}

View File

@ -566,8 +566,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
getHeaderInitializer().initHeaders(headerAccessor);
}
headerAccessor.setSessionId(this.sessionId);
headerAccessor.setUser(this.connectHeaders.getUser());
headerAccessor.setMessage(errorText);
Message<?> errorMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders());
headerAccessor.setImmutable();
sendMessageToClient(errorMessage);
}
}
@ -582,6 +584,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
public void handleMessage(Message<byte[]> message) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
accessor.setSessionId(this.sessionId);
accessor.setUser(this.connectHeaders.getUser());
StompCommand command = accessor.getCommand();
if (StompCommand.CONNECTED.equals(command)) {

View File

@ -33,6 +33,7 @@ import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.support.MessageBuilder;
/**
@ -109,6 +110,7 @@ public class SimpleBrokerMessageHandlerTests {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
headers.setSessionId(sess1);
headers.setUser(new TestPrincipal("joe"));
Message<byte[]> message = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders());
this.messageHandler.handleMessage(message);
@ -120,6 +122,7 @@ public class SimpleBrokerMessageHandlerTests {
Message<?> captured = this.messageCaptor.getAllValues().get(0);
assertEquals(SimpMessageType.DISCONNECT_ACK, SimpMessageHeaderAccessor.getMessageType(captured.getHeaders()));
assertEquals(sess1, SimpMessageHeaderAccessor.getSessionId(captured.getHeaders()));
assertEquals("joe", SimpMessageHeaderAccessor.getUser(captured.getHeaders()).getName());
assertCapturedMessage(sess2, "sub1", "/foo");
assertCapturedMessage(sess2, "sub2", "/foo");
@ -142,6 +145,7 @@ public class SimpleBrokerMessageHandlerTests {
SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.wrap(connectAckMessage);
assertEquals(connectMessage, connectAckHeaders.getHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER));
assertEquals(sess1, connectAckHeaders.getSessionId());
assertEquals("joe", connectAckHeaders.getUser().getName());
}
@ -156,6 +160,7 @@ public class SimpleBrokerMessageHandlerTests {
protected Message<String> createConnectMessage(String sessionId) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
headers.setSessionId(sessionId);
headers.setUser(new TestPrincipal("joe"));
return MessageBuilder.createMessage("", headers.getMessageHeaders());
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.StubMessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.ReconnectStrategy;
@ -47,16 +48,18 @@ public class StompBrokerRelayMessageHandlerTests {
private StompBrokerRelayMessageHandler brokerRelay;
private StubMessageChannel outboundChannel;
private StubTcpOperations tcpClient;
@Before
public void setup() {
this.tcpClient = new StubTcpOperations();
this.outboundChannel = new StubMessageChannel();
this.brokerRelay = new StompBrokerRelayMessageHandler(new StubMessageChannel(),
new StubMessageChannel(), new StubMessageChannel(), Arrays.asList("/topic")) {
this.outboundChannel, new StubMessageChannel(), Arrays.asList("/topic")) {
@Override
protected void startInternal() {
@ -65,6 +68,7 @@ public class StompBrokerRelayMessageHandlerTests {
}
};
this.tcpClient = new StubTcpOperations();
this.brokerRelay.setTcpClient(this.tcpClient);
}
@ -141,6 +145,31 @@ public class StompBrokerRelayMessageHandlerTests {
MessageHeaderAccessor.getAccessor(sent.get(0), MessageHeaderAccessor.class));
}
@Test
public void testOutboundMessage() throws Exception {
this.brokerRelay.start();
String sessionId = "sess1";
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId);
headers.setUser(new TestPrincipal("joe"));
this.brokerRelay.handleMessage(MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()));
List<Message<byte[]>> sent = this.tcpClient.connection.messages;
assertEquals(2, sent.size());
StompHeaderAccessor responseHeaders = StompHeaderAccessor.create(StompCommand.MESSAGE);
responseHeaders.setLeaveMutable(true);
Message<byte[]> response = MessageBuilder.createMessage(new byte[0], responseHeaders.getMessageHeaders());
this.tcpClient.connectionHandler.handleMessage(response);
Message<byte[]> actual = this.outboundChannel.getMessages().get(0);
StompHeaderAccessor actualHeaders = StompHeaderAccessor.getAccessor(actual, StompHeaderAccessor.class);
assertEquals(sessionId, actualHeaders.getSessionId());
assertEquals("joe", actualHeaders.getUser().getName());
}
private static ListenableFutureTask<Void> getVoidFuture() {
ListenableFutureTask<Void> futureTask = new ListenableFutureTask<>(new Callable<Void>() {
@ -169,15 +198,19 @@ public class StompBrokerRelayMessageHandlerTests {
private StubTcpConnection connection = new StubTcpConnection();
private TcpConnectionHandler<byte[]> connectionHandler;
@Override
public ListenableFuture<Void> connect(TcpConnectionHandler<byte[]> connectionHandler) {
this.connectionHandler = connectionHandler;
connectionHandler.afterConnected(this.connection);
return getVoidFuture();
}
@Override
public ListenableFuture<Void> connect(TcpConnectionHandler<byte[]> connectionHandler, ReconnectStrategy reconnectStrategy) {
this.connectionHandler = connectionHandler;
connectionHandler.afterConnected(this.connection);
return getVoidFuture();
}