From fab0a5d504395cfd9e4435fe6491349200d28df7 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 26 Jul 2019 14:41:38 +0100 Subject: [PATCH] MetadataExtractor refactoring Remove RSocketStrategies argument from the contract to avoid having to pass them every time especially by application components, like an implementation of a Spring Security matcher. Decouple DefaultMetadataExtractor from RSocketStrategies in favor of a decoders property and an internal DataBufferFactory, which does not need to be the shared one as we're only wrapping ByteBufs. --- .../rsocket/DefaultMetadataExtractor.java | 162 ++++++++++++------ .../rsocket/DefaultRSocketStrategies.java | 22 ++- .../messaging/rsocket/MetadataExtractor.java | 3 +- .../messaging/rsocket/RSocketStrategies.java | 3 + .../annotation/support/MessagingRSocket.java | 3 +- .../support/RSocketMessageHandler.java | 4 + .../DefaultMetadataExtractorTests.java | 76 ++++++-- .../DefaultRSocketStrategiesTests.java | 35 ++++ .../support/RSocketMessageHandlerTests.java | 26 +++ 9 files changed, 261 insertions(+), 73 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java index ee296fd6e85..1509aa0b2ea 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java @@ -15,22 +15,26 @@ */ package org.springframework.messaging.rsocket; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.BiConsumer; import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; import io.rsocket.Payload; import io.rsocket.metadata.CompositeMetadata; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; import org.springframework.core.codec.Decoder; -import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBuffer; import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.MimeType; /** @@ -47,15 +51,53 @@ import org.springframework.util.MimeType; */ public class DefaultMetadataExtractor implements MetadataExtractor { - private final Map> entryProcessors = new HashMap<>(); + private final List> decoders = new ArrayList<>(); + + private final Map> processors = new HashMap<>(); /** - * Default constructor with {@link RSocketStrategies}. + * Configure the decoders to use for de-serializing metadata entries. + *

By default this is not set. + *

