Add ChannelInterceptor to spring-messaging module

Issue: SPR-10866
This commit is contained in:
Rossen Stoyanchev 2013-08-28 14:41:20 -04:00
parent 467a6b9fa7
commit 4b2847d9d1
12 changed files with 844 additions and 61 deletions

View File

@ -68,7 +68,6 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
*/
protected SimpMessageHeaderAccessor(Message<?> message) {
super(message);
Assert.notNull(message, "message is required");
}

View File

@ -107,7 +107,10 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor {
return result;
}
protected List<String> getNativeHeader(String headerName) {
/**
* Return all values for the specified native header or {@code null}.
*/
public List<String> getNativeHeader(String headerName) {
if (this.nativeHeaders.containsKey(headerName)) {
return this.nativeHeaders.get(headerName);
}
@ -117,23 +120,28 @@ public class NativeMessageHeaderAccessor extends MessageHeaderAccessor {
return null;
}
/**
* Return the first value for the specified native header of {@code null}.
*/
public String getFirstNativeHeader(String headerName) {
List<String> values = getNativeHeader(headerName);
return CollectionUtils.isEmpty(values) ? null : values.get(0);
}
/**
* Set the value for the given header name. If the provided value is {@code null} the
* header will be removed.
* Set the specified native header value.
*/
protected void putNativeHeader(String name, List<String> value) {
public void setNativeHeader(String name, String value) {
if (!ObjectUtils.nullSafeEquals(value, getHeader(name))) {
this.nativeHeaders.put(name, value);
this.nativeHeaders.set(name, value);
}
}
public void setNativeHeader(String name, String value) {
this.nativeHeaders.set(name, value);
/**
* Add the specified native header value.
*/
public void addNativeHeader(String name, String value) {
this.nativeHeaders.add(name, value);
}
}

View File

@ -0,0 +1,130 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support.channel;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.BeanNameAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.MessagingException;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
/**
* Abstract base class for {@link MessageChannel} implementations.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractMessageChannel implements MessageChannel, BeanNameAware {
protected Log logger = LogFactory.getLog(getClass());
private String beanName;
private final ChannelInterceptorChain interceptorChain = new ChannelInterceptorChain();
public AbstractMessageChannel() {
this.beanName = getClass().getSimpleName() + "@" + ObjectUtils.getIdentityHexString(this);
}
/**
* {@inheritDoc}
* <p>Used primarily for logging purposes.
*/
@Override
public void setBeanName(String name) {
this.beanName = name;
}
/**
* @return the name for this channel.
*/
public String getBeanName() {
return this.beanName;
}
/**
* Set the list of channel interceptors. This will clear any existing interceptors.
*/
public void setInterceptors(List<ChannelInterceptor> interceptors) {
this.interceptorChain.set(interceptors);
}
/**
* Add a channel interceptor to the end of the list.
*/
public void addInterceptor(ChannelInterceptor interceptor) {
this.interceptorChain.add(interceptor);
}
/**
* Return a read-only list of the configured interceptors.
*/
public List<ChannelInterceptor> getInterceptors() {
return this.interceptorChain.getInterceptors();
}
/**
* Exposes the interceptor list for subclasses.
*/
protected ChannelInterceptorChain getInterceptorChain() {
return this.interceptorChain;
}
@Override
public final boolean send(Message<?> message) {
return send(message, INDEFINITE_TIMEOUT);
}
@Override
public final boolean send(Message<?> message, long timeout) {
Assert.notNull(message, "Message must not be null");
if (logger.isTraceEnabled()) {
logger.trace("[" + this.beanName + "] sending message " + message);
}
message = this.interceptorChain.preSend(message, this);
if (message == null) {
return false;
}
try {
boolean sent = sendInternal(message, timeout);
this.interceptorChain.postSend(message, this, sent);
return sent;
}
catch (Exception e) {
if (e instanceof MessagingException) {
throw (MessagingException) e;
}
throw new MessageDeliveryException(message,
"Failed to send message to channel '" + this.getBeanName() + "'", e);
}
}
protected abstract boolean sendInternal(Message<?> message, long timeout);
}

View File

@ -16,14 +16,8 @@
package org.springframework.messaging.support.channel;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.BeanNameAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
/**
@ -32,57 +26,17 @@ import org.springframework.util.ObjectUtils;
* @author Rossen Stoyanchev
* @since 4.0
*/
public abstract class AbstractSubscribableChannel implements SubscribableChannel, BeanNameAware {
public abstract class AbstractSubscribableChannel extends AbstractMessageChannel implements SubscribableChannel {
protected Log logger = LogFactory.getLog(getClass());
private String beanName;
public AbstractSubscribableChannel() {
this.beanName = getClass().getSimpleName() + "@" + ObjectUtils.getIdentityHexString(this);
}
/**
* {@inheritDoc}
* <p>Used primarily for logging purposes.
*/
@Override
public void setBeanName(String name) {
this.beanName = name;
}
/**
* @return the name for this channel.
*/
public String getBeanName() {
return this.beanName;
}
@Override
public final boolean send(Message<?> message) {
return send(message, INDEFINITE_TIMEOUT);
}
@Override
public final boolean send(Message<?> message, long timeout) {
Assert.notNull(message, "Message must not be null");
if (logger.isTraceEnabled()) {
logger.trace("[" + this.beanName + "] sending message " + message);
}
return sendInternal(message, timeout);
}
protected abstract boolean sendInternal(Message<?> message, long timeout);
@Override
public final boolean subscribe(MessageHandler handler) {
if (hasSubscription(handler)) {
logger.warn("[" + this.beanName + "] handler already subscribed " + handler);
logger.warn("[" + getBeanName() + "] handler already subscribed " + handler);
return false;
}
if (logger.isDebugEnabled()) {
logger.debug("[" + this.beanName + "] subscribing " + handler);
logger.debug("[" + getBeanName() + "] subscribing " + handler);
}
return subscribeInternal(handler);
}
@ -94,7 +48,7 @@ public abstract class AbstractSubscribableChannel implements SubscribableChannel
@Override
public final boolean unsubscribe(MessageHandler handler) {
if (logger.isDebugEnabled()) {
logger.debug("[" + this.beanName + "] unsubscribing " + handler);
logger.debug("[" + getBeanName() + "] unsubscribing " + handler);
}
return unsubscribeInternal(handler);
}

View File

@ -0,0 +1,60 @@
/*
* Copyright 2002-2010 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support.channel;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
/**
* Interface for interceptors that are able to view and/or modify the
* {@link Message Messages} being sent-to and/or received-from a
* {@link MessageChannel}.
*
* @author Mark Fisher
* @since 4.0
*/
public interface ChannelInterceptor {
/**
* Invoked before the Message is actually sent to the channel.
* This allows for modification of the Message if necessary.
* If this method returns <code>null</code>, then the actual
* send invocation will not occur.
*/
Message<?> preSend(Message<?> message, MessageChannel channel);
/**
* Invoked immediately after the send invocation. The boolean
* value argument represents the return value of that invocation.
*/
void postSend(Message<?> message, MessageChannel channel, boolean sent);
/**
* Invoked as soon as receive is called and before a Message is
* actually retrieved. If the return value is 'false', then no
* Message will be retrieved. This only applies to PollableChannels.
*/
boolean preReceive(MessageChannel channel);
/**
* Invoked immediately after a Message has been retrieved but before
* it is returned to the caller. The Message may be modified if
* necessary. This only applies to PollableChannels.
*/
Message<?> postReceive(Message<?> message, MessageChannel channel);
}

View File

@ -0,0 +1,46 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support.channel;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
/**
* A {@link ChannelInterceptor} with empty method implementations as a convenience.
*
* @author Mark Fisher
* @since 4.0
*/
public class ChannelInterceptorAdapter implements ChannelInterceptor {
public Message<?> preSend(Message<?> message, MessageChannel channel) {
return message;
}
public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
}
public boolean preReceive(MessageChannel channel) {
return true;
}
public Message<?> postReceive(Message<?> message, MessageChannel channel) {
return message;
}
}

View File

@ -0,0 +1,110 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support.channel;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
/**
* A convenience wrapper class for invoking a list of {@link ChannelInterceptor}s.
*
* @author Mark Fisher
* @author Rossen Stoyanchev
* @since 4.0
*/
class ChannelInterceptorChain {
private static final Log logger = LogFactory.getLog(ChannelInterceptorChain.class);
private final List<ChannelInterceptor> interceptors = new CopyOnWriteArrayList<ChannelInterceptor>();
public boolean set(List<ChannelInterceptor> interceptors) {
synchronized (this.interceptors) {
this.interceptors.clear();
return this.interceptors.addAll(interceptors);
}
}
public boolean add(ChannelInterceptor interceptor) {
return this.interceptors.add(interceptor);
}
public List<ChannelInterceptor> getInterceptors() {
return Collections.unmodifiableList(this.interceptors);
}
public Message<?> preSend(Message<?> message, MessageChannel channel) {
if (logger.isTraceEnabled()) {
logger.trace("preSend on channel '" + channel + "', message: " + message);
}
for (ChannelInterceptor interceptor : this.interceptors) {
message = interceptor.preSend(message, channel);
if (message == null) {
return null;
}
}
return message;
}
public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
if (logger.isTraceEnabled()) {
logger.trace("postSend (sent=" + sent + ") on channel '" + channel + "', message: " + message);
}
for (ChannelInterceptor interceptor : this.interceptors) {
interceptor.postSend(message, channel, sent);
}
}
public boolean preReceive(MessageChannel channel) {
if (logger.isTraceEnabled()) {
logger.trace("preReceive on channel '" + channel + "'");
}
for (ChannelInterceptor interceptor : this.interceptors) {
if (!interceptor.preReceive(channel)) {
return false;
}
}
return true;
}
public Message<?> postReceive(Message<?> message, MessageChannel channel) {
if (message != null && logger.isTraceEnabled()) {
logger.trace("postReceive on channel '" + channel + "', message: " + message);
}
else if (logger.isTraceEnabled()) {
logger.trace("postReceive on channel '" + channel + "', message is null");
}
for (ChannelInterceptor interceptor : this.interceptors) {
message = interceptor.postReceive(message, channel);
if (message == null) {
return null;
}
}
return message;
}
}

View File

@ -16,8 +16,14 @@
package org.springframework.messaging.simp.stomp;
import java.util.List;
import java.util.Map;
import org.junit.Test;
import org.springframework.http.MediaType;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import static org.junit.Assert.*;
@ -32,7 +38,8 @@ public class StompHeaderAccessorTests {
@Test
public void testStompCommandSet() {
public void createWithCommand() {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
assertEquals(StompCommand.CONNECTED, accessor.getCommand());
@ -40,4 +47,111 @@ public class StompHeaderAccessorTests {
assertEquals(StompCommand.CONNECTED, accessor.getCommand());
}
@Test
public void createWithSubscribeNativeHeaders() {
MultiValueMap<String, String> extHeaders = new LinkedMultiValueMap<>();
extHeaders.add(StompHeaderAccessor.STOMP_ID_HEADER, "s1");
extHeaders.add(StompHeaderAccessor.STOMP_DESTINATION_HEADER, "/d");
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE, extHeaders);
assertEquals(StompCommand.SUBSCRIBE, headers.getCommand());
assertEquals(SimpMessageType.SUBSCRIBE, headers.getMessageType());
assertEquals("/d", headers.getDestination());
assertEquals("s1", headers.getSubscriptionId());
}
@Test
public void createWithUnubscribeNativeHeaders() {
MultiValueMap<String, String> extHeaders = new LinkedMultiValueMap<>();
extHeaders.add(StompHeaderAccessor.STOMP_ID_HEADER, "s1");
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.UNSUBSCRIBE, extHeaders);
assertEquals(StompCommand.UNSUBSCRIBE, headers.getCommand());
assertEquals(SimpMessageType.UNSUBSCRIBE, headers.getMessageType());
assertEquals("s1", headers.getSubscriptionId());
}
@Test
public void createWithMessageFrameNativeHeaders() {
MultiValueMap<String, String> extHeaders = new LinkedMultiValueMap<>();
extHeaders.add(StompHeaderAccessor.DESTINATION_HEADER, "/d");
extHeaders.add(StompHeaderAccessor.STOMP_SUBSCRIPTION_HEADER, "s1");
extHeaders.add(StompHeaderAccessor.STOMP_CONTENT_TYPE_HEADER, "application/json");
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE, extHeaders);
assertEquals(StompCommand.MESSAGE, headers.getCommand());
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
assertEquals("s1", headers.getSubscriptionId());
}
@Test
public void toNativeHeadersSubscribe() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
headers.setSubscriptionId("s1");
headers.setDestination("/d");
Map<String, List<String>> actual = headers.toNativeHeaderMap();
assertEquals(2, actual.size());
assertEquals("s1", actual.get(StompHeaderAccessor.STOMP_ID_HEADER).get(0));
assertEquals("/d", actual.get(StompHeaderAccessor.STOMP_DESTINATION_HEADER).get(0));
}
@Test
public void toNativeHeadersUnsubscribe() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.UNSUBSCRIBE);
headers.setSubscriptionId("s1");
Map<String, List<String>> actual = headers.toNativeHeaderMap();
assertEquals(1, actual.size());
assertEquals("s1", actual.get(StompHeaderAccessor.STOMP_ID_HEADER).get(0));
}
@Test
public void toNativeHeadersMessageFrame() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
headers.setSubscriptionId("s1");
headers.setDestination("/d");
headers.setContentType(MediaType.APPLICATION_JSON);
Map<String, List<String>> actual = headers.toNativeHeaderMap();
assertEquals(4, actual.size());
assertEquals("s1", actual.get(StompHeaderAccessor.STOMP_SUBSCRIPTION_HEADER).get(0));
assertEquals("/d", actual.get(StompHeaderAccessor.STOMP_DESTINATION_HEADER).get(0));
assertEquals("application/json", actual.get(StompHeaderAccessor.STOMP_CONTENT_TYPE_HEADER).get(0));
assertNotNull("message-id was not created", actual.get(StompHeaderAccessor.STOMP_MESSAGE_ID_HEADER).get(0));
}
@Test
public void modifyCustomNativeHeader() {
MultiValueMap<String, String> extHeaders = new LinkedMultiValueMap<>();
extHeaders.add(StompHeaderAccessor.STOMP_ID_HEADER, "s1");
extHeaders.add(StompHeaderAccessor.STOMP_DESTINATION_HEADER, "/d");
extHeaders.add("accountId", "ABC123");
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE, extHeaders);
String accountId = headers.getFirstNativeHeader("accountId");
headers.setNativeHeader("accountId", accountId.toLowerCase());
Map<String, List<String>> actual = headers.toNativeHeaderMap();
assertEquals(3, actual.size());
assertEquals("s1", actual.get(StompHeaderAccessor.STOMP_ID_HEADER).get(0));
assertEquals("/d", actual.get(StompHeaderAccessor.STOMP_DESTINATION_HEADER).get(0));
assertNotNull("abc123", actual.get("accountId").get(0));
}
}

