Refactor PubSubHeaders, StompHeaders, MessageBuilder

Rename to PubSubHeaderAccessor and StompHeaderAccessor
Move the renamed classes to support packages

Remove fromPayloadAndHeaders from MessageBuilder, just use
withPayload(..).copyHeaders(..) instead.
This commit is contained in:
Rossen Stoyanchev 2013-06-19 11:30:01 -04:00
parent 811bb1b0c9
commit 5cfc59d76d
14 changed files with 111 additions and 131 deletions

View File

@ -93,18 +93,6 @@ public final class MessageBuilder<T> {
return builder;
}
/**
* Create a builder for a new {@link Message} instance with the provided payload and
* headers.
*
* @param payload the payload for the new message
* @param headers the headers to use
*/
public static <T> MessageBuilder<T> fromPayloadAndHeaders(T payload, Map<String, Object> headers) {
MessageBuilder<T> builder = new MessageBuilder<T>(payload, headers);
return builder;
}
/**
* Create a builder for a new {@link Message} instance with the provided payload.
*

View File

@ -30,7 +30,7 @@ import org.springframework.util.AntPathMatcher;
import org.springframework.util.CollectionUtils;
import org.springframework.util.PathMatcher;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
/**
@ -81,7 +81,7 @@ public abstract class AbstractPubSubMessageHandler<M extends Message> implements
protected boolean isDestinationAllowed(M message) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
String destination = headers.getDestination();
if (destination == null) {
@ -117,7 +117,7 @@ public abstract class AbstractPubSubMessageHandler<M extends Message> implements
@Override
public final void handleMessage(M message) throws MessagingException {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
MessageType messageType = headers.getMessageType();
if (!canHandle(message, messageType)) {

View File

@ -29,9 +29,9 @@ import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubChannelRegistry;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
import reactor.core.Reactor;
import reactor.fn.Consumer;
@ -79,7 +79,7 @@ public class ReactorPubSubMessageHandler<M extends Message> extends AbstractPubS
logger.debug("Subscribe " + message);
}
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
String subscriptionId = headers.getSubscriptionId();
BroadcastingConsumer consumer = new BroadcastingConsumer(subscriptionId);
@ -108,10 +108,10 @@ public class ReactorPubSubMessageHandler<M extends Message> extends AbstractPubS
try {
// Convert to byte[] payload before the fan-out
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), headers.getContentType());
@SuppressWarnings("unchecked")
M m = (M) MessageBuilder.fromPayloadAndHeaders(payload, message.getHeaders()).build();
M m = (M) MessageBuilder.withPayload(payload).copyHeaders(message.getHeaders()).build();
this.reactor.notify(getPublishKey(headers.getDestination()), Event.wrap(m));
}
@ -122,7 +122,7 @@ public class ReactorPubSubMessageHandler<M extends Message> extends AbstractPubS
@Override
public void handleDisconnect(M message) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
removeSubscriptions(headers.getSessionId());
}
@ -151,12 +151,11 @@ public class ReactorPubSubMessageHandler<M extends Message> extends AbstractPubS
Message<?> sentMessage = event.getData();
PubSubHeaders clientHeaders = PubSubHeaders.fromMessageHeaders(sentMessage.getHeaders());
clientHeaders.setSubscriptionId(this.subscriptionId);
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(sentMessage);
headers.setSubscriptionId(this.subscriptionId);
@SuppressWarnings("unchecked")
M clientMessage = (M) MessageBuilder.fromPayloadAndHeaders(sentMessage.getPayload(),
clientHeaders.toMessageHeaders()).build();
M clientMessage = (M) MessageBuilder.withPayload(sentMessage.getPayload()).copyHeaders(headers.toHeaders()).build();
clientChannel.send(clientMessage);
}

View File

@ -39,11 +39,11 @@ import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils.MethodFilter;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubChannelRegistry;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.annotation.SubscribeEvent;
import org.springframework.web.messaging.annotation.UnsubscribeEvent;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.service.AbstractPubSubMessageHandler;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.HandlerMethodSelector;
@ -182,7 +182,7 @@ public class AnnotationPubSubMessageHandler<M extends Message> extends AbstractP
private void handleMessageInternal(final M message, Map<MappingInfo, HandlerMethod> handlerMethods) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
String destination = headers.getDestination();
HandlerMethod match = getHandlerMethod(destination, handlerMethods);

View File

@ -21,11 +21,11 @@ import java.util.List;
import org.springframework.core.MethodParameter;
import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.annotation.MessageBody;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConversionException;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
/**
@ -53,7 +53,7 @@ public class MessageBodyArgumentResolver<M extends Message> implements ArgumentR
Object arg = null;
MessageBody annot = parameter.getParameterAnnotation(MessageBody.class);
MediaType contentType = (MediaType) message.getHeaders().get(PubSubHeaders.CONTENT_TYPE);
MediaType contentType = (MediaType) message.getHeaders().get(PubSubHeaderAccesssor.CONTENT_TYPE);
if (annot == null || annot.required()) {
Class<?> sourceType = message.getPayload().getClass();

View File

@ -20,7 +20,7 @@ import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.util.Assert;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
import org.springframework.web.messaging.support.SessionMessageChannel;
@ -47,7 +47,7 @@ public class MessageChannelArgumentResolver<M extends Message> implements Argume
@Override
public Object resolveArgument(MethodParameter parameter, M message) throws Exception {
Assert.notNull(this.messageBrokerChannel, "messageBrokerChannel is required");
final String sessionId = PubSubHeaders.fromMessageHeaders(message.getHeaders()).getSessionId();
final String sessionId = PubSubHeaderAccesssor.wrap(message).getSessionId();
return new SessionMessageChannel<M>(this.messageBrokerChannel, sessionId);
}

View File

@ -21,7 +21,7 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
/**
@ -75,13 +75,13 @@ public class MessageReturnValueHandler<M extends Message> implements ReturnValue
protected M updateReturnMessage(M returnMessage, M message) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
String sessionId = headers.getSessionId();
String subscriptionId = headers.getSubscriptionId();
Assert.notNull(subscriptionId, "No subscription id: " + message);
PubSubHeaders returnHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders());
PubSubHeaderAccesssor returnHeaders = PubSubHeaderAccesssor.wrap(returnMessage);
returnHeaders.setSessionId(sessionId);
returnHeaders.setSubscriptionId(subscriptionId);
@ -89,13 +89,12 @@ public class MessageReturnValueHandler<M extends Message> implements ReturnValue
returnHeaders.setDestination(headers.getDestination());
}
Object payload = returnMessage.getPayload();
return createMessage(returnHeaders, payload);
return createMessage(returnHeaders, returnMessage.getPayload());
}
@SuppressWarnings("unchecked")
private M createMessage(PubSubHeaders returnHeaders, Object payload) {
return (M) MessageBuilder.fromPayloadAndHeaders(payload, returnHeaders.toMessageHeaders()).build();
private M createMessage(PubSubHeaderAccesssor returnHeaders, Object payload) {
return (M) MessageBuilder.withPayload(payload).copyHeaders(returnHeaders.toHeaders()).build();
}
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package org.springframework.web.messaging.stomp;
package org.springframework.web.messaging.stomp.support;
import java.util.Collections;
import java.util.HashMap;
@ -25,13 +25,13 @@ import java.util.concurrent.atomic.AtomicLong;
import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
/**
@ -39,13 +39,13 @@ import org.springframework.web.messaging.PubSubHeaders;
* STOMP-specific headers of an existing message.
* <p>
* Use one of the static factory method in this class, then call getters and setters, and
* at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers
* at the end if necessary call {@link #toHeaders()} to obtain the updated headers
* or call {@link #toStompMessageHeaders()} to obtain only the STOMP-specific headers.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompHeaders extends PubSubHeaders {
public class StompHeaderAccessor extends PubSubHeaderAccesssor {
public static final String STOMP_ID = "id";
@ -88,7 +88,7 @@ public class StompHeaders extends PubSubHeaders {
* A constructor for creating new STOMP message headers.
* This constructor is private. See factory methods in this sub-classes.
*/
private StompHeaders(StompCommand command, Map<String, List<String>> externalSourceHeaders) {
private StompHeaderAccessor(StompCommand command, Map<String, List<String>> externalSourceHeaders) {
super(command.getMessageType(), command, externalSourceHeaders);
this.headers = new HashMap<String, String>(4);
updateMessageHeaders();
@ -118,32 +118,32 @@ public class StompHeaders extends PubSubHeaders {
* constructor is protected. See factory methods in this class.
*/
@SuppressWarnings("unchecked")
private StompHeaders(MessageHeaders messageHeaders) {
super(messageHeaders);
this.headers = (messageHeaders.get(STOMP_HEADERS) != null) ?
(Map<String, String>) messageHeaders.get(STOMP_HEADERS) : new HashMap<String, String>(4);
private StompHeaderAccessor(Message<?> message) {
super(message);
this.headers = (message.getHeaders() .get(STOMP_HEADERS) != null) ?
(Map<String, String>) message.getHeaders().get(STOMP_HEADERS) : new HashMap<String, String>(4);
}
/**
* Create {@link StompHeaders} for a new {@link Message}.
* Create {@link StompHeaderAccessor} for a new {@link Message}.
*/
public static StompHeaders create(StompCommand command) {
return new StompHeaders(command, null);
public static StompHeaderAccessor create(StompCommand command) {
return new StompHeaderAccessor(command, null);
}
/**
* Create {@link StompHeaders} from the headers of an existing {@link Message}.
* Create {@link StompHeaderAccessor} from parsed STOP frame content.
*/
public static StompHeaders fromMessageHeaders(MessageHeaders messageHeaders) {
return new StompHeaders(messageHeaders);
public static StompHeaderAccessor create(StompCommand command, Map<String, List<String>> headers) {
return new StompHeaderAccessor(command, headers);
}
/**
* Create {@link StompHeaders} from parsed STOP frame content.
* Create {@link StompHeaderAccessor} from the headers of an existing {@link Message}.
*/
public static StompHeaders fromParsedFrame(StompCommand command, Map<String, List<String>> headers) {
return new StompHeaders(command, headers);
public static StompHeaderAccessor wrap(Message<?> message) {
return new StompHeaderAccessor(message);
}
@ -152,8 +152,8 @@ public class StompHeaders extends PubSubHeaders {
* updates made via setters.
*/
@Override
public Map<String, Object> toMessageHeaders() {
Map<String, Object> result = super.toMessageHeaders();
public Map<String, Object> toHeaders() {
Map<String, Object> result = super.toHeaders();
if (isModified()) {
result.put(STOMP_HEADERS, this.headers);
}

View File

@ -22,14 +22,12 @@ import java.util.List;
import java.util.Map.Entry;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.stomp.StompHeaders;
/**
@ -96,7 +94,7 @@ public class StompMessageConverter<M extends Message> {
}
}
StompHeaders stompHeaders = StompHeaders.fromParsedFrame(command, headers);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
stompHeaders.setSessionId(sessionId);
byte[] payload = new byte[totalLength - payloadIndex];
@ -106,8 +104,8 @@ public class StompMessageConverter<M extends Message> {
}
@SuppressWarnings("unchecked")
private M createMessage(StompHeaders stompHeaders, byte[] payload) {
return (M) MessageBuilder.fromPayloadAndHeaders(payload, stompHeaders.toMessageHeaders()).build();
private M createMessage(StompHeaderAccessor stompHeaders, byte[] payload) {
return (M) MessageBuilder.withPayload(payload).copyHeaders(stompHeaders.toHeaders()).build();
}
private int findIndexOfPayload(byte[] bytes) {
@ -149,8 +147,7 @@ public class StompMessageConverter<M extends Message> {
}
ByteArrayOutputStream out = new ByteArrayOutputStream();
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
try {
out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8"));

View File

@ -33,12 +33,11 @@ import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubChannelRegistry;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.service.AbstractPubSubMessageHandler;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
import reactor.core.Environment;
import reactor.core.Promise;
@ -97,7 +96,7 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@Override
public void handleConnect(M message) {
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
String sessionId = stompHeaders.getSessionId();
if (sessionId == null) {
logger.error("No sessionId in message " + message);
@ -124,7 +123,7 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@Override
public void handleDisconnect(M message) {
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
if (stompHeaders.getStompCommand() != null) {
forwardMessage(message, StompCommand.DISCONNECT);
}
@ -137,14 +136,14 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
@Override
public void handleOther(M message) {
StompCommand command = (StompCommand) message.getHeaders().get(PubSubHeaders.PROTOCOL_MESSAGE_TYPE);
StompCommand command = (StompCommand) message.getHeaders().get(PubSubHeaderAccesssor.PROTOCOL_MESSAGE_TYPE);
Assert.notNull(command, "Expected STOMP command: " + message.getHeaders());
forwardMessage(message, command);
}
private void forwardMessage(M message, StompCommand command) {
StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setStompCommandIfNotSet(command);
String sessionId = headers.getSessionId();
@ -174,9 +173,10 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
private final Object monitor = new Object();
private boolean isConnected = false;
private volatile boolean isConnected = false;
public RelaySession(final M message, final StompHeaders stompHeaders) {
public RelaySession(final M message, final StompHeaderAccessor stompHeaders) {
Assert.notNull(message, "message is required");
Assert.notNull(stompHeaders, "stompHeaders is required");
@ -222,7 +222,7 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
logger.trace("Reading message " + message);
}
StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (StompCommand.CONNECTED == headers.getStompCommand()) {
synchronized(this.monitor) {
this.isConnected = true;
@ -240,15 +240,15 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
}
private void sendError(String sessionId, String errorText) {
StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR);
stompHeaders.setSessionId(sessionId);
stompHeaders.setMessage(errorText);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setSessionId(sessionId);
headers.setMessage(errorText);
@SuppressWarnings("unchecked")
M errorMessage = (M) MessageBuilder.fromPayloadAndHeaders(new byte[0], stompHeaders.toMessageHeaders()).build();
M errorMessage = (M) MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toHeaders()).build();
clientChannel.send(errorMessage);
}
public void forward(M message, StompHeaders headers) {
public void forward(M message, StompHeaderAccessor headers) {
if (!this.isConnected) {
synchronized(this.monitor) {
@ -277,21 +277,21 @@ public class StompRelayPubSubMessageHandler<M extends Message> extends AbstractP
List<M> messages = new ArrayList<M>();
this.messageQueue.drainTo(messages);
for (Message<?> message : messages) {
StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
if (!forwardInternal(message, headers, connection)) {
return;
}
}
}
private boolean forwardInternal(Message<?> message, StompHeaders headers, TcpConnection<String, String> connection) {
private boolean forwardInternal(Message<?> message, StompHeaderAccessor headers, TcpConnection<String, String> connection) {
try {
headers.setStompCommandIfNotSet(StompCommand.SEND);
MediaType contentType = headers.getContentType();
byte[] payload = payloadConverter.convertToPayload(message.getPayload(), contentType);
@SuppressWarnings("unchecked")
M byteMessage = (M) MessageBuilder.fromPayloadAndHeaders(payload, headers.toMessageHeaders()).build();
M byteMessage = (M) MessageBuilder.withPayload(payload).copyHeaders(headers.toHeaders()).build();
if (logger.isTraceEnabled()) {
logger.trace("Forwarding message " + byteMessage);

View File

@ -31,12 +31,11 @@ import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.PubSubChannelRegistry;
import org.springframework.web.messaging.PubSubHeaders;
import org.springframework.web.messaging.converter.CompositeMessageConverter;
import org.springframework.web.messaging.converter.MessageConverter;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompConversionException;
import org.springframework.web.messaging.stomp.StompHeaders;
import org.springframework.web.messaging.support.PubSubHeaderAccesssor;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
@ -107,7 +106,7 @@ public class StompWebSocketHandler<M extends Message> extends TextWebSocketHandl
}
try {
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
MessageType messageType = stompHeaders.getMessageType();
if (MessageType.CONNECT.equals(messageType)) {
handleConnect(session, message);
@ -142,8 +141,8 @@ public class StompWebSocketHandler<M extends Message> extends TextWebSocketHandl
protected void handleConnect(final WebSocketSession session, M message) throws IOException {
StompHeaders connectHeaders = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaders connectedHeaders = StompHeaders.create(StompCommand.CONNECTED);
StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(message);
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
Set<String> acceptVersions = connectHeaders.getAcceptVersion();
if (acceptVersions.contains("1.2")) {
@ -163,8 +162,8 @@ public class StompWebSocketHandler<M extends Message> extends TextWebSocketHandl
// TODO: security
@SuppressWarnings("unchecked")
M connectedMessage = (M) MessageBuilder.fromPayloadAndHeaders(EMPTY_PAYLOAD,
connectedHeaders.toMessageHeaders()).build();
M connectedMessage = (M) MessageBuilder.withPayload(EMPTY_PAYLOAD).copyHeaders(
connectedHeaders.toHeaders()).build();
byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
@ -186,11 +185,11 @@ public class StompWebSocketHandler<M extends Message> extends TextWebSocketHandl
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaders headers = StompHeaders.create(StompCommand.ERROR);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setMessage(error.getMessage());
@SuppressWarnings("unchecked")
M message = (M) MessageBuilder.fromPayloadAndHeaders(EMPTY_PAYLOAD, headers.toMessageHeaders()).build();
M message = (M) MessageBuilder.withPayload(EMPTY_PAYLOAD).copyHeaders(headers.toHeaders()).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
try {
@ -204,10 +203,10 @@ public class StompWebSocketHandler<M extends Message> extends TextWebSocketHandl
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
this.sessions.remove(session.getId());
PubSubHeaders headers = PubSubHeaders.create(MessageType.DISCONNECT);
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.create(MessageType.DISCONNECT);
headers.setSessionId(session.getId());
@SuppressWarnings("unchecked")
M message = (M) MessageBuilder.fromPayloadAndHeaders(new byte[0], headers.toMessageHeaders()).build();
M message = (M) MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toHeaders()).build();
this.outputChannel.send(message);
}
@ -217,7 +216,7 @@ public class StompWebSocketHandler<M extends Message> extends TextWebSocketHandl
@Override
public void handleMessage(M message) {
StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders());
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setStompCommandIfNotSet(StompCommand.MESSAGE);
if (StompCommand.CONNECTED.equals(headers.getStompCommand())) {
@ -246,7 +245,7 @@ public class StompWebSocketHandler<M extends Message> extends TextWebSocketHandl
try {
@SuppressWarnings("unchecked")
M byteMessage = (M) MessageBuilder.fromPayloadAndHeaders(payload, headers.toMessageHeaders()).build();
M byteMessage = (M) MessageBuilder.withPayload(payload).copyHeaders(headers.toHeaders()).build();
byte[] bytes = getStompMessageConverter().fromMessage(byteMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package org.springframework.web.messaging;
package org.springframework.web.messaging.support;
import java.util.Arrays;
import java.util.Collections;
@ -30,6 +30,7 @@ import org.springframework.messaging.MessageHeaders;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.web.messaging.MessageType;
/**
@ -42,12 +43,12 @@ import org.springframework.util.LinkedMultiValueMap;
* and/or modify headers of an existing message.
* <p>
* Use one of the static factory method in this class, then call getters and setters, and
* at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers.
* at the end if necessary call {@link #toHeaders()} to obtain the updated headers.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PubSubHeaders {
public class PubSubHeaderAccesssor {
protected Log logger = LogFactory.getLog(getClass());
@ -85,7 +86,7 @@ public class PubSubHeaders {
* A constructor for creating new message headers.
* This constructor is protected. See factory methods in this and sub-classes.
*/
protected PubSubHeaders(MessageType messageType, Object protocolMessageType,
protected PubSubHeaderAccesssor(MessageType messageType, Object protocolMessageType,
Map<String, List<String>> externalSourceHeaders) {
this.originalHeaders = null;
@ -111,33 +112,34 @@ public class PubSubHeaders {
* constructor is protected. See factory methods in this and sub-classes.
*/
@SuppressWarnings("unchecked")
protected PubSubHeaders(MessageHeaders originalHeaders) {
Assert.notNull(originalHeaders, "originalHeaders is required");
this.originalHeaders = originalHeaders;
this.externalSourceHeaders = (originalHeaders.get(EXTERNAL_SOURCE_HEADERS) != null) ?
(Map<String, List<String>>) originalHeaders.get(EXTERNAL_SOURCE_HEADERS) : emptyMultiValueMap;
protected PubSubHeaderAccesssor(Message<?> message) {
Assert.notNull(message, "message is required");
this.originalHeaders = message.getHeaders();
this.externalSourceHeaders = (this.originalHeaders.get(EXTERNAL_SOURCE_HEADERS) != null) ?
(Map<String, List<String>>) this.originalHeaders.get(EXTERNAL_SOURCE_HEADERS) : emptyMultiValueMap;
}
/**
* Create {@link PubSubHeaders} for a new {@link Message}.
* Create {@link PubSubHeaderAccesssor} for a new {@link Message} with
* {@link MessageType#MESSAGE}.
*/
public static PubSubHeaders create() {
return new PubSubHeaders(MessageType.MESSAGE, null, null);
public static PubSubHeaderAccesssor create() {
return new PubSubHeaderAccesssor(MessageType.MESSAGE, null, null);
}
/**
* Create {@link PubSubHeaders} for a new {@link Message} of a specific type.
* Create {@link PubSubHeaderAccesssor} for a new {@link Message} of a specific type.
*/
public static PubSubHeaders create(MessageType messageType) {
return new PubSubHeaders(messageType, null, null);
public static PubSubHeaderAccesssor create(MessageType messageType) {
return new PubSubHeaderAccesssor(messageType, null, null);
}
/**
* Create {@link PubSubHeaders} from existing message headers.
* Create {@link PubSubHeaderAccesssor} from the headers of an existing message.
*/
public static PubSubHeaders fromMessageHeaders(MessageHeaders originalHeaders) {
return new PubSubHeaders(originalHeaders);
public static PubSubHeaderAccesssor wrap(Message<?> message) {
return new PubSubHeaderAccesssor(message);
}
@ -145,7 +147,7 @@ public class PubSubHeaders {
* Return the original, wrapped headers (i.e. unmodified) or a new Map including any
* updates made via setters.
*/
public Map<String, Object> toMessageHeaders() {
public Map<String, Object> toHeaders() {
if (!isModified()) {
return this.originalHeaders;
}

View File

@ -19,7 +19,6 @@ package org.springframework.web.messaging.support;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.messaging.PubSubHeaders;
import reactor.util.Assert;
@ -50,10 +49,11 @@ public class SessionMessageChannel<M extends Message> implements MessageChannel<
@Override
public boolean send(M message, long timeout) {
PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders());
PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message);
headers.setSessionId(this.sessionId);
Object payload = message.getPayload();
@SuppressWarnings("unchecked")
M messageToSend = (M) MessageBuilder.fromPayloadAndHeaders(message.getPayload(), headers.toMessageHeaders()).build();
M messageToSend = (M) MessageBuilder.withPayload(payload).copyHeaders(headers.toHeaders()).build();
this.delegate.send(messageToSend);
return true;
}

View File

@ -23,7 +23,6 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.web.messaging.MessageType;
import org.springframework.web.messaging.stomp.StompCommand;
import org.springframework.web.messaging.stomp.StompHeaders;
import static org.junit.Assert.*;
@ -33,12 +32,12 @@ import static org.junit.Assert.*;
*/
public class StompMessageConverterTests {
private StompMessageConverter converter;
private StompMessageConverter<Message<byte[]>> converter;
@Before
public void setup() {
this.converter = new StompMessageConverter();
this.converter = new StompMessageConverter<Message<byte[]>>();
}
@Test
@ -51,9 +50,9 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
assertEquals(7, stompHeaders.toMessageHeaders().size());
MessageHeaders headers = message.getHeaders();
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(7, stompHeaders.toHeaders().size());
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost());
@ -61,8 +60,8 @@ public class StompMessageConverterTests {
assertEquals(MessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand());
assertEquals("session-123", stompHeaders.getSessionId());
assertNotNull(messageHeaders.get(MessageHeaders.ID));
assertNotNull(messageHeaders.get(MessageHeaders.TIMESTAMP));
assertNotNull(headers.get(MessageHeaders.ID));
assertNotNull(headers.get(MessageHeaders.TIMESTAMP));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
@ -81,8 +80,7 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0));
@ -103,8 +101,7 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost());
@ -125,8 +122,7 @@ public class StompMessageConverterTests {
assertEquals(0, message.getPayload().length);
MessageHeaders messageHeaders = message.getHeaders();
StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0));