diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/CompositeMessageCondition.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/CompositeMessageCondition.java new file mode 100644 index 0000000000..23c61faaaa --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/CompositeMessageCondition.java @@ -0,0 +1,160 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.handler; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.springframework.messaging.Message; +import org.springframework.util.Assert; + +/** + * Composite {@link MessageCondition} that delegates to other message conditions. + * + *

For {@link #combine} and {@link #compareTo} it is expected that the "other" + * composite contains the same number, type, and order of message conditions. + * + * @author Rossen Stoyanchev + * @since 5.2 + */ +public class CompositeMessageCondition implements MessageCondition { + + private final List> messageConditions; + + + public CompositeMessageCondition(MessageCondition... messageConditions) { + this(Arrays.asList(messageConditions)); + } + + private CompositeMessageCondition(List> messageConditions) { + Assert.notEmpty(messageConditions, "No message conditions"); + this.messageConditions = messageConditions; + } + + + public List> getMessageConditions() { + return this.messageConditions; + } + + @SuppressWarnings("unchecked") + public > T getCondition(Class messageConditionType) { + for (MessageCondition condition : this.messageConditions) { + if (messageConditionType.isAssignableFrom(condition.getClass())) { + return (T) condition; + } + } + throw new IllegalStateException("No condition of type: " + messageConditionType); + } + + + @Override + public CompositeMessageCondition combine(CompositeMessageCondition other) { + checkCompatible(other); + List> result = new ArrayList<>(this.messageConditions.size()); + for (int i = 0; i < this.messageConditions.size(); i++) { + result.add(combine(getMessageConditions().get(i), other.getMessageConditions().get(i))); + } + return new CompositeMessageCondition(result); + } + + @SuppressWarnings("unchecked") + private > T combine(MessageCondition first, MessageCondition second) { + return ((T) first).combine((T) second); + } + + @Override + public CompositeMessageCondition getMatchingCondition(Message message) { + List> result = new ArrayList<>(this.messageConditions.size()); + for (MessageCondition condition : this.messageConditions) { + MessageCondition matchingCondition = (MessageCondition) condition.getMatchingCondition(message); + if (matchingCondition == null) { + return null; + } + result.add(matchingCondition); + } + return new CompositeMessageCondition(result); + } + + @Override + public int compareTo(CompositeMessageCondition other, Message message) { + checkCompatible(other); + List> otherConditions = other.getMessageConditions(); + for (int i = 0; i < this.messageConditions.size(); i++) { + int result = compare (this.messageConditions.get(i), otherConditions.get(i), message); + if (result != 0) { + return result; + } + } + return 0; + } + + @SuppressWarnings("unchecked") + private > int compare( + MessageCondition first, MessageCondition second, Message message) { + + return ((T) first).compareTo((T) second, message); + } + + private void checkCompatible(CompositeMessageCondition other) { + List> others = other.getMessageConditions(); + for (int i = 0; i < this.messageConditions.size(); i++) { + if (i < others.size()) { + if (this.messageConditions.get(i).getClass().equals(others.get(i).getClass())) { + continue; + } + } + throw new IllegalArgumentException("Mismatched CompositeMessageCondition: " + + this.messageConditions + " vs " + others); + } + } + + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof CompositeMessageCondition)) { + return false; + } + CompositeMessageCondition otherComposite = (CompositeMessageCondition) other; + checkCompatible(otherComposite); + List> otherConditions = otherComposite.getMessageConditions(); + for (int i = 0; i < this.messageConditions.size(); i++) { + if (!this.messageConditions.get(i).equals(otherConditions.get(i))) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + int hashCode = 0; + for (MessageCondition condition : this.messageConditions) { + hashCode += condition.hashCode() * 31; + } + return hashCode; + } + + @Override + public String toString() { + return this.messageConditions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}")); + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandler.java new file mode 100644 index 0000000000..4d8d8ba9d6 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandler.java @@ -0,0 +1,246 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.messaging.handler.annotation.support.reactive; + +import java.lang.reflect.AnnotatedElement; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import org.springframework.context.EmbeddedValueResolverAware; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.codec.Decoder; +import org.springframework.lang.Nullable; +import org.springframework.messaging.Message; +import org.springframework.messaging.handler.CompositeMessageCondition; +import org.springframework.messaging.handler.DestinationPatternsMessageCondition; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.handler.annotation.support.AnnotationExceptionHandlerMethodResolver; +import org.springframework.messaging.handler.invocation.AbstractExceptionHandlerMethodResolver; +import org.springframework.messaging.handler.invocation.reactive.AbstractEncoderMethodReturnValueHandler; +import org.springframework.messaging.handler.invocation.reactive.AbstractMethodMessageHandler; +import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; +import org.springframework.stereotype.Controller; +import org.springframework.util.AntPathMatcher; +import org.springframework.util.Assert; +import org.springframework.util.PathMatcher; +import org.springframework.util.StringValueResolver; +import org.springframework.validation.Validator; + +/** + * Extension of {@link AbstractMethodMessageHandler} for + * {@link MessageMapping @MessageMapping} methods. + * + *

