From 5cfc59d76deb1ece83f536ab8bac794f37a18606 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 19 Jun 2013 11:30:01 -0400 Subject: [PATCH] Refactor PubSubHeaders, StompHeaders, MessageBuilder Rename to PubSubHeaderAccessor and StompHeaderAccessor Move the renamed classes to support packages Remove fromPayloadAndHeaders from MessageBuilder, just use withPayload(..).copyHeaders(..) instead. --- .../messaging/support/MessageBuilder.java | 12 ------ .../service/AbstractPubSubMessageHandler.java | 6 +-- .../service/ReactorPubSubMessageHandler.java | 17 ++++---- .../AnnotationPubSubMessageHandler.java | 4 +- .../method/MessageBodyArgumentResolver.java | 4 +- .../MessageChannelArgumentResolver.java | 4 +- .../method/MessageReturnValueHandler.java | 13 +++--- .../StompHeaderAccessor.java} | 42 +++++++++---------- .../stomp/support/StompMessageConverter.java | 11 ++--- .../StompRelayPubSubMessageHandler.java | 34 +++++++-------- .../stomp/support/StompWebSocketHandler.java | 25 ++++++----- .../PubSubHeaderAccesssor.java} | 40 +++++++++--------- .../support/SessionMessageChannel.java | 6 +-- .../support/StompMessageConverterTests.java | 24 +++++------ 14 files changed, 111 insertions(+), 131 deletions(-) rename spring-websocket/src/main/java/org/springframework/web/messaging/stomp/{StompHeaders.java => support/StompHeaderAccessor.java} (86%) rename spring-websocket/src/main/java/org/springframework/web/messaging/{PubSubHeaders.java => support/PubSubHeaderAccesssor.java} (83%) diff --git a/spring-context/src/main/java/org/springframework/messaging/support/MessageBuilder.java b/spring-context/src/main/java/org/springframework/messaging/support/MessageBuilder.java index 22cf40cf283..96d3859f6e8 100644 --- a/spring-context/src/main/java/org/springframework/messaging/support/MessageBuilder.java +++ b/spring-context/src/main/java/org/springframework/messaging/support/MessageBuilder.java @@ -93,18 +93,6 @@ public final class MessageBuilder { return builder; } - /** - * Create a builder for a new {@link Message} instance with the provided payload and - * headers. - * - * @param payload the payload for the new message - * @param headers the headers to use - */ - public static MessageBuilder fromPayloadAndHeaders(T payload, Map headers) { - MessageBuilder builder = new MessageBuilder(payload, headers); - return builder; - } - /** * Create a builder for a new {@link Message} instance with the provided payload. * diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java index ca4c4939b09..85e761f0ac3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/AbstractPubSubMessageHandler.java @@ -30,7 +30,7 @@ import org.springframework.util.AntPathMatcher; import org.springframework.util.CollectionUtils; import org.springframework.util.PathMatcher; import org.springframework.web.messaging.MessageType; -import org.springframework.web.messaging.PubSubHeaders; +import org.springframework.web.messaging.support.PubSubHeaderAccesssor; /** @@ -81,7 +81,7 @@ public abstract class AbstractPubSubMessageHandler implements protected boolean isDestinationAllowed(M message) { - PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); String destination = headers.getDestination(); if (destination == null) { @@ -117,7 +117,7 @@ public abstract class AbstractPubSubMessageHandler implements @Override public final void handleMessage(M message) throws MessagingException { - PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); MessageType messageType = headers.getMessageType(); if (!canHandle(message, messageType)) { diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java index e80f12bc330..89e82e697cd 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/ReactorPubSubMessageHandler.java @@ -29,9 +29,9 @@ import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.PubSubChannelRegistry; -import org.springframework.web.messaging.PubSubHeaders; import org.springframework.web.messaging.converter.CompositeMessageConverter; import org.springframework.web.messaging.converter.MessageConverter; +import org.springframework.web.messaging.support.PubSubHeaderAccesssor; import reactor.core.Reactor; import reactor.fn.Consumer; @@ -79,7 +79,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubS logger.debug("Subscribe " + message); } - PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); String subscriptionId = headers.getSubscriptionId(); BroadcastingConsumer consumer = new BroadcastingConsumer(subscriptionId); @@ -108,10 +108,10 @@ public class ReactorPubSubMessageHandler extends AbstractPubS try { // Convert to byte[] payload before the fan-out - PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); byte[] payload = payloadConverter.convertToPayload(message.getPayload(), headers.getContentType()); @SuppressWarnings("unchecked") - M m = (M) MessageBuilder.fromPayloadAndHeaders(payload, message.getHeaders()).build(); + M m = (M) MessageBuilder.withPayload(payload).copyHeaders(message.getHeaders()).build(); this.reactor.notify(getPublishKey(headers.getDestination()), Event.wrap(m)); } @@ -122,7 +122,7 @@ public class ReactorPubSubMessageHandler extends AbstractPubS @Override public void handleDisconnect(M message) { - PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); removeSubscriptions(headers.getSessionId()); } @@ -151,12 +151,11 @@ public class ReactorPubSubMessageHandler extends AbstractPubS Message sentMessage = event.getData(); - PubSubHeaders clientHeaders = PubSubHeaders.fromMessageHeaders(sentMessage.getHeaders()); - clientHeaders.setSubscriptionId(this.subscriptionId); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(sentMessage); + headers.setSubscriptionId(this.subscriptionId); @SuppressWarnings("unchecked") - M clientMessage = (M) MessageBuilder.fromPayloadAndHeaders(sentMessage.getPayload(), - clientHeaders.toMessageHeaders()).build(); + M clientMessage = (M) MessageBuilder.withPayload(sentMessage.getPayload()).copyHeaders(headers.toHeaders()).build(); clientChannel.send(clientMessage); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java index a702dc578bd..e0b59661a00 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/AnnotationPubSubMessageHandler.java @@ -39,11 +39,11 @@ import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils.MethodFilter; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.PubSubChannelRegistry; -import org.springframework.web.messaging.PubSubHeaders; import org.springframework.web.messaging.annotation.SubscribeEvent; import org.springframework.web.messaging.annotation.UnsubscribeEvent; import org.springframework.web.messaging.converter.MessageConverter; import org.springframework.web.messaging.service.AbstractPubSubMessageHandler; +import org.springframework.web.messaging.support.PubSubHeaderAccesssor; import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.HandlerMethodSelector; @@ -182,7 +182,7 @@ public class AnnotationPubSubMessageHandler extends AbstractP private void handleMessageInternal(final M message, Map handlerMethods) { - PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); String destination = headers.getDestination(); HandlerMethod match = getHandlerMethod(destination, handlerMethods); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageBodyArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageBodyArgumentResolver.java index 4b280e238be..1a87b503128 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageBodyArgumentResolver.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageBodyArgumentResolver.java @@ -21,11 +21,11 @@ import java.util.List; import org.springframework.core.MethodParameter; import org.springframework.http.MediaType; import org.springframework.messaging.Message; -import org.springframework.web.messaging.PubSubHeaders; import org.springframework.web.messaging.annotation.MessageBody; import org.springframework.web.messaging.converter.CompositeMessageConverter; import org.springframework.web.messaging.converter.MessageConversionException; import org.springframework.web.messaging.converter.MessageConverter; +import org.springframework.web.messaging.support.PubSubHeaderAccesssor; /** @@ -53,7 +53,7 @@ public class MessageBodyArgumentResolver implements ArgumentR Object arg = null; MessageBody annot = parameter.getParameterAnnotation(MessageBody.class); - MediaType contentType = (MediaType) message.getHeaders().get(PubSubHeaders.CONTENT_TYPE); + MediaType contentType = (MediaType) message.getHeaders().get(PubSubHeaderAccesssor.CONTENT_TYPE); if (annot == null || annot.required()) { Class sourceType = message.getPayload().getClass(); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java index 726236694ce..4429421320d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageChannelArgumentResolver.java @@ -20,7 +20,7 @@ import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.util.Assert; -import org.springframework.web.messaging.PubSubHeaders; +import org.springframework.web.messaging.support.PubSubHeaderAccesssor; import org.springframework.web.messaging.support.SessionMessageChannel; @@ -47,7 +47,7 @@ public class MessageChannelArgumentResolver implements Argume @Override public Object resolveArgument(MethodParameter parameter, M message) throws Exception { Assert.notNull(this.messageBrokerChannel, "messageBrokerChannel is required"); - final String sessionId = PubSubHeaders.fromMessageHeaders(message.getHeaders()).getSessionId(); + final String sessionId = PubSubHeaderAccesssor.wrap(message).getSessionId(); return new SessionMessageChannel(this.messageBrokerChannel, sessionId); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java index f9aab71517b..0d525864789 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/service/method/MessageReturnValueHandler.java @@ -21,7 +21,7 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; -import org.springframework.web.messaging.PubSubHeaders; +import org.springframework.web.messaging.support.PubSubHeaderAccesssor; /** @@ -75,13 +75,13 @@ public class MessageReturnValueHandler implements ReturnValue protected M updateReturnMessage(M returnMessage, M message) { - PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); String sessionId = headers.getSessionId(); String subscriptionId = headers.getSubscriptionId(); Assert.notNull(subscriptionId, "No subscription id: " + message); - PubSubHeaders returnHeaders = PubSubHeaders.fromMessageHeaders(returnMessage.getHeaders()); + PubSubHeaderAccesssor returnHeaders = PubSubHeaderAccesssor.wrap(returnMessage); returnHeaders.setSessionId(sessionId); returnHeaders.setSubscriptionId(subscriptionId); @@ -89,13 +89,12 @@ public class MessageReturnValueHandler implements ReturnValue returnHeaders.setDestination(headers.getDestination()); } - Object payload = returnMessage.getPayload(); - return createMessage(returnHeaders, payload); + return createMessage(returnHeaders, returnMessage.getPayload()); } @SuppressWarnings("unchecked") - private M createMessage(PubSubHeaders returnHeaders, Object payload) { - return (M) MessageBuilder.fromPayloadAndHeaders(payload, returnHeaders.toMessageHeaders()).build(); + private M createMessage(PubSubHeaderAccesssor returnHeaders, Object payload) { + return (M) MessageBuilder.withPayload(payload).copyHeaders(returnHeaders.toHeaders()).build(); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderAccessor.java similarity index 86% rename from spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java rename to spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderAccessor.java index 3adcc1318bb..f47d98fa174 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/StompHeaders.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompHeaderAccessor.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.web.messaging.stomp; +package org.springframework.web.messaging.stomp.support; import java.util.Collections; import java.util.HashMap; @@ -25,13 +25,13 @@ import java.util.concurrent.atomic.AtomicLong; import org.springframework.http.MediaType; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHeaders; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import org.springframework.web.messaging.PubSubHeaders; +import org.springframework.web.messaging.stomp.StompCommand; +import org.springframework.web.messaging.support.PubSubHeaderAccesssor; /** @@ -39,13 +39,13 @@ import org.springframework.web.messaging.PubSubHeaders; * STOMP-specific headers of an existing message. *

* Use one of the static factory method in this class, then call getters and setters, and - * at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers + * at the end if necessary call {@link #toHeaders()} to obtain the updated headers * or call {@link #toStompMessageHeaders()} to obtain only the STOMP-specific headers. * * @author Rossen Stoyanchev * @since 4.0 */ -public class StompHeaders extends PubSubHeaders { +public class StompHeaderAccessor extends PubSubHeaderAccesssor { public static final String STOMP_ID = "id"; @@ -88,7 +88,7 @@ public class StompHeaders extends PubSubHeaders { * A constructor for creating new STOMP message headers. * This constructor is private. See factory methods in this sub-classes. */ - private StompHeaders(StompCommand command, Map> externalSourceHeaders) { + private StompHeaderAccessor(StompCommand command, Map> externalSourceHeaders) { super(command.getMessageType(), command, externalSourceHeaders); this.headers = new HashMap(4); updateMessageHeaders(); @@ -118,32 +118,32 @@ public class StompHeaders extends PubSubHeaders { * constructor is protected. See factory methods in this class. */ @SuppressWarnings("unchecked") - private StompHeaders(MessageHeaders messageHeaders) { - super(messageHeaders); - this.headers = (messageHeaders.get(STOMP_HEADERS) != null) ? - (Map) messageHeaders.get(STOMP_HEADERS) : new HashMap(4); + private StompHeaderAccessor(Message message) { + super(message); + this.headers = (message.getHeaders() .get(STOMP_HEADERS) != null) ? + (Map) message.getHeaders().get(STOMP_HEADERS) : new HashMap(4); } /** - * Create {@link StompHeaders} for a new {@link Message}. + * Create {@link StompHeaderAccessor} for a new {@link Message}. */ - public static StompHeaders create(StompCommand command) { - return new StompHeaders(command, null); + public static StompHeaderAccessor create(StompCommand command) { + return new StompHeaderAccessor(command, null); } /** - * Create {@link StompHeaders} from the headers of an existing {@link Message}. + * Create {@link StompHeaderAccessor} from parsed STOP frame content. */ - public static StompHeaders fromMessageHeaders(MessageHeaders messageHeaders) { - return new StompHeaders(messageHeaders); + public static StompHeaderAccessor create(StompCommand command, Map> headers) { + return new StompHeaderAccessor(command, headers); } /** - * Create {@link StompHeaders} from parsed STOP frame content. + * Create {@link StompHeaderAccessor} from the headers of an existing {@link Message}. */ - public static StompHeaders fromParsedFrame(StompCommand command, Map> headers) { - return new StompHeaders(command, headers); + public static StompHeaderAccessor wrap(Message message) { + return new StompHeaderAccessor(message); } @@ -152,8 +152,8 @@ public class StompHeaders extends PubSubHeaders { * updates made via setters. */ @Override - public Map toMessageHeaders() { - Map result = super.toMessageHeaders(); + public Map toHeaders() { + Map result = super.toHeaders(); if (isModified()) { result.put(STOMP_HEADERS, this.headers); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java index b3cbbaf9843..f212d6735fa 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompMessageConverter.java @@ -22,14 +22,12 @@ import java.util.List; import java.util.Map.Entry; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.StompConversionException; -import org.springframework.web.messaging.stomp.StompHeaders; /** @@ -96,7 +94,7 @@ public class StompMessageConverter { } } - StompHeaders stompHeaders = StompHeaders.fromParsedFrame(command, headers); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers); stompHeaders.setSessionId(sessionId); byte[] payload = new byte[totalLength - payloadIndex]; @@ -106,8 +104,8 @@ public class StompMessageConverter { } @SuppressWarnings("unchecked") - private M createMessage(StompHeaders stompHeaders, byte[] payload) { - return (M) MessageBuilder.fromPayloadAndHeaders(payload, stompHeaders.toMessageHeaders()).build(); + private M createMessage(StompHeaderAccessor stompHeaders, byte[] payload) { + return (M) MessageBuilder.withPayload(payload).copyHeaders(stompHeaders.toHeaders()).build(); } private int findIndexOfPayload(byte[] bytes) { @@ -149,8 +147,7 @@ public class StompMessageConverter { } ByteArrayOutputStream out = new ByteArrayOutputStream(); - MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); try { out.write(stompHeaders.getStompCommand().toString().getBytes("UTF-8")); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java index 85d272f4905..1a9fcadbece 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompRelayPubSubMessageHandler.java @@ -33,12 +33,11 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.PubSubChannelRegistry; -import org.springframework.web.messaging.PubSubHeaders; import org.springframework.web.messaging.converter.CompositeMessageConverter; import org.springframework.web.messaging.converter.MessageConverter; import org.springframework.web.messaging.service.AbstractPubSubMessageHandler; import org.springframework.web.messaging.stomp.StompCommand; -import org.springframework.web.messaging.stomp.StompHeaders; +import org.springframework.web.messaging.support.PubSubHeaderAccesssor; import reactor.core.Environment; import reactor.core.Promise; @@ -97,7 +96,7 @@ public class StompRelayPubSubMessageHandler extends AbstractP @Override public void handleConnect(M message) { - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); String sessionId = stompHeaders.getSessionId(); if (sessionId == null) { logger.error("No sessionId in message " + message); @@ -124,7 +123,7 @@ public class StompRelayPubSubMessageHandler extends AbstractP @Override public void handleDisconnect(M message) { - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); if (stompHeaders.getStompCommand() != null) { forwardMessage(message, StompCommand.DISCONNECT); } @@ -137,14 +136,14 @@ public class StompRelayPubSubMessageHandler extends AbstractP @Override public void handleOther(M message) { - StompCommand command = (StompCommand) message.getHeaders().get(PubSubHeaders.PROTOCOL_MESSAGE_TYPE); + StompCommand command = (StompCommand) message.getHeaders().get(PubSubHeaderAccesssor.PROTOCOL_MESSAGE_TYPE); Assert.notNull(command, "Expected STOMP command: " + message.getHeaders()); forwardMessage(message, command); } private void forwardMessage(M message, StompCommand command) { - StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders()); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); headers.setStompCommandIfNotSet(command); String sessionId = headers.getSessionId(); @@ -174,9 +173,10 @@ public class StompRelayPubSubMessageHandler extends AbstractP private final Object monitor = new Object(); - private boolean isConnected = false; + private volatile boolean isConnected = false; - public RelaySession(final M message, final StompHeaders stompHeaders) { + + public RelaySession(final M message, final StompHeaderAccessor stompHeaders) { Assert.notNull(message, "message is required"); Assert.notNull(stompHeaders, "stompHeaders is required"); @@ -222,7 +222,7 @@ public class StompRelayPubSubMessageHandler extends AbstractP logger.trace("Reading message " + message); } - StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders()); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (StompCommand.CONNECTED == headers.getStompCommand()) { synchronized(this.monitor) { this.isConnected = true; @@ -240,15 +240,15 @@ public class StompRelayPubSubMessageHandler extends AbstractP } private void sendError(String sessionId, String errorText) { - StompHeaders stompHeaders = StompHeaders.create(StompCommand.ERROR); - stompHeaders.setSessionId(sessionId); - stompHeaders.setMessage(errorText); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); + headers.setSessionId(sessionId); + headers.setMessage(errorText); @SuppressWarnings("unchecked") - M errorMessage = (M) MessageBuilder.fromPayloadAndHeaders(new byte[0], stompHeaders.toMessageHeaders()).build(); + M errorMessage = (M) MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toHeaders()).build(); clientChannel.send(errorMessage); } - public void forward(M message, StompHeaders headers) { + public void forward(M message, StompHeaderAccessor headers) { if (!this.isConnected) { synchronized(this.monitor) { @@ -277,21 +277,21 @@ public class StompRelayPubSubMessageHandler extends AbstractP List messages = new ArrayList(); this.messageQueue.drainTo(messages); for (Message message : messages) { - StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders()); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); if (!forwardInternal(message, headers, connection)) { return; } } } - private boolean forwardInternal(Message message, StompHeaders headers, TcpConnection connection) { + private boolean forwardInternal(Message message, StompHeaderAccessor headers, TcpConnection connection) { try { headers.setStompCommandIfNotSet(StompCommand.SEND); MediaType contentType = headers.getContentType(); byte[] payload = payloadConverter.convertToPayload(message.getPayload(), contentType); @SuppressWarnings("unchecked") - M byteMessage = (M) MessageBuilder.fromPayloadAndHeaders(payload, headers.toMessageHeaders()).build(); + M byteMessage = (M) MessageBuilder.withPayload(payload).copyHeaders(headers.toHeaders()).build(); if (logger.isTraceEnabled()) { logger.trace("Forwarding message " + byteMessage); diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java index ecd3089676c..8600ec8a274 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/stomp/support/StompWebSocketHandler.java @@ -31,12 +31,11 @@ import org.springframework.messaging.MessageHandler; import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.PubSubChannelRegistry; -import org.springframework.web.messaging.PubSubHeaders; import org.springframework.web.messaging.converter.CompositeMessageConverter; import org.springframework.web.messaging.converter.MessageConverter; import org.springframework.web.messaging.stomp.StompCommand; import org.springframework.web.messaging.stomp.StompConversionException; -import org.springframework.web.messaging.stomp.StompHeaders; +import org.springframework.web.messaging.support.PubSubHeaderAccesssor; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; @@ -107,7 +106,7 @@ public class StompWebSocketHandler extends TextWebSocketHandl } try { - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); MessageType messageType = stompHeaders.getMessageType(); if (MessageType.CONNECT.equals(messageType)) { handleConnect(session, message); @@ -142,8 +141,8 @@ public class StompWebSocketHandler extends TextWebSocketHandl protected void handleConnect(final WebSocketSession session, M message) throws IOException { - StompHeaders connectHeaders = StompHeaders.fromMessageHeaders(message.getHeaders()); - StompHeaders connectedHeaders = StompHeaders.create(StompCommand.CONNECTED); + StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(message); + StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED); Set acceptVersions = connectHeaders.getAcceptVersion(); if (acceptVersions.contains("1.2")) { @@ -163,8 +162,8 @@ public class StompWebSocketHandler extends TextWebSocketHandl // TODO: security @SuppressWarnings("unchecked") - M connectedMessage = (M) MessageBuilder.fromPayloadAndHeaders(EMPTY_PAYLOAD, - connectedHeaders.toMessageHeaders()).build(); + M connectedMessage = (M) MessageBuilder.withPayload(EMPTY_PAYLOAD).copyHeaders( + connectedHeaders.toHeaders()).build(); byte[] bytes = getStompMessageConverter().fromMessage(connectedMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } @@ -186,11 +185,11 @@ public class StompWebSocketHandler extends TextWebSocketHandl protected void sendErrorMessage(WebSocketSession session, Throwable error) { - StompHeaders headers = StompHeaders.create(StompCommand.ERROR); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR); headers.setMessage(error.getMessage()); @SuppressWarnings("unchecked") - M message = (M) MessageBuilder.fromPayloadAndHeaders(EMPTY_PAYLOAD, headers.toMessageHeaders()).build(); + M message = (M) MessageBuilder.withPayload(EMPTY_PAYLOAD).copyHeaders(headers.toHeaders()).build(); byte[] bytes = this.stompMessageConverter.fromMessage(message); try { @@ -204,10 +203,10 @@ public class StompWebSocketHandler extends TextWebSocketHandl @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { this.sessions.remove(session.getId()); - PubSubHeaders headers = PubSubHeaders.create(MessageType.DISCONNECT); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.create(MessageType.DISCONNECT); headers.setSessionId(session.getId()); @SuppressWarnings("unchecked") - M message = (M) MessageBuilder.fromPayloadAndHeaders(new byte[0], headers.toMessageHeaders()).build(); + M message = (M) MessageBuilder.withPayload(new byte[0]).copyHeaders(headers.toHeaders()).build(); this.outputChannel.send(message); } @@ -217,7 +216,7 @@ public class StompWebSocketHandler extends TextWebSocketHandl @Override public void handleMessage(M message) { - StompHeaders headers = StompHeaders.fromMessageHeaders(message.getHeaders()); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); headers.setStompCommandIfNotSet(StompCommand.MESSAGE); if (StompCommand.CONNECTED.equals(headers.getStompCommand())) { @@ -246,7 +245,7 @@ public class StompWebSocketHandler extends TextWebSocketHandl try { @SuppressWarnings("unchecked") - M byteMessage = (M) MessageBuilder.fromPayloadAndHeaders(payload, headers.toMessageHeaders()).build(); + M byteMessage = (M) MessageBuilder.withPayload(payload).copyHeaders(headers.toHeaders()).build(); byte[] bytes = getStompMessageConverter().fromMessage(byteMessage); session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8")))); } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubHeaderAccesssor.java similarity index 83% rename from spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java rename to spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubHeaderAccesssor.java index dfbfda9b326..a05329df2e1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/PubSubHeaders.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/PubSubHeaderAccesssor.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.web.messaging; +package org.springframework.web.messaging.support; import java.util.Arrays; import java.util.Collections; @@ -30,6 +30,7 @@ import org.springframework.messaging.MessageHeaders; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; +import org.springframework.web.messaging.MessageType; /** @@ -42,12 +43,12 @@ import org.springframework.util.LinkedMultiValueMap; * and/or modify headers of an existing message. *

* Use one of the static factory method in this class, then call getters and setters, and - * at the end if necessary call {@link #toMessageHeaders()} to obtain the updated headers. + * at the end if necessary call {@link #toHeaders()} to obtain the updated headers. * * @author Rossen Stoyanchev * @since 4.0 */ -public class PubSubHeaders { +public class PubSubHeaderAccesssor { protected Log logger = LogFactory.getLog(getClass()); @@ -85,7 +86,7 @@ public class PubSubHeaders { * A constructor for creating new message headers. * This constructor is protected. See factory methods in this and sub-classes. */ - protected PubSubHeaders(MessageType messageType, Object protocolMessageType, + protected PubSubHeaderAccesssor(MessageType messageType, Object protocolMessageType, Map> externalSourceHeaders) { this.originalHeaders = null; @@ -111,33 +112,34 @@ public class PubSubHeaders { * constructor is protected. See factory methods in this and sub-classes. */ @SuppressWarnings("unchecked") - protected PubSubHeaders(MessageHeaders originalHeaders) { - Assert.notNull(originalHeaders, "originalHeaders is required"); - this.originalHeaders = originalHeaders; - this.externalSourceHeaders = (originalHeaders.get(EXTERNAL_SOURCE_HEADERS) != null) ? - (Map>) originalHeaders.get(EXTERNAL_SOURCE_HEADERS) : emptyMultiValueMap; + protected PubSubHeaderAccesssor(Message message) { + Assert.notNull(message, "message is required"); + this.originalHeaders = message.getHeaders(); + this.externalSourceHeaders = (this.originalHeaders.get(EXTERNAL_SOURCE_HEADERS) != null) ? + (Map>) this.originalHeaders.get(EXTERNAL_SOURCE_HEADERS) : emptyMultiValueMap; } /** - * Create {@link PubSubHeaders} for a new {@link Message}. + * Create {@link PubSubHeaderAccesssor} for a new {@link Message} with + * {@link MessageType#MESSAGE}. */ - public static PubSubHeaders create() { - return new PubSubHeaders(MessageType.MESSAGE, null, null); + public static PubSubHeaderAccesssor create() { + return new PubSubHeaderAccesssor(MessageType.MESSAGE, null, null); } /** - * Create {@link PubSubHeaders} for a new {@link Message} of a specific type. + * Create {@link PubSubHeaderAccesssor} for a new {@link Message} of a specific type. */ - public static PubSubHeaders create(MessageType messageType) { - return new PubSubHeaders(messageType, null, null); + public static PubSubHeaderAccesssor create(MessageType messageType) { + return new PubSubHeaderAccesssor(messageType, null, null); } /** - * Create {@link PubSubHeaders} from existing message headers. + * Create {@link PubSubHeaderAccesssor} from the headers of an existing message. */ - public static PubSubHeaders fromMessageHeaders(MessageHeaders originalHeaders) { - return new PubSubHeaders(originalHeaders); + public static PubSubHeaderAccesssor wrap(Message message) { + return new PubSubHeaderAccesssor(message); } @@ -145,7 +147,7 @@ public class PubSubHeaders { * Return the original, wrapped headers (i.e. unmodified) or a new Map including any * updates made via setters. */ - public Map toMessageHeaders() { + public Map toHeaders() { if (!isModified()) { return this.originalHeaders; } diff --git a/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java b/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java index c466c031b88..6ae8fc59f70 100644 --- a/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java +++ b/spring-websocket/src/main/java/org/springframework/web/messaging/support/SessionMessageChannel.java @@ -19,7 +19,6 @@ package org.springframework.web.messaging.support; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.support.MessageBuilder; -import org.springframework.web.messaging.PubSubHeaders; import reactor.util.Assert; @@ -50,10 +49,11 @@ public class SessionMessageChannel implements MessageChannel< @Override public boolean send(M message, long timeout) { - PubSubHeaders headers = PubSubHeaders.fromMessageHeaders(message.getHeaders()); + PubSubHeaderAccesssor headers = PubSubHeaderAccesssor.wrap(message); headers.setSessionId(this.sessionId); + Object payload = message.getPayload(); @SuppressWarnings("unchecked") - M messageToSend = (M) MessageBuilder.fromPayloadAndHeaders(message.getPayload(), headers.toMessageHeaders()).build(); + M messageToSend = (M) MessageBuilder.withPayload(payload).copyHeaders(headers.toHeaders()).build(); this.delegate.send(messageToSend); return true; } diff --git a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java index e2eaed097f5..c2f88e9db49 100644 --- a/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/messaging/stomp/support/StompMessageConverterTests.java @@ -23,7 +23,6 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; import org.springframework.web.messaging.MessageType; import org.springframework.web.messaging.stomp.StompCommand; -import org.springframework.web.messaging.stomp.StompHeaders; import static org.junit.Assert.*; @@ -33,12 +32,12 @@ import static org.junit.Assert.*; */ public class StompMessageConverterTests { - private StompMessageConverter converter; + private StompMessageConverter> converter; @Before public void setup() { - this.converter = new StompMessageConverter(); + this.converter = new StompMessageConverter>(); } @Test @@ -51,9 +50,9 @@ public class StompMessageConverterTests { assertEquals(0, message.getPayload().length); - MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); - assertEquals(7, stompHeaders.toMessageHeaders().size()); + MessageHeaders headers = message.getHeaders(); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); + assertEquals(7, stompHeaders.toHeaders().size()); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); assertEquals("github.org", stompHeaders.getHost()); @@ -61,8 +60,8 @@ public class StompMessageConverterTests { assertEquals(MessageType.CONNECT, stompHeaders.getMessageType()); assertEquals(StompCommand.CONNECT, stompHeaders.getStompCommand()); assertEquals("session-123", stompHeaders.getSessionId()); - assertNotNull(messageHeaders.get(MessageHeaders.ID)); - assertNotNull(messageHeaders.get(MessageHeaders.TIMESTAMP)); + assertNotNull(headers.get(MessageHeaders.ID)); + assertNotNull(headers.get(MessageHeaders.TIMESTAMP)); String convertedBack = new String(this.converter.fromMessage(message), "UTF-8"); @@ -81,8 +80,7 @@ public class StompMessageConverterTests { assertEquals(0, message.getPayload().length); - MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0)); @@ -103,8 +101,7 @@ public class StompMessageConverterTests { assertEquals(0, message.getPayload().length); - MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion()); assertEquals("github.org", stompHeaders.getHost()); @@ -125,8 +122,7 @@ public class StompMessageConverterTests { assertEquals(0, message.getPayload().length); - MessageHeaders messageHeaders = message.getHeaders(); - StompHeaders stompHeaders = StompHeaders.fromMessageHeaders(messageHeaders); + StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message); assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion()); assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.getExternalSourceHeaders().get("ho:\ns\rt").get(0));