diff --git a/spring-context/src/main/java/org/springframework/cache/interceptor/CacheAspectSupport.java b/spring-context/src/main/java/org/springframework/cache/interceptor/CacheAspectSupport.java index 57ad772ca7..df2ad3f56a 100644 --- a/spring-context/src/main/java/org/springframework/cache/interceptor/CacheAspectSupport.java +++ b/spring-context/src/main/java/org/springframework/cache/interceptor/CacheAspectSupport.java @@ -26,12 +26,12 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; +import reactor.core.observability.DefaultSignalListener; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -90,6 +90,7 @@ import org.springframework.util.function.SupplierUtils; * @author Sam Brannen * @author Stephane Nicoll * @author Sebastien Deleuze + * @author Simon Baslé * @since 3.1 */ public abstract class CacheAspectSupport extends AbstractCacheInvoker @@ -1036,32 +1037,45 @@ public abstract class CacheAspectSupport extends AbstractCacheInvoker /** - * Reactive Streams Subscriber collection for collecting a List to cache. + * Reactor stateful SignalListener for collecting a List to cache. */ - private class CachePutListSubscriber implements Subscriber { + private class CachePutSignalListener extends DefaultSignalListener { - private final CachePutRequest request; + private final AtomicReference request; private final List cacheValue = new ArrayList<>(); - public CachePutListSubscriber(CachePutRequest request) { - this.request = request; + public CachePutSignalListener(CachePutRequest request) { + this.request = new AtomicReference<>(request); } @Override - public void onSubscribe(Subscription s) { - s.request(Integer.MAX_VALUE); - } - @Override - public void onNext(Object o) { + public void doOnNext(Object o) { this.cacheValue.add(o); } + @Override - public void onError(Throwable t) { + public void doOnComplete() { + CachePutRequest r = this.request.get(); + if (this.request.compareAndSet(r, null)) { + r.performCachePut(this.cacheValue); + } } + @Override - public void onComplete() { - this.request.performCachePut(this.cacheValue); + public void doOnCancel() { + // Note: we don't use doFinally as we want to propagate the signal after cache put, not before + CachePutRequest r = this.request.get(); + if (this.request.compareAndSet(r, null)) { + r.performCachePut(this.cacheValue); + } + } + + @Override + public void doOnError(Throwable error) { + if (this.request.getAndSet(null) != null) { + this.cacheValue.clear(); + } } } @@ -1145,9 +1159,8 @@ public abstract class CacheAspectSupport extends AbstractCacheInvoker ReactiveAdapter adapter = (result != null ? this.registry.getAdapter(result.getClass()) : null); if (adapter != null) { if (adapter.isMultiValue()) { - Flux source = Flux.from(adapter.toPublisher(result)); - source.subscribe(new CachePutListSubscriber(request)); - return adapter.fromPublisher(source); + return adapter.fromPublisher(Flux.from(adapter.toPublisher(result)) + .tap(() -> new CachePutSignalListener(request))); } else { return adapter.fromPublisher(Mono.from(adapter.toPublisher(result)) diff --git a/spring-context/src/test/java/org/springframework/cache/annotation/ReactiveCachingTests.java b/spring-context/src/test/java/org/springframework/cache/annotation/ReactiveCachingTests.java index 09d2869b73..7300ee59d4 100644 --- a/spring-context/src/test/java/org/springframework/cache/annotation/ReactiveCachingTests.java +++ b/spring-context/src/test/java/org/springframework/cache/annotation/ReactiveCachingTests.java @@ -61,7 +61,7 @@ class ReactiveCachingTests { Long r3 = service.cacheFuture(key).join(); assertThat(r1).isNotNull(); - assertThat(r1).isSameAs(r2).isSameAs(r3); + assertThat(r1).as("cacheFuture").isSameAs(r2).isSameAs(r3); key = new Object(); @@ -70,7 +70,7 @@ class ReactiveCachingTests { r3 = service.cacheMono(key).block(); assertThat(r1).isNotNull(); - assertThat(r1).isSameAs(r2).isSameAs(r3); + assertThat(r1).as("cacheMono").isSameAs(r2).isSameAs(r3); key = new Object(); @@ -79,7 +79,7 @@ class ReactiveCachingTests { r3 = service.cacheFlux(key).blockFirst(); assertThat(r1).isNotNull(); - assertThat(r1).isSameAs(r2).isSameAs(r3); + assertThat(r1).as("cacheFlux blockFirst").isSameAs(r2).isSameAs(r3); key = new Object(); @@ -88,7 +88,7 @@ class ReactiveCachingTests { List l3 = service.cacheFlux(key).collectList().block(); assertThat(l1).isNotNull(); - assertThat(l1).isEqualTo(l2).isEqualTo(l3); + assertThat(l1).as("cacheFlux collectList").isEqualTo(l2).isEqualTo(l3); key = new Object(); @@ -97,7 +97,7 @@ class ReactiveCachingTests { r3 = service.cacheMono(key).block(); assertThat(r1).isNotNull(); - assertThat(r1).isSameAs(r2).isSameAs(r3); + assertThat(r1).as("cacheMono common key").isSameAs(r2).isSameAs(r3); // Same key as for Mono, reusing its cached value @@ -106,12 +106,11 @@ class ReactiveCachingTests { r3 = service.cacheFlux(key).blockFirst(); assertThat(r1).isNotNull(); - assertThat(r1).isSameAs(r2).isSameAs(r3); + assertThat(r1).as("cacheFlux blockFirst common key").isSameAs(r2).isSameAs(r3); ctx.close(); } - @CacheConfig(cacheNames = "first") static class ReactiveCacheableService { @@ -124,12 +123,16 @@ class ReactiveCachingTests { @Cacheable Mono cacheMono(Object arg) { - return Mono.just(this.counter.getAndIncrement()); + // here counter not only reflects invocations of cacheMono but subscriptions to + // the returned Mono as well. See https://github.com/spring-projects/spring-framework/issues/32370 + return Mono.defer(() -> Mono.just(this.counter.getAndIncrement())); } @Cacheable Flux cacheFlux(Object arg) { - return Flux.just(this.counter.getAndIncrement(), 0L); + // here counter not only reflects invocations of cacheFlux but subscriptions to + // the returned Flux as well. See https://github.com/spring-projects/spring-framework/issues/32370 + return Flux.defer(() -> Flux.just(this.counter.getAndIncrement(), 0L)); } }