From 4e1c0c682600a093fc76b14d26bcf3707a937229 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 21 Feb 2019 18:08:30 -0500 Subject: [PATCH] @MessageExceptionHandler supports error signal Before this change if a controller method returned a Publisher whose first signal was an error, the error signal would not be delegated to a @MessageExceptionHandler method as expected. To make this work for now we use a package private local copy of the ChannelSendOperator from spring-web. See gh-21987 --- ...stractEncoderMethodReturnValueHandler.java | 3 +- .../reactive/ChannelSendOperator.java | 410 ++++++++++++++++++ .../reactive/MethodMessageHandlerTests.java | 15 +- .../reactive/TestReturnValueHandler.java | 8 + ...RSocketClientToServerIntegrationTests.java | 36 +- 5 files changed, 463 insertions(+), 9 deletions(-) create mode 100644 spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/ChannelSendOperator.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java index aa5916a4147..5c185ce45fe 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java @@ -112,7 +112,8 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler Flux encodedContent = encodeContent( returnValue, returnType, bufferFactory, mimeType, Collections.emptyMap()); - return handleEncodedContent(encodedContent, returnType, message); + return new ChannelSendOperator<>(encodedContent, publisher -> + handleEncodedContent(Flux.from(publisher), returnType, message)); } @SuppressWarnings("unchecked") diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/ChannelSendOperator.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/ChannelSendOperator.java new file mode 100644 index 00000000000..b89f5f96fe7 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/ChannelSendOperator.java @@ -0,0 +1,410 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.handler.invocation.reactive; + +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * ---------------------- + *

NOTE: This class was copied from + * {@code org.springframework.http.server.reactive.ChannelSendOperator} and is + * identical to it. It's used for the same purpose, i.e. the ability to switch to + * alternate handling via annotated exception handler methods if the output + * publisher starts with an error. + *

----------------------
+ * + *

