diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java index 4c738bdbd4b..d40b6e4a40a 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java @@ -45,6 +45,7 @@ import org.springframework.messaging.simp.user.UserDestinationResolver; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.messaging.support.ExecutorSubscribableChannel; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.util.ClassUtils; import org.springframework.util.MimeTypeUtils; @@ -118,6 +119,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC if (this.clientInboundChannelRegistration == null) { ChannelRegistration registration = new ChannelRegistration(); configureClientInboundChannel(registration); + registration.setInterceptors(new ImmutableMessageChannelInterceptor()); this.clientInboundChannelRegistration = registration; } return this.clientInboundChannelRegistration; @@ -152,6 +154,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC if (this.clientOutboundChannelRegistration == null) { ChannelRegistration registration = new ChannelRegistration(); configureClientOutboundChannel(registration); + registration.setInterceptors(new ImmutableMessageChannelInterceptor()); this.clientOutboundChannelRegistration = registration; } return this.clientOutboundChannelRegistration; @@ -169,6 +172,7 @@ public abstract class AbstractMessageBrokerConfiguration implements ApplicationC ChannelRegistration reg = getBrokerRegistry().getBrokerChannelRegistration(); ExecutorSubscribableChannel channel = reg.hasTaskExecutor() ? new ExecutorSubscribableChannel(brokerChannelExecutor()) : new ExecutorSubscribableChannel(); + reg.setInterceptors(new ImmutableMessageChannelInterceptor()); channel.setInterceptors(reg.getInterceptors()); return channel; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/ImmutableMessageChannelInterceptor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/ImmutableMessageChannelInterceptor.java new file mode 100644 index 00000000000..c2921857fd6 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/ImmutableMessageChannelInterceptor.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.support; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; + +/** + * A simpler interceptor that calls {@link MessageHeaderAccessor#setImmutable()} + * on the headers of messages passed through the preSend method. + * + *

When configured as the last interceptor in a chain, it allows the component + * sending the message to leave headers mutable for interceptors to modify prior + * to the message actually being sent and exposed to concurrent access. + * + * @author Rossen Stoyanchev + * @since 4.1.2 + */ +public class ImmutableMessageChannelInterceptor extends ChannelInterceptorAdapter { + + + @Override + public Message preSend(Message message, MessageChannel channel) { + MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); + if (accessor != null && accessor.isMutable()) { + accessor.setImmutable(); + } + return message; + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java index 5e2aa0b586b..bd45517c774 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/config/MessageBrokerConfigurationTests.java @@ -111,7 +111,7 @@ public class MessageBrokerConfigurationTests { AbstractSubscribableChannel channel = this.customContext.getBean( "clientInboundChannel", AbstractSubscribableChannel.class); - assertEquals(2, channel.getInterceptors().size()); + assertEquals(3, channel.getInterceptors().size()); CustomThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "clientInboundChannelExecutor", CustomThreadPoolTaskExecutor.class); @@ -178,7 +178,7 @@ public class MessageBrokerConfigurationTests { AbstractSubscribableChannel channel = this.customContext.getBean( "clientOutboundChannel", AbstractSubscribableChannel.class); - assertEquals(2, channel.getInterceptors().size()); + assertEquals(3, channel.getInterceptors().size()); ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "clientOutboundChannelExecutor", ThreadPoolTaskExecutor.class); @@ -257,7 +257,7 @@ public class MessageBrokerConfigurationTests { AbstractSubscribableChannel channel = this.customContext.getBean( "brokerChannel", AbstractSubscribableChannel.class); - assertEquals(3, channel.getInterceptors().size()); + assertEquals(4, channel.getInterceptors().size()); ThreadPoolTaskExecutor taskExecutor = this.customContext.getBean( "brokerChannelExecutor", ThreadPoolTaskExecutor.class); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index 22b55cdf562..0c06e0be850 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.w3c.dom.Element; import org.springframework.beans.MutablePropertyValues; @@ -201,11 +202,14 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { argValues.addIndexedArgumentValue(0, new RuntimeBeanReference(executorName)); } RootBeanDefinition channelDef = new RootBeanDefinition(ExecutorSubscribableChannel.class, argValues, null); + ManagedList interceptors = new ManagedList(); if (element != null) { Element interceptorsElement = DomUtils.getChildElementByTagName(element, "interceptors"); - ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); - channelDef.getPropertyValues().add("interceptors", interceptors); + interceptors.addAll(WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context)); } + interceptors.add(new ImmutableMessageChannelInterceptor()); + channelDef.getPropertyValues().add("interceptors", interceptors); + registerBeanDefByName(name, channelDef, context, source); return new RuntimeBeanReference(name); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 6fe70058839..82663c1a64a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -45,6 +45,9 @@ import org.springframework.messaging.simp.stomp.StompEncoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.simp.user.UserSessionRegistry; +import org.springframework.messaging.support.AbstractMessageChannel; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.support.MessageHeaderInitializer; @@ -99,6 +102,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE private MessageHeaderInitializer headerInitializer; + private Boolean immutableMessageInterceptorPresent; + private ApplicationEventPublisher eventPublisher; private final Stats stats = new Stats(); @@ -234,7 +239,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE headerAccessor.setSessionId(session.getId()); headerAccessor.setSessionAttributes(session.getAttributes()); headerAccessor.setUser(session.getPrincipal()); - headerAccessor.setImmutable(); + if (!detectImmutableMessageInterceptor(outputChannel)) { + headerAccessor.setImmutable(); + } if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) { this.stats.incrementConnectCount(); @@ -271,6 +278,22 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } + private boolean detectImmutableMessageInterceptor(MessageChannel channel) { + if (this.immutableMessageInterceptorPresent != null) { + return this.immutableMessageInterceptorPresent; + } + if (channel instanceof AbstractMessageChannel) { + for (ChannelInterceptor interceptor : ((AbstractMessageChannel) channel).getInterceptors()) { + if (interceptor instanceof ImmutableMessageChannelInterceptor) { + this.immutableMessageInterceptorPresent = true; + return true; + } + } + } + this.immutableMessageInterceptorPresent = false; + return false; + } + private void publishEvent(ApplicationEvent event) { try { this.eventPublisher.publishEvent(event); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 88904ea2c99..d8fa79e8d08 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -50,6 +50,8 @@ import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.simp.user.UserDestinationResolver; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.AbstractSubscribableChannel; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.MimeTypeUtils; @@ -179,15 +181,15 @@ public class MessageBrokerBeanDefinitionParserTests { List> subscriberTypes = Arrays.>asList(SimpAnnotationMethodMessageHandler.class, UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class); - testChannel("clientInboundChannel", subscriberTypes, 1); + testChannel("clientInboundChannel", subscriberTypes, 2); testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); - testChannel("clientOutboundChannel", subscriberTypes, 0); + testChannel("clientOutboundChannel", subscriberTypes, 1); testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList(SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class); - testChannel("brokerChannel", subscriberTypes, 0); + testChannel("brokerChannel", subscriberTypes, 1); try { this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class); fail("expected exception"); @@ -247,16 +249,16 @@ public class MessageBrokerBeanDefinitionParserTests { List> subscriberTypes = Arrays.>asList(SimpAnnotationMethodMessageHandler.class, UserDestinationMessageHandler.class, StompBrokerRelayMessageHandler.class); - testChannel("clientInboundChannel", subscriberTypes, 1); + testChannel("clientInboundChannel", subscriberTypes, 2); testExecutor("clientInboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); - testChannel("clientOutboundChannel", subscriberTypes, 0); + testChannel("clientOutboundChannel", subscriberTypes, 1); testExecutor("clientOutboundChannel", Runtime.getRuntime().availableProcessors() * 2, Integer.MAX_VALUE, 60); subscriberTypes = Arrays.>asList( StompBrokerRelayMessageHandler.class, UserDestinationMessageHandler.class); - testChannel("brokerChannel", subscriberTypes, 0); + testChannel("brokerChannel", subscriberTypes, 1); try { this.appContext.getBean("brokerChannelExecutor", ThreadPoolTaskExecutor.class); fail("expected exception"); @@ -320,18 +322,18 @@ public class MessageBrokerBeanDefinitionParserTests { Arrays.>asList(SimpAnnotationMethodMessageHandler.class, UserDestinationMessageHandler.class, SimpleBrokerMessageHandler.class); - testChannel("clientInboundChannel", subscriberTypes, 2); + testChannel("clientInboundChannel", subscriberTypes, 3); testExecutor("clientInboundChannel", 100, 200, 600); subscriberTypes = Arrays.>asList(SubProtocolWebSocketHandler.class); - testChannel("clientOutboundChannel", subscriberTypes, 2); + testChannel("clientOutboundChannel", subscriberTypes, 3); testExecutor("clientOutboundChannel", 101, 201, 601); subscriberTypes = Arrays.>asList(SimpleBrokerMessageHandler.class, UserDestinationMessageHandler.class); - testChannel("brokerChannel", subscriberTypes, 0); + testChannel("brokerChannel", subscriberTypes, 1); testExecutor("brokerChannel", 102, 202, 602); } @@ -397,7 +399,9 @@ public class MessageBrokerBeanDefinitionParserTests { assertTrue(channel.hasSubscription(subscriber)); } - assertEquals(interceptorCount, channel.getInterceptors().size()); + List interceptors = channel.getInterceptors(); + assertEquals(interceptorCount, interceptors.size()); + assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); } private void testExecutor(String channelName, int corePoolSize, int maxPoolSize, int keepAliveSeconds) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java index 66a23e1de42..74fd23c02c3 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java @@ -16,14 +16,19 @@ package org.springframework.web.socket.config.annotation; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ScheduledThreadPoolExecutor; import org.junit.Test; - import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; @@ -34,10 +39,14 @@ import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SubscribeMapping; +import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.support.AbstractSubscribableChannel; +import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ExecutorSubscribableChannel; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.stereotype.Controller; import org.springframework.web.servlet.HandlerMapping; @@ -55,8 +64,6 @@ import org.springframework.web.socket.messaging.SubProtocolHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; -import static org.junit.Assert.*; - /** * Test fixture for * {@link org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurationSupport}. @@ -83,6 +90,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { TestChannel channel = config.getBean("clientInboundChannel", TestChannel.class); SubProtocolWebSocketHandler webSocketHandler = config.getBean(SubProtocolWebSocketHandler.class); + List interceptors = channel.getInterceptors(); + assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); + WebSocketSession session = new TestWebSocketSession("s1"); webSocketHandler.afterConnectionEstablished(session); @@ -90,22 +100,40 @@ public class WebSocketMessageBrokerConfigurationSupportTests { webSocketHandler.handleMessage(session, textMessage); Message message = channel.messages.get(0); - StompHeaderAccessor headers = StompHeaderAccessor.wrap(message); - - assertEquals(SimpMessageType.MESSAGE, headers.getMessageType()); - assertEquals("/foo", headers.getDestination()); + StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + assertNotNull(accessor); + assertFalse(accessor.isMutable()); + assertEquals(SimpMessageType.MESSAGE, accessor.getMessageType()); + assertEquals("/foo", accessor.getDestination()); } @Test - public void clientOutboundChannelChannel() { + public void clientOutboundChannel() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); TestChannel channel = config.getBean("clientOutboundChannel", TestChannel.class); Set handlers = channel.getSubscribers(); + List interceptors = channel.getInterceptors(); + assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); + assertEquals(1, handlers.size()); assertTrue(handlers.iterator().next() instanceof SubProtocolWebSocketHandler); } + @Test + public void brokerChannel() { + ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); + TestChannel channel = config.getBean("brokerChannel", TestChannel.class); + Iterator handlers = channel.getSubscribers().iterator(); + + List interceptors = channel.getInterceptors(); + assertEquals(ImmutableMessageChannelInterceptor.class, interceptors.get(interceptors.size()-1).getClass()); + + assertEquals(SimpleBrokerMessageHandler.class, handlers.next().getClass()); + assertEquals(UserDestinationMessageHandler.class, handlers.next().getClass()); + assertFalse(handlers.hasNext()); + } + @Test public void webSocketTransportOptions() { ApplicationContext config = createConfig(TestChannelConfig.class, TestConfigurer.class); @@ -216,18 +244,24 @@ public class WebSocketMessageBrokerConfigurationSupportTests { @Override @Bean public AbstractSubscribableChannel clientInboundChannel() { - return new TestChannel(); + TestChannel channel = new TestChannel(); + channel.setInterceptors(super.clientInboundChannel().getInterceptors()); + return channel; } @Override @Bean public AbstractSubscribableChannel clientOutboundChannel() { - return new TestChannel(); + TestChannel channel = new TestChannel(); + channel.setInterceptors(super.clientOutboundChannel().getInterceptors()); + return channel; } @Override public AbstractSubscribableChannel brokerChannel() { - return new TestChannel(); + TestChannel channel = new TestChannel(); + channel.setInterceptors(super.brokerChannel().getInterceptors()); + return channel; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index fbe71560705..a850f644d34 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; import org.junit.Test; @@ -32,6 +33,9 @@ import org.springframework.context.ApplicationEvent; import org.springframework.context.ApplicationEventPublisher; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; +import org.springframework.messaging.SubscribableChannel; import org.springframework.messaging.simp.SimpAttributes; import org.springframework.messaging.simp.SimpAttributesContextHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -44,7 +48,12 @@ import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.simp.user.UserSessionRegistry; +import org.springframework.messaging.support.AbstractMessageChannel; +import org.springframework.messaging.support.ChannelInterceptorAdapter; +import org.springframework.messaging.support.ExecutorSubscribableChannel; +import org.springframework.messaging.support.ImmutableMessageChannelInterceptor; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; @@ -288,6 +297,48 @@ public class StompSubProtocolHandlerTests { assertEquals(0, this.session.getSentMessages().size()); } + @Test + public void handleMessageFromClientWithImmutableMessageInterceptor() { + AtomicReference mutable = new AtomicReference<>(); + ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(); + channel.addInterceptor(new ChannelInterceptorAdapter() { + @Override + public Message preSend(Message message, MessageChannel channel) { + mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable()); + return message; + } + }); + channel.addInterceptor(new ImmutableMessageChannelInterceptor()); + + StompSubProtocolHandler handler = new StompSubProtocolHandler(); + handler.afterSessionStarted(this.session, channel); + + TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build(); + handler.handleMessageFromClient(this.session, message, channel); + assertNotNull(mutable.get()); + assertTrue(mutable.get()); + } + + @Test + public void handleMessageFromClientWithoutImmutableMessageInterceptor() { + AtomicReference mutable = new AtomicReference<>(); + ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(); + channel.addInterceptor(new ChannelInterceptorAdapter() { + @Override + public Message preSend(Message message, MessageChannel channel) { + mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable()); + return message; + } + }); + + StompSubProtocolHandler handler = new StompSubProtocolHandler(); + handler.afterSessionStarted(this.session, channel); + + TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build(); + handler.handleMessageFromClient(this.session, message, channel); + assertNotNull(mutable.get()); + assertFalse(mutable.get()); + } @Test public void handleMessageFromClientInvalidStompCommand() {