Handle STOMP messages to user destination in order

Closes gh-31395
This commit is contained in:
rstoyanchev 2023-10-11 12:14:39 +01:00
parent 9eb39e182e
commit 3277b0d6ac
9 changed files with 219 additions and 25 deletions

View File

@ -6,7 +6,7 @@ written to WebSocket sessions. As the channel is backed by a `ThreadPoolExecutor
are processed in different threads, and the resulting sequence received by the client may are processed in different threads, and the resulting sequence received by the client may
not match the exact order of publication. not match the exact order of publication.
If this is an issue, enable the `setPreservePublishOrder` flag, as the following example shows: To enable ordered publishing, set the `setPreservePublishOrder` flag as follows:
[source,java,indent=0,subs="verbatim,quotes"] [source,java,indent=0,subs="verbatim,quotes"]
---- ----
@ -47,5 +47,22 @@ When the flag is set, messages within the same client session are published to t
`clientOutboundChannel` one at a time, so that the order of publication is guaranteed. `clientOutboundChannel` one at a time, so that the order of publication is guaranteed.
Note that this incurs a small performance overhead, so you should enable it only if it is required. Note that this incurs a small performance overhead, so you should enable it only if it is required.
The same also applies to messages from the client, which are sent to the `clientInboundChannel`,
from where they are handled according to their destination prefix. As the channel is backed by
a `ThreadPoolExecutor`, messages are processed in different threads, and the resulting sequence
of handling may not match the exact order in which they were received.
To enable ordered publishing, set the `setPreserveReceiveOrder` flag as follows:
[source,java,indent=0,subs="verbatim,quotes"]
----
@Configuration
@EnableWebSocketMessageBroker
public class MyConfig implements WebSocketMessageBrokerConfigurer {
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.setPreserveReceiveOrder(true);
}
}
----

View File

@ -157,7 +157,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
if (simpAccessor.isMutable()) { if (simpAccessor.isMutable()) {
simpAccessor.setDestination(destination); simpAccessor.setDestination(destination);
simpAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); simpAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
simpAccessor.setImmutable(); // ImmutableMessageChannelInterceptor will make it immutable
sendInternal(message); sendInternal(message);
return; return;
} }

View File

@ -159,6 +159,16 @@ public class OrderedMessageChannelDecorator implements MessageChannel {
} }
} }
/**
* Whether the channel has been {@link #configureInterceptor configured}
* with an interceptor for sequential handling.
* @since 6.1
*/
public static boolean supportsOrderedMessages(MessageChannel channel) {
return (channel instanceof ExecutorSubscribableChannel ch &&
ch.getInterceptors().stream().anyMatch(CallbackTaskInterceptor.class::isInstance));
}
/** /**
* Obtain the task to release the next message, if found. * Obtain the task to release the next message, if found.
*/ */

View File

@ -131,8 +131,9 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
} }
String user = parseResult.getUser(); String user = parseResult.getUser();
String sourceDest = parseResult.getSourceDestination(); String sourceDest = parseResult.getSourceDestination();
Set<String> sessionIds = parseResult.getSessionIds();
Set<String> targetSet = new HashSet<>(); Set<String> targetSet = new HashSet<>();
for (String sessionId : parseResult.getSessionIds()) { for (String sessionId : sessionIds) {
String actualDest = parseResult.getActualDestination(); String actualDest = parseResult.getActualDestination();
String targetDest = getTargetDestination(sourceDest, actualDest, sessionId, user); String targetDest = getTargetDestination(sourceDest, actualDest, sessionId, user);
if (targetDest != null) { if (targetDest != null) {
@ -140,7 +141,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
} }
} }
String subscribeDest = parseResult.getSubscribeDestination(); String subscribeDest = parseResult.getSubscribeDestination();
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user); return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user, sessionIds);
} }
@Nullable @Nullable

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,13 +17,18 @@
package org.springframework.messaging.simp.user; package org.springframework.messaging.simp.user;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.springframework.context.SmartLifecycle; import org.springframework.context.SmartLifecycle;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler; import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.MessagingException; import org.springframework.messaging.MessagingException;
@ -33,6 +38,7 @@ import org.springframework.messaging.simp.SimpLogging;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderInitializer; import org.springframework.messaging.support.MessageHeaderInitializer;
@ -61,7 +67,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
private final UserDestinationResolver destinationResolver; private final UserDestinationResolver destinationResolver;
private final MessageSendingOperations<String> messagingTemplate; private final SendHelper sendHelper;
@Nullable @Nullable
private BroadcastHandler broadcastHandler; private BroadcastHandler broadcastHandler;
@ -91,7 +97,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
this.clientInboundChannel = clientInboundChannel; this.clientInboundChannel = clientInboundChannel;
this.brokerChannel = brokerChannel; this.brokerChannel = brokerChannel;
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel); this.sendHelper = new SendHelper(clientInboundChannel, brokerChannel);
this.destinationResolver = destinationResolver; this.destinationResolver = destinationResolver;
} }
@ -112,7 +118,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
*/ */
public void setBroadcastDestination(@Nullable String destination) { public void setBroadcastDestination(@Nullable String destination) {
this.broadcastHandler = (StringUtils.hasText(destination) ? this.broadcastHandler = (StringUtils.hasText(destination) ?
new BroadcastHandler(this.messagingTemplate, destination) : null); new BroadcastHandler(this.sendHelper.getMessagingTemplate(), destination) : null);
} }
/** /**
@ -128,7 +134,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
* broker channel. * broker channel.
*/ */
public MessageSendingOperations<String> getBrokerMessagingTemplate() { public MessageSendingOperations<String> getBrokerMessagingTemplate() {
return this.messagingTemplate; return this.sendHelper.getMessagingTemplate();
} }
/** /**
@ -193,6 +199,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
UserDestinationResult result = this.destinationResolver.resolveDestination(message); UserDestinationResult result = this.destinationResolver.resolveDestination(message);
if (result == null) { if (result == null) {
this.sendHelper.checkDisconnect(message);
return; return;
} }
@ -215,9 +222,8 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Translated " + result.getSourceDestination() + " -> " + result.getTargetDestinations()); logger.trace("Translated " + result.getSourceDestination() + " -> " + result.getTargetDestinations());
} }
for (String target : result.getTargetDestinations()) {
this.messagingTemplate.send(target, message); this.sendHelper.send(result, message);
}
} }
private void initHeaders(SimpMessageHeaderAccessor headerAccessor) { private void initHeaders(SimpMessageHeaderAccessor headerAccessor) {
@ -232,6 +238,63 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
} }
private static class SendHelper {
private final MessageChannel brokerChannel;
private final MessageSendingOperations<String> messagingTemplate;
@Nullable
private final Map<String, MessageSendingOperations<String>> orderedMessagingTemplates;
SendHelper(MessageChannel clientInboundChannel, MessageChannel brokerChannel) {
this.brokerChannel = brokerChannel;
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
if (OrderedMessageChannelDecorator.supportsOrderedMessages(clientInboundChannel)) {
this.orderedMessagingTemplates = new ConcurrentHashMap<>();
OrderedMessageChannelDecorator.configureInterceptor(brokerChannel, true);
}
else {
this.orderedMessagingTemplates = null;
}
}
public MessageSendingOperations<String> getMessagingTemplate() {
return this.messagingTemplate;
}
public void send(UserDestinationResult destinationResult, Message<?> message) throws MessagingException {
Set<String> sessionIds = destinationResult.getSessionIds();
Iterator<String> itr = (sessionIds != null ? sessionIds.iterator() : null);
for (String target : destinationResult.getTargetDestinations()) {
String sessionId = (itr != null ? itr.next() : null);
getTemplateToUse(sessionId).send(target, message);
}
}
private MessageSendingOperations<String> getTemplateToUse(@Nullable String sessionId) {
if (this.orderedMessagingTemplates != null && sessionId != null) {
return this.orderedMessagingTemplates.computeIfAbsent(sessionId, id ->
new SimpMessagingTemplate(new OrderedMessageChannelDecorator(this.brokerChannel, logger)));
}
return this.messagingTemplate;
}
public void checkDisconnect(Message<?> message) {
if (this.orderedMessagingTemplates != null) {
MessageHeaders headers = message.getHeaders();
if (SimpMessageHeaderAccessor.getMessageType(headers) == SimpMessageType.DISCONNECT) {
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (sessionId != null) {
this.orderedMessagingTemplates.remove(sessionId);
}
}
}
}
}
/** /**
* A handler that broadcasts locally unresolved messages to the broker and * A handler that broadcasts locally unresolved messages to the broker and
* also handles similar broadcasts received from the broker. * also handles similar broadcasts received from the broker.

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2017 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
package org.springframework.messaging.simp.user; package org.springframework.messaging.simp.user;
import java.util.Collections;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
@ -40,10 +41,23 @@ public class UserDestinationResult {
@Nullable @Nullable
private final String user; private final String user;
private final Set<String> sessionIds;
public UserDestinationResult(String sourceDestination, Set<String> targetDestinations, public UserDestinationResult(String sourceDestination, Set<String> targetDestinations,
String subscribeDestination, @Nullable String user) { String subscribeDestination, @Nullable String user) {
this(sourceDestination, targetDestinations, subscribeDestination, user, null);
}
/**
* Additional constructor with the session id for each targetDestination.
* @since 6.1
*/
public UserDestinationResult(
String sourceDestination, Set<String> targetDestinations,
String subscribeDestination, @Nullable String user, @Nullable Set<String> sessionIds) {
Assert.notNull(sourceDestination, "'sourceDestination' must not be null"); Assert.notNull(sourceDestination, "'sourceDestination' must not be null");
Assert.notNull(targetDestinations, "'targetDestinations' must not be null"); Assert.notNull(targetDestinations, "'targetDestinations' must not be null");
Assert.notNull(subscribeDestination, "'subscribeDestination' must not be null"); Assert.notNull(subscribeDestination, "'subscribeDestination' must not be null");
@ -52,6 +66,7 @@ public class UserDestinationResult {
this.targetDestinations = targetDestinations; this.targetDestinations = targetDestinations;
this.subscribeDestination = subscribeDestination; this.subscribeDestination = subscribeDestination;
this.user = user; this.user = user;
this.sessionIds = (sessionIds != null ? sessionIds : Collections.emptySet());
} }
@ -96,6 +111,13 @@ public class UserDestinationResult {
return this.user; return this.user;
} }
/**
* Return the session id for the targetDestination.
*/
@Nullable
public Set<String> getSessionIds() {
return this.sessionIds;
}
@Override @Override
public String toString() { public String toString() {

View File

@ -158,7 +158,6 @@ public class SimpMessagingTemplateTests {
Message<byte[]> message = messages.get(0); Message<byte[]> message = messages.get(0);
assertThat(message.getHeaders()).isSameAs(headers); assertThat(message.getHeaders()).isSameAs(headers);
assertThat(accessor.isMutable()).isFalse();
} }
@Test @Test
@ -190,7 +189,6 @@ public class SimpMessagingTemplateTests {
Message<byte[]> sentMessage = messages.get(0); Message<byte[]> sentMessage = messages.get(0);
assertThat(sentMessage).isSameAs(message); assertThat(sentMessage).isSameAs(message);
assertThat(accessor.isMutable()).isFalse();
} }
@Test @Test

View File

@ -24,6 +24,7 @@ import java.util.Map;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream; import java.util.stream.Stream;
import jakarta.servlet.Filter;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
@ -35,6 +36,7 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.context.Lifecycle; import org.springframework.context.Lifecycle;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.lang.Nullable;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.client.WebSocketClient; import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient; import org.springframework.web.socket.client.standard.StandardWebSocketClient;
@ -85,11 +87,18 @@ public abstract class AbstractWebSocketIntegrationTests {
protected AnnotationConfigWebApplicationContext wac; protected AnnotationConfigWebApplicationContext wac;
protected void setup(WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception { protected void setup(WebSocketTestServer server, WebSocketClient client, TestInfo info) throws Exception {
this.server = server; setup(server, null, client, info);
this.webSocketClient = webSocketClient; }
logger.debug("Setting up '" + testInfo.getTestMethod().get().getName() + "', client=" + protected void setup(
WebSocketTestServer server, @Nullable Filter filter, WebSocketClient client, TestInfo info)
throws Exception {
this.server = server;
this.webSocketClient = client;
logger.debug("Setting up '" + info.getTestMethod().get().getName() + "', client=" +
this.webSocketClient.getClass().getSimpleName() + ", server=" + this.webSocketClient.getClass().getSimpleName() + ", server=" +
this.server.getClass().getSimpleName()); this.server.getClass().getSimpleName());
@ -102,7 +111,12 @@ public abstract class AbstractWebSocketIntegrationTests {
} }
this.server.setup(); this.server.setup();
if (filter != null) {
this.server.deployConfig(this.wac, filter);
}
else {
this.server.deployConfig(this.wac); this.server.deployConfig(this.wac);
}
this.server.start(); this.server.start();
this.wac.setServletContext(this.server.getServletContext()); this.wac.setServletContext(this.server.getServletContext());

View File

@ -16,10 +16,12 @@
package org.springframework.web.socket.messaging; package org.springframework.web.socket.messaging;
import java.io.IOException;
import java.lang.annotation.Retention; import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -27,6 +29,13 @@ import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.TestInfo;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@ -165,6 +174,38 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
} }
} }
@ParameterizedWebSocketTest // gh-21798
void sendMessageToUserAndReceiveInOrder(
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
UserFilter userFilter = new UserFilter(() -> "joe");
super.setup(server, userFilter, webSocketClient, testInfo);
List<TextMessage> messages = new ArrayList<>();
messages.add(create(StompCommand.CONNECT).headers("accept-version:1.1").build());
messages.add(create(StompCommand.SUBSCRIBE).headers("id:subs1", "destination:/user/queue/foo").build());
int count = 1000;
for (int i = 0; i < count; i++) {
String dest = "destination:/user/joe/queue/foo";
messages.add(create(StompCommand.SEND).headers(dest).body(String.valueOf(i)).build());
}
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(count, messages);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
for (int i = 0; i < count; i++) {
TextMessage message = clientHandler.actual.get(i);
ByteBuffer buffer = ByteBuffer.wrap(message.asBytes());
byte[] bytes = new StompDecoder().decode(buffer).get(0).getPayload();
assertThat(new String(bytes, StandardCharsets.UTF_8)).isEqualTo(String.valueOf(i));
}
}
}
@ParameterizedWebSocketTest // SPR-11648 @ParameterizedWebSocketTest // SPR-11648
void sendSubscribeToControllerAndReceiveReply( void sendSubscribeToControllerAndReceiveReply(
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception { WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
@ -278,6 +319,30 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
} }
private static class UserFilter implements Filter {
private final Principal user;
private UserFilter(Principal user) {
this.user = user;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
request = new HttpServletRequestWrapper((HttpServletRequest) request) {
@Override
public Principal getUserPrincipal() {
return user;
}
};
chain.doFilter(request, response);
}
}
@IntegrationTestController @IntegrationTestController
static class ScopedBeanController { static class ScopedBeanController {
@ -335,14 +400,17 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override @Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception { public void afterConnectionEstablished(WebSocketSession session) throws Exception {
for (TextMessage message : this.messagesToSend) { session.sendMessage(this.messagesToSend.get(0));
session.sendMessage(message);
}
} }
@Override @Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) { protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
if (!message.getPayload().startsWith("CONNECTED")) { if (message.getPayload().startsWith("CONNECTED")) {
for (int i = 1; i < this.messagesToSend.size(); i++) {
session.sendMessage(this.messagesToSend.get(i));
}
}
else {
this.actual.add(message); this.actual.add(message);
this.latch.countDown(); this.latch.countDown();
} }
@ -371,6 +439,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
configurer.setApplicationDestinationPrefixes("/app"); configurer.setApplicationDestinationPrefixes("/app");
configurer.setPreservePublishOrder(true); configurer.setPreservePublishOrder(true);
configurer.enableSimpleBroker("/topic", "/queue").setSelectorHeaderName("selector"); configurer.enableSimpleBroker("/topic", "/queue").setSelectorHeaderName("selector");
configurer.configureBrokerChannel().taskExecutor();
} }
@Bean @Bean