Merge multiple pull requests from wilkinsona

* broker-relay:
  Polish
  Improve handling of missed heartbeats
  Upgrade to Reactor 1.0.0.M3
  Add heart-beat support to STOMP broker relay
  Remove CONNECT-related message buffer from STOMP relay
  Add StompCodec
This commit is contained in:
Rossen Stoyanchev 2013-09-26 16:09:01 -04:00
commit d5dfd1b4ad
16 changed files with 956 additions and 617 deletions

View File

@ -71,6 +71,7 @@ configure(allprojects) { project ->
maven { url "https://repository.apache.org/content/repositories/releases" } // tomcat 8 RC3
maven { url "https://repository.apache.org/content/repositories/snapshots" } // tomcat-websocket-* snapshots
maven { url "https://maven.java.net/content/repositories/releases" } // javax.websocket, tyrus
maven { url 'http://repo.springsource.org/libs-snapshot' } // reactor
}
dependencies {
@ -352,8 +353,8 @@ project("spring-messaging") {
optional(project(":spring-websocket"))
optional(project(":spring-webmvc"))
optional("com.fasterxml.jackson.core:jackson-databind:2.2.0")
optional("org.projectreactor:reactor-core:1.0.0.M2")
optional("org.projectreactor:reactor-tcp:1.0.0.M2")
optional("org.projectreactor:reactor-core:1.0.0.M3")
optional("org.projectreactor:reactor-tcp:1.0.0.M3")
optional("com.lmax:disruptor:3.1.1")
optional("org.eclipse.jetty.websocket:websocket-server:9.0.5.v20130815")
optional("org.eclipse.jetty.websocket:websocket-client:9.0.5.v20130815")

View File

@ -23,6 +23,7 @@ import org.springframework.messaging.core.MessageSendingOperations;
import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.messaging.handler.method.HandlerMethodReturnValueHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.annotation.SubscribeEvent;
import org.springframework.messaging.support.MessageBuilder;
@ -97,6 +98,7 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setSessionId(this.sessionId);
headers.setSubscriptionId(this.subscriptionId);
headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
return MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
}
}

View File

