From 2bc93710f32a91cc1be9850625d73c80e118b1ce Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Wed, 6 Sep 2023 12:53:52 +0100 Subject: [PATCH] Reactive support in MethodValidationInterceptor Closes gh-20781 --- spring-context/spring-context.gradle | 1 + .../MethodValidationAdapter.java | 7 + .../MethodValidationInterceptor.java | 98 ++++++++++++++ .../MethodValidationProxyReactorTests.java | 121 ++++++++++++++++++ 4 files changed, 227 insertions(+) create mode 100644 spring-context/src/test/java/org/springframework/validation/beanvalidation/MethodValidationProxyReactorTests.java diff --git a/spring-context/spring-context.gradle b/spring-context/spring-context.gradle index 4d09e07664..1efab5bc24 100644 --- a/spring-context/spring-context.gradle +++ b/spring-context/spring-context.gradle @@ -46,6 +46,7 @@ dependencies { testImplementation("org.awaitility:awaitility") testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-core") testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactor") + testImplementation("io.projectreactor:reactor-test") testImplementation("io.reactivex.rxjava3:rxjava") testImplementation('io.micrometer:context-propagation') testImplementation("io.micrometer:micrometer-observation-test") diff --git a/spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationAdapter.java b/spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationAdapter.java index 562b1bc0e3..a37f27eb2a 100644 --- a/spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationAdapter.java +++ b/spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationAdapter.java @@ -140,6 +140,13 @@ public class MethodValidationAdapter implements MethodValidator { } + /** + * Return the {@link SpringValidatorAdapter} configured for use. + */ + public Supplier getSpringValidatorAdapter() { + return this.validatorAdapter; + } + /** * Set the strategy to use to determine message codes for violations. *

Default is a DefaultMessageCodesResolver. diff --git a/spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationInterceptor.java b/spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationInterceptor.java index 06798651d8..ec37ff5aba 100644 --- a/spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationInterceptor.java +++ b/spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationInterceptor.java @@ -17,25 +17,39 @@ package org.springframework.validation.beanvalidation; import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.util.Collections; +import java.util.List; import java.util.Set; import java.util.function.Supplier; import jakarta.validation.ConstraintViolation; import jakarta.validation.ConstraintViolationException; +import jakarta.validation.Valid; import jakarta.validation.Validator; import jakarta.validation.ValidatorFactory; import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import org.springframework.aop.ProxyMethodInvocation; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.SmartFactoryBean; +import org.springframework.core.MethodParameter; +import org.springframework.core.ReactiveAdapter; +import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.annotation.AnnotationUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.validation.BeanPropertyBindingResult; +import org.springframework.validation.Errors; import org.springframework.validation.annotation.Validated; import org.springframework.validation.method.MethodValidationException; import org.springframework.validation.method.MethodValidationResult; +import org.springframework.validation.method.ParameterErrors; +import org.springframework.validation.method.ParameterValidationResult; /** * An AOP Alliance {@link MethodInterceptor} implementation that delegates to a @@ -65,6 +79,10 @@ import org.springframework.validation.method.MethodValidationResult; */ public class MethodValidationInterceptor implements MethodInterceptor { + private static final boolean REACTOR_PRESENT = + ClassUtils.isPresent("reactor.core.publisher.Mono", MethodValidationInterceptor.class.getClassLoader()); + + private final MethodValidationAdapter validationAdapter; private final boolean adaptViolations; @@ -135,6 +153,12 @@ public class MethodValidationInterceptor implements MethodInterceptor { Object[] arguments = invocation.getArguments(); Class[] groups = determineValidationGroups(invocation); + if (REACTOR_PRESENT) { + arguments = ReactorValidationHelper.insertAsyncValidation( + this.validationAdapter.getSpringValidatorAdapter(), this.adaptViolations, + target, method, arguments); + } + Set> violations; if (this.adaptViolations) { @@ -206,4 +230,78 @@ public class MethodValidationInterceptor implements MethodInterceptor { return this.validationAdapter.determineValidationGroups(target, invocation.getMethod()); } + + /** + * Helper class to decorate reactive arguments with async validation. + */ + private final static class ReactorValidationHelper { + + private static final ReactiveAdapterRegistry reactiveAdapterRegistry = + ReactiveAdapterRegistry.getSharedInstance(); + + + public static Object[] insertAsyncValidation( + Supplier validatorAdapterSupplier, boolean adaptViolations, + Object target, Method method, Object[] arguments) { + + for (int i = 0; i < method.getParameterCount(); i++) { + if (arguments[i] == null) { + continue; + } + Class parameterType = method.getParameterTypes()[i]; + ReactiveAdapter reactiveAdapter = reactiveAdapterRegistry.getAdapter(parameterType); + if (reactiveAdapter == null || reactiveAdapter.isNoValue()) { + continue; + } + Class[] groups = determineValidationGroups(method.getParameters()[i]); + if (groups == null) { + continue; + } + SpringValidatorAdapter validatorAdapter = validatorAdapterSupplier.get(); + MethodParameter param = new MethodParameter(method, i); + arguments[i] = (reactiveAdapter.isMultiValue() ? + Flux.from(reactiveAdapter.toPublisher(arguments[i])).doOnNext(value -> + validate(validatorAdapter, adaptViolations, target, method, param, value, groups)) : + Mono.from(reactiveAdapter.toPublisher(arguments[i])).doOnNext(value -> + validate(validatorAdapter, adaptViolations, target, method, param, value, groups))); + } + return arguments; + } + + @Nullable + private static Class[] determineValidationGroups(Parameter parameter) { + Validated validated = AnnotationUtils.findAnnotation(parameter, Validated.class); + if (validated != null) { + return validated.value(); + } + Valid valid = AnnotationUtils.findAnnotation(parameter, Valid.class); + if (valid != null) { + return new Class[0]; + } + return null; + } + + @SuppressWarnings("unchecked") + private static void validate( + SpringValidatorAdapter validatorAdapter, boolean adaptViolations, + Object target, Method method, MethodParameter parameter, Object argument, Class[] groups) { + + if (adaptViolations) { + Errors errors = new BeanPropertyBindingResult(argument, argument.getClass().getSimpleName()); + validatorAdapter.validate(argument, errors); + if (errors.hasErrors()) { + ParameterErrors paramErrors = new ParameterErrors(parameter, argument, errors, null, null, null); + List results = Collections.singletonList(paramErrors); + throw new MethodValidationException(MethodValidationResult.create(target, method, results)); + } + } + else { + Set> violations = validatorAdapter.validate((T) argument, groups); + if (!violations.isEmpty()) { + throw new ConstraintViolationException(violations); + } + } + } + } + } diff --git a/spring-context/src/test/java/org/springframework/validation/beanvalidation/MethodValidationProxyReactorTests.java b/spring-context/src/test/java/org/springframework/validation/beanvalidation/MethodValidationProxyReactorTests.java new file mode 100644 index 0000000000..92d1b2e81e --- /dev/null +++ b/spring-context/src/test/java/org/springframework/validation/beanvalidation/MethodValidationProxyReactorTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.validation.beanvalidation; + +import java.util.Set; + +import jakarta.validation.ConstraintViolation; +import jakarta.validation.ConstraintViolationException; +import jakarta.validation.Valid; +import jakarta.validation.Validation; +import jakarta.validation.Validator; +import jakarta.validation.constraints.Size; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.validation.method.MethodValidationException; +import org.springframework.validation.method.ParameterErrors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * + */ +public class MethodValidationProxyReactorTests { + + @Test + void validMonoArgument() { + MyService myService = initProxy(new MyService(), false); + Mono personMono = Mono.just(new Person("Faustino1234")); + + StepVerifier.create(myService.addPerson(personMono)) + .expectErrorSatisfies(t -> { + ConstraintViolationException ex = (ConstraintViolationException) t; + Set> violations = ex.getConstraintViolations(); + assertThat(violations).hasSize(1); + assertThat(violations.iterator().next().getMessage()).isEqualTo("size must be between 1 and 10"); + }) + .verify(); + } + + @Test + void validFluxArgument() { + MyService myService = initProxy(new MyService(), false); + Flux personFlux = Flux.just(new Person("Faust"), new Person("Faustino1234")); + + StepVerifier.create(myService.addPersons(personFlux)) + .expectErrorSatisfies(t -> { + ConstraintViolationException ex = (ConstraintViolationException) t; + Set> violations = ex.getConstraintViolations(); + assertThat(violations).hasSize(1); + assertThat(violations.iterator().next().getMessage()).isEqualTo("size must be between 1 and 10"); + }) + .verify(); + } + + @Test + void validMonoArgumentWithAdaptedViolations() { + MyService myService = initProxy(new MyService(), true); + Mono personMono = Mono.just(new Person("Faustino1234")); + + StepVerifier.create(myService.addPerson(personMono)) + .expectErrorSatisfies(t -> { + MethodValidationException ex = (MethodValidationException) t; + assertThat(ex.getAllValidationResults()).hasSize(1); + + ParameterErrors errors = ex.getBeanResults().get(0); + assertThat(errors.getErrorCount()).isEqualTo(1); + assertThat(errors.getFieldErrors().get(0).toString()).isEqualTo(""" + Field error in object 'Person' on field 'name': rejected value [Faustino1234]; \ + codes [Size.Person.name,Size.name,Size.java.lang.String,Size]; \ + arguments [org.springframework.context.support.DefaultMessageSourceResolvable: \ + codes [Person.name,name]; arguments []; default message [name],10,1]; \ + default message [size must be between 1 and 10]"""); + }) + .verify(); + } + + private static MyService initProxy(Object target, boolean adaptViolations) { + Validator validator = Validation.buildDefaultValidatorFactory().getValidator(); + MethodValidationInterceptor interceptor = new MethodValidationInterceptor(() -> validator, adaptViolations); + ProxyFactory factory = new ProxyFactory(target); + factory.addAdvice(interceptor); + return (MyService) factory.getProxy(); + } + + + @SuppressWarnings("unused") + static class MyService { + + public Mono addPerson(@Valid Mono personMono) { + return personMono.then(); + } + + public Mono addPersons(@Valid Flux personFlux) { + return personFlux.then(); + } + } + + + @SuppressWarnings("unused") + record Person(@Size(min = 1, max = 10) String name) { + } + +}