Polish and fix issues in STOMP broker relay

Fix error in te code that handles the result of sending a heartbeat

Fix error in processing DISCONNECTED frames that closed the TCP
connection before the message was sent.
This commit is contained in:
Rossen Stoyanchev 2013-09-30 16:37:18 -04:00
parent 34dd844716
commit 48caeef4de
4 changed files with 70 additions and 65 deletions

View File

@ -100,14 +100,13 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
else if (SimpMessageType.DISCONNECT.equals(messageType)) { else if (SimpMessageType.DISCONNECT.equals(messageType)) {
String sessionId = headers.getSessionId(); String sessionId = headers.getSessionId();
this.subscriptionRegistry.unregisterAllSubscriptions(sessionId); this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
} else if (SimpMessageType.CONNECT.equals(messageType)) { }
String sessionId = headers.getSessionId(); else if (SimpMessageType.CONNECT.equals(messageType)) {
SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor replyHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); replyHeaders.setSessionId(headers.getSessionId());
connectAckHeaders.setSessionId(sessionId); replyHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, message);
Message<byte[]> connectAck = Message<byte[]> connectAck = MessageBuilder.withPayloadAndHeaders(EMPTY_PAYLOAD, replyHeaders).build();
MessageBuilder.withPayloadAndHeaders(EMPTY_PAYLOAD, connectAckHeaders).build();
this.messageChannel.send(connectAck); this.messageChannel.send(connectAck);
} }
} }

View File

