Polishing

See gh-21798
This commit is contained in:
Rossen Stoyanchev 2023-10-10 10:51:33 +01:00 committed by rstoyanchev
parent a205eab618
commit 9eb39e182e
6 changed files with 68 additions and 49 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -273,7 +273,7 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
}
/**
* Create an instance from the payload and headers of the given Message.
* Create an instance by copying the headers of a Message.
*/
public static SimpMessageHeaderAccessor wrap(Message<?> message) {
return new SimpMessageHeaderAccessor(message);

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -130,18 +130,17 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
return null;
}
String user = parseResult.getUser();
String sourceDestination = parseResult.getSourceDestination();
String sourceDest = parseResult.getSourceDestination();
Set<String> targetSet = new HashSet<>();
for (String sessionId : parseResult.getSessionIds()) {
String actualDestination = parseResult.getActualDestination();
String targetDestination = getTargetDestination(
sourceDestination, actualDestination, sessionId, user);
if (targetDestination != null) {
targetSet.add(targetDestination);
String actualDest = parseResult.getActualDestination();
String targetDest = getTargetDestination(sourceDest, actualDest, sessionId, user);
if (targetDest != null) {
targetSet.add(targetDest);
}
}
String subscribeDestination = parseResult.getSubscribeDestination();
return new UserDestinationResult(sourceDestination, targetSet, subscribeDestination, user);
String subscribeDest = parseResult.getSubscribeDestination();
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user);
}
@Nullable
@ -283,22 +282,37 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver {
this.user = user;
}
/**
* The destination from the source message, e.g. "/user/{user}/queue/position-updates".
*/
public String getSourceDestination() {
return this.sourceDestination;
}
/**
* The actual destination, without any user prefix, e.g. "/queue/position-updates".
*/
public String getActualDestination() {
return this.actualDestination;
}
/**
* The user destination as it would be on a subscription, "/user/queue/position-updates".
*/
public String getSubscribeDestination() {
return this.subscribeDestination;
}
/**
* The session id or id's for the user.
*/
public Set<String> getSessionIds() {
return this.sessionIds;
}
/**
* The name of the user associated with the session.
*/
@Nullable
public String getUser() {
return this.user;

View File

@ -43,9 +43,9 @@ import org.springframework.util.StringUtils;
/**
* {@code MessageHandler} with support for "user" destinations.
*
* <p>Listens for messages with "user" destinations, translates their destination
* to actual target destinations unique to the active session(s) of a user, and
* then sends the resolved messages to the broker channel to be delivered.
* <p>Listen for messages with "user" destinations, translate the destination to
* a target destination that's unique to the active user session(s), and send
* to the broker channel for delivery.
*
* @author Rossen Stoyanchev
* @since 4.0
@ -75,24 +75,24 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
/**
* Create an instance with the given client and broker channels subscribing
* to handle messages from each and then sending any resolved messages to the
* broker channel.
* Create an instance with the given client and broker channels to subscribe to,
* and then send resolved messages to the broker channel.
* @param clientInboundChannel messages received from clients.
* @param brokerChannel messages sent to the broker.
* @param resolver the resolver for "user" destinations.
* @param destinationResolver the resolver for "user" destinations.
*/
public UserDestinationMessageHandler(SubscribableChannel clientInboundChannel,
SubscribableChannel brokerChannel, UserDestinationResolver resolver) {
public UserDestinationMessageHandler(
SubscribableChannel clientInboundChannel, SubscribableChannel brokerChannel,
UserDestinationResolver destinationResolver) {
Assert.notNull(clientInboundChannel, "'clientInChannel' must not be null");
Assert.notNull(brokerChannel, "'brokerChannel' must not be null");
Assert.notNull(resolver, "resolver must not be null");
Assert.notNull(destinationResolver, "resolver must not be null");
this.clientInboundChannel = clientInboundChannel;
this.brokerChannel = brokerChannel;
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
this.destinationResolver = resolver;
this.destinationResolver = destinationResolver;
}
@ -182,16 +182,16 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
@Override
public void handleMessage(Message<?> message) throws MessagingException {
Message<?> messageToUse = message;
public void handleMessage(Message<?> sourceMessage) throws MessagingException {
Message<?> message = sourceMessage;
if (this.broadcastHandler != null) {
messageToUse = this.broadcastHandler.preHandle(message);
if (messageToUse == null) {
message = this.broadcastHandler.preHandle(sourceMessage);
if (message == null) {
return;
}
}
UserDestinationResult result = this.destinationResolver.resolveDestination(messageToUse);
UserDestinationResult result = this.destinationResolver.resolveDestination(message);
if (result == null) {
return;
}
@ -201,22 +201,22 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec
logger.trace("No active sessions for user destination: " + result.getSourceDestination());
}
if (this.broadcastHandler != null) {
this.broadcastHandler.handleUnresolved(messageToUse);
this.broadcastHandler.handleUnresolved(message);
}
return;
}
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(messageToUse);
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(message);
initHeaders(accessor);
accessor.setNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION, result.getSubscribeDestination());
accessor.setLeaveMutable(true);
messageToUse = MessageBuilder.createMessage(messageToUse.getPayload(), accessor.getMessageHeaders());
message = MessageBuilder.createMessage(message.getPayload(), accessor.getMessageHeaders());
if (logger.isTraceEnabled()) {
logger.trace("Translated " + result.getSourceDestination() + " -> " + result.getTargetDestinations());
}
for (String target : result.getTargetDestinations()) {
this.messagingTemplate.send(target, messageToUse);
this.messagingTemplate.send(target, message);
}
}

View File

@ -54,11 +54,11 @@ class StompBrokerRelayMessageHandlerTests {
private StompBrokerRelayMessageHandler brokerRelay;
private StubMessageChannel outboundChannel = new StubMessageChannel();
private final StubMessageChannel outboundChannel = new StubMessageChannel();
private StubTcpOperations tcpClient = new StubTcpOperations();
private final StubTcpOperations tcpClient = new StubTcpOperations();
private ArgumentCaptor<Runnable> messageCountTaskCaptor = ArgumentCaptor.forClass(Runnable.class);
private final ArgumentCaptor<Runnable> messageCountTaskCaptor = ArgumentCaptor.forClass(Runnable.class);
@BeforeEach

View File

@ -23,6 +23,7 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.springframework.core.testfixture.security.TestPrincipal;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.StubMessageChannel;
import org.springframework.messaging.SubscribableChannel;
@ -50,7 +51,8 @@ class UserDestinationMessageHandlerTests {
private final SubscribableChannel brokerChannel = mock();
private final UserDestinationMessageHandler handler = new UserDestinationMessageHandler(new StubMessageChannel(), this.brokerChannel, new DefaultUserDestinationResolver(this.registry));
private final UserDestinationMessageHandler handler = new UserDestinationMessageHandler(
new StubMessageChannel(), this.brokerChannel, new DefaultUserDestinationResolver(this.registry));
@Test
@ -184,7 +186,9 @@ class UserDestinationMessageHandlerTests {
}
private Message<?> createWith(SimpMessageType type, String user, String sessionId, String destination) {
private Message<?> createWith(
SimpMessageType type, @Nullable String user, @Nullable String sessionId, @Nullable String destination) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type);
if (destination != null) {
headers.setDestination(destination);

View File

@ -108,9 +108,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Nullable
private MessageHeaderInitializer headerInitializer;
private boolean preserveReceiveOrder;
private final Map<String, MessageChannel> messageChannels = new ConcurrentHashMap<>();
@Nullable
private Map<String, MessageChannel> orderedHandlingMessageChannels;
private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<>();
@ -209,7 +208,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
* @since 6.1
*/
public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
this.preserveReceiveOrder = preserveReceiveOrder;
this.orderedHandlingMessageChannels = (preserveReceiveOrder ? new ConcurrentHashMap<>() : null);
}
/**
@ -218,7 +217,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
* @since 6.1
*/
public boolean isPreserveReceiveOrder() {
return this.preserveReceiveOrder;
return (this.orderedHandlingMessageChannels != null);
}
@Override
@ -253,7 +252,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
*/
@Override
public void handleMessageFromClient(WebSocketSession session,
WebSocketMessage<?> webSocketMessage, MessageChannel outputChannel) {
WebSocketMessage<?> webSocketMessage, MessageChannel targetChannel) {
List<Message<byte[]>> messages;
try {
@ -296,11 +295,11 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
return;
}
MessageChannel channelToUse =
(this.messageChannels.computeIfAbsent(session.getId(),
id -> this.preserveReceiveOrder ?
new OrderedMessageChannelDecorator(outputChannel, logger) :
outputChannel));
MessageChannel channelToUse = targetChannel;
if (this.orderedHandlingMessageChannels != null) {
channelToUse = this.orderedHandlingMessageChannels.computeIfAbsent(
session.getId(), id -> new OrderedMessageChannelDecorator(targetChannel, logger));
}
for (Message<byte[]> message : messages) {
StompHeaderAccessor headerAccessor =
@ -324,7 +323,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
});
}
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(outputChannel)) {
if (!detectImmutableMessageInterceptor(targetChannel)) {
headerAccessor.setImmutable();
}
@ -686,7 +685,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
outputChannel.send(message);
}
finally {
this.messageChannels.remove(session.getId());
if (this.orderedHandlingMessageChannels != null) {
this.orderedHandlingMessageChannels.remove(session.getId());
}
this.stompAuthentications.remove(session.getId());
SimpAttributesContextHolder.resetAttributes();
simpAttributes.sessionCompleted();