Validation support for @RequestBody with @Validated

This commit is contained in:
Rossen Stoyanchev 2016-06-10 15:52:29 -04:00
parent 0a2c3c3744
commit 2f8baac4e0
2 changed files with 112 additions and 14 deletions

View File

@ -16,6 +16,7 @@
package org.springframework.web.reactive.result.method.annotation; package org.springframework.web.reactive.result.method.annotation;
import java.lang.annotation.Annotation;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -23,16 +24,25 @@ import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.core.Conventions;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.ConversionService;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.converter.reactive.HttpMessageConverter; import org.springframework.http.converter.reactive.HttpMessageConverter;
import org.springframework.ui.ModelMap; import org.springframework.ui.ModelMap;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.validation.BeanPropertyBindingResult;
import org.springframework.validation.Errors;
import org.springframework.validation.SmartValidator;
import org.springframework.validation.Validator;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolver; import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolver;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebInputException;
import org.springframework.web.server.UnsupportedMediaTypeStatusException; import org.springframework.web.server.UnsupportedMediaTypeStatusException;
/** /**
@ -50,6 +60,8 @@ public class RequestBodyArgumentResolver implements HandlerMethodArgumentResolve
private final ConversionService conversionService; private final ConversionService conversionService;
private final Validator validator;
private final List<MediaType> supportedMediaTypes; private final List<MediaType> supportedMediaTypes;
@ -61,10 +73,23 @@ public class RequestBodyArgumentResolver implements HandlerMethodArgumentResolve
public RequestBodyArgumentResolver(List<HttpMessageConverter<?>> converters, public RequestBodyArgumentResolver(List<HttpMessageConverter<?>> converters,
ConversionService service) { ConversionService service) {
this(converters, service, null);
}
/**
* Constructor with message converters and a ConversionService.
* @param converters converters for reading the request body with
* @param service for converting to other reactive types from Flux and Mono
* @param validator validator to validate decoded objects with
*/
public RequestBodyArgumentResolver(List<HttpMessageConverter<?>> converters,
ConversionService service, Validator validator) {
Assert.notEmpty(converters, "At least one message converter is required."); Assert.notEmpty(converters, "At least one message converter is required.");
Assert.notNull(service, "'conversionService' is required."); Assert.notNull(service, "'conversionService' is required.");
this.messageConverters = converters; this.messageConverters = converters;
this.conversionService = service; this.conversionService = service;
this.validator = validator;
this.supportedMediaTypes = converters.stream() this.supportedMediaTypes = converters.stream()
.flatMap(converter -> converter.getReadableMediaTypes().stream()) .flatMap(converter -> converter.getReadableMediaTypes().stream())
.collect(Collectors.toList()); .collect(Collectors.toList());
@ -107,6 +132,11 @@ public class RequestBodyArgumentResolver implements HandlerMethodArgumentResolve
for (HttpMessageConverter<?> converter : getMessageConverters()) { for (HttpMessageConverter<?> converter : getMessageConverters()) {
if (converter.canRead(elementType, mediaType)) { if (converter.canRead(elementType, mediaType)) {
Flux<?> elementFlux = converter.read(elementType, exchange.getRequest()); Flux<?> elementFlux = converter.read(elementType, exchange.getRequest());
if (this.validator != null) {
elementFlux= applyValidationIfApplicable(elementFlux, parameter);
}
if (Mono.class.equals(type.getRawClass())) { if (Mono.class.equals(type.getRawClass())) {
return Mono.just(Mono.from(elementFlux)); return Mono.just(Mono.from(elementFlux));
} }
@ -130,4 +160,37 @@ public class RequestBodyArgumentResolver implements HandlerMethodArgumentResolve
getConversionService().canConvert(Publisher.class, type.getRawClass())); getConversionService().canConvert(Publisher.class, type.getRawClass()));
} }
protected Flux<?> applyValidationIfApplicable(Flux<?> elementFlux, MethodParameter methodParam) {
Annotation[] annotations = methodParam.getParameterAnnotations();
for (Annotation ann : annotations) {
Validated validAnnot = AnnotationUtils.getAnnotation(ann, Validated.class);
if (validAnnot != null || ann.annotationType().getSimpleName().startsWith("Valid")) {
Object hints = (validAnnot != null ? validAnnot.value() : AnnotationUtils.getValue(ann));
Object[] validationHints = (hints instanceof Object[] ? (Object[]) hints : new Object[] {hints});
return elementFlux.map(element -> {
validate(element, validationHints, methodParam);
return element;
});
}
}
return elementFlux;
}
/**
* TODO: replace with use of DataBinder
*/
private void validate(Object target, Object[] validationHints, MethodParameter methodParam) {
String name = Conventions.getVariableNameForParameter(methodParam);
Errors errors = new BeanPropertyBindingResult(target, name);
if (!ObjectUtils.isEmpty(validationHints) && this.validator instanceof SmartValidator) {
((SmartValidator) this.validator).validate(target, errors, validationHints);
}
else if (this.validator != null) {
this.validator.validate(target, errors);
}
if (errors.hasErrors()) {
throw new ServerWebInputException("Validation failed", methodParam);
}
}
} }

