From a489c2cf3873fa07f60ff15d86518e3c9224228f Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Fri, 30 Aug 2013 11:22:37 +0100 Subject: [PATCH 1/6] Add StompCodec Previously, the broker relay's TCP client used Reactor's built in delimited codec as part of its parsing of STOMP frames. \0 was used as the delimiter. This worked for most STOMP frames but, crucially, not for frames with a body that contained \0: when such a frame was received it would be truncated. This commit adds a custom codec that parses STOMP frames more intelligently. It honours the content-length header allowing it to correctly parse frames with a body that contains \0. The codec largely delegates to two new classes: StompEncoder and StompDecoder. For consistency, code that previously used StompMessageConverter has been reworked to use these new encoder and decoder classes. Issue: SPR-10818 --- .../stomp/StompBrokerRelayMessageHandler.java | 65 +++-- .../messaging/simp/stomp/StompCodec.java | 68 ++++++ .../messaging/simp/stomp/StompDecoder.java | 157 ++++++++++++ .../messaging/simp/stomp/StompEncoder.java | 115 +++++++++ .../simp/stomp/StompMessageConverter.java | 231 ------------------ .../simp/stomp/StompProtocolHandler.java | 23 +- ...erRelayMessageHandlerIntegrationTests.java | 3 +- .../messaging/simp/stomp/StompCodecTests.java | 212 ++++++++++++++++ .../stomp/StompMessageConverterTests.java | 153 ------------ .../simp/stomp/StompProtocolHandlerTests.java | 3 +- 10 files changed, 598 insertions(+), 432 deletions(-) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java delete mode 100644 spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java create mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java delete mode 100644 spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 88c73ca645f..d2783cc94b0 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -17,7 +17,6 @@ package org.springframework.messaging.simp.stomp; import java.net.InetSocketAddress; -import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -34,7 +33,6 @@ import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.handler.AbstractBrokerMessageHandler; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; import reactor.core.Environment; import reactor.core.composable.Composable; @@ -45,8 +43,6 @@ import reactor.function.Consumer; import reactor.tcp.Reconnect; import reactor.tcp.TcpClient; import reactor.tcp.TcpConnection; -import reactor.tcp.encoding.DelimitedCodec; -import reactor.tcp.encoding.StandardCodecs; import reactor.tcp.netty.NettyTcpClient; import reactor.tcp.spec.TcpClientSpec; import reactor.tuple.Tuple; @@ -74,11 +70,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private String systemPasscode = "guest"; - private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); - private Environment environment; - private TcpClient tcpClient; + private TcpClient, Message> tcpClient; private final Map relaySessions = new ConcurrentHashMap(); @@ -159,9 +153,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @Override protected void startInternal() { this.environment = new Environment(); - this.tcpClient = new TcpClientSpec(NettyTcpClient.class) + this.tcpClient = new TcpClientSpec, Message>(NettyTcpClient.class) .env(this.environment) - .codec(new DelimitedCodec((byte) 0, true, StandardCodecs.STRING_CODEC)) + .codec(new StompCodec()) .connect(this.relayHost, this.relayPort) .get(); @@ -275,14 +269,14 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler public void connect(final Message connectMessage) { Assert.notNull(connectMessage, "connectMessage is required"); - Composable> connectionComposable = openTcpConnection(); - connectionComposable.consume(new Consumer>() { + Composable, Message>> promise = openTcpConnection(); + promise.consume(new Consumer, Message>>() { @Override - public void accept(TcpConnection connection) { + public void accept(TcpConnection, Message> connection) { handleTcpConnection(connection, connectMessage); } }); - connectionComposable.when(Throwable.class, new Consumer() { + promise.when(Throwable.class, new Consumer() { @Override public void accept(Throwable ex) { relaySessions.remove(sessionId); @@ -291,29 +285,22 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler }); } - protected Composable> openTcpConnection() { + protected Composable, Message>> openTcpConnection() { return tcpClient.open(); } - protected void handleTcpConnection(TcpConnection tcpConn, final Message connectMessage) { + protected void handleTcpConnection(TcpConnection, Message> tcpConn, final Message connectMessage) { this.stompConnection.setTcpConnection(tcpConn); - tcpConn.in().consume(new Consumer() { + tcpConn.in().consume(new Consumer>() { @Override - public void accept(String message) { + public void accept(Message message) { readStompFrame(message); } }); forwardInternal(tcpConn, connectMessage); } - private void readStompFrame(String stompFrame) { - - // heartbeat - if (StringUtils.isEmpty(stompFrame)) { - return; - } - - Message message = stompMessageConverter.toMessage(stompFrame); + private void readStompFrame(Message message) { if (logger.isTraceEnabled()) { logger.trace("Reading message " + message); } @@ -378,24 +365,24 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } private boolean forwardInternal(final Message message) { - TcpConnection tcpConnection = this.stompConnection.getReadyConnection(); + TcpConnection, Message> tcpConnection = this.stompConnection.getReadyConnection(); if (tcpConnection == null) { return false; } return forwardInternal(tcpConnection, message); } - private boolean forwardInternal(TcpConnection tcpConnection, final Message message) { + @SuppressWarnings("unchecked") + private boolean forwardInternal(TcpConnection, Message> tcpConnection, final Message message) { + + Assert.isInstanceOf(byte[].class, message.getPayload(), "Message's payload must be a byte[]"); if (logger.isTraceEnabled()) { logger.trace("Forwarding to STOMP broker, message: " + message); } - byte[] bytes = stompMessageConverter.fromMessage(message); - String payload = new String(bytes, Charset.forName("UTF-8")); - final Deferred> deferred = new DeferredPromiseSpec().get(); - tcpConnection.send(payload, new Consumer() { + tcpConnection.send((Message)message, new Consumer() { @Override public void accept(Boolean success) { deferred.accept(success); @@ -434,18 +421,22 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private static class StompConnection { - private volatile TcpConnection connection; + private volatile TcpConnection, Message> connection; - private AtomicReference> readyConnection = - new AtomicReference>(); + private AtomicReference, Message>> readyConnection = + new AtomicReference, Message>>(); - public void setTcpConnection(TcpConnection connection) { + public void setTcpConnection(TcpConnection, Message> connection) { Assert.notNull(connection, "connection must not be null"); this.connection = connection; } - public TcpConnection getReadyConnection() { + /** + * Return the underlying {@link TcpConnection} but only after the CONNECTED STOMP + * frame is received. + */ + public TcpConnection, Message> getReadyConnection() { return this.readyConnection.get(); } @@ -488,7 +479,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } @Override - protected Composable> openTcpConnection() { + protected Composable, Message>> openTcpConnection() { return tcpClient.open(new Reconnect() { @Override public Tuple2 reconnect(InetSocketAddress address, int attempt) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java new file mode 100644 index 00000000000..bef1726ae56 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import org.springframework.messaging.Message; + +import reactor.function.Consumer; +import reactor.function.Function; +import reactor.io.Buffer; +import reactor.tcp.encoding.Codec; + +/** + * A Reactor TCP {@link Codec} for sending and receiving STOMP messages + * + * @author Andy Wilkinson + * @since 4.0 + */ +public class StompCodec implements Codec, Message> { + + private static final StompDecoder DECODER = new StompDecoder(); + + private static final Function, Buffer> ENCODER_FUNCTION = new Function, Buffer>() { + + private final StompEncoder encoder = new StompEncoder(); + + @Override + public Buffer apply(Message message) { + return Buffer.wrap(this.encoder.encode(message)); + } + }; + + @Override + public Function> decoder(final Consumer> next) { + return new Function>() { + + @Override + public Message apply(Buffer buffer) { + while (buffer.remaining() > 0) { + Message message = DECODER.decode(buffer.byteBuffer()); + if (message != null) { + next.accept(message); + } + } + return null; + } + }; + } + + @Override + public Function, Buffer> encoder() { + return ENCODER_FUNCTION; + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java new file mode 100644 index 00000000000..5c45244a204 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java @@ -0,0 +1,157 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; + +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * A decoder for STOMP frames + * + * @author awilkinson + * @since 4.0 + */ +public class StompDecoder { + + private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + + + /** + * Decodes a STOMP frame in the given {@code buffer} into a {@link Message}. + * + * @param buffer The buffer to decode the frame from + * @return The decoded message + */ + public Message decode(ByteBuffer buffer) { + skipLeadingEol(buffer); + String command = readCommand(buffer); + if (command.length() > 0) { + MultiValueMap headers = readHeaders(buffer); + byte[] payload = readPayload(buffer, headers); + + return MessageBuilder.withPayloadAndHeaders(payload, + StompHeaderAccessor.create(StompCommand.valueOf(command), headers)).build(); + } + else { + // Heartbeat + return null; + } + + } + + private String readCommand(ByteBuffer buffer) { + ByteArrayOutputStream command = new ByteArrayOutputStream(); + while (buffer.remaining() > 0 && !isEol(buffer)) { + command.write(buffer.get()); + } + return new String(command.toByteArray(), UTF8_CHARSET); + } + + private MultiValueMap readHeaders(ByteBuffer buffer) { + MultiValueMap headers = new LinkedMultiValueMap(); + while (true) { + ByteArrayOutputStream headerStream = new ByteArrayOutputStream(); + while (buffer.remaining() > 0 && !isEol(buffer)) { + headerStream.write(buffer.get()); + } + if (headerStream.size() > 0) { + String header = new String(headerStream.toByteArray(), UTF8_CHARSET); + int colonIndex = header.indexOf(':'); + if (colonIndex <= 0 || colonIndex == header.length() - 1) { + throw new StompConversionException( + "Illegal header: '" + header + "'. A header must be of the form : headers) { + String contentLengthString = headers.getFirst("content-length"); + if (contentLengthString != null) { + int contentLength = Integer.valueOf(contentLengthString); + byte[] payload = new byte[contentLength]; + buffer.get(payload); + if (buffer.remaining() < 1 || buffer.get() != 0) { + throw new StompConversionException("Frame must be terminated with a null octect"); + } + return payload; + } + else { + ByteArrayOutputStream payload = new ByteArrayOutputStream(); + while (buffer.remaining() > 0) { + byte b = buffer.get(); + if (b == 0) { + return payload.toByteArray(); + } + else { + payload.write(b); + } + } + } + + throw new StompConversionException("Frame must be terminated with a null octect"); + } + + private void skipLeadingEol(ByteBuffer buffer) { + while (true) { + if (!isEol(buffer)) { + break; + } + } + } + + private boolean isEol(ByteBuffer buffer) { + if (buffer.remaining() > 0) { + byte b = buffer.get(); + if (b == '\n') { + return true; + } + else if (b == '\r') { + if (buffer.remaining() > 0 && buffer.get() == '\n') { + return true; + } + else { + throw new StompConversionException("'\\r' must be followed by '\\n'"); + } + } + buffer.position(buffer.position() - 1); + } + return false; + } +} \ No newline at end of file diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java new file mode 100644 index 00000000000..f7760d57def --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.List; +import java.util.Map.Entry; + +import org.springframework.messaging.Message; + +/** + * An encoder for STOMP frames + * + * @author Andy Wilkinson + * @since 4.0 + */ +public final class StompEncoder { + + private static final byte LF = '\n'; + + private static final byte COLON = ':'; + + private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + + + /** + * Encodes the given STOMP {@code message} into a {@code byte[]} + * + * @param message The message to encode + * + * @return The encoded message + */ + public byte[] encode(Message message) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream output = new DataOutputStream(baos); + + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + + writeCommand(headers, output); + writeHeaders(headers, message, output); + output.write(LF); + writeBody(message, output); + output.write((byte)0); + + return baos.toByteArray(); + } + catch (IOException e) { + throw new StompConversionException("Failed to encode STOMP frame", e); + } + } + + private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException { + output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET)); + output.write(LF); + } + + private void writeHeaders(StompHeaderAccessor headers, Message message, DataOutputStream output) + throws IOException { + + for (Entry> entry : headers.toStompHeaderMap().entrySet()) { + byte[] key = getUtf8BytesEscapingIfNecessary(entry.getKey(), headers); + for (String value : entry.getValue()) { + output.write(key); + output.write(COLON); + output.write(getUtf8BytesEscapingIfNecessary(value, headers)); + output.write(LF); + } + } + if (headers.getCommand() == StompCommand.SEND || + headers.getCommand() == StompCommand.MESSAGE || + headers.getCommand() == StompCommand.ERROR) { + output.write("content-length:".getBytes(UTF8_CHARSET)); + output.write(Integer.toString(message.getPayload().length).getBytes(UTF8_CHARSET)); + output.write(LF); + } + } + + private void writeBody(Message message, DataOutputStream output) throws IOException { + output.write(message.getPayload()); + } + + private byte[] getUtf8BytesEscapingIfNecessary(String input, StompHeaderAccessor headers) { + if (headers.getCommand() != StompCommand.CONNECT && headers.getCommand() != StompCommand.CONNECTED) { + return escape(input).getBytes(UTF8_CHARSET); + } + else { + return input.getBytes(UTF8_CHARSET); + } + } + + private String escape(String input) { + return input.replaceAll("\\\\", "\\\\\\\\") + .replaceAll(":", "\\\\c") + .replaceAll("\n", "\\\\n") + .replaceAll("\r", "\\\\r"); + } +} \ No newline at end of file diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java deleted file mode 100644 index a410c615c5a..00000000000 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.messaging.simp.stomp; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.nio.charset.Charset; -import java.util.List; -import java.util.Map.Entry; - -import org.springframework.messaging.Message; -import org.springframework.messaging.support.MessageBuilder; -import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; - - -/** - * @author Gary Russell - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class StompMessageConverter { - - private static final Charset STOMP_CHARSET = Charset.forName("UTF-8"); - - public static final byte LF = 0x0a; - - public static final byte CR = 0x0d; - - private static final byte COLON = ':'; - - /** - * @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String. - */ - public Message toMessage(Object stompContent) { - - byte[] byteContent = null; - if (stompContent instanceof String) { - byteContent = ((String) stompContent).getBytes(STOMP_CHARSET); - } - else if (stompContent instanceof byte[]){ - byteContent = (byte[]) stompContent; - } - else { - throw new IllegalArgumentException( - "stompContent is neither String nor byte[]: " + stompContent.getClass()); - } - - int totalLength = byteContent.length; - if (byteContent[totalLength-1] == 0) { - totalLength--; - } - - int payloadIndex = findIndexOfPayload(byteContent); - if (payloadIndex == 0) { - throw new StompConversionException("No command found"); - } - - String headerContent = new String(byteContent, 0, payloadIndex, STOMP_CHARSET); - Parser parser = new Parser(headerContent); - - StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim()); - Assert.notNull(command, "No command found"); - - MultiValueMap headers = new LinkedMultiValueMap(); - while (parser.hasNext()) { - String header = parser.nextToken(COLON); - if (header != null) { - if (parser.hasNext()) { - String value = parser.nextToken(LF); - headers.add(header, value); - } - else { - throw new StompConversionException("Parse exception for " + headerContent); - } - } - } - - byte[] payload = new byte[totalLength - payloadIndex]; - System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); - StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers); - return MessageBuilder.withPayloadAndHeaders(payload, stompHeaders).build(); - } - - private int findIndexOfPayload(byte[] bytes) { - int i; - // ignore any leading EOL from the previous message - for (i = 0; i < bytes.length; i++) { - if (bytes[i] != '\n' && bytes[i] != '\r') { - break; - } - bytes[i] = ' '; - } - int index = 0; - for (; i < bytes.length - 1; i++) { - if (bytes[i] == LF && bytes[i+1] == LF) { - index = i + 2; - break; - } - if ((i < (bytes.length - 3)) && - (bytes[i] == CR && bytes[i+1] == LF && bytes[i+2] == CR && bytes[i+3] == LF)) { - index = i + 4; - break; - } - } - if (i >= bytes.length) { - throw new StompConversionException("No end of headers found"); - } - return index; - } - - public byte[] fromMessage(Message message) { - - byte[] payload; - if (message.getPayload() instanceof byte[]) { - payload = (byte[]) message.getPayload(); - } - else { - throw new IllegalArgumentException( - "stompContent is not byte[]: " + message.getPayload().getClass()); - } - - ByteArrayOutputStream out = new ByteArrayOutputStream(); - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - - try { - out.write(stompHeaders.getCommand().toString().getBytes("UTF-8")); - out.write(LF); - for (Entry> entry : stompHeaders.toStompHeaderMap().entrySet()) { - String key = entry.getKey(); - key = replaceAllOutbound(key); - for (String value : entry.getValue()) { - out.write(key.getBytes("UTF-8")); - out.write(COLON); - value = replaceAllOutbound(value); - out.write(value.getBytes("UTF-8")); - out.write(LF); - } - } - out.write(LF); - out.write(payload); - out.write(0); - return out.toByteArray(); - } - catch (IOException e) { - throw new StompConversionException("Failed to serialize " + message, e); - } - } - - private String replaceAllOutbound(String key) { - return key.replaceAll("\\\\", "\\\\") - .replaceAll(":", "\\\\c") - .replaceAll("\n", "\\\\n") - .replaceAll("\r", "\\\\r"); - } - - - private class Parser { - - private final String content; - - private int offset; - - public Parser(String content) { - this.content = content; - } - - public boolean hasNext() { - return this.offset < this.content.length(); - } - - public String nextToken(byte delimiter) { - if (this.offset >= this.content.length()) { - return null; - } - int delimAt = this.content.indexOf(delimiter, this.offset); - if (delimAt == -1) { - if (this.offset == this.content.length() - 1 && delimiter == COLON && - this.content.charAt(this.offset) == LF) { - this.offset++; - return null; - } - else if (this.offset == this.content.length() - 2 && delimiter == COLON && - this.content.charAt(this.offset) == CR && - this.content.charAt(this.offset + 1) == LF) { - this.offset += 2; - return null; - } - else { - throw new StompConversionException("No delimiter found at offset " + offset + " in " + this.content); - } - } - int escapeAt = this.content.indexOf('\\', this.offset); - String token = this.content.substring(this.offset, delimAt + 1); - this.offset += token.length(); - if (escapeAt >= 0 && escapeAt < delimAt) { - char escaped = this.content.charAt(escapeAt + 1); - if (escaped == 'n' || escaped == 'c' || escaped == '\\') { - token = token.replaceAll("\\\\n", "\n") - .replaceAll("\\\\r", "\r") - .replaceAll("\\\\c", ":") - .replaceAll("\\\\\\\\", "\\\\"); - } - else { - throw new StompConversionException("Invalid escape sequence \\" + escaped); - } - } - int length = token.length(); - if (delimiter == LF && length > 1 && token.charAt(length - 2) == CR) { - return token.substring(0, length - 2); - } - else { - return token.substring(0, length - 1); - } - } - } -} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java index ce5af6da96b..f4ead2a35c9 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java @@ -17,6 +17,7 @@ package org.springframework.messaging.simp.stomp; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.security.Principal; import java.util.Arrays; @@ -63,7 +64,9 @@ public class StompProtocolHandler implements SubProtocolHandler { private final Log logger = LogFactory.getLog(StompProtocolHandler.class); - private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); + private final StompDecoder stompDecoder = new StompDecoder(); + + private final StompEncoder stompEncoder = new StompEncoder(); private MutableUserQueueSuffixResolver queueSuffixResolver = new SimpleUserQueueSuffixResolver(); @@ -98,7 +101,8 @@ public class StompProtocolHandler implements SubProtocolHandler { try { Assert.isInstanceOf(TextMessage.class, webSocketMessage); String payload = ((TextMessage)webSocketMessage).getPayload(); - message = this.stompMessageConverter.toMessage(payload); + ByteBuffer byteBuffer = ByteBuffer.wrap(payload.getBytes(Charset.forName("UTF-8"))); + message = this.stompDecoder.decode(byteBuffer); } catch (Throwable error) { logger.error("Failed to parse STOMP frame, WebSocket message payload: ", error); @@ -133,6 +137,7 @@ public class StompProtocolHandler implements SubProtocolHandler { /** * Handle STOMP messages going back out to WebSocket clients. */ + @SuppressWarnings("unchecked") @Override public void handleMessageToClient(WebSocketSession session, Message message) { @@ -156,7 +161,7 @@ public class StompProtocolHandler implements SubProtocolHandler { try { message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); - byte[] bytes = this.stompMessageConverter.fromMessage(message); + byte[] bytes = this.stompEncoder.encode((Message)message); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } catch (Throwable t) { @@ -204,19 +209,19 @@ public class StompProtocolHandler implements SubProtocolHandler { } } - Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); - byte[] bytes = this.stompMessageConverter.fromMessage(connectedMessage); - session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); + Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); + String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8")); + session.sendMessage(new TextMessage(payload)); } protected void sendErrorMessage(WebSocketSession session, Throwable error) { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); headers.setMessage(error.getMessage()); - Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - byte[] bytes = this.stompMessageConverter.fromMessage(message); + Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + String payload = new String(this.stompEncoder.encode(message), Charset.forName("UTF-8")); try { - session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); + session.sendMessage(new TextMessage(payload)); } catch (Throwable t) { // ignore diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index ab53dc4d40f..078b77c1008 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -253,7 +253,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } public void awaitAndAssert() throws InterruptedException { - boolean result = this.latch.await(5000, TimeUnit.MILLISECONDS); + boolean result = this.latch.await(10000, TimeUnit.MILLISECONDS); assertTrue(getAsString(), result && this.unexpected.isEmpty()); } @@ -356,6 +356,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { public static MessageExchangeBuilder connect(String sessionId) { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); + headers.setAcceptVersion("1.1,1.2"); Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); return new MessageExchangeBuilder(message); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java new file mode 100644 index 00000000000..887c0f9e3ae --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java @@ -0,0 +1,212 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +import reactor.function.Consumer; +import reactor.function.Function; +import reactor.io.Buffer; + +import static org.junit.Assert.*; + + +/** + * + * @author awilkinson + */ +public class StompCodecTests { + + private final ArgumentCapturingConsumer> consumer = new ArgumentCapturingConsumer>(); + + private final Function> decoder = new StompCodec().decoder(consumer); + + @Test + public void decodeFrameWithCrLfEols() { + Message frame = decode("DISCONNECT\r\n\r\n\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.DISCONNECT, headers.getCommand()); + assertEquals(0, headers.toStompHeaderMap().size()); + assertEquals(0, frame.getPayload().length); + } + + @Test + public void decodeFrameWithNoHeadersAndNoBody() { + Message frame = decode("DISCONNECT\n\n\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.DISCONNECT, headers.getCommand()); + assertEquals(0, headers.toStompHeaderMap().size()); + assertEquals(0, frame.getPayload().length); + } + + @Test + public void decodeFrameWithNoBody() { + String accept = "accept-version:1.1\n"; + String host = "host:github.org\n"; + + Message frame = decode("CONNECT\n" + accept + host + "\n\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.CONNECT, headers.getCommand()); + + assertEquals(2, headers.toStompHeaderMap().size()); + assertEquals("1.1", headers.getFirstNativeHeader("accept-version")); + assertEquals("github.org", headers.getHost()); + + assertEquals(0, frame.getPayload().length); + } + + @Test + public void decodeFrame() throws UnsupportedEncodingException { + Message frame = decode("SEND\ndestination:test\n\nThe body of the message\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.SEND, headers.getCommand()); + + assertEquals(1, headers.toStompHeaderMap().size()); + assertEquals("test", headers.getDestination()); + + String bodyText = new String(frame.getPayload()); + assertEquals("The body of the message", bodyText); + } + + @Test + public void decodeFrameWithContentLength() { + Message frame = decode("SEND\ncontent-length:23\n\nThe body of the message\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.SEND, headers.getCommand()); + + assertEquals(1, headers.toStompHeaderMap().size()); + assertEquals(Integer.valueOf(23), headers.getContentLength()); + + String bodyText = new String(frame.getPayload()); + assertEquals("The body of the message", bodyText); + } + + @Test + public void decodeFrameWithNullOctectsInTheBody() { + Message frame = decode("SEND\ncontent-length:23\n\nThe b\0dy \0f the message\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.SEND, headers.getCommand()); + + assertEquals(1, headers.toStompHeaderMap().size()); + assertEquals(Integer.valueOf(23), headers.getContentLength()); + + String bodyText = new String(frame.getPayload()); + assertEquals("The b\0dy \0f the message", bodyText); + } + + @Test + public void decodeFrameWithEscapedHeaders() { + Message frame = decode("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.DISCONNECT, headers.getCommand()); + + assertEquals(1, headers.toStompHeaderMap().size()); + assertEquals("alpha:bravo\r\n\\", headers.getFirstNativeHeader("a:\r\n\\b")); + } + + @Test + public void decodeMultipleFramesFromSameBuffer() { + String frame1 = "SEND\ndestination:test\n\nThe body of the message\0"; + String frame2 = "DISCONNECT\n\n\0"; + + Buffer buffer = Buffer.wrap(frame1 + frame2); + + final List> messages = new ArrayList>(); + new StompCodec().decoder(new Consumer>() { + @Override + public void accept(Message message) { + messages.add(message); + } + }).apply(buffer); + + assertEquals(2, messages.size()); + assertEquals(StompCommand.SEND, StompHeaderAccessor.wrap(messages.get(0)).getCommand()); + assertEquals(StompCommand.DISCONNECT, StompHeaderAccessor.wrap(messages.get(1)).getCommand()); + } + + @Test + public void encodeFrameWithNoHeadersAndNoBody() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + + Message frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + + assertEquals("DISCONNECT\n\n\0", new StompCodec().encoder().apply(frame).asString()); + } + + @Test + public void encodeFrameWithHeaders() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setAcceptVersion("1.2"); + headers.setHost("github.org"); + + Message frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + + String frameString = new StompCodec().encoder().apply(frame).asString(); + + assertTrue(frameString.equals("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0") || + frameString.equals("CONNECT\nhost:github.org\naccept-version:1.2\n\n\0")); + } + + @Test + public void encodeFrameWithHeadersThatShouldBeEscaped() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\"); + + Message frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + + assertEquals("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0", new StompCodec().encoder().apply(frame).asString()); + } + + @Test + public void encodeFrameWithHeadersBody() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.addNativeHeader("a", "alpha"); + + Message frame = MessageBuilder.withPayloadAndHeaders("Message body".getBytes(), headers).build(); + + assertEquals("SEND\na:alpha\ncontent-length:12\n\nMessage body\0", new StompCodec().encoder().apply(frame).asString()); + } + + private Message decode(String stompFrame) { + this.decoder.apply(Buffer.wrap(stompFrame)); + return consumer.arguments.get(0); + } + + private static final class ArgumentCapturingConsumer implements Consumer { + + private final List arguments = new ArrayList(); + + @Override + public void accept(T t) { + arguments.add(t); + } + + } +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java deleted file mode 100644 index f2716271378..00000000000 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.messaging.simp.stomp; - -import java.util.Collections; -import java.util.Map; - -import org.junit.Before; -import org.junit.Test; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHeaders; -import org.springframework.messaging.simp.SimpMessageHeaderAccessor; -import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.web.socket.TextMessage; - -import static org.junit.Assert.*; - -/** - * @author Gary Russell - * @author Rossen Stoyanchev - */ -public class StompMessageConverterTests { - - private StompMessageConverter converter; - - - @Before - public void setup() { - this.converter = new StompMessageConverter(); - } - - @Test - public void connectFrame() throws Exception { - - String accept = "accept-version:1.1"; - String host = "host:github.org"; - - TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT) - .headers(accept, host).build(); - - @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(textMessage.getPayload()); - - assertEquals(0, message.getPayload().length); - - MessageHeaders headers = message.getHeaders(); - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - Map map = stompHeaders.toMap(); - assertEquals(5, map.size()); - assertNotNull(stompHeaders.getId()); - assertNotNull(stompHeaders.getTimestamp()); - assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType()); - assertEquals(StompCommand.CONNECT, stompHeaders.getCommand()); - assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS)); - - assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); - assertEquals("github.org", stompHeaders.getHost()); - - assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType()); - assertEquals(StompCommand.CONNECT, stompHeaders.getCommand()); - assertNotNull(headers.get(MessageHeaders.ID)); - assertNotNull(headers.get(MessageHeaders.TIMESTAMP)); - - String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); - - assertEquals("CONNECT\n", convertedBack.substring(0,8)); - assertTrue(convertedBack.contains(accept)); - assertTrue(convertedBack.contains(host)); - } - - @Test - public void connectWithEscapes() throws Exception { - - String accept = "accept-version:1.1"; - String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org"; - - TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT) - .headers(accept, host).build(); - - @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(textMessage.getPayload()); - - assertEquals(0, message.getPayload().length); - - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); - assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.toNativeHeaderMap().get("ho:\ns\rt").get(0)); - - String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); - - assertEquals("CONNECT\n", convertedBack.substring(0,8)); - assertTrue(convertedBack.contains(accept)); - assertTrue(convertedBack.contains(host)); - } - - @Test - public void connectCR12() throws Exception { - - String accept = "accept-version:1.2\n"; - String host = "host:github.org\n"; - String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8")); - - assertEquals(0, message.getPayload().length); - - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion()); - assertEquals("github.org", stompHeaders.getHost()); - - String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); - - assertEquals("CONNECT\n", convertedBack.substring(0,8)); - assertTrue(convertedBack.contains(accept)); - assertTrue(convertedBack.contains(host)); - } - - @Test - public void connectWithEscapesAndCR12() throws Exception { - - String accept = "accept-version:1.1\n"; - String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; - String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8")); - - assertEquals(0, message.getPayload().length); - - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); - assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.toNativeHeaderMap().get("ho:\ns\rt").get(0)); - - String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); - - assertEquals("CONNECT\n", convertedBack.substring(0,8)); - assertTrue(convertedBack.contains(accept)); - assertTrue(convertedBack.contains(host)); - } - -} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java index 62d16318fca..a6e2afa842b 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java @@ -16,6 +16,7 @@ package org.springframework.messaging.simp.stomp; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.HashSet; @@ -84,7 +85,7 @@ public class StompProtocolHandlerTests { assertEquals(1, this.session.getSentMessages().size()); textMessage = (TextMessage) this.session.getSentMessages().get(0); - Message message = new StompMessageConverter().toMessage(textMessage.getPayload()); + Message message = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes())); StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message); assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand()); From 8d2a376b0f0da9670eb1d465a2c5e7341e1d418d Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Mon, 2 Sep 2013 09:45:36 +0100 Subject: [PATCH 2/6] Remove CONNECT-related message buffer from STOMP relay Before this change, the StompProtocolHandler always responded to clients with a CONNECTED frame, while the STOMP broker relay independantly forwarded the client CONNECT to the broker and waited for the CONNECTED frame back. That meant the relay had to buffer client messages until it received the CONNECTED response from the message broker. This change ensures that clients wait for a CONNECTED frame from the message broker. The broker relay forwards the CONNECT frame to the broker. The broker responds with a CONNECTED frame, which the relay then forwards to the client. As a result, a (well-written) client will not send any messages to the relay until the connection to the broker is fully established. The StompProtcolHandler can now be configured whether to send CONNECTED frame back. By default that is off. So when using the simple broker, the StompProtocolHandler can still respond with CONNECTED frames. The relay's handling of a connection being dropped has also been improved. When a connection for a client relay session is dropped an ERROR frame will be sent back to the client. If a connection is closed as part of a DISCONNECT frame being sent, no ERROR frame is sent back to the client. When the connection for the system relay session is dropped, an event is published indicating that the broker is unavailable. Reactor's TcpClient will then attempt to re-restablish the connection. --- .../config/ServletStompEndpointRegistry.java | 4 +- ...cketMessageBrokerConfigurationSupport.java | 6 +- .../stomp/StompBrokerRelayMessageHandler.java | 71 +++--- .../simp/stomp/StompProtocolHandler.java | 36 ++- .../ServletStompEndpointRegistryTests.java | 2 +- ...erRelayMessageHandlerIntegrationTests.java | 222 +++++++++++------- .../simp/stomp/StompProtocolHandlerTests.java | 41 +++- 7 files changed, 232 insertions(+), 150 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java index 3e1050c9205..0f54c67fd80 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java @@ -54,7 +54,8 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry { public ServletStompEndpointRegistry(WebSocketHandler webSocketHandler, - MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) { + MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler, + boolean handleConnect) { Assert.notNull(webSocketHandler); Assert.notNull(userQueueSuffixResolver); @@ -63,6 +64,7 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry { this.subProtocolWebSocketHandler = findSubProtocolWebSocketHandler(webSocketHandler); this.stompHandler = new StompProtocolHandler(); this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver); + this.stompHandler.setHandleConnect(handleConnect); this.sockJsScheduler = defaultSockJsTaskScheduler; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java index ef077487fab..ba5fb13baa2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java @@ -57,8 +57,10 @@ public abstract class WebSocketMessageBrokerConfigurationSupport { @Bean public HandlerMapping brokerWebSocketHandlerMapping() { - ServletStompEndpointRegistry registry = new ServletStompEndpointRegistry( - subProtocolWebSocketHandler(), userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler()); + boolean brokerRelayConfigured = getMessageBrokerConfigurer().getStompBrokerRelay() != null; + ServletStompEndpointRegistry registry = new ServletStompEndpointRegistry(subProtocolWebSocketHandler(), + userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler(), !brokerRelayConfigured); + registerStompEndpoints(registry); AbstractHandlerMapping hm = registry.getHandlerMapping(); hm.setOrder(1); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index d2783cc94b0..5d93031ba0a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -17,13 +17,9 @@ package org.springframework.messaging.simp.stomp; import java.net.InetSocketAddress; -import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.Map; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicReference; import org.springframework.messaging.Message; @@ -249,12 +245,8 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private final String sessionId; - private final BlockingQueue> messageQueue = new LinkedBlockingQueue>(50); - private volatile StompConnection stompConnection = new StompConnection(); - private final Object monitor = new Object(); - private RelaySession(String sessionId) { Assert.notNull(sessionId, "sessionId is required"); @@ -291,6 +283,12 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler protected void handleTcpConnection(TcpConnection, Message> tcpConn, final Message connectMessage) { this.stompConnection.setTcpConnection(tcpConn); + tcpConn.on().close(new Runnable() { + @Override + public void run() { + connectionClosed(); + } + }); tcpConn.in().consume(new Consumer>() { @Override public void accept(Message message) { @@ -307,12 +305,8 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED == headers.getCommand()) { - synchronized(this.monitor) { - this.stompConnection.setReady(); - publishBrokerAvailableEvent(); - flushMessages(); - } - return; + this.stompConnection.setReady(); + publishBrokerAvailableEvent(); } headers.setSessionId(this.sessionId); @@ -344,24 +338,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler public void forward(Message message) { if (!this.stompConnection.isReady()) { - synchronized(this.monitor) { - if (!this.stompConnection.isReady()) { - this.messageQueue.add(message); - if (logger.isTraceEnabled()) { - logger.trace("Not connected, message queued. Queue size=" + this.messageQueue.size()); - } - return; - } - } + logger.warn("Message sent to relay before it was CONNECTED. Discarding message: " + message); + return; } - if (this.messageQueue.isEmpty()) { - forwardInternal(message); - } - else { - this.messageQueue.add(message); - flushMessages(); - } + forwardInternal(message); } private boolean forwardInternal(final Message message) { @@ -381,6 +362,12 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler logger.trace("Forwarding to STOMP broker, message: " + message); } + StompCommand command = StompHeaderAccessor.wrap(message).getCommand(); + + if (command == StompCommand.DISCONNECT) { + this.stompConnection.setDisconnected(); + } + final Deferred> deferred = new DeferredPromiseSpec().get(); tcpConnection.send((Message)message, new Consumer() { @Override @@ -396,7 +383,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler handleTcpClientFailure("Timed out waiting for message to be forwarded to the broker", null); } else if (!success) { - if (StompHeaderAccessor.wrap(message).getCommand() != StompCommand.DISCONNECT) { + if (command != StompCommand.DISCONNECT) { handleTcpClientFailure("Failed to forward message to the broker", null); } } @@ -408,13 +395,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler return (success != null) ? success : false; } - private void flushMessages() { - List> messages = new ArrayList>(); - this.messageQueue.drainTo(messages); - for (Message message : messages) { - if (!forwardInternal(message)) { - return; - } + protected void connectionClosed() { + relaySessions.remove(this.sessionId); + if (this.stompConnection.isReady()) { + sendError("Lost connection to the broker"); } } } @@ -461,6 +445,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private class SystemRelaySession extends RelaySession { + private static final long HEARTBEAT_SEND_INTERVAL = 10000; + + private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000; + public static final String ID = "stompRelaySystemSessionId"; @@ -473,7 +461,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler headers.setAcceptVersion("1.1,1.2"); headers.setLogin(systemLogin); headers.setPasscode(systemPasscode); - headers.setHeartbeat(0,0); + headers.setHeartbeat(HEARTBEAT_SEND_INTERVAL, HEARTBEAT_RECEIVE_INTERVAL); Message connectMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); super.connect(connectMessage); } @@ -488,6 +476,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler }); } + @Override + protected void connectionClosed() { + publishBrokerUnavailableEvent(); + } + @Override protected void sendMessageToClient(Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java index f4ead2a35c9..1c70bbb669a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java @@ -70,6 +70,7 @@ public class StompProtocolHandler implements SubProtocolHandler { private MutableUserQueueSuffixResolver queueSuffixResolver = new SimpleUserQueueSuffixResolver(); + private volatile boolean handleConnect = false; /** * Configure a resolver to use to maintain queue suffixes for user @@ -86,6 +87,29 @@ public class StompProtocolHandler implements SubProtocolHandler { return this.queueSuffixResolver; } + /** + * Configures the handling of CONNECT frames. When {@code true}, CONNECT + * frames will be handled by this handler, and a CONNECTED response will be + * sent. When {@code false}, CONNECT frames will be forwarded for + * handling by another component. + * + * @param handleConnect {@code true} if connect frames should be handled + * by this handler, {@code false} otherwise. + */ + public void setHandleConnect(boolean handleConnect) { + this.handleConnect = handleConnect; + } + + /** + * Returns whether or not this handler will handle CONNECT frames. + * + * @return Returns {@code true} if this handler will handle CONNECT frames, + * otherwise {@code false}. + */ + public boolean willHandleConnect() { + return this.handleConnect; + } + @Override public List getSupportedProtocols() { return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"); @@ -121,17 +145,17 @@ public class StompProtocolHandler implements SubProtocolHandler { message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); - if (SimpMessageType.CONNECT.equals(headers.getMessageType())) { + if (this.handleConnect && SimpMessageType.CONNECT.equals(headers.getMessageType())) { handleConnect(session, message); } - - outputChannel.send(message); + else { + outputChannel.send(message); + } } catch (Throwable t) { logger.error("Terminating STOMP session due to failure to send message: ", t); sendErrorMessage(session, t); } - } /** @@ -144,8 +168,8 @@ public class StompProtocolHandler implements SubProtocolHandler { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); headers.setCommandIfNotSet(StompCommand.MESSAGE); - if (StompCommand.CONNECTED.equals(headers.getCommand())) { - // Ignore for now since we already sent it + if (this.handleConnect && StompCommand.CONNECTED.equals(headers.getCommand())) { + // Ignore since we already sent it return; } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java index 7531e8ed208..f4c190ba39d 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java @@ -53,7 +53,7 @@ public class ServletStompEndpointRegistryTests { this.webSocketHandler = new SubProtocolWebSocketHandler(channel); this.queueSuffixResolver = new SimpleUserQueueSuffixResolver(); TaskScheduler taskScheduler = Mockito.mock(TaskScheduler.class); - this.registry = new ServletStompEndpointRegistry(webSocketHandler, queueSuffixResolver, taskScheduler); + this.registry = new ServletStompEndpointRegistry(webSocketHandler, queueSuffixResolver, taskScheduler, false); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index 078b77c1008..7441536ef9b 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory; import org.junit.After; import org.junit.Before; import org.junit.Test; - import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; import org.springframework.messaging.Message; @@ -63,16 +62,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { private ExpectationMatchingEventPublisher eventPublisher; + private int port; + @Before public void setUp() throws Exception { - int port = SocketUtils.findAvailableTcpPort(61613); + this.port = SocketUtils.findAvailableTcpPort(61613); - this.activeMQBroker = new BrokerService(); - this.activeMQBroker.addConnector("stomp://localhost:" + port); - this.activeMQBroker.setStartAsync(false); - this.activeMQBroker.setDeleteAllMessagesOnStartup(true); - this.activeMQBroker.start(); + createAndStartBroker(); this.responseChannel = new ExecutorSubscribableChannel(); this.responseHandler = new ExpectationMatchingMessageHandler(); @@ -86,6 +83,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { this.relay.start(); } + private void createAndStartBroker() throws Exception { + this.activeMQBroker = new BrokerService(); + this.activeMQBroker.addConnector("stomp://localhost:" + port); + this.activeMQBroker.setStartAsync(false); + this.activeMQBroker.setDeleteAllMessagesOnStartup(true); + this.activeMQBroker.start(); + } + @After public void tearDown() throws Exception { try { @@ -102,22 +107,24 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { String sess1 = "sess1"; MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build(); this.relay.handleMessage(conn1.message); + this.responseHandler.expect(conn1); String sess2 = "sess2"; MessageExchange conn2 = MessageExchangeBuilder.connect(sess2).build(); this.relay.handleMessage(conn2.message); + this.responseHandler.expect(conn2); + + this.responseHandler.awaitAndAssert(); String subs1 = "subs1"; String destination = "/topic/test"; MessageExchange subscribe = MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build(); - this.responseHandler.expect(subscribe); - this.relay.handleMessage(subscribe.message); + this.responseHandler.expect(subscribe); this.responseHandler.awaitAndAssert(); MessageExchange send = MessageExchangeBuilder.send(destination, "foo").andExpectMessage(sess1, subs1).build(); - this.responseHandler.reset(); this.responseHandler.expect(send); this.relay.handleMessage(send.message); @@ -129,7 +136,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { stopBrokerAndAwait(); - MessageExchange connect = MessageExchangeBuilder.connect("sess1").andExpectError().build(); + MessageExchange connect = MessageExchangeBuilder.connectWithError("sess1").build(); this.responseHandler.expect(connect); this.relay.handleMessage(connect.message); @@ -137,37 +144,31 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } @Test - public void brokerUnvailableErrorFrameOnSend() throws Exception { + public void brokerBecomingUnvailableTriggersErrorFrame() throws Exception { String sess1 = "sess1"; MessageExchange connect = MessageExchangeBuilder.connect(sess1).build(); + this.responseHandler.expect(connect); + this.relay.handleMessage(connect.message); - // TODO: expect CONNECTED - Thread.sleep(2000); + this.responseHandler.awaitAndAssert(); + + this.responseHandler.expect(MessageExchangeBuilder.error(sess1).build()); stopBrokerAndAwait(); - MessageExchange subscribe = MessageExchangeBuilder.subscribe(sess1, "s1", "/topic/a").andExpectError().build(); - this.responseHandler.expect(subscribe); - - this.relay.handleMessage(subscribe.message); this.responseHandler.awaitAndAssert(); } @Test public void brokerAvailabilityEvents() throws Exception { - // TODO: expect CONNECTED - Thread.sleep(2000); - - this.eventPublisher.expect(true, false); + this.eventPublisher.expect(true); + this.eventPublisher.awaitAndAssert(); + this.eventPublisher.expect(false); stopBrokerAndAwait(); - - // TODO: remove when stop is detecteded - this.relay.handleMessage(MessageExchangeBuilder.connect("sess1").build().message); - this.eventPublisher.awaitAndAssert(); } @@ -176,37 +177,55 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { String sess1 = "sess1"; MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build(); + this.responseHandler.expect(conn1); this.relay.handleMessage(conn1.message); + this.responseHandler.awaitAndAssert(); String subs1 = "subs1"; String destination = "/topic/test"; - MessageExchange subscribe = MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build(); + MessageExchange subscribe = + MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build(); this.responseHandler.expect(subscribe); this.relay.handleMessage(subscribe.message); this.responseHandler.awaitAndAssert(); + this.responseHandler.expect(MessageExchangeBuilder.error(sess1).build()); + stopBrokerAndAwait(); - // 1st message will see ERROR frame (broker shutdown is not but should be detected) - // 2nd message will be queued (a side effect of CONNECT/CONNECTED-buffering, likely to be removed) - // Finish this once the above changes are made. + this.responseHandler.awaitAndAssert(); + + this.eventPublisher.expect(true, false); + this.eventPublisher.awaitAndAssert(); + + this.eventPublisher.expect(true); + createAndStartBroker(); + this.eventPublisher.awaitAndAssert(); + + // TODO The event publisher assertions show that the broker's back up and the system relay session + // has reconnected. We need to decide what we want the reconnect behaviour to be for client relay + // sessions and add further message sending and assertions as appropriate. At the moment any client + // sessions will be closed and an ERROR from will be sent. + } + + @Test + public void disconnectClosesRelaySessionCleanly() throws Exception { + String sess1 = "sess1"; + MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build(); + this.responseHandler.expect(conn1); + this.relay.handleMessage(conn1.message); + this.responseHandler.awaitAndAssert(); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.setSessionId(sess1); + + this.relay.handleMessage(MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build()); -/* MessageExchange send = MessageExchangeBuilder.send(destination, "foo").build(); - this.responseHandler.reset(); - this.relay.handleMessage(send.message); Thread.sleep(2000); - this.activeMQBroker.start(); - Thread.sleep(5000); - - send = MessageExchangeBuilder.send(destination, "foo").andExpectMessage(sess1, subs1).build(); - this.responseHandler.reset(); - this.responseHandler.expect(send); - this.relay.handleMessage(send.message); - + // Check that we have not received an ERROR as a result of the connection closing this.responseHandler.awaitAndAssert(); -*/ } @@ -234,58 +253,66 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { */ private static class ExpectationMatchingMessageHandler implements MessageHandler { + private final Object monitor = new Object(); + private final List expected; - private final List actual = new CopyOnWriteArrayList<>(); + private final List actual = new ArrayList<>(); - private final List> unexpected = new CopyOnWriteArrayList<>(); - - private CountDownLatch latch = new CountDownLatch(1); + private final List> unexpected = new ArrayList<>(); public ExpectationMatchingMessageHandler(MessageExchange... expected) { - this.expected = new CopyOnWriteArrayList<>(expected); + synchronized (this.monitor) { + this.expected = new CopyOnWriteArrayList<>(expected); + } } - public void expect(MessageExchange... expected) { - this.expected.addAll(Arrays.asList(expected)); + synchronized (this.monitor) { + this.expected.addAll(Arrays.asList(expected)); + } } public void awaitAndAssert() throws InterruptedException { - boolean result = this.latch.await(10000, TimeUnit.MILLISECONDS); - assertTrue(getAsString(), result && this.unexpected.isEmpty()); - } - - public void reset() { - this.latch = new CountDownLatch(1); - this.expected.clear(); - this.actual.clear(); - this.unexpected.clear(); + long endTime = System.currentTimeMillis() + 10000; + synchronized (this.monitor) { + while (!this.expected.isEmpty() && System.currentTimeMillis() < endTime) { + this.monitor.wait(500); + } + boolean result = this.expected.isEmpty(); + assertTrue(getAsString(), result && this.unexpected.isEmpty()); + } } @Override public void handleMessage(Message message) throws MessagingException { - for (MessageExchange exch : this.expected) { - if (exch.matchMessage(message)) { - if (exch.isDone()) { - this.expected.remove(exch); - this.actual.add(exch); - if (this.expected.isEmpty()) { - this.latch.countDown(); + synchronized(this.monitor) { + for (MessageExchange exch : this.expected) { + if (exch.matchMessage(message)) { + if (exch.isDone()) { + this.expected.remove(exch); + this.actual.add(exch); + if (this.expected.isEmpty()) { + this.monitor.notifyAll(); + } } + return; } - return; } + this.unexpected.add(message); } - this.unexpected.add(message); } public String getAsString() { StringBuilder sb = new StringBuilder("\n"); - sb.append("INCOMPLETE:\n").append(this.expected).append("\n"); - sb.append("COMPLETE:\n").append(this.actual).append("\n"); - sb.append("UNMATCHED MESSAGES:\n").append(this.unexpected).append("\n"); + + synchronized (this.monitor) { + sb.append("INCOMPLETE:\n").append(this.expected).append("\n"); + sb.append("COMPLETE:\n").append(this.actual).append("\n"); + sb.append("UNMATCHED MESSAGES:\n").append(this.unexpected).append("\n"); + } + return sb.toString(); } } @@ -352,22 +379,28 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { this.headers = StompHeaderAccessor.wrap(message); } + public static MessageExchangeBuilder error(String sessionId) { + return new MessageExchangeBuilder(null).andExpectError(sessionId); + } public static MessageExchangeBuilder connect(String sessionId) { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); headers.setAcceptVersion("1.1,1.2"); Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - return new MessageExchangeBuilder(message); + + MessageExchangeBuilder builder = new MessageExchangeBuilder(message); + builder.expected.add(new StompConnectedFrameMessageMatcher(sessionId)); + return builder; } - public static MessageExchangeBuilder subscribe(String sessionId, String subscriptionId, String destination) { - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); + public static MessageExchangeBuilder connectWithError(String sessionId) { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); - headers.setSubscriptionId(subscriptionId); - headers.setDestination(destination); + headers.setAcceptVersion("1.1,1.2"); Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - return new MessageExchangeBuilder(message); + MessageExchangeBuilder builder = new MessageExchangeBuilder(message); + return builder.andExpectError(); } public static MessageExchangeBuilder subscribeWithReceipt(String sessionId, String subscriptionId, @@ -515,35 +548,48 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } } + private static class StompConnectedFrameMessageMatcher extends StompFrameMessageMatcher { + + + public StompConnectedFrameMessageMatcher(String sessionId) { + super(StompCommand.CONNECTED, sessionId); + } + + } + private static class ExpectationMatchingEventPublisher implements ApplicationEventPublisher { - private final List expected = new CopyOnWriteArrayList<>(); + private final List expected = new ArrayList<>(); - private final List actual = new CopyOnWriteArrayList<>(); + private final List actual = new ArrayList<>(); - private CountDownLatch latch = new CountDownLatch(1); + private final Object monitor = new Object(); public void expect(Boolean... expected) { - this.expected.addAll(Arrays.asList(expected)); + synchronized (this.monitor) { + this.expected.addAll(Arrays.asList(expected)); + } } public void awaitAndAssert() throws InterruptedException { - if (this.expected.size() == this.actual.size()) { + synchronized(this.monitor) { + long endTime = System.currentTimeMillis() + 5000; + while (this.expected.size() != this.actual.size() && System.currentTimeMillis() < endTime) { + this.monitor.wait(500); + } assertEquals(this.expected, this.actual); } - else { - assertTrue("Expected=" + this.expected + ", actual=" + this.actual, - this.latch.await(5, TimeUnit.SECONDS)); - } } @Override public void publishEvent(ApplicationEvent event) { if (event instanceof BrokerAvailabilityEvent) { - this.actual.add(((BrokerAvailabilityEvent) event).isBrokerAvailable()); - if (this.actual.size() == this.expected.size()) { - this.latch.countDown(); + synchronized(this.monitor) { + this.actual.add(((BrokerAvailabilityEvent) event).isBrokerAvailable()); + if (this.actual.size() == this.expected.size()) { + this.monitor.notifyAll(); + } } } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java index a6e2afa842b..78311b36d87 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java @@ -61,7 +61,33 @@ public class StompProtocolHandlerTests { } @Test - public void handleConnect() { + public void connectedResponseIsSentWhenHandlingConnect() { + this.stompHandler.setHandleConnect(true); + + TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers( + "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build(); + + this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel); + + verifyNoMoreInteractions(this.channel); + + // Check CONNECTED reply + + assertEquals(1, this.session.getSentMessages().size()); + textMessage = (TextMessage) this.session.getSentMessages().get(0); + Message message = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes())); + StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message); + + assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand()); + assertEquals("1.1", replyHeaders.getVersion()); + assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat()); + assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0)); + assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0)); + } + + @Test + public void connectIsForwardedWhenNotHandlingConnect() { + this.stompHandler.setHandleConnect(false); TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers( "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build(); @@ -81,18 +107,7 @@ public class StompProtocolHandlerTests { assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat()); assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion()); - // Check CONNECTED reply - - assertEquals(1, this.session.getSentMessages().size()); - textMessage = (TextMessage) this.session.getSentMessages().get(0); - Message message = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes())); - StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message); - - assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand()); - assertEquals("1.1", replyHeaders.getVersion()); - assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat()); - assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0)); - assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0)); + assertEquals(0, this.session.getSentMessages().size()); } } From 496d8321c3d8618b1f42923ec9caf0403422111e Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Mon, 2 Sep 2013 15:31:08 +0100 Subject: [PATCH 3/6] Add heart-beat support to STOMP broker relay Previously, the STOMP broker relay did not support heart-beats. It sent 0,0 in the heart-beats header for its own CONNECTED message, and set the heart-beats header to 0,0 when it was forwarding a CONNECTED from from a client to the broker. The broker relay now supports heart-beats for the system relay session. It will send heart-beats at the send interval that's been negotiated with the broker and will also expect to receive heart-beats at the receive interval that's been negotiated with the broker. The receive interval is multiplied by a factor of three to satisfy the STOMP spec's suggestion of lenience and ActiveMQ 5.8.0's heart-beat behaviour (see AMQ-4710). The broker relay also supports heart-beats between clients and the broker. For any given client's relay session, any heart-beats received from the client are forwarded on to the broker and any heart-beats received from the broker are sent back to the client. Internally, a heart-beat is represented as a Message with a byte array payload containing the single byte of new line ('\n') character and 'empty' headers. SubscriptionMethodReturnValueHandler has been updated to default the message type to SimpMessageType.MESSAGE. This eases the distinction between a heartbeat and a message that's been created from a return value from application code. --- build.gradle | 5 +- .../SubscriptionMethodReturnValueHandler.java | 2 + .../stomp/StompBrokerRelayMessageHandler.java | 63 +++++++++++++++---- .../messaging/simp/stomp/StompDecoder.java | 20 +++++- .../messaging/simp/stomp/StompEncoder.java | 24 +++++-- .../simp/stomp/StompHeaderAccessor.java | 6 +- .../simp/stomp/StompProtocolHandler.java | 32 +++++++--- ...erRelayMessageHandlerIntegrationTests.java | 22 ++++--- 8 files changed, 131 insertions(+), 43 deletions(-) diff --git a/build.gradle b/build.gradle index 90bf8b5ee53..381e9895cd7 100644 --- a/build.gradle +++ b/build.gradle @@ -71,6 +71,7 @@ configure(allprojects) { project -> maven { url "https://repository.apache.org/content/repositories/releases" } // tomcat 8 RC3 maven { url "https://repository.apache.org/content/repositories/snapshots" } // tomcat-websocket-* snapshots maven { url "https://maven.java.net/content/repositories/releases" } // javax.websocket, tyrus + maven { url 'http://repo.springsource.org/libs-snapshot' } // reactor } dependencies { @@ -352,8 +353,8 @@ project("spring-messaging") { optional(project(":spring-websocket")) optional(project(":spring-webmvc")) optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") - optional("org.projectreactor:reactor-core:1.0.0.M2") - optional("org.projectreactor:reactor-tcp:1.0.0.M2") + optional("org.projectreactor:reactor-core:1.0.0.BUILD-SNAPSHOT") + optional("org.projectreactor:reactor-tcp:1.0.0.BUILD-SNAPSHOT") optional("com.lmax:disruptor:3.1.1") optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815") optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java index 683e493e4ab..9b0e4fdde26 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java @@ -23,6 +23,7 @@ import org.springframework.messaging.core.MessageSendingOperations; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.method.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.annotation.SubscribeEvent; import org.springframework.messaging.support.MessageBuilder; @@ -97,6 +98,7 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); headers.setSessionId(this.sessionId); headers.setSubscriptionId(this.subscriptionId); + headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); return MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 5d93031ba0a..4596be012bf 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -196,21 +196,17 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); } - if (headers.getCommand() == null) { - logger.error("No STOMP command, ignoring message: " + message); - return; - } if (sessionId == null) { logger.error("No sessionId, ignoring message: " + message); return; } - if (command.requiresDestination() && !checkDestinationPrefix(destination)) { + + if (command != null && command.requiresDestination() && !checkDestinationPrefix(destination)) { return; } try { if (SimpMessageType.CONNECT.equals(messageType)) { - headers.setHeartbeat(0, 0); message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); RelaySession session = new RelaySession(sessionId); this.relaySessions.put(sessionId, session); @@ -305,8 +301,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED == headers.getCommand()) { - this.stompConnection.setReady(); - publishBrokerAvailableEvent(); + connected(headers, this.stompConnection); } headers.setSessionId(this.sessionId); @@ -314,12 +309,21 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler sendMessageToClient(message); } + protected void connected(StompHeaderAccessor headers, StompConnection stompConnection) { + this.stompConnection.setReady(); + publishBrokerAvailableEvent(); + } + private void handleTcpClientFailure(String message, Throwable ex) { if (logger.isErrorEnabled()) { logger.error(message + ", sessionId=" + this.sessionId, ex); } + disconnected(message); + } + + protected void disconnected(String errorMessage) { this.stompConnection.setDisconnected(); - sendError(message); + sendError(errorMessage); publishBrokerUnavailableEvent(); } @@ -445,12 +449,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private class SystemRelaySession extends RelaySession { - private static final long HEARTBEAT_SEND_INTERVAL = 10000; + private static final long HEARTBEAT_RECEIVE_MULTIPLIER = 3; - private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000; + private static final long HEARTBEAT_SEND_INTERVAL = 10000; + + private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000; public static final String ID = "stompRelaySystemSessionId"; + private final byte[] heartbeatPayload = new byte[] {'\n'}; + public SystemRelaySession() { super(ID); @@ -481,6 +489,39 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler publishBrokerUnavailableEvent(); } + @Override + protected void connected(StompHeaderAccessor headers, final StompConnection stompConnection) { + long brokerReceiveInterval = headers.getHeartbeat()[1]; + + if (HEARTBEAT_SEND_INTERVAL > 0 && brokerReceiveInterval > 0) { + long interval = Math.max(HEARTBEAT_SEND_INTERVAL, brokerReceiveInterval); + stompConnection.connection.on().writeIdle(interval, new Runnable() { + + @Override + public void run() { + stompConnection.connection.send(MessageBuilder.withPayload(heartbeatPayload).build()); + } + + }); + } + + long brokerSendInterval = headers.getHeartbeat()[0]; + if (HEARTBEAT_RECEIVE_INTERVAL > 0 && brokerSendInterval > 0) { + final long interval = + Math.max(HEARTBEAT_RECEIVE_INTERVAL, brokerSendInterval) * HEARTBEAT_RECEIVE_MULTIPLIER; + stompConnection.connection.on().readIdle(interval, new Runnable() { + @Override + public void run() { + String message = "Broker hearbeat missed: connection idle for more than " + interval + "ms"; + logger.warn(message); + disconnected(message); + } + }); + } + + super.connected(headers, stompConnection); + } + @Override protected void sendMessageToClient(Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java index 5c45244a204..e876df99ad2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java @@ -20,6 +20,8 @@ import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; import java.nio.charset.Charset; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.LinkedMultiValueMap; @@ -35,6 +37,10 @@ public class StompDecoder { private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + private static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'}; + + private final Log logger = LogFactory.getLog(StompDecoder.class); + /** * Decodes a STOMP frame in the given {@code buffer} into a {@link Message}. @@ -49,12 +55,20 @@ public class StompDecoder { MultiValueMap headers = readHeaders(buffer); byte[] payload = readPayload(buffer, headers); - return MessageBuilder.withPayloadAndHeaders(payload, + Message decodedMessage = MessageBuilder.withPayloadAndHeaders(payload, StompHeaderAccessor.create(StompCommand.valueOf(command), headers)).build(); + + if (logger.isTraceEnabled()) { + logger.trace("Decoded " + decodedMessage); + } + + return decodedMessage; } else { - // Heartbeat - return null; + if (logger.isTraceEnabled()) { + logger.trace("Decoded heartbeat"); + } + return MessageBuilder.withPayload(HEARTBEAT_PAYLOAD).build(); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java index f7760d57def..aa342c2e0f8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java @@ -23,6 +23,8 @@ import java.nio.charset.Charset; import java.util.List; import java.util.Map.Entry; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.messaging.Message; /** @@ -39,6 +41,7 @@ public final class StompEncoder { private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + private final Log logger = LogFactory.getLog(StompEncoder.class); /** * Encodes the given STOMP {@code message} into a {@code byte[]} @@ -49,16 +52,23 @@ public final class StompEncoder { */ public byte[] encode(Message message) { try { + if (logger.isTraceEnabled()) { + logger.trace("Encoding " + message); + } ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputStream output = new DataOutputStream(baos); StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - writeCommand(headers, output); - writeHeaders(headers, message, output); - output.write(LF); - writeBody(message, output); - output.write((byte)0); + if (isHeartbeat(headers)) { + output.write(message.getPayload()); + } else { + writeCommand(headers, output); + writeHeaders(headers, message, output); + output.write(LF); + writeBody(message, output); + output.write((byte)0); + } return baos.toByteArray(); } @@ -67,6 +77,10 @@ public final class StompEncoder { } } + private boolean isHeartbeat(StompHeaderAccessor headers) { + return headers.getCommand() == null; + } + private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException { output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET)); output.write(LF); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java index f39c43fe664..f3771059446 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java @@ -82,6 +82,8 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { public static final String STOMP_HEARTBEAT_HEADER = "heart-beat"; + private static final long[] DEFAULT_HEARTBEAT = new long[] {0, 0}; + // Other header names @@ -185,7 +187,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { result.put(STOMP_CONTENT_TYPE_HEADER, Arrays.asList(contentType.toString())); } - if (getCommand().requiresSubscriptionId()) { + if (getCommand() != null && getCommand().requiresSubscriptionId()) { String subscriptionId = getSubscriptionId(); if (subscriptionId != null) { String name = StompCommand.MESSAGE.equals(getCommand()) ? STOMP_SUBSCRIPTION_HEADER : STOMP_ID_HEADER; @@ -252,7 +254,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { public long[] getHeartbeat() { String rawValue = getFirstNativeHeader(STOMP_HEARTBEAT_HEADER); if (!StringUtils.hasText(rawValue)) { - return null; + return Arrays.copyOf(DEFAULT_HEARTBEAT, 2); } String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue); return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])}; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java index 1c70bbb669a..6bbd655e91b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java @@ -166,11 +166,17 @@ public class StompProtocolHandler implements SubProtocolHandler { public void handleMessageToClient(WebSocketSession session, Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - headers.setCommandIfNotSet(StompCommand.MESSAGE); + if (headers.getCommand() == null && SimpMessageType.MESSAGE == headers.getMessageType()) { + headers.setCommandIfNotSet(StompCommand.MESSAGE); + } - if (this.handleConnect && StompCommand.CONNECTED.equals(headers.getCommand())) { - // Ignore since we already sent it - return; + if (headers.getCommand() == StompCommand.CONNECTED) { + if (this.handleConnect) { + // Ignore since we already sent it + return; + } else { + augmentConnectedHeaders(headers, session); + } } if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) { @@ -222,20 +228,26 @@ public class StompProtocolHandler implements SubProtocolHandler { } connectedHeaders.setHeartbeat(0,0); + augmentConnectedHeaders(connectedHeaders, session); + + // TODO: security + + Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); + String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8")); + session.sendMessage(new TextMessage(payload)); + } + + private void augmentConnectedHeaders(StompHeaderAccessor headers, WebSocketSession session) { Principal principal = session.getPrincipal(); if (principal != null) { - connectedHeaders.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); - connectedHeaders.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId()); + headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); + headers.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId()); if (this.queueSuffixResolver != null) { String suffix = session.getId(); this.queueSuffixResolver.addQueueSuffix(principal.getName(), session.getId(), suffix); } } - - Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); - String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8")); - session.sendMessage(new TextMessage(payload)); } protected void sendErrorMessage(WebSocketSession session, Throwable error) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index 7441536ef9b..3cf1544313c 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -287,20 +287,22 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { @Override public void handleMessage(Message message) throws MessagingException { - synchronized(this.monitor) { - for (MessageExchange exch : this.expected) { - if (exch.matchMessage(message)) { - if (exch.isDone()) { - this.expected.remove(exch); - this.actual.add(exch); - if (this.expected.isEmpty()) { - this.monitor.notifyAll(); + if (StompHeaderAccessor.wrap(message).getCommand() != null) { + synchronized(this.monitor) { + for (MessageExchange exch : this.expected) { + if (exch.matchMessage(message)) { + if (exch.isDone()) { + this.expected.remove(exch); + this.actual.add(exch); + if (this.expected.isEmpty()) { + this.monitor.notifyAll(); + } } + return; } - return; } + this.unexpected.add(message); } - this.unexpected.add(message); } } From bae9134a6e00de75155054e32d29f5aa0fe1aa4b Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Wed, 25 Sep 2013 16:36:22 +0100 Subject: [PATCH 4/6] Upgrade to Reactor 1.0.0.M3 --- build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index 381e9895cd7..3c35a4d35f6 100644 --- a/build.gradle +++ b/build.gradle @@ -353,8 +353,8 @@ project("spring-messaging") { optional(project(":spring-websocket")) optional(project(":spring-webmvc")) optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") - optional("org.projectreactor:reactor-core:1.0.0.BUILD-SNAPSHOT") - optional("org.projectreactor:reactor-tcp:1.0.0.BUILD-SNAPSHOT") + optional("org.projectreactor:reactor-core:1.0.0.M3") + optional("org.projectreactor:reactor-tcp:1.0.0.M3") optional("com.lmax:disruptor:3.1.1") optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815") optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") From 6679feb77b49ef3f779f1dcfb40b5ee25c5d9f76 Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Thu, 26 Sep 2013 15:08:34 +0100 Subject: [PATCH 5/6] Improve handling of missed heartbeats Previously, when a broker heartbeat was mnissed, the STOMP connection would be left in a semi-disconnected state such that, for example, the read and write idle callbacks would still be active, even though the underlying TCP connection had been nulled out. As part of disconnecting the STOMP connection, this commit closes the underlying TCP connection when a heartbeat's missed which cancels the read and write idle callbacks. It also now copes with the underlying TCP connection being null when sending a heartbeat to the broker. This protects again a race condition between the write idle callback being fired, such that a heartbeat needs to be sent, and the connection being nulled out due to it being closed. --- .../simp/stomp/StompBrokerRelayMessageHandler.java | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 4596be012bf..b65a032d726 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -438,7 +438,12 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler public void setDisconnected() { this.readyConnection.set(null); - this.connection = null; + + TcpConnection, Message> localConnection = this.connection; + if (localConnection != null) { + localConnection.close(); + this.connection = null; + } } @Override @@ -499,7 +504,10 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @Override public void run() { - stompConnection.connection.send(MessageBuilder.withPayload(heartbeatPayload).build()); + TcpConnection, Message> connection = stompConnection.connection; + if (connection != null) { + connection.send(MessageBuilder.withPayload(heartbeatPayload).build()); + } } }); From 469aaa875492f9b86bd92b80f5f87d18233829c5 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 26 Sep 2013 16:04:31 -0400 Subject: [PATCH 6/6] Polish --- .../stomp/StompBrokerRelayMessageHandler.java | 71 +++++++++---------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index b65a032d726..beaef084538 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -70,7 +70,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private TcpClient, Message> tcpClient; - private final Map relaySessions = new ConcurrentHashMap(); + private final Map relaySessions = new ConcurrentHashMap(); /** @@ -158,7 +158,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler if (logger.isDebugEnabled()) { logger.debug("Initializing \"system\" TCP connection"); } - SystemRelaySession session = new SystemRelaySession(); + SystemStompRelaySession session = new SystemStompRelaySession(); this.relaySessions.put(session.getId(), session); session.connect(); } @@ -189,7 +189,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler SimpMessageType messageType = headers.getMessageType(); if (SimpMessageType.MESSAGE.equals(messageType)) { - sessionId = (sessionId == null) ? SystemRelaySession.ID : sessionId; + sessionId = (sessionId == null) ? SystemStompRelaySession.ID : sessionId; headers.setSessionId(sessionId); command = (command == null) ? StompCommand.SEND : command; headers.setCommandIfNotSet(command); @@ -208,12 +208,12 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler try { if (SimpMessageType.CONNECT.equals(messageType)) { message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); - RelaySession session = new RelaySession(sessionId); + StompRelaySession session = new StompRelaySession(sessionId); this.relaySessions.put(sessionId, session); session.connect(message); } else if (SimpMessageType.DISCONNECT.equals(messageType)) { - RelaySession session = this.relaySessions.remove(sessionId); + StompRelaySession session = this.relaySessions.remove(sessionId); if (session == null) { if (logger.isTraceEnabled()) { logger.trace("Session already removed, sessionId=" + sessionId); @@ -223,7 +223,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler session.forward(message); } else { - RelaySession session = this.relaySessions.get(sessionId); + StompRelaySession session = this.relaySessions.get(sessionId); if (session == null) { logger.warn("Session id=" + sessionId + " not found. Ignoring message: " + message); return; @@ -237,14 +237,14 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } - private class RelaySession { + private class StompRelaySession { private final String sessionId; private volatile StompConnection stompConnection = new StompConnection(); - private RelaySession(String sessionId) { + private StompRelaySession(String sessionId) { Assert.notNull(sessionId, "sessionId is required"); this.sessionId = sessionId; } @@ -257,11 +257,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler public void connect(final Message connectMessage) { Assert.notNull(connectMessage, "connectMessage is required"); - Composable, Message>> promise = openTcpConnection(); + Composable, Message>> promise = initConnection(); promise.consume(new Consumer, Message>>() { @Override public void accept(TcpConnection, Message> connection) { - handleTcpConnection(connection, connectMessage); + handleConnectionReady(connection, connectMessage); } }); promise.when(Throwable.class, new Consumer() { @@ -273,11 +273,13 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler }); } - protected Composable, Message>> openTcpConnection() { + protected Composable, Message>> initConnection() { return tcpClient.open(); } - protected void handleTcpConnection(TcpConnection, Message> tcpConn, final Message connectMessage) { + protected void handleConnectionReady( + TcpConnection, Message> tcpConn, final Message connectMessage) { + this.stompConnection.setTcpConnection(tcpConn); tcpConn.on().close(new Runnable() { @Override @@ -294,6 +296,13 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler forwardInternal(tcpConn, connectMessage); } + protected void connectionClosed() { + relaySessions.remove(this.sessionId); + if (this.stompConnection.isReady()) { + sendError("Lost connection to the broker"); + } + } + private void readStompFrame(Message message) { if (logger.isTraceEnabled()) { logger.trace("Reading message " + message); @@ -340,40 +349,33 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } public void forward(Message message) { - - if (!this.stompConnection.isReady()) { - logger.warn("Message sent to relay before it was CONNECTED. Discarding message: " + message); - return; - } - - forwardInternal(message); - } - - private boolean forwardInternal(final Message message) { TcpConnection, Message> tcpConnection = this.stompConnection.getReadyConnection(); if (tcpConnection == null) { - return false; + logger.warn("Connection to STOMP broker is not active, discarding message: " + message); + return; } - return forwardInternal(tcpConnection, message); + forwardInternal(tcpConnection, message); } - @SuppressWarnings("unchecked") - private boolean forwardInternal(TcpConnection, Message> tcpConnection, final Message message) { + private boolean forwardInternal( + TcpConnection, Message> tcpConnection, Message message) { Assert.isInstanceOf(byte[].class, message.getPayload(), "Message's payload must be a byte[]"); + @SuppressWarnings("unchecked") + Message byteMessage = (Message) message; + if (logger.isTraceEnabled()) { logger.trace("Forwarding to STOMP broker, message: " + message); } StompCommand command = StompHeaderAccessor.wrap(message).getCommand(); - if (command == StompCommand.DISCONNECT) { this.stompConnection.setDisconnected(); } final Deferred> deferred = new DeferredPromiseSpec().get(); - tcpConnection.send((Message)message, new Consumer() { + tcpConnection.send(byteMessage, new Consumer() { @Override public void accept(Boolean success) { deferred.accept(success); @@ -398,13 +400,6 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } return (success != null) ? success : false; } - - protected void connectionClosed() { - relaySessions.remove(this.sessionId); - if (this.stompConnection.isReady()) { - sendError("Lost connection to the broker"); - } - } } private static class StompConnection { @@ -452,7 +447,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } } - private class SystemRelaySession extends RelaySession { + private class SystemStompRelaySession extends StompRelaySession { private static final long HEARTBEAT_RECEIVE_MULTIPLIER = 3; @@ -465,7 +460,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private final byte[] heartbeatPayload = new byte[] {'\n'}; - public SystemRelaySession() { + public SystemStompRelaySession() { super(ID); } @@ -480,7 +475,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } @Override - protected Composable, Message>> openTcpConnection() { + protected Composable, Message>> initConnection() { return tcpClient.open(new Reconnect() { @Override public Tuple2 reconnect(InetSocketAddress address, int attempt) {