View File

@ -0,0 +1,78 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.springframework.messaging.MessageHeaders;
import static org.junit.Assert.*;
/**
* Test fixture for {@link MessageHeaderAccessor}.
*
* @author Rossen Stoyanchev
*/
public class MessageHeaderAccessorTests {
@Test
public void empty() {
MessageHeaderAccessor headers = new MessageHeaderAccessor();
assertEquals(Collections.emptyMap(), headers.toMap());
}
@Test
public void wrapMessage() {
Map<String, Object> original = new HashMap<>();
original.put("foo", "bar");
original.put("bar", "baz");
GenericMessage<String> message = new GenericMessage<>("p", original);
MessageHeaderAccessor headers = new MessageHeaderAccessor(message);
Map<String, Object> actual = headers.toMap();
assertEquals(4, actual.size());
assertNotNull(actual.get(MessageHeaders.ID));
assertNotNull(actual.get(MessageHeaders.TIMESTAMP));
assertEquals("bar", actual.get("foo"));
assertEquals("baz", actual.get("bar"));
}
@Test
public void wrapMessageAndModifyHeaders() {
Map<String, Object> original = new HashMap<>();
original.put("foo", "bar");
original.put("bar", "baz");
GenericMessage<String> message = new GenericMessage<>("p", original);
MessageHeaderAccessor headers = new MessageHeaderAccessor(message);
headers.setHeader("foo", "BAR");
Map<String, Object> actual = headers.toMap();
assertEquals(4, actual.size());
assertNotNull(actual.get(MessageHeaders.ID));
assertNotNull(actual.get(MessageHeaders.TIMESTAMP));
assertEquals("BAR", actual.get("foo"));
assertEquals("baz", actual.get("bar"));
}
}

