Handle STOMP messages to user destination in order
Closes gh-31395
This commit is contained in:
parent
9eb39e182e
commit
3277b0d6ac
|
@ -6,7 +6,7 @@ written to WebSocket sessions. As the channel is backed by a `ThreadPoolExecutor
|
|||
are processed in different threads, and the resulting sequence received by the client may
|
||||
not match the exact order of publication.
|
||||
|
||||
If this is an issue, enable the `setPreservePublishOrder` flag, as the following example shows:
|
||||
To enable ordered publishing, set the `setPreservePublishOrder` flag as follows:
|
||||
|
||||
[source,java,indent=0,subs="verbatim,quotes"]
|
||||
----
|
||||
|
@ -47,5 +47,22 @@ When the flag is set, messages within the same client session are published to t
|
|||
`clientOutboundChannel` one at a time, so that the order of publication is guaranteed.
|
||||
Note that this incurs a small performance overhead, so you should enable it only if it is required.
|
||||
|
||||
The same also applies to messages from the client, which are sent to the `clientInboundChannel`,
|
||||
from where they are handled according to their destination prefix. As the channel is backed by
|
||||
a `ThreadPoolExecutor`, messages are processed in different threads, and the resulting sequence
|
||||
of handling may not match the exact order in which they were received.
|
||||
|
||||
To enable ordered publishing, set the `setPreserveReceiveOrder` flag as follows:
|
||||
|
||||
[source,java,indent=0,subs="verbatim,quotes"]
|
||||
----
|
||||
@Configuration
|
||||
@EnableWebSocketMessageBroker
|
||||
public class MyConfig implements WebSocketMessageBrokerConfigurer {
|
||||
|
||||
@Override
|
||||
public void registerStompEndpoints(StompEndpointRegistry registry) {
|
||||
registry.setPreserveReceiveOrder(true);
|
||||
}
|
||||
}
|
||||
----
|
||||
|
|
|
@ -157,7 +157,7 @@ public class SimpMessagingTemplate extends AbstractMessageSendingTemplate<String
|
|||
if (simpAccessor.isMutable()) {
|
||||
simpAccessor.setDestination(destination);
|
||||
simpAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
|
||||
simpAccessor.setImmutable();
|
||||
// ImmutableMessageChannelInterceptor will make it immutable
|
||||
sendInternal(message);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -159,6 +159,16 @@ public class OrderedMessageChannelDecorator implements MessageChannel {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the channel has been {@link #configureInterceptor configured}
|
||||
* with an interceptor for sequential handling.
|
||||
* @since 6.1
|
||||
*/
|
||||
public static boolean supportsOrderedMessages(MessageChannel channel) {
|
||||
return (channel instanceof ExecutorSubscribableChannel ch &&
|
||||
ch.getInterceptors().stream().anyMatch(CallbackTaskInterceptor.class::isInstance));
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain the task to release the next message, if found.
|
||||
*/
|
||||
|
|
|
@ -131,8 +131,9 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
|
|||
}
|
||||
String user = parseResult.getUser();
|
||||
String sourceDest = parseResult.getSourceDestination();
|
||||
Set<String> sessionIds = parseResult.getSessionIds();
|
||||
Set<String> targetSet = new HashSet<>();
|
||||
for (String sessionId : parseResult.getSessionIds()) {
|
||||
for (String sessionId : sessionIds) {
|
||||
String actualDest = parseResult.getActualDestination();
|
||||
String targetDest = getTargetDestination(sourceDest, actualDest, sessionId, user);
|
||||
if (targetDest != null) {
|
||||
|
@ -140,7 +141,7 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
|
|||
}
|
||||
}
|
||||
String subscribeDest = parseResult.getSubscribeDestination();
|
||||
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user);
|
||||
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user, sessionIds);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2021 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,13 +17,18 @@
|
|||
package org.springframework.messaging.simp.user;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
|
||||
import org.springframework.context.SmartLifecycle;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.MessageHandler;
|
||||
import org.springframework.messaging.MessageHeaders;
|
||||
import org.springframework.messaging.MessagingException;
|
||||
|
@ -33,6 +38,7 @@ import org.springframework.messaging.simp.SimpLogging;
|
|||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
||||
import org.springframework.messaging.simp.SimpMessageType;
|
||||
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
import org.springframework.messaging.support.MessageHeaderAccessor;
|
||||
import org.springframework.messaging.support.MessageHeaderInitializer;
|
||||
|
@ -61,7 +67,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
|
|||
|
||||
private final UserDestinationResolver destinationResolver;
|
||||
|
||||
private final MessageSendingOperations<String> messagingTemplate;
|
||||
private final SendHelper sendHelper;
|
||||
|
||||
@Nullable
|
||||
private BroadcastHandler broadcastHandler;
|
||||
|
@ -91,7 +97,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
|
|||
|
||||
this.clientInboundChannel = clientInboundChannel;
|
||||
this.brokerChannel = brokerChannel;
|
||||
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
|
||||
this.sendHelper = new SendHelper(clientInboundChannel, brokerChannel);
|
||||
this.destinationResolver = destinationResolver;
|
||||
}
|
||||
|
||||
|
@ -112,7 +118,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
|
|||
*/
|
||||
public void setBroadcastDestination(@Nullable String destination) {
|
||||
this.broadcastHandler = (StringUtils.hasText(destination) ?
|
||||
new BroadcastHandler(this.messagingTemplate, destination) : null);
|
||||
new BroadcastHandler(this.sendHelper.getMessagingTemplate(), destination) : null);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -128,7 +134,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
|
|||
* broker channel.
|
||||
*/
|
||||
public MessageSendingOperations<String> getBrokerMessagingTemplate() {
|
||||
return this.messagingTemplate;
|
||||
return this.sendHelper.getMessagingTemplate();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -193,6 +199,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
|
|||
|
||||
UserDestinationResult result = this.destinationResolver.resolveDestination(message);
|
||||
if (result == null) {
|
||||
this.sendHelper.checkDisconnect(message);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -215,9 +222,8 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
|
|||
if (logger.isTraceEnabled()) {
|
||||
logger.trace("Translated " + result.getSourceDestination() + " -> " + result.getTargetDestinations());
|
||||
}
|
||||
for (String target : result.getTargetDestinations()) {
|
||||
this.messagingTemplate.send(target, message);
|
||||
}
|
||||
|
||||
this.sendHelper.send(result, message);
|
||||
}
|
||||
|
||||
private void initHeaders(SimpMessageHeaderAccessor headerAccessor) {
|
||||
|
@ -232,6 +238,63 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
|
|||
}
|
||||
|
||||
|
||||
private static class SendHelper {
|
||||
|
||||
private final MessageChannel brokerChannel;
|
||||
|
||||
private final MessageSendingOperations<String> messagingTemplate;
|
||||
|
||||
@Nullable
|
||||
private final Map<String, MessageSendingOperations<String>> orderedMessagingTemplates;
|
||||
|
||||
SendHelper(MessageChannel clientInboundChannel, MessageChannel brokerChannel) {
|
||||
this.brokerChannel = brokerChannel;
|
||||
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
|
||||
if (OrderedMessageChannelDecorator.supportsOrderedMessages(clientInboundChannel)) {
|
||||
this.orderedMessagingTemplates = new ConcurrentHashMap<>();
|
||||
OrderedMessageChannelDecorator.configureInterceptor(brokerChannel, true);
|
||||
}
|
||||
else {
|
||||
this.orderedMessagingTemplates = null;
|
||||
}
|
||||
}
|
||||
|
||||
public MessageSendingOperations<String> getMessagingTemplate() {
|
||||
return this.messagingTemplate;
|
||||
}
|
||||
|
||||
public void send(UserDestinationResult destinationResult, Message<?> message) throws MessagingException {
|
||||
Set<String> sessionIds = destinationResult.getSessionIds();
|
||||
Iterator<String> itr = (sessionIds != null ? sessionIds.iterator() : null);
|
||||
|
||||
for (String target : destinationResult.getTargetDestinations()) {
|
||||
String sessionId = (itr != null ? itr.next() : null);
|
||||
getTemplateToUse(sessionId).send(target, message);
|
||||
}
|
||||
}
|
||||
|
||||
private MessageSendingOperations<String> getTemplateToUse(@Nullable String sessionId) {
|
||||
if (this.orderedMessagingTemplates != null && sessionId != null) {
|
||||
return this.orderedMessagingTemplates.computeIfAbsent(sessionId, id ->
|
||||
new SimpMessagingTemplate(new OrderedMessageChannelDecorator(this.brokerChannel, logger)));
|
||||
}
|
||||
return this.messagingTemplate;
|
||||
}
|
||||
|
||||
public void checkDisconnect(Message<?> message) {
|
||||
if (this.orderedMessagingTemplates != null) {
|
||||
MessageHeaders headers = message.getHeaders();
|
||||
if (SimpMessageHeaderAccessor.getMessageType(headers) == SimpMessageType.DISCONNECT) {
|
||||
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
|
||||
if (sessionId != null) {
|
||||
this.orderedMessagingTemplates.remove(sessionId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A handler that broadcasts locally unresolved messages to the broker and
|
||||
* also handles similar broadcasts received from the broker.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.springframework.messaging.simp.user;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Set;
|
||||
|
||||
import org.springframework.lang.Nullable;
|
||||
|
@ -40,10 +41,23 @@ public class UserDestinationResult {
|
|||
@Nullable
|
||||
private final String user;
|
||||
|
||||
private final Set<String> sessionIds;
|
||||
|
||||
|
||||
public UserDestinationResult(String sourceDestination, Set<String> targetDestinations,
|
||||
String subscribeDestination, @Nullable String user) {
|
||||
|
||||
this(sourceDestination, targetDestinations, subscribeDestination, user, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Additional constructor with the session id for each targetDestination.
|
||||
* @since 6.1
|
||||
*/
|
||||
public UserDestinationResult(
|
||||
String sourceDestination, Set<String> targetDestinations,
|
||||
String subscribeDestination, @Nullable String user, @Nullable Set<String> sessionIds) {
|
||||
|
||||
Assert.notNull(sourceDestination, "'sourceDestination' must not be null");
|
||||
Assert.notNull(targetDestinations, "'targetDestinations' must not be null");
|
||||
Assert.notNull(subscribeDestination, "'subscribeDestination' must not be null");
|
||||
|
@ -52,6 +66,7 @@ public class UserDestinationResult {
|
|||
this.targetDestinations = targetDestinations;
|
||||
this.subscribeDestination = subscribeDestination;
|
||||
this.user = user;
|
||||
this.sessionIds = (sessionIds != null ? sessionIds : Collections.emptySet());
|
||||
}
|
||||
|
||||
|
||||
|
@ -96,6 +111,13 @@ public class UserDestinationResult {
|
|||
return this.user;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the session id for the targetDestination.
|
||||
*/
|
||||
@Nullable
|
||||
public Set<String> getSessionIds() {
|
||||
return this.sessionIds;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
|
|
@ -158,7 +158,6 @@ public class SimpMessagingTemplateTests {
|
|||
Message<byte[]> message = messages.get(0);
|
||||
|
||||
assertThat(message.getHeaders()).isSameAs(headers);
|
||||
assertThat(accessor.isMutable()).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -190,7 +189,6 @@ public class SimpMessagingTemplateTests {
|
|||
Message<byte[]> sentMessage = messages.get(0);
|
||||
|
||||
assertThat(sentMessage).isSameAs(message);
|
||||
assertThat(accessor.isMutable()).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.Map;
|
|||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import jakarta.servlet.Filter;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
|
@ -35,6 +36,7 @@ import org.junit.jupiter.params.provider.MethodSource;
|
|||
import org.springframework.context.Lifecycle;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
|
||||
import org.springframework.web.socket.client.WebSocketClient;
|
||||
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
|
||||
|
@ -85,11 +87,18 @@ public abstract class AbstractWebSocketIntegrationTests {
|
|||
protected AnnotationConfigWebApplicationContext wac;
|
||||
|
||||
|
||||
protected void setup(WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
|
||||
this.server = server;
|
||||
this.webSocketClient = webSocketClient;
|
||||
protected void setup(WebSocketTestServer server, WebSocketClient client, TestInfo info) throws Exception {
|
||||
setup(server, null, client, info);
|
||||
}
|
||||
|
||||
logger.debug("Setting up '" + testInfo.getTestMethod().get().getName() + "', client=" +
|
||||
protected void setup(
|
||||
WebSocketTestServer server, @Nullable Filter filter, WebSocketClient client, TestInfo info)
|
||||
throws Exception {
|
||||
|
||||
this.server = server;
|
||||
this.webSocketClient = client;
|
||||
|
||||
logger.debug("Setting up '" + info.getTestMethod().get().getName() + "', client=" +
|
||||
this.webSocketClient.getClass().getSimpleName() + ", server=" +
|
||||
this.server.getClass().getSimpleName());
|
||||
|
||||
|
@ -102,7 +111,12 @@ public abstract class AbstractWebSocketIntegrationTests {
|
|||
}
|
||||
|
||||
this.server.setup();
|
||||
this.server.deployConfig(this.wac);
|
||||
if (filter != null) {
|
||||
this.server.deployConfig(this.wac, filter);
|
||||
}
|
||||
else {
|
||||
this.server.deployConfig(this.wac);
|
||||
}
|
||||
this.server.start();
|
||||
|
||||
this.wac.setServletContext(this.server.getServletContext());
|
||||
|
|
|
@ -16,10 +16,12 @@
|
|||
|
||||
package org.springframework.web.socket.messaging;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.Principal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
@ -27,6 +29,13 @@ import java.util.concurrent.CopyOnWriteArrayList;
|
|||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import jakarta.servlet.Filter;
|
||||
import jakarta.servlet.FilterChain;
|
||||
import jakarta.servlet.ServletException;
|
||||
import jakarta.servlet.ServletRequest;
|
||||
import jakarta.servlet.ServletResponse;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletRequestWrapper;
|
||||
import org.junit.jupiter.api.TestInfo;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
@ -165,6 +174,38 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedWebSocketTest // gh-21798
|
||||
void sendMessageToUserAndReceiveInOrder(
|
||||
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
|
||||
|
||||
UserFilter userFilter = new UserFilter(() -> "joe");
|
||||
super.setup(server, userFilter, webSocketClient, testInfo);
|
||||
|
||||
List<TextMessage> messages = new ArrayList<>();
|
||||
messages.add(create(StompCommand.CONNECT).headers("accept-version:1.1").build());
|
||||
messages.add(create(StompCommand.SUBSCRIBE).headers("id:subs1", "destination:/user/queue/foo").build());
|
||||
|
||||
int count = 1000;
|
||||
for (int i = 0; i < count; i++) {
|
||||
String dest = "destination:/user/joe/queue/foo";
|
||||
messages.add(create(StompCommand.SEND).headers(dest).body(String.valueOf(i)).build());
|
||||
}
|
||||
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(count, messages);
|
||||
|
||||
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
|
||||
assertThat(session).isNotNull();
|
||||
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
|
||||
|
||||
for (int i = 0; i < count; i++) {
|
||||
TextMessage message = clientHandler.actual.get(i);
|
||||
ByteBuffer buffer = ByteBuffer.wrap(message.asBytes());
|
||||
byte[] bytes = new StompDecoder().decode(buffer).get(0).getPayload();
|
||||
assertThat(new String(bytes, StandardCharsets.UTF_8)).isEqualTo(String.valueOf(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedWebSocketTest // SPR-11648
|
||||
void sendSubscribeToControllerAndReceiveReply(
|
||||
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
|
||||
|
@ -278,6 +319,30 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
}
|
||||
|
||||
|
||||
private static class UserFilter implements Filter {
|
||||
|
||||
private final Principal user;
|
||||
|
||||
private UserFilter(Principal user) {
|
||||
this.user = user;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
|
||||
throws IOException, ServletException {
|
||||
|
||||
request = new HttpServletRequestWrapper((HttpServletRequest) request) {
|
||||
@Override
|
||||
public Principal getUserPrincipal() {
|
||||
return user;
|
||||
}
|
||||
};
|
||||
|
||||
chain.doFilter(request, response);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@IntegrationTestController
|
||||
static class ScopedBeanController {
|
||||
|
||||
|
@ -335,14 +400,17 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
|
||||
@Override
|
||||
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
||||
for (TextMessage message : this.messagesToSend) {
|
||||
session.sendMessage(message);
|
||||
}
|
||||
session.sendMessage(this.messagesToSend.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
|
||||
if (!message.getPayload().startsWith("CONNECTED")) {
|
||||
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
|
||||
if (message.getPayload().startsWith("CONNECTED")) {
|
||||
for (int i = 1; i < this.messagesToSend.size(); i++) {
|
||||
session.sendMessage(this.messagesToSend.get(i));
|
||||
}
|
||||
}
|
||||
else {
|
||||
this.actual.add(message);
|
||||
this.latch.countDown();
|
||||
}
|
||||
|
@ -371,6 +439,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
configurer.setApplicationDestinationPrefixes("/app");
|
||||
configurer.setPreservePublishOrder(true);
|
||||
configurer.enableSimpleBroker("/topic", "/queue").setSelectorHeaderName("selector");
|
||||
configurer.configureBrokerChannel().taskExecutor();
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
|
Loading…
Reference in New Issue