Add reconnect logic to the relay's system session

Upgrade to Reactor snapshot builds to take advantage of TcpClient's
reconnect support that was added post-M1. Now, the system relay session
will try every 5 seconds to establish a connection with the broker, both
when first connecting and in the event of subsequently becoming
disconnected.

A more sophisticated reconnection policy, including back off and
failover to different brokers, is possible with the Reactor API. We may
want to enhance the relay's reconnection policy in the future.

Typically, a broken connection is identified by the failure to forward
a message to the broker. As things stand, the message id then discarded.
Any further messages that are forwarded before the connection's been
re-established are queued for forwarding once the CONNECTED frame's been
received. We may want to consider also queueing the message that failed
to send, however we would then need to consider the possibility of the
message itself being what caused the broker to close the connection
and resending it would simply cause the connection to be closed again.
This commit is contained in:
Andy Wilkinson 2013-08-14 11:46:09 +01:00 committed by Rossen Stoyanchev
parent 8b48d8f445
commit 131b5de6f9
4 changed files with 170 additions and 81 deletions

View File

@ -318,8 +318,8 @@ project("spring-messaging") {
compile(project(":spring-context")) compile(project(":spring-context"))
optional(project(":spring-websocket")) optional(project(":spring-websocket"))
optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") optional("com.fasterxml.jackson.core:jackson-databind:2.2.0")
optional("org.projectreactor:reactor-core:1.0.0.M1") optional("org.projectreactor:reactor-core:1.0.0.BUILD-SNAPSHOT")
optional("org.projectreactor:reactor-tcp:1.0.0.M1") optional("org.projectreactor:reactor-tcp:1.0.0.BUILD-SNAPSHOT")
optional("com.lmax:disruptor:3.1.1") optional("com.lmax:disruptor:3.1.1")
testCompile("commons-dbcp:commons-dbcp:1.2.2") testCompile("commons-dbcp:commons-dbcp:1.2.2")
testCompile("javax.inject:javax.inject-tck:1") testCompile("javax.inject:javax.inject-tck:1")
@ -328,6 +328,7 @@ project("spring-messaging") {
repositories { repositories {
maven { url 'http://repo.springsource.org/libs-milestone' } // reactor maven { url 'http://repo.springsource.org/libs-milestone' } // reactor
maven { url 'http://repo.springsource.org/libs-snapshot' } // reactor
} }
} }

View File

