diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/MessageHolder.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/ReplyTo.java similarity index 54% rename from spring-messaging/src/main/java/org/springframework/messaging/simp/MessageHolder.java rename to spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/ReplyTo.java index 31c265108d0..38fc1542262 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/MessageHolder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/ReplyTo.java @@ -14,34 +14,28 @@ * limitations under the License. */ -package org.springframework.messaging.simp; +package org.springframework.messaging.handler.annotation; -import org.springframework.core.NamedThreadLocal; -import org.springframework.messaging.Message; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; -// TODO: remove? - /** * @author Rossen Stoyanchev * @since 4.0 */ -public class MessageHolder { - - private static final NamedThreadLocal> messageHolder = - new NamedThreadLocal>("Current message"); +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface ReplyTo { - public static void setMessage(Message message) { - messageHolder.set(message); - } - - public static Message getMessage() { - return messageHolder.get(); - } - - public static void reset() { - messageHolder.remove(); - } + /** + * The destination value for the reply. + */ + String value(); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/method/InvalidMessageMethodParameterException.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/method/MissingSessionUserException.java similarity index 58% rename from spring-messaging/src/main/java/org/springframework/messaging/handler/method/InvalidMessageMethodParameterException.java rename to spring-messaging/src/main/java/org/springframework/messaging/handler/method/MissingSessionUserException.java index 11b9aaebab7..0e13f563206 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/method/InvalidMessageMethodParameterException.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/method/MissingSessionUserException.java @@ -16,7 +16,6 @@ package org.springframework.messaging.handler.method; -import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessagingException; @@ -26,29 +25,13 @@ import org.springframework.messaging.MessagingException; * @author Rossen Stoyanchev * @since 4.0 */ -public class InvalidMessageMethodParameterException extends MessagingException { +public class MissingSessionUserException extends MessagingException { private static final long serialVersionUID = -6905878930083523161L; - private final MethodParameter parameter; - - public InvalidMessageMethodParameterException(Message message, String description, - MethodParameter parameter, Throwable cause) { - super(message, description, cause); - this.parameter = parameter; - } - - public InvalidMessageMethodParameterException(Message message, String description, - MethodParameter parameter) { - - super(message, description); - this.parameter = parameter; - } - - - public MethodParameter getParameter() { - return this.parameter; + public MissingSessionUserException(Message message) { + super(message, "No \"user\" header in message"); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/MessageSendingReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/DefaultMessageReturnValueHandler.java similarity index 59% rename from spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/MessageSendingReturnValueHandler.java rename to spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/DefaultMessageReturnValueHandler.java index be786231bbf..38c3dba0fe8 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/MessageSendingReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/DefaultMessageReturnValueHandler.java @@ -16,10 +16,14 @@ package org.springframework.messaging.simp.annotation.support; +import java.security.Principal; + import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.handler.annotation.ReplyTo; import org.springframework.messaging.handler.method.MessageReturnValueHandler; +import org.springframework.messaging.handler.method.MissingSessionUserException; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.support.MessageBuilder; import org.springframework.messaging.support.converter.MessageConverter; @@ -27,19 +31,32 @@ import org.springframework.util.Assert; /** + * Expects return values to be either a {@link Message} or the payload of a message to be + * converted and sent on a {@link MessageChannel}. + * + *

