diff --git a/build.gradle b/build.gradle index 90bf8b5ee53..3c35a4d35f6 100644 --- a/build.gradle +++ b/build.gradle @@ -71,6 +71,7 @@ configure(allprojects) { project -> maven { url "https://repository.apache.org/content/repositories/releases" } // tomcat 8 RC3 maven { url "https://repository.apache.org/content/repositories/snapshots" } // tomcat-websocket-* snapshots maven { url "https://maven.java.net/content/repositories/releases" } // javax.websocket, tyrus + maven { url 'http://repo.springsource.org/libs-snapshot' } // reactor } dependencies { @@ -352,8 +353,8 @@ project("spring-messaging") { optional(project(":spring-websocket")) optional(project(":spring-webmvc")) optional("com.fasterxml.jackson.core:jackson-databind:2.2.0") - optional("org.projectreactor:reactor-core:1.0.0.M2") - optional("org.projectreactor:reactor-tcp:1.0.0.M2") + optional("org.projectreactor:reactor-core:1.0.0.M3") + optional("org.projectreactor:reactor-tcp:1.0.0.M3") optional("com.lmax:disruptor:3.1.1") optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815") optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815") diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java index 683e493e4ab..9b0e4fdde26 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java @@ -23,6 +23,7 @@ import org.springframework.messaging.core.MessageSendingOperations; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.method.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.annotation.SubscribeEvent; import org.springframework.messaging.support.MessageBuilder; @@ -97,6 +98,7 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); headers.setSessionId(this.sessionId); headers.setSubscriptionId(this.subscriptionId); + headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); return MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java index 3e1050c9205..0f54c67fd80 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java @@ -54,7 +54,8 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry { public ServletStompEndpointRegistry(WebSocketHandler webSocketHandler, - MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) { + MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler, + boolean handleConnect) { Assert.notNull(webSocketHandler); Assert.notNull(userQueueSuffixResolver); @@ -63,6 +64,7 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry { this.subProtocolWebSocketHandler = findSubProtocolWebSocketHandler(webSocketHandler); this.stompHandler = new StompProtocolHandler(); this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver); + this.stompHandler.setHandleConnect(handleConnect); this.sockJsScheduler = defaultSockJsTaskScheduler; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java index ef077487fab..ba5fb13baa2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java @@ -57,8 +57,10 @@ public abstract class WebSocketMessageBrokerConfigurationSupport { @Bean public HandlerMapping brokerWebSocketHandlerMapping() { - ServletStompEndpointRegistry registry = new ServletStompEndpointRegistry( - subProtocolWebSocketHandler(), userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler()); + boolean brokerRelayConfigured = getMessageBrokerConfigurer().getStompBrokerRelay() != null; + ServletStompEndpointRegistry registry = new ServletStompEndpointRegistry(subProtocolWebSocketHandler(), + userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler(), !brokerRelayConfigured); + registerStompEndpoints(registry); AbstractHandlerMapping hm = registry.getHandlerMapping(); hm.setOrder(1); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java index 88c73ca645f..beaef084538 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java @@ -17,14 +17,9 @@ package org.springframework.messaging.simp.stomp; import java.net.InetSocketAddress; -import java.nio.charset.Charset; -import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.Map; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicReference; import org.springframework.messaging.Message; @@ -34,7 +29,6 @@ import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.handler.AbstractBrokerMessageHandler; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; import reactor.core.Environment; import reactor.core.composable.Composable; @@ -45,8 +39,6 @@ import reactor.function.Consumer; import reactor.tcp.Reconnect; import reactor.tcp.TcpClient; import reactor.tcp.TcpConnection; -import reactor.tcp.encoding.DelimitedCodec; -import reactor.tcp.encoding.StandardCodecs; import reactor.tcp.netty.NettyTcpClient; import reactor.tcp.spec.TcpClientSpec; import reactor.tuple.Tuple; @@ -74,13 +66,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler private String systemPasscode = "guest"; - private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); - private Environment environment; - private TcpClient tcpClient; + private TcpClient, Message> tcpClient; - private final Map relaySessions = new ConcurrentHashMap(); + private final Map relaySessions = new ConcurrentHashMap(); /** @@ -159,16 +149,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @Override protected void startInternal() { this.environment = new Environment(); - this.tcpClient = new TcpClientSpec(NettyTcpClient.class) + this.tcpClient = new TcpClientSpec, Message>(NettyTcpClient.class) .env(this.environment) - .codec(new DelimitedCodec((byte) 0, true, StandardCodecs.STRING_CODEC)) + .codec(new StompCodec()) .connect(this.relayHost, this.relayPort) .get(); if (logger.isDebugEnabled()) { logger.debug("Initializing \"system\" TCP connection"); } - SystemRelaySession session = new SystemRelaySession(); + SystemStompRelaySession session = new SystemStompRelaySession(); this.relaySessions.put(session.getId(), session); session.connect(); } @@ -199,35 +189,31 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler SimpMessageType messageType = headers.getMessageType(); if (SimpMessageType.MESSAGE.equals(messageType)) { - sessionId = (sessionId == null) ? SystemRelaySession.ID : sessionId; + sessionId = (sessionId == null) ? SystemStompRelaySession.ID : sessionId; headers.setSessionId(sessionId); command = (command == null) ? StompCommand.SEND : command; headers.setCommandIfNotSet(command); message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); } - if (headers.getCommand() == null) { - logger.error("No STOMP command, ignoring message: " + message); - return; - } if (sessionId == null) { logger.error("No sessionId, ignoring message: " + message); return; } - if (command.requiresDestination() && !checkDestinationPrefix(destination)) { + + if (command != null && command.requiresDestination() && !checkDestinationPrefix(destination)) { return; } try { if (SimpMessageType.CONNECT.equals(messageType)) { - headers.setHeartbeat(0, 0); message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); - RelaySession session = new RelaySession(sessionId); + StompRelaySession session = new StompRelaySession(sessionId); this.relaySessions.put(sessionId, session); session.connect(message); } else if (SimpMessageType.DISCONNECT.equals(messageType)) { - RelaySession session = this.relaySessions.remove(sessionId); + StompRelaySession session = this.relaySessions.remove(sessionId); if (session == null) { if (logger.isTraceEnabled()) { logger.trace("Session already removed, sessionId=" + sessionId); @@ -237,7 +223,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler session.forward(message); } else { - RelaySession session = this.relaySessions.get(sessionId); + StompRelaySession session = this.relaySessions.get(sessionId); if (session == null) { logger.warn("Session id=" + sessionId + " not found. Ignoring message: " + message); return; @@ -251,18 +237,14 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } - private class RelaySession { + private class StompRelaySession { private final String sessionId; - private final BlockingQueue> messageQueue = new LinkedBlockingQueue>(50); - private volatile StompConnection stompConnection = new StompConnection(); - private final Object monitor = new Object(); - - private RelaySession(String sessionId) { + private StompRelaySession(String sessionId) { Assert.notNull(sessionId, "sessionId is required"); this.sessionId = sessionId; } @@ -275,14 +257,14 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler public void connect(final Message connectMessage) { Assert.notNull(connectMessage, "connectMessage is required"); - Composable> connectionComposable = openTcpConnection(); - connectionComposable.consume(new Consumer>() { + Composable, Message>> promise = initConnection(); + promise.consume(new Consumer, Message>>() { @Override - public void accept(TcpConnection connection) { - handleTcpConnection(connection, connectMessage); + public void accept(TcpConnection, Message> connection) { + handleConnectionReady(connection, connectMessage); } }); - connectionComposable.when(Throwable.class, new Consumer() { + promise.when(Throwable.class, new Consumer() { @Override public void accept(Throwable ex) { relaySessions.remove(sessionId); @@ -291,41 +273,44 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler }); } - protected Composable> openTcpConnection() { + protected Composable, Message>> initConnection() { return tcpClient.open(); } - protected void handleTcpConnection(TcpConnection tcpConn, final Message connectMessage) { + protected void handleConnectionReady( + TcpConnection, Message> tcpConn, final Message connectMessage) { + this.stompConnection.setTcpConnection(tcpConn); - tcpConn.in().consume(new Consumer() { + tcpConn.on().close(new Runnable() { @Override - public void accept(String message) { + public void run() { + connectionClosed(); + } + }); + tcpConn.in().consume(new Consumer>() { + @Override + public void accept(Message message) { readStompFrame(message); } }); forwardInternal(tcpConn, connectMessage); } - private void readStompFrame(String stompFrame) { - - // heartbeat - if (StringUtils.isEmpty(stompFrame)) { - return; + protected void connectionClosed() { + relaySessions.remove(this.sessionId); + if (this.stompConnection.isReady()) { + sendError("Lost connection to the broker"); } + } - Message message = stompMessageConverter.toMessage(stompFrame); + private void readStompFrame(Message message) { if (logger.isTraceEnabled()) { logger.trace("Reading message " + message); } StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED == headers.getCommand()) { - synchronized(this.monitor) { - this.stompConnection.setReady(); - publishBrokerAvailableEvent(); - flushMessages(); - } - return; + connected(headers, this.stompConnection); } headers.setSessionId(this.sessionId); @@ -333,12 +318,21 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler sendMessageToClient(message); } + protected void connected(StompHeaderAccessor headers, StompConnection stompConnection) { + this.stompConnection.setReady(); + publishBrokerAvailableEvent(); + } + private void handleTcpClientFailure(String message, Throwable ex) { if (logger.isErrorEnabled()) { logger.error(message + ", sessionId=" + this.sessionId, ex); } + disconnected(message); + } + + protected void disconnected(String errorMessage) { this.stompConnection.setDisconnected(); - sendError(message); + sendError(errorMessage); publishBrokerUnavailableEvent(); } @@ -355,47 +349,33 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } public void forward(Message message) { - - if (!this.stompConnection.isReady()) { - synchronized(this.monitor) { - if (!this.stompConnection.isReady()) { - this.messageQueue.add(message); - if (logger.isTraceEnabled()) { - logger.trace("Not connected, message queued. Queue size=" + this.messageQueue.size()); - } - return; - } - } - } - - if (this.messageQueue.isEmpty()) { - forwardInternal(message); - } - else { - this.messageQueue.add(message); - flushMessages(); - } - } - - private boolean forwardInternal(final Message message) { - TcpConnection tcpConnection = this.stompConnection.getReadyConnection(); + TcpConnection, Message> tcpConnection = this.stompConnection.getReadyConnection(); if (tcpConnection == null) { - return false; + logger.warn("Connection to STOMP broker is not active, discarding message: " + message); + return; } - return forwardInternal(tcpConnection, message); + forwardInternal(tcpConnection, message); } - private boolean forwardInternal(TcpConnection tcpConnection, final Message message) { + private boolean forwardInternal( + TcpConnection, Message> tcpConnection, Message message) { + + Assert.isInstanceOf(byte[].class, message.getPayload(), "Message's payload must be a byte[]"); + + @SuppressWarnings("unchecked") + Message byteMessage = (Message) message; if (logger.isTraceEnabled()) { logger.trace("Forwarding to STOMP broker, message: " + message); } - byte[] bytes = stompMessageConverter.fromMessage(message); - String payload = new String(bytes, Charset.forName("UTF-8")); + StompCommand command = StompHeaderAccessor.wrap(message).getCommand(); + if (command == StompCommand.DISCONNECT) { + this.stompConnection.setDisconnected(); + } final Deferred> deferred = new DeferredPromiseSpec().get(); - tcpConnection.send(payload, new Consumer() { + tcpConnection.send(byteMessage, new Consumer() { @Override public void accept(Boolean success) { deferred.accept(success); @@ -409,7 +389,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler handleTcpClientFailure("Timed out waiting for message to be forwarded to the broker", null); } else if (!success) { - if (StompHeaderAccessor.wrap(message).getCommand() != StompCommand.DISCONNECT) { + if (command != StompCommand.DISCONNECT) { handleTcpClientFailure("Failed to forward message to the broker", null); } } @@ -420,32 +400,26 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } return (success != null) ? success : false; } - - private void flushMessages() { - List> messages = new ArrayList>(); - this.messageQueue.drainTo(messages); - for (Message message : messages) { - if (!forwardInternal(message)) { - return; - } - } - } } private static class StompConnection { - private volatile TcpConnection connection; + private volatile TcpConnection, Message> connection; - private AtomicReference> readyConnection = - new AtomicReference>(); + private AtomicReference, Message>> readyConnection = + new AtomicReference, Message>>(); - public void setTcpConnection(TcpConnection connection) { + public void setTcpConnection(TcpConnection, Message> connection) { Assert.notNull(connection, "connection must not be null"); this.connection = connection; } - public TcpConnection getReadyConnection() { + /** + * Return the underlying {@link TcpConnection} but only after the CONNECTED STOMP + * frame is received. + */ + public TcpConnection, Message> getReadyConnection() { return this.readyConnection.get(); } @@ -459,7 +433,12 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler public void setDisconnected() { this.readyConnection.set(null); - this.connection = null; + + TcpConnection, Message> localConnection = this.connection; + if (localConnection != null) { + localConnection.close(); + this.connection = null; + } } @Override @@ -468,12 +447,20 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler } } - private class SystemRelaySession extends RelaySession { + private class SystemStompRelaySession extends StompRelaySession { + + private static final long HEARTBEAT_RECEIVE_MULTIPLIER = 3; + + private static final long HEARTBEAT_SEND_INTERVAL = 10000; + + private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000; public static final String ID = "stompRelaySystemSessionId"; + private final byte[] heartbeatPayload = new byte[] {'\n'}; - public SystemRelaySession() { + + public SystemStompRelaySession() { super(ID); } @@ -482,13 +469,13 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler headers.setAcceptVersion("1.1,1.2"); headers.setLogin(systemLogin); headers.setPasscode(systemPasscode); - headers.setHeartbeat(0,0); + headers.setHeartbeat(HEARTBEAT_SEND_INTERVAL, HEARTBEAT_RECEIVE_INTERVAL); Message connectMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); super.connect(connectMessage); } @Override - protected Composable> openTcpConnection() { + protected Composable, Message>> initConnection() { return tcpClient.open(new Reconnect() { @Override public Tuple2 reconnect(InetSocketAddress address, int attempt) { @@ -497,6 +484,47 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler }); } + @Override + protected void connectionClosed() { + publishBrokerUnavailableEvent(); + } + + @Override + protected void connected(StompHeaderAccessor headers, final StompConnection stompConnection) { + long brokerReceiveInterval = headers.getHeartbeat()[1]; + + if (HEARTBEAT_SEND_INTERVAL > 0 && brokerReceiveInterval > 0) { + long interval = Math.max(HEARTBEAT_SEND_INTERVAL, brokerReceiveInterval); + stompConnection.connection.on().writeIdle(interval, new Runnable() { + + @Override + public void run() { + TcpConnection, Message> connection = stompConnection.connection; + if (connection != null) { + connection.send(MessageBuilder.withPayload(heartbeatPayload).build()); + } + } + + }); + } + + long brokerSendInterval = headers.getHeartbeat()[0]; + if (HEARTBEAT_RECEIVE_INTERVAL > 0 && brokerSendInterval > 0) { + final long interval = + Math.max(HEARTBEAT_RECEIVE_INTERVAL, brokerSendInterval) * HEARTBEAT_RECEIVE_MULTIPLIER; + stompConnection.connection.on().readIdle(interval, new Runnable() { + @Override + public void run() { + String message = "Broker hearbeat missed: connection idle for more than " + interval + "ms"; + logger.warn(message); + disconnected(message); + } + }); + } + + super.connected(headers, stompConnection); + } + @Override protected void sendMessageToClient(Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java new file mode 100644 index 00000000000..bef1726ae56 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import org.springframework.messaging.Message; + +import reactor.function.Consumer; +import reactor.function.Function; +import reactor.io.Buffer; +import reactor.tcp.encoding.Codec; + +/** + * A Reactor TCP {@link Codec} for sending and receiving STOMP messages + * + * @author Andy Wilkinson + * @since 4.0 + */ +public class StompCodec implements Codec, Message> { + + private static final StompDecoder DECODER = new StompDecoder(); + + private static final Function, Buffer> ENCODER_FUNCTION = new Function, Buffer>() { + + private final StompEncoder encoder = new StompEncoder(); + + @Override + public Buffer apply(Message message) { + return Buffer.wrap(this.encoder.encode(message)); + } + }; + + @Override + public Function> decoder(final Consumer> next) { + return new Function>() { + + @Override + public Message apply(Buffer buffer) { + while (buffer.remaining() > 0) { + Message message = DECODER.decode(buffer.byteBuffer()); + if (message != null) { + next.accept(message); + } + } + return null; + } + }; + } + + @Override + public Function, Buffer> encoder() { + return ENCODER_FUNCTION; + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java new file mode 100644 index 00000000000..e876df99ad2 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java @@ -0,0 +1,171 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * A decoder for STOMP frames + * + * @author awilkinson + * @since 4.0 + */ +public class StompDecoder { + + private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + + private static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'}; + + private final Log logger = LogFactory.getLog(StompDecoder.class); + + + /** + * Decodes a STOMP frame in the given {@code buffer} into a {@link Message}. + * + * @param buffer The buffer to decode the frame from + * @return The decoded message + */ + public Message decode(ByteBuffer buffer) { + skipLeadingEol(buffer); + String command = readCommand(buffer); + if (command.length() > 0) { + MultiValueMap headers = readHeaders(buffer); + byte[] payload = readPayload(buffer, headers); + + Message decodedMessage = MessageBuilder.withPayloadAndHeaders(payload, + StompHeaderAccessor.create(StompCommand.valueOf(command), headers)).build(); + + if (logger.isTraceEnabled()) { + logger.trace("Decoded " + decodedMessage); + } + + return decodedMessage; + } + else { + if (logger.isTraceEnabled()) { + logger.trace("Decoded heartbeat"); + } + return MessageBuilder.withPayload(HEARTBEAT_PAYLOAD).build(); + } + + } + + private String readCommand(ByteBuffer buffer) { + ByteArrayOutputStream command = new ByteArrayOutputStream(); + while (buffer.remaining() > 0 && !isEol(buffer)) { + command.write(buffer.get()); + } + return new String(command.toByteArray(), UTF8_CHARSET); + } + + private MultiValueMap readHeaders(ByteBuffer buffer) { + MultiValueMap headers = new LinkedMultiValueMap(); + while (true) { + ByteArrayOutputStream headerStream = new ByteArrayOutputStream(); + while (buffer.remaining() > 0 && !isEol(buffer)) { + headerStream.write(buffer.get()); + } + if (headerStream.size() > 0) { + String header = new String(headerStream.toByteArray(), UTF8_CHARSET); + int colonIndex = header.indexOf(':'); + if (colonIndex <= 0 || colonIndex == header.length() - 1) { + throw new StompConversionException( + "Illegal header: '" + header + "'. A header must be of the form : headers) { + String contentLengthString = headers.getFirst("content-length"); + if (contentLengthString != null) { + int contentLength = Integer.valueOf(contentLengthString); + byte[] payload = new byte[contentLength]; + buffer.get(payload); + if (buffer.remaining() < 1 || buffer.get() != 0) { + throw new StompConversionException("Frame must be terminated with a null octect"); + } + return payload; + } + else { + ByteArrayOutputStream payload = new ByteArrayOutputStream(); + while (buffer.remaining() > 0) { + byte b = buffer.get(); + if (b == 0) { + return payload.toByteArray(); + } + else { + payload.write(b); + } + } + } + + throw new StompConversionException("Frame must be terminated with a null octect"); + } + + private void skipLeadingEol(ByteBuffer buffer) { + while (true) { + if (!isEol(buffer)) { + break; + } + } + } + + private boolean isEol(ByteBuffer buffer) { + if (buffer.remaining() > 0) { + byte b = buffer.get(); + if (b == '\n') { + return true; + } + else if (b == '\r') { + if (buffer.remaining() > 0 && buffer.get() == '\n') { + return true; + } + else { + throw new StompConversionException("'\\r' must be followed by '\\n'"); + } + } + buffer.position(buffer.position() - 1); + } + return false; + } +} \ No newline at end of file diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java new file mode 100644 index 00000000000..aa342c2e0f8 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.List; +import java.util.Map.Entry; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.messaging.Message; + +/** + * An encoder for STOMP frames + * + * @author Andy Wilkinson + * @since 4.0 + */ +public final class StompEncoder { + + private static final byte LF = '\n'; + + private static final byte COLON = ':'; + + private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + + private final Log logger = LogFactory.getLog(StompEncoder.class); + + /** + * Encodes the given STOMP {@code message} into a {@code byte[]} + * + * @param message The message to encode + * + * @return The encoded message + */ + public byte[] encode(Message message) { + try { + if (logger.isTraceEnabled()) { + logger.trace("Encoding " + message); + } + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream output = new DataOutputStream(baos); + + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); + + if (isHeartbeat(headers)) { + output.write(message.getPayload()); + } else { + writeCommand(headers, output); + writeHeaders(headers, message, output); + output.write(LF); + writeBody(message, output); + output.write((byte)0); + } + + return baos.toByteArray(); + } + catch (IOException e) { + throw new StompConversionException("Failed to encode STOMP frame", e); + } + } + + private boolean isHeartbeat(StompHeaderAccessor headers) { + return headers.getCommand() == null; + } + + private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException { + output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET)); + output.write(LF); + } + + private void writeHeaders(StompHeaderAccessor headers, Message message, DataOutputStream output) + throws IOException { + + for (Entry> entry : headers.toStompHeaderMap().entrySet()) { + byte[] key = getUtf8BytesEscapingIfNecessary(entry.getKey(), headers); + for (String value : entry.getValue()) { + output.write(key); + output.write(COLON); + output.write(getUtf8BytesEscapingIfNecessary(value, headers)); + output.write(LF); + } + } + if (headers.getCommand() == StompCommand.SEND || + headers.getCommand() == StompCommand.MESSAGE || + headers.getCommand() == StompCommand.ERROR) { + output.write("content-length:".getBytes(UTF8_CHARSET)); + output.write(Integer.toString(message.getPayload().length).getBytes(UTF8_CHARSET)); + output.write(LF); + } + } + + private void writeBody(Message message, DataOutputStream output) throws IOException { + output.write(message.getPayload()); + } + + private byte[] getUtf8BytesEscapingIfNecessary(String input, StompHeaderAccessor headers) { + if (headers.getCommand() != StompCommand.CONNECT && headers.getCommand() != StompCommand.CONNECTED) { + return escape(input).getBytes(UTF8_CHARSET); + } + else { + return input.getBytes(UTF8_CHARSET); + } + } + + private String escape(String input) { + return input.replaceAll("\\\\", "\\\\\\\\") + .replaceAll(":", "\\\\c") + .replaceAll("\n", "\\\\n") + .replaceAll("\r", "\\\\r"); + } +} \ No newline at end of file diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java index f39c43fe664..f3771059446 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompHeaderAccessor.java @@ -82,6 +82,8 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { public static final String STOMP_HEARTBEAT_HEADER = "heart-beat"; + private static final long[] DEFAULT_HEARTBEAT = new long[] {0, 0}; + // Other header names @@ -185,7 +187,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { result.put(STOMP_CONTENT_TYPE_HEADER, Arrays.asList(contentType.toString())); } - if (getCommand().requiresSubscriptionId()) { + if (getCommand() != null && getCommand().requiresSubscriptionId()) { String subscriptionId = getSubscriptionId(); if (subscriptionId != null) { String name = StompCommand.MESSAGE.equals(getCommand()) ? STOMP_SUBSCRIPTION_HEADER : STOMP_ID_HEADER; @@ -252,7 +254,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor { public long[] getHeartbeat() { String rawValue = getFirstNativeHeader(STOMP_HEARTBEAT_HEADER); if (!StringUtils.hasText(rawValue)) { - return null; + return Arrays.copyOf(DEFAULT_HEARTBEAT, 2); } String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue); return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])}; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java deleted file mode 100644 index a410c615c5a..00000000000 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.messaging.simp.stomp; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.nio.charset.Charset; -import java.util.List; -import java.util.Map.Entry; - -import org.springframework.messaging.Message; -import org.springframework.messaging.support.MessageBuilder; -import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; - - -/** - * @author Gary Russell - * @author Rossen Stoyanchev - * @since 4.0 - */ -public class StompMessageConverter { - - private static final Charset STOMP_CHARSET = Charset.forName("UTF-8"); - - public static final byte LF = 0x0a; - - public static final byte CR = 0x0d; - - private static final byte COLON = ':'; - - /** - * @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String. - */ - public Message toMessage(Object stompContent) { - - byte[] byteContent = null; - if (stompContent instanceof String) { - byteContent = ((String) stompContent).getBytes(STOMP_CHARSET); - } - else if (stompContent instanceof byte[]){ - byteContent = (byte[]) stompContent; - } - else { - throw new IllegalArgumentException( - "stompContent is neither String nor byte[]: " + stompContent.getClass()); - } - - int totalLength = byteContent.length; - if (byteContent[totalLength-1] == 0) { - totalLength--; - } - - int payloadIndex = findIndexOfPayload(byteContent); - if (payloadIndex == 0) { - throw new StompConversionException("No command found"); - } - - String headerContent = new String(byteContent, 0, payloadIndex, STOMP_CHARSET); - Parser parser = new Parser(headerContent); - - StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim()); - Assert.notNull(command, "No command found"); - - MultiValueMap headers = new LinkedMultiValueMap(); - while (parser.hasNext()) { - String header = parser.nextToken(COLON); - if (header != null) { - if (parser.hasNext()) { - String value = parser.nextToken(LF); - headers.add(header, value); - } - else { - throw new StompConversionException("Parse exception for " + headerContent); - } - } - } - - byte[] payload = new byte[totalLength - payloadIndex]; - System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex); - StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers); - return MessageBuilder.withPayloadAndHeaders(payload, stompHeaders).build(); - } - - private int findIndexOfPayload(byte[] bytes) { - int i; - // ignore any leading EOL from the previous message - for (i = 0; i < bytes.length; i++) { - if (bytes[i] != '\n' && bytes[i] != '\r') { - break; - } - bytes[i] = ' '; - } - int index = 0; - for (; i < bytes.length - 1; i++) { - if (bytes[i] == LF && bytes[i+1] == LF) { - index = i + 2; - break; - } - if ((i < (bytes.length - 3)) && - (bytes[i] == CR && bytes[i+1] == LF && bytes[i+2] == CR && bytes[i+3] == LF)) { - index = i + 4; - break; - } - } - if (i >= bytes.length) { - throw new StompConversionException("No end of headers found"); - } - return index; - } - - public byte[] fromMessage(Message message) { - - byte[] payload; - if (message.getPayload() instanceof byte[]) { - payload = (byte[]) message.getPayload(); - } - else { - throw new IllegalArgumentException( - "stompContent is not byte[]: " + message.getPayload().getClass()); - } - - ByteArrayOutputStream out = new ByteArrayOutputStream(); - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - - try { - out.write(stompHeaders.getCommand().toString().getBytes("UTF-8")); - out.write(LF); - for (Entry> entry : stompHeaders.toStompHeaderMap().entrySet()) { - String key = entry.getKey(); - key = replaceAllOutbound(key); - for (String value : entry.getValue()) { - out.write(key.getBytes("UTF-8")); - out.write(COLON); - value = replaceAllOutbound(value); - out.write(value.getBytes("UTF-8")); - out.write(LF); - } - } - out.write(LF); - out.write(payload); - out.write(0); - return out.toByteArray(); - } - catch (IOException e) { - throw new StompConversionException("Failed to serialize " + message, e); - } - } - - private String replaceAllOutbound(String key) { - return key.replaceAll("\\\\", "\\\\") - .replaceAll(":", "\\\\c") - .replaceAll("\n", "\\\\n") - .replaceAll("\r", "\\\\r"); - } - - - private class Parser { - - private final String content; - - private int offset; - - public Parser(String content) { - this.content = content; - } - - public boolean hasNext() { - return this.offset < this.content.length(); - } - - public String nextToken(byte delimiter) { - if (this.offset >= this.content.length()) { - return null; - } - int delimAt = this.content.indexOf(delimiter, this.offset); - if (delimAt == -1) { - if (this.offset == this.content.length() - 1 && delimiter == COLON && - this.content.charAt(this.offset) == LF) { - this.offset++; - return null; - } - else if (this.offset == this.content.length() - 2 && delimiter == COLON && - this.content.charAt(this.offset) == CR && - this.content.charAt(this.offset + 1) == LF) { - this.offset += 2; - return null; - } - else { - throw new StompConversionException("No delimiter found at offset " + offset + " in " + this.content); - } - } - int escapeAt = this.content.indexOf('\\', this.offset); - String token = this.content.substring(this.offset, delimAt + 1); - this.offset += token.length(); - if (escapeAt >= 0 && escapeAt < delimAt) { - char escaped = this.content.charAt(escapeAt + 1); - if (escaped == 'n' || escaped == 'c' || escaped == '\\') { - token = token.replaceAll("\\\\n", "\n") - .replaceAll("\\\\r", "\r") - .replaceAll("\\\\c", ":") - .replaceAll("\\\\\\\\", "\\\\"); - } - else { - throw new StompConversionException("Invalid escape sequence \\" + escaped); - } - } - int length = token.length(); - if (delimiter == LF && length > 1 && token.charAt(length - 2) == CR) { - return token.substring(0, length - 2); - } - else { - return token.substring(0, length - 1); - } - } - } -} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java index ce5af6da96b..6bbd655e91b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java @@ -17,6 +17,7 @@ package org.springframework.messaging.simp.stomp; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.security.Principal; import java.util.Arrays; @@ -63,10 +64,13 @@ public class StompProtocolHandler implements SubProtocolHandler { private final Log logger = LogFactory.getLog(StompProtocolHandler.class); - private final StompMessageConverter stompMessageConverter = new StompMessageConverter(); + private final StompDecoder stompDecoder = new StompDecoder(); + + private final StompEncoder stompEncoder = new StompEncoder(); private MutableUserQueueSuffixResolver queueSuffixResolver = new SimpleUserQueueSuffixResolver(); + private volatile boolean handleConnect = false; /** * Configure a resolver to use to maintain queue suffixes for user @@ -83,6 +87,29 @@ public class StompProtocolHandler implements SubProtocolHandler { return this.queueSuffixResolver; } + /** + * Configures the handling of CONNECT frames. When {@code true}, CONNECT + * frames will be handled by this handler, and a CONNECTED response will be + * sent. When {@code false}, CONNECT frames will be forwarded for + * handling by another component. + * + * @param handleConnect {@code true} if connect frames should be handled + * by this handler, {@code false} otherwise. + */ + public void setHandleConnect(boolean handleConnect) { + this.handleConnect = handleConnect; + } + + /** + * Returns whether or not this handler will handle CONNECT frames. + * + * @return Returns {@code true} if this handler will handle CONNECT frames, + * otherwise {@code false}. + */ + public boolean willHandleConnect() { + return this.handleConnect; + } + @Override public List getSupportedProtocols() { return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"); @@ -98,7 +125,8 @@ public class StompProtocolHandler implements SubProtocolHandler { try { Assert.isInstanceOf(TextMessage.class, webSocketMessage); String payload = ((TextMessage)webSocketMessage).getPayload(); - message = this.stompMessageConverter.toMessage(payload); + ByteBuffer byteBuffer = ByteBuffer.wrap(payload.getBytes(Charset.forName("UTF-8"))); + message = this.stompDecoder.decode(byteBuffer); } catch (Throwable error) { logger.error("Failed to parse STOMP frame, WebSocket message payload: ", error); @@ -117,31 +145,38 @@ public class StompProtocolHandler implements SubProtocolHandler { message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); - if (SimpMessageType.CONNECT.equals(headers.getMessageType())) { + if (this.handleConnect && SimpMessageType.CONNECT.equals(headers.getMessageType())) { handleConnect(session, message); } - - outputChannel.send(message); + else { + outputChannel.send(message); + } } catch (Throwable t) { logger.error("Terminating STOMP session due to failure to send message: ", t); sendErrorMessage(session, t); } - } /** * Handle STOMP messages going back out to WebSocket clients. */ + @SuppressWarnings("unchecked") @Override public void handleMessageToClient(WebSocketSession session, Message message) { StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - headers.setCommandIfNotSet(StompCommand.MESSAGE); + if (headers.getCommand() == null && SimpMessageType.MESSAGE == headers.getMessageType()) { + headers.setCommandIfNotSet(StompCommand.MESSAGE); + } - if (StompCommand.CONNECTED.equals(headers.getCommand())) { - // Ignore for now since we already sent it - return; + if (headers.getCommand() == StompCommand.CONNECTED) { + if (this.handleConnect) { + // Ignore since we already sent it + return; + } else { + augmentConnectedHeaders(headers, session); + } } if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) { @@ -156,7 +191,7 @@ public class StompProtocolHandler implements SubProtocolHandler { try { message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build(); - byte[] bytes = this.stompMessageConverter.fromMessage(message); + byte[] bytes = this.stompEncoder.encode((Message)message); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } catch (Throwable t) { @@ -193,30 +228,36 @@ public class StompProtocolHandler implements SubProtocolHandler { } connectedHeaders.setHeartbeat(0,0); + augmentConnectedHeaders(connectedHeaders, session); + + // TODO: security + + Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); + String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8")); + session.sendMessage(new TextMessage(payload)); + } + + private void augmentConnectedHeaders(StompHeaderAccessor headers, WebSocketSession session) { Principal principal = session.getPrincipal(); if (principal != null) { - connectedHeaders.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); - connectedHeaders.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId()); + headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); + headers.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId()); if (this.queueSuffixResolver != null) { String suffix = session.getId(); this.queueSuffixResolver.addQueueSuffix(principal.getName(), session.getId(), suffix); } } - - Message connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build(); - byte[] bytes = this.stompMessageConverter.fromMessage(connectedMessage); - session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } protected void sendErrorMessage(WebSocketSession session, Throwable error) { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); headers.setMessage(error.getMessage()); - Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - byte[] bytes = this.stompMessageConverter.fromMessage(message); + Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + String payload = new String(this.stompEncoder.encode(message), Charset.forName("UTF-8")); try { - session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); + session.sendMessage(new TextMessage(payload)); } catch (Throwable t) { // ignore diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java index 7531e8ed208..f4c190ba39d 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistryTests.java @@ -53,7 +53,7 @@ public class ServletStompEndpointRegistryTests { this.webSocketHandler = new SubProtocolWebSocketHandler(channel); this.queueSuffixResolver = new SimpleUserQueueSuffixResolver(); TaskScheduler taskScheduler = Mockito.mock(TaskScheduler.class); - this.registry = new ServletStompEndpointRegistry(webSocketHandler, queueSuffixResolver, taskScheduler); + this.registry = new ServletStompEndpointRegistry(webSocketHandler, queueSuffixResolver, taskScheduler, false); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java index ab53dc4d40f..3cf1544313c 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java @@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory; import org.junit.After; import org.junit.Before; import org.junit.Test; - import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; import org.springframework.messaging.Message; @@ -63,16 +62,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { private ExpectationMatchingEventPublisher eventPublisher; + private int port; + @Before public void setUp() throws Exception { - int port = SocketUtils.findAvailableTcpPort(61613); + this.port = SocketUtils.findAvailableTcpPort(61613); - this.activeMQBroker = new BrokerService(); - this.activeMQBroker.addConnector("stomp://localhost:" + port); - this.activeMQBroker.setStartAsync(false); - this.activeMQBroker.setDeleteAllMessagesOnStartup(true); - this.activeMQBroker.start(); + createAndStartBroker(); this.responseChannel = new ExecutorSubscribableChannel(); this.responseHandler = new ExpectationMatchingMessageHandler(); @@ -86,6 +83,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { this.relay.start(); } + private void createAndStartBroker() throws Exception { + this.activeMQBroker = new BrokerService(); + this.activeMQBroker.addConnector("stomp://localhost:" + port); + this.activeMQBroker.setStartAsync(false); + this.activeMQBroker.setDeleteAllMessagesOnStartup(true); + this.activeMQBroker.start(); + } + @After public void tearDown() throws Exception { try { @@ -102,22 +107,24 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { String sess1 = "sess1"; MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build(); this.relay.handleMessage(conn1.message); + this.responseHandler.expect(conn1); String sess2 = "sess2"; MessageExchange conn2 = MessageExchangeBuilder.connect(sess2).build(); this.relay.handleMessage(conn2.message); + this.responseHandler.expect(conn2); + + this.responseHandler.awaitAndAssert(); String subs1 = "subs1"; String destination = "/topic/test"; MessageExchange subscribe = MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build(); - this.responseHandler.expect(subscribe); - this.relay.handleMessage(subscribe.message); + this.responseHandler.expect(subscribe); this.responseHandler.awaitAndAssert(); MessageExchange send = MessageExchangeBuilder.send(destination, "foo").andExpectMessage(sess1, subs1).build(); - this.responseHandler.reset(); this.responseHandler.expect(send); this.relay.handleMessage(send.message); @@ -129,7 +136,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { stopBrokerAndAwait(); - MessageExchange connect = MessageExchangeBuilder.connect("sess1").andExpectError().build(); + MessageExchange connect = MessageExchangeBuilder.connectWithError("sess1").build(); this.responseHandler.expect(connect); this.relay.handleMessage(connect.message); @@ -137,37 +144,31 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } @Test - public void brokerUnvailableErrorFrameOnSend() throws Exception { + public void brokerBecomingUnvailableTriggersErrorFrame() throws Exception { String sess1 = "sess1"; MessageExchange connect = MessageExchangeBuilder.connect(sess1).build(); + this.responseHandler.expect(connect); + this.relay.handleMessage(connect.message); - // TODO: expect CONNECTED - Thread.sleep(2000); + this.responseHandler.awaitAndAssert(); + + this.responseHandler.expect(MessageExchangeBuilder.error(sess1).build()); stopBrokerAndAwait(); - MessageExchange subscribe = MessageExchangeBuilder.subscribe(sess1, "s1", "/topic/a").andExpectError().build(); - this.responseHandler.expect(subscribe); - - this.relay.handleMessage(subscribe.message); this.responseHandler.awaitAndAssert(); } @Test public void brokerAvailabilityEvents() throws Exception { - // TODO: expect CONNECTED - Thread.sleep(2000); - - this.eventPublisher.expect(true, false); + this.eventPublisher.expect(true); + this.eventPublisher.awaitAndAssert(); + this.eventPublisher.expect(false); stopBrokerAndAwait(); - - // TODO: remove when stop is detecteded - this.relay.handleMessage(MessageExchangeBuilder.connect("sess1").build().message); - this.eventPublisher.awaitAndAssert(); } @@ -176,37 +177,55 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { String sess1 = "sess1"; MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build(); + this.responseHandler.expect(conn1); this.relay.handleMessage(conn1.message); + this.responseHandler.awaitAndAssert(); String subs1 = "subs1"; String destination = "/topic/test"; - MessageExchange subscribe = MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build(); + MessageExchange subscribe = + MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build(); this.responseHandler.expect(subscribe); this.relay.handleMessage(subscribe.message); this.responseHandler.awaitAndAssert(); + this.responseHandler.expect(MessageExchangeBuilder.error(sess1).build()); + stopBrokerAndAwait(); - // 1st message will see ERROR frame (broker shutdown is not but should be detected) - // 2nd message will be queued (a side effect of CONNECT/CONNECTED-buffering, likely to be removed) - // Finish this once the above changes are made. + this.responseHandler.awaitAndAssert(); + + this.eventPublisher.expect(true, false); + this.eventPublisher.awaitAndAssert(); + + this.eventPublisher.expect(true); + createAndStartBroker(); + this.eventPublisher.awaitAndAssert(); + + // TODO The event publisher assertions show that the broker's back up and the system relay session + // has reconnected. We need to decide what we want the reconnect behaviour to be for client relay + // sessions and add further message sending and assertions as appropriate. At the moment any client + // sessions will be closed and an ERROR from will be sent. + } + + @Test + public void disconnectClosesRelaySessionCleanly() throws Exception { + String sess1 = "sess1"; + MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build(); + this.responseHandler.expect(conn1); + this.relay.handleMessage(conn1.message); + this.responseHandler.awaitAndAssert(); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.setSessionId(sess1); + + this.relay.handleMessage(MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build()); -/* MessageExchange send = MessageExchangeBuilder.send(destination, "foo").build(); - this.responseHandler.reset(); - this.relay.handleMessage(send.message); Thread.sleep(2000); - this.activeMQBroker.start(); - Thread.sleep(5000); - - send = MessageExchangeBuilder.send(destination, "foo").andExpectMessage(sess1, subs1).build(); - this.responseHandler.reset(); - this.responseHandler.expect(send); - this.relay.handleMessage(send.message); - + // Check that we have not received an ERROR as a result of the connection closing this.responseHandler.awaitAndAssert(); -*/ } @@ -234,58 +253,68 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { */ private static class ExpectationMatchingMessageHandler implements MessageHandler { + private final Object monitor = new Object(); + private final List expected; - private final List actual = new CopyOnWriteArrayList<>(); + private final List actual = new ArrayList<>(); - private final List> unexpected = new CopyOnWriteArrayList<>(); - - private CountDownLatch latch = new CountDownLatch(1); + private final List> unexpected = new ArrayList<>(); public ExpectationMatchingMessageHandler(MessageExchange... expected) { - this.expected = new CopyOnWriteArrayList<>(expected); + synchronized (this.monitor) { + this.expected = new CopyOnWriteArrayList<>(expected); + } } - public void expect(MessageExchange... expected) { - this.expected.addAll(Arrays.asList(expected)); + synchronized (this.monitor) { + this.expected.addAll(Arrays.asList(expected)); + } } public void awaitAndAssert() throws InterruptedException { - boolean result = this.latch.await(5000, TimeUnit.MILLISECONDS); - assertTrue(getAsString(), result && this.unexpected.isEmpty()); - } - - public void reset() { - this.latch = new CountDownLatch(1); - this.expected.clear(); - this.actual.clear(); - this.unexpected.clear(); + long endTime = System.currentTimeMillis() + 10000; + synchronized (this.monitor) { + while (!this.expected.isEmpty() && System.currentTimeMillis() < endTime) { + this.monitor.wait(500); + } + boolean result = this.expected.isEmpty(); + assertTrue(getAsString(), result && this.unexpected.isEmpty()); + } } @Override public void handleMessage(Message message) throws MessagingException { - for (MessageExchange exch : this.expected) { - if (exch.matchMessage(message)) { - if (exch.isDone()) { - this.expected.remove(exch); - this.actual.add(exch); - if (this.expected.isEmpty()) { - this.latch.countDown(); + if (StompHeaderAccessor.wrap(message).getCommand() != null) { + synchronized(this.monitor) { + for (MessageExchange exch : this.expected) { + if (exch.matchMessage(message)) { + if (exch.isDone()) { + this.expected.remove(exch); + this.actual.add(exch); + if (this.expected.isEmpty()) { + this.monitor.notifyAll(); + } + } + return; } } - return; + this.unexpected.add(message); } } - this.unexpected.add(message); } public String getAsString() { StringBuilder sb = new StringBuilder("\n"); - sb.append("INCOMPLETE:\n").append(this.expected).append("\n"); - sb.append("COMPLETE:\n").append(this.actual).append("\n"); - sb.append("UNMATCHED MESSAGES:\n").append(this.unexpected).append("\n"); + + synchronized (this.monitor) { + sb.append("INCOMPLETE:\n").append(this.expected).append("\n"); + sb.append("COMPLETE:\n").append(this.actual).append("\n"); + sb.append("UNMATCHED MESSAGES:\n").append(this.unexpected).append("\n"); + } + return sb.toString(); } } @@ -352,21 +381,28 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { this.headers = StompHeaderAccessor.wrap(message); } + public static MessageExchangeBuilder error(String sessionId) { + return new MessageExchangeBuilder(null).andExpectError(sessionId); + } public static MessageExchangeBuilder connect(String sessionId) { StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); + headers.setAcceptVersion("1.1,1.2"); Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - return new MessageExchangeBuilder(message); + + MessageExchangeBuilder builder = new MessageExchangeBuilder(message); + builder.expected.add(new StompConnectedFrameMessageMatcher(sessionId)); + return builder; } - public static MessageExchangeBuilder subscribe(String sessionId, String subscriptionId, String destination) { - StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); + public static MessageExchangeBuilder connectWithError(String sessionId) { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); headers.setSessionId(sessionId); - headers.setSubscriptionId(subscriptionId); - headers.setDestination(destination); + headers.setAcceptVersion("1.1,1.2"); Message message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); - return new MessageExchangeBuilder(message); + MessageExchangeBuilder builder = new MessageExchangeBuilder(message); + return builder.andExpectError(); } public static MessageExchangeBuilder subscribeWithReceipt(String sessionId, String subscriptionId, @@ -514,35 +550,48 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { } } + private static class StompConnectedFrameMessageMatcher extends StompFrameMessageMatcher { + + + public StompConnectedFrameMessageMatcher(String sessionId) { + super(StompCommand.CONNECTED, sessionId); + } + + } + private static class ExpectationMatchingEventPublisher implements ApplicationEventPublisher { - private final List expected = new CopyOnWriteArrayList<>(); + private final List expected = new ArrayList<>(); - private final List actual = new CopyOnWriteArrayList<>(); + private final List actual = new ArrayList<>(); - private CountDownLatch latch = new CountDownLatch(1); + private final Object monitor = new Object(); public void expect(Boolean... expected) { - this.expected.addAll(Arrays.asList(expected)); + synchronized (this.monitor) { + this.expected.addAll(Arrays.asList(expected)); + } } public void awaitAndAssert() throws InterruptedException { - if (this.expected.size() == this.actual.size()) { + synchronized(this.monitor) { + long endTime = System.currentTimeMillis() + 5000; + while (this.expected.size() != this.actual.size() && System.currentTimeMillis() < endTime) { + this.monitor.wait(500); + } assertEquals(this.expected, this.actual); } - else { - assertTrue("Expected=" + this.expected + ", actual=" + this.actual, - this.latch.await(5, TimeUnit.SECONDS)); - } } @Override public void publishEvent(ApplicationEvent event) { if (event instanceof BrokerAvailabilityEvent) { - this.actual.add(((BrokerAvailabilityEvent) event).isBrokerAvailable()); - if (this.actual.size() == this.expected.size()) { - this.latch.countDown(); + synchronized(this.monitor) { + this.actual.add(((BrokerAvailabilityEvent) event).isBrokerAvailable()); + if (this.actual.size() == this.expected.size()) { + this.monitor.notifyAll(); + } } } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java new file mode 100644 index 00000000000..887c0f9e3ae --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java @@ -0,0 +1,212 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +import reactor.function.Consumer; +import reactor.function.Function; +import reactor.io.Buffer; + +import static org.junit.Assert.*; + + +/** + * + * @author awilkinson + */ +public class StompCodecTests { + + private final ArgumentCapturingConsumer> consumer = new ArgumentCapturingConsumer>(); + + private final Function> decoder = new StompCodec().decoder(consumer); + + @Test + public void decodeFrameWithCrLfEols() { + Message frame = decode("DISCONNECT\r\n\r\n\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.DISCONNECT, headers.getCommand()); + assertEquals(0, headers.toStompHeaderMap().size()); + assertEquals(0, frame.getPayload().length); + } + + @Test + public void decodeFrameWithNoHeadersAndNoBody() { + Message frame = decode("DISCONNECT\n\n\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.DISCONNECT, headers.getCommand()); + assertEquals(0, headers.toStompHeaderMap().size()); + assertEquals(0, frame.getPayload().length); + } + + @Test + public void decodeFrameWithNoBody() { + String accept = "accept-version:1.1\n"; + String host = "host:github.org\n"; + + Message frame = decode("CONNECT\n" + accept + host + "\n\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.CONNECT, headers.getCommand()); + + assertEquals(2, headers.toStompHeaderMap().size()); + assertEquals("1.1", headers.getFirstNativeHeader("accept-version")); + assertEquals("github.org", headers.getHost()); + + assertEquals(0, frame.getPayload().length); + } + + @Test + public void decodeFrame() throws UnsupportedEncodingException { + Message frame = decode("SEND\ndestination:test\n\nThe body of the message\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.SEND, headers.getCommand()); + + assertEquals(1, headers.toStompHeaderMap().size()); + assertEquals("test", headers.getDestination()); + + String bodyText = new String(frame.getPayload()); + assertEquals("The body of the message", bodyText); + } + + @Test + public void decodeFrameWithContentLength() { + Message frame = decode("SEND\ncontent-length:23\n\nThe body of the message\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.SEND, headers.getCommand()); + + assertEquals(1, headers.toStompHeaderMap().size()); + assertEquals(Integer.valueOf(23), headers.getContentLength()); + + String bodyText = new String(frame.getPayload()); + assertEquals("The body of the message", bodyText); + } + + @Test + public void decodeFrameWithNullOctectsInTheBody() { + Message frame = decode("SEND\ncontent-length:23\n\nThe b\0dy \0f the message\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.SEND, headers.getCommand()); + + assertEquals(1, headers.toStompHeaderMap().size()); + assertEquals(Integer.valueOf(23), headers.getContentLength()); + + String bodyText = new String(frame.getPayload()); + assertEquals("The b\0dy \0f the message", bodyText); + } + + @Test + public void decodeFrameWithEscapedHeaders() { + Message frame = decode("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertEquals(StompCommand.DISCONNECT, headers.getCommand()); + + assertEquals(1, headers.toStompHeaderMap().size()); + assertEquals("alpha:bravo\r\n\\", headers.getFirstNativeHeader("a:\r\n\\b")); + } + + @Test + public void decodeMultipleFramesFromSameBuffer() { + String frame1 = "SEND\ndestination:test\n\nThe body of the message\0"; + String frame2 = "DISCONNECT\n\n\0"; + + Buffer buffer = Buffer.wrap(frame1 + frame2); + + final List> messages = new ArrayList>(); + new StompCodec().decoder(new Consumer>() { + @Override + public void accept(Message message) { + messages.add(message); + } + }).apply(buffer); + + assertEquals(2, messages.size()); + assertEquals(StompCommand.SEND, StompHeaderAccessor.wrap(messages.get(0)).getCommand()); + assertEquals(StompCommand.DISCONNECT, StompHeaderAccessor.wrap(messages.get(1)).getCommand()); + } + + @Test + public void encodeFrameWithNoHeadersAndNoBody() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + + Message frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + + assertEquals("DISCONNECT\n\n\0", new StompCodec().encoder().apply(frame).asString()); + } + + @Test + public void encodeFrameWithHeaders() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setAcceptVersion("1.2"); + headers.setHost("github.org"); + + Message frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + + String frameString = new StompCodec().encoder().apply(frame).asString(); + + assertTrue(frameString.equals("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0") || + frameString.equals("CONNECT\nhost:github.org\naccept-version:1.2\n\n\0")); + } + + @Test + public void encodeFrameWithHeadersThatShouldBeEscaped() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\"); + + Message frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build(); + + assertEquals("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0", new StompCodec().encoder().apply(frame).asString()); + } + + @Test + public void encodeFrameWithHeadersBody() { + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.addNativeHeader("a", "alpha"); + + Message frame = MessageBuilder.withPayloadAndHeaders("Message body".getBytes(), headers).build(); + + assertEquals("SEND\na:alpha\ncontent-length:12\n\nMessage body\0", new StompCodec().encoder().apply(frame).asString()); + } + + private Message decode(String stompFrame) { + this.decoder.apply(Buffer.wrap(stompFrame)); + return consumer.arguments.get(0); + } + + private static final class ArgumentCapturingConsumer implements Consumer { + + private final List arguments = new ArrayList(); + + @Override + public void accept(T t) { + arguments.add(t); + } + + } +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java deleted file mode 100644 index f2716271378..00000000000 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.messaging.simp.stomp; - -import java.util.Collections; -import java.util.Map; - -import org.junit.Before; -import org.junit.Test; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHeaders; -import org.springframework.messaging.simp.SimpMessageHeaderAccessor; -import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.web.socket.TextMessage; - -import static org.junit.Assert.*; - -/** - * @author Gary Russell - * @author Rossen Stoyanchev - */ -public class StompMessageConverterTests { - - private StompMessageConverter converter; - - - @Before - public void setup() { - this.converter = new StompMessageConverter(); - } - - @Test - public void connectFrame() throws Exception { - - String accept = "accept-version:1.1"; - String host = "host:github.org"; - - TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT) - .headers(accept, host).build(); - - @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(textMessage.getPayload()); - - assertEquals(0, message.getPayload().length); - - MessageHeaders headers = message.getHeaders(); - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - Map map = stompHeaders.toMap(); - assertEquals(5, map.size()); - assertNotNull(stompHeaders.getId()); - assertNotNull(stompHeaders.getTimestamp()); - assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType()); - assertEquals(StompCommand.CONNECT, stompHeaders.getCommand()); - assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS)); - - assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); - assertEquals("github.org", stompHeaders.getHost()); - - assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType()); - assertEquals(StompCommand.CONNECT, stompHeaders.getCommand()); - assertNotNull(headers.get(MessageHeaders.ID)); - assertNotNull(headers.get(MessageHeaders.TIMESTAMP)); - - String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); - - assertEquals("CONNECT\n", convertedBack.substring(0,8)); - assertTrue(convertedBack.contains(accept)); - assertTrue(convertedBack.contains(host)); - } - - @Test - public void connectWithEscapes() throws Exception { - - String accept = "accept-version:1.1"; - String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org"; - - TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT) - .headers(accept, host).build(); - - @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(textMessage.getPayload()); - - assertEquals(0, message.getPayload().length); - - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); - assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.toNativeHeaderMap().get("ho:\ns\rt").get(0)); - - String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); - - assertEquals("CONNECT\n", convertedBack.substring(0,8)); - assertTrue(convertedBack.contains(accept)); - assertTrue(convertedBack.contains(host)); - } - - @Test - public void connectCR12() throws Exception { - - String accept = "accept-version:1.2\n"; - String host = "host:github.org\n"; - String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8")); - - assertEquals(0, message.getPayload().length); - - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion()); - assertEquals("github.org", stompHeaders.getHost()); - - String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); - - assertEquals("CONNECT\n", convertedBack.substring(0,8)); - assertTrue(convertedBack.contains(accept)); - assertTrue(convertedBack.contains(host)); - } - - @Test - public void connectWithEscapesAndCR12() throws Exception { - - String accept = "accept-version:1.1\n"; - String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n"; - String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n"; - @SuppressWarnings("unchecked") - Message message = (Message) this.converter.toMessage(test.getBytes("UTF-8")); - - assertEquals(0, message.getPayload().length); - - StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); - assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); - assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.toNativeHeaderMap().get("ho:\ns\rt").get(0)); - - String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); - - assertEquals("CONNECT\n", convertedBack.substring(0,8)); - assertTrue(convertedBack.contains(accept)); - assertTrue(convertedBack.contains(host)); - } - -} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java index 62d16318fca..78311b36d87 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java @@ -16,6 +16,7 @@ package org.springframework.messaging.simp.stomp; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.HashSet; @@ -60,7 +61,33 @@ public class StompProtocolHandlerTests { } @Test - public void handleConnect() { + public void connectedResponseIsSentWhenHandlingConnect() { + this.stompHandler.setHandleConnect(true); + + TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers( + "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build(); + + this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel); + + verifyNoMoreInteractions(this.channel); + + // Check CONNECTED reply + + assertEquals(1, this.session.getSentMessages().size()); + textMessage = (TextMessage) this.session.getSentMessages().get(0); + Message message = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes())); + StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message); + + assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand()); + assertEquals("1.1", replyHeaders.getVersion()); + assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat()); + assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0)); + assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0)); + } + + @Test + public void connectIsForwardedWhenNotHandlingConnect() { + this.stompHandler.setHandleConnect(false); TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers( "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build(); @@ -80,18 +107,7 @@ public class StompProtocolHandlerTests { assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat()); assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion()); - // Check CONNECTED reply - - assertEquals(1, this.session.getSentMessages().size()); - textMessage = (TextMessage) this.session.getSentMessages().get(0); - Message message = new StompMessageConverter().toMessage(textMessage.getPayload()); - StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message); - - assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand()); - assertEquals("1.1", replyHeaders.getVersion()); - assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat()); - assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0)); - assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0)); + assertEquals(0, this.session.getSentMessages().size()); } }