diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index d15fdac6229..e4081df3bfd 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -662,6 +662,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler return this.sessionId; } + public StompHeaderAccessor getConnectHeaders() { + return this.connectHeaders; + } + @Nullable protected TcpConnection getTcpConnection() { return this.tcpConnection; @@ -768,39 +772,13 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler protected void afterStompConnected(StompHeaderAccessor connectedHeaders) { this.isStompConnected = true; stats.incrementConnectedCount(); - if (this.isRemoteClientSession) { - if (taskScheduler != null) { - long interval = connectedHeaders.getHeartbeat()[1]; - this.clientSendInterval = Math.max(interval, this.clientSendInterval); - } - } - else { - // system session - initHeartbeats(connectedHeaders); - } + initHeartbeats(connectedHeaders); } - private void initHeartbeats(StompHeaderAccessor connectedHeaders) { - TcpConnection con = this.tcpConnection; - Assert.state(con != null, "No TcpConnection available"); - - long clientSendInterval = this.connectHeaders.getHeartbeat()[0]; - long clientReceiveInterval = this.connectHeaders.getHeartbeat()[1]; - long serverSendInterval = connectedHeaders.getHeartbeat()[0]; - long serverReceiveInterval = connectedHeaders.getHeartbeat()[1]; - - if (clientSendInterval > 0 && serverReceiveInterval > 0) { - long interval = Math.max(clientSendInterval, serverReceiveInterval); - con.onWriteInactivity(() -> - con.send(HEARTBEAT_MESSAGE).addCallback( - result -> {}, - ex -> handleTcpConnectionFailure( - "Failed to forward heartbeat: " + ex.getMessage(), ex)), interval); - } - if (clientReceiveInterval > 0 && serverSendInterval > 0) { - final long interval = Math.max(clientReceiveInterval, serverSendInterval) * HEARTBEAT_MULTIPLIER; - con.onReadInactivity( - () -> handleTcpConnectionFailure("No messages received in " + interval + " ms.", null), interval); + protected void initHeartbeats(StompHeaderAccessor connectedHeaders) { + if (taskScheduler != null) { + long interval = connectedHeaders.getHeartbeat()[1]; + this.clientSendInterval = Math.max(interval, this.clientSendInterval); } } @@ -1006,6 +984,30 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler sendSystemSubscriptions(); } + protected void initHeartbeats(StompHeaderAccessor connectedHeaders) { + TcpConnection con = getTcpConnection(); + Assert.state(con != null, "No TcpConnection available"); + + long clientSendInterval = getConnectHeaders().getHeartbeat()[0]; + long clientReceiveInterval = getConnectHeaders().getHeartbeat()[1]; + long serverSendInterval = connectedHeaders.getHeartbeat()[0]; + long serverReceiveInterval = connectedHeaders.getHeartbeat()[1]; + + if (clientSendInterval > 0 && serverReceiveInterval > 0) { + long interval = Math.max(clientSendInterval, serverReceiveInterval); + con.onWriteInactivity(() -> + con.send(HEARTBEAT_MESSAGE).addCallback( + result -> {}, + ex -> handleTcpConnectionFailure( + "Failed to forward heartbeat: " + ex.getMessage(), ex)), interval); + } + if (clientReceiveInterval > 0 && serverSendInterval > 0) { + final long interval = Math.max(clientReceiveInterval, serverSendInterval) * HEARTBEAT_MULTIPLIER; + con.onReadInactivity( + () -> handleTcpConnectionFailure("No messages received in " + interval + " ms.", null), interval); + } + } + private void sendSystemSubscriptions() { int i = 0; for (String destination : getSystemSubscriptions().keySet()) {