This {@link MessageReturnValueHandler} should be ordered last as it supports all + * return value types. + * * @author Rossen Stoyanchev * @since 4.0 */ -public class MessageSendingReturnValueHandler implements MessageReturnValueHandler { +public class DefaultMessageReturnValueHandler implements MessageReturnValueHandler { + + private MessageChannel inboundChannel; private MessageChannel outboundChannel; private final MessageConverter converter; - public MessageSendingReturnValueHandler(MessageChannel outboundChannel, MessageConverter converter) { + public DefaultMessageReturnValueHandler(MessageChannel inboundChannel, MessageChannel outboundChannel, + MessageConverter converter) { + + Assert.notNull(inboundChannel, "inboundChannel is required"); Assert.notNull(outboundChannel, "outboundChannel is required"); Assert.notNull(converter, "converter is required"); + + this.inboundChannel = inboundChannel; this.outboundChannel = outboundChannel; this.converter = converter; } @@ -60,6 +77,7 @@ public class MessageSendingReturnValueHandler implements MessageReturnValueHandl } SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(message); + Message returnMessage = (returnValue instanceof Message) ? (Message) returnValue : null; Object returnPayload = (returnMessage != null) ? returnMessage.getPayload() : returnValue; @@ -68,14 +86,43 @@ public class MessageSendingReturnValueHandler implements MessageReturnValueHandl returnHeaders.setSessionId(inputHeaders.getSessionId()); returnHeaders.setSubscriptionId(inputHeaders.getSubscriptionId()); - if (returnHeaders.getDestination() == null) { - returnHeaders.setDestination(inputHeaders.getDestination()); - } + + String destination = getDestination(message, returnType, inputHeaders, returnHeaders); + returnHeaders.setDestination(destination); returnMessage = this.converter.toMessage(returnPayload); returnMessage = MessageBuilder.fromMessage(returnMessage).copyHeaders(returnHeaders.toMap()).build(); - this.outboundChannel.send(returnMessage); + if (destination.startsWith("/user/")) { + this.inboundChannel.send(returnMessage); + } + else { + this.outboundChannel.send(returnMessage); + } } + protected String getDestination(Message inputMessage, MethodParameter returnType, + SimpMessageHeaderAccessor inputHeaders, SimpMessageHeaderAccessor returnHeaders) { + + ReplyTo annot = returnType.getMethodAnnotation(ReplyTo.class); + + if (returnHeaders.getDestination() != null) { + return returnHeaders.getDestination(); + } + else if (annot != null) { + Principal user = inputHeaders.getUser(); + if (user == null) { + throw new MissingSessionUserException(inputMessage); + } + return "/user/" + user.getName() + annot.value(); + } + else if (inputHeaders.getDestination() != null) { + return inputHeaders.getDestination(); + } + else { + return null; + } + + } + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java index 7d1e17db5ef..3c3ed041ee6 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/PrincipalMessageArgumentResolver.java @@ -20,8 +20,8 @@ import java.security.Principal; import org.springframework.core.MethodParameter; import org.springframework.messaging.Message; -import org.springframework.messaging.handler.method.InvalidMessageMethodParameterException; import org.springframework.messaging.handler.method.MessageArgumentResolver; +import org.springframework.messaging.handler.method.MissingSessionUserException; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -43,7 +43,7 @@ public class PrincipalMessageArgumentResolver implements MessageArgumentResolver SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); Principal user = headers.getUser(); if (user == null) { - throw new InvalidMessageMethodParameterException(message, "User not available", parameter); + throw new MissingSessionUserException(message); } return user; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationMethodMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationMethodMessageHandler.java index fa71006fb8d..df8b83a48ca 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationMethodMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/AnnotationMethodMessageHandler.java @@ -43,12 +43,11 @@ import org.springframework.messaging.handler.annotation.support.MessageException import org.springframework.messaging.handler.method.InvocableMessageHandlerMethod; import org.springframework.messaging.handler.method.MessageArgumentResolverComposite; import org.springframework.messaging.handler.method.MessageReturnValueHandlerComposite; -import org.springframework.messaging.simp.MessageHolder; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SubscribeEvent; import org.springframework.messaging.simp.annotation.UnsubscribeEvent; -import org.springframework.messaging.simp.annotation.support.MessageSendingReturnValueHandler; +import org.springframework.messaging.simp.annotation.support.DefaultMessageReturnValueHandler; import org.springframework.messaging.simp.annotation.support.PrincipalMessageArgumentResolver; import org.springframework.messaging.support.converter.MessageConverter; import org.springframework.stereotype.Controller; @@ -67,6 +66,8 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati private static final Log logger = LogFactory.getLog(AnnotationMethodMessageHandler.class); + private final MessageChannel inboundChannel; + private final MessageChannel outboundChannel; private MessageConverter messageConverter; @@ -91,8 +92,10 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati * @param inboundChannel a channel for processing incoming messages from clients * @param outboundChannel a channel for messages going out to clients */ - public AnnotationMethodMessageHandler(MessageChannel outboundChannel) { + public AnnotationMethodMessageHandler(MessageChannel inboundChannel, MessageChannel outboundChannel) { + Assert.notNull(inboundChannel, "inboundChannel is required"); Assert.notNull(outboundChannel, "outboundChannel is required"); + this.inboundChannel = inboundChannel; this.outboundChannel = outboundChannel; } @@ -116,8 +119,8 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati this.argumentResolvers.addResolver(new PrincipalMessageArgumentResolver()); this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverter)); - this.returnValueHandlers.addHandler( - new MessageSendingReturnValueHandler(this.outboundChannel, this.messageConverter)); + this.returnValueHandlers.addHandler(new DefaultMessageReturnValueHandler( + this.inboundChannel, this.outboundChannel, this.messageConverter)); } protected void initHandlerMethods() { @@ -215,16 +218,13 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati invocableHandlerMethod.setMessageMethodArgumentResolvers(this.argumentResolvers); try { - MessageHolder.setMessage(message); - - Object value = invocableHandlerMethod.invoke(message); + Object returnValue = invocableHandlerMethod.invoke(message); MethodParameter returnType = handlerMethod.getReturnType(); if (void.class.equals(returnType.getParameterType())) { return; } - - this.returnValueHandlers.handleReturnValue(value, returnType, message); + this.returnValueHandlers.handleReturnValue(returnValue, returnType, message); } catch (Exception ex) { invokeExceptionHandler(message, handlerMethod, ex); @@ -233,14 +233,11 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati // TODO ex.printStackTrace(); } - finally { - MessageHolder.reset(); - } } private void invokeExceptionHandler(Message message, HandlerMethod handlerMethod, Exception ex) { - InvocableMessageHandlerMethod invocableHandlerMethod; + InvocableMessageHandlerMethod exceptionHandlerMethod; Class beanType = handlerMethod.getBeanType(); MessageExceptionHandlerMethodResolver resolver = this.exceptionHandlerCache.get(beanType); if (resolver == null) { @@ -254,11 +251,17 @@ public class AnnotationMethodMessageHandler implements MessageHandler, Applicati return; } - invocableHandlerMethod = new InvocableMessageHandlerMethod(handlerMethod.getBean(), method); - invocableHandlerMethod.setMessageMethodArgumentResolvers(this.argumentResolvers); + exceptionHandlerMethod = new InvocableMessageHandlerMethod(handlerMethod.getBean(), method); + exceptionHandlerMethod.setMessageMethodArgumentResolvers(this.argumentResolvers); try { - invocableHandlerMethod.invoke(message, ex); + Object returnValue = exceptionHandlerMethod.invoke(message, ex); + + MethodParameter returnType = exceptionHandlerMethod.getReturnType(); + if (void.class.equals(returnType.getParameterType())) { + return; + } + this.returnValueHandlers.handleReturnValue(returnValue, returnType, message); } catch (Throwable t) { logger.error("Error while handling exception", t); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/InMemoryUserSessionResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolver.java similarity index 95% rename from spring-messaging/src/main/java/org/springframework/messaging/simp/handler/InMemoryUserSessionResolver.java rename to spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolver.java index c24acea17bf..f1a6fd88415 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/InMemoryUserSessionResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/SimpleUserSessionResolver.java @@ -27,7 +27,7 @@ import java.util.concurrent.CopyOnWriteArraySet; * @author Rossen Stoyanchev * @since 4.0 */ -public class InMemoryUserSessionResolver implements UserSessionResolver, UserSessionStore { +public class SimpleUserSessionResolver implements UserSessionResolver, UserSessionStore { // userId -> sessionId's private final Map> userSessionIds = new ConcurrentHashMap>(); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandler.java index 340e022a9e9..6cc6ce001cf 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/handler/UserDestinationMessageHandler.java @@ -47,7 +47,7 @@ public class UserDestinationMessageHandler implements MessageHandler { private String prefix = "/user/"; - private UserSessionResolver userSessionResolver = new InMemoryUserSessionResolver(); + private UserSessionResolver userSessionResolver = new SimpleUserSessionResolver(); public UserDestinationMessageHandler(MessageSendingOperations messagingTemplate) {