diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java index b36a6fd83e..1ce015f181 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java @@ -116,7 +116,7 @@ public class DefaultMetadataExtractor implements MetadataExtractor, MetadataExtr } } else { - extractEntry(payload.metadata(), metadataMimeType.toString(), result); + extractEntry(payload.metadata().slice(), metadataMimeType.toString(), result); } return result; } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java index cd1052a4ae..c12a15160d 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java @@ -15,6 +15,7 @@ */ package org.springframework.messaging.rsocket; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Collections; import java.util.Map; @@ -26,10 +27,14 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.AbstractDataBufferDecoder; import org.springframework.core.codec.ByteArrayDecoder; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; @@ -156,6 +161,24 @@ public class DefaultMetadataExtractorTests { .containsEntry("entry1", "text data"); } + @Test + public void nonCompositeMetadataCanBeReadTwice() { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(new TestDecoder()); + extractor.metadataToExtract(TEXT_PLAIN, String.class, "name"); + + MetadataEncoder encoder = new MetadataEncoder(TEXT_PLAIN, this.strategies).metadata("value", null); + DataBuffer metadata = encoder.encode(); + Payload payload = createPayload(metadata); + + Map result = extractor.extract(payload, TEXT_PLAIN); + assertThat(result).hasSize(1).containsEntry("name", "value"); + + result = extractor.extract(payload, TEXT_PLAIN); + assertThat(result).hasSize(1).containsEntry("name", "value"); + + payload.release(); + } + @Test public void noDecoder() { DefaultMetadataExtractor extractor = @@ -172,4 +195,25 @@ public class DefaultMetadataExtractorTests { return PayloadUtils.createPayload(this.strategies.dataBufferFactory().allocateBuffer(), metadata); } + + /** + * Like StringDecoder but consumes the reader index in order to prove that + * extraction uses a slice and can be read twice. + */ + private static class TestDecoder extends AbstractDataBufferDecoder { + + public TestDecoder() { + super(TEXT_PLAIN); + } + + @Override + public String decode(DataBuffer dataBuffer, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + byte[] bytes = new byte[dataBuffer.readableByteCount()]; + dataBuffer.read(bytes); + DataBufferUtils.release(dataBuffer); + return new String(bytes, StandardCharsets.UTF_8); + } + } }