View File

@ -27,7 +27,6 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import javax.xml.bind.annotation.XmlRootElement; import javax.xml.bind.annotation.XmlRootElement;
import org.junit.Before; import org.junit.Before;
@ -61,8 +60,12 @@ import org.springframework.http.server.reactive.MockServerHttpResponse;
import org.springframework.ui.ExtendedModelMap; import org.springframework.ui.ExtendedModelMap;
import org.springframework.ui.ModelMap; import org.springframework.ui.ModelMap;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import org.springframework.validation.Errors;
import org.springframework.validation.Validator;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebInputException;
import org.springframework.web.server.UnsupportedMediaTypeStatusException; import org.springframework.web.server.UnsupportedMediaTypeStatusException;
import org.springframework.web.server.adapter.DefaultServerWebExchange; import org.springframework.web.server.adapter.DefaultServerWebExchange;
import org.springframework.web.server.session.DefaultWebSessionManager; import org.springframework.web.server.session.DefaultWebSessionManager;
@ -120,28 +123,28 @@ public class RequestBodyArgumentResolverTests {
@Test @SuppressWarnings("unchecked") @Test @SuppressWarnings("unchecked")
public void monoTestBean() throws Exception { public void monoTestBean() throws Exception {
String body = "{\"bar\":\"b1\",\"foo\":\"f1\"}"; String body = "{\"bar\":\"b1\",\"foo\":\"f1\"}";
Mono<TestBean> mono = (Mono<TestBean>) resolve("monoTestBean", Mono.class, body); Mono<TestBean> mono = (Mono<TestBean>) resolveValue("monoTestBean", Mono.class, body);
assertEquals(new TestBean("f1", "b1"), mono.block()); assertEquals(new TestBean("f1", "b1"), mono.block());
} }
@Test @SuppressWarnings("unchecked") @Test @SuppressWarnings("unchecked")
public void fluxTestBean() throws Exception { public void fluxTestBean() throws Exception {
String body = "[{\"bar\":\"b1\",\"foo\":\"f1\"},{\"bar\":\"b2\",\"foo\":\"f2\"}]"; String body = "[{\"bar\":\"b1\",\"foo\":\"f1\"},{\"bar\":\"b2\",\"foo\":\"f2\"}]";
Flux<TestBean> flux = (Flux<TestBean>) resolve("fluxTestBean", Flux.class, body); Flux<TestBean> flux = (Flux<TestBean>) resolveValue("fluxTestBean", Flux.class, body);
assertEquals(Arrays.asList(new TestBean("f1", "b1"), new TestBean("f2", "b2")), flux.collectList().block()); assertEquals(Arrays.asList(new TestBean("f1", "b1"), new TestBean("f2", "b2")), flux.collectList().block());
} }
@Test @SuppressWarnings("unchecked") @Test @SuppressWarnings("unchecked")
public void singleTestBean() throws Exception { public void singleTestBean() throws Exception {
String body = "{\"bar\":\"b1\",\"foo\":\"f1\"}"; String body = "{\"bar\":\"b1\",\"foo\":\"f1\"}";
Single<TestBean> single = (Single<TestBean>) resolve("singleTestBean", Single.class, body); Single<TestBean> single = (Single<TestBean>) resolveValue("singleTestBean", Single.class, body);
assertEquals(new TestBean("f1", "b1"), single.toBlocking().value()); assertEquals(new TestBean("f1", "b1"), single.toBlocking().value());
} }
@Test @SuppressWarnings("unchecked") @Test @SuppressWarnings("unchecked")
public void observableTestBean() throws Exception { public void observableTestBean() throws Exception {
String body = "[{\"bar\":\"b1\",\"foo\":\"f1\"},{\"bar\":\"b2\",\"foo\":\"f2\"}]"; String body = "[{\"bar\":\"b1\",\"foo\":\"f1\"},{\"bar\":\"b2\",\"foo\":\"f2\"}]";
Observable<?> observable = (Observable<?>) resolve("observableTestBean", Observable.class, body); Observable<?> observable = (Observable<?>) resolveValue("observableTestBean", Observable.class, body);
assertEquals(Arrays.asList(new TestBean("f1", "b1"), new TestBean("f2", "b2")), assertEquals(Arrays.asList(new TestBean("f1", "b1"), new TestBean("f2", "b2")),
observable.toList().toBlocking().first()); observable.toList().toBlocking().first());
} }
@ -149,13 +152,13 @@ public class RequestBodyArgumentResolverTests {
@Test @SuppressWarnings("unchecked") @Test @SuppressWarnings("unchecked")
public void futureTestBean() throws Exception { public void futureTestBean() throws Exception {
String body = "{\"bar\":\"b1\",\"foo\":\"f1\"}"; String body = "{\"bar\":\"b1\",\"foo\":\"f1\"}";
assertEquals(new TestBean("f1", "b1"), resolve("futureTestBean", CompletableFuture.class, body).get()); assertEquals(new TestBean("f1", "b1"), resolveValue("futureTestBean", CompletableFuture.class, body).get());
} }
@Test @Test
public void testBean() throws Exception { public void testBean() throws Exception {
String body = "{\"bar\":\"b1\",\"foo\":\"f1\"}"; String body = "{\"bar\":\"b1\",\"foo\":\"f1\"}";
assertEquals(new TestBean("f1", "b1"), resolve("testBean", TestBean.class, body)); assertEquals(new TestBean("f1", "b1"), resolveValue("testBean", TestBean.class, body));
} }
@Test @Test
@ -164,7 +167,7 @@ public class RequestBodyArgumentResolverTests {
Map<String, String> map = new HashMap<>(); Map<String, String> map = new HashMap<>();
map.put("foo", "f1"); map.put("foo", "f1");
map.put("bar", "b1"); map.put("bar", "b1");
assertEquals(map, resolve("map", Map.class, body)); assertEquals(map, resolveValue("map", Map.class, body));
} }
// TODO: @Ignore // TODO: @Ignore
@ -174,7 +177,7 @@ public class RequestBodyArgumentResolverTests {
public void list() throws Exception { public void list() throws Exception {
String body = "[{\"bar\":\"b1\",\"foo\":\"f1\"},{\"bar\":\"b2\",\"foo\":\"f2\"}]"; String body = "[{\"bar\":\"b1\",\"foo\":\"f1\"},{\"bar\":\"b2\",\"foo\":\"f2\"}]";
assertEquals(Arrays.asList(new TestBean("f1", "b1"), new TestBean("f2", "b2")), assertEquals(Arrays.asList(new TestBean("f1", "b1"), new TestBean("f2", "b2")),
resolve("list", List.class, body)); resolveValue("list", List.class, body));
} }
@Test @Test
@ -182,12 +185,28 @@ public class RequestBodyArgumentResolverTests {
public void array() throws Exception { public void array() throws Exception {
String body = "[{\"bar\":\"b1\",\"foo\":\"f1\"},{\"bar\":\"b2\",\"foo\":\"f2\"}]"; String body = "[{\"bar\":\"b1\",\"foo\":\"f1\"},{\"bar\":\"b2\",\"foo\":\"f2\"}]";
assertArrayEquals(new TestBean[] {new TestBean("f1", "b1"), new TestBean("f2", "b2")}, assertArrayEquals(new TestBean[] {new TestBean("f1", "b1"), new TestBean("f2", "b2")},
resolve("array", TestBean[].class, body)); resolveValue("array", TestBean[].class, body));
}
@Test @SuppressWarnings("unchecked")
public void validateMonoTestBean() throws Exception {
String body = "{\"bar\":\"b1\"}";
Mono<TestBean> mono = (Mono<TestBean>) resolveValue("monoTestBean", Mono.class, body);
TestSubscriber.subscribe(mono).assertNoValues().assertError(ServerWebInputException.class);
}
@Test @SuppressWarnings("unchecked")
public void validateFluxTestBean() throws Exception {
String body = "[{\"bar\":\"b1\",\"foo\":\"f1\"},{\"bar\":\"b2\"}]";
Flux<TestBean> flux = (Flux<TestBean>) resolveValue("fluxTestBean", Flux.class, body);
TestSubscriber.subscribe(flux).assertValues(new TestBean("f1", "b1"))
.assertError(ServerWebInputException.class);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private <T> T resolve(String paramName, Class<T> valueType, String body) { private <T> T resolveValue(String paramName, Class<T> valueType, String body) {
this.request.getHeaders().setContentType(MediaType.APPLICATION_JSON); this.request.getHeaders().setContentType(MediaType.APPLICATION_JSON);
this.request.writeWith(Flux.just(dataBuffer(body))); this.request.writeWith(Flux.just(dataBuffer(body)));
Mono<Object> result = this.resolver.resolveArgument(parameter(paramName), this.model, this.exchange); Mono<Object> result = this.resolver.resolveArgument(parameter(paramName), this.model, this.exchange);
@ -204,7 +223,7 @@ public class RequestBodyArgumentResolverTests {
GenericConversionService service = new GenericConversionService(); GenericConversionService service = new GenericConversionService();
service.addConverter(new ReactiveStreamsToCompletableFutureConverter()); service.addConverter(new ReactiveStreamsToCompletableFutureConverter());
service.addConverter(new ReactiveStreamsToRxJava1Converter()); service.addConverter(new ReactiveStreamsToRxJava1Converter());
return new RequestBodyArgumentResolver(converters, service); return new RequestBodyArgumentResolver(converters, service, new TestBeanValidator());
} }
@SuppressWarnings("ConfusingArgumentToVarargsMethod") @SuppressWarnings("ConfusingArgumentToVarargsMethod")
@ -230,8 +249,8 @@ public class RequestBodyArgumentResolverTests {
@SuppressWarnings("unused") @SuppressWarnings("unused")
void handle( void handle(
@RequestBody Mono<TestBean> monoTestBean, @Validated @RequestBody Mono<TestBean> monoTestBean,
@RequestBody Flux<TestBean> fluxTestBean, @Validated @RequestBody Flux<TestBean> fluxTestBean,
@RequestBody Single<TestBean> singleTestBean, @RequestBody Single<TestBean> singleTestBean,
@RequestBody Observable<TestBean> observableTestBean, @RequestBody Observable<TestBean> observableTestBean,
@RequestBody CompletableFuture<TestBean> futureTestBean, @RequestBody CompletableFuture<TestBean> futureTestBean,
@ -297,4 +316,20 @@ public class RequestBodyArgumentResolverTests {
return "TestBean[foo='" + this.foo + "\'" + ", bar='" + this.bar + "\']"; return "TestBean[foo='" + this.foo + "\'" + ", bar='" + this.bar + "\']";
} }
} }
static class TestBeanValidator implements Validator {
@Override
public boolean supports(Class<?> clazz) {
return clazz.equals(TestBean.class);
}
@Override
public void validate(Object target, Errors errors) {
TestBean testBean = (TestBean) target;
if (testBean.getFoo() == null) {
errors.rejectValue("foo", "nullValue");
}
}
}
} }