When this extractor is passed into {@link RSocketStrategies.Builder} or + * {@link org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler + * RSocketMessageHandler}, the decoders may be left not set, and they will + * be initialized from the decoders already configured there. */ - public DefaultMetadataExtractor() { - // TODO: remove when rsocket-core API available - metadataToExtract(MetadataExtractor.ROUTING, String.class, ROUTE_KEY); + public void setDecoders(List> decoders) { + this.decoders.clear(); + if (!decoders.isEmpty()) { + this.decoders.addAll(decoders); + updateProcessors(); + } + } + + @SuppressWarnings("unchecked") + private void updateProcessors() { + for (MetadataProcessor info : this.processors.values()) { + Decoder decoder = decoderFor(info.mimeType(), info.targetType()); + Assert.isTrue(decoder != null, "No decoder for " + info); + info = ((MetadataProcessor) info).setDecoder(decoder); + this.processors.put(info.mimeType().toString(), info); + } + } + + @Nullable + @SuppressWarnings("unchecked") + private Decoder decoderFor(MimeType mimeType, ResolvableType type) { + for (Decoder decoder : this.decoders) { + if (decoder.canDecode(type, mimeType)) { + return (Decoder) decoder; + } + } + return null; + } + + /** + * Return the {@link #setDecoders(List) configured} decoders. + */ + public List> getDecoders() { + return this.decoders; } @@ -97,11 +139,9 @@ public class DefaultMetadataExtractor implements MetadataExtractor { * @param the target value type */ public void metadataToExtract( - MimeType mimeType, Class targetType, - BiConsumer> mapper) { + MimeType mimeType, Class targetType, BiConsumer> mapper) { - EntryProcessor spec = new EntryProcessor<>(mimeType, targetType, mapper); - this.entryProcessors.put(mimeType.toString(), spec); + metadataToExtract(mimeType, mapper, ResolvableType.forClass(targetType)); } /** @@ -117,45 +157,52 @@ public class DefaultMetadataExtractor implements MetadataExtractor { MimeType mimeType, ParameterizedTypeReference targetType, BiConsumer> mapper) { - EntryProcessor spec = new EntryProcessor<>(mimeType, targetType, mapper); - this.entryProcessors.put(mimeType.toString(), spec); + metadataToExtract(mimeType, mapper, ResolvableType.forType(targetType)); + } + + private void metadataToExtract( + MimeType mimeType, BiConsumer> mapper, ResolvableType elementType) { + + Decoder decoder = decoderFor(mimeType, elementType); + Assert.isTrue(this.decoders.isEmpty() || decoder != null, () -> "No decoder for " + mimeType); + MetadataProcessor info = new MetadataProcessor<>(mimeType, elementType, mapper, decoder); + this.processors.put(mimeType.toString(), info); } @Override - public Map extract(Payload payload, MimeType metadataMimeType, RSocketStrategies strategies) { + public Map extract(Payload payload, MimeType metadataMimeType) { Map result = new HashMap<>(); if (metadataMimeType.equals(COMPOSITE_METADATA)) { for (CompositeMetadata.Entry entry : new CompositeMetadata(payload.metadata(), false)) { - processEntry(entry.getContent(), entry.getMimeType(), result, strategies); + processEntry(entry.getContent(), entry.getMimeType(), result); } } else { - processEntry(payload.metadata(), metadataMimeType.toString(), result, strategies); + processEntry(payload.metadata(), metadataMimeType.toString(), result); } return result; } - private void processEntry(ByteBuf content, - @Nullable String mimeType, Map result, RSocketStrategies strategies) { - - EntryProcessor entryProcessor = this.entryProcessors.get(mimeType); - if (entryProcessor != null) { - content.retain(); - entryProcessor.process(content, result, strategies); + @SuppressWarnings("unchecked") + private void processEntry(ByteBuf content, @Nullable String mimeType, Map result) { + MetadataProcessor info = (MetadataProcessor) this.processors.get(mimeType); + if (info != null) { + info.process(content, result); return; } if (MetadataExtractor.ROUTING.toString().equals(mimeType)) { // TODO: use rsocket-core API when available + result.put(MetadataExtractor.ROUTE_KEY, content.toString(StandardCharsets.UTF_8)); } } - /** - * Helps to decode a metadata entry and add the resulting value to the - * output map. - */ - private class EntryProcessor { + private static class MetadataProcessor { + + private final static NettyDataBufferFactory bufferFactory = + new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT); + private final MimeType mimeType; @@ -163,41 +210,54 @@ public class DefaultMetadataExtractor implements MetadataExtractor { private final BiConsumer> accumulator; + @Nullable + private final Decoder decoder; - public EntryProcessor( - MimeType mimeType, Class targetType, - BiConsumer> accumulator) { - this(mimeType, ResolvableType.forClass(targetType), accumulator); - } - - public EntryProcessor( - MimeType mimeType, ParameterizedTypeReference targetType, - BiConsumer> accumulator) { - - this(mimeType, ResolvableType.forType(targetType), accumulator); - } - - private EntryProcessor( - MimeType mimeType, ResolvableType targetType, - BiConsumer> accumulator) { + MetadataProcessor(MimeType mimeType, ResolvableType targetType, + BiConsumer> accumulator, @Nullable Decoder decoder) { this.mimeType = mimeType; this.targetType = targetType; this.accumulator = accumulator; + this.decoder = decoder; + } + + MetadataProcessor(MetadataProcessor other, Decoder decoder) { + this.mimeType = other.mimeType; + this.targetType = other.targetType; + this.accumulator = other.accumulator; + this.decoder = decoder; } - public void process(ByteBuf byteBuf, Map result, RSocketStrategies strategies) { - DataBufferFactory factory = strategies.dataBufferFactory(); - DataBuffer buffer = factory instanceof NettyDataBufferFactory ? - ((NettyDataBufferFactory) factory).wrap(byteBuf) : - factory.wrap(byteBuf.nioBuffer()); + public MimeType mimeType() { + return this.mimeType; + } - Decoder decoder = strategies.decoder(this.targetType, this.mimeType); - T value = decoder.decode(buffer, this.targetType, this.mimeType, Collections.emptyMap()); + public ResolvableType targetType() { + return this.targetType; + } + + public MetadataProcessor setDecoder(Decoder decoder) { + return this.decoder != decoder ? new MetadataProcessor<>(this, decoder) : this; + } + + + public void process(ByteBuf content, Map result) { + if (this.decoder == null) { + throw new IllegalStateException("No decoder for " + this); + } + NettyDataBuffer dataBuffer = bufferFactory.wrap(content.retain()); + T value = this.decoder.decode(dataBuffer, this.targetType, this.mimeType, Collections.emptyMap()); this.accumulator.accept(value, result); } + + + @Override + public String toString() { + return "MetadataProcessor mimeType=" + this.mimeType + ", targetType=" + this.targetType; + } } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketStrategies.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketStrategies.java index e5bb9c262e1..62115717dd6 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketStrategies.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketStrategies.java @@ -209,7 +209,7 @@ final class DefaultRSocketStrategies implements RSocketStrategies { return new DefaultRSocketStrategies( this.encoders, this.decoders, this.routeMatcher != null ? this.routeMatcher : initRouteMatcher(), - this.metadataExtractor != null ? this.metadataExtractor : initMetadataExtractor(), + getOrInitMetadataExtractor(), this.bufferFactory != null ? this.bufferFactory : initBufferFactory(), this.adapterRegistry != null ? this.adapterRegistry : initReactiveAdapterRegistry()); } @@ -220,10 +220,22 @@ final class DefaultRSocketStrategies implements RSocketStrategies { return new SimpleRouteMatcher(pathMatcher); } - private MetadataExtractor initMetadataExtractor() { - DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); - extractor.metadataToExtract(MimeTypeUtils.TEXT_PLAIN, String.class, MetadataExtractor.ROUTE_KEY); - return extractor; + private MetadataExtractor getOrInitMetadataExtractor() { + if (this.metadataExtractor != null) { + if (this.metadataExtractor instanceof DefaultMetadataExtractor) { + DefaultMetadataExtractor extractor = (DefaultMetadataExtractor) this.metadataExtractor; + if (extractor.getDecoders().isEmpty()) { + extractor.setDecoders(this.decoders); + } + } + return this.metadataExtractor; + } + else { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + extractor.setDecoders(this.decoders); + extractor.metadataToExtract(MimeTypeUtils.TEXT_PLAIN, String.class, MetadataExtractor.ROUTE_KEY); + return extractor; + } } private DataBufferFactory initBufferFactory() { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java index 86e86a9bc36..580b702da57 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java @@ -58,9 +58,8 @@ public interface MetadataExtractor { * @param payload the payload whose metadata should be read * @param metadataMimeType the mime type of the metadata; this is what was * specified by the client at the start of the RSocket connection. - * @param strategies for access to codecs and a DataBufferFactory * @return a map of 0 or more decoded metadata values with assigned names */ - Map extract(Payload payload, MimeType metadataMimeType, RSocketStrategies strategies); + Map extract(Payload payload, MimeType metadataMimeType); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketStrategies.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketStrategies.java index c54daa99051..fa8cadcf4bb 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketStrategies.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketStrategies.java @@ -190,6 +190,9 @@ public interface RSocketStrategies { *

By default this is {@link DefaultMetadataExtractor} extracting a * route from {@code "message/x.rsocket.routing.v0"} or * {@code "text/plain"} metadata entries. + *

If the extractor is a {@code DefaultMetadataExtractor}, its + * {@code decoders} property will be set, if not already set, to the + * {@link #decoder(Decoder[]) decoders} configured here. */ Builder metadataExtractor(@Nullable MetadataExtractor metadataExtractor); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java index d205e597f79..10bd75c733c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/MessagingRSocket.java @@ -192,8 +192,7 @@ class MessagingRSocket extends AbstractRSocket { MessageHeaderAccessor headers = new MessageHeaderAccessor(); headers.setLeaveMutable(true); - Map metadataValues = - this.metadataExtractor.extract(payload, this.metadataMimeType, this.strategies); + Map metadataValues = this.metadataExtractor.extract(payload, this.metadataMimeType); metadataValues.putIfAbsent(MetadataExtractor.ROUTE_KEY, ""); for (Map.Entry entry : metadataValues.entrySet()) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java index 52392bdb28f..1d422fff402 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java @@ -174,6 +174,9 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { * other metadata. *

By default this is {@link DefaultMetadataExtractor} extracting a * route from {@code "message/x.rsocket.routing.v0"} or {@code "text/plain"}. + *

If the extractor is a {@code DefaultMetadataExtractor}, its + * {@code decoders} property will be set, if not already set, to the + * {@link #setDecoders(List)} configured here. * @param extractor the extractor to use */ public void setMetadataExtractor(MetadataExtractor extractor) { @@ -238,6 +241,7 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { if (getMetadataExtractor() == null) { DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + extractor.setDecoders(getDecoders()); extractor.metadataToExtract(MimeTypeUtils.TEXT_PLAIN, String.class, MetadataExtractor.ROUTE_KEY); setMetadataExtractor(extractor); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java index a8af5523406..bad58c8ac8e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java @@ -16,6 +16,7 @@ package org.springframework.messaging.rsocket; import java.time.Duration; +import java.util.Collections; import java.util.Map; import io.netty.buffer.PooledByteBufAllocator; @@ -28,13 +29,15 @@ import org.mockito.ArgumentCaptor; import org.mockito.BDDMockito; import reactor.core.publisher.Mono; -import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.ByteArrayDecoder; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.util.Assert; import org.springframework.util.MimeType; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.springframework.messaging.rsocket.MetadataExtractor.COMPOSITE_METADATA; import static org.springframework.messaging.rsocket.MetadataExtractor.ROUTE_KEY; import static org.springframework.messaging.rsocket.MetadataExtractor.ROUTING; @@ -61,8 +64,6 @@ public class DefaultMetadataExtractorTests { @Before public void setUp() { this.strategies = RSocketStrategies.builder() - .decoder(StringDecoder.allMimeTypes()) - .encoder(CharSequenceEncoder.allMimeTypes()) .dataBufferFactory(new LeakAwareNettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)) .build(); @@ -71,6 +72,7 @@ public class DefaultMetadataExtractorTests { BDDMockito.when(this.rsocket.fireAndForget(captor.capture())).thenReturn(Mono.empty()); this.extractor = new DefaultMetadataExtractor(); + this.extractor.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes())); } @After @@ -82,7 +84,6 @@ public class DefaultMetadataExtractorTests { @Test public void compositeMetadataWithDefaultSettings() { - requester(COMPOSITE_METADATA).route("toA") .metadata("text data", TEXT_PLAIN) .metadata("html data", TEXT_HTML) @@ -91,7 +92,7 @@ public class DefaultMetadataExtractorTests { .send().block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, COMPOSITE_METADATA, this.strategies); + Map result = this.extractor.extract(payload, COMPOSITE_METADATA); payload.release(); assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); @@ -99,7 +100,6 @@ public class DefaultMetadataExtractorTests { @Test public void compositeMetadataWithMimeTypeRegistrations() { - this.extractor.metadataToExtract(TEXT_PLAIN, String.class, "text-entry"); this.extractor.metadataToExtract(TEXT_HTML, String.class, "html-entry"); this.extractor.metadataToExtract(TEXT_XML, String.class, "xml-entry"); @@ -113,7 +113,7 @@ public class DefaultMetadataExtractorTests { .block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, COMPOSITE_METADATA, this.strategies); + Map result = this.extractor.extract(payload, COMPOSITE_METADATA); payload.release(); assertThat(result).hasSize(4) @@ -125,10 +125,9 @@ public class DefaultMetadataExtractorTests { @Test public void route() { - requester(ROUTING).route("toA").data("data").send().block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, ROUTING, this.strategies); + Map result = this.extractor.extract(payload, ROUTING); payload.release(); assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); @@ -136,12 +135,11 @@ public class DefaultMetadataExtractorTests { @Test public void routeAsText() { - this.extractor.metadataToExtract(TEXT_PLAIN, String.class, ROUTE_KEY); requester(TEXT_PLAIN).route("toA").data("data").send().block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, TEXT_PLAIN, this.strategies); + Map result = this.extractor.extract(payload, TEXT_PLAIN); payload.release(); assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); @@ -149,7 +147,6 @@ public class DefaultMetadataExtractorTests { @Test public void routeWithCustomFormatting() { - this.extractor.metadataToExtract(TEXT_PLAIN, String.class, (text, result) -> { String[] items = text.split(":"); Assert.isTrue(items.length == 2, "Expected two items"); @@ -159,7 +156,7 @@ public class DefaultMetadataExtractorTests { requester(TEXT_PLAIN).metadata("toA:text data", null).data("data").send().block(); Payload payload = this.captor.getValue(); - Map result = this.extractor.extract(payload, TEXT_PLAIN, this.strategies); + Map result = this.extractor.extract(payload, TEXT_PLAIN); payload.release(); assertThat(result).hasSize(2) @@ -167,6 +164,59 @@ public class DefaultMetadataExtractorTests { .containsEntry("entry1", "text data"); } + @Test + public void addMetadataToExtractBeforeDecoders() { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + extractor.metadataToExtract(TEXT_PLAIN, String.class, "key"); + extractor.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes())); + + requester(TEXT_PLAIN).metadata("meta entry", null).data("data").send().block(); + Payload payload = this.captor.getValue(); + Map result = extractor.extract(payload, TEXT_PLAIN); + payload.release(); + + assertThat(result).hasSize(1).containsEntry("key", "meta entry"); + } + + @Test + public void noDecoderExceptionWhenSettingDecoders() { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + extractor.metadataToExtract(TEXT_PLAIN, String.class, "key"); + + assertThatIllegalArgumentException() + .isThrownBy(() -> extractor.setDecoders(Collections.singletonList(new ByteArrayDecoder()))) + .withMessage("No decoder for MetadataProcessor mimeType=text/plain, targetType=java.lang.String"); + } + + @Test + public void noDecoderExceptionWhenRegisteringMetadataToExtract() { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + extractor.setDecoders(Collections.singletonList(new ByteArrayDecoder())); + + assertThatIllegalArgumentException() + .isThrownBy(() -> extractor.metadataToExtract(TEXT_PLAIN, String.class, "key")) + .withMessage("No decoder for text/plain"); + } + + @Test + public void decodersNotSet() { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + extractor.metadataToExtract(TEXT_PLAIN, String.class, "key"); + + assertThatIllegalStateException() + .isThrownBy(() -> { + requester(TEXT_PLAIN).metadata("meta entry", null).data("data").send().block(); + Payload payload = this.captor.getValue(); + try { + extractor.extract(payload, TEXT_PLAIN); + } + finally { + payload.release(); + } + }) + .withMessage("No decoder for MetadataProcessor mimeType=text/plain, targetType=java.lang.String"); + } + private RSocketRequester requester(MimeType metadataMimeType) { return RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, metadataMimeType, this.strategies); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketStrategiesTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketStrategiesTests.java index cf783b1ce60..7b256140285 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketStrategiesTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketStrategiesTests.java @@ -15,6 +15,8 @@ */ package org.springframework.messaging.rsocket; +import java.util.Collections; + import org.junit.Test; import org.springframework.core.ReactiveAdapterRegistry; @@ -87,6 +89,39 @@ public class DefaultRSocketStrategiesTests { assertThat(strategies.reactiveAdapterRegistry()).isSameAs(registry); } + @Test + public void metadataExtractorInitializedWithDecoders() { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + + RSocketStrategies strategies = RSocketStrategies.builder() + .decoders(decoders -> { + decoders.clear(); + decoders.add(new ByteArrayDecoder()); + decoders.add(new ByteBufferDecoder()); + }) + .metadataExtractor(extractor) + .build(); + + assertThat(((DefaultMetadataExtractor) strategies.metadataExtractor()).getDecoders()).hasSize(2); + } + + @Test + public void metadataExtractorWithExplicitlySetDecoders() { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + extractor.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes())); + + RSocketStrategies strategies = RSocketStrategies.builder() + .decoders(decoders -> { + decoders.clear(); + decoders.add(new ByteArrayDecoder()); + decoders.add(new ByteBufferDecoder()); + }) + .metadataExtractor(extractor) + .build(); + + assertThat(((DefaultMetadataExtractor) strategies.metadataExtractor()).getDecoders()).hasSize(1); + } + @Test public void copyConstructor() { RSocketStrategies strategies1 = RSocketStrategies.create(); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java index 90cf9ca3f3c..25264b568a5 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java @@ -15,6 +15,7 @@ */ package org.springframework.messaging.rsocket.annotation.support; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -158,6 +159,31 @@ public class RSocketMessageHandlerTests { assertThat(strategies.reactiveAdapterRegistry()).isSameAs(handler.getReactiveAdapterRegistry()); } + @Test + public void metadataExtractorInitializedWithDecoders() { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setDecoders(Arrays.asList(new ByteArrayDecoder(), new ByteBufferDecoder())); + handler.setMetadataExtractor(extractor); + handler.afterPropertiesSet(); + + assertThat(((DefaultMetadataExtractor) handler.getMetadataExtractor()).getDecoders()).hasSize(2); + } + + @Test + public void metadataExtractorWithExplicitlySetDecoders() { + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); + extractor.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes())); + + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setDecoders(Arrays.asList(new ByteArrayDecoder(), new ByteBufferDecoder())); + handler.setMetadataExtractor(extractor); + handler.afterPropertiesSet(); + + assertThat(((DefaultMetadataExtractor) handler.getMetadataExtractor()).getDecoders()).hasSize(1); + } + @Test public void mappings() { testMapping(new SimpleController(), "path");