diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index 1273592937..ad0419ee4b 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -47,6 +47,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; import reactor.core.publisher.SynchronousSink; +import reactor.util.context.Context; import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; @@ -1057,6 +1058,12 @@ public abstract class DataBufferUtils { protected void hookOnComplete() { this.sink.complete(); } + + @Override + public Context currentContext() { + return this.sink.currentContext(); + } + } @@ -1148,6 +1155,12 @@ public abstract class DataBufferUtils { this.sink.next(dataBuffer); this.dataBuffer.set(null); } + + @Override + public Context currentContext() { + return this.sink.currentContext(); + } + } } diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index 8615551b31..ba85a1575b 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -24,6 +24,7 @@ import java.nio.channels.AsynchronousFileChannel; import java.nio.channels.CompletionHandler; import java.nio.channels.FileChannel; import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; import java.nio.channels.WritableByteChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Files; @@ -43,6 +44,7 @@ import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import reactor.util.context.Context; import org.springframework.core.io.ByteArrayResource; import org.springframework.core.io.ClassPathResource; @@ -940,6 +942,53 @@ class DataBufferUtilsTests extends AbstractDataBufferAllocatingTests { release(foo); } + @ParameterizedDataBufferAllocatingTest + void propagateContextByteChannel(String displayName, DataBufferFactory bufferFactory) throws IOException { + Path path = Paths.get(this.resource.getURI()); + try (SeekableByteChannel out = Files.newByteChannel(this.tempFile, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING)) { + Flux result = DataBufferUtils.read(path, bufferFactory, 1024, StandardOpenOption.READ) + .transformDeferredContextual((f, ctx) -> { + assertThat(ctx.getOrDefault("key", "EMPTY")).isEqualTo("TEST"); + return f; + }) + .transform(f -> DataBufferUtils.write(f, out)) + .transformDeferredContextual((f, ctx) -> { + assertThat(ctx.getOrDefault("key", "EMPTY")).isEqualTo("TEST"); + return f; + }) + .contextWrite(Context.of("key", "TEST")); + + StepVerifier.create(result) + .consumeNextWith(DataBufferUtils::release) + .verifyComplete(); + + + } + } + + @ParameterizedDataBufferAllocatingTest + void propagateContextAsynchronousFileChannel(String displayName, DataBufferFactory bufferFactory) throws IOException { + Path path = Paths.get(this.resource.getURI()); + try (AsynchronousFileChannel out = AsynchronousFileChannel.open(this.tempFile, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING)) { + Flux result = DataBufferUtils.read(path, bufferFactory, 1024, StandardOpenOption.READ) + .transformDeferredContextual((f, ctx) -> { + assertThat(ctx.getOrDefault("key", "EMPTY")).isEqualTo("TEST"); + return f; + }) + .transform(f -> DataBufferUtils.write(f, out)) + .transformDeferredContextual((f, ctx) -> { + assertThat(ctx.getOrDefault("key", "EMPTY")).isEqualTo("TEST"); + return f; + }) + .contextWrite(Context.of("key", "TEST")); + + StepVerifier.create(result) + .consumeNextWith(DataBufferUtils::release) + .verifyComplete(); + + + } + } private static class ZeroDemandSubscriber extends BaseSubscriber { diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartParser.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartParser.java index 3d9ab7f2b3..d2057f53d6 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartParser.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartParser.java @@ -29,6 +29,7 @@ import org.reactivestreams.Subscription; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; +import reactor.util.context.Context; import org.springframework.core.codec.DecodingException; import org.springframework.core.io.buffer.DataBuffer; @@ -98,6 +99,11 @@ final class MultipartParser extends BaseSubscriber { }); } + @Override + public Context currentContext() { + return this.sink.currentContext(); + } + @Override protected void hookOnSubscribe(Subscription subscription) { requestBuffer(); diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/PartGenerator.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/PartGenerator.java index 9de34009d4..32a923d8a7 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/PartGenerator.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/PartGenerator.java @@ -42,6 +42,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; +import reactor.util.context.Context; import org.springframework.core.codec.DecodingException; import org.springframework.core.io.buffer.DataBuffer; @@ -113,6 +114,11 @@ final class PartGenerator extends BaseSubscriber { }); } + @Override + public Context currentContext() { + return this.sink.currentContext(); + } + @Override protected void hookOnSubscribe(Subscription subscription) { requestToken();