From 687955a70412530a465f180f4b36c3a4e9b274bb Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 23 Oct 2014 15:21:09 -0400 Subject: [PATCH] Add ImmutableMessageChannelInterceptor This change adds a ChannelInterceptor that flips the immutable flag on messages being sent. This allows components sending messages to leave the message mutable for interceptors to further apply modifications before the message is sent (and exposed to concurrency). The interceptor is automatically added with the STOMP/WebSocket Java and XML config and the StompSubProtocolHandler leaves parsed incoming messages mutable so they can be further modified before being sent. Issue: SPR-12321 --- .../AbstractMessageBrokerConfiguration.java | 4 ++ .../ImmutableMessageChannelInterceptor.java | 45 +++++++++++++++ .../MessageBrokerConfigurationTests.java | 6 +- .../MessageBrokerBeanDefinitionParser.java | 8 ++- .../messaging/StompSubProtocolHandler.java | 25 ++++++++- ...essageBrokerBeanDefinitionParserTests.java | 24 ++++---- ...essageBrokerConfigurationSupportTests.java | 56 +++++++++++++++---- .../StompSubProtocolHandlerTests.java | 51 +++++++++++++++++ 8 files changed, 192 insertions(+), 27 deletions(-) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/support/ImmutableMessageChannelInterceptor.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java index 4c738bdbd4b..d40b6e4a40a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java @@ -45,6 +45,7 @@ import org.springframework.messaging.simp.user.UserDestinationResolver; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.messaging.support.ExecutorSubscribableChannel; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.util.ClassUtils; import org.springframework.util.MimeTypeUtils; @@ -118,6 +119,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC if (this.clientInboundChannelRegistration == null) { ChannelRegistration registration = new ChannelRegistration(); configureClientInboundChannel(registration); + registration.setInterceptors(new ImmutableMessageChannelInterceptor()); this.clientInboundChannelRegistration = registration; } return this.clientInboundChannelRegistration; @@ -152,6 +154,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC if (this.clientOutboundChannelRegistration == null) { ChannelRegistration registration = new ChannelRegistration(); configureClientOutboundChannel(registration); + registration.setInterceptors(new ImmutableMessageChannelInterceptor()); this.clientOutboundChannelRegistration = registration; } return this.clientOutboundChannelRegistration; @@ -169,6 +172,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC ChannelRegistration reg = getBrokerRegistry().getBrokerChannelRegistration(); ExecutorSubscribableChannel channel = reg.hasTaskExecutor() ? new ExecutorSubscribableChannel(brokerChannelExecutor()) : new ExecutorSubscribableChannel(); + reg.setInterceptors(new ImmutableMessageChannelInterceptor()); channel.setInterceptors(reg.getInterceptors()); return channel; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/ImmutableMessageChannelInterceptor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/ImmutableMessageChannelInterceptor.java new file mode 100644 index 00000000000..c2921857fd6 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/ImmutableMessageChannelInterceptor.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; + +/** + * A simpler interceptor that calls {@link MessageHeaderAccessor#setImmutable()} + * on the headers of messages passed through the preSend method. + * + *

When configured as the last interceptor in a chain, it allows the component + * sending the message to leave headers mutable for interceptors to modify prior + * to the message actually being sent and exposed to concurrent access. + * + * @author Rossen Stoyanchev + * @since 4.1.2 + */ +public class ImmutableMessageChannelInterceptor extends ChannelInterceptorAdapter { + + + @Override + public Message preSend(Message message, MessageChannel channel) { + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); + if (accessor != null && accessor.isMutable()) { + accessor.setImmutable(); + } + return message; + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index 5e2aa0b586b..bd45517c774 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -111,7 +111,7 @@ public class MessageBrokerConfigurationTests { AbstractSubscribableChannel channel = this.customContext.getBean( "clientInboundChannel", AbstractSubscribableChannel.class); - assertEquals(2, channel.getInterceptors().size()); + assertEquals(3, channel.getInterceptors().size()); CustomThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "clientInboundChannelExecutor", CustomThreadPoolTaskExecutor.class); @@ -178,7 +178,7 @@ public class MessageBrokerConfigurationTests { AbstractSubscribableChannel channel = this.customContext.getBean( "clientOutboundChannel", AbstractSubscribableChannel.class); - assertEquals(2, channel.getInterceptors().size()); + assertEquals(3, channel.getInterceptors().size()); ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "clientOutboundChannelExecutor", ThreadPoolTaskExecutor.class); @@ -257,7 +257,7 @@ public class MessageBrokerConfigurationTests { AbstractSubscribableChannel channel = this.customContext.getBean( "brokerChannel", AbstractSubscribableChannel.class); - assertEquals(3, channel.getInterceptors().size()); + assertEquals(4, channel.getInterceptors().size()); ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "brokerChannelExecutor", ThreadPoolTaskExecutor.class); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index 22b55cdf562..0c06e0be850 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.w3c.dom.Element; import org.springframework.beans.MutablePropertyValues; @@ -201,11 +202,14 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { argValues.addIndexedArgumentValue(0, new RuntimeBeanReference(executorName)); } RootBeanDefinition channelDef = new RootBeanDefinition(ExecutorSubscribableChannel.class, argValues, null); + ManagedList interceptors = new ManagedList(); if (element != null) { Element interceptorsElement = DomUtils.getChildElementByTagName(element, "interceptors"); - ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); - channelDef.getPropertyValues().add("interceptors", interceptors); + interceptors.addAll(WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context)); } + interceptors.add(new ImmutableMessageChannelInterceptor()); + channelDef.getPropertyValues().add("interceptors", interceptors); + registerBeanDefByName(name, channelDef, context, source); return new RuntimeBeanReference(name); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 6fe70058839..82663c1a64a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -45,6 +45,9 @@ import org.springframework.messaging.simp.stomp.StompEncoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.simp.user.UserSessionRegistry; +import org.springframework.messaging.support.AbstractMessageChannel; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.support.MessageHeaderInitializer; @@ -99,6 +102,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE private MessageHeaderInitializer headerInitializer; + private Boolean immutableMessageInterceptorPresent; + private ApplicationEventPublisher eventPublisher; private final Stats stats = new Stats(); @@ -234,7 +239,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE headerAccessor.setSessionId(session.getId()); headerAccessor.setSessionAttributes(session.getAttributes()); headerAccessor.setUser(session.getPrincipal()); - headerAccessor.setImmutable(); + if (!detectImmutableMessageInterceptor(outputChannel)) { + headerAccessor.setImmutable(); + } if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) { this.stats.incrementConnectCount(); @@ -271,6 +278,22 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } + private boolean detectImmutableMessageInterceptor(MessageChannel channel) { + if (this.immutableMessageInterceptorPresent != null) { + return this.immutableMessageInterceptorPresent; + } + if (channel instanceof AbstractMessageChannel) { + for (ChannelInterceptor interceptor : ((AbstractMessageChannel) channel).getInterceptors()) { + if (interceptor instanceof ImmutableMessageChannelInterceptor) { + this.immutableMessageInterceptorPresent = true; + return true; + } + } + } + this.immutableMessageInterceptorPresent = false; + return false; + } + private void publishEvent(ApplicationEvent event) { try { this.eventPublisher.publishEvent(event); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 88904ea2c99..d8fa79e8d08 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -50,6 +50,8 @@ import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.simp.user.UserDestinationResolver; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.AbstractSubscribableChannel; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.MimeTypeUtils; @@ -179,15 +181,15 @@ public class MessageBrokerBeanDefinitionParserTests { List> subscriberTypes = Arrays.>asList(SimpAnnotationMethodMessageHandler.class, UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class); - testChannel("clientInboundChannel", subscriberTypes, 1); + testChannel("clientInboundChannel", subscriberTypes, 2); testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); - testChannel("clientOutboundChannel", subscriberTypes, 0); + testChannel("clientOutboundChannel", subscriberTypes, 1); testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList(SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class); - testChannel("brokerChannel", subscriberTypes, 0); + testChannel("brokerChannel", subscriberTypes, 1); try { this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class); fail("expected exception"); @@ -247,16 +249,16 @@ public class MessageBrokerBeanDefinitionParserTests { List> subscriberTypes = Arrays.>asList(SimpAnnotationMethodMessageHandler.class, UserDestinationMessageHandler.class, StompBrokerRelayMessageHandler.class); - testChannel("clientInboundChannel", subscriberTypes, 1); + testChannel("clientInboundChannel", subscriberTypes, 2); testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); - testChannel("clientOutboundChannel", subscriberTypes, 0); + testChannel("clientOutboundChannel", subscriberTypes, 1); testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList( StompBrokerRelayMessageHandler.class, UserDestinationMessageHandler.class); - testChannel("brokerChannel", subscriberTypes, 0); + testChannel("brokerChannel", subscriberTypes, 1); try { this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class); fail("expected exception"); @@ -320,18 +322,18 @@ public class MessageBrokerBeanDefinitionParserTests { Arrays.>asList(SimpAnnotationMethodMessageHandler.class, UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class); - testChannel("clientInboundChannel", subscriberTypes, 2); + testChannel("clientInboundChannel", subscriberTypes, 3); testExecutor("clientInboundChannel", 100, 200, 600); subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); - testChannel("clientOutboundChannel", subscriberTypes, 2); + testChannel("clientOutboundChannel", subscriberTypes, 3); testExecutor("clientOutboundChannel", 101, 201, 601); subscriberTypes = Arrays.>asList(SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class); - testChannel("brokerChannel", subscriberTypes, 0); + testChannel("brokerChannel", subscriberTypes, 1); testExecutor("brokerChannel", 102, 202, 602); } @@ -397,7 +399,9 @@ public class MessageBrokerBeanDefinitionParserTests { assertTrue(channel.hasSubscription(subscriber)); } - assertEquals(interceptorCount, channel.getInterceptors().size()); + List interceptors = channel.getInterceptors(); + assertEquals(interceptorCount, interceptors.size()); + assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); } private void testExecutor(String channelName, int corePoolSize, int maxPoolSize, int keepAliveSeconds) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java index 66a23e1de42..74fd23c02c3 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java @@ -16,14 +16,19 @@ package org.springframework.web.socket.config.annotation; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ScheduledThreadPoolExecutor; import org.junit.Test; - import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; @@ -34,10 +39,14 @@ import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SubscribeMapping; +import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.support.AbstractSubscribableChannel; +import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ExecutorSubscribableChannel; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.stereotype.Controller; import org.springframework.web.servlet.HandlerMapping; @@ -55,8 +64,6 @@ import org.springframework.web.socket.messaging.SubProtocolHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; -import static org.junit.Assert.*; - /** * Test fixture for * {@link org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurationSupport}. @@ -83,6 +90,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { TestChannel channel = config.getBean("clientInboundChannel", TestChannel.class); SubProtocolWebSocketHandler webSocketHandler = config.getBean(SubProtocolWebSocketHandler.class); + List interceptors = channel.getInterceptors(); + assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); + WebSocketSession session = new TestWebSocketSession("s1"); webSocketHandler.afterConnectionEstablished(session); @@ -90,22 +100,40 @@ public class WebSocketMessageBrokerConfigurationSupportTests { webSocketHandler.handleMessage(session, textMessage); Message message = channel.messages.get(0); - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - - assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); - assertEquals("/foo", headers.getDestination()); + StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertNotNull(accessor); + assertFalse(accessor.isMutable()); + assertEquals(SimpMessageType.MESSAGE, accessor.getMessageType()); + assertEquals("/foo", accessor.getDestination()); } @Test - public void clientOutboundChannelChannel() { + public void clientOutboundChannel() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); TestChannel channel = config.getBean("clientOutboundChannel", TestChannel.class); Set handlers = channel.getSubscribers(); + List interceptors = channel.getInterceptors(); + assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); + assertEquals(1, handlers.size()); assertTrue(handlers.iterator().next() instanceof SubProtocolWebSocketHandler); } + @Test + public void brokerChannel() { + ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); + TestChannel channel = config.getBean("brokerChannel", TestChannel.class); + Iterator handlers = channel.getSubscribers().iterator(); + + List interceptors = channel.getInterceptors(); + assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); + + assertEquals(SimpleBrokerMessageHandler.class, handlers.next().getClass()); + assertEquals(UserDestinationMessageHandler.class, handlers.next().getClass()); + assertFalse(handlers.hasNext()); + } + @Test public void webSocketTransportOptions() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); @@ -216,18 +244,24 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @Override @Bean public AbstractSubscribableChannel clientInboundChannel() { - return new TestChannel(); + TestChannel channel = new TestChannel(); + channel.setInterceptors(super.clientInboundChannel().getInterceptors()); + return channel; } @Override @Bean public AbstractSubscribableChannel clientOutboundChannel() { - return new TestChannel(); + TestChannel channel = new TestChannel(); + channel.setInterceptors(super.clientOutboundChannel().getInterceptors()); + return channel; } @Override public AbstractSubscribableChannel brokerChannel() { - return new TestChannel(); + TestChannel channel = new TestChannel(); + channel.setInterceptors(super.brokerChannel().getInterceptors()); + return channel; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index fbe71560705..a850f644d34 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; import org.junit.Test; @@ -32,6 +33,9 @@ import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpAttributes; import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -44,7 +48,12 @@ import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.simp.user.UserSessionRegistry; +import org.springframework.messaging.support.AbstractMessageChannel; +import org.springframework.messaging.support.ChannelInterceptorAdapter; +import org.springframework.messaging.support.ExecutorSubscribableChannel; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; @@ -288,6 +297,48 @@ public class StompSubProtocolHandlerTests { assertEquals(0, this.session.getSentMessages().size()); } + @Test + public void handleMessageFromClientWithImmutableMessageInterceptor() { + AtomicReference mutable = new AtomicReference<>(); + ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(); + channel.addInterceptor(new ChannelInterceptorAdapter() { + @Override + public Message preSend(Message message, MessageChannel channel) { + mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable()); + return message; + } + }); + channel.addInterceptor(new ImmutableMessageChannelInterceptor()); + + StompSubProtocolHandler handler = new StompSubProtocolHandler(); + handler.afterSessionStarted(this.session, channel); + + TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build(); + handler.handleMessageFromClient(this.session, message, channel); + assertNotNull(mutable.get()); + assertTrue(mutable.get()); + } + + @Test + public void handleMessageFromClientWithoutImmutableMessageInterceptor() { + AtomicReference mutable = new AtomicReference<>(); + ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(); + channel.addInterceptor(new ChannelInterceptorAdapter() { + @Override + public Message preSend(Message message, MessageChannel channel) { + mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable()); + return message; + } + }); + + StompSubProtocolHandler handler = new StompSubProtocolHandler(); + handler.afterSessionStarted(this.session, channel); + + TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build(); + handler.handleMessageFromClient(this.session, message, channel); + assertNotNull(mutable.get()); + assertFalse(mutable.get()); + } @Test public void handleMessageFromClientInvalidStompCommand() {