The payload of incoming messages is decoded through + * {@link PayloadMethodArgumentResolver} using one of the configured + * {@link #setDecoders(List)} decoders. + * + *

The {@link #setEncoderReturnValueHandler encoderReturnValueHandler} + * property must be set to encode and handle return values from + * {@code @MessageMapping} methods. + * + * @author Rossen Stoyanchev + * @since 5.2 + */ +public class MessageMappingMessageHandler extends AbstractMethodMessageHandler + implements EmbeddedValueResolverAware { + + private PathMatcher pathMatcher = new AntPathMatcher(); + + private final List> decoders = new ArrayList<>(); + + @Nullable + private Validator validator; + + @Nullable + private HandlerMethodReturnValueHandler encoderReturnValueHandler; + + @Nullable + private StringValueResolver valueResolver; + + + /** + * Set the PathMatcher implementation to use for matching destinations + * against configured destination patterns. + *

By default, {@link AntPathMatcher} is used. + */ + public void setPathMatcher(PathMatcher pathMatcher) { + Assert.notNull(pathMatcher, "PathMatcher must not be null"); + this.pathMatcher = pathMatcher; + } + + /** + * Return the PathMatcher implementation to use for matching destinations. + */ + public PathMatcher getPathMatcher() { + return this.pathMatcher; + } + + /** + * Configure the decoders to user for incoming payloads. + */ + public void setDecoders(List> decoders) { + this.decoders.addAll(decoders); + } + + /** + * Return the configured decoders. + */ + public List> getDecoders() { + return this.decoders; + } + + /** + * Return the configured Validator instance. + */ + @Nullable + public Validator getValidator() { + return this.validator; + } + + /** + * Set the Validator instance used for validating {@code @Payload} arguments. + * @see org.springframework.validation.annotation.Validated + * @see PayloadMethodArgumentResolver + */ + public void setValidator(@Nullable Validator validator) { + this.validator = validator; + } + + /** + * Configure the return value handler that will encode response content. + * Consider extending {@link AbstractEncoderMethodReturnValueHandler} which + * provides the infrastructure to encode and all that's left is to somehow + * handle the encoded content, e.g. by wrapping as a message and passing it + * to something or sending it somewhere. + *

By default this is not configured in which case payload/content return + * values from {@code @MessageMapping} methods will remain unhandled. + * @param encoderReturnValueHandler the return value handler to use + * @see AbstractEncoderMethodReturnValueHandler + */ + public void setEncoderReturnValueHandler(@Nullable HandlerMethodReturnValueHandler encoderReturnValueHandler) { + this.encoderReturnValueHandler = encoderReturnValueHandler; + } + + /** + * Return the configured + * {@link #setEncoderReturnValueHandler encoderReturnValueHandler}. + */ + @Nullable + public HandlerMethodReturnValueHandler getEncoderReturnValueHandler() { + return this.encoderReturnValueHandler; + } + + @Override + public void setEmbeddedValueResolver(StringValueResolver resolver) { + this.valueResolver = resolver; + } + + + @Override + protected List initArgumentResolvers() { + List resolvers = new ArrayList<>(); + + // Custom resolvers + resolvers.addAll(getArgumentResolverConfigurer().getCustomResolvers()); + + // Catch-all + resolvers.add(new PayloadMethodArgumentResolver( + this.decoders, this.validator, getReactiveAdapterRegistry(), true)); + + return resolvers; + } + + @Override + protected List initReturnValueHandlers() { + List handlers = new ArrayList<>(); + handlers.addAll(getReturnValueHandlerConfigurer().getCustomHandlers()); + if (this.encoderReturnValueHandler != null) { + handlers.add(this.encoderReturnValueHandler); + } + return handlers; + } + + + @Override + protected boolean isHandler(Class beanType) { + return AnnotatedElementUtils.hasAnnotation(beanType, Controller.class); + } + + @Override + protected CompositeMessageCondition getMappingForMethod(Method method, Class handlerType) { + CompositeMessageCondition methodCondition = getCondition(method); + if (methodCondition != null) { + CompositeMessageCondition typeCondition = getCondition(handlerType); + if (typeCondition != null) { + return typeCondition.combine(methodCondition); + } + } + return methodCondition; + } + + @Nullable + private CompositeMessageCondition getCondition(AnnotatedElement element) { + MessageMapping annot = AnnotatedElementUtils.findMergedAnnotation(element, MessageMapping.class); + if (annot == null || annot.value().length == 0) { + return null; + } + String[] destinations = annot.value(); + if (this.valueResolver != null) { + destinations = Arrays.stream(annot.value()) + .map(s -> this.valueResolver.resolveStringValue(s)) + .toArray(String[]::new); + } + return new CompositeMessageCondition(new DestinationPatternsMessageCondition(destinations, this.pathMatcher)); + } + + @Override + protected Set getDirectLookupMappings(CompositeMessageCondition mapping) { + Set result = new LinkedHashSet<>(); + for (String pattern : mapping.getCondition(DestinationPatternsMessageCondition.class).getPatterns()) { + if (!this.pathMatcher.isPattern(pattern)) { + result.add(pattern); + } + } + return result; + } + + @Override + protected String getDestination(Message message) { + return (String) message.getHeaders().get(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER); + } + + @Override + protected CompositeMessageCondition getMatchingMapping(CompositeMessageCondition mapping, Message message) { + return mapping.getMatchingCondition(message); + } + + @Override + protected Comparator getMappingComparator(Message message) { + return (info1, info2) -> info1.compareTo(info2, message); + } + + @Override + protected AbstractExceptionHandlerMethodResolver createExceptionMethodResolverFor(Class beanType) { + return new AnnotationExceptionHandlerMethodResolver(beanType); + } + +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageMappingInfo.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageMappingInfo.java index 317c24919b..e6d6b55994 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageMappingInfo.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageMappingInfo.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.messaging.simp; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; +import org.springframework.messaging.handler.CompositeMessageCondition; import org.springframework.messaging.handler.DestinationPatternsMessageCondition; import org.springframework.messaging.handler.MessageCondition; @@ -34,62 +35,44 @@ import org.springframework.messaging.handler.MessageCondition; */ public class SimpMessageMappingInfo implements MessageCondition { - private final SimpMessageTypeMessageCondition messageTypeMessageCondition; - - private final DestinationPatternsMessageCondition destinationConditions; + private final CompositeMessageCondition delegate; public SimpMessageMappingInfo(SimpMessageTypeMessageCondition messageTypeMessageCondition, DestinationPatternsMessageCondition destinationConditions) { - this.messageTypeMessageCondition = messageTypeMessageCondition; - this.destinationConditions = destinationConditions; + this.delegate = new CompositeMessageCondition(messageTypeMessageCondition, destinationConditions); + } + + private SimpMessageMappingInfo(CompositeMessageCondition delegate) { + this.delegate = delegate; } public SimpMessageTypeMessageCondition getMessageTypeMessageCondition() { - return this.messageTypeMessageCondition; + return this.delegate.getCondition(SimpMessageTypeMessageCondition.class); } public DestinationPatternsMessageCondition getDestinationConditions() { - return this.destinationConditions; + return this.delegate.getCondition(DestinationPatternsMessageCondition.class); } @Override public SimpMessageMappingInfo combine(SimpMessageMappingInfo other) { - SimpMessageTypeMessageCondition typeCond = - this.getMessageTypeMessageCondition().combine(other.getMessageTypeMessageCondition()); - DestinationPatternsMessageCondition destCond = - this.destinationConditions.combine(other.getDestinationConditions()); - return new SimpMessageMappingInfo(typeCond, destCond); + return new SimpMessageMappingInfo(this.delegate.combine(other.delegate)); } @Override @Nullable public SimpMessageMappingInfo getMatchingCondition(Message message) { - SimpMessageTypeMessageCondition typeCond = this.messageTypeMessageCondition.getMatchingCondition(message); - if (typeCond == null) { - return null; - } - DestinationPatternsMessageCondition destCond = this.destinationConditions.getMatchingCondition(message); - if (destCond == null) { - return null; - } - return new SimpMessageMappingInfo(typeCond, destCond); + CompositeMessageCondition condition = this.delegate.getMatchingCondition(message); + return condition != null ? new SimpMessageMappingInfo(condition) : null; } @Override public int compareTo(SimpMessageMappingInfo other, Message message) { - int result = this.messageTypeMessageCondition.compareTo(other.messageTypeMessageCondition, message); - if (result != 0) { - return result; - } - result = this.destinationConditions.compareTo(other.destinationConditions, message); - if (result != 0) { - return result; - } - return 0; + return this.delegate.compareTo(other.delegate, message); } @@ -101,19 +84,17 @@ public class SimpMessageMappingInfo implements MessageCondition> decoders = Collections.singletonList(StringDecoder.allMimeTypes()); + List> encoders = Collections.singletonList(CharSequenceEncoder.allMimeTypes()); + + ReactiveAdapterRegistry registry = ReactiveAdapterRegistry.getSharedInstance(); + this.returnValueHandler = new TestEncoderReturnValueHandler(encoders, registry); + + PropertySource source = new MapPropertySource("test", Collections.singletonMap("path", "path123")); + + StaticApplicationContext context = new StaticApplicationContext(); + context.getEnvironment().getPropertySources().addFirst(source); + context.registerSingleton("testController", TestController.class); + context.refresh(); + + MessageMappingMessageHandler messageHandler = new MessageMappingMessageHandler(); + messageHandler.setApplicationContext(context); + messageHandler.setEmbeddedValueResolver(new EmbeddedValueResolver(context.getBeanFactory())); + messageHandler.setDecoders(decoders); + messageHandler.setEncoderReturnValueHandler(this.returnValueHandler); + messageHandler.afterPropertiesSet(); + + return messageHandler; + } + + private Message message(String destination, String... content) { + return new GenericMessage<>( + Flux.fromIterable(Arrays.stream(content).map(this::toDataBuffer).collect(Collectors.toList())), + Collections.singletonMap(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination)); + } + + private DataBuffer toDataBuffer(String payload) { + return bufferFactory.wrap(payload.getBytes(UTF_8)); + } + + private void verifyOutputContent(List expected) { + List buffers = this.returnValueHandler.getOutputContent(); + assertNotNull("No output: no matching handler method?", buffers); + List actual = buffers.stream().map(buffer -> dumpString(buffer, UTF_8)).collect(Collectors.toList()); + assertEquals(expected, actual); + } + + + @Controller + static class TestController { + + @MessageMapping("/string") + String handleString(String payload) { + return payload + "::response"; + } + + @MessageMapping("/monoString") + Mono handleMonoString(Mono payload) { + return payload.map(s -> s + "::response").delayElement(Duration.ofMillis(10)); + } + + @MessageMapping("/fluxString") + Flux handleFluxString(Flux payload) { + return payload.map(s -> s + "::response").delayElements(Duration.ofMillis(10)); + } + + @MessageMapping("/${path}") + String handleWithPlaceholder(String payload) { + return payload + "::response"; + } + + @MessageMapping("/exception") + String handleAndThrow() { + throw new IllegalArgumentException("rejected"); + } + + @MessageMapping("/errorSignal") + Mono handleAndSignalError() { + return Mono.delay(Duration.ofMillis(10)) + .flatMap(aLong -> Mono.error(new IllegalArgumentException("rejected"))); + } + + @MessageExceptionHandler + Mono handleException(IllegalArgumentException ex) { + return Mono.delay(Duration.ofMillis(10)).map(aLong -> ex.getMessage() + "::handled"); + } + } + + + private static class TestEncoderReturnValueHandler extends AbstractEncoderMethodReturnValueHandler { + + @Nullable + private volatile List outputContent; + + + TestEncoderReturnValueHandler(List> encoders, ReactiveAdapterRegistry registry) { + super(encoders, registry); + } + + + @Nullable + public List getOutputContent() { + return this.outputContent; + } + + @Override + protected Mono handleEncodedContent( + Flux encodedContent, MethodParameter returnType, Message message) { + + return encodedContent.collectList().doOnNext(buffers -> this.outputContent = buffers).then(); + } + } + +}