Given a write function that accepts a source {@code Publisher} to write + * with and returns {@code Publisher} for the result, this operator helps + * to defer the invocation of the write function, until we know if the source + * publisher will begin publishing without an error. If the first emission is + * an error, the write function is bypassed, and the error is sent directly + * through the result publisher. Otherwise the write function is invoked. + * + * @author Rossen Stoyanchev + * @author Stephane Maldini + * @since 5.2 + * @param the type of element signaled + */ +class ChannelSendOperator extends Mono implements Scannable { + + private final Function, Publisher> writeFunction; + + private final Flux source; + + + public ChannelSendOperator(Publisher source, Function, Publisher> writeFunction) { + this.source = Flux.from(source); + this.writeFunction = writeFunction; + } + + + @Override + @Nullable + @SuppressWarnings("rawtypes") + public Object scanUnsafe(Attr key) { + if (key == Attr.PREFETCH) { + return Integer.MAX_VALUE; + } + if (key == Attr.PARENT) { + return this.source; + } + return null; + } + + @Override + public void subscribe(CoreSubscriber actual) { + this.source.subscribe(new WriteBarrier(actual)); + } + + + private enum State { + + /** No emissions from the upstream source yet. */ + NEW, + + /** + * At least one signal of any kind has been received; we're ready to + * call the write function and proceed with actual writing. + */ + FIRST_SIGNAL_RECEIVED, + + /** + * The write subscriber has subscribed and requested; we're going to + * emit the cached signals. + */ + EMITTING_CACHED_SIGNALS, + + /** + * The write subscriber has subscribed, and cached signals have been + * emitted to it; we're ready to switch to a simple pass-through mode + * for all remaining signals. + **/ + READY_TO_WRITE + + } + + + /** + * A barrier inserted between the write source and the write subscriber + * (i.e. the HTTP server adapter) that pre-fetches and waits for the first + * signal before deciding whether to hook in to the write subscriber. + * + *

Acts as: + *

+ * + *

Also uses {@link WriteCompletionBarrier} to communicate completion + * and detect cancel signals from the completion subscriber. + */ + private class WriteBarrier implements CoreSubscriber, Subscription, Publisher { + + /* Bridges signals to and from the completionSubscriber */ + private final WriteCompletionBarrier writeCompletionBarrier; + + /* Upstream write source subscription */ + @Nullable + private Subscription subscription; + + /** Cached data item before readyToWrite. */ + @Nullable + private T item; + + /** Cached error signal before readyToWrite. */ + @Nullable + private Throwable error; + + /** Cached onComplete signal before readyToWrite. */ + private boolean completed = false; + + /** Recursive demand while emitting cached signals. */ + private long demandBeforeReadyToWrite; + + /** Current state. */ + private State state = State.NEW; + + /** The actual writeSubscriber from the HTTP server adapter. */ + @Nullable + private Subscriber writeSubscriber; + + + WriteBarrier(CoreSubscriber completionSubscriber) { + this.writeCompletionBarrier = new WriteCompletionBarrier(completionSubscriber, this); + } + + + // Subscriber methods (we're the subscriber to the write source).. + + @Override + public final void onSubscribe(Subscription s) { + if (Operators.validate(this.subscription, s)) { + this.subscription = s; + this.writeCompletionBarrier.connect(); + s.request(1); + } + } + + @Override + public final void onNext(T item) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onNext(item); + return; + } + //FIXME revisit in case of reentrant sync deadlock + synchronized (this) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onNext(item); + } + else if (this.state == State.NEW) { + this.item = item; + this.state = State.FIRST_SIGNAL_RECEIVED; + writeFunction.apply(this).subscribe(this.writeCompletionBarrier); + } + else { + if (this.subscription != null) { + this.subscription.cancel(); + } + this.writeCompletionBarrier.onError(new IllegalStateException("Unexpected item.")); + } + } + } + + private Subscriber requiredWriteSubscriber() { + Assert.state(this.writeSubscriber != null, "No write subscriber"); + return this.writeSubscriber; + } + + @Override + public final void onError(Throwable ex) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onError(ex); + return; + } + synchronized (this) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onError(ex); + } + else if (this.state == State.NEW) { + this.state = State.FIRST_SIGNAL_RECEIVED; + this.writeCompletionBarrier.onError(ex); + } + else { + this.error = ex; + } + } + } + + @Override + public final void onComplete() { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onComplete(); + return; + } + synchronized (this) { + if (this.state == State.READY_TO_WRITE) { + requiredWriteSubscriber().onComplete(); + } + else if (this.state == State.NEW) { + this.completed = true; + this.state = State.FIRST_SIGNAL_RECEIVED; + writeFunction.apply(this).subscribe(this.writeCompletionBarrier); + } + else { + this.completed = true; + } + } + } + + @Override + public Context currentContext() { + return this.writeCompletionBarrier.currentContext(); + } + + + // Subscription methods (we're the Subscription to the writeSubscriber).. + + @Override + public void request(long n) { + Subscription s = this.subscription; + if (s == null) { + return; + } + if (this.state == State.READY_TO_WRITE) { + s.request(n); + return; + } + synchronized (this) { + if (this.writeSubscriber != null) { + if (this.state == State.EMITTING_CACHED_SIGNALS) { + this.demandBeforeReadyToWrite = n; + return; + } + try { + this.state = State.EMITTING_CACHED_SIGNALS; + if (emitCachedSignals()) { + return; + } + n = n + this.demandBeforeReadyToWrite - 1; + if (n == 0) { + return; + } + } + finally { + this.state = State.READY_TO_WRITE; + } + } + } + s.request(n); + } + + private boolean emitCachedSignals() { + if (this.item != null) { + requiredWriteSubscriber().onNext(this.item); + } + if (this.error != null) { + requiredWriteSubscriber().onError(this.error); + return true; + } + if (this.completed) { + requiredWriteSubscriber().onComplete(); + return true; + } + return false; + } + + @Override + public void cancel() { + Subscription s = this.subscription; + if (s != null) { + this.subscription = null; + s.cancel(); + } + } + + + // Publisher methods (we're the Publisher to the writeSubscriber).. + + @Override + public void subscribe(Subscriber writeSubscriber) { + synchronized (this) { + Assert.state(this.writeSubscriber == null, "Only one write subscriber supported"); + this.writeSubscriber = writeSubscriber; + if (this.error != null || this.completed) { + this.writeSubscriber.onSubscribe(Operators.emptySubscription()); + emitCachedSignals(); + } + else { + this.writeSubscriber.onSubscribe(this); + } + } + } + } + + + /** + * We need an extra barrier between the WriteBarrier itself and the actual + * completion subscriber. + * + *

The completionSubscriber is subscribed initially to the WriteBarrier. + * Later after the first signal is received, we need one more subscriber + * instance (per spec can only subscribe once) to subscribe to the write + * function and switch to delegating completion signals from it. + */ + private class WriteCompletionBarrier implements CoreSubscriber, Subscription { + + /* Downstream write completion subscriber */ + private final CoreSubscriber completionSubscriber; + + private final WriteBarrier writeBarrier; + + @Nullable + private Subscription subscription; + + + public WriteCompletionBarrier(CoreSubscriber subscriber, WriteBarrier writeBarrier) { + this.completionSubscriber = subscriber; + this.writeBarrier = writeBarrier; + } + + + /** + * Connect the underlying completion subscriber to this barrier in order + * to track cancel signals and pass them on to the write barrier. + */ + public void connect() { + this.completionSubscriber.onSubscribe(this); + } + + // Subscriber methods (we're the subscriber to the write function).. + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + } + + @Override + public void onError(Throwable ex) { + this.completionSubscriber.onError(ex); + } + + @Override + public void onComplete() { + this.completionSubscriber.onComplete(); + } + + @Override + public Context currentContext() { + return this.completionSubscriber.currentContext(); + } + + + @Override + public void request(long n) { + // Ignore: we don't produce data + } + + @Override + public void cancel() { + this.writeBarrier.cancel(); + Subscription subscription = this.subscription; + if (subscription != null) { + subscription.cancel(); + } + } + } + +} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/MethodMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/MethodMessageHandlerTests.java index bcd7b6eec0b..82a448e14f3 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/MethodMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/MethodMessageHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import java.util.function.Consumer; import org.hamcrest.Matchers; import org.junit.Test; +import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -63,7 +64,7 @@ public class MethodMessageHandlerTests { assertEquals(5, mappings.keySet().size()); assertThat(mappings.keySet(), Matchers.containsInAnyOrder( - "/handleMessage", "/handleMessageWithArgument", "/handleMessageAndThrow", + "/handleMessage", "/handleMessageWithArgument", "/handleMessageWithError", "/handleMessageMatch1", "/handleMessageMatch2")); } @@ -80,7 +81,7 @@ public class MethodMessageHandlerTests { handler.handleMessage(message).block(Duration.ofSeconds(5)); - StepVerifier.create((Mono) handler.getLastReturnValue()) + StepVerifier.create((Publisher) handler.getLastReturnValue()) .expectNext("handleMessageMatch1") .verifyComplete(); } @@ -100,7 +101,7 @@ public class MethodMessageHandlerTests { handler.handleMessage(message).block(Duration.ofSeconds(5)); - StepVerifier.create((Mono) handler.getLastReturnValue()) + StepVerifier.create((Publisher) handler.getLastReturnValue()) .expectNext("handleMessageWithArgument,payload=foo") .verifyComplete(); } @@ -111,11 +112,11 @@ public class MethodMessageHandlerTests { TestMethodMessageHandler handler = initMethodMessageHandler(TestController.class); Message message = new GenericMessage<>("body", Collections.singletonMap( - DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, "/handleMessageAndThrow")); + DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, "/handleMessageWithError")); handler.handleMessage(message).block(Duration.ofSeconds(5)); - StepVerifier.create((Mono) handler.getLastReturnValue()) + StepVerifier.create((Publisher) handler.getLastReturnValue()) .expectNext("handleIllegalStateException,ex=rejected") .verifyComplete(); } @@ -153,7 +154,7 @@ public class MethodMessageHandlerTests { return delay("handleMessageWithArgument,payload=" + payload); } - public Mono handleMessageAndThrow() { + public Mono handleMessageWithError() { return Mono.delay(Duration.ofMillis(10)) .flatMap(aLong -> Mono.error(new IllegalStateException("rejected"))); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/TestReturnValueHandler.java b/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/TestReturnValueHandler.java index 449cec194b6..1133c13d73e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/TestReturnValueHandler.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/handler/invocation/reactive/TestReturnValueHandler.java @@ -15,6 +15,7 @@ */ package org.springframework.messaging.handler.invocation.reactive; +import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import org.springframework.core.MethodParameter; @@ -43,7 +44,14 @@ public class TestReturnValueHandler implements HandlerMethodReturnValueHandler { } @Override + @SuppressWarnings("unchecked") public Mono handleReturnValue(@Nullable Object value, MethodParameter returnType, Message message) { + return value instanceof Publisher ? + new ChannelSendOperator((Publisher) value, this::saveValue) : + saveValue(value); + } + + private Mono saveValue(@Nullable Object value) { this.lastReturnValue = value; return Mono.empty(); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java index b6e0f8f38c6..ff13b07100f 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java @@ -35,9 +35,9 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.StringDecoder; -import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.ReactiveMessageChannel; import org.springframework.messaging.ReactiveSubscribableChannel; +import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.support.DefaultReactiveMessageChannel; import org.springframework.stereotype.Controller; @@ -170,6 +170,26 @@ public class RSocketClientToServerIntegrationTests { .verifyComplete(); } + @Test + public void handleWithThrownException() { + + Mono result = requester.route("thrown-exception").data("a").retrieveMono(String.class); + + StepVerifier.create(result) + .expectNext("Invalid input error handled") + .verifyComplete(); + } + + @Test + public void handleWithErrorSignal() { + + Mono result = requester.route("error-signal").data("a").retrieveMono(String.class); + + StepVerifier.create(result) + .expectNext("Invalid input error handled") + .verifyComplete(); + } + @Test public void noMatchingRoute() { Mono result = requester.route("invalid").data("anything").retrieveMono(String.class); @@ -208,6 +228,20 @@ public class RSocketClientToServerIntegrationTests { return payloads.delayElements(Duration.ofMillis(10)).map(payload -> payload + " async"); } + @MessageMapping("thrown-exception") + Mono handleAndThrow(String payload) { + throw new IllegalArgumentException("Invalid input error"); + } + + @MessageMapping("error-signal") + Mono handleAndReturnError(String payload) { + return Mono.error(new IllegalArgumentException("Invalid input error")); + } + + @MessageExceptionHandler + Mono handleException(IllegalArgumentException ex) { + return Mono.delay(Duration.ofMillis(10)).map(aLong -> ex.getMessage() + " handled"); + } }