@ -54,7 +54,8 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry {
public ServletStompEndpointRegistry(WebSocketHandler webSocketHandler,
MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) {
MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler,
boolean handleConnect) {
Assert.notNull(webSocketHandler);
Assert.notNull(userQueueSuffixResolver);
@ -63,6 +64,7 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry {
this.subProtocolWebSocketHandler = findSubProtocolWebSocketHandler(webSocketHandler);
this.stompHandler = new StompProtocolHandler();
this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver);
this.stompHandler.setHandleConnect(handleConnect);
this.sockJsScheduler = defaultSockJsTaskScheduler;
}

View File

@ -57,8 +57,10 @@ public abstract class WebSocketMessageBrokerConfigurationSupport {
@Bean
public HandlerMapping brokerWebSocketHandlerMapping() {
ServletStompEndpointRegistry registry = new ServletStompEndpointRegistry(
subProtocolWebSocketHandler(), userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler());
boolean brokerRelayConfigured = getMessageBrokerConfigurer().getStompBrokerRelay() != null;
ServletStompEndpointRegistry registry = new ServletStompEndpointRegistry(subProtocolWebSocketHandler(),
userQueueSuffixResolver(), brokerDefaultSockJsTaskScheduler(), !brokerRelayConfigured);
registerStompEndpoints(registry);
AbstractHandlerMapping hm = registry.getHandlerMapping();
hm.setOrder(1);

View File

@ -17,14 +17,9 @@
package org.springframework.messaging.simp.stomp;
import java.net.InetSocketAddress;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicReference;
import org.springframework.messaging.Message;
@ -34,7 +29,6 @@ import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.AbstractBrokerMessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import reactor.core.Environment;
import reactor.core.composable.Composable;
@ -45,8 +39,6 @@ import reactor.function.Consumer;
import reactor.tcp.Reconnect;
import reactor.tcp.TcpClient;
import reactor.tcp.TcpConnection;
import reactor.tcp.encoding.DelimitedCodec;
import reactor.tcp.encoding.StandardCodecs;
import reactor.tcp.netty.NettyTcpClient;
import reactor.tcp.spec.TcpClientSpec;
import reactor.tuple.Tuple;
@ -74,13 +66,11 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
private String systemPasscode = "guest";
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private Environment environment;
private TcpClient<String, String> tcpClient;
private TcpClient<Message<byte[]>, Message<byte[]>> tcpClient;
private final Map<String, RelaySession> relaySessions = new ConcurrentHashMap<String, RelaySession>();
private final Map<String, StompRelaySession> relaySessions = new ConcurrentHashMap<String, StompRelaySession>();
/**
@ -159,16 +149,16 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
@Override
protected void startInternal() {
this.environment = new Environment();
this.tcpClient = new TcpClientSpec<String, String>(NettyTcpClient.class)
this.tcpClient = new TcpClientSpec<Message<byte[]>, Message<byte[]>>(NettyTcpClient.class)
.env(this.environment)
.codec(new DelimitedCodec<String, String>((byte) 0, true, StandardCodecs.STRING_CODEC))
.codec(new StompCodec())
.connect(this.relayHost, this.relayPort)
.get();
if (logger.isDebugEnabled()) {
logger.debug("Initializing \"system\" TCP connection");
}
SystemRelaySession session = new SystemRelaySession();
SystemStompRelaySession session = new SystemStompRelaySession();
this.relaySessions.put(session.getId(), session);
session.connect();
}
@ -199,35 +189,31 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
SimpMessageType messageType = headers.getMessageType();
if (SimpMessageType.MESSAGE.equals(messageType)) {
sessionId = (sessionId == null) ? SystemRelaySession.ID : sessionId;
sessionId = (sessionId == null) ? SystemStompRelaySession.ID : sessionId;
headers.setSessionId(sessionId);
command = (command == null) ? StompCommand.SEND : command;
headers.setCommandIfNotSet(command);
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
}
if (headers.getCommand() == null) {
logger.error("No STOMP command, ignoring message: " + message);
return;
}
if (sessionId == null) {
logger.error("No sessionId, ignoring message: " + message);
return;
}
if (command.requiresDestination() && !checkDestinationPrefix(destination)) {
if (command != null && command.requiresDestination() && !checkDestinationPrefix(destination)) {
return;
}
try {
if (SimpMessageType.CONNECT.equals(messageType)) {
headers.setHeartbeat(0, 0);
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
RelaySession session = new RelaySession(sessionId);
StompRelaySession session = new StompRelaySession(sessionId);
this.relaySessions.put(sessionId, session);
session.connect(message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
RelaySession session = this.relaySessions.remove(sessionId);
StompRelaySession session = this.relaySessions.remove(sessionId);
if (session == null) {
if (logger.isTraceEnabled()) {
logger.trace("Session already removed, sessionId=" + sessionId);
@ -237,7 +223,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
session.forward(message);
}
else {
RelaySession session = this.relaySessions.get(sessionId);
StompRelaySession session = this.relaySessions.get(sessionId);
if (session == null) {
logger.warn("Session id=" + sessionId + " not found. Ignoring message: " + message);
return;
@ -251,18 +237,14 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
private class RelaySession {
private class StompRelaySession {
private final String sessionId;
private final BlockingQueue<Message<?>> messageQueue = new LinkedBlockingQueue<Message<?>>(50);
private volatile StompConnection stompConnection = new StompConnection();
private final Object monitor = new Object();
private RelaySession(String sessionId) {
private StompRelaySession(String sessionId) {
Assert.notNull(sessionId, "sessionId is required");
this.sessionId = sessionId;
}
@ -275,14 +257,14 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
public void connect(final Message<?> connectMessage) {
Assert.notNull(connectMessage, "connectMessage is required");
Composable<TcpConnection<String, String>> connectionComposable = openTcpConnection();
connectionComposable.consume(new Consumer<TcpConnection<String, String>>() {
Composable<TcpConnection<Message<byte[]>, Message<byte[]>>> promise = initConnection();
promise.consume(new Consumer<TcpConnection<Message<byte[]>, Message<byte[]>>>() {
@Override
public void accept(TcpConnection<String, String> connection) {
handleTcpConnection(connection, connectMessage);
public void accept(TcpConnection<Message<byte[]>, Message<byte[]>> connection) {
handleConnectionReady(connection, connectMessage);
}
});
connectionComposable.when(Throwable.class, new Consumer<Throwable>() {
promise.when(Throwable.class, new Consumer<Throwable>() {
@Override
public void accept(Throwable ex) {
relaySessions.remove(sessionId);
@ -291,41 +273,44 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
});
}
protected Composable<TcpConnection<String, String>> openTcpConnection() {
protected Composable<TcpConnection<Message<byte[]>, Message<byte[]>>> initConnection() {
return tcpClient.open();
}
protected void handleTcpConnection(TcpConnection<String, String> tcpConn, final Message<?> connectMessage) {
protected void handleConnectionReady(
TcpConnection<Message<byte[]>, Message<byte[]>> tcpConn, final Message<?> connectMessage) {
this.stompConnection.setTcpConnection(tcpConn);
tcpConn.in().consume(new Consumer<String>() {
tcpConn.on().close(new Runnable() {
@Override
public void accept(String message) {
public void run() {
connectionClosed();
}
});
tcpConn.in().consume(new Consumer<Message<byte[]>>() {
@Override
public void accept(Message<byte[]> message) {
readStompFrame(message);
}
});
forwardInternal(tcpConn, connectMessage);
}
private void readStompFrame(String stompFrame) {
// heartbeat
if (StringUtils.isEmpty(stompFrame)) {
return;
protected void connectionClosed() {
relaySessions.remove(this.sessionId);
if (this.stompConnection.isReady()) {
sendError("Lost connection to the broker");
}
}
Message<?> message = stompMessageConverter.toMessage(stompFrame);
private void readStompFrame(Message<byte[]> message) {
if (logger.isTraceEnabled()) {
logger.trace("Reading message " + message);
}
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (StompCommand.CONNECTED == headers.getCommand()) {
synchronized(this.monitor) {
this.stompConnection.setReady();
publishBrokerAvailableEvent();
flushMessages();
}
return;
connected(headers, this.stompConnection);
}
headers.setSessionId(this.sessionId);
@ -333,12 +318,21 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
sendMessageToClient(message);
}
protected void connected(StompHeaderAccessor headers, StompConnection stompConnection) {
this.stompConnection.setReady();
publishBrokerAvailableEvent();
}
private void handleTcpClientFailure(String message, Throwable ex) {
if (logger.isErrorEnabled()) {
logger.error(message + ", sessionId=" + this.sessionId, ex);
}
disconnected(message);
}
protected void disconnected(String errorMessage) {
this.stompConnection.setDisconnected();
sendError(message);
sendError(errorMessage);
publishBrokerUnavailableEvent();
}
@ -355,47 +349,33 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
public void forward(Message<?> message) {
if (!this.stompConnection.isReady()) {
synchronized(this.monitor) {
if (!this.stompConnection.isReady()) {
this.messageQueue.add(message);
if (logger.isTraceEnabled()) {
logger.trace("Not connected, message queued. Queue size=" + this.messageQueue.size());
}
return;
}
}
}
if (this.messageQueue.isEmpty()) {
forwardInternal(message);
}
else {
this.messageQueue.add(message);
flushMessages();
}
}
private boolean forwardInternal(final Message<?> message) {
TcpConnection<String, String> tcpConnection = this.stompConnection.getReadyConnection();
TcpConnection<Message<byte[]>, Message<byte[]>> tcpConnection = this.stompConnection.getReadyConnection();
if (tcpConnection == null) {
return false;
logger.warn("Connection to STOMP broker is not active, discarding message: " + message);
return;
}
return forwardInternal(tcpConnection, message);
forwardInternal(tcpConnection, message);
}
private boolean forwardInternal(TcpConnection<String, String> tcpConnection, final Message<?> message) {
private boolean forwardInternal(
TcpConnection<Message<byte[]>, Message<byte[]>> tcpConnection, Message<?> message) {
Assert.isInstanceOf(byte[].class, message.getPayload(), "Message's payload must be a byte[]");
@SuppressWarnings("unchecked")
Message<byte[]> byteMessage = (Message<byte[]>) message;
if (logger.isTraceEnabled()) {
logger.trace("Forwarding to STOMP broker, message: " + message);
}
byte[] bytes = stompMessageConverter.fromMessage(message);
String payload = new String(bytes, Charset.forName("UTF-8"));
StompCommand command = StompHeaderAccessor.wrap(message).getCommand();
if (command == StompCommand.DISCONNECT) {
this.stompConnection.setDisconnected();
}
final Deferred<Boolean, Promise<Boolean>> deferred = new DeferredPromiseSpec<Boolean>().get();
tcpConnection.send(payload, new Consumer<Boolean>() {
tcpConnection.send(byteMessage, new Consumer<Boolean>() {
@Override
public void accept(Boolean success) {
deferred.accept(success);
@ -409,7 +389,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
handleTcpClientFailure("Timed out waiting for message to be forwarded to the broker", null);
}
else if (!success) {
if (StompHeaderAccessor.wrap(message).getCommand() != StompCommand.DISCONNECT) {
if (command != StompCommand.DISCONNECT) {
handleTcpClientFailure("Failed to forward message to the broker", null);
}
}
@ -420,32 +400,26 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
return (success != null) ? success : false;
}
private void flushMessages() {
List<Message<?>> messages = new ArrayList<Message<?>>();
this.messageQueue.drainTo(messages);
for (Message<?> message : messages) {
if (!forwardInternal(message)) {
return;
}
}
}
}
private static class StompConnection {
private volatile TcpConnection<String, String> connection;
private volatile TcpConnection<Message<byte[]>, Message<byte[]>> connection;
private AtomicReference<TcpConnection<String, String>> readyConnection =
new AtomicReference<TcpConnection<String, String>>();
private AtomicReference<TcpConnection<Message<byte[]>, Message<byte[]>>> readyConnection =
new AtomicReference<TcpConnection<Message<byte[]>, Message<byte[]>>>();
public void setTcpConnection(TcpConnection<String, String> connection) {
public void setTcpConnection(TcpConnection<Message<byte[]>, Message<byte[]>> connection) {
Assert.notNull(connection, "connection must not be null");
this.connection = connection;
}
public TcpConnection<String, String> getReadyConnection() {
/**
* Return the underlying {@link TcpConnection} but only after the CONNECTED STOMP
* frame is received.
*/
public TcpConnection<Message<byte[]>, Message<byte[]>> getReadyConnection() {
return this.readyConnection.get();
}
@ -459,7 +433,12 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
public void setDisconnected() {
this.readyConnection.set(null);
this.connection = null;
TcpConnection<Message<byte[]>, Message<byte[]>> localConnection = this.connection;
if (localConnection != null) {
localConnection.close();
this.connection = null;
}
}
@Override
@ -468,12 +447,20 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
}
private class SystemRelaySession extends RelaySession {
private class SystemStompRelaySession extends StompRelaySession {
private static final long HEARTBEAT_RECEIVE_MULTIPLIER = 3;
private static final long HEARTBEAT_SEND_INTERVAL = 10000;
private static final long HEARTBEAT_RECEIVE_INTERVAL = 10000;
public static final String ID = "stompRelaySystemSessionId";
private final byte[] heartbeatPayload = new byte[] {'\n'};
public SystemRelaySession() {
public SystemStompRelaySession() {
super(ID);
}
@ -482,13 +469,13 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
headers.setAcceptVersion("1.1,1.2");
headers.setLogin(systemLogin);
headers.setPasscode(systemPasscode);
headers.setHeartbeat(0,0);
headers.setHeartbeat(HEARTBEAT_SEND_INTERVAL, HEARTBEAT_RECEIVE_INTERVAL);
Message<?> connectMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
super.connect(connectMessage);
}
@Override
protected Composable<TcpConnection<String, String>> openTcpConnection() {
protected Composable<TcpConnection<Message<byte[]>, Message<byte[]>>> initConnection() {
return tcpClient.open(new Reconnect() {
@Override
public Tuple2<InetSocketAddress, Long> reconnect(InetSocketAddress address, int attempt) {
@ -497,6 +484,47 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
});
}
@Override
protected void connectionClosed() {
publishBrokerUnavailableEvent();
}
@Override
protected void connected(StompHeaderAccessor headers, final StompConnection stompConnection) {
long brokerReceiveInterval = headers.getHeartbeat()[1];
if (HEARTBEAT_SEND_INTERVAL > 0 && brokerReceiveInterval > 0) {
long interval = Math.max(HEARTBEAT_SEND_INTERVAL, brokerReceiveInterval);
stompConnection.connection.on().writeIdle(interval, new Runnable() {
@Override
public void run() {
TcpConnection<Message<byte[]>, Message<byte[]>> connection = stompConnection.connection;
if (connection != null) {
connection.send(MessageBuilder.withPayload(heartbeatPayload).build());
}
}
});
}
long brokerSendInterval = headers.getHeartbeat()[0];
if (HEARTBEAT_RECEIVE_INTERVAL > 0 && brokerSendInterval > 0) {
final long interval =
Math.max(HEARTBEAT_RECEIVE_INTERVAL, brokerSendInterval) * HEARTBEAT_RECEIVE_MULTIPLIER;
stompConnection.connection.on().readIdle(interval, new Runnable() {
@Override
public void run() {
String message = "Broker hearbeat missed: connection idle for more than " + interval + "ms";
logger.warn(message);
disconnected(message);
}
});
}
super.connected(headers, stompConnection);
}
@Override
protected void sendMessageToClient(Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);

View File

@ -0,0 +1,68 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import org.springframework.messaging.Message;
import reactor.function.Consumer;
import reactor.function.Function;
import reactor.io.Buffer;
import reactor.tcp.encoding.Codec;
/**
* A Reactor TCP {@link Codec} for sending and receiving STOMP messages
*
* @author Andy Wilkinson
* @since 4.0
*/
public class StompCodec implements Codec<Buffer, Message<byte[]>, Message<byte[]>> {
private static final StompDecoder DECODER = new StompDecoder();
private static final Function<Message<byte[]>, Buffer> ENCODER_FUNCTION = new Function<Message<byte[]>, Buffer>() {
private final StompEncoder encoder = new StompEncoder();
@Override
public Buffer apply(Message<byte[]> message) {
return Buffer.wrap(this.encoder.encode(message));
}
};
@Override
public Function<Buffer, Message<byte[]>> decoder(final Consumer<Message<byte[]>> next) {
return new Function<Buffer, Message<byte[]>>() {
@Override
public Message<byte[]> apply(Buffer buffer) {
while (buffer.remaining() > 0) {
Message<byte[]> message = DECODER.decode(buffer.byteBuffer());
if (message != null) {
next.accept(message);
}
}
return null;
}
};
}
@Override
public Function<Message<byte[]>, Buffer> encoder() {
return ENCODER_FUNCTION;
}
}

View File

@ -0,0 +1,171 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* A decoder for STOMP frames
*
* @author awilkinson
* @since 4.0
*/
public class StompDecoder {
private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
private static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'};
private final Log logger = LogFactory.getLog(StompDecoder.class);
/**
* Decodes a STOMP frame in the given {@code buffer} into a {@link Message}.
*
* @param buffer The buffer to decode the frame from
* @return The decoded message
*/
public Message<byte[]> decode(ByteBuffer buffer) {
skipLeadingEol(buffer);
String command = readCommand(buffer);
if (command.length() > 0) {
MultiValueMap<String, String> headers = readHeaders(buffer);
byte[] payload = readPayload(buffer, headers);
Message<byte[]> decodedMessage = MessageBuilder.withPayloadAndHeaders(payload,
StompHeaderAccessor.create(StompCommand.valueOf(command), headers)).build();
if (logger.isTraceEnabled()) {
logger.trace("Decoded " + decodedMessage);
}
return decodedMessage;
}
else {
if (logger.isTraceEnabled()) {
logger.trace("Decoded heartbeat");
}
return MessageBuilder.withPayload(HEARTBEAT_PAYLOAD).build();
}
}
private String readCommand(ByteBuffer buffer) {
ByteArrayOutputStream command = new ByteArrayOutputStream();
while (buffer.remaining() > 0 && !isEol(buffer)) {
command.write(buffer.get());
}
return new String(command.toByteArray(), UTF8_CHARSET);
}
private MultiValueMap<String, String> readHeaders(ByteBuffer buffer) {
MultiValueMap<String, String> headers = new LinkedMultiValueMap<String, String>();
while (true) {
ByteArrayOutputStream headerStream = new ByteArrayOutputStream();
while (buffer.remaining() > 0 && !isEol(buffer)) {
headerStream.write(buffer.get());
}
if (headerStream.size() > 0) {
String header = new String(headerStream.toByteArray(), UTF8_CHARSET);
int colonIndex = header.indexOf(':');
if (colonIndex <= 0 || colonIndex == header.length() - 1) {
throw new StompConversionException(
"Illegal header: '" + header + "'. A header must be of the form <name>:<value");
}
else {
String headerName = unescape(header.substring(0, colonIndex));
String headerValue = unescape(header.substring(colonIndex + 1));
headers.add(headerName, headerValue);
}
}
else {
break;
}
}
return headers;
}
private String unescape(String input) {
return input.replaceAll("\\\\n", "\n")
.replaceAll("\\\\r", "\r")
.replaceAll("\\\\c", ":")
.replaceAll("\\\\\\\\", "\\\\");
}
private byte[] readPayload(ByteBuffer buffer, MultiValueMap<String, String> headers) {
String contentLengthString = headers.getFirst("content-length");
if (contentLengthString != null) {
int contentLength = Integer.valueOf(contentLengthString);
byte[] payload = new byte[contentLength];
buffer.get(payload);
if (buffer.remaining() < 1 || buffer.get() != 0) {
throw new StompConversionException("Frame must be terminated with a null octect");
}
return payload;
}
else {
ByteArrayOutputStream payload = new ByteArrayOutputStream();
while (buffer.remaining() > 0) {
byte b = buffer.get();
if (b == 0) {
return payload.toByteArray();
}
else {
payload.write(b);
}
}
}
throw new StompConversionException("Frame must be terminated with a null octect");
}
private void skipLeadingEol(ByteBuffer buffer) {
while (true) {
if (!isEol(buffer)) {
break;
}
}
}
private boolean isEol(ByteBuffer buffer) {
if (buffer.remaining() > 0) {
byte b = buffer.get();
if (b == '\n') {
return true;
}
else if (b == '\r') {
if (buffer.remaining() > 0 && buffer.get() == '\n') {
return true;
}
else {
throw new StompConversionException("'\\r' must be followed by '\\n'");
}
}
buffer.position(buffer.position() - 1);
}
return false;
}
}

View File

@ -0,0 +1,129 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
/**
* An encoder for STOMP frames
*
* @author Andy Wilkinson
* @since 4.0
*/
public final class StompEncoder {
private static final byte LF = '\n';
private static final byte COLON = ':';
private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
private final Log logger = LogFactory.getLog(StompEncoder.class);
/**
* Encodes the given STOMP {@code message} into a {@code byte[]}
*
* @param message The message to encode
*
* @return The encoded message
*/
public byte[] encode(Message<byte[]> message) {
try {
if (logger.isTraceEnabled()) {
logger.trace("Encoding " + message);
}
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream output = new DataOutputStream(baos);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (isHeartbeat(headers)) {
output.write(message.getPayload());
} else {
writeCommand(headers, output);
writeHeaders(headers, message, output);
output.write(LF);
writeBody(message, output);
output.write((byte)0);
}
return baos.toByteArray();
}
catch (IOException e) {
throw new StompConversionException("Failed to encode STOMP frame", e);
}
}
private boolean isHeartbeat(StompHeaderAccessor headers) {
return headers.getCommand() == null;
}
private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException {
output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET));
output.write(LF);
}
private void writeHeaders(StompHeaderAccessor headers, Message<byte[]> message, DataOutputStream output)
throws IOException {
for (Entry<String, List<String>> entry : headers.toStompHeaderMap().entrySet()) {
byte[] key = getUtf8BytesEscapingIfNecessary(entry.getKey(), headers);
for (String value : entry.getValue()) {
output.write(key);
output.write(COLON);
output.write(getUtf8BytesEscapingIfNecessary(value, headers));
output.write(LF);
}
}
if (headers.getCommand() == StompCommand.SEND ||
headers.getCommand() == StompCommand.MESSAGE ||
headers.getCommand() == StompCommand.ERROR) {
output.write("content-length:".getBytes(UTF8_CHARSET));
output.write(Integer.toString(message.getPayload().length).getBytes(UTF8_CHARSET));
output.write(LF);
}
}
private void writeBody(Message<byte[]> message, DataOutputStream output) throws IOException {
output.write(message.getPayload());
}
private byte[] getUtf8BytesEscapingIfNecessary(String input, StompHeaderAccessor headers) {
if (headers.getCommand() != StompCommand.CONNECT && headers.getCommand() != StompCommand.CONNECTED) {
return escape(input).getBytes(UTF8_CHARSET);
}
else {
return input.getBytes(UTF8_CHARSET);
}
}
private String escape(String input) {
return input.replaceAll("\\\\", "\\\\\\\\")
.replaceAll(":", "\\\\c")
.replaceAll("\n", "\\\\n")
.replaceAll("\r", "\\\\r");
}
}

View File

@ -82,6 +82,8 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
public static final String STOMP_HEARTBEAT_HEADER = "heart-beat";
private static final long[] DEFAULT_HEARTBEAT = new long[] {0, 0};
// Other header names
@ -185,7 +187,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
result.put(STOMP_CONTENT_TYPE_HEADER, Arrays.asList(contentType.toString()));
}
if (getCommand().requiresSubscriptionId()) {
if (getCommand() != null && getCommand().requiresSubscriptionId()) {
String subscriptionId = getSubscriptionId();
if (subscriptionId != null) {
String name = StompCommand.MESSAGE.equals(getCommand()) ? STOMP_SUBSCRIPTION_HEADER : STOMP_ID_HEADER;
@ -252,7 +254,7 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {
public long[] getHeartbeat() {
String rawValue = getFirstNativeHeader(STOMP_HEARTBEAT_HEADER);
if (!StringUtils.hasText(rawValue)) {
return null;
return Arrays.copyOf(DEFAULT_HEARTBEAT, 2);
}
String[] rawValues = StringUtils.commaDelimitedListToStringArray(rawValue);
return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])};

View File

@ -1,231 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map.Entry;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* @author Gary Russell
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompMessageConverter {
private static final Charset STOMP_CHARSET = Charset.forName("UTF-8");
public static final byte LF = 0x0a;
public static final byte CR = 0x0d;
private static final byte COLON = ':';
/**
* @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String.
*/
public Message<?> toMessage(Object stompContent) {
byte[] byteContent = null;
if (stompContent instanceof String) {
byteContent = ((String) stompContent).getBytes(STOMP_CHARSET);
}
else if (stompContent instanceof byte[]){
byteContent = (byte[]) stompContent;
}
else {
throw new IllegalArgumentException(
"stompContent is neither String nor byte[]: " + stompContent.getClass());
}
int totalLength = byteContent.length;
if (byteContent[totalLength-1] == 0) {
totalLength--;
}
int payloadIndex = findIndexOfPayload(byteContent);
if (payloadIndex == 0) {
throw new StompConversionException("No command found");
}
String headerContent = new String(byteContent, 0, payloadIndex, STOMP_CHARSET);
Parser parser = new Parser(headerContent);
StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim());
Assert.notNull(command, "No command found");
MultiValueMap<String, String> headers = new LinkedMultiValueMap<String, String>();
while (parser.hasNext()) {
String header = parser.nextToken(COLON);
if (header != null) {
if (parser.hasNext()) {
String value = parser.nextToken(LF);
headers.add(header, value);
}
else {
throw new StompConversionException("Parse exception for " + headerContent);
}
}
}
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
return MessageBuilder.withPayloadAndHeaders(payload, stompHeaders).build();
}
private int findIndexOfPayload(byte[] bytes) {
int i;
// ignore any leading EOL from the previous message
for (i = 0; i < bytes.length; i++) {
if (bytes[i] != '\n' && bytes[i] != '\r') {
break;
}
bytes[i] = ' ';
}
int index = 0;
for (; i < bytes.length - 1; i++) {
if (bytes[i] == LF && bytes[i+1] == LF) {
index = i + 2;
break;
}
if ((i < (bytes.length - 3)) &&
(bytes[i] == CR && bytes[i+1] == LF && bytes[i+2] == CR && bytes[i+3] == LF)) {
index = i + 4;
break;
}
}
if (i >= bytes.length) {
throw new StompConversionException("No end of headers found");
}
return index;
}
public byte[] fromMessage(Message<?> message) {
byte[] payload;
if (message.getPayload() instanceof byte[]) {
payload = (byte[]) message.getPayload();
}
else {
throw new IllegalArgumentException(
"stompContent is not byte[]: " + message.getPayload().getClass());
}
ByteArrayOutputStream out = new ByteArrayOutputStream();
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
try {
out.write(stompHeaders.getCommand().toString().getBytes("UTF-8"));
out.write(LF);
for (Entry<String, List<String>> entry : stompHeaders.toStompHeaderMap().entrySet()) {
String key = entry.getKey();
key = replaceAllOutbound(key);
for (String value : entry.getValue()) {
out.write(key.getBytes("UTF-8"));
out.write(COLON);
value = replaceAllOutbound(value);
out.write(value.getBytes("UTF-8"));
out.write(LF);
}
}
out.write(LF);
out.write(payload);
out.write(0);
return out.toByteArray();
}
catch (IOException e) {
throw new StompConversionException("Failed to serialize " + message, e);
}
}
private String replaceAllOutbound(String key) {
return key.replaceAll("\\\\", "\\\\")
.replaceAll(":", "\\\\c")
.replaceAll("\n", "\\\\n")
.replaceAll("\r", "\\\\r");
}
private class Parser {
private final String content;
private int offset;
public Parser(String content) {
this.content = content;
}
public boolean hasNext() {
return this.offset < this.content.length();
}
public String nextToken(byte delimiter) {
if (this.offset >= this.content.length()) {
return null;
}
int delimAt = this.content.indexOf(delimiter, this.offset);
if (delimAt == -1) {
if (this.offset == this.content.length() - 1 && delimiter == COLON &&
this.content.charAt(this.offset) == LF) {
this.offset++;
return null;
}
else if (this.offset == this.content.length() - 2 && delimiter == COLON &&
this.content.charAt(this.offset) == CR &&
this.content.charAt(this.offset + 1) == LF) {
this.offset += 2;
return null;
}
else {
throw new StompConversionException("No delimiter found at offset " + offset + " in " + this.content);
}
}
int escapeAt = this.content.indexOf('\\', this.offset);
String token = this.content.substring(this.offset, delimAt + 1);
this.offset += token.length();
if (escapeAt >= 0 && escapeAt < delimAt) {
char escaped = this.content.charAt(escapeAt + 1);
if (escaped == 'n' || escaped == 'c' || escaped == '\\') {
token = token.replaceAll("\\\\n", "\n")
.replaceAll("\\\\r", "\r")
.replaceAll("\\\\c", ":")
.replaceAll("\\\\\\\\", "\\\\");
}
else {
throw new StompConversionException("Invalid escape sequence \\" + escaped);
}
}
int length = token.length();
if (delimiter == LF && length > 1 && token.charAt(length - 2) == CR) {
return token.substring(0, length - 2);
}
else {
return token.substring(0, length - 1);
}
}
}
}

View File

@ -17,6 +17,7 @@
package org.springframework.messaging.simp.stomp;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.security.Principal;
import java.util.Arrays;
@ -63,10 +64,13 @@ public class StompProtocolHandler implements SubProtocolHandler {
private final Log logger = LogFactory.getLog(StompProtocolHandler.class);
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private final StompDecoder stompDecoder = new StompDecoder();
private final StompEncoder stompEncoder = new StompEncoder();
private MutableUserQueueSuffixResolver queueSuffixResolver = new SimpleUserQueueSuffixResolver();
private volatile boolean handleConnect = false;
/**
* Configure a resolver to use to maintain queue suffixes for user
@ -83,6 +87,29 @@ public class StompProtocolHandler implements SubProtocolHandler {
return this.queueSuffixResolver;
}
/**
* Configures the handling of CONNECT frames. When {@code true}, CONNECT
* frames will be handled by this handler, and a CONNECTED response will be
* sent. When {@code false}, CONNECT frames will be forwarded for
* handling by another component.
*
* @param handleConnect {@code true} if connect frames should be handled
* by this handler, {@code false} otherwise.
*/
public void setHandleConnect(boolean handleConnect) {
this.handleConnect = handleConnect;
}
/**
* Returns whether or not this handler will handle CONNECT frames.
*
* @return Returns {@code true} if this handler will handle CONNECT frames,
* otherwise {@code false}.
*/
public boolean willHandleConnect() {
return this.handleConnect;
}
@Override
public List<String> getSupportedProtocols() {
return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
@ -98,7 +125,8 @@ public class StompProtocolHandler implements SubProtocolHandler {
try {
Assert.isInstanceOf(TextMessage.class, webSocketMessage);
String payload = ((TextMessage)webSocketMessage).getPayload();
message = this.stompMessageConverter.toMessage(payload);
ByteBuffer byteBuffer = ByteBuffer.wrap(payload.getBytes(Charset.forName("UTF-8")));
message = this.stompDecoder.decode(byteBuffer);
}
catch (Throwable error) {
logger.error("Failed to parse STOMP frame, WebSocket message payload: ", error);
@ -117,31 +145,38 @@ public class StompProtocolHandler implements SubProtocolHandler {
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
if (SimpMessageType.CONNECT.equals(headers.getMessageType())) {
if (this.handleConnect && SimpMessageType.CONNECT.equals(headers.getMessageType())) {
handleConnect(session, message);
}
outputChannel.send(message);
else {
outputChannel.send(message);
}
}
catch (Throwable t) {
logger.error("Terminating STOMP session due to failure to send message: ", t);
sendErrorMessage(session, t);
}
}
/**
* Handle STOMP messages going back out to WebSocket clients.
*/
@SuppressWarnings("unchecked")
@Override
public void handleMessageToClient(WebSocketSession session, Message<?> message) {
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setCommandIfNotSet(StompCommand.MESSAGE);
if (headers.getCommand() == null && SimpMessageType.MESSAGE == headers.getMessageType()) {
headers.setCommandIfNotSet(StompCommand.MESSAGE);
}
if (StompCommand.CONNECTED.equals(headers.getCommand())) {
// Ignore for now since we already sent it
return;
if (headers.getCommand() == StompCommand.CONNECTED) {
if (this.handleConnect) {
// Ignore since we already sent it
return;
} else {
augmentConnectedHeaders(headers, session);
}
}
if (StompCommand.MESSAGE.equals(headers.getCommand()) && (headers.getSubscriptionId() == null)) {
@ -156,7 +191,7 @@ public class StompProtocolHandler implements SubProtocolHandler {
try {
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
byte[] bytes = this.stompEncoder.encode((Message<byte[]>)message);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
catch (Throwable t) {
@ -193,30 +228,36 @@ public class StompProtocolHandler implements SubProtocolHandler {
}
connectedHeaders.setHeartbeat(0,0);
augmentConnectedHeaders(connectedHeaders, session);
// TODO: security
Message<byte[]> connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build();
String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8"));
session.sendMessage(new TextMessage(payload));
}
private void augmentConnectedHeaders(StompHeaderAccessor headers, WebSocketSession session) {
Principal principal = session.getPrincipal();
if (principal != null) {
connectedHeaders.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
connectedHeaders.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId());
headers.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
headers.setNativeHeader(QUEUE_SUFFIX_HEADER, session.getId());
if (this.queueSuffixResolver != null) {
String suffix = session.getId();
this.queueSuffixResolver.addQueueSuffix(principal.getName(), session.getId(), suffix);
}
}
Message<?> connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build();
byte[] bytes = this.stompMessageConverter.fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setMessage(error.getMessage());
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
Message<byte[]> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
String payload = new String(this.stompEncoder.encode(message), Charset.forName("UTF-8"));
try {
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
session.sendMessage(new TextMessage(payload));
}
catch (Throwable t) {
// ignore

View File

@ -53,7 +53,7 @@ public class ServletStompEndpointRegistryTests {
this.webSocketHandler = new SubProtocolWebSocketHandler(channel);
this.queueSuffixResolver = new SimpleUserQueueSuffixResolver();
TaskScheduler taskScheduler = Mockito.mock(TaskScheduler.class);
this.registry = new ServletStompEndpointRegistry(webSocketHandler, queueSuffixResolver, taskScheduler);
this.registry = new ServletStompEndpointRegistry(webSocketHandler, queueSuffixResolver, taskScheduler, false);
}

View File

@ -30,7 +30,6 @@ import org.apache.commons.logging.LogFactory;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.messaging.Message;
@ -63,16 +62,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
private ExpectationMatchingEventPublisher eventPublisher;
private int port;
@Before
public void setUp() throws Exception {
int port = SocketUtils.findAvailableTcpPort(61613);
this.port = SocketUtils.findAvailableTcpPort(61613);
this.activeMQBroker = new BrokerService();
this.activeMQBroker.addConnector("stomp://localhost:" + port);
this.activeMQBroker.setStartAsync(false);
this.activeMQBroker.setDeleteAllMessagesOnStartup(true);
this.activeMQBroker.start();
createAndStartBroker();
this.responseChannel = new ExecutorSubscribableChannel();
this.responseHandler = new ExpectationMatchingMessageHandler();
@ -86,6 +83,14 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
this.relay.start();
}
private void createAndStartBroker() throws Exception {
this.activeMQBroker = new BrokerService();
this.activeMQBroker.addConnector("stomp://localhost:" + port);
this.activeMQBroker.setStartAsync(false);
this.activeMQBroker.setDeleteAllMessagesOnStartup(true);
this.activeMQBroker.start();
}
@After
public void tearDown() throws Exception {
try {
@ -102,22 +107,24 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
String sess1 = "sess1";
MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build();
this.relay.handleMessage(conn1.message);
this.responseHandler.expect(conn1);
String sess2 = "sess2";
MessageExchange conn2 = MessageExchangeBuilder.connect(sess2).build();
this.relay.handleMessage(conn2.message);
this.responseHandler.expect(conn2);
this.responseHandler.awaitAndAssert();
String subs1 = "subs1";
String destination = "/topic/test";
MessageExchange subscribe = MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build();
this.responseHandler.expect(subscribe);
this.relay.handleMessage(subscribe.message);
this.responseHandler.expect(subscribe);
this.responseHandler.awaitAndAssert();
MessageExchange send = MessageExchangeBuilder.send(destination, "foo").andExpectMessage(sess1, subs1).build();
this.responseHandler.reset();
this.responseHandler.expect(send);
this.relay.handleMessage(send.message);
@ -129,7 +136,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
stopBrokerAndAwait();
MessageExchange connect = MessageExchangeBuilder.connect("sess1").andExpectError().build();
MessageExchange connect = MessageExchangeBuilder.connectWithError("sess1").build();
this.responseHandler.expect(connect);
this.relay.handleMessage(connect.message);
@ -137,37 +144,31 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
}
@Test
public void brokerUnvailableErrorFrameOnSend() throws Exception {
public void brokerBecomingUnvailableTriggersErrorFrame() throws Exception {
String sess1 = "sess1";
MessageExchange connect = MessageExchangeBuilder.connect(sess1).build();
this.responseHandler.expect(connect);
this.relay.handleMessage(connect.message);
// TODO: expect CONNECTED
Thread.sleep(2000);
this.responseHandler.awaitAndAssert();
this.responseHandler.expect(MessageExchangeBuilder.error(sess1).build());
stopBrokerAndAwait();
MessageExchange subscribe = MessageExchangeBuilder.subscribe(sess1, "s1", "/topic/a").andExpectError().build();
this.responseHandler.expect(subscribe);
this.relay.handleMessage(subscribe.message);
this.responseHandler.awaitAndAssert();
}
@Test
public void brokerAvailabilityEvents() throws Exception {
// TODO: expect CONNECTED
Thread.sleep(2000);
this.eventPublisher.expect(true, false);
this.eventPublisher.expect(true);
this.eventPublisher.awaitAndAssert();
this.eventPublisher.expect(false);
stopBrokerAndAwait();
// TODO: remove when stop is detecteded
this.relay.handleMessage(MessageExchangeBuilder.connect("sess1").build().message);
this.eventPublisher.awaitAndAssert();
}
@ -176,37 +177,55 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
String sess1 = "sess1";
MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build();
this.responseHandler.expect(conn1);
this.relay.handleMessage(conn1.message);
this.responseHandler.awaitAndAssert();
String subs1 = "subs1";
String destination = "/topic/test";
MessageExchange subscribe = MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build();
MessageExchange subscribe =
MessageExchangeBuilder.subscribeWithReceipt(sess1, subs1, destination, "r1").build();
this.responseHandler.expect(subscribe);
this.relay.handleMessage(subscribe.message);
this.responseHandler.awaitAndAssert();
this.responseHandler.expect(MessageExchangeBuilder.error(sess1).build());
stopBrokerAndAwait();
// 1st message will see ERROR frame (broker shutdown is not but should be detected)
// 2nd message will be queued (a side effect of CONNECT/CONNECTED-buffering, likely to be removed)
// Finish this once the above changes are made.
this.responseHandler.awaitAndAssert();
this.eventPublisher.expect(true, false);
this.eventPublisher.awaitAndAssert();
this.eventPublisher.expect(true);
createAndStartBroker();
this.eventPublisher.awaitAndAssert();
// TODO The event publisher assertions show that the broker's back up and the system relay session
// has reconnected. We need to decide what we want the reconnect behaviour to be for client relay
// sessions and add further message sending and assertions as appropriate. At the moment any client
// sessions will be closed and an ERROR from will be sent.
}
@Test
public void disconnectClosesRelaySessionCleanly() throws Exception {
String sess1 = "sess1";
MessageExchange conn1 = MessageExchangeBuilder.connect(sess1).build();
this.responseHandler.expect(conn1);
this.relay.handleMessage(conn1.message);
this.responseHandler.awaitAndAssert();
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(sess1);
this.relay.handleMessage(MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build());
/* MessageExchange send = MessageExchangeBuilder.send(destination, "foo").build();
this.responseHandler.reset();
this.relay.handleMessage(send.message);
Thread.sleep(2000);
this.activeMQBroker.start();
Thread.sleep(5000);
send = MessageExchangeBuilder.send(destination, "foo").andExpectMessage(sess1, subs1).build();
this.responseHandler.reset();
this.responseHandler.expect(send);
this.relay.handleMessage(send.message);
// Check that we have not received an ERROR as a result of the connection closing
this.responseHandler.awaitAndAssert();
*/
}
@ -234,58 +253,68 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
*/
private static class ExpectationMatchingMessageHandler implements MessageHandler {
private final Object monitor = new Object();
private final List<MessageExchange> expected;
private final List<MessageExchange> actual = new CopyOnWriteArrayList<>();
private final List<MessageExchange> actual = new ArrayList<>();
private final List<Message<?>> unexpected = new CopyOnWriteArrayList<>();
private CountDownLatch latch = new CountDownLatch(1);
private final List<Message<?>> unexpected = new ArrayList<>();
public ExpectationMatchingMessageHandler(MessageExchange... expected) {
this.expected = new CopyOnWriteArrayList<>(expected);
synchronized (this.monitor) {
this.expected = new CopyOnWriteArrayList<>(expected);
}
}
public void expect(MessageExchange... expected) {
this.expected.addAll(Arrays.asList(expected));
synchronized (this.monitor) {
this.expected.addAll(Arrays.asList(expected));
}
}
public void awaitAndAssert() throws InterruptedException {
boolean result = this.latch.await(5000, TimeUnit.MILLISECONDS);
assertTrue(getAsString(), result && this.unexpected.isEmpty());
}
public void reset() {
this.latch = new CountDownLatch(1);
this.expected.clear();
this.actual.clear();
this.unexpected.clear();
long endTime = System.currentTimeMillis() + 10000;
synchronized (this.monitor) {
while (!this.expected.isEmpty() && System.currentTimeMillis() < endTime) {
this.monitor.wait(500);
}
boolean result = this.expected.isEmpty();
assertTrue(getAsString(), result && this.unexpected.isEmpty());
}
}
@Override
public void handleMessage(Message<?> message) throws MessagingException {
for (MessageExchange exch : this.expected) {
if (exch.matchMessage(message)) {
if (exch.isDone()) {
this.expected.remove(exch);
this.actual.add(exch);
if (this.expected.isEmpty()) {
this.latch.countDown();
if (StompHeaderAccessor.wrap(message).getCommand() != null) {
synchronized(this.monitor) {
for (MessageExchange exch : this.expected) {
if (exch.matchMessage(message)) {
if (exch.isDone()) {
this.expected.remove(exch);
this.actual.add(exch);
if (this.expected.isEmpty()) {
this.monitor.notifyAll();
}
}
return;
}
}
return;
this.unexpected.add(message);
}
}
this.unexpected.add(message);
}
public String getAsString() {
StringBuilder sb = new StringBuilder("\n");
sb.append("INCOMPLETE:\n").append(this.expected).append("\n");
sb.append("COMPLETE:\n").append(this.actual).append("\n");
sb.append("UNMATCHED MESSAGES:\n").append(this.unexpected).append("\n");
synchronized (this.monitor) {
sb.append("INCOMPLETE:\n").append(this.expected).append("\n");
sb.append("COMPLETE:\n").append(this.actual).append("\n");
sb.append("UNMATCHED MESSAGES:\n").append(this.unexpected).append("\n");
}
return sb.toString();
}
}
@ -352,21 +381,28 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
this.headers = StompHeaderAccessor.wrap(message);
}
public static MessageExchangeBuilder error(String sessionId) {
return new MessageExchangeBuilder(null).andExpectError(sessionId);
}
public static MessageExchangeBuilder connect(String sessionId) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId);
headers.setAcceptVersion("1.1,1.2");
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
return new MessageExchangeBuilder(message);
MessageExchangeBuilder builder = new MessageExchangeBuilder(message);
builder.expected.add(new StompConnectedFrameMessageMatcher(sessionId));
return builder;
}
public static MessageExchangeBuilder subscribe(String sessionId, String subscriptionId, String destination) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
public static MessageExchangeBuilder connectWithError(String sessionId) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId);
headers.setSubscriptionId(subscriptionId);
headers.setDestination(destination);
headers.setAcceptVersion("1.1,1.2");
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
return new MessageExchangeBuilder(message);
MessageExchangeBuilder builder = new MessageExchangeBuilder(message);
return builder.andExpectError();
}
public static MessageExchangeBuilder subscribeWithReceipt(String sessionId, String subscriptionId,
@ -514,35 +550,48 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
}
}
private static class StompConnectedFrameMessageMatcher extends StompFrameMessageMatcher {
public StompConnectedFrameMessageMatcher(String sessionId) {
super(StompCommand.CONNECTED, sessionId);
}
}
private static class ExpectationMatchingEventPublisher implements ApplicationEventPublisher {
private final List<Boolean> expected = new CopyOnWriteArrayList<>();
private final List<Boolean> expected = new ArrayList<>();
private final List<Boolean> actual = new CopyOnWriteArrayList<>();
private final List<Boolean> actual = new ArrayList<>();
private CountDownLatch latch = new CountDownLatch(1);
private final Object monitor = new Object();
public void expect(Boolean... expected) {
this.expected.addAll(Arrays.asList(expected));
synchronized (this.monitor) {
this.expected.addAll(Arrays.asList(expected));
}
}
public void awaitAndAssert() throws InterruptedException {
if (this.expected.size() == this.actual.size()) {
synchronized(this.monitor) {
long endTime = System.currentTimeMillis() + 5000;
while (this.expected.size() != this.actual.size() && System.currentTimeMillis() < endTime) {
this.monitor.wait(500);
}
assertEquals(this.expected, this.actual);
}
else {
assertTrue("Expected=" + this.expected + ", actual=" + this.actual,
this.latch.await(5, TimeUnit.SECONDS));
}
}
@Override
public void publishEvent(ApplicationEvent event) {
if (event instanceof BrokerAvailabilityEvent) {
this.actual.add(((BrokerAvailabilityEvent) event).isBrokerAvailable());
if (this.actual.size() == this.expected.size()) {
this.latch.countDown();
synchronized(this.monitor) {
this.actual.add(((BrokerAvailabilityEvent) event).isBrokerAvailable());
if (this.actual.size() == this.expected.size()) {
this.monitor.notifyAll();
}
}
}
}

View File

@ -0,0 +1,212 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import reactor.function.Consumer;
import reactor.function.Function;
import reactor.io.Buffer;
import static org.junit.Assert.*;
/**
*
* @author awilkinson
*/
public class StompCodecTests {
private final ArgumentCapturingConsumer<Message<byte[]>> consumer = new ArgumentCapturingConsumer<Message<byte[]>>();
private final Function<Buffer, Message<byte[]>> decoder = new StompCodec().decoder(consumer);
@Test
public void decodeFrameWithCrLfEols() {
Message<byte[]> frame = decode("DISCONNECT\r\n\r\n\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.DISCONNECT, headers.getCommand());
assertEquals(0, headers.toStompHeaderMap().size());
assertEquals(0, frame.getPayload().length);
}
@Test
public void decodeFrameWithNoHeadersAndNoBody() {
Message<byte[]> frame = decode("DISCONNECT\n\n\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.DISCONNECT, headers.getCommand());
assertEquals(0, headers.toStompHeaderMap().size());
assertEquals(0, frame.getPayload().length);
}
@Test
public void decodeFrameWithNoBody() {
String accept = "accept-version:1.1\n";
String host = "host:github.org\n";
Message<byte[]> frame = decode("CONNECT\n" + accept + host + "\n\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.CONNECT, headers.getCommand());
assertEquals(2, headers.toStompHeaderMap().size());
assertEquals("1.1", headers.getFirstNativeHeader("accept-version"));
assertEquals("github.org", headers.getHost());
assertEquals(0, frame.getPayload().length);
}
@Test
public void decodeFrame() throws UnsupportedEncodingException {
Message<byte[]> frame = decode("SEND\ndestination:test\n\nThe body of the message\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.SEND, headers.getCommand());
assertEquals(1, headers.toStompHeaderMap().size());
assertEquals("test", headers.getDestination());
String bodyText = new String(frame.getPayload());
assertEquals("The body of the message", bodyText);
}
@Test
public void decodeFrameWithContentLength() {
Message<byte[]> frame = decode("SEND\ncontent-length:23\n\nThe body of the message\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.SEND, headers.getCommand());
assertEquals(1, headers.toStompHeaderMap().size());
assertEquals(Integer.valueOf(23), headers.getContentLength());
String bodyText = new String(frame.getPayload());
assertEquals("The body of the message", bodyText);
}
@Test
public void decodeFrameWithNullOctectsInTheBody() {
Message<byte[]> frame = decode("SEND\ncontent-length:23\n\nThe b\0dy \0f the message\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.SEND, headers.getCommand());
assertEquals(1, headers.toStompHeaderMap().size());
assertEquals(Integer.valueOf(23), headers.getContentLength());
String bodyText = new String(frame.getPayload());
assertEquals("The b\0dy \0f the message", bodyText);
}
@Test
public void decodeFrameWithEscapedHeaders() {
Message<byte[]> frame = decode("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.DISCONNECT, headers.getCommand());
assertEquals(1, headers.toStompHeaderMap().size());
assertEquals("alpha:bravo\r\n\\", headers.getFirstNativeHeader("a:\r\n\\b"));
}
@Test
public void decodeMultipleFramesFromSameBuffer() {
String frame1 = "SEND\ndestination:test\n\nThe body of the message\0";
String frame2 = "DISCONNECT\n\n\0";
Buffer buffer = Buffer.wrap(frame1 + frame2);
final List<Message<byte[]>> messages = new ArrayList<Message<byte[]>>();
new StompCodec().decoder(new Consumer<Message<byte[]>>() {
@Override
public void accept(Message<byte[]> message) {
messages.add(message);
}
}).apply(buffer);
assertEquals(2, messages.size());
assertEquals(StompCommand.SEND, StompHeaderAccessor.wrap(messages.get(0)).getCommand());
assertEquals(StompCommand.DISCONNECT, StompHeaderAccessor.wrap(messages.get(1)).getCommand());
}
@Test
public void encodeFrameWithNoHeadersAndNoBody() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
Message<byte[]> frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
assertEquals("DISCONNECT\n\n\0", new StompCodec().encoder().apply(frame).asString());
}
@Test
public void encodeFrameWithHeaders() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setAcceptVersion("1.2");
headers.setHost("github.org");
Message<byte[]> frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
String frameString = new StompCodec().encoder().apply(frame).asString();
assertTrue(frameString.equals("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0") ||
frameString.equals("CONNECT\nhost:github.org\naccept-version:1.2\n\n\0"));
}
@Test
public void encodeFrameWithHeadersThatShouldBeEscaped() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\");
Message<byte[]> frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
assertEquals("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0", new StompCodec().encoder().apply(frame).asString());
}
@Test
public void encodeFrameWithHeadersBody() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.addNativeHeader("a", "alpha");
Message<byte[]> frame = MessageBuilder.withPayloadAndHeaders("Message body".getBytes(), headers).build();
assertEquals("SEND\na:alpha\ncontent-length:12\n\nMessage body\0", new StompCodec().encoder().apply(frame).asString());
}
private Message<byte[]> decode(String stompFrame) {
this.decoder.apply(Buffer.wrap(stompFrame));
return consumer.arguments.get(0);
}
private static final class ArgumentCapturingConsumer<T> implements Consumer<T> {
private final List<T> arguments = new ArrayList<T>();
@Override
public void accept(T t) {
arguments.add(t);
}
}
}

View File

@ -1,153 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.util.Collections;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.web.socket.TextMessage;
import static org.junit.Assert.*;
/**
* @author Gary Russell
* @author Rossen Stoyanchev
*/
public class StompMessageConverterTests {
private StompMessageConverter converter;
@Before
public void setup() {
this.converter = new StompMessageConverter();
}
@Test
public void connectFrame() throws Exception {
String accept = "accept-version:1.1";
String host = "host:github.org";
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT)
.headers(accept, host).build();
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(textMessage.getPayload());
assertEquals(0, message.getPayload().length);
MessageHeaders headers = message.getHeaders();
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
Map<String, Object> map = stompHeaders.toMap();
assertEquals(5, map.size());
assertNotNull(stompHeaders.getId());
assertNotNull(stompHeaders.getTimestamp());
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getCommand());
assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS));
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost());
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getCommand());
assertNotNull(headers.get(MessageHeaders.ID));
assertNotNull(headers.get(MessageHeaders.TIMESTAMP));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
@Test
public void connectWithEscapes() throws Exception {
String accept = "accept-version:1.1";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org";
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT)
.headers(accept, host).build();
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(textMessage.getPayload());
assertEquals(0, message.getPayload().length);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.toNativeHeaderMap().get("ho:\ns\rt").get(0));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
@Test
public void connectCR12() throws Exception {
String accept = "accept-version:1.2\n";
String host = "host:github.org\n";
String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost());
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
@Test
public void connectWithEscapesAndCR12() throws Exception {
String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.toNativeHeaderMap().get("ho:\ns\rt").get(0));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.messaging.simp.stomp;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashSet;
@ -60,7 +61,33 @@ public class StompProtocolHandlerTests {
}
@Test
public void handleConnect() {
public void connectedResponseIsSentWhenHandlingConnect() {
this.stompHandler.setHandleConnect(true);
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers(
"login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();
this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel);
verifyNoMoreInteractions(this.channel);
// Check CONNECTED reply
assertEquals(1, this.session.getSentMessages().size());
textMessage = (TextMessage) this.session.getSentMessages().get(0);
Message<?> message = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes()));
StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message);
assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand());
assertEquals("1.1", replyHeaders.getVersion());
assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat());
assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0));
assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0));
}
@Test
public void connectIsForwardedWhenNotHandlingConnect() {
this.stompHandler.setHandleConnect(false);
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers(
"login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();
@ -80,18 +107,7 @@ public class StompProtocolHandlerTests {
assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat());
assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion());
// Check CONNECTED reply
assertEquals(1, this.session.getSentMessages().size());
textMessage = (TextMessage) this.session.getSentMessages().get(0);
Message<?> message = new StompMessageConverter().toMessage(textMessage.getPayload());
StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message);
assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand());
assertEquals("1.1", replyHeaders.getVersion());
assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat());
assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0));
assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0));
assertEquals(0, this.session.getSentMessages().size());
}
}