Refactor approach to working with STOMP headers

This commit is contained in:
Rossen Stoyanchev 2013-06-11 01:52:32 -04:00
parent 547167e8b4
commit d26b9d60e5
21 changed files with 762 additions and 969 deletions

View File

@ -37,8 +37,9 @@ import org.apache.commons.logging.LogFactory;
* The headers for a {@link Message}.<br>
* IMPORTANT: MessageHeaders are immutable. Any mutating operation (e.g., put(..), putAll(..) etc.)
* will result in {@link UnsupportedOperationException}
*
* <p>
* TODO: update javadoc
* TODO: update below instructions
*
* <p>To create MessageHeaders instance use fluent MessageBuilder API
* <pre>
@ -76,16 +77,7 @@ public class MessageHeaders implements Map<String, Object>, Serializable {
public static final String TIMESTAMP = "timestamp";
public static final String REPLY_CHANNEL = "replyChannel";
public static final String ERROR_CHANNEL = "errorChannel";
public static final String CONTENT_TYPE = "content-type";
// DESTINATION ?
public static final List<String> HEADER_NAMES =
Arrays.asList(ID, TIMESTAMP, REPLY_CHANNEL, ERROR_CHANNEL, CONTENT_TYPE);
public static final List<String> HEADER_NAMES = Arrays.asList(ID, TIMESTAMP);
private final Map<String, Object> headers;
@ -111,14 +103,6 @@ public class MessageHeaders implements Map<String, Object>, Serializable {
return this.get(TIMESTAMP, Long.class);
}
public Object getReplyChannel() {
return this.get(REPLY_CHANNEL);
}
public Object getErrorChannel() {
return this.get(ERROR_CHANNEL);
}
@SuppressWarnings("unchecked")
public <T> T get(Object key, Class<T> type) {
Object value = this.headers.get(key);

View File

@ -0,0 +1,163 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.http.MediaType;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.CollectionUtils;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PubSubHeaders {
private static final String DESTINATIONS = "destinations";
private static final String CONTENT_TYPE = "contentType";
private static final String MESSAGE_TYPE = "messageType";
private static final String SUBSCRIPTION_ID = "subscriptionId";
private static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
private static final String SESSION_ID = "sessionId";
private static final String RAW_HEADERS = "rawHeaders";
private final Map<String, Object> messageHeaders;
private final Map<String, String> rawHeaders;
/**
* Constructor for building new headers.
*
* @param messageType the message type
* @param protocolMessageType the protocol-specific message type or command
*/
public PubSubHeaders(MessageType messageType, Object protocolMessageType) {
this.messageHeaders = new HashMap<String, Object>();
this.messageHeaders.put(MESSAGE_TYPE, messageType);
if (protocolMessageType != null) {
this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
}
this.rawHeaders = new HashMap<String, String>();
this.messageHeaders.put(RAW_HEADERS, this.rawHeaders);
}
public PubSubHeaders() {
this(MessageType.MESSAGE, null);
}
/**
* Constructor for access to existing {@link MessageHeaders}.
*
* @param messageHeaders
*/
@SuppressWarnings("unchecked")
public PubSubHeaders(MessageHeaders messageHeaders, boolean readOnly) {
this.messageHeaders = readOnly ? messageHeaders : new HashMap<String, Object>(messageHeaders);
this.rawHeaders = this.messageHeaders.containsKey(RAW_HEADERS) ?
(Map<String, String>) messageHeaders.get(RAW_HEADERS) : new HashMap<String, String>();
if (this.messageHeaders.get(MESSAGE_TYPE) == null) {
this.messageHeaders.put(MESSAGE_TYPE, MessageType.MESSAGE);
}
}
public Map<String, Object> getMessageHeaders() {
return this.messageHeaders;
}
public Map<String, String> getRawHeaders() {
return this.rawHeaders;
}
public MessageType getMessageType() {
return (MessageType) this.messageHeaders.get(MESSAGE_TYPE);
}
public void setProtocolMessageType(Object protocolMessageType) {
this.messageHeaders.put(PROTOCOL_MESSAGE_TYPE, protocolMessageType);
}
public Object getProtocolMessageType() {
return this.messageHeaders.get(PROTOCOL_MESSAGE_TYPE);
}
public void setDestination(String destination) {
this.messageHeaders.put(DESTINATIONS, Arrays.asList(destination));
}
public String getDestination() {
@SuppressWarnings("unchecked")
List<String> destination = (List<String>) messageHeaders.get(DESTINATIONS);
return CollectionUtils.isEmpty(destination) ? null : destination.get(0);
}
@SuppressWarnings("unchecked")
public List<String> getDestinations() {
return (List<String>) messageHeaders.get(DESTINATIONS);
}
public void setDestinations(List<String> destinations) {
if (destinations != null) {
this.messageHeaders.put(DESTINATIONS, destinations);
}
}
public MediaType getContentType() {
return (MediaType) this.messageHeaders.get(CONTENT_TYPE);
}
public void setContentType(MediaType mediaType) {
if (mediaType != null) {
this.messageHeaders.put(CONTENT_TYPE, mediaType);
}
}
public String getSubscriptionId() {
return (String) this.messageHeaders.get(SUBSCRIPTION_ID);
}
public void setSubscriptionId(String subscriptionId) {
this.messageHeaders.put(SUBSCRIPTION_ID, subscriptionId);
}
public String getSessionId() {
return (String) this.messageHeaders.get(SESSION_ID);
}
public void setSessionId(String sessionId) {
this.messageHeaders.put(SESSION_ID, sessionId);
}
}

View File

@ -16,6 +16,9 @@
package org.springframework.web.messaging.event;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.Reactor;
import reactor.fn.Consumer;
import reactor.fn.Event;
@ -28,6 +31,8 @@ import reactor.fn.selector.ObjectSelector;
*/
public class ReactorEventBus implements EventBus {
private static Log logger = LogFactory.getLog(ReactorEventBus.class);
private final Reactor reactor;
@ -37,6 +42,9 @@ public class ReactorEventBus implements EventBus {
@Override
public void send(String key, Object data) {
if (logger.isTraceEnabled()) {
logger.trace("Sending notification key=" + key + ", data=" + data);
}
this.reactor.notify(key, Event.wrap(data));
}

View File

@ -27,6 +27,7 @@ import org.springframework.util.AntPathMatcher;
import org.springframework.util.Assert;
import org.springframework.util.PathMatcher;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.event.EventConsumer;
@ -37,11 +38,14 @@ import org.springframework.web.messaging.event.EventConsumer;
*/
public abstract class AbstractMessageService {
public static final String MESSAGE_KEY = "messageKey";
protected final Log logger = LogFactory.getLog(getClass());
public static final String CLIENT_TO_SERVER_MESSAGE_KEY = "clientToServerMessageKey";
public static final String CLIENT_CONNECTION_CLOSED_KEY = "clientConnectionClosed";
protected final Log logger = LogFactory.getLog(getClass());
public static final String SERVER_TO_CLIENT_MESSAGE_KEY = "serverToClientMessageKey";
private final EventBus eventBus;
@ -58,7 +62,7 @@ public abstract class AbstractMessageService {
Assert.notNull(reactor, "reactor is required");
this.eventBus = reactor;
this.eventBus.registerConsumer(MESSAGE_KEY, new EventConsumer<Message<?>>() {
this.eventBus.registerConsumer(CLIENT_TO_SERVER_MESSAGE_KEY, new EventConsumer<Message<?>>() {
@Override
public void accept(Message<?> message) {
@ -124,7 +128,8 @@ public abstract class AbstractMessageService {
}
private boolean isAllowedDestination(Message<?> message) {
String destination = (String) message.getHeaders().get("destination");
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
String destination = headers.getDestination();
if (destination == null) {
return true;
}

View File

@ -17,14 +17,13 @@
package org.springframework.web.messaging.service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.event.EventBus;
@ -61,26 +60,21 @@ public class PubSubMessageService extends AbstractMessageService {
logger.debug("Message received: " + message);
}
Map<String, Object> headers = new HashMap<String, Object>();
headers.put("destination", message.getHeaders().get("destination"));
MediaType contentType = (MediaType) message.getHeaders().get("content-type");
headers.put("content-type", contentType);
try {
// Convert to byte[] payload before the fan-out
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), contentType);
message = new GenericMessage<byte[]>(payload, headers);
PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), inHeaders.getContentType());
message = new GenericMessage<byte[]>(payload, message.getHeaders());
getEventBus().send(getPublishKey(message), message);
getEventBus().send(getPublishKey(inHeaders.getDestination()), message);
}
catch (Exception ex) {
logger.error("Failed to publish " + message, ex);
}
}
private String getPublishKey(Message<?> message) {
return "destination:" + (String) message.getHeaders().get("destination");
private String getPublishKey(String destination) {
return "destination:" + destination;
}
@Override
@ -88,12 +82,20 @@ public class PubSubMessageService extends AbstractMessageService {
if (logger.isDebugEnabled()) {
logger.debug("Subscribe " + message);
}
final String replyKey = (String) message.getHeaders().getReplyChannel();
EventRegistration registration = getEventBus().registerConsumer(getPublishKey(message),
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
final String subscriptionId = headers.getSubscriptionId();
EventRegistration registration = getEventBus().registerConsumer(getPublishKey(headers.getDestination()),
new EventConsumer<Message<?>>() {
@Override
public void accept(Message<?> message) {
getEventBus().send(replyKey, message);
PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
PubSubHeaders outHeaders = new PubSubHeaders();
outHeaders.setDestinations(inHeaders.getDestinations());
outHeaders.setContentType(inHeaders.getContentType());
outHeaders.setSubscriptionId(subscriptionId);
Object payload = message.getPayload();
message = new GenericMessage<Object>(payload, outHeaders.getMessageHeaders());
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message);
}
});

View File

@ -35,6 +35,7 @@ import org.springframework.messaging.annotation.MessageMapping;
import org.springframework.stereotype.Controller;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils.MethodFilter;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.annotation.SubscribeEvent;
import org.springframework.web.messaging.annotation.UnsubscribeEvent;
import org.springframework.web.messaging.converter.MessageConverter;
@ -166,7 +167,8 @@ public class AnnotationMessageService extends AbstractMessageService implements
private void handleMessage(final Message<?> message, Map<MappingInfo, HandlerMethod> handlerMethods) {
String destination = (String) message.getHeaders().get("destination");
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), true);
String destination = headers.getDestination();
HandlerMethod match = getHandlerMethod(destination, handlerMethods);
if (match == null) {

View File

@ -16,16 +16,11 @@
package org.springframework.web.messaging.service.method;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.service.AbstractMessageService;
@ -38,8 +33,6 @@ import reactor.util.Assert;
*/
public class MessageChannelArgumentResolver implements ArgumentResolver {
private static Log logger = LogFactory.getLog(MessageChannelArgumentResolver.class);
private final EventBus eventBus;
@ -56,24 +49,18 @@ public class MessageChannelArgumentResolver implements ArgumentResolver {
@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
final String sessionId = (String) message.getHeaders().get("sessionId");
final String sessionId = new PubSubHeaders(message.getHeaders(), true).getSessionId();
return new MessageChannel() {
@Override
public boolean send(Message<?> message) {
Map<String, Object> headers = new HashMap<String, Object>(message.getHeaders());
headers.put("messageType", MessageType.MESSAGE);
headers.put("sessionId", sessionId);
message = new GenericMessage<Object>(message.getPayload(), headers);
PubSubHeaders headers = new PubSubHeaders(message.getHeaders(), false);
headers.setSessionId(sessionId);
message = new GenericMessage<Object>(message.getPayload(), headers.getMessageHeaders());
if (logger.isTraceEnabled()) {
logger.trace("Sending notification: " + message);
}
String key = AbstractMessageService.MESSAGE_KEY;
MessageChannelArgumentResolver.this.eventBus.send(key, message);
eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message);
return true;
}

View File

@ -16,11 +16,12 @@
package org.springframework.web.messaging.service.method;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.service.AbstractMessageService;
import reactor.util.Assert;
@ -31,8 +32,6 @@ import reactor.util.Assert;
*/
public class MessageReturnValueHandler implements ReturnValueHandler {
private static Log logger = LogFactory.getLog(MessageReturnValueHandler.class);
private final EventBus eventBus;
@ -67,13 +66,17 @@ public class MessageReturnValueHandler implements ReturnValueHandler {
return;
}
String replyTo = (String) message.getHeaders().getReplyChannel();
Assert.notNull(replyTo, "Cannot reply to: " + message);
PubSubHeaders inHeaders = new PubSubHeaders(message.getHeaders(), true);
String sessionId = inHeaders.getSessionId();
String subscriptionId = inHeaders.getSubscriptionId();
Assert.notNull(subscriptionId, "No subscription id: " + message);
if (logger.isTraceEnabled()) {
logger.trace("Sending notification: " + message);
}
this.eventBus.send(replyTo, returnMessage);
PubSubHeaders outHeaders = new PubSubHeaders(returnMessage.getHeaders(), false);
outHeaders.setSessionId(sessionId);
outHeaders.setSubscriptionId(subscriptionId);
returnMessage = new GenericMessage<Object>(returnMessage.getPayload(), outHeaders.getMessageHeaders());
this.eventBus.send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, returnMessage);
}
}

View File

@ -16,6 +16,11 @@
package org.springframework.web.messaging.stomp;
import java.util.HashMap;
import java.util.Map;
import org.springframework.web.messaging.MessageType;
/**
*
@ -43,4 +48,21 @@ public enum StompCommand {
RECEIPT,
ERROR;
private static Map<StompCommand, MessageType> commandToMessageType = new HashMap<StompCommand, MessageType>();
static {
commandToMessageType.put(StompCommand.CONNECT, MessageType.CONNECT);
commandToMessageType.put(StompCommand.STOMP, MessageType.CONNECT);
commandToMessageType.put(StompCommand.SEND, MessageType.MESSAGE);
commandToMessageType.put(StompCommand.SUBSCRIBE, MessageType.SUBSCRIBE);
commandToMessageType.put(StompCommand.UNSUBSCRIBE, MessageType.UNSUBSCRIBE);
commandToMessageType.put(StompCommand.DISCONNECT, MessageType.DISCONNECT);
}
public MessageType getMessageType() {
MessageType messageType = commandToMessageType.get(this);
return (messageType != null) ? messageType : MessageType.OTHER;
}
}

View File

@ -20,17 +20,16 @@ import org.springframework.core.NestedRuntimeException;
/**
* @author Gary Russell
* @since 4.0
*
*/
@SuppressWarnings("serial")
public class StompException extends NestedRuntimeException {
public class StompConversionException extends NestedRuntimeException {
public StompException(String msg, Throwable cause) {
public StompConversionException(String msg, Throwable cause) {
super(msg, cause);
}
public StompException(String msg) {
public StompConversionException(String msg) {
super(msg);
}

View File

@ -16,42 +16,32 @@
package org.springframework.web.messaging.stomp;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.StringUtils;
import org.springframework.web.messaging.PubSubHeaders;
import reactor.util.Assert;
/**
* STOMP adapter for {@link MessageHeaders}.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompHeaders implements MultiValueMap<String, String>, Serializable {
public class StompHeaders extends PubSubHeaders {
private static final long serialVersionUID = 1L;
// TODO: separate client from server headers so they can't be mixed
// Client
private static final String ID = "id";
private static final String HOST = "host";
private static final String ACCEPT_VERSION = "accept-version";
// Server
private static final String MESSAGE_ID = "message-id";
private static final String RECEIPT_ID = "receipt-id";
@ -62,8 +52,6 @@ public class StompHeaders implements MultiValueMap<String, String>, Serializable
private static final String MESSAGE = "message";
// Client and Server
private static final String ACK = "ack";
private static final String DESTINATION = "destination";
@ -75,96 +63,63 @@ public class StompHeaders implements MultiValueMap<String, String>, Serializable
private static final String HEARTBEAT = "heart-beat";
public static final List<String> STANDARD_HEADER_NAMES =
Arrays.asList(ID, HOST, ACCEPT_VERSION, MESSAGE_ID, RECEIPT_ID, SUBSCRIPTION,
VERSION, MESSAGE, ACK, DESTINATION, CONTENT_LENGTH, CONTENT_TYPE, HEARTBEAT);
private final Map<String, List<String>> headers;
/**
* Private constructor that can create read-only {@code StompHeaders} instances.
* Constructor for building new headers.
*
* @param command the STOMP command
*/
private StompHeaders(Map<String, List<String>> headers, boolean readOnly) {
Assert.notNull(headers, "'headers' must not be null");
if (readOnly) {
Map<String, List<String>> map = new LinkedHashMap<String, List<String>>(headers.size());
for (Entry<String, List<String>> entry : headers.entrySet()) {
List<String> values = Collections.unmodifiableList(entry.getValue());
map.put(entry.getKey(), values);
}
this.headers = Collections.unmodifiableMap(map);
}
else {
this.headers = headers;
}
public StompHeaders(StompCommand command) {
super(command.getMessageType(), command);
}
/**
* Constructs a new, empty instance of the {@code StompHeaders} object.
* Constructor for access to existing {@link MessageHeaders}.
*
* @param messageHeaders the existing message headers
* @param readOnly whether the resulting instance will be used for read-only access,
* if {@code true}, then set methods will throw exceptions; if {@code false}
* they will work.
*/
public StompHeaders() {
this(new LinkedHashMap<String, List<String>>(4), false);
public StompHeaders(MessageHeaders messageHeaders, boolean readOnly) {
super(messageHeaders, readOnly);
}
/**
* Returns {@code StompHeaders} object that can only be read, not written to.
*/
public static StompHeaders readOnlyStompHeaders(StompHeaders headers) {
return new StompHeaders(headers, true);
@Override
public StompCommand getProtocolMessageType() {
return (StompCommand) super.getProtocolMessageType();
}
public StompCommand getStompCommand() {
return (StompCommand) super.getProtocolMessageType();
}
public Set<String> getAcceptVersion() {
String rawValue = getFirst(ACCEPT_VERSION);
String rawValue = getRawHeaders().get(ACCEPT_VERSION);
return (rawValue != null) ? StringUtils.commaDelimitedListToSet(rawValue) : Collections.<String>emptySet();
}
public void setAcceptVersion(String acceptVersion) {
set(ACCEPT_VERSION, acceptVersion);
}
public String getVersion() {
return getFirst(VERSION);
}
public void setVersion(String version) {
set(VERSION, version);
}
public String getDestination() {
return getFirst(DESTINATION);
getRawHeaders().put(ACCEPT_VERSION, acceptVersion);
}
@Override
public void setDestination(String destination) {
set(DESTINATION, destination);
}
public MediaType getContentType() {
String contentType = getFirst(CONTENT_TYPE);
return StringUtils.hasText(contentType) ? MediaType.valueOf(contentType) : null;
}
public void setContentType(MediaType mediaType) {
if (mediaType != null) {
set(CONTENT_TYPE, mediaType.toString());
}
else {
remove(CONTENT_TYPE);
if (destination != null) {
super.setDestination(destination);
getRawHeaders().put(DESTINATION, destination);
}
}
public Integer getContentLength() {
String contentLength = getFirst(CONTENT_LENGTH);
return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null;
}
public void setContentLength(int contentLength) {
set(CONTENT_LENGTH, String.valueOf(contentLength));
@Override
public void setDestinations(List<String> destinations) {
if (destinations != null) {
super.setDestinations(destinations);
getRawHeaders().put(DESTINATION, destinations.get(0));
}
}
public long[] getHeartbeat() {
String rawValue = getFirst(HEARTBEAT);
String rawValue = getRawHeaders().get(HEARTBEAT);
if (!StringUtils.hasText(rawValue)) {
return null;
}
@ -173,172 +128,102 @@ public class StompHeaders implements MultiValueMap<String, String>, Serializable
return new long[] { Long.valueOf(rawValues[0]), Long.valueOf(rawValues[1])};
}
public void setContentType(MediaType mediaType) {
if (mediaType != null) {
super.setContentType(mediaType);
getRawHeaders().put(CONTENT_TYPE, mediaType.toString());
}
}
public Integer getContentLength() {
String contentLength = getRawHeaders().get(CONTENT_LENGTH);
return StringUtils.hasText(contentLength) ? new Integer(contentLength) : null;
}
public void setContentLength(int contentLength) {
getRawHeaders().put(CONTENT_LENGTH, String.valueOf(contentLength));
}
@Override
public String getSubscriptionId() {
return StompCommand.SUBSCRIBE.equals(getStompCommand()) ? getRawHeaders().get(ID) : null;
}
@Override
public void setSubscriptionId(String subscriptionId) {
Assert.isTrue(StompCommand.MESSAGE.equals(getStompCommand()),
"\"subscription\" can only be set on a STOMP MESSAGE frame");
super.setSubscriptionId(subscriptionId);
getRawHeaders().put(SUBSCRIPTION, subscriptionId);
}
public void setHeartbeat(long cx, long cy) {
set(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
}
public String getId() {
return getFirst(ID);
}
public void setId(String id) {
set(ID, id);
}
public String getMessageId() {
return getFirst(MESSAGE_ID);
}
public void setMessageId(String id) {
set(MESSAGE_ID, id);
}
public String getSubscription() {
return getFirst(SUBSCRIPTION);
}
public void setSubscription(String id) {
set(SUBSCRIPTION, id);
getRawHeaders().put(HEARTBEAT, StringUtils.arrayToCommaDelimitedString(new Object[] {cx, cy}));
}
public String getMessage() {
return getFirst(MESSAGE);
return getRawHeaders().get(MESSAGE);
}
public void setMessage(String id) {
set(MESSAGE, id);
public void setMessage(String content) {
getRawHeaders().put(MESSAGE, content);
}
public String getMessageId() {
return getRawHeaders().get(MESSAGE_ID);
}
public void setMessageId(String id) {
getRawHeaders().put(MESSAGE_ID, id);
}
public String getVersion() {
return getRawHeaders().get(VERSION);
}
public void setVersion(String version) {
getRawHeaders().put(VERSION, version);
}
// MultiValueMap methods
/**
* Return the first header value for the given header name, if any.
* @param headerName the header name
* @return the first header value; or {@code null}
* Update generic message headers from raw headers. This method only needs to be
* invoked when raw headers are added via {@link #getRawHeaders()}.
*/
public String getFirst(String headerName) {
List<String> headerValues = headers.get(headerName);
return headerValues != null ? headerValues.get(0) : null;
public void updateMessageHeaders() {
String destination = getRawHeaders().get(DESTINATION);
if (destination != null) {
setDestination(destination);
}
String contentType = getRawHeaders().get(CONTENT_TYPE);
if (contentType != null) {
setContentType(MediaType.parseMediaType(contentType));
}
if (StompCommand.SUBSCRIBE.equals(getStompCommand())) {
if (getRawHeaders().get(ID) != null) {
super.setSubscriptionId(getRawHeaders().get(ID));
}
}
}
/**
* Add the given, single header value under the given name.
* @param headerName the header name
* @param headerValue the header value
* @throws UnsupportedOperationException if adding headers is not supported
* @see #put(String, List)
* @see #set(String, String)
* Update raw headers from generic message headers. This method only needs to be
* invoked if creating {@link StompHeaders} from {@link MessageHeaders} that never
* contained raw headers.
*/
public void add(String headerName, String headerValue) {
List<String> headerValues = headers.get(headerName);
if (headerValues == null) {
headerValues = new LinkedList<String>();
this.headers.put(headerName, headerValues);
public void updateRawHeaders() {
String destination = getDestination();
if (destination != null) {
getRawHeaders().put(DESTINATION, destination);
}
headerValues.add(headerValue);
}
/**
* Set the given, single header value under the given name.
* @param headerName the header name
* @param headerValue the header value
* @throws UnsupportedOperationException if adding headers is not supported
* @see #put(String, List)
* @see #add(String, String)
*/
public void set(String headerName, String headerValue) {
List<String> headerValues = new LinkedList<String>();
headerValues.add(headerValue);
headers.put(headerName, headerValues);
}
public void setAll(Map<String, String> values) {
for (Entry<String, String> entry : values.entrySet()) {
set(entry.getKey(), entry.getValue());
MediaType contentType = getContentType();
if (contentType != null) {
getRawHeaders().put(CONTENT_TYPE, contentType.toString());
}
}
public Map<String, String> toSingleValueMap() {
LinkedHashMap<String, String> singleValueMap = new LinkedHashMap<String,String>(this.headers.size());
for (Entry<String, List<String>> entry : headers.entrySet()) {
singleValueMap.put(entry.getKey(), entry.getValue().get(0));
String subscriptionId = getSubscriptionId();
if (subscriptionId != null) {
getRawHeaders().put(SUBSCRIPTION, subscriptionId);
}
return singleValueMap;
}
// Map implementation
public int size() {
return this.headers.size();
}
public boolean isEmpty() {
return this.headers.isEmpty();
}
public boolean containsKey(Object key) {
return this.headers.containsKey(key);
}
public boolean containsValue(Object value) {
return this.headers.containsValue(value);
}
public List<String> get(Object key) {
return this.headers.get(key);
}
public List<String> put(String key, List<String> value) {
return this.headers.put(key, value);
}
public List<String> remove(Object key) {
return this.headers.remove(key);
}
public void putAll(Map<? extends String, ? extends List<String>> m) {
this.headers.putAll(m);
}
public void clear() {
this.headers.clear();
}
public Set<String> keySet() {
return this.headers.keySet();
}
public Collection<List<String>> values() {
return this.headers.values();
}
public Set<Entry<String, List<String>>> entrySet() {
return this.headers.entrySet();
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof StompHeaders)) {
return false;
}
StompHeaders otherHeaders = (StompHeaders) other;
return this.headers.equals(otherHeaders.headers);
}
@Override
public int hashCode() {
return this.headers.hashCode();
}
@Override
public String toString() {
return this.headers.toString();
}
}

View File

@ -1,78 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.stomp;
import java.nio.charset.Charset;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompMessage {
public static final Charset CHARSET = Charset.forName("UTF-8");
private final StompCommand command;
private final StompHeaders headers;
private final byte[] payload;
private String sessionId;
public StompMessage(StompCommand command, StompHeaders headers, byte[] payload) {
this.command = command;
this.headers = (headers != null) ? headers : new StompHeaders();
this.payload = payload;
}
/**
* Constructor for empty payload message.
*/
public StompMessage(StompCommand command, StompHeaders headers) {
this(command, headers, new byte[0]);
}
public StompCommand getCommand() {
return this.command;
}
public StompHeaders getHeaders() {
return this.headers;
}
public byte[] getPayload() {
return this.payload;
}
public void setSessionId(String sessionId) {
this.sessionId = sessionId;
}
public String getSessionId() {
return this.sessionId;
}
@Override
public String toString() {
return "StompMessage [" + command + ", headers=" + this.headers + ", payload=" + new String(this.payload) + "]";
}
}

View File

@ -1,41 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.stomp;
import java.io.IOException;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface StompSession {
String getId();
/**
* TODO...
* <p>
* If the message is a STOMP ERROR message, the session will also be closed.
*/
void sendMessage(StompMessage message) throws IOException;
/**
* Register a task to be invoked if the underlying connection is closed.
*/
void registerConnectionClosedTask(Runnable task);
}

View File

@ -15,14 +15,14 @@
*/
package org.springframework.web.messaging.stomp.socket;
import java.nio.charset.Charset;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.util.Assert;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.StompMessage;
import org.springframework.web.messaging.stomp.StompSession;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.support.StompMessageConverter;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
@ -36,57 +36,65 @@ import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
*/
public abstract class AbstractStompWebSocketHandler extends TextWebSocketHandlerAdapter {
private final StompMessageConverter messageConverter = new StompMessageConverter();
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private final Map<String, WebSocketStompSession> sessions = new ConcurrentHashMap<String, WebSocketStompSession>();
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
WebSocketStompSession stompSession = new WebSocketStompSession(session, this.messageConverter);
this.sessions.put(session.getId(), stompSession);
public StompMessageConverter getStompMessageConverter() {
return this.stompMessageConverter;
}
protected WebSocketSession getWebSocketSession(String sessionId) {
return this.sessions.get(sessionId);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
StompSession stompSession = this.sessions.get(session.getId());
Assert.notNull(stompSession, "No STOMP session for WebSocket session id=" + session.getId());
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.sessions.put(session.getId(), session);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) {
try {
StompMessage stompMessage = this.messageConverter.toStompMessage(message.getPayload());
stompMessage.setSessionId(stompSession.getId());
String payload = textMessage.getPayload();
Message<byte[]> message = this.stompMessageConverter.toMessage(payload, session.getId());
// TODO: validate size limits
// http://stomp.github.io/stomp-specification-1.2.html#Size_Limits
handleStompMessage(stompSession, stompMessage);
handleStompMessage(session, message);
// TODO: send RECEIPT message if incoming message has "receipt" header
// http://stomp.github.io/stomp-specification-1.2.html#Header_receipt
}
catch (Throwable error) {
StompHeaders headers = new StompHeaders();
headers.setMessage(error.getMessage());
StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers);
try {
stompSession.sendMessage(errorMessage);
}
catch (Throwable t) {
// ignore
}
sendErrorMessage(session, error);
}
}
protected abstract void handleStompMessage(StompSession stompSession, StompMessage stompMessage);
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR);
stompHeaders.setMessage(error.getMessage());
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.getMessageHeaders());
byte[] bytes = this.stompMessageConverter.fromMessage(errorMessage);
try {
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
catch (Throwable t) {
// ignore
}
}
protected abstract void handleStompMessage(WebSocketSession session, Message<byte[]> message);
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
WebSocketStompSession stompSession = this.sessions.remove(session.getId());
if (stompSession != null) {
stompSession.handleConnectionClosed();
}
this.sessions.remove(session.getId());
}
}

View File

@ -16,32 +16,28 @@
package org.springframework.web.messaging.stomp.socket;
import java.io.IOException;
import java.util.ArrayList;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.event.EventConsumer;
import org.springframework.web.messaging.event.EventRegistration;
import org.springframework.web.messaging.service.AbstractMessageService;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompException;
import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.StompMessage;
import org.springframework.web.messaging.stomp.StompSession;
import org.springframework.web.messaging.stomp.support.StompHeaderMapper;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* @author Gary Russell
@ -57,14 +53,59 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
private MessageConverter payloadConverter = new CompositeMessageConverter(null);
private final StompHeaderMapper headerMapper = new StompHeaderMapper();
private Map<String, List<EventRegistration>> registrationsBySession =
new ConcurrentHashMap<String, List<EventRegistration>>();
public DefaultStompWebSocketHandler(EventBus eventBus) {
this.eventBus = eventBus;
this.eventBus.registerConsumer(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY,
new EventConsumer<Message<?>>() {
@Override
public void accept(Message<?> message) {
StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false);
if (stompHeaders.getProtocolMessageType() == null) {
stompHeaders.setProtocolMessageType(StompCommand.MESSAGE);
}
if (StompCommand.CONNECTED.equals(stompHeaders.getStompCommand())) {
// Ignore for now since we already sent it
return;
}
String sessionId = stompHeaders.getSessionId();
WebSocketSession session = getWebSocketSession(sessionId);
byte[] payload;
try {
MediaType contentType = stompHeaders.getContentType();
payload = payloadConverter.convertToPayload(message.getPayload(), contentType);
}
catch (Exception e) {
logger.error("Failed to send " + message, e);
return;
}
try {
Map<String, Object> messageHeaders = stompHeaders.getMessageHeaders();
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, messageHeaders);
byte[] bytes = getStompMessageConverter().fromMessage(byteMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
catch (Throwable t) {
sendErrorMessage(session, t);
}
finally {
if (StompCommand.ERROR.equals(stompHeaders.getStompCommand())) {
try {
session.close(CloseStatus.PROTOCOL_ERROR);
}
catch (IOException e) {
}
}
}
}
});
}
@ -72,252 +113,83 @@ public class DefaultStompWebSocketHandler extends AbstractStompWebSocketHandler
this.payloadConverter = new CompositeMessageConverter(converters);
}
public void handleStompMessage(final StompSession session, StompMessage stompMessage) {
public void handleStompMessage(final WebSocketSession session, Message<byte[]> message) {
if (logger.isTraceEnabled()) {
logger.trace("Processing: " + stompMessage);
logger.trace("Processing: " + message);
}
try {
MessageType messageType = MessageType.OTHER;
String replyKey = null;
StompCommand command = stompMessage.getCommand();
if (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command)) {
session.registerConnectionClosedTask(new ConnectionClosedTask(session));
messageType = MessageType.CONNECT;
replyKey = handleConnect(session, stompMessage);
StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), true);
MessageType messageType = stompHeaders.getMessageType();
if (MessageType.CONNECT.equals(messageType)) {
handleConnect(session, message);
}
else if (StompCommand.SEND.equals(command)) {
messageType = MessageType.MESSAGE;
handleSend(session, stompMessage);
else if (MessageType.MESSAGE.equals(messageType)) {
handleMessage(message);
}
else if (StompCommand.SUBSCRIBE.equals(command)) {
messageType = MessageType.SUBSCRIBE;
replyKey = handleSubscribe(session, stompMessage);
else if (MessageType.SUBSCRIBE.equals(messageType)) {
handleSubscribe(message);
}
else if (StompCommand.UNSUBSCRIBE.equals(command)) {
messageType = MessageType.UNSUBSCRIBE;
handleUnsubscribe(session, stompMessage);
else if (MessageType.UNSUBSCRIBE.equals(messageType)) {
handleUnsubscribe(message);
}
else if (StompCommand.DISCONNECT.equals(command)) {
messageType = MessageType.DISCONNECT;
handleDisconnect(session, stompMessage);
else if (MessageType.DISCONNECT.equals(messageType)) {
handleDisconnect(message);
}
else {
sendErrorMessage(session, "Invalid STOMP command " + command);
return;
}
Map<String, Object> messageHeaders = this.headerMapper.toMessageHeaders(stompMessage.getHeaders());
messageHeaders.put("messageType", messageType);
if (replyKey != null) {
messageHeaders.put(MessageHeaders.REPLY_CHANNEL, replyKey);
}
messageHeaders.put("stompCommand", command);
messageHeaders.put("sessionId", session.getId());
Message<byte[]> genericMessage = new GenericMessage<byte[]>(stompMessage.getPayload(), messageHeaders);
if (logger.isTraceEnabled()) {
logger.trace("Sending notification: " + genericMessage);
}
this.eventBus.send(AbstractMessageService.MESSAGE_KEY, genericMessage);
this.eventBus.send(AbstractMessageService.CLIENT_TO_SERVER_MESSAGE_KEY, message);
}
catch (Throwable t) {
handleError(session, t);
logger.error("Terminating STOMP session due to failure to send message: ", t);
sendErrorMessage(session, t);
}
}
private void handleError(final StompSession session, Throwable t) {
logger.error("Terminating STOMP session due to failure to send message: ", t);
sendErrorMessage(session, t.getMessage());
if (removeSubscriptions(session)) {
// TODO: send error event including exception info
}
}
protected void handleConnect(final WebSocketSession session, Message<byte[]> message) throws IOException {
private void sendErrorMessage(StompSession session, String errorText) {
StompHeaders headers = new StompHeaders();
headers.setMessage(errorText);
StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers);
try {
session.sendMessage(errorMessage);
}
catch (Throwable t) {
// ignore
}
}
StompHeaders connectStompHeaders = new StompHeaders(message.getHeaders(), true);
StompHeaders connectedStompHeaders = new StompHeaders(StompCommand.CONNECTED);
protected String handleConnect(final StompSession session, StompMessage stompMessage) throws IOException {
StompHeaders headers = new StompHeaders();
Set<String> acceptVersions = stompMessage.getHeaders().getAcceptVersion();
Set<String> acceptVersions = connectStompHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) {
headers.setVersion("1.2");
connectedStompHeaders.setAcceptVersion("1.2");
}
else if (acceptVersions.contains("1.1")) {
headers.setVersion("1.1");
connectedStompHeaders.setAcceptVersion("1.1");
}
else if (acceptVersions.isEmpty()) {
// 1.0
}
else {
throw new StompException("Unsupported version '" + acceptVersions + "'");
throw new StompConversionException("Unsupported version '" + acceptVersions + "'");
}
headers.setHeartbeat(0,0); // TODO
headers.setId(session.getId());
connectedStompHeaders.setHeartbeat(0,0); // TODO
// TODO: security
session.sendMessage(new StompMessage(StompCommand.CONNECTED, headers));
String replyKey = "relay-message" + session.getId();
EventRegistration registration = this.eventBus.registerConsumer(replyKey,
new EventConsumer<StompMessage>() {
@Override
public void accept(StompMessage message) {
try {
if (StompCommand.CONNECTED.equals(message.getCommand())) {
// TODO: skip for now (we already sent CONNECTED)
return;
}
if (logger.isTraceEnabled()) {
logger.trace("Relaying back to client: " + message);
}
session.sendMessage(message);
}
catch (Throwable t) {
handleError(session, t);
}
}
});
addRegistration(session, registration);
return replyKey;
Message<byte[]> connectedMessage = new GenericMessage<byte[]>(new byte[0], connectedStompHeaders.getMessageHeaders());
byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
protected String handleSubscribe(final StompSession session, StompMessage message) {
final String subscriptionId = message.getHeaders().getId();
String replyKey = getSubscriptionReplyKey(session, subscriptionId);
// TODO: extract and remember "ack" mode
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE_ack_Header
if (logger.isTraceEnabled()) {
logger.trace("Adding subscription, key=" + replyKey);
}
EventRegistration registration = this.eventBus.registerConsumer(replyKey, new EventConsumer<Message<?>>() {
@Override
public void accept(Message<?> replyMessage) {
StompHeaders headers = new StompHeaders();
headers.setSubscription(subscriptionId);
headerMapper.fromMessageHeaders(replyMessage.getHeaders(), headers);
byte[] payload;
try {
MediaType contentType = headers.getContentType();
payload = payloadConverter.convertToPayload(replyMessage.getPayload(), contentType);
}
catch (Exception e) {
logger.error("Failed to send " + replyMessage, e);
return;
}
try {
StompMessage stompMessage = new StompMessage(StompCommand.MESSAGE, headers, payload);
session.sendMessage(stompMessage);
}
catch (Throwable t) {
handleError(session, t);
}
}
});
addRegistration(session, registration);
return replyKey;
protected void handleSubscribe(Message<byte[]> message) {
// TODO: need a way to communicate back if subscription was successfully created or
// not in which case an ERROR should be sent back and close the connection
// http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE
}
private String getSubscriptionReplyKey(StompSession session, String subscriptionId) {
return StompCommand.SUBSCRIBE + ":" + session.getId() + ":" + subscriptionId;
protected void handleUnsubscribe(Message<byte[]> message) {
}
private void addRegistration(StompSession session, EventRegistration registration) {
String sessionId = session.getId();
List<EventRegistration> list = this.registrationsBySession.get(sessionId);
if (list == null) {
list = new ArrayList<EventRegistration>();
this.registrationsBySession.put(sessionId, list);
}
list.add(registration);
protected void handleMessage(Message<byte[]> stompMessage) {
}
protected void handleUnsubscribe(StompSession session, StompMessage message) {
cancelRegistration(session, message.getHeaders().getId());
protected void handleDisconnect(Message<byte[]> stompMessage) {
}
private void cancelRegistration(StompSession session, String subscriptionId) {
String key = getSubscriptionReplyKey(session, subscriptionId);
List<EventRegistration> list = this.registrationsBySession.get(session.getId());
for (EventRegistration registration : list) {
if (registration.getRegistrationKey().equals(key)) {
if (logger.isDebugEnabled()) {
logger.debug("Cancelling subscription, key=" + key);
}
list.remove(registration);
registration.cancel();
}
}
}
protected void handleSend(StompSession session, StompMessage stompMessage) {
}
protected void handleDisconnect(StompSession session, StompMessage stompMessage) {
removeSubscriptions(session);
}
private boolean removeSubscriptions(StompSession session) {
String sessionId = session.getId();
List<EventRegistration> registrations = this.registrationsBySession.remove(sessionId);
if (CollectionUtils.isEmpty(registrations)) {
return false;
}
if (logger.isTraceEnabled()) {
logger.trace("Cancelling " + registrations.size() + " subscriptions for session=" + sessionId);
}
for (EventRegistration registration : registrations) {
registration.cancel();
}
return true;
}
private final class ConnectionClosedTask implements Runnable {
private final StompSession session;
private ConnectionClosedTask(StompSession session) {
this.session = session;
}
@Override
public void run() {
removeSubscriptions(session);
if (logger.isTraceEnabled()) {
logger.trace("Sending notification for closed connection: " + session.getId());
}
eventBus.send(AbstractMessageService.CLIENT_CONNECTION_CLOSED_KEY, session.getId());
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
eventBus.send(AbstractMessageService.CLIENT_CONNECTION_CLOSED_KEY, session.getId());
}
}

View File

@ -1,92 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.stomp.socket;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.springframework.util.Assert;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompMessage;
import org.springframework.web.messaging.stomp.StompSession;
import org.springframework.web.messaging.stomp.support.StompMessageConverter;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class WebSocketStompSession implements StompSession {
private final String id;
private WebSocketSession webSocketSession;
private final StompMessageConverter messageConverter;
private final List<Runnable> connectionClosedTasks = new ArrayList<Runnable>();
public WebSocketStompSession(WebSocketSession webSocketSession, StompMessageConverter messageConverter) {
Assert.notNull(webSocketSession, "webSocketSession is required");
this.id = webSocketSession.getId();
this.webSocketSession = webSocketSession;
this.messageConverter = messageConverter;
}
@Override
public String getId() {
return this.id;
}
@Override
public void sendMessage(StompMessage message) throws IOException {
Assert.notNull(this.webSocketSession, "Cannot send message without active session");
try {
byte[] bytes = this.messageConverter.fromStompMessage(message);
this.webSocketSession.sendMessage(new TextMessage(new String(bytes, StompMessage.CHARSET)));
}
finally {
if (StompCommand.ERROR.equals(message.getCommand())) {
this.webSocketSession.close(CloseStatus.PROTOCOL_ERROR);
this.webSocketSession = null;
}
}
}
public void registerConnectionClosedTask(Runnable task) {
this.connectionClosedTasks.add(task);
}
public void handleConnectionClosed() {
for (Runnable task : this.connectionClosedTasks) {
try {
task.run();
}
catch (Throwable t) {
// ignore
}
}
}
}

View File

@ -31,6 +31,7 @@ import javax.net.SocketFactory;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.MediaType;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
@ -38,7 +39,6 @@ import org.springframework.web.messaging.event.EventBus;
import org.springframework.web.messaging.service.AbstractMessageService;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.StompMessage;
import reactor.util.Assert;
@ -57,8 +57,6 @@ public class RelayStompService extends AbstractMessageService {
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private final StompHeaderMapper stompHeaderMapper = new StompHeaderMapper();
public RelayStompService(EventBus eventBus, TaskExecutor executor) {
super(eventBus);
@ -84,8 +82,7 @@ public class RelayStompService extends AbstractMessageService {
forwardMessage(message, StompCommand.CONNECT);
String replyTo = (String) message.getHeaders().getReplyChannel();
RelayReadTask readTask = new RelayReadTask(sessionId, replyTo, session);
RelayReadTask readTask = new RelayReadTask(sessionId, session);
this.taskExecutor.execute(readTask);
}
catch (Throwable t) {
@ -96,23 +93,25 @@ public class RelayStompService extends AbstractMessageService {
private void forwardMessage(Message<?> message, StompCommand command) {
String sessionId = (String) message.getHeaders().get("sessionId");
StompHeaders stompHeaders = new StompHeaders(message.getHeaders(), false);
String sessionId = stompHeaders.getSessionId();
RelaySession session = RelayStompService.this.relaySessions.get(sessionId);
Assert.notNull(session, "RelaySession not found");
try {
StompHeaders stompHeaders = new StompHeaders();
this.stompHeaderMapper.fromMessageHeaders(message.getHeaders(), stompHeaders);
if (stompHeaders.getProtocolMessageType() == null) {
stompHeaders.setProtocolMessageType(StompCommand.SEND);
}
MediaType contentType = stompHeaders.getContentType();
byte[] payload = this.payloadConverter.convertToPayload(message.getPayload(), contentType);
StompMessage stompMessage = new StompMessage(command, stompHeaders, payload);
Message<byte[]> byteMessage = new GenericMessage<byte[]>(payload, stompHeaders.getMessageHeaders());
if (logger.isTraceEnabled()) {
logger.trace("Forwarding: " + stompMessage);
logger.trace("Forwarding: " + byteMessage);
}
byte[] bytesToWrite = this.stompMessageConverter.fromStompMessage(stompMessage);
session.getOutputStream().write(bytesToWrite);
byte[] bytes = this.stompMessageConverter.fromMessage(byteMessage);
session.getOutputStream().write(bytes);
session.getOutputStream().flush();
}
catch (Exception ex) {
@ -200,13 +199,12 @@ public class RelayStompService extends AbstractMessageService {
private final class RelayReadTask implements Runnable {
private final String stompSessionId;
private final String replyTo;
private final String sessionId;
private final RelaySession session;
private RelayReadTask(String stompSessionId, String replyTo, RelaySession session) {
this.stompSessionId = stompSessionId;
this.replyTo = replyTo;
private RelayReadTask(String sessionId, RelaySession session) {
this.sessionId = sessionId;
this.session = session;
}
@ -221,28 +219,28 @@ public class RelayStompService extends AbstractMessageService {
}
else if (b == 0x00) {
byte[] bytes = out.toByteArray();
StompMessage message = RelayStompService.this.stompMessageConverter.toStompMessage(bytes);
getEventBus().send(this.replyTo, message);
Message<byte[]> message = stompMessageConverter.toMessage(bytes, sessionId);
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, message);
out.reset();
}
else {
out.write(b);
}
}
logger.debug("Socket closed, STOMP session=" + stompSessionId);
logger.debug("Socket closed, STOMP session=" + sessionId);
sendErrorMessage("Lost connection");
}
catch (IOException e) {
logger.error("Socket error: " + e.getMessage());
clearRelaySession(stompSessionId);
clearRelaySession(sessionId);
}
}
private void sendErrorMessage(String message) {
StompHeaders headers = new StompHeaders();
headers.setMessage(message);
StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers);
getEventBus().send(this.replyTo, errorMessage);
StompHeaders stompHeaders = new StompHeaders(StompCommand.ERROR);
stompHeaders.setMessage(message);
Message<byte[]> errorMessage = new GenericMessage<byte[]>(new byte[0], stompHeaders.getMessageHeaders());
getEventBus().send(AbstractMessageService.SERVER_TO_CLIENT_MESSAGE_KEY, errorMessage);
}
}

View File

@ -1,102 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.stomp.support;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.MediaType;
import org.springframework.messaging.MessageHeaders;
import org.springframework.web.messaging.stomp.StompHeaders;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompHeaderMapper {
private static Log logger = LogFactory.getLog(StompHeaderMapper.class);
private static final String[][] stompHeaderNames;
static {
stompHeaderNames = new String[2][StompHeaders.STANDARD_HEADER_NAMES.size()];
for (int i=0 ; i < StompHeaders.STANDARD_HEADER_NAMES.size(); i++) {
stompHeaderNames[0][i] = StompHeaders.STANDARD_HEADER_NAMES.get(i);
stompHeaderNames[1][i] = "stomp." + StompHeaders.STANDARD_HEADER_NAMES.get(i);
}
}
public Map<String, Object> toMessageHeaders(StompHeaders stompHeaders) {
Map<String, Object> headers = new HashMap<String, Object>();
// prefixed STOMP headers
for (int i=0; i < stompHeaderNames[0].length; i++) {
String header = stompHeaderNames[0][i];
if (stompHeaders.containsKey(header)) {
String prefixedHeader = stompHeaderNames[1][i];
headers.put(prefixedHeader, stompHeaders.getFirst(header));
}
}
// for generic use (not-prefixed)
if (stompHeaders.getDestination() != null) {
headers.put("destination", stompHeaders.getDestination());
}
if (stompHeaders.getContentType() != null) {
headers.put("content-type", stompHeaders.getContentType());
}
return headers;
}
public void fromMessageHeaders(MessageHeaders messageHeaders, StompHeaders stompHeaders) {
// prefixed STOMP headers
for (int i=0; i < stompHeaderNames[0].length; i++) {
String prefixedHeader = stompHeaderNames[1][i];
if (messageHeaders.containsKey(prefixedHeader)) {
String header = stompHeaderNames[0][i];
stompHeaders.add(header, (String) messageHeaders.get(prefixedHeader));
}
}
// generic (not prefixed)
String destination = (String) messageHeaders.get("destination");
if (destination != null) {
stompHeaders.setDestination(destination);
}
Object contentType = messageHeaders.get("content-type");
if (contentType != null) {
if (contentType instanceof String) {
stompHeaders.setContentType(MediaType.valueOf((String) contentType));
}
else if (contentType instanceof MediaType) {
stompHeaders.setContentType((MediaType) contentType);
}
else {
logger.warn("Invalid contentType class: " + contentType.getClass());
}
}
}
}

View File

@ -17,96 +17,149 @@ package org.springframework.web.messaging.stomp.support;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.List;
import java.nio.charset.Charset;
import java.util.Map;
import java.util.Map.Entry;
import org.springframework.messaging.GenericMessage;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompException;
import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.StompMessage;
/**
* @author Gary Russell
* @author Rossen Stoyanchev
* @since 4.0
*
*/
public class StompMessageConverter {
private static final Charset STOMP_CHARSET = Charset.forName("UTF-8");
public static final byte LF = 0x0a;
public static final byte CR = 0x0d;
private static final byte COLON = ':';
/**
* @param bytes a complete STOMP message (without the trailing 0x00).
* @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String.
*/
public StompMessage toStompMessage(Object stomp) {
Assert.state(stomp instanceof String || stomp instanceof byte[], "'stomp' must be String or byte[]");
byte[] stompBytes = null;
if (stomp instanceof String) {
stompBytes = ((String) stomp).getBytes(StompMessage.CHARSET);
public Message<byte[]> toMessage(Object stompContent, String sessionId) {
byte[] byteContent = null;
if (stompContent instanceof String) {
byteContent = ((String) stompContent).getBytes(STOMP_CHARSET);
}
else if (stompContent instanceof byte[]){
byteContent = (byte[]) stompContent;
}
else {
stompBytes = (byte[]) stomp;
throw new IllegalArgumentException(
"stompContent is neither String nor byte[]: " + stompContent.getClass());
}
int totalLength = stompBytes.length;
if (stompBytes[totalLength-1] == 0) {
int totalLength = byteContent.length;
if (byteContent[totalLength-1] == 0) {
totalLength--;
}
int payloadIndex = findPayloadStart(stompBytes);
int payloadIndex = findIndexOfPayload(byteContent);
if (payloadIndex == 0) {
throw new StompException("No command found");
throw new StompConversionException("No command found");
}
String headerString = new String(stompBytes, 0, payloadIndex, StompMessage.CHARSET);
Parser parser = new Parser(headerString);
StompHeaders headers = new StompHeaders();
String headerContent = new String(byteContent, 0, payloadIndex, STOMP_CHARSET);
Parser parser = new Parser(headerContent);
// TODO: validate command and whether a payload is allowed
StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim());
Assert.notNull(command, "No command found");
StompHeaders stompHeaders = new StompHeaders(command);
stompHeaders.setSessionId(sessionId);
while (parser.hasNext()) {
String header = parser.nextToken(COLON);
if (header != null) {
if (parser.hasNext()) {
String value = parser.nextToken(LF);
headers.add(header, value);
stompHeaders.getRawHeaders().put(header, value);
}
else {
throw new StompException("Parse exception for " + headerString);
throw new StompConversionException("Parse exception for " + headerContent);
}
}
}
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(stompBytes, payloadIndex, payload, 0, totalLength - payloadIndex);
return new StompMessage(command, headers, payload);
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
stompHeaders.updateMessageHeaders();
return createMessage(command, stompHeaders.getMessageHeaders(), payload);
}
public byte[] fromStompMessage(StompMessage message) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
StompHeaders headers = message.getHeaders();
StompCommand command = message.getCommand();
private int findIndexOfPayload(byte[] bytes) {
int i;
// ignore any leading EOL from the previous message
for (i = 0; i < bytes.length; i++) {
if (bytes[i] != '\n' && bytes[i] != '\r') {
break;
}
bytes[i] = ' ';
}
int index = 0;
for (; i < bytes.length - 1; i++) {
if (bytes[i] == LF && bytes[i+1] == LF) {
index = i + 2;
break;
}
if ((i < (bytes.length - 3)) &&
(bytes[i] == CR && bytes[i+1] == LF && bytes[i+2] == CR && bytes[i+3] == LF)) {
index = i + 4;
break;
}
}
if (i >= bytes.length) {
throw new StompConversionException("No end of headers found");
}
return index;
}
protected Message<byte[]> createMessage(StompCommand command, Map<String, Object> headers, byte[] payload) {
return new GenericMessage<byte[]>(payload, headers);
}
public byte[] fromMessage(Message<byte[]> message) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, false);
stompHeaders.updateRawHeaders();
try {
outputStream.write(command.toString().getBytes("UTF-8"));
outputStream.write(LF);
for (Entry<String, List<String>> entry : headers.entrySet()) {
out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8"));
out.write(LF);
for (Entry<String, String> entry : stompHeaders.getRawHeaders().entrySet()) {
String key = entry.getKey();
key = replaceAllOutbound(key);
for (String value : entry.getValue()) {
outputStream.write(key.getBytes("UTF-8"));
outputStream.write(COLON);
value = replaceAllOutbound(value);
outputStream.write(value.getBytes("UTF-8"));
outputStream.write(LF);
}
String value = entry.getValue();
out.write(key.getBytes("UTF-8"));
out.write(COLON);
value = replaceAllOutbound(value);
out.write(value.getBytes("UTF-8"));
out.write(LF);
}
outputStream.write(LF);
outputStream.write(message.getPayload());
outputStream.write(0);
return outputStream.toByteArray();
out.write(LF);
out.write(message.getPayload());
out.write(0);
return out.toByteArray();
}
catch (IOException e) {
throw new StompException("Failed to serialize " + message, e);
throw new StompConversionException("Failed to serialize " + message, e);
}
}
@ -117,33 +170,6 @@ public class StompMessageConverter {
.replaceAll("\r", "\\\\r");
}
private int findPayloadStart(byte[] bytes) {
int i;
// ignore any leading EOL from the previous message
for (i = 0; i < bytes.length; i++) {
if (bytes[i] != '\n' && bytes[i] != '\r' ) {
break;
}
bytes[i] = ' ';
}
int payloadOffset = 0;
for (; i < bytes.length - 1; i++) {
if ((bytes[i] == LF && bytes[i+1] == LF)) {
payloadOffset = i + 2;
break;
}
if (i < bytes.length - 3 &&
(bytes[i] == CR && bytes[i+1] == LF &&
bytes[i+2] == CR && bytes[i+3] == LF)) {
payloadOffset = i + 4;
break;
}
}
if (i >= bytes.length) {
throw new StompException("No end of headers found");
}
return payloadOffset;
}
private class Parser {
@ -177,7 +203,7 @@ public class StompMessageConverter {
return null;
}
else {
throw new StompException("No delimiter found at offset " + offset + " in " + this.content);
throw new StompConversionException("No delimiter found at offset " + offset + " in " + this.content);
}
}
int escapeAt = this.content.indexOf('\\', this.offset);
@ -192,7 +218,7 @@ public class StompMessageConverter {
.replaceAll("\\\\\\\\", "\\\\");
}
else {
throw new StompException("Invalid escape sequence \\" + escaped);
throw new StompConversionException("Invalid escape sequence \\" + escaped);
}
}
int length = token.length();

View File

@ -14,25 +14,29 @@
* limitations under the License.
*/
package org.springframework.web.messaging.stomp.socket;
package org.springframework.web.messaging.support;
import org.springframework.web.messaging.stomp.StompMessage;
import java.util.Map;
import org.springframework.messaging.GenericMessage;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public interface StompMessageInterceptor {
public class DestinationMessage<T> extends GenericMessage<T> {
boolean handleConnect(StompMessage message);
boolean handleSubscribe(StompMessage message);
public DestinationMessage(T payload, Map<String, Object> headers) {
super(payload, headers);
}
boolean handleUnsubscribe(StompMessage message);
public DestinationMessage(T payload) {
super(payload);
}
StompMessage handleSend(StompMessage message);
void handleDisconnect();
}

View File

@ -0,0 +1,138 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.messaging.stomp.support;
import java.util.Collections;
import org.junit.Before;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.stomp.StompCommand;
import static org.junit.Assert.*;
/**
* @author Gary Russell
* @author Rossen Stoyanchev
*/
public class StompMessageConverterTests {
private StompMessageConverter converter;
@Before
public void setup() {
this.converter = new StompMessageConverter();
}
@Test
public void connectFrame() throws Exception {
String accept = "accept-version:1.1\n";
String host = "host:github.org\n";
String frame = "\n\n\nCONNECT\n" + accept + host + "\n";
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
assertEquals(6, stompHeaders.getMessageHeaders().size());
assertEquals(MessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand());
assertEquals("session-123", stompHeaders.getSessionId());
assertNotNull(messageHeaders.get(MessageHeaders.ID));
assertNotNull(messageHeaders.get(MessageHeaders.TIMESTAMP));
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getRawHeaders().get("host"));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
@Test
public void connectWithEscapes() throws Exception {
String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String frame = "CONNECT\n" + accept + host + "\n";
Message<byte[]> message = this.converter.toMessage(frame.getBytes("UTF-8"), "session-123");
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt"));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
@Test
public void connectCR12() throws Exception {
String accept = "accept-version:1.2\n";
String host = "host:github.org\n";
String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getRawHeaders().get("host"));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
@Test
public void connectWithEscapesAndCR12() throws Exception {
String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
Message<byte[]> message = this.converter.toMessage(test.getBytes("UTF-8"), "session-123");
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = new StompHeaders(messageHeaders, true);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getRawHeaders().get("ho:\ns\rt"));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
}