Add support for RSocket interface client

See gh-24456
This commit is contained in:
rstoyanchev 2022-09-05 16:54:30 +01:00
parent ae861a2b3e
commit 8423b2cab7
22 changed files with 1956 additions and 93 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2015 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -51,7 +51,11 @@ public @interface Payload {
* <p>This attribute may or may not be supported depending on whether the message being
* handled contains a non-primitive Object as its payload or is in serialized form and
* requires message conversion.
* <p>When processing STOMP over WebSocket messages this attribute is not supported.
* <p>This attribute is not supported for:
* <ul>
* <li>STOMP over WebSocket messages</li>
* <li>RSocket interface client</li>
* </ul>
* @since 4.2
*/
@AliasFor("value")

View File

@ -102,6 +102,12 @@ final class DefaultRSocketRequester implements RSocketRequester {
return this.metadataMimeType;
}
@Override
public RSocketStrategies strategies() {
return this.strategies;
}
@Override
public RequestSpec route(String route, Object... vars) {
return new DefaultRequestSpec(route, vars);

View File

@ -83,6 +83,11 @@ public interface RSocketRequester extends Disposable {
*/
MimeType metadataMimeType();
/**
* Return the configured {@link RSocketStrategies}.
*/
RSocketStrategies strategies();
/**
* Begin to specify a new request with the given route to a remote handler.
* <p>The route can be a template with placeholders, e.g.

View File

@ -0,0 +1,65 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.util.Collection;
import org.springframework.core.MethodParameter;
import org.springframework.lang.Nullable;
import org.springframework.messaging.handler.annotation.DestinationVariable;
/**
* {@link RSocketServiceArgumentResolver} for a
* {@link DestinationVariable @DestinationVariable} annotated argument.
*
* <p>The argument is treated as a single route variable, or in case of a
* Collection or an array, as multiple route variables.
*
* @author Rossen Stoyanchev
* @since 6.0
*/
public class DestinationVariableArgumentResolver implements RSocketServiceArgumentResolver {
@Override
public boolean resolve(
@Nullable Object argument, MethodParameter parameter, RSocketRequestValues.Builder requestValues) {
DestinationVariable annot = parameter.getParameterAnnotation(DestinationVariable.class);
if (annot == null) {
return false;
}
if (argument != null) {
if (argument instanceof Collection) {
((Collection<?>) argument).forEach(requestValues::addRouteVariable);
return true;
}
else if (argument.getClass().isArray()) {
for (Object variable : (Object[]) argument) {
requestValues.addRouteVariable(variable);
}
return true;
}
else {
requestValues.addRouteVariable(argument);
}
}
return true;
}
}

View File

@ -0,0 +1,61 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import org.springframework.core.MethodParameter;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
/**
* {@link RSocketServiceArgumentResolver} for metadata entries.
*
* <p>Supports a sequence of an {@link Object} parameter for the metadata value,
* followed by a {@link MimeType} parameter for the metadata mime type.
*
* <p>This should be ordered last to give other, more specific resolvers a
* chance to resolve the argument.
*
* @author Rossen Stoyanchev
* @since 6.0
*/
public class MetadataArgumentResolver implements RSocketServiceArgumentResolver {
@Override
public boolean resolve(
@Nullable Object argument, MethodParameter parameter, RSocketRequestValues.Builder requestValues) {
int index = parameter.getParameterIndex();
Class<?>[] paramTypes = parameter.getExecutable().getParameterTypes();
if (parameter.getParameterType().equals(MimeType.class)) {
Assert.notNull(argument, "MimeType parameter is required");
Assert.state(index > 0, "MimeType parameter should have preceding metadata object parameter");
requestValues.addMimeType((MimeType) argument);
return true;
}
if (paramTypes.length > (index + 1) && MimeType.class.equals(paramTypes[index + 1])) {
Assert.notNull(argument, "MimeType parameter is required");
requestValues.addMetadata(argument);
return true;
}
return false;
}
}

View File

@ -0,0 +1,78 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.lang.Nullable;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.util.Assert;
/**
* {@link RSocketServiceArgumentResolver} for {@link Payload @Payload}
* annotated arguments.
*
* @author Rossen Stoyanchev
* @since 6.0
*/
public class PayloadArgumentResolver implements RSocketServiceArgumentResolver {
private final ReactiveAdapterRegistry reactiveAdapterRegistry;
private final boolean useDefaultResolution;
public PayloadArgumentResolver(ReactiveAdapterRegistry reactiveAdapterRegistry, boolean useDefaultResolution) {
this.useDefaultResolution = useDefaultResolution;
Assert.notNull(reactiveAdapterRegistry, "ReactiveAdapterRegistry is required");
this.reactiveAdapterRegistry = reactiveAdapterRegistry;
}
@Override
public boolean resolve(
@Nullable Object argument, MethodParameter parameter, RSocketRequestValues.Builder requestValues) {
Payload annot = parameter.getParameterAnnotation(Payload.class);
if (annot == null && !this.useDefaultResolution) {
return false;
}
if (argument != null) {
ReactiveAdapter reactiveAdapter = this.reactiveAdapterRegistry.getAdapter(parameter.getParameterType());
if (reactiveAdapter == null) {
requestValues.setPayloadValue(argument);
}
else {
MethodParameter nestedParameter = parameter.nested();
String message = "Async type for @Payload should produce value(s)";
Assert.isTrue(nestedParameter.getNestedParameterType() != Void.class, message);
Assert.isTrue(!reactiveAdapter.isNoValue(), message);
requestValues.setPayload(
reactiveAdapter.toPublisher(argument),
ParameterizedTypeReference.forType(nestedParameter.getNestedGenericParameterType()));
}
}
return true;
}
}

View File

@ -0,0 +1,79 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* Annotation to declare a method on an RSocket service interface as an RSocket
* endpoint. The endpoint route is defined statically through the annotation
* attributes, and through the input method argument types.
*
* <p>Supported at the type level to express common attributes, to be inherited
* by all methods, such as a base route.
*
* <p>Supported method arguments:
* <table border="1">
* <tr>
* <th>Method Argument</th>
* <th>Description</th>
* <th>Resolver</th>
* </tr>
* <tr>
* <td>{@link org.springframework.messaging.handler.annotation.DestinationVariable @DestinationVariable}</td>
* <td>Add a route variable to expand into the route</td>
* <td>{@link DestinationVariableArgumentResolver}</td>
* </tr>
* <tr>
* <td>{@link org.springframework.messaging.handler.annotation.Payload @Payload}</td>
* <td>Set the input payload(s) for the request</td>
* <td>{@link PayloadArgumentResolver}</td>
* </tr>
* <tr>
* <td>{@link Object} argument followed by {@link org.springframework.util.MimeType} argument</td>
* <td>Add a metadata value</td>
* <td>{@link MetadataArgumentResolver}</td>
* </tr>
* <tr>
* <td>{@link org.springframework.util.MimeType} argument preceded by {@link Object} argument</td>
* <td>Specify the mime type for the preceding metadata value</td>
* <td>{@link MetadataArgumentResolver}</td>
* </tr>
* </table>
*
* @author Rossen Stoyanchev
* @since 6.0
*/
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RSocketExchange {
/**
* Destination-based mapping expressed by this annotation. This is either
* {@link org.springframework.util.AntPathMatcher AntPathMatcher} or
* {@link org.springframework.web.util.pattern.PathPattern PathPattern}
* based pattern, depending on which is configured, matched to the route of
* the stream request.
*/
String value() default "";
}

View File

@ -0,0 +1,268 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.reactivestreams.Publisher;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
import org.springframework.util.StringUtils;
/**
* Container for RSocket request values extracted from an
* {@link RSocketExchange @RSocketExchange}-annotated
* method and argument values passed to it. This is then used to define a request
* via {@link org.springframework.messaging.rsocket.RSocketRequester}.
*
* @author Rossen Stoyanchev
* @since 6.0
*/
public final class RSocketRequestValues {
@Nullable
private final String route;
private final Object[] routeVariables;
private final Map<Object, MimeType> metadata;
@Nullable
private final Object payloadValue;
@Nullable
private final Publisher<?> payload;
@Nullable
private final ParameterizedTypeReference<?> payloadElementType;
public RSocketRequestValues(
@Nullable String route, @Nullable List<Object> routeVariables, @Nullable MetadataHelper metadataHelper,
@Nullable Object payloadValue, @Nullable Publisher<?> payload,
@Nullable ParameterizedTypeReference<?> payloadElementType) {
this.route = route;
this.routeVariables = (routeVariables != null ? routeVariables.toArray() : new Object[0]);
this.metadata = (metadataHelper != null ? metadataHelper.toMap() : Collections.emptyMap());
this.payloadValue = payloadValue;
this.payload = payload;
this.payloadElementType = payloadElementType;
}
/**
* Return the route value for
* {@link org.springframework.messaging.rsocket.RSocketRequester#route(String, Object...) route}.
*/
@Nullable
public String getRoute() {
return this.route;
}
/**
* Return the route variables for
* {@link org.springframework.messaging.rsocket.RSocketRequester#route(String, Object...) route}.
*/
public Object[] getRouteVariables() {
return this.routeVariables;
}
/**
* Return the metadata entries for
* {@link org.springframework.messaging.rsocket.RSocketRequester.RequestSpec#metadata(Object, MimeType)}.
*/
public Map<Object, MimeType> getMetadata() {
return this.metadata;
}
/**
* Return the request payload as a value to be serialized, if set.
* <p>This is mutually exclusive with {@link #getPayload()}.
* Only one of the two or neither is set.
*/
@Nullable
public Object getPayloadValue() {
return this.payloadValue;
}
/**
* Return the request payload as a Publisher.
* <p>This is mutually exclusive with {@link #getPayloadValue()}.
* Only one of the two or neither is set.
*/
@Nullable
public Publisher<?> getPayload() {
return this.payload;
}
/**
* Return the element type for a {@linkplain #getPayload() Publisher payload}.
*/
@Nullable
public ParameterizedTypeReference<?> getPayloadElementType() {
return this.payloadElementType;
}
public static Builder builder(@Nullable String route) {
return new Builder(route);
}
/**
* Builder for {@link RSocketRequestValues}.
*/
public final static class Builder {
@Nullable
private String route;
@Nullable
private List<Object> routeVariables;
@Nullable
private MetadataHelper metadataHelper;
@Nullable
private Object payloadValue;
@Nullable
private Publisher<?> payload;
@Nullable
private ParameterizedTypeReference<?> payloadElementType;
Builder(@Nullable String route) {
this.route = (StringUtils.hasText(route) ? route : null);
}
/**
* Set the route for the request.
*/
public Builder setRoute(String route) {
this.route = route;
this.routeVariables = null;
return this;
}
/**
* Add a route variable.
*/
public Builder addRouteVariable(Object variable) {
this.routeVariables = (this.routeVariables != null ? this.routeVariables : new ArrayList<>());
this.routeVariables.add(variable);
return this;
}
/**
* Add a metadata entry.
* This must be followed by a corresponding call to {@link #addMimeType(MimeType)}.
*/
public Builder addMetadata(Object metadata) {
this.metadataHelper = (this.metadataHelper != null ? this.metadataHelper : new MetadataHelper());
this.metadataHelper.addMetadata(metadata);
return this;
}
/**
* Set the mime type for a metadata entry.
* This must be preceded by a call to {@link #addMetadata(Object)}.
*/
public Builder addMimeType(MimeType mimeType) {
this.metadataHelper = (this.metadataHelper != null ? this.metadataHelper : new MetadataHelper());
this.metadataHelper.addMimeType(mimeType);
return this;
}
/**
* Set the request payload as a concrete value to be serialized.
* <p>This is mutually exclusive with, and resets any previously set
* {@linkplain #setPayload(Publisher, ParameterizedTypeReference) payload Publisher}.
*/
public Builder setPayloadValue(Object payloadValue) {
this.payloadValue = payloadValue;
this.payload = null;
this.payloadElementType = null;
return this;
}
/**
* Set the request payload value to be serialized.
*/
public <T, P extends Publisher<T>> Builder setPayload(P payload, ParameterizedTypeReference<T> elementTye) {
this.payload = payload;
this.payloadElementType = elementTye;
this.payloadValue = null;
return this;
}
/**
* Build the {@link RSocketRequestValues} instance.
*/
public RSocketRequestValues build() {
return new RSocketRequestValues(
this.route, this.routeVariables, this.metadataHelper,
this.payloadValue, this.payload, this.payloadElementType);
}
}
/**
* Class that helps to collect a map of metadata entries as a series of calls
* to provide each metadata and mime type pair.
*/
private static class MetadataHelper {
private final List<Object> metadata = new ArrayList<>();
private final List<MimeType> mimeTypes = new ArrayList<>();
public void addMetadata(Object metadata) {
Assert.isTrue(this.metadata.size() == this.mimeTypes.size(), "Invalid state: " + this);
this.metadata.add(metadata);
}
public void addMimeType(MimeType mimeType) {
Assert.isTrue(this.metadata.size() == (this.mimeTypes.size() + 1), "Invalid state: " + this);
this.mimeTypes.add(mimeType);
}
public Map<Object, MimeType> toMap() {
Map<Object, MimeType> map = new LinkedHashMap<>(this.metadata.size());
for (int i = 0; i < this.metadata.size(); i++) {
map.put(this.metadata.get(i), this.mimeTypes.get(i));
}
return map;
}
@Override
public String toString() {
return "metadata=" + this.metadata + ", mimeTypes=" + this.mimeTypes;
}
}
}

View File

@ -0,0 +1,40 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import org.springframework.core.MethodParameter;
import org.springframework.lang.Nullable;
/**
* Resolve an argument from an {@link RSocketExchange @RSocketExchange}-annotated
* method to one or more RSocket request values.
*
* @author Rossen Stoyanchev
* @since 6.0
*/
public interface RSocketServiceArgumentResolver {
/**
* Resolve the argument value.
* @param argument the argument value
* @param parameter the method parameter for the argument
* @param requestValues builder to add RSocket request values to
* @return {@code true} if the argument was resolved, {@code false} otherwise
*/
boolean resolve(@Nullable Object argument, MethodParameter parameter, RSocketRequestValues.Builder requestValues);
}

View File

@ -0,0 +1,242 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.SynthesizingMethodParameter;
import org.springframework.lang.Nullable;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
import org.springframework.util.StringUtils;
import org.springframework.util.StringValueResolver;
/**
* Implements the invocation of an {@link RSocketExchange @RSocketExchange}-annotated,
* {@link RSocketServiceProxyFactory#createClient(Class) RSocket service proxy} method
* by delegating to an {@link RSocketRequester} to perform actual requests.
*
* @author Rossen Stoyanchev
* @since 6.0
*/
final class RSocketServiceMethod {
private final Method method;
private final MethodParameter[] parameters;
private final List<RSocketServiceArgumentResolver> argumentResolvers;
@Nullable
private final String route;
private final Function<RSocketRequestValues, Object> responseFunction;
RSocketServiceMethod(
Method method, Class<?> containingClass, List<RSocketServiceArgumentResolver> argumentResolvers,
RSocketRequester rsocketRequester, @Nullable StringValueResolver embeddedValueResolver,
ReactiveAdapterRegistry reactiveRegistry, Duration blockTimeout) {
this.method = method;
this.parameters = initMethodParameters(method);
this.argumentResolvers = argumentResolvers;
this.route = initRoute(method, containingClass, rsocketRequester.strategies(), embeddedValueResolver);
this.responseFunction = initResponseFunction(
rsocketRequester, method, reactiveRegistry, blockTimeout);
}
private static MethodParameter[] initMethodParameters(Method method) {
int count = method.getParameterCount();
if (count == 0) {
return new MethodParameter[0];
}
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
MethodParameter[] parameters = new MethodParameter[count];
for (int i = 0; i < count; i++) {
parameters[i] = new SynthesizingMethodParameter(method, i);
parameters[i].initParameterNameDiscovery(nameDiscoverer);
}
return parameters;
}
@Nullable
private static String initRoute(
Method method, Class<?> containingClass, RSocketStrategies strategies,
@Nullable StringValueResolver embeddedValueResolver) {
RSocketExchange annot1 = AnnotatedElementUtils.findMergedAnnotation(containingClass, RSocketExchange.class);
RSocketExchange annot2 = AnnotatedElementUtils.findMergedAnnotation(method, RSocketExchange.class);
Assert.notNull(annot2, "Expected RSocketExchange annotation");
String route1 = (annot1 != null ? annot1.value() : null);
String route2 = annot2.value();
if (embeddedValueResolver != null) {
route1 = (route1 != null ? embeddedValueResolver.resolveStringValue(route1) : null);
route2 = embeddedValueResolver.resolveStringValue(route2);
}
boolean hasRoute1 = StringUtils.hasText(route1);
boolean hasRoute2 = StringUtils.hasText(route2);
if (hasRoute1 && hasRoute2) {
return strategies.routeMatcher().combine(route1, route2);
}
if (!hasRoute1 && !hasRoute2) {
return null;
}
return (hasRoute2 ? route2 : route1);
}
private static Function<RSocketRequestValues, Object> initResponseFunction(
RSocketRequester requester, Method method,
ReactiveAdapterRegistry reactiveRegistry, Duration blockTimeout) {
MethodParameter returnParam = new MethodParameter(method, -1);
Class<?> returnType = returnParam.getParameterType();
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
Class<?> actualType = actualParam.getNestedParameterType();
Function<RSocketRequestValues, Publisher<?>> responseFunction;
if (actualType.equals(void.class) || actualType.equals(Void.class) ||
(reactiveAdapter != null && reactiveAdapter.isNoValue())) {
responseFunction = values -> {
RSocketRequester.RetrieveSpec retrieveSpec = initRequest(requester, values);
return (values.getPayload() == null && values.getPayloadValue() == null ?
((RSocketRequester.RequestSpec) retrieveSpec).sendMetadata() : retrieveSpec.send());
};
}
else if (reactiveAdapter == null) {
responseFunction = values -> initRequest(requester, values).retrieveMono(actualType);
}
else {
ParameterizedTypeReference<?> payloadType =
ParameterizedTypeReference.forType(actualParam.getNestedGenericParameterType());
responseFunction = values -> (
reactiveAdapter.isMultiValue() ?
initRequest(requester, values).retrieveFlux(payloadType) :
initRequest(requester, values).retrieveMono(payloadType));
}
boolean blockForOptional = returnType.equals(Optional.class);
return responseFunction.andThen(responsePublisher -> {
if (reactiveAdapter != null) {
return reactiveAdapter.fromPublisher(responsePublisher);
}
return (blockForOptional ?
((Mono<?>) responsePublisher).blockOptional(blockTimeout) :
((Mono<?>) responsePublisher).block(blockTimeout));
});
}
@SuppressWarnings("ReactiveStreamsUnusedPublisher")
private static RSocketRequester.RetrieveSpec initRequest(
RSocketRequester requester, RSocketRequestValues requestValues) {
RSocketRequester.RequestSpec spec;
String route = requestValues.getRoute();
Map<Object, MimeType> metadata = requestValues.getMetadata();
if (StringUtils.hasText(route)) {
spec = requester.route(route, requestValues.getRouteVariables());
for (Map.Entry<Object, MimeType> entry : metadata.entrySet()) {
spec.metadata(entry.getKey(), entry.getValue());
}
}
else {
Iterator<Map.Entry<Object, MimeType>> iterator = metadata.entrySet().iterator();
Assert.isTrue(iterator.hasNext(), "Neither route nor metadata provided");
Map.Entry<Object, MimeType> entry = iterator.next();
spec = requester.metadata(entry.getKey(), entry.getValue());
while (iterator.hasNext()) {
spec.metadata(entry.getKey(), entry.getValue());
}
}
if (requestValues.getPayloadValue() != null) {
spec.data(requestValues.getPayloadValue());
}
else if (requestValues.getPayload() != null) {
Assert.notNull(requestValues.getPayloadElementType(), "Publisher body element type is required");
spec.data(requestValues.getPayload(), requestValues.getPayloadElementType());
}
return spec;
}
public Method getMethod() {
return this.method;
}
@Nullable
public Object invoke(Object[] arguments) {
RSocketRequestValues.Builder requestValues = RSocketRequestValues.builder(this.route);
applyArguments(requestValues, arguments);
return this.responseFunction.apply(requestValues.build());
}
private void applyArguments(RSocketRequestValues.Builder requestValues, Object[] arguments) {
Assert.isTrue(arguments.length == this.parameters.length, "Method argument mismatch");
for (int i = 0; i < arguments.length; i++) {
Object value = arguments[i];
boolean resolved = false;
for (RSocketServiceArgumentResolver resolver : this.argumentResolvers) {
if (resolver.resolve(value, this.parameters[i], requestValues)) {
resolved = true;
break;
}
}
Assert.state(resolved, formatArgumentError(this.parameters[i], "No suitable resolver"));
}
}
@SuppressWarnings("SameParameterValue")
private static String formatArgumentError(MethodParameter param, String message) {
return "Could not resolve parameter [" + param.getParameterIndex() + "] in " +
param.getExecutable().toGenericString() + (StringUtils.hasText(message) ? ": " + message : "");
}
}

View File

@ -0,0 +1,217 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.aop.framework.ReflectiveMethodInvocation;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.EmbeddedValueResolverAware;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.lang.Nullable;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.util.Assert;
import org.springframework.util.StringValueResolver;
/**
* Factory for creating a client proxy given an RSocket service interface with
* {@link RSocketExchange @RSocketExchange} methods.
*
* <p>This class is intended to be declared as a bean in Spring configuration.
*
* @author Rossen Stoyanchev
* @since 6.0
*/
public final class RSocketServiceProxyFactory implements InitializingBean, EmbeddedValueResolverAware {
private final RSocketRequester rsocketRequester;
@Nullable
private List<RSocketServiceArgumentResolver> customArgumentResolvers;
@Nullable
private List<RSocketServiceArgumentResolver> argumentResolvers;
@Nullable
private StringValueResolver embeddedValueResolver;
private ReactiveAdapterRegistry reactiveAdapterRegistry = ReactiveAdapterRegistry.getSharedInstance();
private Duration blockTimeout = Duration.ofSeconds(5);
/**
* Create an instance with the underlying RSocketRequester to perform requests with.
* @param rsocketRequester the requester to use
*/
public RSocketServiceProxyFactory(RSocketRequester rsocketRequester) {
Assert.notNull(rsocketRequester, "RSocketRequester is required");
this.rsocketRequester = rsocketRequester;
}
/**
* Register a custom argument resolver, invoked ahead of default resolvers.
* @param resolver the resolver to add
*/
public void addCustomArgumentResolver(RSocketServiceArgumentResolver resolver) {
if (this.customArgumentResolvers == null) {
this.customArgumentResolvers = new ArrayList<>();
}
this.customArgumentResolvers.add(resolver);
}
/**
* Set the custom argument resolvers to use, ahead of default resolvers.
* @param resolvers the resolvers to use
*/
public void setCustomArgumentResolvers(List<RSocketServiceArgumentResolver> resolvers) {
this.customArgumentResolvers = new ArrayList<>(resolvers);
}
/**
* Set the StringValueResolver to use for resolving placeholders and
* expressions in {@link RSocketExchange#value()}.
* @param resolver the resolver to use
*/
@Override
public void setEmbeddedValueResolver(StringValueResolver resolver) {
this.embeddedValueResolver = resolver;
}
/**
* Set the {@link ReactiveAdapterRegistry} to use to support different
* asynchronous types for RSocket service method return values.
* <p>By default this is {@link ReactiveAdapterRegistry#getSharedInstance()}.
*/
public void setReactiveAdapterRegistry(ReactiveAdapterRegistry registry) {
this.reactiveAdapterRegistry = registry;
}
/**
* Configure how long to wait for a response for an RSocket service method
* with a synchronous (blocking) method signature.
* <p>By default this is 5 seconds.
* @param blockTimeout the timeout value
*/
public void setBlockTimeout(Duration blockTimeout) {
this.blockTimeout = blockTimeout;
}
@Override
public void afterPropertiesSet() throws Exception {
this.argumentResolvers = initArgumentResolvers();
}
private List<RSocketServiceArgumentResolver> initArgumentResolvers() {
List<RSocketServiceArgumentResolver> resolvers = new ArrayList<>();
// Custom
if (this.customArgumentResolvers != null) {
resolvers.addAll(this.customArgumentResolvers);
}
// Annotation-based
resolvers.add(new PayloadArgumentResolver(this.reactiveAdapterRegistry, false));
resolvers.add(new DestinationVariableArgumentResolver());
// Type-based
resolvers.add(new MetadataArgumentResolver());
// Fallback
resolvers.add(new PayloadArgumentResolver(this.reactiveAdapterRegistry, true));
return resolvers;
}
/**
* Return a proxy that implements the given RSocket service interface to
* perform RSocket requests and retrieve responses through the configured
* {@link RSocketRequester}.
* @param serviceType the RSocket service to create a proxy for
* @param <S> the RSocket service type
* @return the created proxy
*/
public <S> S createClient(Class<S> serviceType) {
List<RSocketServiceMethod> serviceMethods =
MethodIntrospector.selectMethods(serviceType, this::isExchangeMethod).stream()
.map(method -> createRSocketServiceMethod(serviceType, method))
.toList();
return ProxyFactory.getProxy(serviceType, new ServiceMethodInterceptor(serviceMethods));
}
private boolean isExchangeMethod(Method method) {
return AnnotatedElementUtils.hasAnnotation(method, RSocketExchange.class);
}
private <S> RSocketServiceMethod createRSocketServiceMethod(Class<S> serviceType, Method method) {
Assert.notNull(this.argumentResolvers,
"No argument resolvers: afterPropertiesSet was not called");
return new RSocketServiceMethod(
method, serviceType, this.argumentResolvers, this.rsocketRequester,
this.embeddedValueResolver, this.reactiveAdapterRegistry, this.blockTimeout);
}
/**
* {@link MethodInterceptor} that invokes an {@link RSocketServiceMethod}.
*/
private static final class ServiceMethodInterceptor implements MethodInterceptor {
private final Map<Method, RSocketServiceMethod> serviceMethods;
private ServiceMethodInterceptor(List<RSocketServiceMethod> methods) {
this.serviceMethods = methods.stream()
.collect(Collectors.toMap(RSocketServiceMethod::getMethod, Function.identity()));
}
@Override
public Object invoke(MethodInvocation invocation) throws Throwable {
Method method = invocation.getMethod();
RSocketServiceMethod serviceMethod = this.serviceMethods.get(method);
if (serviceMethod != null) {
return serviceMethod.invoke(invocation.getArguments());
}
if (method.isDefault()) {
if (invocation instanceof ReflectiveMethodInvocation reflectiveMethodInvocation) {
Object proxy = reflectiveMethodInvocation.getProxy();
return InvocationHandler.invokeDefault(proxy, method, invocation.getArguments());
}
}
throw new IllegalStateException("Unexpected method invocation: " + method);
}
}
}

View File

@ -0,0 +1,11 @@
/**
* Annotations to declare an RSocket service contract with request methods along
* with a proxy factory backed by an
* {@link org.springframework.messaging.rsocket.RSocketRequester}.
*/
@NonNullApi
@NonNullFields
package org.springframework.messaging.rsocket.service;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@ -28,17 +28,14 @@ import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Single;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.metadata.WellKnownMimeType;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.lang.Nullable;
import org.springframework.messaging.rsocket.RSocketRequester.RequestSpec;
import org.springframework.messaging.rsocket.RSocketRequester.RetrieveSpec;
import org.springframework.util.MimeType;
@ -246,73 +243,4 @@ public class DefaultRSocketRequesterTests {
return PayloadUtils.createPayload(DefaultDataBufferFactory.sharedInstance.wrap(bytes));
}
private static class TestRSocket implements RSocket {
private Mono<Payload> payloadMonoToReturn = Mono.empty();
private Flux<Payload> payloadFluxToReturn = Flux.empty();
@Nullable private volatile String savedMethodName;
@Nullable private volatile Payload savedPayload;
@Nullable private volatile Flux<Payload> savedPayloadFlux;
void setPayloadMonoToReturn(Mono<Payload> payloadMonoToReturn) {
this.payloadMonoToReturn = payloadMonoToReturn;
}
void setPayloadFluxToReturn(Flux<Payload> payloadFluxToReturn) {
this.payloadFluxToReturn = payloadFluxToReturn;
}
@Nullable
String getSavedMethodName() {
return this.savedMethodName;
}
@Nullable
Payload getSavedPayload() {
return this.savedPayload;
}
@Nullable
Flux<Payload> getSavedPayloadFlux() {
return this.savedPayloadFlux;
}
public void reset() {
this.savedMethodName = null;
this.savedPayload = null;
this.savedPayloadFlux = null;
}
@Override
public Mono<Void> fireAndForget(Payload payload) {
this.savedMethodName = "fireAndForget";
this.savedPayload = payload;
return Mono.empty();
}
@Override
public Mono<Payload> requestResponse(Payload payload) {
this.savedMethodName = "requestResponse";
this.savedPayload = payload;
return this.payloadMonoToReturn;
}
@Override
public Flux<Payload> requestStream(Payload payload) {
this.savedMethodName = "requestStream";
this.savedPayload = payload;
return this.payloadFluxToReturn;
}
@Override
public Flux<Payload> requestChannel(Publisher<Payload> publisher) {
this.savedMethodName = "requestChannel";
this.savedPayloadFlux = Flux.from(publisher);
return this.payloadFluxToReturn;
}
}
}

View File

@ -0,0 +1,101 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.lang.Nullable;
/**
* {@link RSocket} that saves the name of the invoked method and the input payload(s).
*/
public class TestRSocket implements RSocket {
private Mono<Payload> payloadMonoToReturn = Mono.empty();
private Flux<Payload> payloadFluxToReturn = Flux.empty();
@Nullable private volatile String savedMethodName;
@Nullable private volatile Payload savedPayload;
@Nullable private volatile Flux<Payload> savedPayloadFlux;
public void setPayloadMonoToReturn(Mono<Payload> payloadMonoToReturn) {
this.payloadMonoToReturn = payloadMonoToReturn;
}
public void setPayloadFluxToReturn(Flux<Payload> payloadFluxToReturn) {
this.payloadFluxToReturn = payloadFluxToReturn;
}
@Nullable
public String getSavedMethodName() {
return this.savedMethodName;
}
@Nullable
public Payload getSavedPayload() {
return this.savedPayload;
}
@Nullable
public Flux<Payload> getSavedPayloadFlux() {
return this.savedPayloadFlux;
}
public void reset() {
this.savedMethodName = null;
this.savedPayload = null;
this.savedPayloadFlux = null;
}
@Override
public Mono<Void> fireAndForget(Payload payload) {
this.savedMethodName = "fireAndForget";
this.savedPayload = payload;
return Mono.empty();
}
@Override
public Mono<Payload> requestResponse(Payload payload) {
this.savedMethodName = "requestResponse";
this.savedPayload = payload;
return this.payloadMonoToReturn;
}
@Override
public Flux<Payload> requestStream(Payload payload) {
this.savedMethodName = "requestStream";
this.savedPayload = payload;
return this.payloadFluxToReturn;
}
@Override
public Flux<Payload> requestChannel(Publisher<Payload> publisher) {
this.savedMethodName = "requestChannel";
this.savedPayloadFlux = Flux.from(publisher);
return this.payloadFluxToReturn;
}
}

View File

@ -0,0 +1,99 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.handler.annotation.DestinationVariable;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Unit tests for {@link DestinationVariableArgumentResolver}.
* @author Rossen Stoyanchev
*/
public class DestinationVariableArgumentResolverTests extends RSocketServiceArgumentResolverTestSupport {
@Override
protected RSocketServiceArgumentResolver initResolver() {
return new DestinationVariableArgumentResolver();
}
@Test
void variable() {
String value = "foo";
boolean resolved = execute(value, initMethodParameter(Service.class, "execute", 0));
assertThat(resolved).isTrue();
assertThat(getRequestValues().getRouteVariables()).containsExactly(value);
}
@Test
void variableList() {
List<String> values = Arrays.asList("foo", "bar", "baz");
boolean resolved = execute(values, initMethodParameter(Service.class, "execute", 0));
assertThat(resolved).isTrue();
assertThat(getRequestValues().getRouteVariables()).containsExactlyElementsOf(values);
}
@Test
void variableArray() {
String[] values = new String[] {"foo", "bar", "baz"};
boolean resolved = execute(values, initMethodParameter(Service.class, "execute", 0));
assertThat(resolved).isTrue();
assertThat(getRequestValues().getRouteVariables()).containsExactlyElementsOf(Arrays.asList(values));
}
@Test
void notRequestBody() {
MethodParameter parameter = initMethodParameter(Service.class, "executeNotAnnotated", 0);
boolean resolved = execute("value", parameter);
assertThat(resolved).isFalse();
}
@Test
void ignoreNull() {
boolean resolved = execute(null, initMethodParameter(Service.class, "execute", 0));
assertThat(resolved).isTrue();
assertThat(getRequestValues().getPayloadValue()).isNull();
assertThat(getRequestValues().getPayload()).isNull();
assertThat(getRequestValues().getPayloadElementType()).isNull();
}
@SuppressWarnings("unused")
private interface Service {
void execute(@DestinationVariable String variable);
void executeList(@DestinationVariable List<String> variables);
void executeArray(@DestinationVariable String[] variables);
void executeNotAnnotated(String variable);
}
}

View File

@ -0,0 +1,71 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.util.LinkedHashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.springframework.core.MethodParameter;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Unit tests for {@link MetadataArgumentResolver}.
* @author Rossen Stoyanchev
*/
public class MetadataArgumentResolverTests extends RSocketServiceArgumentResolverTestSupport {
@Override
protected RSocketServiceArgumentResolver initResolver() {
return new MetadataArgumentResolver();
}
@Test
void metadata() {
MethodParameter param1 = initMethodParameter(Service.class, "execute", 0);
MethodParameter param2 = initMethodParameter(Service.class, "execute", 1);
MethodParameter param3 = initMethodParameter(Service.class, "execute", 2);
MethodParameter param4 = initMethodParameter(Service.class, "execute", 3);
assertThat(execute("foo", param1)).isTrue();
assertThat(execute(MimeTypeUtils.APPLICATION_JSON, param2)).isTrue();
assertThat(execute("bar", param3)).isTrue();
assertThat(execute(MimeTypeUtils.APPLICATION_XML, param4)).isTrue();
Map<Object, MimeType> expected = new LinkedHashMap<>();
expected.put("foo", MimeTypeUtils.APPLICATION_JSON);
expected.put("bar", MimeTypeUtils.APPLICATION_XML);
assertThat(getRequestValues().getMetadata()).containsExactlyEntriesOf(expected);
}
@SuppressWarnings("unused")
private interface Service {
void execute(String metadata1, MimeType mimeType1, String metadata2, MimeType mimeType2);
void executeNotAnnotated(String foo, String bar);
}
}

View File

@ -0,0 +1,129 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Single;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.messaging.handler.annotation.Payload;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/**
* Unit tests for {@link PayloadArgumentResolver}.
* @author Rossen Stoyanchev
*/
public class PayloadArgumentResolverTests extends RSocketServiceArgumentResolverTestSupport {
@Override
protected RSocketServiceArgumentResolver initResolver() {
return new PayloadArgumentResolver(ReactiveAdapterRegistry.getSharedInstance(), false);
}
@Test
void stringPayload() {
String payload = "payloadValue";
boolean resolved = execute(payload, initMethodParameter(Service.class, "execute", 0));
assertThat(resolved).isTrue();
assertThat(getRequestValues().getPayloadValue()).isEqualTo(payload);
assertThat(getRequestValues().getPayload()).isNull();
}
@Test
void monoPayload() {
Mono<String> payloadMono = Mono.just("payloadValue");
boolean resolved = execute(payloadMono, initMethodParameter(Service.class, "executeMono", 0));
assertThat(resolved).isTrue();
assertThat(getRequestValues().getPayloadValue()).isNull();
assertThat(getRequestValues().getPayload()).isSameAs(payloadMono);
assertThat(getRequestValues().getPayloadElementType()).isEqualTo(new ParameterizedTypeReference<String>() {});
}
@Test
@SuppressWarnings("unchecked")
void singlePayload() {
boolean resolved = execute(Single.just("bodyValue"), initMethodParameter(Service.class, "executeSingle", 0));
assertThat(resolved).isTrue();
assertThat(getRequestValues().getPayloadValue()).isNull();
assertThat(getRequestValues().getPayloadElementType()).isEqualTo(new ParameterizedTypeReference<String>() {});
Publisher<?> payload = getRequestValues().getPayload();
assertThat(payload).isNotNull();
assertThat(((Mono<String>) payload).block()).isEqualTo("bodyValue");
}
@Test
void monoVoid() {
assertThatIllegalArgumentException()
.isThrownBy(() -> execute(Mono.empty(), initMethodParameter(Service.class, "executeMonoVoid", 0)))
.withMessage("Async type for @Payload should produce value(s)");
}
@Test
void completable() {
assertThatIllegalArgumentException()
.isThrownBy(() -> execute(Completable.complete(), initMethodParameter(Service.class, "executeCompletable", 0)))
.withMessage("Async type for @Payload should produce value(s)");
}
@Test
void notRequestBody() {
MethodParameter parameter = initMethodParameter(Service.class, "executeNotAnnotated", 0);
boolean resolved = execute("value", parameter);
assertThat(resolved).isFalse();
}
@Test
void ignoreNull() {
boolean resolved = execute(null, initMethodParameter(Service.class, "execute", 0));
assertThat(resolved).isTrue();
assertThat(getRequestValues().getPayloadValue()).isNull();
assertThat(getRequestValues().getPayload()).isNull();
assertThat(getRequestValues().getPayloadElementType()).isNull();
}
@SuppressWarnings("unused")
private interface Service {
void execute(@Payload String body);
void executeMono(@Payload Mono<String> body);
void executeSingle(@Payload Single<String> body);
void executeMonoVoid(@Payload Mono<Void> body);
void executeCompletable(@Payload Completable body);
void executeNotAnnotated(String body);
}
}

View File

@ -0,0 +1,95 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.util.MimeTypeUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/**
* Unit tests for {@link RSocketRequestValues}.
* @author Rossen Stoyanchev
*/
public class RSocketRequestValuesTests {
@Test
void route() {
String myRoute = "myRoute";
RSocketRequestValues values = RSocketRequestValues.builder(myRoute).build();
assertThat(values.getRoute()).isEqualTo(myRoute);
}
@Test
void routeOverride() {
RSocketRequestValues values = RSocketRequestValues.builder("route1").setRoute("route2").build();
assertThat(values.getRoute()).isEqualTo("route2");
}
@Test
void payloadValue() {
String payload = "myValue";
RSocketRequestValues values = RSocketRequestValues.builder(null).setPayloadValue(payload).build();
assertThat(values.getPayloadValue()).isEqualTo(payload);
assertThat(values.getPayload()).isNull();
}
@Test
void payloadPublisher() {
Mono<String> payloadMono = Mono.just( "myValue");
RSocketRequestValues values = RSocketRequestValues.builder(null)
.setPayload(payloadMono, new ParameterizedTypeReference<>() { })
.build();
assertThat(values.getPayloadValue()).isNull();
assertThat(values.getPayload()).isSameAs(payloadMono);
}
@Test
void metadata() {
RSocketRequestValues values = RSocketRequestValues.builder(null)
.addMetadata("myMetadata1").addMimeType(MimeTypeUtils.TEXT_PLAIN)
.addMetadata("myMetadata2").addMimeType(MimeTypeUtils.TEXT_HTML)
.build();
assertThat(values.getMetadata())
.hasSize(2)
.containsEntry("myMetadata1", MimeTypeUtils.TEXT_PLAIN)
.containsEntry("myMetadata2", MimeTypeUtils.TEXT_HTML);
}
@Test
void metadataInvalidEntry() {
// MimeType without metadata
assertThatIllegalArgumentException()
.isThrownBy(() -> RSocketRequestValues.builder(null).addMimeType(MimeTypeUtils.TEXT_PLAIN));
// Metadata without MimeType
assertThatIllegalArgumentException()
.isThrownBy(() -> RSocketRequestValues.builder(null)
.addMetadata("metadata1")
.addMetadata("metadata2"));
}
}

View File

@ -0,0 +1,60 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.lang.reflect.Method;
import org.springframework.core.MethodParameter;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
/**
* Base class for {@link RSocketServiceArgumentResolver} test fixtures.
* @author Rossen Stoyanchev
*/
public abstract class RSocketServiceArgumentResolverTestSupport {
@Nullable
private RSocketServiceArgumentResolver resolver;
private final RSocketRequestValues.Builder requestValuesBuilder = RSocketRequestValues.builder(null);
@Nullable
private RSocketRequestValues requestValues;
protected RSocketServiceArgumentResolverTestSupport() {
this.resolver = initResolver();
}
protected abstract RSocketServiceArgumentResolver initResolver();
protected static MethodParameter initMethodParameter(Class<?> serviceClass, String methodName, int index) {
Method method = ClassUtils.getMethod(serviceClass, methodName, (Class<?>[]) null);
return new MethodParameter(method, index);
}
protected boolean execute(Object payload, MethodParameter parameter) {
return this.resolver.resolve(payload, parameter, this.requestValuesBuilder);
}
protected RSocketRequestValues getRequestValues() {
this.requestValues = (this.requestValues != null ? this.requestValues : this.requestValuesBuilder.build());
return this.requestValues;
}
}

View File

@ -0,0 +1,162 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.time.Duration;
import io.rsocket.SocketAcceptor;
import io.rsocket.core.RSocketServer;
import io.rsocket.metadata.WellKnownMimeType;
import io.rsocket.transport.netty.server.CloseableChannel;
import io.rsocket.transport.netty.server.TcpServerTransport;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
import org.springframework.stereotype.Controller;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
/**
* Integration tests with RSocket Service client.
*
* @author Rossen Stoyanchev
*/
public class RSocketServiceIntegrationTests {
private static CloseableChannel server;
private static RSocketRequester requester;
private static Service serviceProxy;
@BeforeAll
@SuppressWarnings("ConstantConditions")
public static void setupOnce() throws Exception {
MimeType metadataMimeType = MimeTypeUtils.parseMimeType(
WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString());
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(ServerConfig.class);
RSocketMessageHandler messageHandler = context.getBean(RSocketMessageHandler.class);
SocketAcceptor responder = messageHandler.responder();
server = RSocketServer.create(responder)
.bind(TcpServerTransport.create("localhost", 7000))
.block();
requester = RSocketRequester.builder()
.metadataMimeType(metadataMimeType)
.rsocketStrategies(context.getBean(RSocketStrategies.class))
.tcp("localhost", 7000);
RSocketServiceProxyFactory proxyFactory = new RSocketServiceProxyFactory(requester);
proxyFactory.afterPropertiesSet();
serviceProxy = proxyFactory.createClient(Service.class);
}
@AfterAll
public static void tearDownOnce() {
requester.rsocketClient().dispose();
server.dispose();
}
@Test
public void echoAsync() {
Flux<String> result = Flux.range(1, 3).concatMap(i -> serviceProxy.echoAsync("Hello " + i));
StepVerifier.create(result)
.expectNext("Hello 1 async").expectNext("Hello 2 async").expectNext("Hello 3 async")
.expectComplete()
.verify(Duration.ofSeconds(5));
}
@Test
public void echoStream() {
Flux<String> result = serviceProxy.echoStream("Hello");
StepVerifier.create(result)
.expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7")
.thenCancel()
.verify(Duration.ofSeconds(5));
}
@Controller
interface Service {
@RSocketExchange("echo-async")
Mono<String> echoAsync(String payload);
@RSocketExchange("echo-stream")
Flux<String> echoStream(String payload);
}
@Controller
static class ServerController {
@MessageMapping("echo-async")
Mono<String> echoAsync(String payload) {
return Mono.delay(Duration.ofMillis(10)).map(aLong -> payload + " async");
}
@MessageMapping("echo-stream")
Flux<String> echoStream(String payload) {
return Flux.interval(Duration.ofMillis(10)).map(aLong -> payload + " " + aLong);
}
}
@Configuration
static class ServerConfig {
@Bean
public ServerController controller() {
return new ServerController();
}
@Bean
public RSocketMessageHandler messageHandler(RSocketStrategies rsocketStrategies) {
RSocketMessageHandler handler = new RSocketMessageHandler();
handler.setRSocketStrategies(rsocketStrategies);
return handler;
}
@Bean
public RSocketStrategies rsocketStrategies() {
return RSocketStrategies.create();
}
}
}

View File

@ -0,0 +1,150 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.rsocket.service;
import java.time.Duration;
import java.util.List;
import io.rsocket.util.DefaultPayload;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.messaging.rsocket.TestRSocket;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.util.MimeTypeUtils.TEXT_PLAIN;
/**
* Tests for {@link RSocketServiceMethod} that create an interface client with
* an {@link RSocketRequester} delegating to a {@link TestRSocket}.
*
* @author Rossen Stoyanchev
*/
public class RSocketServiceMethodTests {
private TestRSocket rsocket;
private RSocketServiceProxyFactory proxyFactory;
@BeforeEach
public void setUp() throws Exception {
this.rsocket = new TestRSocket();
RSocketRequester requester = RSocketRequester.wrap(this.rsocket, TEXT_PLAIN, TEXT_PLAIN, RSocketStrategies.create());
this.proxyFactory = new RSocketServiceProxyFactory(requester);
this.proxyFactory.afterPropertiesSet();
}
@Test
void fireAndForget() {
ReactorService service = this.proxyFactory.createClient(ReactorService.class);
String payload = "p1";
service.fireAndForget(Mono.just(payload)).block(Duration.ofSeconds(5));
assertThat(this.rsocket.getSavedMethodName()).isEqualTo("fireAndForget");
assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("ff");
assertThat(this.rsocket.getSavedPayload().getDataUtf8()).isEqualTo(payload);
}
@Test
void requestResponse() {
ReactorService service = this.proxyFactory.createClient(ReactorService.class);
String payload1 = "p1";
String payload2 = "p2";
this.rsocket.setPayloadMonoToReturn(
Mono.just(DefaultPayload.create(payload2)));
String response = service.requestResponse(Mono.just(payload1)).block(Duration.ofSeconds(5));
assertThat(response).isEqualTo(payload2);
assertThat(this.rsocket.getSavedMethodName()).isEqualTo("requestResponse");
assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("rr");
assertThat(this.rsocket.getSavedPayload().getDataUtf8()).isEqualTo(payload1);
}
@Test
void requestStream() {
ReactorService service = this.proxyFactory.createClient(ReactorService.class);
String payload1 = "p1";
String payload2 = "p2";
String payload3 = "p3";
this.rsocket.setPayloadFluxToReturn(
Flux.just(DefaultPayload.create(payload2), DefaultPayload.create(payload3)));
List<String> response = service.requestStream(Mono.just(payload1))
.collectList()
.block(Duration.ofSeconds(5));
assertThat(response).containsExactly(payload2, payload3);
assertThat(this.rsocket.getSavedMethodName()).isEqualTo("requestStream");
assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("rs");
assertThat(this.rsocket.getSavedPayload().getDataUtf8()).isEqualTo(payload1);
}
@Test
void requestChannel() {
ReactorService service = this.proxyFactory.createClient(ReactorService.class);
String payload1 = "p1";
String payload2 = "p2";
String payload3 = "p3";
String payload4 = "p4";
this.rsocket.setPayloadFluxToReturn(
Flux.just(DefaultPayload.create(payload3), DefaultPayload.create(payload4)));
List<String> response = service.requestChannel(Flux.just(payload1, payload2))
.collectList()
.block(Duration.ofSeconds(5));
assertThat(response).containsExactly(payload3, payload4);
assertThat(this.rsocket.getSavedMethodName()).isEqualTo("requestChannel");
List<String> savedPayloads = this.rsocket.getSavedPayloadFlux()
.map(io.rsocket.Payload::getDataUtf8)
.collectList()
.block(Duration.ofSeconds(5));
assertThat(savedPayloads).containsExactly("p1", "p2");
}
private interface ReactorService {
@RSocketExchange("ff")
Mono<Void> fireAndForget(@Payload Mono<String> input);
@RSocketExchange("rr")
Mono<String> requestResponse(@Payload Mono<String> input);
@RSocketExchange("rs")
Flux<String> requestStream(@Payload Mono<String> input);
@RSocketExchange("rc")
Flux<String> requestChannel(@Payload Flux<String> input);
}
}

View File

@ -55,31 +55,23 @@ public class RequestBodyArgumentResolver implements HttpServiceArgumentResolver
if (argument != null) {
ReactiveAdapter reactiveAdapter = this.reactiveAdapterRegistry.getAdapter(parameter.getParameterType());
if (reactiveAdapter != null) {
setBody(argument, parameter, reactiveAdapter, requestValues);
if (reactiveAdapter == null) {
requestValues.setBodyValue(argument);
}
else {
requestValues.setBodyValue(argument);
MethodParameter nestedParameter = parameter.nested();
String message = "Async type for @RequestBody should produce value(s)";
Assert.isTrue(!reactiveAdapter.isNoValue(), message);
Assert.isTrue(nestedParameter.getNestedParameterType() != Void.class, message);
requestValues.setBody(
reactiveAdapter.toPublisher(argument),
ParameterizedTypeReference.forType(nestedParameter.getNestedGenericParameterType()));
}
}
return true;
}
private <E> void setBody(
Object argument, MethodParameter parameter, ReactiveAdapter reactiveAdapter,
HttpRequestValues.Builder requestValues) {
String message = "Async type for @RequestBody should produce value(s)";
Assert.isTrue(!reactiveAdapter.isNoValue(), message);
parameter = parameter.nested();
Class<?> elementClass = parameter.getNestedParameterType();
Assert.isTrue(elementClass != Void.class, message);
ParameterizedTypeReference<E> typeRef = ParameterizedTypeReference.forType(parameter.getNestedGenericParameterType());
Publisher<E> publisher = reactiveAdapter.toPublisher(argument);
requestValues.setBody(publisher, typeRef);
}
}