diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java index 45e8b8c0031..76fc45bea51 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java @@ -26,6 +26,7 @@ import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; +import reactor.util.context.Context; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -75,7 +76,13 @@ public class ChannelSendOperator extends Mono implements Scannable { @SuppressWarnings("deprecation") - private class WriteWithBarrier extends SubscriberAdapter implements Publisher { + private final class WriteWithBarrier + implements Publisher, CoreSubscriber, Subscription { + + private final CoreSubscriber subscriber; + + @Nullable + private Subscription subscription; /** * We've at at least one emission, we've called the write function, the write @@ -102,62 +109,26 @@ public class ChannelSendOperator extends Mono implements Scannable { @Nullable private Subscriber writeSubscriber; - public WriteWithBarrier(Subscriber subscriber) { - super(subscriber); + WriteWithBarrier(CoreSubscriber subscriber) { + this.subscriber = subscriber; } @Override - protected void doOnSubscribe(Subscription subscription) { - super.doOnSubscribe(subscription); - super.upstream().request(1); // bypass doRequest - } - - @Override - public void doNext(T item) { - if (this.readyToWrite) { - obtainWriteSubscriber().onNext(item); - return; - } - synchronized (this) { - if (this.readyToWrite) { - obtainWriteSubscriber().onNext(item); - } - else if (this.beforeFirstEmission) { - this.item = item; - this.beforeFirstEmission = false; - writeFunction.apply(this).subscribe(new DownstreamBridge(downstream())); - } - else { - if (this.subscription != null) { - this.subscription.cancel(); - } - downstream().onError(new IllegalStateException("Unexpected item.")); - } + public void cancel() { + Subscription s = this.subscription; + if (s != null) { + this.subscription = null; + s.cancel(); } } @Override - public void doError(Throwable ex) { - if (this.readyToWrite) { - obtainWriteSubscriber().onError(ex); - return; - } - synchronized (this) { - if (this.readyToWrite) { - obtainWriteSubscriber().onError(ex); - } - else if (this.beforeFirstEmission) { - this.beforeFirstEmission = false; - downstream().onError(ex); - } - else { - this.error = ex; - } - } + public Context currentContext() { + return subscriber.currentContext(); } @Override - public void doComplete() { + public final void onComplete() { if (this.readyToWrite) { obtainWriteSubscriber().onComplete(); return; @@ -169,7 +140,7 @@ public class ChannelSendOperator extends Mono implements Scannable { else if (this.beforeFirstEmission) { this.completed = true; this.beforeFirstEmission = false; - writeFunction.apply(this).subscribe(new DownstreamBridge(downstream())); + writeFunction.apply(this).subscribe(new DownstreamBridge(subscriber)); } else { this.completed = true; @@ -177,6 +148,60 @@ public class ChannelSendOperator extends Mono implements Scannable { } } + @Override + public final void onError(Throwable ex) { + if (this.readyToWrite) { + obtainWriteSubscriber().onError(ex); + return; + } + synchronized (this) { + if (this.readyToWrite) { + obtainWriteSubscriber().onError(ex); + } + else if (this.beforeFirstEmission) { + this.beforeFirstEmission = false; + subscriber.onError(ex); + } + else { + this.error = ex; + } + } + } + + @Override + public final void onNext(T item) { + if (this.readyToWrite) { + obtainWriteSubscriber().onNext(item); + return; + } + //FIXME revisit in case of reentrant sync deadlock + synchronized (this) { + if (this.readyToWrite) { + obtainWriteSubscriber().onNext(item); + } + else if (this.beforeFirstEmission) { + this.item = item; + this.beforeFirstEmission = false; + writeFunction.apply(this).subscribe(new DownstreamBridge(subscriber)); + } + else { + if (this.subscription != null) { + this.subscription.cancel(); + } + subscriber.onError(new IllegalStateException("Unexpected item.")); + } + } + } + + @Override + public final void onSubscribe(Subscription s) { + if (Operators.validate(this.subscription, s)) { + this.subscription = s; + this.subscriber.onSubscribe(this); + s.request(1); // bypass doRequest + } + } + @Override public void subscribe(Subscriber writeSubscriber) { synchronized (this) { @@ -212,9 +237,13 @@ public class ChannelSendOperator extends Mono implements Scannable { } @Override - protected void doRequest(long n) { + public void request(long n) { + Subscription s = this.subscription; + if (s == null) { + return; + } if (readyToWrite) { - super.doRequest(n); + s.request(n); return; } synchronized (this) { @@ -227,9 +256,9 @@ public class ChannelSendOperator extends Mono implements Scannable { if (n == 0) { return; } - super.doRequest(n); } } + s.request(n); } private Subscriber obtainWriteSubscriber() { @@ -239,139 +268,11 @@ public class ChannelSendOperator extends Mono implements Scannable { } - // TODO Remove this copy of Reactor 3.0.x Operators.SubscriberAdapter - private static class SubscriberAdapter implements Subscriber, Subscription { + private class DownstreamBridge implements CoreSubscriber { - protected final Subscriber subscriber; + private final CoreSubscriber downstream; - @Nullable - protected Subscription subscription; - - public SubscriberAdapter(Subscriber subscriber) { - this.subscriber = subscriber; - } - - public Subscriber downstream() { - return this.subscriber; - } - - @Override - public final void cancel() { - try { - doCancel(); - } - catch (Throwable throwable) { - doOnSubscriberError(Operators.onOperatorError(this.subscription, throwable)); - } - } - - @Override - public final void onComplete() { - try { - doComplete(); - } - catch (Throwable throwable) { - doOnSubscriberError(Operators.onOperatorError(throwable)); - } - } - - @Override - public final void onError(Throwable t) { - doError(t); - } - - @Override - public final void onNext(I i) { - try { - doNext(i); - } - catch (Throwable throwable) { - doOnSubscriberError(Operators.onOperatorError(this.subscription, throwable, i)); - } - } - - @Override - public final void onSubscribe(Subscription s) { - if (Operators.validate(this.subscription, s)) { - try { - this.subscription = s; - doOnSubscribe(s); - } - catch (Throwable throwable) { - doOnSubscriberError(Operators.onOperatorError(s, throwable)); - } - } - } - - @Override - public final void request(long n) { - try { - Operators.checkRequest(n); - doRequest(n); - } - catch (Throwable throwable) { - doCancel(); - doOnSubscriberError(Operators.onOperatorError(throwable)); - } - } - - @Override - public String toString() { - return getClass().getSimpleName(); - } - - /** - * Hook for further processing of onSubscribe's Subscription. - * @param subscription the subscription to optionally process - */ - protected void doOnSubscribe(Subscription subscription) { - this.subscriber.onSubscribe(this); - } - - public Subscription upstream() { - Assert.state(this.subscription != null, "No subscription"); - return this.subscription; - } - - @SuppressWarnings("unchecked") - protected void doNext(I i) { - this.subscriber.onNext((O) i); - } - - protected void doError(Throwable throwable) { - this.subscriber.onError(throwable); - } - - protected void doOnSubscriberError(Throwable throwable){ - this.subscriber.onError(throwable); - } - - protected void doComplete() { - this.subscriber.onComplete(); - } - - protected void doRequest(long n) { - Subscription s = this.subscription; - if (s != null) { - s.request(n); - } - } - - protected void doCancel() { - Subscription s = this.subscription; - if (s != null) { - this.subscription = null; - s.cancel(); - } - } - } - - - private class DownstreamBridge implements Subscriber { - - private final Subscriber downstream; - - public DownstreamBridge(Subscriber downstream) { + public DownstreamBridge(CoreSubscriber downstream) { this.downstream = downstream; } @@ -393,6 +294,11 @@ public class ChannelSendOperator extends Mono implements Scannable { public void onComplete() { this.downstream.onComplete(); } + + @Override + public Context currentContext() { + return downstream.currentContext(); + } } }