View File

@ -0,0 +1,127 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import static org.junit.Assert.*;
/**
* Test fixture for {@link NativeMessageHeaderAccessor}.
*
* @author Rossen Stoyanchev
*/
public class NativeMessageHeaderAccessorTests {
@Test
public void originalNativeHeaders() {
MultiValueMap<String, String> original = new LinkedMultiValueMap<>();
original.add("foo", "bar");
original.add("bar", "baz");
NativeMessageHeaderAccessor headers = new NativeMessageHeaderAccessor(original);
Map<String, Object> actual = headers.toMap();
assertEquals(1, actual.size());
assertNotNull(actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS));
assertEquals(original, actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS));
}
@Test
public void wrapMessage() {
MultiValueMap<String, String> originalNativeHeaders = new LinkedMultiValueMap<>();
originalNativeHeaders.add("foo", "bar");
originalNativeHeaders.add("bar", "baz");
Map<String, Object> original = new HashMap<String, Object>();
original.put("a", "b");
original.put(NativeMessageHeaderAccessor.NATIVE_HEADERS, originalNativeHeaders);
GenericMessage<String> message = new GenericMessage<>("p", original);
NativeMessageHeaderAccessor headers = new NativeMessageHeaderAccessor(message);
Map<String, Object> actual = headers.toMap();
assertEquals(4, actual.size());
assertNotNull(actual.get(MessageHeaders.ID));
assertNotNull(actual.get(MessageHeaders.TIMESTAMP));
assertEquals("b", actual.get("a"));
assertNotNull(actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS));
assertEquals(originalNativeHeaders, actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS));
}
@Test
public void wrapNullMessage() {
NativeMessageHeaderAccessor headers = new NativeMessageHeaderAccessor((Message<?>) null);
Map<String, Object> actual = headers.toMap();
assertEquals(1, actual.size());
@SuppressWarnings("unchecked")
Map<String, List<String>> actualNativeHeaders =
(Map<String, List<String>>) actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS);
assertEquals(Collections.emptyMap(), actualNativeHeaders);
}
@Test
public void wrapMessageAndModifyHeaders() {
MultiValueMap<String, String> originalNativeHeaders = new LinkedMultiValueMap<>();
originalNativeHeaders.add("foo", "bar");
originalNativeHeaders.add("bar", "baz");
Map<String, Object> original = new HashMap<String, Object>();
original.put("a", "b");
original.put(NativeMessageHeaderAccessor.NATIVE_HEADERS, originalNativeHeaders);
GenericMessage<String> message = new GenericMessage<>("p", original);
NativeMessageHeaderAccessor headers = new NativeMessageHeaderAccessor(message);
headers.setHeader("a", "B");
headers.setNativeHeader("foo", "BAR");
Map<String, Object> actual = headers.toMap();
assertEquals(4, actual.size());
assertNotNull(actual.get(MessageHeaders.ID));
assertNotNull(actual.get(MessageHeaders.TIMESTAMP));
assertEquals("B", actual.get("a"));
@SuppressWarnings("unchecked")
Map<String, List<String>> actualNativeHeaders =
(Map<String, List<String>>) actual.get(NativeMessageHeaderAccessor.NATIVE_HEADERS);
assertNotNull(actualNativeHeaders);
assertEquals(Arrays.asList("BAR"), actualNativeHeaders.get("foo"));
assertEquals(Arrays.asList("baz"), actualNativeHeaders.get("bar"));
}
}

