diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java index 2fbc393190..d97657e389 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.messaging.simp.annotation.support; import java.lang.annotation.Annotation; import java.security.Principal; +import java.util.Collections; import java.util.Map; import org.springframework.core.MethodParameter; @@ -153,11 +154,10 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH MessageHeaders headers = message.getHeaders(); String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); - PlaceholderResolver varResolver = initVarResolver(headers); - Object annotation = findAnnotation(returnType); + DestinationHelper destinationHelper = getDestinationHelper(headers, returnType); - if (annotation instanceof SendToUser) { - SendToUser sendToUser = (SendToUser) annotation; + SendToUser sendToUser = destinationHelper.getSendToUser(); + if (sendToUser != null) { boolean broadcast = sendToUser.broadcast(); String user = getUserName(message, headers); if (user == null) { @@ -169,7 +169,7 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH } String[] destinations = getTargetDestinations(sendToUser, message, this.defaultUserDestinationPrefix); for (String destination : destinations) { - destination = this.placeholderHelper.replacePlaceholders(destination, varResolver); + destination = destinationHelper.expandTemplateVars(destination); if (broadcast) { this.messagingTemplate.convertAndSendToUser( user, destination, returnValue, createHeaders(null, returnType)); @@ -180,51 +180,33 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH } } } - else { - SendTo sendTo = (SendTo) annotation; // possibly null + + SendTo sendTo = destinationHelper.getSendTo(); + if (sendTo != null || sendToUser == null) { String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix); for (String destination : destinations) { - destination = this.placeholderHelper.replacePlaceholders(destination, varResolver); + destination = destinationHelper.expandTemplateVars(destination); this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType)); } } } - @Nullable - private Object findAnnotation(MethodParameter returnType) { - Annotation[] anns = new Annotation[4]; - anns[0] = AnnotatedElementUtils.findMergedAnnotation(returnType.getExecutable(), SendToUser.class); - anns[1] = AnnotatedElementUtils.findMergedAnnotation(returnType.getExecutable(), SendTo.class); - anns[2] = AnnotatedElementUtils.findMergedAnnotation(returnType.getDeclaringClass(), SendToUser.class); - anns[3] = AnnotatedElementUtils.findMergedAnnotation(returnType.getDeclaringClass(), SendTo.class); + private DestinationHelper getDestinationHelper(MessageHeaders headers, MethodParameter returnType) { - if (anns[0] != null && !ObjectUtils.isEmpty(((SendToUser) anns[0]).value())) { - return anns[0]; - } - if (anns[1] != null && !ObjectUtils.isEmpty(((SendTo) anns[1]).value())) { - return anns[1]; - } - if (anns[2] != null && !ObjectUtils.isEmpty(((SendToUser) anns[2]).value())) { - return anns[2]; - } - if (anns[3] != null && !ObjectUtils.isEmpty(((SendTo) anns[3]).value())) { - return anns[3]; + SendToUser m1 = AnnotatedElementUtils.findMergedAnnotation(returnType.getExecutable(), SendToUser.class); + SendTo m2 = AnnotatedElementUtils.findMergedAnnotation(returnType.getExecutable(), SendTo.class); + if ((m1 != null && !ObjectUtils.isEmpty(m1.value())) || (m2 != null && !ObjectUtils.isEmpty(m2.value()))) { + return new DestinationHelper(headers, m1, m2); } - for (int i=0; i < 4; i++) { - if (anns[i] != null) { - return anns[i]; - } + SendToUser c1 = AnnotatedElementUtils.findMergedAnnotation(returnType.getDeclaringClass(), SendToUser.class); + SendTo c2 = AnnotatedElementUtils.findMergedAnnotation(returnType.getDeclaringClass(), SendTo.class); + if ((c1 != null && !ObjectUtils.isEmpty(c1.value())) || (c2 != null && !ObjectUtils.isEmpty(c2.value()))) { + return new DestinationHelper(headers, c1, c2); } - return null; - } - - @SuppressWarnings("unchecked") - private PlaceholderResolver initVarResolver(MessageHeaders headers) { - String name = DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER; - Map vars = (Map) headers.get(name); - return new DestinationVariablePlaceholderResolver(vars); + return m1 != null || m2 != null ? + new DestinationHelper(headers, m1, m2) : new DestinationHelper(headers, c1, c2); } @Nullable @@ -275,20 +257,43 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH } - private static class DestinationVariablePlaceholderResolver implements PlaceholderResolver { + private class DestinationHelper { + + private final PlaceholderResolver placeholderResolver; @Nullable - private final Map vars; + private final SendTo sendTo; - public DestinationVariablePlaceholderResolver(@Nullable Map vars) { - this.vars = vars; + @Nullable + private final SendToUser sendToUser; + + + public DestinationHelper(MessageHeaders headers, @Nullable SendToUser sendToUser, @Nullable SendTo sendTo) { + Map variables = getTemplateVariables(headers); + this.placeholderResolver = variables::get; + this.sendTo = sendTo; + this.sendToUser = sendToUser; + } + + @SuppressWarnings("unchecked") + private Map getTemplateVariables(MessageHeaders headers) { + String name = DestinationVariableMethodArgumentResolver.DESTINATION_TEMPLATE_VARIABLES_HEADER; + return (Map) headers.getOrDefault(name, Collections.emptyMap()); } - @Override @Nullable - public String resolvePlaceholder(String placeholderName) { - return (this.vars != null ? this.vars.get(placeholderName) : null); + public SendTo getSendTo() { + return this.sendTo; + } + + @Nullable + public SendToUser getSendToUser() { + return this.sendToUser; + } + + + public String expandTemplateVars(String destination) { + return placeholderHelper.replacePlaceholders(destination, this.placeholderResolver); } } - -} +} \ No newline at end of file diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java index abaa7a0e2b..98a4e756b9 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java @@ -87,6 +87,7 @@ public class SendToMethodReturnValueHandlerTests { private MethodParameter sendToWithPlaceholdersReturnType = param("handleAndSendToWithPlaceholders"); private MethodParameter sendToUserReturnType = param("handleAndSendToUser"); private MethodParameter sendToUserInSessionReturnType = param("handleAndSendToUserInSession"); + private MethodParameter sendToSendToUserReturnType = param("handleAndSendToAndSendToUser"); private MethodParameter sendToUserDefaultDestReturnType = param("handleAndSendToUserDefaultDest"); private MethodParameter sendToUserInSessionDefaultDestReturnType = param("handleAndSendToUserDefaultDestInSession"); private MethodParameter jsonViewReturnType = param("handleAndSendToJsonView"); @@ -355,6 +356,38 @@ public class SendToMethodReturnValueHandlerTests { assertEquals("/user/" + user.getName() + "/dest2", accessor.getDestination()); } + @Test + public void sendToAndSendToUser() throws Exception { + given(this.messageChannel.send(any(Message.class))).willReturn(true); + + String sessionId = "sess1"; + TestUser user = new TestUser(); + Message inputMessage = createMessage(sessionId, "sub1", null, null, user); + this.handler.handleReturnValue(PAYLOAD, this.sendToSendToUserReturnType, inputMessage); + + verify(this.messageChannel, times(4)).send(this.messageCaptor.capture()); + + SimpMessageHeaderAccessor accessor = getCapturedAccessor(0); + assertNull(accessor.getSessionId()); + assertNull(accessor.getSubscriptionId()); + assertEquals("/user/" + user.getName() + "/dest1", accessor.getDestination()); + + accessor = getCapturedAccessor(1); + assertNull(accessor.getSessionId()); + assertNull(accessor.getSubscriptionId()); + assertEquals("/user/" + user.getName() + "/dest2", accessor.getDestination()); + + accessor = getCapturedAccessor(2); + assertEquals("sess1", accessor.getSessionId()); + assertNull(accessor.getSubscriptionId()); + assertEquals("/dest1", accessor.getDestination()); + + accessor = getCapturedAccessor(3); + assertEquals("sess1", accessor.getSessionId()); + assertNull(accessor.getSubscriptionId()); + assertEquals("/dest2", accessor.getDestination()); + } + @Test // SPR-12170 public void sendToWithDestinationPlaceholders() throws Exception { given(this.messageChannel.send(any(Message.class))).willReturn(true); @@ -577,6 +610,13 @@ public class SendToMethodReturnValueHandlerTests { return PAYLOAD; } + @SendTo({"/dest1", "/dest2"}) + @SendToUser({"/dest1", "/dest2"}) + @SuppressWarnings("unused") + String handleAndSendToAndSendToUser() { + return PAYLOAD; + } + @JsonView(MyJacksonView1.class) @SuppressWarnings("unused") JacksonViewBean handleAndSendToJsonView() {