Handle STOMP messages to user destination in order

Closes gh-31395
This commit is contained in:
rstoyanchev 2023-10-11 12:14:39 +01:00
parent 9eb39e182e
commit 3277b0d6ac
9 changed files with 219 additions and 25 deletions

View File

@ -6,7 +6,7 @@ written to WebSocket sessions. As the channel is backed by a `ThreadPoolExecutor
are processed in different threads, and the resulting sequence received by the client may
not match the exact order of publication.
If this is an issue, enable the `setPreservePublishOrder` flag, as the following example shows:
To enable ordered publishing, set the `setPreservePublishOrder` flag as follows:
[source,java,indent=0,subs="verbatim,quotes"]
----
@ -47,5 +47,22 @@ When the flag is set, messages within the same client session are published to t
`clientOutboundChannel` one at a time, so that the order of publication is guaranteed.
Note that this incurs a small performance overhead, so you should enable it only if it is required.
The same also applies to messages from the client, which are sent to the `clientInboundChannel`,
from where they are handled according to their destination prefix. As the channel is backed by
a `ThreadPoolExecutor`, messages are processed in different threads, and the resulting sequence
of handling may not match the exact order in which they were received.
To enable ordered publishing, set the `setPreserveReceiveOrder` flag as follows:
[source,java,indent=0,subs="verbatim,quotes"]
----
@Configuration
@EnableWebSocketMessageBroker
public class MyConfig implements WebSocketMessageBrokerConfigurer {
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.setPreserveReceiveOrder(true);
}
}
----

View File

@ -157,7 +157,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
if (simpAccessor.isMutable()) {
simpAccessor.setDestination(destination);
simpAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
simpAccessor.setImmutable();
// ImmutableMessageChannelInterceptor will make it immutable
sendInternal(message);
return;
}

View File

@ -159,6 +159,16 @@ public class OrderedMessageChannelDecorator implements MessageChannel {
}
}
/**
* Whether the channel has been {@link #configureInterceptor configured}
* with an interceptor for sequential handling.
* @since 6.1
*/
public static boolean supportsOrderedMessages(MessageChannel channel) {
return (channel instanceof ExecutorSubscribableChannel ch &&
ch.getInterceptors().stream().anyMatch(CallbackTaskInterceptor.class::isInstance));
}
/**
* Obtain the task to release the next message, if found.
*/

View File

@ -131,8 +131,9 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
}
String user = parseResult.getUser();
String sourceDest = parseResult.getSourceDestination();
Set<String> sessionIds = parseResult.getSessionIds();
Set<String> targetSet = new HashSet<>();
for (String sessionId : parseResult.getSessionIds()) {
for (String sessionId : sessionIds) {
String actualDest = parseResult.getActualDestination();
String targetDest = getTargetDestination(sourceDest, actualDest, sessionId, user);
if (targetDest != null) {
@ -140,7 +141,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
}
}
String subscribeDest = parseResult.getSubscribeDestination();
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user);
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user, sessionIds);
}
@Nullable

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,13 +17,18 @@
package org.springframework.messaging.simp.user;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.springframework.context.SmartLifecycle;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.MessagingException;
@ -33,6 +38,7 @@ import org.springframework.messaging.simp.SimpLogging;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderInitializer;
@ -61,7 +67,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
private final UserDestinationResolver destinationResolver;
private final MessageSendingOperations<String> messagingTemplate;
private final SendHelper sendHelper;
@Nullable
private BroadcastHandler broadcastHandler;
@ -91,7 +97,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
this.clientInboundChannel = clientInboundChannel;
this.brokerChannel = brokerChannel;
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
this.sendHelper = new SendHelper(clientInboundChannel, brokerChannel);
this.destinationResolver = destinationResolver;
}
@ -112,7 +118,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
*/
public void setBroadcastDestination(@Nullable String destination) {
this.broadcastHandler = (StringUtils.hasText(destination) ?
new BroadcastHandler(this.messagingTemplate, destination) : null);
new BroadcastHandler(this.sendHelper.getMessagingTemplate(), destination) : null);
}
/**
@ -128,7 +134,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
* broker channel.
*/
public MessageSendingOperations<String> getBrokerMessagingTemplate() {
return this.messagingTemplate;
return this.sendHelper.getMessagingTemplate();
}
/**
@ -193,6 +199,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
UserDestinationResult result = this.destinationResolver.resolveDestination(message);
if (result == null) {
this.sendHelper.checkDisconnect(message);
return;
}
@ -215,9 +222,8 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
if (logger.isTraceEnabled()) {
logger.trace("Translated " + result.getSourceDestination() + " -> " + result.getTargetDestinations());
}
for (String target : result.getTargetDestinations()) {
this.messagingTemplate.send(target, message);
}
this.sendHelper.send(result, message);
}
private void initHeaders(SimpMessageHeaderAccessor headerAccessor) {
@ -232,6 +238,63 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
}
private static class SendHelper {
private final MessageChannel brokerChannel;
private final MessageSendingOperations<String> messagingTemplate;
@Nullable
private final Map<String, MessageSendingOperations<String>> orderedMessagingTemplates;
SendHelper(MessageChannel clientInboundChannel, MessageChannel brokerChannel) {
this.brokerChannel = brokerChannel;
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
if (OrderedMessageChannelDecorator.supportsOrderedMessages(clientInboundChannel)) {
this.orderedMessagingTemplates = new ConcurrentHashMap<>();
OrderedMessageChannelDecorator.configureInterceptor(brokerChannel, true);
}
else {
this.orderedMessagingTemplates = null;
}
}
public MessageSendingOperations<String> getMessagingTemplate() {
return this.messagingTemplate;
}
public void send(UserDestinationResult destinationResult, Message<?> message) throws MessagingException {
Set<String> sessionIds = destinationResult.getSessionIds();
Iterator<String> itr = (sessionIds != null ? sessionIds.iterator() : null);
for (String target : destinationResult.getTargetDestinations()) {
String sessionId = (itr != null ? itr.next() : null);
getTemplateToUse(sessionId).send(target, message);
}
}
private MessageSendingOperations<String> getTemplateToUse(@Nullable String sessionId) {
if (this.orderedMessagingTemplates != null && sessionId != null) {
return this.orderedMessagingTemplates.computeIfAbsent(sessionId, id ->
new SimpMessagingTemplate(new OrderedMessageChannelDecorator(this.brokerChannel, logger)));
}
return this.messagingTemplate;
}
public void checkDisconnect(Message<?> message) {
if (this.orderedMessagingTemplates != null) {
MessageHeaders headers = message.getHeaders();
if (SimpMessageHeaderAccessor.getMessageType(headers) == SimpMessageType.DISCONNECT) {
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (sessionId != null) {
this.orderedMessagingTemplates.remove(sessionId);
}
}
}
}
}
/**
* A handler that broadcasts locally unresolved messages to the broker and
* also handles similar broadcasts received from the broker.

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
package org.springframework.messaging.simp.user;
import java.util.Collections;
import java.util.Set;
import org.springframework.lang.Nullable;
@ -40,10 +41,23 @@ public class UserDestinationResult {
@Nullable
private final String user;
private final Set<String> sessionIds;
public UserDestinationResult(String sourceDestination, Set<String> targetDestinations,
String subscribeDestination, @Nullable String user) {
this(sourceDestination, targetDestinations, subscribeDestination, user, null);
}
/**
* Additional constructor with the session id for each targetDestination.
* @since 6.1
*/
public UserDestinationResult(
String sourceDestination, Set<String> targetDestinations,
String subscribeDestination, @Nullable String user, @Nullable Set<String> sessionIds) {
Assert.notNull(sourceDestination, "'sourceDestination' must not be null");
Assert.notNull(targetDestinations, "'targetDestinations' must not be null");
Assert.notNull(subscribeDestination, "'subscribeDestination' must not be null");
@ -52,6 +66,7 @@ public class UserDestinationResult {
this.targetDestinations = targetDestinations;
this.subscribeDestination = subscribeDestination;
this.user = user;
this.sessionIds = (sessionIds != null ? sessionIds : Collections.emptySet());
}
@ -96,6 +111,13 @@ public class UserDestinationResult {
return this.user;
}
/**
* Return the session id for the targetDestination.
*/
@Nullable
public Set<String> getSessionIds() {
return this.sessionIds;
}
@Override
public String toString() {

View File

@ -158,7 +158,6 @@ public class SimpMessagingTemplateTests {
Message<byte[]> message = messages.get(0);
assertThat(message.getHeaders()).isSameAs(headers);
assertThat(accessor.isMutable()).isFalse();
}
@Test
@ -190,7 +189,6 @@ public class SimpMessagingTemplateTests {
Message<byte[]> sentMessage = messages.get(0);
assertThat(sentMessage).isSameAs(message);
assertThat(accessor.isMutable()).isFalse();
}
@Test

View File

@ -24,6 +24,7 @@ import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import jakarta.servlet.Filter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.AfterEach;
@ -35,6 +36,7 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.context.Lifecycle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.lang.Nullable;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
@ -85,11 +87,18 @@ public abstract class AbstractWebSocketIntegrationTests {
protected AnnotationConfigWebApplicationContext wac;
protected void setup(WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
this.server = server;
this.webSocketClient = webSocketClient;
protected void setup(WebSocketTestServer server, WebSocketClient client, TestInfo info) throws Exception {
setup(server, null, client, info);
}
logger.debug("Setting up '" + testInfo.getTestMethod().get().getName() + "', client=" +
protected void setup(
WebSocketTestServer server, @Nullable Filter filter, WebSocketClient client, TestInfo info)
throws Exception {
this.server = server;
this.webSocketClient = client;
logger.debug("Setting up '" + info.getTestMethod().get().getName() + "', client=" +
this.webSocketClient.getClass().getSimpleName() + ", server=" +
this.server.getClass().getSimpleName());
@ -102,7 +111,12 @@ public abstract class AbstractWebSocketIntegrationTests {
}
this.server.setup();
this.server.deployConfig(this.wac);
if (filter != null) {
this.server.deployConfig(this.wac, filter);
}
else {
this.server.deployConfig(this.wac);
}
this.server.start();
this.wac.setServletContext(this.server.getServletContext());

View File

@ -16,10 +16,12 @@
package org.springframework.web.socket.messaging;
import java.io.IOException;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@ -27,6 +29,13 @@ import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import org.junit.jupiter.api.TestInfo;
import org.springframework.beans.factory.annotation.Autowired;
@ -165,6 +174,38 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
}
}
@ParameterizedWebSocketTest // gh-21798
void sendMessageToUserAndReceiveInOrder(
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
UserFilter userFilter = new UserFilter(() -> "joe");
super.setup(server, userFilter, webSocketClient, testInfo);
List<TextMessage> messages = new ArrayList<>();
messages.add(create(StompCommand.CONNECT).headers("accept-version:1.1").build());
messages.add(create(StompCommand.SUBSCRIBE).headers("id:subs1", "destination:/user/queue/foo").build());
int count = 1000;
for (int i = 0; i < count; i++) {
String dest = "destination:/user/joe/queue/foo";
messages.add(create(StompCommand.SEND).headers(dest).body(String.valueOf(i)).build());
}
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(count, messages);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
for (int i = 0; i < count; i++) {
TextMessage message = clientHandler.actual.get(i);
ByteBuffer buffer = ByteBuffer.wrap(message.asBytes());
byte[] bytes = new StompDecoder().decode(buffer).get(0).getPayload();
assertThat(new String(bytes, StandardCharsets.UTF_8)).isEqualTo(String.valueOf(i));
}
}
}
@ParameterizedWebSocketTest // SPR-11648
void sendSubscribeToControllerAndReceiveReply(
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
@ -278,6 +319,30 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
}
private static class UserFilter implements Filter {
private final Principal user;
private UserFilter(Principal user) {
this.user = user;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
request = new HttpServletRequestWrapper((HttpServletRequest) request) {
@Override
public Principal getUserPrincipal() {
return user;
}
};
chain.doFilter(request, response);
}
}
@IntegrationTestController
static class ScopedBeanController {
@ -335,14 +400,17 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
for (TextMessage message : this.messagesToSend) {
session.sendMessage(message);
}
session.sendMessage(this.messagesToSend.get(0));
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
if (!message.getPayload().startsWith("CONNECTED")) {
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
if (message.getPayload().startsWith("CONNECTED")) {
for (int i = 1; i < this.messagesToSend.size(); i++) {
session.sendMessage(this.messagesToSend.get(i));
}
}
else {
this.actual.add(message);
this.latch.countDown();
}
@ -371,6 +439,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
configurer.setApplicationDestinationPrefixes("/app");
configurer.setPreservePublishOrder(true);
configurer.enableSimpleBroker("/topic", "/queue").setSelectorHeaderName("selector");
configurer.configureBrokerChannel().taskExecutor();
}
@Bean