From d5330a070f7d29cc5469f4d22a032bb94b1538b1 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 28 May 2021 12:18:15 -0600 Subject: [PATCH] PayloadInterceptorRSocket retains all payloads Flux#skip discards its corresponding elements, meaning that they aren't intended for reuse. When using RSocket's ByteBufPayloads, this means that the bytes are releaseed back into RSocket's pool. Since the downstream request may still need the skipped payload, we should construct the publisher in a different way so as to avoid the preemptive release. Deferring Spring JavaFormat to clarify what changed. Closes gh-9345 --- .../core/PayloadInterceptorRSocket.java | 9 ++- .../core/PayloadInterceptorRSocketTests.java | 62 ++++++++++++++++++- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java index 3120cab77c..0fe8fd002d 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 the original author or authors. + * Copyright 2019-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -92,13 +92,16 @@ class PayloadInterceptorRSocket extends RSocketProxy { return Flux.from(payloads).switchOnFirst((signal, innerFlux) -> { Payload firstPayload = signal.get(); return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload).flatMapMany((context) -> innerFlux - .skip(1).flatMap((p) -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p)) - .transform((securedPayloads) -> Flux.concat(Flux.just(firstPayload), securedPayloads)) + .index().concatMap((tuple) -> justOrIntercept(tuple.getT1(), tuple.getT2())) .transform((securedPayloads) -> this.source.requestChannel(securedPayloads)) .subscriberContext(context)); }); } + private Mono justOrIntercept(Long index, Payload payload) { + return (index == 0) ? Mono.just(payload) : intercept(PayloadExchangeType.PAYLOAD, payload).thenReturn(payload); + } + @Override public Mono metadataPush(Payload payload) { return intercept(PayloadExchangeType.METADATA_PUSH, payload) diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java index 3a153b8c61..f92c3c8f6c 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 the original author or authors. + * Copyright 2019-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,10 +19,14 @@ package org.springframework.security.rsocket.core; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; import io.rsocket.util.RSocketProxy; import org.junit.Test; import org.junit.runner.RunWith; @@ -32,13 +36,17 @@ import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import reactor.test.publisher.PublisherProbe; import reactor.test.publisher.TestPublisher; +import reactor.util.context.Context; import org.springframework.http.MediaType; +import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; @@ -56,6 +64,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; @@ -265,6 +274,57 @@ public class PayloadInterceptorRSocketTests { verify(this.delegate).requestChannel(any()); } + // gh-9345 + @Test + public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() { + ExecutorService executors = Executors.newSingleThreadExecutor(); + Payload payload = ByteBufPayload.create("data"); + Payload payloadTwo = ByteBufPayload.create("moredata"); + Payload payloadThree = ByteBufPayload.create("stillmoredata"); + Context ctx = Context.empty(); + Flux payloads = this.payloadResult.flux(); + given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty()) + .willReturn(Mono.error(() -> new AccessDeniedException("Access Denied"))); + given(this.delegate.requestChannel(any())).willAnswer((invocation) -> { + Flux input = invocation.getArgument(0); + return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8) + .transform((data) -> Flux.create((emitter) -> { + Runnable run = () -> data.subscribe(new CoreSubscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(3); + } + + @Override + public void onNext(String s) { + emitter.next(s); + } + + @Override + public void onError(Throwable t) { + emitter.error(t); + } + + @Override + public void onComplete() { + emitter.complete(); + } + }); + executors.execute(run); + })).map(DefaultPayload::create)); + }); + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx); + StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release)) + .then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree)) + .assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8())) + .verifyError(AccessDeniedException.class); + verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo); + verify(this.delegate).requestChannel(any()); + } + @Test public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() { RuntimeException expected = new RuntimeException("Oops");