diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java index 8709ec114c5..7ec3c279e0f 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java @@ -68,7 +68,6 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { */ protected SimpMessageHeaderAccessor(Message message) { super(message); - Assert.notNull(message, "message is required"); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java index b5d147ef3ad..54ed13c3e8b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/NativeMessageHeaderAccessor.java @@ -107,7 +107,10 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { return result; } - protected List getNativeHeader(String headerName) { + /** + * Return all values for the specified native header or {@code null}. + */ + public List getNativeHeader(String headerName) { if (this.nativeHeaders.containsKey(headerName)) { return this.nativeHeaders.get(headerName); } @@ -117,23 +120,28 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor { return null; } + /** + * Return the first value for the specified native header of {@code null}. + */ public String getFirstNativeHeader(String headerName) { List values = getNativeHeader(headerName); return CollectionUtils.isEmpty(values) ? null : values.get(0); } /** - * Set the value for the given header name. If the provided value is {@code null} the - * header will be removed. + * Set the specified native header value. */ - protected void putNativeHeader(String name, List value) { + public void setNativeHeader(String name, String value) { if (!ObjectUtils.nullSafeEquals(value, getHeader(name))) { - this.nativeHeaders.put(name, value); + this.nativeHeaders.set(name, value); } } - public void setNativeHeader(String name, String value) { - this.nativeHeaders.set(name, value); + /** + * Add the specified native header value. + */ + public void addNativeHeader(String name, String value) { + this.nativeHeaders.add(name, value); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractMessageChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractMessageChannel.java new file mode 100644 index 00000000000..b446981363b --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractMessageChannel.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support.channel; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.beans.factory.BeanNameAware; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageDeliveryException; +import org.springframework.messaging.MessagingException; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; + + +/** + * Abstract base class for {@link MessageChannel} implementations. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public abstract class AbstractMessageChannel implements MessageChannel, BeanNameAware { + + protected Log logger = LogFactory.getLog(getClass()); + + private String beanName; + + private final ChannelInterceptorChain interceptorChain = new ChannelInterceptorChain(); + + + public AbstractMessageChannel() { + this.beanName = getClass().getSimpleName() + "@" + ObjectUtils.getIdentityHexString(this); + } + + /** + * {@inheritDoc} + *

Used primarily for logging purposes. + */ + @Override + public void setBeanName(String name) { + this.beanName = name; + } + + /** + * @return the name for this channel. + */ + public String getBeanName() { + return this.beanName; + } + + /** + * Set the list of channel interceptors. This will clear any existing interceptors. + */ + public void setInterceptors(List interceptors) { + this.interceptorChain.set(interceptors); + } + + /** + * Add a channel interceptor to the end of the list. + */ + public void addInterceptor(ChannelInterceptor interceptor) { + this.interceptorChain.add(interceptor); + } + + /** + * Return a read-only list of the configured interceptors. + */ + public List getInterceptors() { + return this.interceptorChain.getInterceptors(); + } + + /** + * Exposes the interceptor list for subclasses. + */ + protected ChannelInterceptorChain getInterceptorChain() { + return this.interceptorChain; + } + + + @Override + public final boolean send(Message message) { + return send(message, INDEFINITE_TIMEOUT); + } + + @Override + public final boolean send(Message message, long timeout) { + + Assert.notNull(message, "Message must not be null"); + if (logger.isTraceEnabled()) { + logger.trace("[" + this.beanName + "] sending message " + message); + } + + message = this.interceptorChain.preSend(message, this); + if (message == null) { + return false; + } + + try { + boolean sent = sendInternal(message, timeout); + this.interceptorChain.postSend(message, this, sent); + return sent; + } + catch (Exception e) { + if (e instanceof MessagingException) { + throw (MessagingException) e; + } + throw new MessageDeliveryException(message, + "Failed to send message to channel '" + this.getBeanName() + "'", e); + } + } + + protected abstract boolean sendInternal(Message message, long timeout); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java index b7a3800d856..9740e149713 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/AbstractSubscribableChannel.java @@ -16,14 +16,8 @@ package org.springframework.messaging.support.channel; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.beans.factory.BeanNameAware; -import org.springframework.messaging.Message; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.SubscribableChannel; -import org.springframework.util.Assert; -import org.springframework.util.ObjectUtils; /** @@ -32,57 +26,17 @@ import org.springframework.util.ObjectUtils; * @author Rossen Stoyanchev * @since 4.0 */ -public abstract class AbstractSubscribableChannel implements SubscribableChannel, BeanNameAware { +public abstract class AbstractSubscribableChannel extends AbstractMessageChannel implements SubscribableChannel { - protected Log logger = LogFactory.getLog(getClass()); - - private String beanName; - - - public AbstractSubscribableChannel() { - this.beanName = getClass().getSimpleName() + "@" + ObjectUtils.getIdentityHexString(this); - } - - /** - * {@inheritDoc} - *

Used primarily for logging purposes. - */ - @Override - public void setBeanName(String name) { - this.beanName = name; - } - - /** - * @return the name for this channel. - */ - public String getBeanName() { - return this.beanName; - } - - @Override - public final boolean send(Message message) { - return send(message, INDEFINITE_TIMEOUT); - } - - @Override - public final boolean send(Message message, long timeout) { - Assert.notNull(message, "Message must not be null"); - if (logger.isTraceEnabled()) { - logger.trace("[" + this.beanName + "] sending message " + message); - } - return sendInternal(message, timeout); - } - - protected abstract boolean sendInternal(Message message, long timeout); @Override public final boolean subscribe(MessageHandler handler) { if (hasSubscription(handler)) { - logger.warn("[" + this.beanName + "] handler already subscribed " + handler); + logger.warn("[" + getBeanName() + "] handler already subscribed " + handler); return false; } if (logger.isDebugEnabled()) { - logger.debug("[" + this.beanName + "] subscribing " + handler); + logger.debug("[" + getBeanName() + "] subscribing " + handler); } return subscribeInternal(handler); } @@ -94,7 +48,7 @@ public abstract class AbstractSubscribableChannel implements SubscribableChannel @Override public final boolean unsubscribe(MessageHandler handler) { if (logger.isDebugEnabled()) { - logger.debug("[" + this.beanName + "] unsubscribing " + handler); + logger.debug("[" + getBeanName() + "] unsubscribing " + handler); } return unsubscribeInternal(handler); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ChannelInterceptor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ChannelInterceptor.java new file mode 100644 index 00000000000..8c7c26b7bc7 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ChannelInterceptor.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2010 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support.channel; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; + +/** + * Interface for interceptors that are able to view and/or modify the + * {@link Message Messages} being sent-to and/or received-from a + * {@link MessageChannel}. + * + * @author Mark Fisher + * @since 4.0 + */ +public interface ChannelInterceptor { + + /** + * Invoked before the Message is actually sent to the channel. + * This allows for modification of the Message if necessary. + * If this method returns null, then the actual + * send invocation will not occur. + */ + Message preSend(Message message, MessageChannel channel); + + /** + * Invoked immediately after the send invocation. The boolean + * value argument represents the return value of that invocation. + */ + void postSend(Message message, MessageChannel channel, boolean sent); + + /** + * Invoked as soon as receive is called and before a Message is + * actually retrieved. If the return value is 'false', then no + * Message will be retrieved. This only applies to PollableChannels. + */ + boolean preReceive(MessageChannel channel); + + /** + * Invoked immediately after a Message has been retrieved but before + * it is returned to the caller. The Message may be modified if + * necessary. This only applies to PollableChannels. + */ + Message postReceive(Message message, MessageChannel channel); + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ChannelInterceptorAdapter.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ChannelInterceptorAdapter.java new file mode 100644 index 00000000000..342e3ce2e8f --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ChannelInterceptorAdapter.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support.channel; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; + +/** + * A {@link ChannelInterceptor} with empty method implementations as a convenience. + * + * @author Mark Fisher + * @since 4.0 + */ +public class ChannelInterceptorAdapter implements ChannelInterceptor { + + + public Message preSend(Message message, MessageChannel channel) { + return message; + } + + public void postSend(Message message, MessageChannel channel, boolean sent) { + } + + public boolean preReceive(MessageChannel channel) { + return true; + } + + public Message postReceive(Message message, MessageChannel channel) { + return message; + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ChannelInterceptorChain.java b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ChannelInterceptorChain.java new file mode 100644 index 00000000000..fe62dfb15e4 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/channel/ChannelInterceptorChain.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support.channel; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; + + +/** + * A convenience wrapper class for invoking a list of {@link ChannelInterceptor}s. + * + * @author Mark Fisher + * @author Rossen Stoyanchev + * @since 4.0 + */ +class ChannelInterceptorChain { + + private static final Log logger = LogFactory.getLog(ChannelInterceptorChain.class); + + private final List interceptors = new CopyOnWriteArrayList(); + + + public boolean set(List interceptors) { + synchronized (this.interceptors) { + this.interceptors.clear(); + return this.interceptors.addAll(interceptors); + } + } + + public boolean add(ChannelInterceptor interceptor) { + return this.interceptors.add(interceptor); + } + + public List getInterceptors() { + return Collections.unmodifiableList(this.interceptors); + } + + + public Message preSend(Message message, MessageChannel channel) { + if (logger.isTraceEnabled()) { + logger.trace("preSend on channel '" + channel + "', message: " + message); + } + for (ChannelInterceptor interceptor : this.interceptors) { + message = interceptor.preSend(message, channel); + if (message == null) { + return null; + } + } + return message; + } + + public void postSend(Message message, MessageChannel channel, boolean sent) { + if (logger.isTraceEnabled()) { + logger.trace("postSend (sent=" + sent + ") on channel '" + channel + "', message: " + message); + } + for (ChannelInterceptor interceptor : this.interceptors) { + interceptor.postSend(message, channel, sent); + } + } + + public boolean preReceive(MessageChannel channel) { + if (logger.isTraceEnabled()) { + logger.trace("preReceive on channel '" + channel + "'"); + } + for (ChannelInterceptor interceptor : this.interceptors) { + if (!interceptor.preReceive(channel)) { + return false; + } + } + return true; + } + + public Message postReceive(Message message, MessageChannel channel) { + if (message != null && logger.isTraceEnabled()) { + logger.trace("postReceive on channel '" + channel + "', message: " + message); + } + else if (logger.isTraceEnabled()) { + logger.trace("postReceive on channel '" + channel + "', message is null"); + } + for (ChannelInterceptor interceptor : this.interceptors) { + message = interceptor.postReceive(message, channel); + if (message == null) { + return null; + } + } + return message; + } + + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompHeaderAccessorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompHeaderAccessorTests.java index 89f459f034c..fda1aa3dffd 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompHeaderAccessorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompHeaderAccessorTests.java @@ -16,8 +16,14 @@ package org.springframework.messaging.simp.stomp; +import java.util.List; +import java.util.Map; + import org.junit.Test; +import org.springframework.http.MediaType; +import org.springframework.messaging.simp.SimpMessageType; import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import static org.junit.Assert.*; @@ -32,7 +38,8 @@ public class StompHeaderAccessorTests { @Test - public void testStompCommandSet() { + public void createWithCommand() { + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED); assertEquals(StompCommand.CONNECTED, accessor.getCommand()); @@ -40,4 +47,111 @@ public class StompHeaderAccessorTests { assertEquals(StompCommand.CONNECTED, accessor.getCommand()); } + @Test + public void createWithSubscribeNativeHeaders() { + + MultiValueMap extHeaders = new LinkedMultiValueMap<>(); + extHeaders.add(StompHeaderAccessor.STOMP_ID_HEADER, "s1"); + extHeaders.add(StompHeaderAccessor.STOMP_DESTINATION_HEADER, "/d"); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE, extHeaders); + + assertEquals(StompCommand.SUBSCRIBE, headers.getCommand()); + assertEquals(SimpMessageType.SUBSCRIBE, headers.getMessageType()); + assertEquals("/d", headers.getDestination()); + assertEquals("s1", headers.getSubscriptionId()); + } + + @Test + public void createWithUnubscribeNativeHeaders() { + + MultiValueMap extHeaders = new LinkedMultiValueMap<>(); + extHeaders.add(StompHeaderAccessor.STOMP_ID_HEADER, "s1"); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.UNSUBSCRIBE, extHeaders); + + assertEquals(StompCommand.UNSUBSCRIBE, headers.getCommand()); + assertEquals(SimpMessageType.UNSUBSCRIBE, headers.getMessageType()); + assertEquals("s1", headers.getSubscriptionId()); + } + + @Test + public void createWithMessageFrameNativeHeaders() { + + MultiValueMap extHeaders = new LinkedMultiValueMap<>(); + extHeaders.add(StompHeaderAccessor.DESTINATION_HEADER, "/d"); + extHeaders.add(StompHeaderAccessor.STOMP_SUBSCRIPTION_HEADER, "s1"); + extHeaders.add(StompHeaderAccessor.STOMP_CONTENT_TYPE_HEADER, "application/json"); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE, extHeaders); + + assertEquals(StompCommand.MESSAGE, headers.getCommand()); + assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); + assertEquals("s1", headers.getSubscriptionId()); + } + + @Test + public void toNativeHeadersSubscribe() { + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE); + headers.setSubscriptionId("s1"); + headers.setDestination("/d"); + + Map> actual = headers.toNativeHeaderMap(); + + assertEquals(2, actual.size()); + assertEquals("s1", actual.get(StompHeaderAccessor.STOMP_ID_HEADER).get(0)); + assertEquals("/d", actual.get(StompHeaderAccessor.STOMP_DESTINATION_HEADER).get(0)); + } + + @Test + public void toNativeHeadersUnsubscribe() { + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.UNSUBSCRIBE); + headers.setSubscriptionId("s1"); + + Map> actual = headers.toNativeHeaderMap(); + + assertEquals(1, actual.size()); + assertEquals("s1", actual.get(StompHeaderAccessor.STOMP_ID_HEADER).get(0)); + } + + @Test + public void toNativeHeadersMessageFrame() { + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE); + headers.setSubscriptionId("s1"); + headers.setDestination("/d"); + headers.setContentType(MediaType.APPLICATION_JSON); + + Map> actual = headers.toNativeHeaderMap(); + + assertEquals(4, actual.size()); + assertEquals("s1", actual.get(StompHeaderAccessor.STOMP_SUBSCRIPTION_HEADER).get(0)); + assertEquals("/d", actual.get(StompHeaderAccessor.STOMP_DESTINATION_HEADER).get(0)); + assertEquals("application/json", actual.get(StompHeaderAccessor.STOMP_CONTENT_TYPE_HEADER).get(0)); + assertNotNull("message-id was not created", actual.get(StompHeaderAccessor.STOMP_MESSAGE_ID_HEADER).get(0)); + } + + @Test + public void modifyCustomNativeHeader() { + + MultiValueMap extHeaders = new LinkedMultiValueMap<>(); + extHeaders.add(StompHeaderAccessor.STOMP_ID_HEADER, "s1"); + extHeaders.add(StompHeaderAccessor.STOMP_DESTINATION_HEADER, "/d"); + extHeaders.add("accountId", "ABC123"); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE, extHeaders); + String accountId = headers.getFirstNativeHeader("accountId"); + headers.setNativeHeader("accountId", accountId.toLowerCase()); + + Map> actual = headers.toNativeHeaderMap(); + assertEquals(3, actual.size()); + + assertEquals("s1", actual.get(StompHeaderAccessor.STOMP_ID_HEADER).get(0)); + assertEquals("/d", actual.get(StompHeaderAccessor.STOMP_DESTINATION_HEADER).get(0)); + assertNotNull("abc123", actual.get("accountId").get(0)); + } + + } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/MessageHeaderAccessorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/MessageHeaderAccessorTests.java new file mode 100644 index 00000000000..f3f6591a10e --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/MessageHeaderAccessorTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; +import org.springframework.messaging.MessageHeaders; + +import static org.junit.Assert.*; + + +/** + * Test fixture for {@link MessageHeaderAccessor}. + * + * @author Rossen Stoyanchev + */ +public class MessageHeaderAccessorTests { + + + @Test + public void empty() { + MessageHeaderAccessor headers = new MessageHeaderAccessor(); + assertEquals(Collections.emptyMap(), headers.toMap()); + } + + @Test + public void wrapMessage() { + Map original = new HashMap<>(); + original.put("foo", "bar"); + original.put("bar", "baz"); + GenericMessage message = new GenericMessage<>("p", original); + + MessageHeaderAccessor headers = new MessageHeaderAccessor(message); + Map actual = headers.toMap(); + + assertEquals(4, actual.size()); + assertNotNull(actual.get(MessageHeaders.ID)); + assertNotNull(actual.get(MessageHeaders.TIMESTAMP)); + assertEquals("bar", actual.get("foo")); + assertEquals("baz", actual.get("bar")); + } + + @Test + public void wrapMessageAndModifyHeaders() { + Map original = new HashMap<>(); + original.put("foo", "bar"); + original.put("bar", "baz"); + GenericMessage message = new GenericMessage<>("p", original); + + MessageHeaderAccessor headers = new MessageHeaderAccessor(message); + headers.setHeader("foo", "BAR"); + Map actual = headers.toMap(); + + assertEquals(4, actual.size()); + assertNotNull(actual.get(MessageHeaders.ID)); + assertNotNull(actual.get(MessageHeaders.TIMESTAMP)); + assertEquals("BAR", actual.get("foo")); + assertEquals("baz", actual.get("bar")); + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java new file mode 100644 index 00000000000..20bf9f25359 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/NativeMessageHeaderAccessorTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; + + +/** + * Test fixture for {@link NativeMessageHeaderAccessor}. + * + * @author Rossen Stoyanchev + */ +public class NativeMessageHeaderAccessorTests { + + + @Test + public void originalNativeHeaders() { + MultiValueMap original = new LinkedMultiValueMap<>(); + original.add("foo", "bar"); + original.add("bar", "baz"); + + NativeMessageHeaderAccessor headers = new NativeMessageHeaderAccessor(original); + Map actual = headers.toMap(); + + assertEquals(1, actual.size()); + assertNotNull(actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS)); + assertEquals(original, actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS)); + } + + @Test + public void wrapMessage() { + + MultiValueMap originalNativeHeaders = new LinkedMultiValueMap<>(); + originalNativeHeaders.add("foo", "bar"); + originalNativeHeaders.add("bar", "baz"); + + Map original = new HashMap(); + original.put("a", "b"); + original.put(NativeMessageHeaderAccessor.NATIVE_HEADERS, originalNativeHeaders); + + GenericMessage message = new GenericMessage<>("p", original); + + NativeMessageHeaderAccessor headers = new NativeMessageHeaderAccessor(message); + Map actual = headers.toMap(); + + assertEquals(4, actual.size()); + assertNotNull(actual.get(MessageHeaders.ID)); + assertNotNull(actual.get(MessageHeaders.TIMESTAMP)); + assertEquals("b", actual.get("a")); + assertNotNull(actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS)); + assertEquals(originalNativeHeaders, actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS)); + } + + @Test + public void wrapNullMessage() { + NativeMessageHeaderAccessor headers = new NativeMessageHeaderAccessor((Message) null); + Map actual = headers.toMap(); + + assertEquals(1, actual.size()); + + @SuppressWarnings("unchecked") + Map> actualNativeHeaders = + (Map>) actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS); + + assertEquals(Collections.emptyMap(), actualNativeHeaders); + } + + @Test + public void wrapMessageAndModifyHeaders() { + + MultiValueMap originalNativeHeaders = new LinkedMultiValueMap<>(); + originalNativeHeaders.add("foo", "bar"); + originalNativeHeaders.add("bar", "baz"); + + Map original = new HashMap(); + original.put("a", "b"); + original.put(NativeMessageHeaderAccessor.NATIVE_HEADERS, originalNativeHeaders); + + GenericMessage message = new GenericMessage<>("p", original); + + NativeMessageHeaderAccessor headers = new NativeMessageHeaderAccessor(message); + headers.setHeader("a", "B"); + headers.setNativeHeader("foo", "BAR"); + + Map actual = headers.toMap(); + + assertEquals(4, actual.size()); + assertNotNull(actual.get(MessageHeaders.ID)); + assertNotNull(actual.get(MessageHeaders.TIMESTAMP)); + assertEquals("B", actual.get("a")); + + @SuppressWarnings("unchecked") + Map> actualNativeHeaders = + (Map>) actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS); + + assertNotNull(actualNativeHeaders); + assertEquals(Arrays.asList("BAR"), actualNativeHeaders.get("foo")); + assertEquals(Arrays.asList("baz"), actualNativeHeaders.get("bar")); + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/channel/ChannelInterceptorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/channel/ChannelInterceptorTests.java new file mode 100644 index 00000000000..1a68361fa90 --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/channel/ChannelInterceptorTests.java @@ -0,0 +1,156 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support.channel; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; +import org.springframework.messaging.support.MessageBuilder; + +import static org.junit.Assert.*; + +/** + * Test fixture for the use of {@link ChannelInterceptor}s. + * @author Rossen Stoyanchev + */ +public class ChannelInterceptorTests { + + private ExecutorSubscribableChannel channel; + + private TestMessageHandler messageHandler; + + + @Before + public void setup() { + this.channel = new ExecutorSubscribableChannel(); + this.messageHandler = new TestMessageHandler(); + this.channel.subscribe(this.messageHandler); + } + + + @Test + public void preSendInterceptorReturningModifiedMessage() { + + this.channel.addInterceptor(new PreSendReturnsMessageInterceptor()); + this.channel.send(MessageBuilder.withPayload("test").build()); + + assertEquals(1, this.messageHandler.messages.size()); + Message result = this.messageHandler.messages.get(0); + + assertNotNull(result); + assertEquals("test", result.getPayload()); + assertEquals(1, result.getHeaders().get(PreSendReturnsMessageInterceptor.class.getSimpleName())); + } + + @Test + public void preSendInterceptorReturningNull() { + + PreSendReturnsNullInterceptor interceptor = new PreSendReturnsNullInterceptor(); + this.channel.addInterceptor(interceptor); + Message message = MessageBuilder.withPayload("test").build(); + this.channel.send(message); + + assertEquals(1, interceptor.counter.get()); + assertEquals(0, this.messageHandler.messages.size()); + } + + @Test + public void postSendInterceptorMessageWasSent() { + final AtomicBoolean invoked = new AtomicBoolean(false); + this.channel.addInterceptor(new ChannelInterceptorAdapter() { + @Override + public void postSend(Message message, MessageChannel channel, boolean sent) { + assertNotNull(message); + assertNotNull(channel); + assertSame(ChannelInterceptorTests.this.channel, channel); + assertTrue(sent); + invoked.set(true); + } + }); + this.channel.send(MessageBuilder.withPayload("test").build()); + assertTrue(invoked.get()); + } + + @Test + public void postSendInterceptorMessageWasNotSent() { + final AbstractMessageChannel testChannel = new AbstractMessageChannel() { + @Override + protected boolean sendInternal(Message message, long timeout) { + return false; + } + }; + final AtomicBoolean invoked = new AtomicBoolean(false); + testChannel.addInterceptor(new ChannelInterceptorAdapter() { + @Override + public void postSend(Message message, MessageChannel channel, boolean sent) { + assertNotNull(message); + assertNotNull(channel); + assertSame(testChannel, channel); + assertFalse(sent); + invoked.set(true); + } + }); + testChannel.send(MessageBuilder.withPayload("test").build()); + assertTrue(invoked.get()); + } + + + private static class TestMessageHandler implements MessageHandler { + + private List> messages = new ArrayList>(); + + @Override + public void handleMessage(Message message) throws MessagingException { + this.messages.add(message); + } + } + + private static class PreSendReturnsMessageInterceptor extends ChannelInterceptorAdapter { + + private AtomicInteger counter = new AtomicInteger(); + + private String foo; + + @Override + public Message preSend(Message message, MessageChannel channel) { + assertNotNull(message); + return MessageBuilder.fromMessage(message).setHeader( + this.getClass().getSimpleName(), counter.incrementAndGet()).build(); + } + } + + private static class PreSendReturnsNullInterceptor extends ChannelInterceptorAdapter { + + private AtomicInteger counter = new AtomicInteger(); + + @Override + public Message preSend(Message message, MessageChannel channel) { + assertNotNull(message); + counter.incrementAndGet(); + return null; + } + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java index 4779d1c4da1..5db559e7d0e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/channel/PublishSubscibeChannelTests.java @@ -26,6 +26,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessagingException; import org.springframework.messaging.support.MessageBuilder; @@ -117,8 +118,8 @@ public class PublishSubscibeChannelTests { try { this.channel.send(message); } - catch(RuntimeException actualException) { - assertThat(actualException, equalTo(ex)); + catch(MessageDeliveryException actualException) { + assertThat((RuntimeException) actualException.getCause(), equalTo(ex)); } verifyZeroInteractions(secondHandler); }