@ -368,15 +368,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Message<byte[]> byteMessage = (Message<byte[]>) message; Message<byte[]> byteMessage = (Message<byte[]>) message;
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Forwarding to STOMP broker, message: " + message); logger.trace("Forwarding to STOMP broker, message: " + message);
} }
StompCommand command = StompHeaderAccessor.wrap(message).getCommand(); StompCommand command = StompHeaderAccessor.wrap(message).getCommand();
if (command == StompCommand.DISCONNECT) {
this.stompConnection.setDisconnected();
}
final Deferred<Boolean, Promise<Boolean>> deferred = new DeferredPromiseSpec<Boolean>().get(); final Deferred<Boolean, Promise<Boolean>> deferred = new DeferredPromiseSpec<Boolean>().get();
tcpConnection.send(byteMessage, new Consumer<Boolean>() { tcpConnection.send(byteMessage, new Consumer<Boolean>() {
@ -393,8 +389,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
handleTcpClientFailure("Timed out waiting for message to be forwarded to the broker", null); handleTcpClientFailure("Timed out waiting for message to be forwarded to the broker", null);
} }
else if (!success) { else if (!success) {
if (command != StompCommand.DISCONNECT) { handleTcpClientFailure("Failed to forward message to the broker", null);
handleTcpClientFailure("Failed to forward message to the broker", null); }
else {
if (command == StompCommand.DISCONNECT) {
this.stompConnection.setDisconnected();
} }
} }
} }
@ -508,8 +507,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
tcpConn.send(MessageBuilder.withPayload(heartbeatPayload).build(), tcpConn.send(MessageBuilder.withPayload(heartbeatPayload).build(),
new Consumer<Boolean>() { new Consumer<Boolean>() {
@Override @Override
public void accept(Boolean t) { public void accept(Boolean result) {
handleTcpClientFailure("Failed to send heartbeat to the broker", null); if (!result) {
handleTcpClientFailure("Failed to send heartbeat to the broker", null);
}
} }
}); });
} }
@ -542,7 +543,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (StompCommand.ERROR.equals(headers.getCommand())) { if (StompCommand.ERROR.equals(headers.getCommand())) {
if (logger.isErrorEnabled()) { if (logger.isErrorEnabled()) {
logger.error("System session received ERROR frame from broker: " + message); logger.error("STOMP ERROR frame on system session: " + message);
} }
} }
else { else {

View File

@ -128,6 +128,20 @@ public class StompProtocolHandler implements SubProtocolHandler {
} }
} }
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setMessage(error.getMessage());
Message<byte[]> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
String payload = new String(this.stompEncoder.encode(message), Charset.forName("UTF-8"));
try {
session.sendMessage(new TextMessage(payload));
}
catch (Throwable t) {
// ignore
}
}
/** /**
* Handle STOMP messages going back out to WebSocket clients. * Handle STOMP messages going back out to WebSocket clients.
*/ */
@ -143,7 +157,7 @@ public class StompProtocolHandler implements SubProtocolHandler {
if (headers.getMessageType() == SimpMessageType.CONNECT_ACK) { if (headers.getMessageType() == SimpMessageType.CONNECT_ACK) {
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED); StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
connectedHeaders.setVersion(getVersion(headers)); connectedHeaders.setVersion(getVersion(headers));
connectedHeaders.setHeartbeat(0, 0); connectedHeaders.setHeartbeat(0, 0); // no heart-beat support with simple broker
headers = connectedHeaders; headers = connectedHeaders;
} }
@ -180,6 +194,41 @@ public class StompProtocolHandler implements SubProtocolHandler {
} }
} }
private String getVersion(StompHeaderAccessor connectAckHeaders) {
String name = StompHeaderAccessor.CONNECT_MESSAGE_HEADER;
Message<?> connectMessage = (Message<?>) connectAckHeaders.getHeader(name);
StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(connectMessage);
Assert.notNull(connectMessage, "CONNECT_ACK does not contain original CONNECT " + connectAckHeaders);
Set<String> acceptVersions = connectHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) {
return "1.2";
}
else if (acceptVersions.contains("1.1")) {
return "1.1";
}
else if (acceptVersions.isEmpty()) {
return null;
}
else {
throw new StompConversionException("Unsupported version '" + acceptVersions + "'");
}
}
private void augmentConnectedHeaders(StompHeaderAccessor headers, WebSocketSession session) {
Principal principal = session.getPrincipal();
if (principal != null) {
headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
headers.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId());
if (this.queueSuffixResolver != null) {
String suffix = session.getId();
this.queueSuffixResolver.addQueueSuffix(principal.getName(), session.getId(), suffix);
}
}
}
@Override @Override
public String resolveSessionId(Message<?> message) { public String resolveSessionId(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
@ -203,50 +252,4 @@ public class StompProtocolHandler implements SubProtocolHandler {
outputChannel.send(message); outputChannel.send(message);
} }
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setMessage(error.getMessage());
Message<byte[]> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
String payload = new String(this.stompEncoder.encode(message), Charset.forName("UTF-8"));
try {
session.sendMessage(new TextMessage(payload));
}
catch (Throwable t) {
// ignore
}
}
private void augmentConnectedHeaders(StompHeaderAccessor headers, WebSocketSession session) {
Principal principal = session.getPrincipal();
if (principal != null) {
headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
headers.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId());
if (this.queueSuffixResolver != null) {
String suffix = session.getId();
this.queueSuffixResolver.addQueueSuffix(principal.getName(), session.getId(), suffix);
}
}
}
private String getVersion(StompHeaderAccessor connectAckHeaders) {
Message<?> connectMessage =
(Message<?>) connectAckHeaders.getHeader(StompHeaderAccessor.CONNECT_MESSAGE_HEADER);
StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(connectMessage);
Set<String> acceptVersions = connectHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) {
return "1.2";
}
else if (acceptVersions.contains("1.1")) {
return "1.1";
}
else if (acceptVersions.isEmpty()) {
return null;
}
else {
throw new StompConversionException("Unsupported version '" + acceptVersions + "'");
}
}
} }

View File

@ -65,7 +65,9 @@ public class ExecutorSubscribableChannel extends AbstractSubscribableChannel {
@Override @Override
public boolean sendInternal(final Message<?> message, long timeout) { public boolean sendInternal(final Message<?> message, long timeout) {
logger.trace("subscribers " + this.handlers); if (logger.isTraceEnabled()) {
logger.trace("subscribers " + this.handlers);
}
for (final MessageHandler handler : this.handlers) { for (final MessageHandler handler : this.handlers) {
if (this.executor == null) { if (this.executor == null) {