@ -16,6 +16,7 @@
package org.springframework.messaging.simp.stomp; package org.springframework.messaging.simp.stomp;
import java.net.InetSocketAddress;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
@ -41,16 +42,20 @@ import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import reactor.core.Environment; import reactor.core.Environment;
import reactor.core.composable.Composable;
import reactor.core.composable.Deferred; import reactor.core.composable.Deferred;
import reactor.core.composable.Promise; import reactor.core.composable.Promise;
import reactor.core.composable.spec.DeferredPromiseSpec; import reactor.core.composable.spec.DeferredPromiseSpec;
import reactor.function.Consumer; import reactor.function.Consumer;
import reactor.tcp.Reconnect;
import reactor.tcp.TcpClient; import reactor.tcp.TcpClient;
import reactor.tcp.TcpConnection; import reactor.tcp.TcpConnection;
import reactor.tcp.encoding.DelimitedCodec; import reactor.tcp.encoding.DelimitedCodec;
import reactor.tcp.encoding.StandardCodecs; import reactor.tcp.encoding.StandardCodecs;
import reactor.tcp.netty.NettyTcpClient; import reactor.tcp.netty.NettyTcpClient;
import reactor.tcp.spec.TcpClientSpec; import reactor.tcp.spec.TcpClientSpec;
import reactor.tuple.Tuple;
import reactor.tuple.Tuple2;
/** /**
@ -219,12 +224,24 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
private void openSystemSession() { private void openSystemSession() {
RelaySession session = new RelaySession(STOMP_RELAY_SYSTEM_SESSION_ID) { RelaySession session = new RelaySession(STOMP_RELAY_SYSTEM_SESSION_ID) {
@Override @Override
protected void sendMessageToClient(Message<?> message) { protected void sendMessageToClient(Message<?> message) {
// ignore, only used to send messages // ignore, only used to send messages
// TODO: ERROR frame/reconnect // TODO: ERROR frame/reconnect
} }
@Override
protected Composable<TcpConnection<String, String>> openConnection() {
return tcpClient.open(new Reconnect() {
@Override
public Tuple2<InetSocketAddress, Long> reconnect(InetSocketAddress currentAddress, int attempt) {
return Tuple.of(currentAddress, 5000L);
}
});
}
}; };
this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session); this.relaySessions.put(STOMP_RELAY_SYSTEM_SESSION_ID, session);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
@ -376,7 +393,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
private final BlockingQueue<Message<?>> messageQueue = new LinkedBlockingQueue<Message<?>>(50); private final BlockingQueue<Message<?>> messageQueue = new LinkedBlockingQueue<Message<?>>(50);
private Promise<TcpConnection<String, String>> promise; private volatile TcpConnection<String, String> connection = null;
private volatile boolean isConnected = false; private volatile boolean isConnected = false;
@ -391,21 +408,24 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
public void open(final Message<?> message) { public void open(final Message<?> message) {
Assert.notNull(message, "message is required"); Assert.notNull(message, "message is required");
this.promise = tcpClient.open(); Composable<TcpConnection<String, String>> connectionComposable = openConnection();
this.promise.consume(new Consumer<TcpConnection<String,String>>() { connectionComposable.consume(new Consumer<TcpConnection<String, String>>() {
@Override @Override
public void accept(TcpConnection<String, String> connection) { public void accept(TcpConnection<String, String> newConnection) {
connection.in().consume(new Consumer<String>() { isConnected = false;
connection = newConnection;
newConnection.in().consume(new Consumer<String>() {
@Override @Override
public void accept(String stompFrame) { public void accept(String stompFrame) {
readStompFrame(stompFrame); readStompFrame(stompFrame);
} }
}); });
forwardInternal(message, connection); forwardInternal(message);
} }
}); });
this.promise.onError(new Consumer<Throwable>() {
connectionComposable.when(Throwable.class, new Consumer<Throwable>() {
@Override @Override
public void accept(Throwable ex) { public void accept(Throwable ex) {
relaySessions.remove(sessionId); relaySessions.remove(sessionId);
@ -415,6 +435,10 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
}); });
} }
protected Composable<TcpConnection<String, String>> openConnection() {
return tcpClient.open();
}
private void readStompFrame(String stompFrame) { private void readStompFrame(String stompFrame) {
// heartbeat // heartbeat
@ -432,7 +456,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
synchronized(this.monitor) { synchronized(this.monitor) {
this.isConnected = true; this.isConnected = true;
brokerAvailable(); brokerAvailable();
flushMessages(this.promise.get()); flushMessages();
} }
return; return;
} }
@ -447,7 +471,7 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
} }
private void sendError(String sessionId, String errorText) { private void sendError(String sessionId, String errorText) {
brokerUnavailable(); disconnect();
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setSessionId(sessionId); headers.setSessionId(sessionId);
@ -456,6 +480,14 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
sendMessageToClient(errorMessage); sendMessageToClient(errorMessage);
} }
private void disconnect() {
this.isConnected = false;
this.connection.close();
this.connection = null;
brokerUnavailable();
}
public void forward(Message<?> message) { public void forward(Message<?> message) {
if (!this.isConnected) { if (!this.isConnected) {
@ -463,25 +495,26 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
if (!this.isConnected) { if (!this.isConnected) {
this.messageQueue.add(message); this.messageQueue.add(message);
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Not connected yet, message queued, queue size=" + this.messageQueue.size()); logger.trace("Not connected, message queued. Queue size=" + this.messageQueue.size());
} }
return; return;
} }
} }
} }
TcpConnection<String, String> connection = this.promise.get();
if (this.messageQueue.isEmpty()) { if (this.messageQueue.isEmpty()) {
forwardInternal(message, connection); forwardInternal(message);
} }
else { else {
this.messageQueue.add(message); this.messageQueue.add(message);
flushMessages(connection); flushMessages();
} }
} }
private boolean forwardInternal(final Message<?> message, TcpConnection<String, String> connection) { private boolean forwardInternal(final Message<?> message) {
TcpConnection<String, String> localConnection = this.connection;
if (localConnection != null) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId()); logger.trace("Forwarding message to STOMP broker, message id=" + message.getHeaders().getId());
} }
@ -489,7 +522,8 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
final Deferred<Boolean, Promise<Boolean>> deferred = new DeferredPromiseSpec<Boolean>().get(); final Deferred<Boolean, Promise<Boolean>> deferred = new DeferredPromiseSpec<Boolean>().get();
connection.send(new String(bytes, Charset.forName("UTF-8")), new Consumer<Boolean>() { String payload = new String(bytes, Charset.forName("UTF-8"));
localConnection.send(payload, new Consumer<Boolean>() {
@Override @Override
public void accept(Boolean success) { public void accept(Boolean success) {
@ -522,13 +556,16 @@ public class StompBrokerRelayMessageHandler implements MessageHandler, SmartLife
} }
return success; return success;
} else {
return false;
}
} }
private void flushMessages(TcpConnection<String, String> connection) { private void flushMessages() {
List<Message<?>> messages = new ArrayList<Message<?>>(); List<Message<?>> messages = new ArrayList<Message<?>>();
this.messageQueue.drainTo(messages); this.messageQueue.drainTo(messages);
for (Message<?> message : messages) { for (Message<?> message : messages) {
if (!forwardInternal(message, connection)) { if (!forwardInternal(message)) {
return; return;
} }
} }

View File

@ -43,8 +43,8 @@ import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.util.SocketUtils; import org.springframework.util.SocketUtils;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertEquals;
/** /**
@ -81,7 +81,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
final CountDownLatch messageLatch = new CountDownLatch(1); final CountDownLatch messageLatch = new CountDownLatch(1);
messageChannel.subscribe(new MessageHandler() { this.messageChannel.subscribe(new MessageHandler() {
@Override @Override
public void handleMessage(Message<?> message) throws MessagingException { public void handleMessage(Message<?> message) throws MessagingException {
@ -93,18 +93,18 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
}); });
relay.handleMessage(createConnectMessage(client1SessionId)); this.relay.handleMessage(createConnectMessage(client1SessionId));
relay.handleMessage(createConnectMessage(client2SessionId)); this.relay.handleMessage(createConnectMessage(client2SessionId));
relay.handleMessage(createSubscribeMessage(client1SessionId, "/topic/test")); this.relay.handleMessage(createSubscribeMessage(client1SessionId, "/topic/test"));
stompBroker.awaitMessages(4); this.stompBroker.awaitMessages(4);
relay.handleMessage(createSendMessage(client2SessionId, "/topic/test", "fromClient2")); this.relay.handleMessage(createSendMessage(client2SessionId, "/topic/test", "fromClient2"));
assertTrue(messageLatch.await(30, TimeUnit.SECONDS)); assertTrue(messageLatch.await(30, TimeUnit.SECONDS));
assertEquals(1, brokerAvailabilityListener.availabilityEvents.size()); List<BrokerAvailabilityEvent> availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1);
assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent);
} }
@Test @Test
@ -115,7 +115,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
final CountDownLatch errorLatch = new CountDownLatch(1); final CountDownLatch errorLatch = new CountDownLatch(1);
messageChannel.subscribe(new MessageHandler() { this.messageChannel.subscribe(new MessageHandler() {
@Override @Override
public void handleMessage(Message<?> message) throws MessagingException { public void handleMessage(Message<?> message) throws MessagingException {
@ -127,20 +127,20 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
}); });
stompBroker.awaitMessages(1); this.stompBroker.awaitMessages(1);
assertEquals(1, brokerAvailabilityListener.availabilityEvents.size()); List<BrokerAvailabilityEvent> availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1);
assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent);
stompBroker.stop(); this.stompBroker.stop();
relay.handleMessage(createConnectMessage(sessionId)); this.relay.handleMessage(createConnectMessage(sessionId));
errorLatch.await(30, TimeUnit.SECONDS); errorLatch.await(30, TimeUnit.SECONDS);
assertEquals(2, brokerAvailabilityListener.availabilityEvents.size()); availabilityEvents = brokerAvailabilityListener.awaitAvailabilityEvents(2);
assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent);
assertTrue(brokerAvailabilityListener.availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); assertTrue(availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent);
} }
@Test @Test
@ -151,7 +151,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
final CountDownLatch errorLatch = new CountDownLatch(1); final CountDownLatch errorLatch = new CountDownLatch(1);
messageChannel.subscribe(new MessageHandler() { this.messageChannel.subscribe(new MessageHandler() {
@Override @Override
public void handleMessage(Message<?> message) throws MessagingException { public void handleMessage(Message<?> message) throws MessagingException {
@ -163,22 +163,51 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
}); });
relay.handleMessage(createConnectMessage(sessionId)); this.relay.handleMessage(createConnectMessage(sessionId));
stompBroker.awaitMessages(2); this.stompBroker.awaitMessages(2);
assertEquals(1, brokerAvailabilityListener.availabilityEvents.size()); List<BrokerAvailabilityEvent> availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1);
assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent);
stompBroker.stop(); this.stompBroker.stop();
relay.handleMessage(createSubscribeMessage(sessionId, "/topic/test/")); this.relay.handleMessage(createSubscribeMessage(sessionId, "/topic/test/"));
errorLatch.await(30, TimeUnit.SECONDS); errorLatch.await(30, TimeUnit.SECONDS);
assertEquals(2, brokerAvailabilityListener.availabilityEvents.size()); availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1);
assertTrue(brokerAvailabilityListener.availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent); assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent);
assertTrue(brokerAvailabilityListener.availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent); assertTrue(availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent);
}
@Test
public void relayReconnectsIfTheBrokerComesBackUp() throws InterruptedException {
List<BrokerAvailabilityEvent> availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(1);
assertTrue(availabilityEvents.get(0) instanceof BrokerBecameAvailableEvent);
List<Message<?>> messages = this.stompBroker.awaitMessages(1);
assertEquals(1, messages.size());
assertStompCommand(messages.get(0), StompCommand.CONNECT);
this.stompBroker.stop();
this.relay.handleMessage(createSendMessage(null, "/topic/test", "test"));
availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(2);
assertTrue(availabilityEvents.get(1) instanceof BrokerBecameUnavailableEvent);
this.relay.handleMessage(createSendMessage(null, "/topic/test", "test-again"));
this.stompBroker.start();
messages = this.stompBroker.awaitMessages(3);
assertEquals(3, messages.size());
assertStompCommand(messages.get(1), StompCommand.CONNECT);
assertStompCommandAndPayload(messages.get(2), StompCommand.SEND, "test-again");
availabilityEvents = this.brokerAvailabilityListener.awaitAvailabilityEvents(3);
assertTrue(availabilityEvents.get(2) instanceof BrokerBecameAvailableEvent);
} }
private Message<?> createConnectMessage(String sessionId) { private Message<?> createConnectMessage(String sessionId) {
@ -204,6 +233,16 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
return MessageBuilder.withPayloadAndHeaders(payload.getBytes(), headers).build(); return MessageBuilder.withPayloadAndHeaders(payload.getBytes(), headers).build();
} }
private void assertStompCommand(Message<?> message, StompCommand expectedCommand) {
assertEquals(expectedCommand, StompHeaderAccessor.wrap(message).getCommand());
}
private void assertStompCommandAndPayload(Message<?> message, StompCommand expectedCommand,
String expectedPayload) {
assertStompCommand(message, expectedCommand);
assertEquals(expectedPayload, new String(((byte[])message.getPayload())));
}
@Configuration @Configuration
public static class TestConfiguration { public static class TestConfiguration {
@ -233,14 +272,27 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
} }
} }
private static class BrokerAvailabilityListener implements ApplicationListener<BrokerAvailabilityEvent> { private static class BrokerAvailabilityListener implements ApplicationListener<BrokerAvailabilityEvent> {
private final List<BrokerAvailabilityEvent> availabilityEvents = new ArrayList<BrokerAvailabilityEvent>(); private final List<BrokerAvailabilityEvent> availabilityEvents = new ArrayList<BrokerAvailabilityEvent>();
private final Object monitor = new Object();
@Override @Override
public void onApplicationEvent(BrokerAvailabilityEvent event) { public void onApplicationEvent(BrokerAvailabilityEvent event) {
synchronized (this.monitor) {
this.availabilityEvents.add(event); this.availabilityEvents.add(event);
this.monitor.notifyAll();
}
}
private List<BrokerAvailabilityEvent> awaitAvailabilityEvents(int eventCount) throws InterruptedException {
synchronized (this.monitor) {
while (this.availabilityEvents.size() < eventCount) {
this.monitor.wait();
}
return new ArrayList<BrokerAvailabilityEvent>(this.availabilityEvents);
}
} }
} }
} }

View File

@ -16,7 +16,6 @@
package org.springframework.messaging.simp.stomp; package org.springframework.messaging.simp.stomp;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;