parent
4195e6906c
commit
a205eab618
|
|
@ -113,7 +113,12 @@ public class OrderedMessageChannelDecorator implements MessageChannel {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove the message from the top of the queue, but only if it matches,
|
||||
* i.e. hasn't been removed already.
|
||||
*/
|
||||
private boolean removeMessage(Message<?> message) {
|
||||
// Remove only if not removed already
|
||||
Message<?> next = this.messages.peek();
|
||||
if (next == message) {
|
||||
this.messages.remove();
|
||||
|
|
@ -181,7 +186,7 @@ public class OrderedMessageChannelDecorator implements MessageChannel {
|
|||
@Override
|
||||
public void run() {
|
||||
if (this.handledCount == null || this.handledCount.addAndGet(1) == subscriberCount) {
|
||||
if (OrderedMessageChannelDecorator.this.removeMessage(message)) {
|
||||
if (OrderedMessageChannelDecorator.this.removeMessage(this.message)) {
|
||||
sendNextMessage();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
* Copyright 2002-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
@ -52,4 +52,18 @@ public interface StompEndpointRegistry {
|
|||
*/
|
||||
WebMvcStompEndpointRegistry setErrorHandler(StompSubProtocolErrorHandler errorHandler);
|
||||
|
||||
/**
|
||||
* Whether to handle client messages sequentially in the order in which
|
||||
* they were received.
|
||||
* <p>By default messages sent to the {@code "clientInboundChannel"} may
|
||||
* be handled in parallel and not in the same order as they were received
|
||||
* because the channel is backed by a ThreadPoolExecutor that in turn does
|
||||
* not guarantee processing in order.
|
||||
* <p>When this flag is set to {@code true} messages within the same session
|
||||
* will be sent to the {@code "clientInboundChannel"} one at a time in
|
||||
* order to preserve the order in which they were received.
|
||||
* @since 6.1
|
||||
*/
|
||||
WebMvcStompEndpointRegistry setPreserveReceiveOrder(boolean preserveReceiveOrder);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -142,6 +142,15 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry {
|
|||
return this;
|
||||
}
|
||||
|
||||
public WebMvcStompEndpointRegistry setPreserveReceiveOrder(boolean preserveReceiveOrder) {
|
||||
this.stompHandler.setPreserveReceiveOrder(preserveReceiveOrder);
|
||||
return this;
|
||||
}
|
||||
|
||||
protected boolean isPreserveReceiveOrder() {
|
||||
return this.stompHandler.isPreserveReceiveOrder();
|
||||
}
|
||||
|
||||
protected void setApplicationContext(ApplicationContext applicationContext) {
|
||||
this.stompHandler.setApplicationEventPublisher(applicationContext);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import org.springframework.messaging.simp.SimpMessagingTemplate;
|
|||
import org.springframework.messaging.simp.SimpSessionScope;
|
||||
import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler;
|
||||
import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler;
|
||||
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
|
||||
import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration;
|
||||
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
|
||||
import org.springframework.messaging.simp.user.SimpUserRegistry;
|
||||
|
|
@ -80,7 +81,8 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
|
|||
|
||||
@Bean
|
||||
public HandlerMapping stompWebSocketHandlerMapping(
|
||||
WebSocketHandler subProtocolWebSocketHandler, TaskScheduler messageBrokerTaskScheduler) {
|
||||
WebSocketHandler subProtocolWebSocketHandler, TaskScheduler messageBrokerTaskScheduler,
|
||||
AbstractSubscribableChannel clientInboundChannel) {
|
||||
|
||||
WebSocketHandler handler = decorateWebSocketHandler(subProtocolWebSocketHandler);
|
||||
WebMvcStompEndpointRegistry registry =
|
||||
|
|
@ -90,6 +92,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
|
|||
registry.setApplicationContext(applicationContext);
|
||||
}
|
||||
registerStompEndpoints(registry);
|
||||
OrderedMessageChannelDecorator.configureInterceptor(clientInboundChannel, registry.isPreserveReceiveOrder());
|
||||
return registry.getHandlerMapping();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -108,6 +108,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
|
|||
@Nullable
|
||||
private MessageHeaderInitializer headerInitializer;
|
||||
|
||||
private boolean preserveReceiveOrder;
|
||||
|
||||
private final Map<String, MessageChannel> messageChannels = new ConcurrentHashMap<>();
|
||||
|
||||
private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<>();
|
||||
|
||||
@Nullable
|
||||
|
|
@ -193,6 +197,30 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
|
|||
return this.headerInitializer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether client messages must be handled in the order received.
|
||||
* <p>By default messages sent to the {@code "clientInboundChannel"} may
|
||||
* not be handled in the same order because the channel is backed by a
|
||||
* ThreadPoolExecutor that in turn does not guarantee processing in order.
|
||||
* <p>When this flag is set to {@code true} messages within the same session
|
||||
* will be sent to the {@code "clientInboundChannel"} one at a time to
|
||||
* preserve the order in which they were received.
|
||||
* @param preserveReceiveOrder whether to publish in order
|
||||
* @since 6.1
|
||||
*/
|
||||
public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
|
||||
this.preserveReceiveOrder = preserveReceiveOrder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the handler is configured to handle inbound messages in the
|
||||
* order in which they were received.
|
||||
* @since 6.1
|
||||
*/
|
||||
public boolean isPreserveReceiveOrder() {
|
||||
return this.preserveReceiveOrder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getSupportedProtocols() {
|
||||
return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
|
||||
|
|
@ -268,6 +296,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
|
|||
return;
|
||||
}
|
||||
|
||||
MessageChannel channelToUse =
|
||||
(this.messageChannels.computeIfAbsent(session.getId(),
|
||||
id -> this.preserveReceiveOrder ?
|
||||
new OrderedMessageChannelDecorator(outputChannel, logger) :
|
||||
outputChannel));
|
||||
|
||||
for (Message<byte[]> message : messages) {
|
||||
StompHeaderAccessor headerAccessor =
|
||||
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
|
||||
|
|
@ -307,7 +341,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
|
|||
|
||||
try {
|
||||
SimpAttributesContextHolder.setAttributesFromMessage(message);
|
||||
sent = outputChannel.send(message);
|
||||
sent = channelToUse.send(message);
|
||||
|
||||
if (sent) {
|
||||
if (this.eventPublisher != null) {
|
||||
|
|
@ -652,6 +686,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
|
|||
outputChannel.send(message);
|
||||
}
|
||||
finally {
|
||||
this.messageChannels.remove(session.getId());
|
||||
this.stompAuthentications.remove(session.getId());
|
||||
SimpAttributesContextHolder.resetAttributes();
|
||||
simpAttributes.sessionCompleted();
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ package org.springframework.web.socket.messaging;
|
|||
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
|
@ -28,25 +32,22 @@ import org.junit.jupiter.api.TestInfo;
|
|||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.ComponentScan;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.context.annotation.Scope;
|
||||
import org.springframework.context.annotation.ScopedProxyMode;
|
||||
import org.springframework.core.task.TaskExecutor;
|
||||
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
|
||||
import org.springframework.messaging.handler.annotation.MessageMapping;
|
||||
import org.springframework.messaging.simp.annotation.SendToUser;
|
||||
import org.springframework.messaging.simp.annotation.SubscribeMapping;
|
||||
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
|
||||
import org.springframework.messaging.simp.stomp.StompCommand;
|
||||
import org.springframework.messaging.support.AbstractSubscribableChannel;
|
||||
import org.springframework.messaging.support.ExecutorSubscribableChannel;
|
||||
import org.springframework.messaging.simp.stomp.StompDecoder;
|
||||
import org.springframework.stereotype.Controller;
|
||||
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
|
||||
import org.springframework.web.socket.TextMessage;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.WebSocketTestServer;
|
||||
import org.springframework.web.socket.client.WebSocketClient;
|
||||
import org.springframework.web.socket.config.annotation.DelegatingWebSocketMessageBrokerConfiguration;
|
||||
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
|
||||
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
|
||||
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
|
||||
import org.springframework.web.socket.handler.TextWebSocketHandler;
|
||||
|
|
@ -69,7 +70,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
|
||||
@Override
|
||||
protected Class<?>[] getAnnotatedConfigClasses() {
|
||||
return new Class<?>[] {TestMessageBrokerConfiguration.class, TestMessageBrokerConfigurer.class};
|
||||
return new Class<?>[] {TestMessageBrokerConfigurer.class};
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -100,7 +101,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
TextMessage m2 = create(StompCommand.SEND)
|
||||
.headers("destination:/app/increment").body("5").build();
|
||||
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
|
||||
|
||||
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
|
||||
assertThat(session).isNotNull();
|
||||
|
|
@ -121,17 +122,49 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destination, selector).build();
|
||||
TextMessage m2 = create(StompCommand.SEND).headers(destination, "foo:bar").body("5").build();
|
||||
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
|
||||
|
||||
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
|
||||
assertThat(session).isNotNull();
|
||||
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
|
||||
|
||||
String payload = clientHandler.actual.get(1).getPayload();
|
||||
String payload = clientHandler.actual.get(0).getPayload();
|
||||
assertThat(payload).as("Expected STOMP MESSAGE, got " + payload).startsWith("MESSAGE\n");
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedWebSocketTest // gh-21798
|
||||
void sendMessageToBrokerAndReceiveInOrder(
|
||||
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
|
||||
|
||||
super.setup(server, webSocketClient, testInfo);
|
||||
|
||||
String destination = "destination:/topic/foo";
|
||||
|
||||
List<TextMessage> messages = new ArrayList<>();
|
||||
messages.add(create(StompCommand.CONNECT).headers("accept-version:1.1").build());
|
||||
messages.add(create(StompCommand.SUBSCRIBE).headers("id:subs1", destination).build());
|
||||
|
||||
int count = 1000;
|
||||
for (int i = 0; i < count; i++) {
|
||||
messages.add(create(StompCommand.SEND).headers(destination).body(String.valueOf(i)).build());
|
||||
}
|
||||
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(count, messages);
|
||||
|
||||
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
|
||||
assertThat(session).isNotNull();
|
||||
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
|
||||
|
||||
for (int i = 0; i < count; i++) {
|
||||
TextMessage message = clientHandler.actual.get(i);
|
||||
ByteBuffer buffer = ByteBuffer.wrap(message.asBytes());
|
||||
byte[] bytes = new StompDecoder().decode(buffer).get(0).getPayload();
|
||||
assertThat(new String(bytes, StandardCharsets.UTF_8)).isEqualTo(String.valueOf(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedWebSocketTest // SPR-11648
|
||||
void sendSubscribeToControllerAndReceiveReply(
|
||||
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
|
||||
|
|
@ -142,12 +175,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
String destHeader = "destination:/app/number";
|
||||
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build();
|
||||
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1);
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1);
|
||||
|
||||
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
|
||||
assertThat(session).isNotNull();
|
||||
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
|
||||
String payload = clientHandler.actual.get(1).getPayload();
|
||||
String payload = clientHandler.actual.get(0).getPayload();
|
||||
assertThat(payload).as("Expected STOMP destination=/app/number, got " + payload).contains(destHeader);
|
||||
assertThat(payload).as("Expected STOMP Payload=42, got " + payload).contains("42");
|
||||
}
|
||||
|
|
@ -164,12 +197,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build();
|
||||
TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/exception").build();
|
||||
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
|
||||
|
||||
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
|
||||
assertThat(session).isNotNull();
|
||||
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
|
||||
String payload = clientHandler.actual.get(1).getPayload();
|
||||
String payload = clientHandler.actual.get(0).getPayload();
|
||||
assertThat(payload).startsWith("MESSAGE\n");
|
||||
assertThat(payload).contains("destination:/user/queue/error\n");
|
||||
assertThat(payload).endsWith("Got error: Bad input\0");
|
||||
|
|
@ -188,12 +221,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
TextMessage m2 = create(StompCommand.SEND)
|
||||
.headers("destination:/app/scopedBeanValue").build();
|
||||
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
|
||||
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
|
||||
|
||||
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
|
||||
assertThat(session).isNotNull();
|
||||
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
|
||||
String payload = clientHandler.actual.get(1).getPayload();
|
||||
String payload = clientHandler.actual.get(0).getPayload();
|
||||
assertThat(payload).startsWith("MESSAGE\n");
|
||||
assertThat(payload).contains("destination:/topic/scopedBeanValue\n");
|
||||
assertThat(payload).endsWith("55\0");
|
||||
|
|
@ -285,18 +318,19 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
|
||||
private static class TestClientWebSocketHandler extends TextWebSocketHandler {
|
||||
|
||||
private final TextMessage[] messagesToSend;
|
||||
|
||||
private final int expected;
|
||||
private final List<TextMessage> messagesToSend;
|
||||
|
||||
private final List<TextMessage> actual = new CopyOnWriteArrayList<>();
|
||||
|
||||
private final CountDownLatch latch;
|
||||
|
||||
public TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) {
|
||||
TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) {
|
||||
this(expectedNumberOfMessages, Arrays.asList(messagesToSend));
|
||||
}
|
||||
|
||||
TestClientWebSocketHandler(int expectedNumberOfMessages, List<TextMessage> messagesToSend) {
|
||||
this.messagesToSend = messagesToSend;
|
||||
this.expected = expectedNumberOfMessages;
|
||||
this.latch = new CountDownLatch(this.expected);
|
||||
this.latch = new CountDownLatch(expectedNumberOfMessages);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -307,18 +341,20 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
|
||||
this.actual.add(message);
|
||||
this.latch.countDown();
|
||||
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
|
||||
if (!message.getPayload().startsWith("CONNECTED")) {
|
||||
this.actual.add(message);
|
||||
this.latch.countDown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Configuration
|
||||
@ComponentScan(
|
||||
basePackageClasses = StompWebSocketIntegrationTests.class,
|
||||
useDefaultFilters = false,
|
||||
includeFilters = @ComponentScan.Filter(IntegrationTestController.class))
|
||||
@EnableWebSocketMessageBroker
|
||||
static class TestMessageBrokerConfigurer implements WebSocketMessageBrokerConfigurer {
|
||||
|
||||
@Autowired
|
||||
|
|
@ -326,12 +362,14 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
|
||||
@Override
|
||||
public void registerStompEndpoints(StompEndpointRegistry registry) {
|
||||
registry.setPreserveReceiveOrder(true);
|
||||
registry.addEndpoint("/ws").setHandshakeHandler(this.handshakeHandler);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void configureMessageBroker(MessageBrokerRegistry configurer) {
|
||||
configurer.setApplicationDestinationPrefixes("/app");
|
||||
configurer.setPreservePublishOrder(true);
|
||||
configurer.enableSimpleBroker("/topic", "/queue").setSelectorHeaderName("selector");
|
||||
}
|
||||
|
||||
|
|
@ -342,21 +380,4 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
@Configuration
|
||||
static class TestMessageBrokerConfiguration extends DelegatingWebSocketMessageBrokerConfiguration {
|
||||
|
||||
@Override
|
||||
@Bean
|
||||
public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) {
|
||||
return new ExecutorSubscribableChannel(); // synchronous
|
||||
}
|
||||
|
||||
@Override
|
||||
@Bean
|
||||
public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) {
|
||||
return new ExecutorSubscribableChannel(); // synchronous
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue