diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java index b2d58cd7a3..d2463f2cae 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java @@ -93,7 +93,8 @@ class ReactiveTypeHandler { private final ContentNegotiationManager contentNegotiationManager; - private final ContextSnapshotFactory contextSnapshotFactory; + @Nullable + private final Object contextSnapshotHelper; public ReactiveTypeHandler() { @@ -102,7 +103,7 @@ class ReactiveTypeHandler { ReactiveTypeHandler( ReactiveAdapterRegistry registry, TaskExecutor executor, ContentNegotiationManager manager, - @Nullable ContextSnapshotFactory contextSnapshotFactory) { + @Nullable Object contextSnapshotFactory) { Assert.notNull(registry, "ReactiveAdapterRegistry is required"); Assert.notNull(executor, "TaskExecutor is required"); @@ -110,8 +111,15 @@ class ReactiveTypeHandler { this.adapterRegistry = registry; this.taskExecutor = executor; this.contentNegotiationManager = manager; - this.contextSnapshotFactory = (contextSnapshotFactory != null ? - contextSnapshotFactory : ContextSnapshotFactory.builder().build()); + this.contextSnapshotHelper = initContextSnapshotHelper(contextSnapshotFactory); + } + + @Nullable + private static Object initContextSnapshotHelper(@Nullable Object snapshotFactory) { + if (isContextPropagationPresent) { + return new ContextSnapshotHelper((ContextSnapshotFactory) snapshotFactory); + } + return null; } @@ -140,8 +148,10 @@ class ReactiveTypeHandler { TaskDecorator taskDecorator = null; if (isContextPropagationPresent) { - returnValue = ContextSnapshotHelper.writeReactorContext(returnValue, this.contextSnapshotFactory); - taskDecorator = ContextSnapshotHelper.getTaskDecorator(this.contextSnapshotFactory); + ContextSnapshotHelper helper = (ContextSnapshotHelper) this.contextSnapshotHelper; + Assert.notNull(helper, "No ContextSnapshotHelper"); + returnValue = helper.writeReactorContext(returnValue); + taskDecorator = helper.getTaskDecorator(); } ResolvableType elementType = ResolvableType.forMethodParameter(returnType).getGeneric(); @@ -534,16 +544,22 @@ class ReactiveTypeHandler { } - private static class ContextSnapshotHelper { + private static final class ContextSnapshotHelper { + + private final ContextSnapshotFactory snapshotFactory; + + private ContextSnapshotHelper(@Nullable ContextSnapshotFactory factory) { + this.snapshotFactory = (factory != null ? factory : ContextSnapshotFactory.builder().build()); + } @SuppressWarnings("ReactiveStreamsUnusedPublisher") - public static Object writeReactorContext(Object returnValue, ContextSnapshotFactory snapshotFactory) { + public Object writeReactorContext(Object returnValue) { if (Mono.class.isAssignableFrom(returnValue.getClass())) { - ContextSnapshot snapshot = snapshotFactory.captureAll(); + ContextSnapshot snapshot = this.snapshotFactory.captureAll(); return ((Mono) returnValue).contextWrite(snapshot::updateContext); } else if (Flux.class.isAssignableFrom(returnValue.getClass())) { - ContextSnapshot snapshot = snapshotFactory.captureAll(); + ContextSnapshot snapshot = this.snapshotFactory.captureAll(); return ((Flux) returnValue).contextWrite(snapshot::updateContext); } else { @@ -551,8 +567,8 @@ class ReactiveTypeHandler { } } - public static TaskDecorator getTaskDecorator(ContextSnapshotFactory snapshotFactory) { - return new ContextPropagatingTaskDecorator(snapshotFactory); + public TaskDecorator getTaskDecorator() { + return new ContextPropagatingTaskDecorator(this.snapshotFactory); } }