diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java index fc20f17a42..da46443edc 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ChannelSendOperator.java @@ -181,7 +181,15 @@ public class ChannelSendOperator extends Mono implements Scannable { else if (this.state == State.NEW) { this.item = item; this.state = State.FIRST_SIGNAL_RECEIVED; - writeFunction.apply(this).subscribe(this.writeCompletionBarrier); + Publisher result; + try { + result = writeFunction.apply(this); + } + catch (Throwable ex) { + this.writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(this.writeCompletionBarrier); } else { if (this.subscription != null) { @@ -230,7 +238,15 @@ public class ChannelSendOperator extends Mono implements Scannable { else if (this.state == State.NEW) { this.completed = true; this.state = State.FIRST_SIGNAL_RECEIVED; - writeFunction.apply(this).subscribe(this.writeCompletionBarrier); + Publisher result; + try { + result = writeFunction.apply(this); + } + catch (Throwable ex) { + this.writeCompletionBarrier.onError(ex); + return; + } + result.subscribe(this.writeCompletionBarrier); } else { this.completed = true; diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java index 266a7f14f1..0a81e34d39 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ChannelSendOperatorTests.java @@ -92,7 +92,7 @@ public class ChannelSendOperatorTests { @Test - public void writeMultipleItems() throws Exception { + public void writeMultipleItems() { List items = Arrays.asList("one", "two", "three"); Mono completion = Flux.fromIterable(items).as(this::sendOperator); Signal signal = completion.materialize().block(); @@ -108,7 +108,7 @@ public class ChannelSendOperatorTests { } @Test - public void errorAfterMultipleItems() throws Exception { + public void errorAfterMultipleItems() { IllegalStateException error = new IllegalStateException("boo"); Flux publisher = Flux.generate(() -> 0, (idx , subscriber) -> { int i = ++idx; @@ -213,6 +213,25 @@ public class ChannelSendOperatorTests { bufferFactory.checkForLeaks(); } + @Test // gh-23175 + public void errorInWriteFunction() { + + StepVerifier + .create(new ChannelSendOperator<>(Mono.just("one"), p -> { + throw new IllegalStateException("boo"); + })) + .expectErrorMessage("boo") + .verify(Duration.ofMillis(5000)); + + StepVerifier + .create(new ChannelSendOperator<>(Mono.empty(), p -> { + throw new IllegalStateException("boo"); + })) + .expectErrorMessage("boo") + .verify(Duration.ofMillis(5000)); + } + + private Mono sendOperator(Publisher source){ return new ChannelSendOperator<>(source, writer::send); }