Add ImmutableMessageChannelInterceptor

This change adds a ChannelInterceptor that flips the immutable flag on
messages being sent. This allows components sending messages to leave
the message mutable for interceptors to further apply modifications
before the message is sent (and exposed to concurrency).

The interceptor is automatically added with the STOMP/WebSocket Java
and XML config and the StompSubProtocolHandler leaves parsed incoming
messages mutable so they can be further modified before being sent.

Issue: SPR-12321
This commit is contained in:
Rossen Stoyanchev 2014-10-23 15:21:09 -04:00
parent d5bf6713ed
commit 687955a704
8 changed files with 192 additions and 27 deletions

View File

@ -45,6 +45,7 @@ import org.springframework.messaging.simp.user.UserDestinationResolver;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.ClassUtils;
import org.springframework.util.MimeTypeUtils;
@ -118,6 +119,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
if (this.clientInboundChannelRegistration == null) {
ChannelRegistration registration = new ChannelRegistration();
configureClientInboundChannel(registration);
registration.setInterceptors(new ImmutableMessageChannelInterceptor());
this.clientInboundChannelRegistration = registration;
}
return this.clientInboundChannelRegistration;
@ -152,6 +154,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
if (this.clientOutboundChannelRegistration == null) {
ChannelRegistration registration = new ChannelRegistration();
configureClientOutboundChannel(registration);
registration.setInterceptors(new ImmutableMessageChannelInterceptor());
this.clientOutboundChannelRegistration = registration;
}
return this.clientOutboundChannelRegistration;
@ -169,6 +172,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC
ChannelRegistration reg = getBrokerRegistry().getBrokerChannelRegistration();
ExecutorSubscribableChannel channel = reg.hasTaskExecutor() ?
new ExecutorSubscribableChannel(brokerChannelExecutor()) : new ExecutorSubscribableChannel();
reg.setInterceptors(new ImmutableMessageChannelInterceptor());
channel.setInterceptors(reg.getInterceptors());
return channel;
}

View File

@ -0,0 +1,45 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.support;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
/**
* A simpler interceptor that calls {@link MessageHeaderAccessor#setImmutable()}
* on the headers of messages passed through the preSend method.
*
* <p>When configured as the last interceptor in a chain, it allows the component
* sending the message to leave headers mutable for interceptors to modify prior
* to the message actually being sent and exposed to concurrent access.
*
* @author Rossen Stoyanchev
* @since 4.1.2
*/
public class ImmutableMessageChannelInterceptor extends ChannelInterceptorAdapter {
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
if (accessor != null && accessor.isMutable()) {
accessor.setImmutable();
}
return message;
}
}

View File

@ -111,7 +111,7 @@ public class MessageBrokerConfigurationTests {
AbstractSubscribableChannel channel = this.customContext.getBean(
"clientInboundChannel", AbstractSubscribableChannel.class);
assertEquals(2, channel.getInterceptors().size());
assertEquals(3, channel.getInterceptors().size());
CustomThreadPoolTaskExecutor taskExecutor = this.customContext.getBean(
"clientInboundChannelExecutor", CustomThreadPoolTaskExecutor.class);
@ -178,7 +178,7 @@ public class MessageBrokerConfigurationTests {
AbstractSubscribableChannel channel = this.customContext.getBean(
"clientOutboundChannel", AbstractSubscribableChannel.class);
assertEquals(2, channel.getInterceptors().size());
assertEquals(3, channel.getInterceptors().size());
ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean(
"clientOutboundChannelExecutor", ThreadPoolTaskExecutor.class);
@ -257,7 +257,7 @@ public class MessageBrokerConfigurationTests {
AbstractSubscribableChannel channel = this.customContext.getBean(
"brokerChannel", AbstractSubscribableChannel.class);
assertEquals(3, channel.getInterceptors().size());
assertEquals(4, channel.getInterceptors().size());
ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean(
"brokerChannelExecutor", ThreadPoolTaskExecutor.class);

View File

@ -21,6 +21,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.w3c.dom.Element;
import org.springframework.beans.MutablePropertyValues;
@ -201,11 +202,14 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
argValues.addIndexedArgumentValue(0, new RuntimeBeanReference(executorName));
}
RootBeanDefinition channelDef = new RootBeanDefinition(ExecutorSubscribableChannel.class, argValues, null);
ManagedList<? super Object> interceptors = new ManagedList<Object>();
if (element != null) {
Element interceptorsElement = DomUtils.getChildElementByTagName(element, "interceptors");
ManagedList<?> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
channelDef.getPropertyValues().add("interceptors", interceptors);
interceptors.addAll(WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context));
}
interceptors.add(new ImmutableMessageChannelInterceptor());
channelDef.getPropertyValues().add("interceptors", interceptors);
registerBeanDefByName(name, channelDef, context, source);
return new RuntimeBeanReference(name);
}

