diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java index da3e7f41856..67d50fc6eab 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java @@ -113,7 +113,12 @@ public class OrderedMessageChannelDecorator implements MessageChannel { } } + /** + * Remove the message from the top of the queue, but only if it matches, + * i.e. hasn't been removed already. + */ private boolean removeMessage(Message message) { + // Remove only if not removed already Message next = this.messages.peek(); if (next == message) { this.messages.remove(); @@ -181,7 +186,7 @@ public class OrderedMessageChannelDecorator implements MessageChannel { @Override public void run() { if (this.handledCount == null || this.handledCount.addAndGet(1) == subscriberCount) { - if (OrderedMessageChannelDecorator.this.removeMessage(message)) { + if (OrderedMessageChannelDecorator.this.removeMessage(this.message)) { sendNextMessage(); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java index 8748d1e2f84..78866719ce3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,4 +52,18 @@ public interface StompEndpointRegistry { */ WebMvcStompEndpointRegistry setErrorHandler(StompSubProtocolErrorHandler errorHandler); + /** + * Whether to handle client messages sequentially in the order in which + * they were received. + *

By default messages sent to the {@code "clientInboundChannel"} may + * be handled in parallel and not in the same order as they were received + * because the channel is backed by a ThreadPoolExecutor that in turn does + * not guarantee processing in order. + *

When this flag is set to {@code true} messages within the same session + * will be sent to the {@code "clientInboundChannel"} one at a time in + * order to preserve the order in which they were received. + * @since 6.1 + */ + WebMvcStompEndpointRegistry setPreserveReceiveOrder(boolean preserveReceiveOrder); + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java index 7b3dd2291e1..0265ae8db53 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java @@ -142,6 +142,15 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry { return this; } + public WebMvcStompEndpointRegistry setPreserveReceiveOrder(boolean preserveReceiveOrder) { + this.stompHandler.setPreserveReceiveOrder(preserveReceiveOrder); + return this; + } + + protected boolean isPreserveReceiveOrder() { + return this.stompHandler.isPreserveReceiveOrder(); + } + protected void setApplicationContext(ApplicationContext applicationContext) { this.stompHandler.setApplicationEventPublisher(applicationContext); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java index 2b84d011afa..aad0ba3db39 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java @@ -28,6 +28,7 @@ import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.SimpSessionScope; import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler; +import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator; import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration; import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler; import org.springframework.messaging.simp.user.SimpUserRegistry; @@ -80,7 +81,8 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac @Bean public HandlerMapping stompWebSocketHandlerMapping( - WebSocketHandler subProtocolWebSocketHandler, TaskScheduler messageBrokerTaskScheduler) { + WebSocketHandler subProtocolWebSocketHandler, TaskScheduler messageBrokerTaskScheduler, + AbstractSubscribableChannel clientInboundChannel) { WebSocketHandler handler = decorateWebSocketHandler(subProtocolWebSocketHandler); WebMvcStompEndpointRegistry registry = @@ -90,6 +92,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac registry.setApplicationContext(applicationContext); } registerStompEndpoints(registry); + OrderedMessageChannelDecorator.configureInterceptor(clientInboundChannel, registry.isPreserveReceiveOrder()); return registry.getHandlerMapping(); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 165e8605951..ec6df9b246f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -108,6 +108,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @Nullable private MessageHeaderInitializer headerInitializer; + private boolean preserveReceiveOrder; + + private final Map messageChannels = new ConcurrentHashMap<>(); + private final Map stompAuthentications = new ConcurrentHashMap<>(); @Nullable @@ -193,6 +197,30 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE return this.headerInitializer; } + /** + * Whether client messages must be handled in the order received. + *

By default messages sent to the {@code "clientInboundChannel"} may + * not be handled in the same order because the channel is backed by a + * ThreadPoolExecutor that in turn does not guarantee processing in order. + *

When this flag is set to {@code true} messages within the same session + * will be sent to the {@code "clientInboundChannel"} one at a time to + * preserve the order in which they were received. + * @param preserveReceiveOrder whether to publish in order + * @since 6.1 + */ + public void setPreserveReceiveOrder(boolean preserveReceiveOrder) { + this.preserveReceiveOrder = preserveReceiveOrder; + } + + /** + * Whether the handler is configured to handle inbound messages in the + * order in which they were received. + * @since 6.1 + */ + public boolean isPreserveReceiveOrder() { + return this.preserveReceiveOrder; + } + @Override public List getSupportedProtocols() { return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"); @@ -268,6 +296,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE return; } + MessageChannel channelToUse = + (this.messageChannels.computeIfAbsent(session.getId(), + id -> this.preserveReceiveOrder ? + new OrderedMessageChannelDecorator(outputChannel, logger) : + outputChannel)); + for (Message message : messages) { StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); @@ -307,7 +341,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE try { SimpAttributesContextHolder.setAttributesFromMessage(message); - sent = outputChannel.send(message); + sent = channelToUse.send(message); if (sent) { if (this.eventPublisher != null) { @@ -652,6 +686,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE outputChannel.send(message); } finally { + this.messageChannels.remove(session.getId()); this.stompAuthentications.remove(session.getId()); SimpAttributesContextHolder.resetAttributes(); simpAttributes.sessionCompleted(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java index ac283f3c9c4..839bafc72c8 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java @@ -18,6 +18,10 @@ package org.springframework.web.socket.messaging; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; @@ -28,25 +32,22 @@ import org.junit.jupiter.api.TestInfo; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; -import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Scope; import org.springframework.context.annotation.ScopedProxyMode; -import org.springframework.core.task.TaskExecutor; import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.annotation.SubscribeMapping; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.simp.stomp.StompCommand; -import org.springframework.messaging.support.AbstractSubscribableChannel; -import org.springframework.messaging.support.ExecutorSubscribableChannel; +import org.springframework.messaging.simp.stomp.StompDecoder; import org.springframework.stereotype.Controller; import org.springframework.web.socket.AbstractWebSocketIntegrationTests; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketTestServer; import org.springframework.web.socket.client.WebSocketClient; -import org.springframework.web.socket.config.annotation.DelegatingWebSocketMessageBrokerConfiguration; +import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; import org.springframework.web.socket.handler.TextWebSocketHandler; @@ -69,7 +70,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @Override protected Class[] getAnnotatedConfigClasses() { - return new Class[] {TestMessageBrokerConfiguration.class, TestMessageBrokerConfigurer.class}; + return new Class[] {TestMessageBrokerConfigurer.class}; } @@ -100,7 +101,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { TextMessage m2 = create(StompCommand.SEND) .headers("destination:/app/increment").body("5").build(); - TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2); + TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2); try (WebSocketSession session = execute(clientHandler, "/ws").get()) { assertThat(session).isNotNull(); @@ -121,17 +122,49 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destination, selector).build(); TextMessage m2 = create(StompCommand.SEND).headers(destination, "foo:bar").body("5").build(); - TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2); + TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2); try (WebSocketSession session = execute(clientHandler, "/ws").get()) { assertThat(session).isNotNull(); assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue(); - String payload = clientHandler.actual.get(1).getPayload(); + String payload = clientHandler.actual.get(0).getPayload(); assertThat(payload).as("Expected STOMP MESSAGE, got " + payload).startsWith("MESSAGE\n"); } } + @ParameterizedWebSocketTest // gh-21798 + void sendMessageToBrokerAndReceiveInOrder( + WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception { + + super.setup(server, webSocketClient, testInfo); + + String destination = "destination:/topic/foo"; + + List messages = new ArrayList<>(); + messages.add(create(StompCommand.CONNECT).headers("accept-version:1.1").build()); + messages.add(create(StompCommand.SUBSCRIBE).headers("id:subs1", destination).build()); + + int count = 1000; + for (int i = 0; i < count; i++) { + messages.add(create(StompCommand.SEND).headers(destination).body(String.valueOf(i)).build()); + } + + TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(count, messages); + + try (WebSocketSession session = execute(clientHandler, "/ws").get()) { + assertThat(session).isNotNull(); + assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue(); + + for (int i = 0; i < count; i++) { + TextMessage message = clientHandler.actual.get(i); + ByteBuffer buffer = ByteBuffer.wrap(message.asBytes()); + byte[] bytes = new StompDecoder().decode(buffer).get(0).getPayload(); + assertThat(new String(bytes, StandardCharsets.UTF_8)).isEqualTo(String.valueOf(i)); + } + } + } + @ParameterizedWebSocketTest // SPR-11648 void sendSubscribeToControllerAndReceiveReply( WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception { @@ -142,12 +175,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { String destHeader = "destination:/app/number"; TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build(); - TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1); + TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1); try (WebSocketSession session = execute(clientHandler, "/ws").get()) { assertThat(session).isNotNull(); assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue(); - String payload = clientHandler.actual.get(1).getPayload(); + String payload = clientHandler.actual.get(0).getPayload(); assertThat(payload).as("Expected STOMP destination=/app/number, got " + payload).contains(destHeader); assertThat(payload).as("Expected STOMP Payload=42, got " + payload).contains("42"); } @@ -164,12 +197,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build(); TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/exception").build(); - TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2); + TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2); try (WebSocketSession session = execute(clientHandler, "/ws").get()) { assertThat(session).isNotNull(); assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue(); - String payload = clientHandler.actual.get(1).getPayload(); + String payload = clientHandler.actual.get(0).getPayload(); assertThat(payload).startsWith("MESSAGE\n"); assertThat(payload).contains("destination:/user/queue/error\n"); assertThat(payload).endsWith("Got error: Bad input\0"); @@ -188,12 +221,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { TextMessage m2 = create(StompCommand.SEND) .headers("destination:/app/scopedBeanValue").build(); - TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2); + TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2); try (WebSocketSession session = execute(clientHandler, "/ws").get()) { assertThat(session).isNotNull(); assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue(); - String payload = clientHandler.actual.get(1).getPayload(); + String payload = clientHandler.actual.get(0).getPayload(); assertThat(payload).startsWith("MESSAGE\n"); assertThat(payload).contains("destination:/topic/scopedBeanValue\n"); assertThat(payload).endsWith("55\0"); @@ -285,18 +318,19 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { private static class TestClientWebSocketHandler extends TextWebSocketHandler { - private final TextMessage[] messagesToSend; - - private final int expected; + private final List messagesToSend; private final List actual = new CopyOnWriteArrayList<>(); private final CountDownLatch latch; - public TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) { + TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) { + this(expectedNumberOfMessages, Arrays.asList(messagesToSend)); + } + + TestClientWebSocketHandler(int expectedNumberOfMessages, List messagesToSend) { this.messagesToSend = messagesToSend; - this.expected = expectedNumberOfMessages; - this.latch = new CountDownLatch(this.expected); + this.latch = new CountDownLatch(expectedNumberOfMessages); } @Override @@ -307,18 +341,20 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { } @Override - protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { - this.actual.add(message); - this.latch.countDown(); + protected void handleTextMessage(WebSocketSession session, TextMessage message) { + if (!message.getPayload().startsWith("CONNECTED")) { + this.actual.add(message); + this.latch.countDown(); + } } } - @Configuration @ComponentScan( basePackageClasses = StompWebSocketIntegrationTests.class, useDefaultFilters = false, includeFilters = @ComponentScan.Filter(IntegrationTestController.class)) + @EnableWebSocketMessageBroker static class TestMessageBrokerConfigurer implements WebSocketMessageBrokerConfigurer { @Autowired @@ -326,12 +362,14 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @Override public void registerStompEndpoints(StompEndpointRegistry registry) { + registry.setPreserveReceiveOrder(true); registry.addEndpoint("/ws").setHandshakeHandler(this.handshakeHandler); } @Override public void configureMessageBroker(MessageBrokerRegistry configurer) { configurer.setApplicationDestinationPrefixes("/app"); + configurer.setPreservePublishOrder(true); configurer.enableSimpleBroker("/topic", "/queue").setSelectorHeaderName("selector"); } @@ -342,21 +380,4 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { } } - - @Configuration - static class TestMessageBrokerConfiguration extends DelegatingWebSocketMessageBrokerConfiguration { - - @Override - @Bean - public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) { - return new ExecutorSubscribableChannel(); // synchronous - } - - @Override - @Bean - public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) { - return new ExecutorSubscribableChannel(); // synchronous - } - } - }