View File

@ -0,0 +1,156 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support.channel;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Before;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.support.MessageBuilder;
import static org.junit.Assert.*;
/**
* Test fixture for the use of {@link ChannelInterceptor}s.
* @author Rossen Stoyanchev
*/
public class ChannelInterceptorTests {
private ExecutorSubscribableChannel channel;
private TestMessageHandler messageHandler;
@Before
public void setup() {
this.channel = new ExecutorSubscribableChannel();
this.messageHandler = new TestMessageHandler();
this.channel.subscribe(this.messageHandler);
}
@Test
public void preSendInterceptorReturningModifiedMessage() {
this.channel.addInterceptor(new PreSendReturnsMessageInterceptor());
this.channel.send(MessageBuilder.withPayload("test").build());
assertEquals(1, this.messageHandler.messages.size());
Message<?> result = this.messageHandler.messages.get(0);
assertNotNull(result);
assertEquals("test", result.getPayload());
assertEquals(1, result.getHeaders().get(PreSendReturnsMessageInterceptor.class.getSimpleName()));
}
@Test
public void preSendInterceptorReturningNull() {
PreSendReturnsNullInterceptor interceptor = new PreSendReturnsNullInterceptor();
this.channel.addInterceptor(interceptor);
Message<?> message = MessageBuilder.withPayload("test").build();
this.channel.send(message);
assertEquals(1, interceptor.counter.get());
assertEquals(0, this.messageHandler.messages.size());
}
@Test
public void postSendInterceptorMessageWasSent() {
final AtomicBoolean invoked = new AtomicBoolean(false);
this.channel.addInterceptor(new ChannelInterceptorAdapter() {
@Override
public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
assertNotNull(message);
assertNotNull(channel);
assertSame(ChannelInterceptorTests.this.channel, channel);
assertTrue(sent);
invoked.set(true);
}
});
this.channel.send(MessageBuilder.withPayload("test").build());
assertTrue(invoked.get());
}
@Test
public void postSendInterceptorMessageWasNotSent() {
final AbstractMessageChannel testChannel = new AbstractMessageChannel() {
@Override
protected boolean sendInternal(Message<?> message, long timeout) {
return false;
}
};
final AtomicBoolean invoked = new AtomicBoolean(false);
testChannel.addInterceptor(new ChannelInterceptorAdapter() {
@Override
public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
assertNotNull(message);
assertNotNull(channel);
assertSame(testChannel, channel);
assertFalse(sent);
invoked.set(true);
}
});
testChannel.send(MessageBuilder.withPayload("test").build());
assertTrue(invoked.get());
}
private static class TestMessageHandler implements MessageHandler {
private List<Message<?>> messages = new ArrayList<Message<?>>();
@Override
public void handleMessage(Message<?> message) throws MessagingException {
this.messages.add(message);
}
}
private static class PreSendReturnsMessageInterceptor extends ChannelInterceptorAdapter {
private AtomicInteger counter = new AtomicInteger();
private String foo;
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
assertNotNull(message);
return MessageBuilder.fromMessage(message).setHeader(
this.getClass().getSimpleName(), counter.incrementAndGet()).build();
}
}
private static class PreSendReturnsNullInterceptor extends ChannelInterceptorAdapter {
private AtomicInteger counter = new AtomicInteger();
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
assertNotNull(message);
counter.incrementAndGet();
return null;
}
}
}

View File

@ -26,6 +26,7 @@ import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.support.MessageBuilder;
@ -117,8 +118,8 @@ public class PublishSubscibeChannelTests {
try {
this.channel.send(message);
}
catch(RuntimeException actualException) {
assertThat(actualException, equalTo(ex));
catch(MessageDeliveryException actualException) {
assertThat((RuntimeException) actualException.getCause(), equalTo(ex));
}
verifyZeroInteractions(secondHandler);
}