View File

@ -45,6 +45,9 @@ import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.AbstractMessageChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderInitializer;
@ -99,6 +102,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
private MessageHeaderInitializer headerInitializer;
private Boolean immutableMessageInterceptorPresent;
private ApplicationEventPublisher eventPublisher;
private final Stats stats = new Stats();
@ -234,7 +239,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(session.getPrincipal());
headerAccessor.setImmutable();
if (!detectImmutableMessageInterceptor(outputChannel)) {
headerAccessor.setImmutable();
}
if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
this.stats.incrementConnectCount();
@ -271,6 +278,22 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
}
}
private boolean detectImmutableMessageInterceptor(MessageChannel channel) {
if (this.immutableMessageInterceptorPresent != null) {
return this.immutableMessageInterceptorPresent;
}
if (channel instanceof AbstractMessageChannel) {
for (ChannelInterceptor interceptor : ((AbstractMessageChannel) channel).getInterceptors()) {
if (interceptor instanceof ImmutableMessageChannelInterceptor) {
this.immutableMessageInterceptorPresent = true;
return true;
}
}
}
this.immutableMessageInterceptorPresent = false;
return false;
}
private void publishEvent(ApplicationEvent event) {
try {
this.eventPublisher.publishEvent(event);

View File

@ -50,6 +50,8 @@ import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserDestinationResolver;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.MimeTypeUtils;
@ -179,15 +181,15 @@ public class MessageBrokerBeanDefinitionParserTests {
List<Class<? extends MessageHandler>> subscriberTypes =
Arrays.<Class<? extends MessageHandler>>asList(SimpAnnotationMethodMessageHandler.class,
UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class);
testChannel("clientInboundChannel", subscriberTypes, 1);
testChannel("clientInboundChannel", subscriberTypes, 2);
testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(SubProtocolWebSocketHandler.class);
testChannel("clientOutboundChannel", subscriberTypes, 0);
testChannel("clientOutboundChannel", subscriberTypes, 1);
testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class);
testChannel("brokerChannel", subscriberTypes, 0);
testChannel("brokerChannel", subscriberTypes, 1);
try {
this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class);
fail("expected exception");
@ -247,16 +249,16 @@ public class MessageBrokerBeanDefinitionParserTests {
List<Class<? extends MessageHandler>> subscriberTypes =
Arrays.<Class<? extends MessageHandler>>asList(SimpAnnotationMethodMessageHandler.class,
UserDestinationMessageHandler.class, StompBrokerRelayMessageHandler.class);
testChannel("clientInboundChannel", subscriberTypes, 1);
testChannel("clientInboundChannel", subscriberTypes, 2);
testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(SubProtocolWebSocketHandler.class);
testChannel("clientOutboundChannel", subscriberTypes, 0);
testChannel("clientOutboundChannel", subscriberTypes, 1);
testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(
StompBrokerRelayMessageHandler.class, UserDestinationMessageHandler.class);
testChannel("brokerChannel", subscriberTypes, 0);
testChannel("brokerChannel", subscriberTypes, 1);
try {
this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class);
fail("expected exception");
@ -320,18 +322,18 @@ public class MessageBrokerBeanDefinitionParserTests {
Arrays.<Class<? extends MessageHandler>>asList(SimpAnnotationMethodMessageHandler.class,
UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class);
testChannel("clientInboundChannel", subscriberTypes, 2);
testChannel("clientInboundChannel", subscriberTypes, 3);
testExecutor("clientInboundChannel", 100, 200, 600);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(SubProtocolWebSocketHandler.class);
testChannel("clientOutboundChannel", subscriberTypes, 2);
testChannel("clientOutboundChannel", subscriberTypes, 3);
testExecutor("clientOutboundChannel", 101, 201, 601);
subscriberTypes = Arrays.<Class<? extends MessageHandler>>asList(SimpleBrokerMessageHandler.class,
UserDestinationMessageHandler.class);
testChannel("brokerChannel", subscriberTypes, 0);
testChannel("brokerChannel", subscriberTypes, 1);
testExecutor("brokerChannel", 102, 202, 602);
}
@ -397,7 +399,9 @@ public class MessageBrokerBeanDefinitionParserTests {
assertTrue(channel.hasSubscription(subscriber));
}
assertEquals(interceptorCount, channel.getInterceptors().size());
List<ChannelInterceptor> interceptors = channel.getInterceptors();
assertEquals(interceptorCount, interceptors.size());
assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass());
}
private void testExecutor(String channelName, int corePoolSize, int maxPoolSize, int keepAliveSeconds) {

View File

@ -16,14 +16,19 @@
package org.springframework.web.socket.config.annotation;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import org.junit.Test;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
@ -34,10 +39,14 @@ import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.stereotype.Controller;
import org.springframework.web.servlet.HandlerMapping;
@ -55,8 +64,6 @@ import org.springframework.web.socket.messaging.SubProtocolHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import static org.junit.Assert.*;
/**
* Test fixture for
* {@link org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurationSupport}.
@ -83,6 +90,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
TestChannel channel = config.getBean("clientInboundChannel", TestChannel.class);
SubProtocolWebSocketHandler webSocketHandler = config.getBean(SubProtocolWebSocketHandler.class);
List<ChannelInterceptor> interceptors = channel.getInterceptors();
assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass());
WebSocketSession session = new TestWebSocketSession("s1");
webSocketHandler.afterConnectionEstablished(session);
@ -90,22 +100,40 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
webSocketHandler.handleMessage(session, textMessage);
Message<?> message = channel.messages.get(0);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
assertEquals("/foo", headers.getDestination());
StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
assertNotNull(accessor);
assertFalse(accessor.isMutable());
assertEquals(SimpMessageType.MESSAGE, accessor.getMessageType());
assertEquals("/foo", accessor.getDestination());
}
@Test
public void clientOutboundChannelChannel() {
public void clientOutboundChannel() {
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
TestChannel channel = config.getBean("clientOutboundChannel", TestChannel.class);
Set<MessageHandler> handlers = channel.getSubscribers();
List<ChannelInterceptor> interceptors = channel.getInterceptors();
assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass());
assertEquals(1, handlers.size());
assertTrue(handlers.iterator().next() instanceof SubProtocolWebSocketHandler);
}
@Test
public void brokerChannel() {
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
TestChannel channel = config.getBean("brokerChannel", TestChannel.class);
Iterator<MessageHandler> handlers = channel.getSubscribers().iterator();
List<ChannelInterceptor> interceptors = channel.getInterceptors();
assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass());
assertEquals(SimpleBrokerMessageHandler.class, handlers.next().getClass());
assertEquals(UserDestinationMessageHandler.class, handlers.next().getClass());
assertFalse(handlers.hasNext());
}
@Test
public void webSocketTransportOptions() {
ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class);
@ -216,18 +244,24 @@ public class WebSocketMessageBrokerConfigurationSupportTests {
@Override
@Bean
public AbstractSubscribableChannel clientInboundChannel() {
return new TestChannel();
TestChannel channel = new TestChannel();
channel.setInterceptors(super.clientInboundChannel().getInterceptors());
return channel;
}
@Override
@Bean
public AbstractSubscribableChannel clientOutboundChannel() {
return new TestChannel();
TestChannel channel = new TestChannel();
channel.setInterceptors(super.clientOutboundChannel().getInterceptors());
return channel;
}
@Override
public AbstractSubscribableChannel brokerChannel() {
return new TestChannel();
TestChannel channel = new TestChannel();
channel.setInterceptors(super.brokerChannel().getInterceptors());
return channel;
}
}

View File

@ -22,6 +22,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Before;
import org.junit.Test;
@ -32,6 +33,9 @@ import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.simp.SimpAttributes;
import org.springframework.messaging.simp.SimpAttributesContextHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
@ -44,7 +48,12 @@ import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.AbstractMessageChannel;
import org.springframework.messaging.support.ChannelInterceptorAdapter;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
@ -288,6 +297,48 @@ public class StompSubProtocolHandlerTests {
assertEquals(0, this.session.getSentMessages().size());
}
@Test
public void handleMessageFromClientWithImmutableMessageInterceptor() {
AtomicReference<Boolean> mutable = new AtomicReference<>();
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
channel.addInterceptor(new ChannelInterceptorAdapter() {
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable());
return message;
}
});
channel.addInterceptor(new ImmutableMessageChannelInterceptor());
StompSubProtocolHandler handler = new StompSubProtocolHandler();
handler.afterSessionStarted(this.session, channel);
TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
handler.handleMessageFromClient(this.session, message, channel);
assertNotNull(mutable.get());
assertTrue(mutable.get());
}
@Test
public void handleMessageFromClientWithoutImmutableMessageInterceptor() {
AtomicReference<Boolean> mutable = new AtomicReference<>();
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
channel.addInterceptor(new ChannelInterceptorAdapter() {
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable());
return message;
}
});
StompSubProtocolHandler handler = new StompSubProtocolHandler();
handler.afterSessionStarted(this.session, channel);
TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
handler.handleMessageFromClient(this.session, message, channel);
assertNotNull(mutable.get());
assertFalse(mutable.get());
}
@Test
public void handleMessageFromClientInvalidStompCommand() {