diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java index 62c462f935..11d1a889cb 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java @@ -254,7 +254,6 @@ final class DefaultRSocketRequester implements RSocketRequester { dataBuffer instanceof NettyDataBuffer ? ((NettyDataBuffer) dataBuffer).getNativeBuffer() : Unpooled.wrappedBuffer(dataBuffer.asByteBuffer())); - }); return asDataBuffer(metadata); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessageHandlerAcceptor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessageHandlerAcceptor.java deleted file mode 100644 index 4a92e69552..0000000000 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessageHandlerAcceptor.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.messaging.rsocket; - -import java.util.function.BiFunction; -import java.util.function.Function; - -import io.rsocket.ConnectionSetupPayload; -import io.rsocket.RSocket; -import io.rsocket.SocketAcceptor; -import reactor.core.publisher.Mono; - -import org.springframework.lang.Nullable; -import org.springframework.messaging.Message; -import org.springframework.util.Assert; -import org.springframework.util.MimeType; -import org.springframework.util.MimeTypeUtils; -import org.springframework.util.StringUtils; - -/** - * Extension of {@link RSocketMessageHandler} that can be plugged directly into - * RSocket to receive connections either on the - * {@link io.rsocket.RSocketFactory.ClientRSocketFactory#acceptor(Function) client} or on the - * {@link io.rsocket.RSocketFactory.ServerRSocketFactory#acceptor(SocketAcceptor) server} - * side. Requests are handled by delegating to the "super" {@link #handleMessage(Message)}. - * - * @author Rossen Stoyanchev - * @since 5.2 - */ -public final class MessageHandlerAcceptor extends RSocketMessageHandler - implements SocketAcceptor, BiFunction { - - @Nullable - private MimeType defaultDataMimeType; - - private MimeType defaultMetadataMimeType = DefaultRSocketRequester.COMPOSITE_METADATA; - - - /** - * Configure the default content type to use for data payloads if the - * {@code SETUP} frame did not specify one. - *

By default this is not set. - * @param mimeType the MimeType to use - */ - public void setDefaultDataMimeType(@Nullable MimeType mimeType) { - this.defaultDataMimeType = mimeType; - } - - /** - * Configure the default {@code MimeType} for payload data if the - * {@code SETUP} frame did not specify one. - *

By default this is set to {@code "message/x.rsocket.composite-metadata.v0"} - * @param mimeType the MimeType to use - */ - public void setDefaultMetadataMimeType(MimeType mimeType) { - Assert.notNull(mimeType, "'metadataMimeType' is required"); - this.defaultMetadataMimeType = mimeType; - } - - - @Override - public Mono accept(ConnectionSetupPayload setupPayload, RSocket sendingRSocket) { - MessagingRSocket rsocket = createRSocket(setupPayload, sendingRSocket); - - // Allow handling of the ConnectionSetupPayload via @MessageMapping methods. - // However, if the handling is to make requests to the client, it's expected - // it will do so decoupled from the handling, e.g. via .subscribe(). - return rsocket.handleConnectionSetupPayload(setupPayload).then(Mono.just(rsocket)); - } - - @Override - public RSocket apply(ConnectionSetupPayload setupPayload, RSocket sendingRSocket) { - return createRSocket(setupPayload, sendingRSocket); - } - - private MessagingRSocket createRSocket(ConnectionSetupPayload setupPayload, RSocket rsocket) { - - String s = setupPayload.dataMimeType(); - MimeType dataMimeType = StringUtils.hasText(s) ? MimeTypeUtils.parseMimeType(s) : this.defaultDataMimeType; - Assert.notNull(dataMimeType, "No `dataMimeType` in ConnectionSetupPayload and no default value"); - - s = setupPayload.metadataMimeType(); - MimeType metaMimeType = StringUtils.hasText(s) ? MimeTypeUtils.parseMimeType(s) : this.defaultMetadataMimeType; - Assert.notNull(dataMimeType, "No `metadataMimeType` in ConnectionSetupPayload and no default value"); - - RSocketRequester requester = RSocketRequester.wrap( - rsocket, dataMimeType, metaMimeType, getRSocketStrategies()); - - return new MessagingRSocket(this, getRouteMatcher(), requester, - dataMimeType, metaMimeType, getRSocketStrategies().dataBufferFactory()); - } - -} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketMessageHandler.java index 69ec6ffdc0..eeff53fc8f 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketMessageHandler.java @@ -18,6 +18,13 @@ package org.springframework.messaging.rsocket; import java.util.ArrayList; import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Function; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import reactor.core.publisher.Mono; import org.springframework.core.codec.Decoder; import org.springframework.core.codec.Encoder; @@ -27,15 +34,19 @@ import org.springframework.messaging.MessageDeliveryException; import org.springframework.messaging.handler.annotation.reactive.MessageMappingMessageHandler; import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; import org.springframework.util.RouteMatcher; import org.springframework.util.StringUtils; /** - * RSocket-specific extension of {@link MessageMappingMessageHandler}. - * - *

The configured {@link #setEncoders(List) encoders} are used to encode the - * return values from handler methods, with the help of - * {@link RSocketPayloadReturnValueHandler}. + * Extension of {@link MessageMappingMessageHandler} to use as an RSocket + * responder by handling incoming streams via {@code @MessageMapping} annotated + * methods. + *

Use {@link #clientAcceptor()} and {@link #serverAcceptor()} to obtain + * {@link io.rsocket.RSocketFactory.ClientRSocketFactory#acceptor(Function) client} or + * {@link io.rsocket.RSocketFactory.ServerRSocketFactory#acceptor(SocketAcceptor) server} + * side adapters. * * @author Rossen Stoyanchev * @since 5.2 @@ -47,6 +58,11 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { @Nullable private RSocketStrategies rsocketStrategies; + @Nullable + private MimeType defaultDataMimeType; + + private MimeType defaultMetadataMimeType = DefaultRSocketRequester.COMPOSITE_METADATA; + /** * Configure the encoders to use for encoding handler method return values. @@ -95,6 +111,27 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { return this.rsocketStrategies; } + /** + * Configure the default content type to use for data payloads if the + * {@code SETUP} frame did not specify one. + *

By default this is not set. + * @param mimeType the MimeType to use + */ + public void setDefaultDataMimeType(@Nullable MimeType mimeType) { + this.defaultDataMimeType = mimeType; + } + + /** + * Configure the default {@code MimeType} for payload data if the + * {@code SETUP} frame did not specify one. + *

By default this is set to {@code "message/x.rsocket.composite-metadata.v0"} + * @param mimeType the MimeType to use + */ + public void setDefaultMetadataMimeType(MimeType mimeType) { + Assert.notNull(mimeType, "'metadataMimeType' is required"); + this.defaultMetadataMimeType = mimeType; + } + @Override public void afterPropertiesSet() { @@ -124,4 +161,49 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { } } + /** + * Return an adapter for a + * {@link io.rsocket.RSocketFactory.ServerRSocketFactory#acceptor(SocketAcceptor) + * server acceptor}. The adapter implements a responding {@link RSocket} by + * wrapping {@code Payload} data and metadata as {@link Message} and + * delegating to this {@link RSocketMessageHandler} to handle and reply. + */ + public SocketAcceptor serverAcceptor() { + return (setupPayload, sendingRSocket) -> { + MessagingRSocket rsocket = createRSocket(setupPayload, sendingRSocket); + + // Allow handling of the ConnectionSetupPayload via @MessageMapping methods. + // However, if the handling is to make requests to the client, it's expected + // it will do so decoupled from the handling, e.g. via .subscribe(). + return rsocket.handleConnectionSetupPayload(setupPayload).then(Mono.just(rsocket)); + }; + } + + /** + * Return an adapter for a + * {@link io.rsocket.RSocketFactory.ClientRSocketFactory#acceptor(BiFunction) + * client acceptor}. The adapter implements a responding {@link RSocket} by + * wrapping {@code Payload} data and metadata as {@link Message} and + * delegating to this {@link RSocketMessageHandler} to handle and reply. + */ + public BiFunction clientAcceptor() { + return this::createRSocket; + } + + private MessagingRSocket createRSocket(ConnectionSetupPayload setupPayload, RSocket rsocket) { + String s = setupPayload.dataMimeType(); + MimeType dataMimeType = StringUtils.hasText(s) ? MimeTypeUtils.parseMimeType(s) : this.defaultDataMimeType; + Assert.notNull(dataMimeType, "No `dataMimeType` in ConnectionSetupPayload and no default value"); + + s = setupPayload.metadataMimeType(); + MimeType metaMimeType = StringUtils.hasText(s) ? MimeTypeUtils.parseMimeType(s) : this.defaultMetadataMimeType; + Assert.notNull(dataMimeType, "No `metadataMimeType` in ConnectionSetupPayload and no default value"); + + RSocketRequester requester = RSocketRequester.wrap( + rsocket, dataMimeType, metaMimeType, getRSocketStrategies()); + + return new MessagingRSocket(this, getRouteMatcher(), requester, + dataMimeType, metaMimeType, getRSocketStrategies().dataBufferFactory()); + } + } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java index 416f4d5fa9..231f4a2f44 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketBufferLeakTests.java @@ -87,7 +87,7 @@ public class RSocketBufferLeakTests { server = RSocketFactory.receive() .frameDecoder(PayloadDecoder.ZERO_COPY) .addServerPlugin(payloadInterceptor) // intercept responding - .acceptor(context.getBean(MessageHandlerAcceptor.class)) + .acceptor(context.getBean(RSocketMessageHandler.class).serverAcceptor()) .transport(TcpServerTransport.create("localhost", 7000)) .start() .block(); @@ -214,10 +214,10 @@ public class RSocketBufferLeakTests { } @Bean - public MessageHandlerAcceptor messageHandlerAcceptor() { - MessageHandlerAcceptor acceptor = new MessageHandlerAcceptor(); - acceptor.setRSocketStrategies(rsocketStrategies()); - return acceptor; + public RSocketMessageHandler messageHandler() { + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setRSocketStrategies(rsocketStrategies()); + return handler; } @Bean diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java index ef1bd615c8..219d3da2f3 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketClientToServerIntegrationTests.java @@ -67,7 +67,7 @@ public class RSocketClientToServerIntegrationTests { server = RSocketFactory.receive() .addServerPlugin(interceptor) .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(context.getBean(MessageHandlerAcceptor.class)) + .acceptor(context.getBean(RSocketMessageHandler.class).serverAcceptor()) .transport(TcpServerTransport.create("localhost", 7000)) .start() .block(); @@ -257,10 +257,10 @@ public class RSocketClientToServerIntegrationTests { } @Bean - public MessageHandlerAcceptor messageHandlerAcceptor() { - MessageHandlerAcceptor acceptor = new MessageHandlerAcceptor(); - acceptor.setRSocketStrategies(rsocketStrategies()); - return acceptor; + public RSocketMessageHandler messageHandler() { + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setRSocketStrategies(rsocketStrategies()); + return handler; } @Bean diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketServerToClientIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketServerToClientIntegrationTests.java index 6e07cf5752..60b1806711 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketServerToClientIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketServerToClientIntegrationTests.java @@ -65,7 +65,7 @@ public class RSocketServerToClientIntegrationTests { server = RSocketFactory.receive() .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(context.getBean("serverAcceptor", MessageHandlerAcceptor.class)) + .acceptor(context.getBean("serverMessageHandler", RSocketMessageHandler.class).serverAcceptor()) .transport(TcpServerTransport.create("localhost", 7000)) .start() .block(); @@ -110,7 +110,7 @@ public class RSocketServerToClientIntegrationTests { .dataMimeType("text/plain") .setupPayload(DefaultPayload.create("", destination)) .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(context.getBean("clientAcceptor", MessageHandlerAcceptor.class)) + .acceptor(context.getBean("clientMessageHandler", RSocketMessageHandler.class).clientAcceptor()) .transport(TcpClientTransport.create("localhost", 7000)) .start() .block(); @@ -260,17 +260,16 @@ public class RSocketServerToClientIntegrationTests { } @Bean - public MessageHandlerAcceptor clientAcceptor() { - MessageHandlerAcceptor acceptor = new MessageHandlerAcceptor(); - acceptor.setHandlers(Collections.singletonList(clientHandler())); - acceptor.setAutoDetectDisabled(); - acceptor.setRSocketStrategies(rsocketStrategies()); - return acceptor; + public RSocketMessageHandler clientMessageHandler() { + RSocketMessageHandler handler = new RSocketMessageHandler(); + handler.setHandlers(Collections.singletonList(clientHandler())); + handler.setRSocketStrategies(rsocketStrategies()); + return handler; } @Bean - public MessageHandlerAcceptor serverAcceptor() { - MessageHandlerAcceptor handler = new MessageHandlerAcceptor(); + public RSocketMessageHandler serverMessageHandler() { + RSocketMessageHandler handler = new RSocketMessageHandler(); handler.setRSocketStrategies(rsocketStrategies()); return handler; } diff --git a/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt index fcb4a15821..7d66833fa9 100644 --- a/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt +++ b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/RSocketClientToServerCoroutinesIntegrationTests.kt @@ -16,8 +16,6 @@ package org.springframework.messaging.rsocket -import java.time.Duration - import io.netty.buffer.PooledByteBufAllocator import io.rsocket.RSocketFactory import io.rsocket.frame.decoder.PayloadDecoder @@ -31,9 +29,6 @@ import kotlinx.coroutines.flow.map import org.junit.AfterClass import org.junit.BeforeClass import org.junit.Test -import reactor.core.publisher.Flux -import reactor.test.StepVerifier - import org.springframework.context.annotation.AnnotationConfigApplicationContext import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration @@ -43,6 +38,9 @@ import org.springframework.core.io.buffer.NettyDataBufferFactory import org.springframework.messaging.handler.annotation.MessageExceptionHandler import org.springframework.messaging.handler.annotation.MessageMapping import org.springframework.stereotype.Controller +import reactor.core.publisher.Flux +import reactor.test.StepVerifier +import java.time.Duration /** * Coroutines server-side handling of RSocket requests. @@ -167,10 +165,10 @@ class RSocketClientToServerCoroutinesIntegrationTests { } @Bean - open fun messageHandlerAcceptor(): MessageHandlerAcceptor { - val acceptor = MessageHandlerAcceptor() - acceptor.rSocketStrategies = rsocketStrategies() - return acceptor + open fun messageHandler(): RSocketMessageHandler { + val handler = RSocketMessageHandler() + handler.rSocketStrategies = rsocketStrategies() + return handler } @Bean @@ -202,7 +200,7 @@ class RSocketClientToServerCoroutinesIntegrationTests { server = RSocketFactory.receive() .addServerPlugin(interceptor) .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(context.getBean(MessageHandlerAcceptor::class.java)) + .acceptor(context.getBean(RSocketMessageHandler::class.java).serverAcceptor()) .transport(TcpServerTransport.create("localhost", 7000)) .start() .block()!!