diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 4deba94bf5..5b6447e71a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -247,9 +247,19 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler if (logger.isDebugEnabled()) { logger.debug("Initializing \"system\" TCP connection"); } + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setAcceptVersion("1.1,1.2"); + headers.setLogin(this.systemLogin); + headers.setPasscode(this.systemPasscode); + headers.setHeartbeat(this.systemHeartbeatSendInterval, this.systemHeartbeatReceiveInterval); + headers.setHost(getVirtualHost()); + Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + SystemStompRelaySession session = new SystemStompRelaySession(); + session.connect(message); + this.relaySessions.put(session.getId(), session); - session.connect(); } @Override @@ -302,8 +312,8 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); } StompRelaySession session = new StompRelaySession(sessionId); - this.relaySessions.put(sessionId, session); session.connect(message); + this.relaySessions.put(sessionId, session); } else if (SimpMessageType.DISCONNECT.equals(messageType)) { StompRelaySession session = this.relaySessions.remove(sessionId); @@ -328,14 +338,30 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private class StompRelaySession { + private static final long HEARTBEAT_MULTIPLIER = 3; + private final String sessionId; + private final boolean isRemoteClientSession; + + private final long reconnectInterval; + private volatile StompConnection stompConnection = new StompConnection(); + private volatile StompHeaderAccessor connectHeaders; + + private volatile StompHeaderAccessor connectedHeaders; + private StompRelaySession(String sessionId) { + this(sessionId, true, 0); + } + + private StompRelaySession(String sessionId, boolean isRemoteClientSession, long reconnectInterval) { Assert.notNull(sessionId, "sessionId is required"); this.sessionId = sessionId; + this.isRemoteClientSession = isRemoteClientSession; + this.reconnectInterval = reconnectInterval; } @@ -344,9 +370,23 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } public void connect(final Message connectMessage) { - Assert.notNull(connectMessage, "connectMessage is required"); - Composable, Message>> promise = initConnection(); + Assert.notNull(connectMessage, "connectMessage is required"); + this.connectHeaders = StompHeaderAccessor.wrap(connectMessage); + + Composable, Message>> promise; + if (this.reconnectInterval > 0) { + promise = tcpClient.open(new Reconnect() { + @Override + public Tuple2 reconnect(InetSocketAddress address, int attempt) { + return Tuple.of(address, 5000L); + } + }); + } + else { + promise = tcpClient.open(); + } + promise.consume(new Consumer, Message>>() { @Override public void accept(TcpConnection, Message> connection) { @@ -366,10 +406,6 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler this.stompConnection.setDisconnected(); } - protected Composable, Message>> initConnection() { - return tcpClient.open(); - } - protected void handleConnectionReady( TcpConnection, Message> tcpConn, final Message connectMessage) { @@ -403,7 +439,8 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED == headers.getCommand()) { - connected(headers, this.stompConnection); + this.connectedHeaders = headers; + connected(); } headers.setSessionId(this.sessionId); @@ -411,7 +448,56 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler sendMessageToClient(message); } - protected void connected(StompHeaderAccessor headers, StompConnection stompConnection) { + private void initHeartbeats() { + + long clientSendInterval = this.connectHeaders.getHeartbeat()[0]; + long clientReceiveInterval = this.connectHeaders.getHeartbeat()[1]; + + long serverSendInterval = this.connectedHeaders.getHeartbeat()[0]; + long serverReceiveInterval = this.connectedHeaders.getHeartbeat()[1]; + + if ((clientSendInterval > 0) && (serverReceiveInterval > 0)) { + long interval = Math.max(clientSendInterval, serverReceiveInterval); + stompConnection.connection.on().writeIdle(interval, new Runnable() { + + @Override + public void run() { + TcpConnection, Message> tcpConn = stompConnection.connection; + if (tcpConn != null) { + tcpConn.send(MessageBuilder.withPayload(new byte[] {'\n'}).build(), + new Consumer() { + @Override + public void accept(Boolean result) { + if (!result) { + handleTcpClientFailure("Failed to send heartbeat to the broker", null); + } + } + }); + } + } + }); + } + + if (clientReceiveInterval > 0 && serverSendInterval > 0) { + final long interval = Math.max(clientReceiveInterval, serverSendInterval) * HEARTBEAT_MULTIPLIER; + stompConnection.connection.on().readIdle(interval, new Runnable() { + + @Override + public void run() { + String message = "Broker hearbeat missed: connection idle for more than " + interval + "ms"; + if (logger.isWarnEnabled()) { + logger.warn(message); + } + disconnected(message); + } + }); + } + } + + protected void connected() { + if (!this.isRemoteClientSession) { + initHeartbeats(); + } this.stompConnection.setReady(); } @@ -436,7 +522,18 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } protected void sendMessageToClient(Message message) { - messageChannel.send(message); + if (this.isRemoteClientSession) { + messageChannel.send(message); + } + else { + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + if (StompCommand.ERROR.equals(headers.getCommand())) { + if (logger.isErrorEnabled()) { + logger.error("STOMP ERROR on sessionId=" + this.sessionId + ": " + message); + } + } + // ignore otherwise + } } private void forward(Message message) { @@ -547,89 +644,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private class SystemStompRelaySession extends StompRelaySession { - private static final long HEARTBEAT_RECEIVE_MULTIPLIER = 3; - public static final String ID = "stompRelaySystemSessionId"; - private final byte[] heartbeatPayload = new byte[] {'\n'}; - public SystemStompRelaySession() { - super(ID); - } - - public void connect() { - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); - headers.setAcceptVersion("1.1,1.2"); - headers.setLogin(systemLogin); - headers.setPasscode(systemPasscode); - headers.setHeartbeat(systemHeartbeatSendInterval, systemHeartbeatReceiveInterval); - if (getVirtualHost() != null) { - headers.setHost(getVirtualHost()); - } - Message connectMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); - super.connect(connectMessage); + super(ID, false, 5000); } @Override - protected Composable, Message>> initConnection() { - return tcpClient.open(new Reconnect() { - @Override - public Tuple2 reconnect(InetSocketAddress address, int attempt) { - return Tuple.of(address, 5000L); - } - }); - } - - @Override - protected void connectionClosed() { - publishBrokerUnavailableEvent(); - } - - @Override - protected void connected(StompHeaderAccessor headers, final StompConnection stompConnection) { - - long brokerReceiveInterval = headers.getHeartbeat()[1]; - if ((systemHeartbeatSendInterval > 0) && (brokerReceiveInterval > 0)) { - long interval = Math.max(systemHeartbeatSendInterval, brokerReceiveInterval); - stompConnection.connection.on().writeIdle(interval, new Runnable() { - - @Override - public void run() { - TcpConnection, Message> tcpConn = stompConnection.connection; - if (tcpConn != null) { - tcpConn.send(MessageBuilder.withPayload(heartbeatPayload).build(), - new Consumer() { - @Override - public void accept(Boolean result) { - if (!result) { - handleTcpClientFailure("Failed to send heartbeat to the broker", null); - } - } - }); - } - } - }); - } - - long brokerSendInterval = headers.getHeartbeat()[0]; - if (systemHeartbeatReceiveInterval > 0 && brokerSendInterval > 0) { - final long interval = Math.max(systemHeartbeatReceiveInterval, brokerSendInterval) - * HEARTBEAT_RECEIVE_MULTIPLIER; - stompConnection.connection.on().readIdle(interval, new Runnable() { - - @Override - public void run() { - String message = "Broker hearbeat missed: connection idle for more than " + interval + "ms"; - if (logger.isWarnEnabled()) { - logger.warn(message); - } - disconnected(message); - } - }); - } - - super.connected(headers, stompConnection); + protected void connected() { + super.connected(); publishBrokerAvailableEvent(); } @@ -640,16 +664,8 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } @Override - protected void sendMessageToClient(Message message) { - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - if (StompCommand.ERROR.equals(headers.getCommand())) { - if (logger.isErrorEnabled()) { - logger.error("STOMP ERROR frame on system session: " + message); - } - } - else { - // Ignore - } + protected void connectionClosed() { + publishBrokerUnavailableEvent(); } @Override diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index 25c3cbe9bc..40eb1a5636 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -46,6 +46,8 @@ import org.springframework.util.SocketUtils; import static org.junit.Assert.*; /** + * Integration tests for {@link StompBrokerRelayMessageHandler} running against ActiveMQ. + * * @author Rossen Stoyanchev */ public class StompBrokerRelayMessageHandlerIntegrationTests { @@ -91,14 +93,13 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } private void createAndStartRelay() throws InterruptedException { - this.relay = new StompBrokerRelayMessageHandler( - this.responseChannel, Arrays.asList("/queue/", "/topic/")); + this.relay = new StompBrokerRelayMessageHandler(this.responseChannel, Arrays.asList("/queue/", "/topic/")); this.relay.setRelayPort(port); this.relay.setApplicationEventPublisher(this.eventPublisher); this.relay.setSystemHeartbeatReceiveInterval(0); this.relay.setSystemHeartbeatSendInterval(0); - this.eventPublisher.expect(true); + this.eventPublisher.expectAvailabilityStatusChanges(true); this.relay.start(); this.eventPublisher.awaitAndAssert(); } @@ -186,7 +187,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { @Test public void brokerAvailabilityEventWhenStopped() throws Exception { - this.eventPublisher.expect(false); + this.eventPublisher.expectAvailabilityStatusChanges(false); stopBrokerAndAwait(); this.eventPublisher.awaitAndAssert(); } @@ -215,10 +216,10 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { this.responseHandler.awaitAndAssert(); - this.eventPublisher.expect(false); + this.eventPublisher.expectAvailabilityStatusChanges(false); this.eventPublisher.awaitAndAssert(); - this.eventPublisher.expect(true); + this.eventPublisher.expectAvailabilityStatusChanges(true); createAndStartBroker(); this.eventPublisher.awaitAndAssert(); @@ -587,7 +588,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { private final Object monitor = new Object(); - public void expect(Boolean... expected) { + public void expectAvailabilityStatusChanges(Boolean... expected) { synchronized (this.monitor) { this.expected.addAll(Arrays.asList(expected)); }