diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/CompositeMessageCondition.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/CompositeMessageCondition.java
new file mode 100644
index 0000000000..23c61faaaa
--- /dev/null
+++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/CompositeMessageCondition.java
@@ -0,0 +1,160 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.messaging.handler;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.springframework.messaging.Message;
+import org.springframework.util.Assert;
+
+/**
+ * Composite {@link MessageCondition} that delegates to other message conditions.
+ *
+ *
For {@link #combine} and {@link #compareTo} it is expected that the "other"
+ * composite contains the same number, type, and order of message conditions.
+ *
+ * @author Rossen Stoyanchev
+ * @since 5.2
+ */
+public class CompositeMessageCondition implements MessageCondition {
+
+ private final List> messageConditions;
+
+
+ public CompositeMessageCondition(MessageCondition>... messageConditions) {
+ this(Arrays.asList(messageConditions));
+ }
+
+ private CompositeMessageCondition(List> messageConditions) {
+ Assert.notEmpty(messageConditions, "No message conditions");
+ this.messageConditions = messageConditions;
+ }
+
+
+ public List> getMessageConditions() {
+ return this.messageConditions;
+ }
+
+ @SuppressWarnings("unchecked")
+ public > T getCondition(Class messageConditionType) {
+ for (MessageCondition> condition : this.messageConditions) {
+ if (messageConditionType.isAssignableFrom(condition.getClass())) {
+ return (T) condition;
+ }
+ }
+ throw new IllegalStateException("No condition of type: " + messageConditionType);
+ }
+
+
+ @Override
+ public CompositeMessageCondition combine(CompositeMessageCondition other) {
+ checkCompatible(other);
+ List> result = new ArrayList<>(this.messageConditions.size());
+ for (int i = 0; i < this.messageConditions.size(); i++) {
+ result.add(combine(getMessageConditions().get(i), other.getMessageConditions().get(i)));
+ }
+ return new CompositeMessageCondition(result);
+ }
+
+ @SuppressWarnings("unchecked")
+ private > T combine(MessageCondition> first, MessageCondition> second) {
+ return ((T) first).combine((T) second);
+ }
+
+ @Override
+ public CompositeMessageCondition getMatchingCondition(Message> message) {
+ List> result = new ArrayList<>(this.messageConditions.size());
+ for (MessageCondition> condition : this.messageConditions) {
+ MessageCondition> matchingCondition = (MessageCondition>) condition.getMatchingCondition(message);
+ if (matchingCondition == null) {
+ return null;
+ }
+ result.add(matchingCondition);
+ }
+ return new CompositeMessageCondition(result);
+ }
+
+ @Override
+ public int compareTo(CompositeMessageCondition other, Message> message) {
+ checkCompatible(other);
+ List> otherConditions = other.getMessageConditions();
+ for (int i = 0; i < this.messageConditions.size(); i++) {
+ int result = compare (this.messageConditions.get(i), otherConditions.get(i), message);
+ if (result != 0) {
+ return result;
+ }
+ }
+ return 0;
+ }
+
+ @SuppressWarnings("unchecked")
+ private > int compare(
+ MessageCondition> first, MessageCondition> second, Message> message) {
+
+ return ((T) first).compareTo((T) second, message);
+ }
+
+ private void checkCompatible(CompositeMessageCondition other) {
+ List> others = other.getMessageConditions();
+ for (int i = 0; i < this.messageConditions.size(); i++) {
+ if (i < others.size()) {
+ if (this.messageConditions.get(i).getClass().equals(others.get(i).getClass())) {
+ continue;
+ }
+ }
+ throw new IllegalArgumentException("Mismatched CompositeMessageCondition: " +
+ this.messageConditions + " vs " + others);
+ }
+ }
+
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ }
+ if (!(other instanceof CompositeMessageCondition)) {
+ return false;
+ }
+ CompositeMessageCondition otherComposite = (CompositeMessageCondition) other;
+ checkCompatible(otherComposite);
+ List> otherConditions = otherComposite.getMessageConditions();
+ for (int i = 0; i < this.messageConditions.size(); i++) {
+ if (!this.messageConditions.get(i).equals(otherConditions.get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int hashCode = 0;
+ for (MessageCondition> condition : this.messageConditions) {
+ hashCode += condition.hashCode() * 31;
+ }
+ return hashCode;
+ }
+
+ @Override
+ public String toString() {
+ return this.messageConditions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}"));
+ }
+
+}
diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandler.java
new file mode 100644
index 0000000000..4d8d8ba9d6
--- /dev/null
+++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandler.java
@@ -0,0 +1,246 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.messaging.handler.annotation.support.reactive;
+
+import java.lang.reflect.AnnotatedElement;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Set;
+
+import org.springframework.context.EmbeddedValueResolverAware;
+import org.springframework.core.annotation.AnnotatedElementUtils;
+import org.springframework.core.codec.Decoder;
+import org.springframework.lang.Nullable;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.handler.CompositeMessageCondition;
+import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
+import org.springframework.messaging.handler.annotation.MessageMapping;
+import org.springframework.messaging.handler.annotation.support.AnnotationExceptionHandlerMethodResolver;
+import org.springframework.messaging.handler.invocation.AbstractExceptionHandlerMethodResolver;
+import org.springframework.messaging.handler.invocation.reactive.AbstractEncoderMethodReturnValueHandler;
+import org.springframework.messaging.handler.invocation.reactive.AbstractMethodMessageHandler;
+import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver;
+import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler;
+import org.springframework.stereotype.Controller;
+import org.springframework.util.AntPathMatcher;
+import org.springframework.util.Assert;
+import org.springframework.util.PathMatcher;
+import org.springframework.util.StringValueResolver;
+import org.springframework.validation.Validator;
+
+/**
+ * Extension of {@link AbstractMethodMessageHandler} for
+ * {@link MessageMapping @MessageMapping} methods.
+ *
+ * The payload of incoming messages is decoded through
+ * {@link PayloadMethodArgumentResolver} using one of the configured
+ * {@link #setDecoders(List)} decoders.
+ *
+ *
The {@link #setEncoderReturnValueHandler encoderReturnValueHandler}
+ * property must be set to encode and handle return values from
+ * {@code @MessageMapping} methods.
+ *
+ * @author Rossen Stoyanchev
+ * @since 5.2
+ */
+public class MessageMappingMessageHandler extends AbstractMethodMessageHandler
+ implements EmbeddedValueResolverAware {
+
+ private PathMatcher pathMatcher = new AntPathMatcher();
+
+ private final List> decoders = new ArrayList<>();
+
+ @Nullable
+ private Validator validator;
+
+ @Nullable
+ private HandlerMethodReturnValueHandler encoderReturnValueHandler;
+
+ @Nullable
+ private StringValueResolver valueResolver;
+
+
+ /**
+ * Set the PathMatcher implementation to use for matching destinations
+ * against configured destination patterns.
+ * By default, {@link AntPathMatcher} is used.
+ */
+ public void setPathMatcher(PathMatcher pathMatcher) {
+ Assert.notNull(pathMatcher, "PathMatcher must not be null");
+ this.pathMatcher = pathMatcher;
+ }
+
+ /**
+ * Return the PathMatcher implementation to use for matching destinations.
+ */
+ public PathMatcher getPathMatcher() {
+ return this.pathMatcher;
+ }
+
+ /**
+ * Configure the decoders to user for incoming payloads.
+ */
+ public void setDecoders(List extends Decoder>> decoders) {
+ this.decoders.addAll(decoders);
+ }
+
+ /**
+ * Return the configured decoders.
+ */
+ public List extends Decoder>> getDecoders() {
+ return this.decoders;
+ }
+
+ /**
+ * Return the configured Validator instance.
+ */
+ @Nullable
+ public Validator getValidator() {
+ return this.validator;
+ }
+
+ /**
+ * Set the Validator instance used for validating {@code @Payload} arguments.
+ * @see org.springframework.validation.annotation.Validated
+ * @see PayloadMethodArgumentResolver
+ */
+ public void setValidator(@Nullable Validator validator) {
+ this.validator = validator;
+ }
+
+ /**
+ * Configure the return value handler that will encode response content.
+ * Consider extending {@link AbstractEncoderMethodReturnValueHandler} which
+ * provides the infrastructure to encode and all that's left is to somehow
+ * handle the encoded content, e.g. by wrapping as a message and passing it
+ * to something or sending it somewhere.
+ *
By default this is not configured in which case payload/content return
+ * values from {@code @MessageMapping} methods will remain unhandled.
+ * @param encoderReturnValueHandler the return value handler to use
+ * @see AbstractEncoderMethodReturnValueHandler
+ */
+ public void setEncoderReturnValueHandler(@Nullable HandlerMethodReturnValueHandler encoderReturnValueHandler) {
+ this.encoderReturnValueHandler = encoderReturnValueHandler;
+ }
+
+ /**
+ * Return the configured
+ * {@link #setEncoderReturnValueHandler encoderReturnValueHandler}.
+ */
+ @Nullable
+ public HandlerMethodReturnValueHandler getEncoderReturnValueHandler() {
+ return this.encoderReturnValueHandler;
+ }
+
+ @Override
+ public void setEmbeddedValueResolver(StringValueResolver resolver) {
+ this.valueResolver = resolver;
+ }
+
+
+ @Override
+ protected List extends HandlerMethodArgumentResolver> initArgumentResolvers() {
+ List resolvers = new ArrayList<>();
+
+ // Custom resolvers
+ resolvers.addAll(getArgumentResolverConfigurer().getCustomResolvers());
+
+ // Catch-all
+ resolvers.add(new PayloadMethodArgumentResolver(
+ this.decoders, this.validator, getReactiveAdapterRegistry(), true));
+
+ return resolvers;
+ }
+
+ @Override
+ protected List extends HandlerMethodReturnValueHandler> initReturnValueHandlers() {
+ List handlers = new ArrayList<>();
+ handlers.addAll(getReturnValueHandlerConfigurer().getCustomHandlers());
+ if (this.encoderReturnValueHandler != null) {
+ handlers.add(this.encoderReturnValueHandler);
+ }
+ return handlers;
+ }
+
+
+ @Override
+ protected boolean isHandler(Class> beanType) {
+ return AnnotatedElementUtils.hasAnnotation(beanType, Controller.class);
+ }
+
+ @Override
+ protected CompositeMessageCondition getMappingForMethod(Method method, Class> handlerType) {
+ CompositeMessageCondition methodCondition = getCondition(method);
+ if (methodCondition != null) {
+ CompositeMessageCondition typeCondition = getCondition(handlerType);
+ if (typeCondition != null) {
+ return typeCondition.combine(methodCondition);
+ }
+ }
+ return methodCondition;
+ }
+
+ @Nullable
+ private CompositeMessageCondition getCondition(AnnotatedElement element) {
+ MessageMapping annot = AnnotatedElementUtils.findMergedAnnotation(element, MessageMapping.class);
+ if (annot == null || annot.value().length == 0) {
+ return null;
+ }
+ String[] destinations = annot.value();
+ if (this.valueResolver != null) {
+ destinations = Arrays.stream(annot.value())
+ .map(s -> this.valueResolver.resolveStringValue(s))
+ .toArray(String[]::new);
+ }
+ return new CompositeMessageCondition(new DestinationPatternsMessageCondition(destinations, this.pathMatcher));
+ }
+
+ @Override
+ protected Set getDirectLookupMappings(CompositeMessageCondition mapping) {
+ Set result = new LinkedHashSet<>();
+ for (String pattern : mapping.getCondition(DestinationPatternsMessageCondition.class).getPatterns()) {
+ if (!this.pathMatcher.isPattern(pattern)) {
+ result.add(pattern);
+ }
+ }
+ return result;
+ }
+
+ @Override
+ protected String getDestination(Message> message) {
+ return (String) message.getHeaders().get(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER);
+ }
+
+ @Override
+ protected CompositeMessageCondition getMatchingMapping(CompositeMessageCondition mapping, Message> message) {
+ return mapping.getMatchingCondition(message);
+ }
+
+ @Override
+ protected Comparator getMappingComparator(Message> message) {
+ return (info1, info2) -> info1.compareTo(info2, message);
+ }
+
+ @Override
+ protected AbstractExceptionHandlerMethodResolver createExceptionMethodResolverFor(Class> beanType) {
+ return new AnnotationExceptionHandlerMethodResolver(beanType);
+ }
+
+}
diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageMappingInfo.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageMappingInfo.java
index 317c24919b..e6d6b55994 100644
--- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageMappingInfo.java
+++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageMappingInfo.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package org.springframework.messaging.simp;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
+import org.springframework.messaging.handler.CompositeMessageCondition;
import org.springframework.messaging.handler.DestinationPatternsMessageCondition;
import org.springframework.messaging.handler.MessageCondition;
@@ -34,62 +35,44 @@ import org.springframework.messaging.handler.MessageCondition;
*/
public class SimpMessageMappingInfo implements MessageCondition {
- private final SimpMessageTypeMessageCondition messageTypeMessageCondition;
-
- private final DestinationPatternsMessageCondition destinationConditions;
+ private final CompositeMessageCondition delegate;
public SimpMessageMappingInfo(SimpMessageTypeMessageCondition messageTypeMessageCondition,
DestinationPatternsMessageCondition destinationConditions) {
- this.messageTypeMessageCondition = messageTypeMessageCondition;
- this.destinationConditions = destinationConditions;
+ this.delegate = new CompositeMessageCondition(messageTypeMessageCondition, destinationConditions);
+ }
+
+ private SimpMessageMappingInfo(CompositeMessageCondition delegate) {
+ this.delegate = delegate;
}
public SimpMessageTypeMessageCondition getMessageTypeMessageCondition() {
- return this.messageTypeMessageCondition;
+ return this.delegate.getCondition(SimpMessageTypeMessageCondition.class);
}
public DestinationPatternsMessageCondition getDestinationConditions() {
- return this.destinationConditions;
+ return this.delegate.getCondition(DestinationPatternsMessageCondition.class);
}
@Override
public SimpMessageMappingInfo combine(SimpMessageMappingInfo other) {
- SimpMessageTypeMessageCondition typeCond =
- this.getMessageTypeMessageCondition().combine(other.getMessageTypeMessageCondition());
- DestinationPatternsMessageCondition destCond =
- this.destinationConditions.combine(other.getDestinationConditions());
- return new SimpMessageMappingInfo(typeCond, destCond);
+ return new SimpMessageMappingInfo(this.delegate.combine(other.delegate));
}
@Override
@Nullable
public SimpMessageMappingInfo getMatchingCondition(Message> message) {
- SimpMessageTypeMessageCondition typeCond = this.messageTypeMessageCondition.getMatchingCondition(message);
- if (typeCond == null) {
- return null;
- }
- DestinationPatternsMessageCondition destCond = this.destinationConditions.getMatchingCondition(message);
- if (destCond == null) {
- return null;
- }
- return new SimpMessageMappingInfo(typeCond, destCond);
+ CompositeMessageCondition condition = this.delegate.getMatchingCondition(message);
+ return condition != null ? new SimpMessageMappingInfo(condition) : null;
}
@Override
public int compareTo(SimpMessageMappingInfo other, Message> message) {
- int result = this.messageTypeMessageCondition.compareTo(other.messageTypeMessageCondition, message);
- if (result != 0) {
- return result;
- }
- result = this.destinationConditions.compareTo(other.destinationConditions, message);
- if (result != 0) {
- return result;
- }
- return 0;
+ return this.delegate.compareTo(other.delegate, message);
}
@@ -101,19 +84,17 @@ public class SimpMessageMappingInfo implements MessageCondition> decoders = Collections.singletonList(StringDecoder.allMimeTypes());
+ List> encoders = Collections.singletonList(CharSequenceEncoder.allMimeTypes());
+
+ ReactiveAdapterRegistry registry = ReactiveAdapterRegistry.getSharedInstance();
+ this.returnValueHandler = new TestEncoderReturnValueHandler(encoders, registry);
+
+ PropertySource> source = new MapPropertySource("test", Collections.singletonMap("path", "path123"));
+
+ StaticApplicationContext context = new StaticApplicationContext();
+ context.getEnvironment().getPropertySources().addFirst(source);
+ context.registerSingleton("testController", TestController.class);
+ context.refresh();
+
+ MessageMappingMessageHandler messageHandler = new MessageMappingMessageHandler();
+ messageHandler.setApplicationContext(context);
+ messageHandler.setEmbeddedValueResolver(new EmbeddedValueResolver(context.getBeanFactory()));
+ messageHandler.setDecoders(decoders);
+ messageHandler.setEncoderReturnValueHandler(this.returnValueHandler);
+ messageHandler.afterPropertiesSet();
+
+ return messageHandler;
+ }
+
+ private Message> message(String destination, String... content) {
+ return new GenericMessage<>(
+ Flux.fromIterable(Arrays.stream(content).map(this::toDataBuffer).collect(Collectors.toList())),
+ Collections.singletonMap(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination));
+ }
+
+ private DataBuffer toDataBuffer(String payload) {
+ return bufferFactory.wrap(payload.getBytes(UTF_8));
+ }
+
+ private void verifyOutputContent(List expected) {
+ List buffers = this.returnValueHandler.getOutputContent();
+ assertNotNull("No output: no matching handler method?", buffers);
+ List actual = buffers.stream().map(buffer -> dumpString(buffer, UTF_8)).collect(Collectors.toList());
+ assertEquals(expected, actual);
+ }
+
+
+ @Controller
+ static class TestController {
+
+ @MessageMapping("/string")
+ String handleString(String payload) {
+ return payload + "::response";
+ }
+
+ @MessageMapping("/monoString")
+ Mono handleMonoString(Mono payload) {
+ return payload.map(s -> s + "::response").delayElement(Duration.ofMillis(10));
+ }
+
+ @MessageMapping("/fluxString")
+ Flux handleFluxString(Flux payload) {
+ return payload.map(s -> s + "::response").delayElements(Duration.ofMillis(10));
+ }
+
+ @MessageMapping("/${path}")
+ String handleWithPlaceholder(String payload) {
+ return payload + "::response";
+ }
+
+ @MessageMapping("/exception")
+ String handleAndThrow() {
+ throw new IllegalArgumentException("rejected");
+ }
+
+ @MessageMapping("/errorSignal")
+ Mono handleAndSignalError() {
+ return Mono.delay(Duration.ofMillis(10))
+ .flatMap(aLong -> Mono.error(new IllegalArgumentException("rejected")));
+ }
+
+ @MessageExceptionHandler
+ Mono handleException(IllegalArgumentException ex) {
+ return Mono.delay(Duration.ofMillis(10)).map(aLong -> ex.getMessage() + "::handled");
+ }
+ }
+
+
+ private static class TestEncoderReturnValueHandler extends AbstractEncoderMethodReturnValueHandler {
+
+ @Nullable
+ private volatile List outputContent;
+
+
+ TestEncoderReturnValueHandler(List> encoders, ReactiveAdapterRegistry registry) {
+ super(encoders, registry);
+ }
+
+
+ @Nullable
+ public List getOutputContent() {
+ return this.outputContent;
+ }
+
+ @Override
+ protected Mono handleEncodedContent(
+ Flux encodedContent, MethodParameter returnType, Message> message) {
+
+ return encodedContent.collectList().doOnNext(buffers -> this.outputContent = buffers).then();
+ }
+ }
+
+}