diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java
index da3e7f41856..67d50fc6eab 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java
@@ -113,7 +113,12 @@ public class OrderedMessageChannelDecorator implements MessageChannel {
}
}
+ /**
+ * Remove the message from the top of the queue, but only if it matches,
+ * i.e. hasn't been removed already.
+ */
private boolean removeMessage(Message> message) {
+ // Remove only if not removed already
Message> next = this.messages.peek();
if (next == message) {
this.messages.remove();
@@ -181,7 +186,7 @@ public class OrderedMessageChannelDecorator implements MessageChannel {
@Override
public void run() {
if (this.handledCount == null || this.handledCount.addAndGet(1) == subscriberCount) {
- if (OrderedMessageChannelDecorator.this.removeMessage(message)) {
+ if (OrderedMessageChannelDecorator.this.removeMessage(this.message)) {
sendNextMessage();
}
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java
index 8748d1e2f84..78866719ce3 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -52,4 +52,18 @@ public interface StompEndpointRegistry {
*/
WebMvcStompEndpointRegistry setErrorHandler(StompSubProtocolErrorHandler errorHandler);
+ /**
+ * Whether to handle client messages sequentially in the order in which
+ * they were received.
+ *
By default messages sent to the {@code "clientInboundChannel"} may
+ * be handled in parallel and not in the same order as they were received
+ * because the channel is backed by a ThreadPoolExecutor that in turn does
+ * not guarantee processing in order.
+ *
When this flag is set to {@code true} messages within the same session
+ * will be sent to the {@code "clientInboundChannel"} one at a time in
+ * order to preserve the order in which they were received.
+ * @since 6.1
+ */
+ WebMvcStompEndpointRegistry setPreserveReceiveOrder(boolean preserveReceiveOrder);
+
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java
index 7b3dd2291e1..0265ae8db53 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java
@@ -142,6 +142,15 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry {
return this;
}
+ public WebMvcStompEndpointRegistry setPreserveReceiveOrder(boolean preserveReceiveOrder) {
+ this.stompHandler.setPreserveReceiveOrder(preserveReceiveOrder);
+ return this;
+ }
+
+ protected boolean isPreserveReceiveOrder() {
+ return this.stompHandler.isPreserveReceiveOrder();
+ }
+
protected void setApplicationContext(ApplicationContext applicationContext) {
this.stompHandler.setApplicationEventPublisher(applicationContext);
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java
index 2b84d011afa..aad0ba3db39 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java
@@ -28,6 +28,7 @@ import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.SimpSessionScope;
import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler;
import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler;
+import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.user.SimpUserRegistry;
@@ -80,7 +81,8 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
@Bean
public HandlerMapping stompWebSocketHandlerMapping(
- WebSocketHandler subProtocolWebSocketHandler, TaskScheduler messageBrokerTaskScheduler) {
+ WebSocketHandler subProtocolWebSocketHandler, TaskScheduler messageBrokerTaskScheduler,
+ AbstractSubscribableChannel clientInboundChannel) {
WebSocketHandler handler = decorateWebSocketHandler(subProtocolWebSocketHandler);
WebMvcStompEndpointRegistry registry =
@@ -90,6 +92,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
registry.setApplicationContext(applicationContext);
}
registerStompEndpoints(registry);
+ OrderedMessageChannelDecorator.configureInterceptor(clientInboundChannel, registry.isPreserveReceiveOrder());
return registry.getHandlerMapping();
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
index 165e8605951..ec6df9b246f 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
@@ -108,6 +108,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Nullable
private MessageHeaderInitializer headerInitializer;
+ private boolean preserveReceiveOrder;
+
+ private final Map messageChannels = new ConcurrentHashMap<>();
+
private final Map stompAuthentications = new ConcurrentHashMap<>();
@Nullable
@@ -193,6 +197,30 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
return this.headerInitializer;
}
+ /**
+ * Whether client messages must be handled in the order received.
+ * By default messages sent to the {@code "clientInboundChannel"} may
+ * not be handled in the same order because the channel is backed by a
+ * ThreadPoolExecutor that in turn does not guarantee processing in order.
+ *
When this flag is set to {@code true} messages within the same session
+ * will be sent to the {@code "clientInboundChannel"} one at a time to
+ * preserve the order in which they were received.
+ * @param preserveReceiveOrder whether to publish in order
+ * @since 6.1
+ */
+ public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
+ this.preserveReceiveOrder = preserveReceiveOrder;
+ }
+
+ /**
+ * Whether the handler is configured to handle inbound messages in the
+ * order in which they were received.
+ * @since 6.1
+ */
+ public boolean isPreserveReceiveOrder() {
+ return this.preserveReceiveOrder;
+ }
+
@Override
public List getSupportedProtocols() {
return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
@@ -268,6 +296,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
return;
}
+ MessageChannel channelToUse =
+ (this.messageChannels.computeIfAbsent(session.getId(),
+ id -> this.preserveReceiveOrder ?
+ new OrderedMessageChannelDecorator(outputChannel, logger) :
+ outputChannel));
+
for (Message message : messages) {
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
@@ -307,7 +341,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
try {
SimpAttributesContextHolder.setAttributesFromMessage(message);
- sent = outputChannel.send(message);
+ sent = channelToUse.send(message);
if (sent) {
if (this.eventPublisher != null) {
@@ -652,6 +686,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
outputChannel.send(message);
}
finally {
+ this.messageChannels.remove(session.getId());
this.stompAuthentications.remove(session.getId());
SimpAttributesContextHolder.resetAttributes();
simpAttributes.sessionCompleted();
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java
index ac283f3c9c4..839bafc72c8 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java
@@ -18,6 +18,10 @@ package org.springframework.web.socket.messaging;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
@@ -28,25 +32,22 @@ import org.junit.jupiter.api.TestInfo;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
-import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Scope;
import org.springframework.context.annotation.ScopedProxyMode;
-import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
-import org.springframework.messaging.support.AbstractSubscribableChannel;
-import org.springframework.messaging.support.ExecutorSubscribableChannel;
+import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.stereotype.Controller;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.WebSocketTestServer;
import org.springframework.web.socket.client.WebSocketClient;
-import org.springframework.web.socket.config.annotation.DelegatingWebSocketMessageBrokerConfiguration;
+import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.handler.TextWebSocketHandler;
@@ -69,7 +70,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override
protected Class>[] getAnnotatedConfigClasses() {
- return new Class>[] {TestMessageBrokerConfiguration.class, TestMessageBrokerConfigurer.class};
+ return new Class>[] {TestMessageBrokerConfigurer.class};
}
@@ -100,7 +101,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m2 = create(StompCommand.SEND)
.headers("destination:/app/increment").body("5").build();
- TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
+ TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
@@ -121,17 +122,49 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destination, selector).build();
TextMessage m2 = create(StompCommand.SEND).headers(destination, "foo:bar").body("5").build();
- TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
+ TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
- String payload = clientHandler.actual.get(1).getPayload();
+ String payload = clientHandler.actual.get(0).getPayload();
assertThat(payload).as("Expected STOMP MESSAGE, got " + payload).startsWith("MESSAGE\n");
}
}
+ @ParameterizedWebSocketTest // gh-21798
+ void sendMessageToBrokerAndReceiveInOrder(
+ WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
+
+ super.setup(server, webSocketClient, testInfo);
+
+ String destination = "destination:/topic/foo";
+
+ List messages = new ArrayList<>();
+ messages.add(create(StompCommand.CONNECT).headers("accept-version:1.1").build());
+ messages.add(create(StompCommand.SUBSCRIBE).headers("id:subs1", destination).build());
+
+ int count = 1000;
+ for (int i = 0; i < count; i++) {
+ messages.add(create(StompCommand.SEND).headers(destination).body(String.valueOf(i)).build());
+ }
+
+ TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(count, messages);
+
+ try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
+ assertThat(session).isNotNull();
+ assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
+
+ for (int i = 0; i < count; i++) {
+ TextMessage message = clientHandler.actual.get(i);
+ ByteBuffer buffer = ByteBuffer.wrap(message.asBytes());
+ byte[] bytes = new StompDecoder().decode(buffer).get(0).getPayload();
+ assertThat(new String(bytes, StandardCharsets.UTF_8)).isEqualTo(String.valueOf(i));
+ }
+ }
+ }
+
@ParameterizedWebSocketTest // SPR-11648
void sendSubscribeToControllerAndReceiveReply(
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
@@ -142,12 +175,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
String destHeader = "destination:/app/number";
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build();
- TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1);
+ TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
- String payload = clientHandler.actual.get(1).getPayload();
+ String payload = clientHandler.actual.get(0).getPayload();
assertThat(payload).as("Expected STOMP destination=/app/number, got " + payload).contains(destHeader);
assertThat(payload).as("Expected STOMP Payload=42, got " + payload).contains("42");
}
@@ -164,12 +197,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build();
TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/exception").build();
- TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
+ TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
- String payload = clientHandler.actual.get(1).getPayload();
+ String payload = clientHandler.actual.get(0).getPayload();
assertThat(payload).startsWith("MESSAGE\n");
assertThat(payload).contains("destination:/user/queue/error\n");
assertThat(payload).endsWith("Got error: Bad input\0");
@@ -188,12 +221,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m2 = create(StompCommand.SEND)
.headers("destination:/app/scopedBeanValue").build();
- TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
+ TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
- String payload = clientHandler.actual.get(1).getPayload();
+ String payload = clientHandler.actual.get(0).getPayload();
assertThat(payload).startsWith("MESSAGE\n");
assertThat(payload).contains("destination:/topic/scopedBeanValue\n");
assertThat(payload).endsWith("55\0");
@@ -285,18 +318,19 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
private static class TestClientWebSocketHandler extends TextWebSocketHandler {
- private final TextMessage[] messagesToSend;
-
- private final int expected;
+ private final List messagesToSend;
private final List actual = new CopyOnWriteArrayList<>();
private final CountDownLatch latch;
- public TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) {
+ TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) {
+ this(expectedNumberOfMessages, Arrays.asList(messagesToSend));
+ }
+
+ TestClientWebSocketHandler(int expectedNumberOfMessages, List messagesToSend) {
this.messagesToSend = messagesToSend;
- this.expected = expectedNumberOfMessages;
- this.latch = new CountDownLatch(this.expected);
+ this.latch = new CountDownLatch(expectedNumberOfMessages);
}
@Override
@@ -307,18 +341,20 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
}
@Override
- protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
- this.actual.add(message);
- this.latch.countDown();
+ protected void handleTextMessage(WebSocketSession session, TextMessage message) {
+ if (!message.getPayload().startsWith("CONNECTED")) {
+ this.actual.add(message);
+ this.latch.countDown();
+ }
}
}
- @Configuration
@ComponentScan(
basePackageClasses = StompWebSocketIntegrationTests.class,
useDefaultFilters = false,
includeFilters = @ComponentScan.Filter(IntegrationTestController.class))
+ @EnableWebSocketMessageBroker
static class TestMessageBrokerConfigurer implements WebSocketMessageBrokerConfigurer {
@Autowired
@@ -326,12 +362,14 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
+ registry.setPreserveReceiveOrder(true);
registry.addEndpoint("/ws").setHandshakeHandler(this.handshakeHandler);
}
@Override
public void configureMessageBroker(MessageBrokerRegistry configurer) {
configurer.setApplicationDestinationPrefixes("/app");
+ configurer.setPreservePublishOrder(true);
configurer.enableSimpleBroker("/topic", "/queue").setSelectorHeaderName("selector");
}
@@ -342,21 +380,4 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
}
}
-
- @Configuration
- static class TestMessageBrokerConfiguration extends DelegatingWebSocketMessageBrokerConfiguration {
-
- @Override
- @Bean
- public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) {
- return new ExecutorSubscribableChannel(); // synchronous
- }
-
- @Override
- @Bean
- public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) {
- return new ExecutorSubscribableChannel(); // synchronous
- }
- }
-
}