Merge branch '6.2.x'

This commit is contained in:
rstoyanchev 2024-12-11 16:22:47 +00:00
commit 2d5943352f
7 changed files with 351 additions and 108 deletions

View File

@ -18,19 +18,14 @@ package org.springframework.web.reactive;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.util.Collection; import java.util.Collection;
import java.util.List;
import java.util.Map; import java.util.Map;
import reactor.core.publisher.Mono;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.ui.Model; import org.springframework.ui.Model;
import org.springframework.util.CollectionUtils;
import org.springframework.validation.BindingResult; import org.springframework.validation.BindingResult;
import org.springframework.validation.DataBinder; import org.springframework.validation.DataBinder;
import org.springframework.validation.SmartValidator; import org.springframework.validation.SmartValidator;
@ -141,7 +136,7 @@ public class BindingContext {
public WebExchangeDataBinder createDataBinder( public WebExchangeDataBinder createDataBinder(
ServerWebExchange exchange, @Nullable Object target, String name, @Nullable ResolvableType targetType) { ServerWebExchange exchange, @Nullable Object target, String name, @Nullable ResolvableType targetType) {
WebExchangeDataBinder dataBinder = new ExtendedWebExchangeDataBinder(target, name); WebExchangeDataBinder dataBinder = createBinderInstance(target, name);
dataBinder.setNameResolver(new BindParamNameResolver()); dataBinder.setNameResolver(new BindParamNameResolver());
if (target == null && targetType != null) { if (target == null && targetType != null) {
@ -163,6 +158,18 @@ public class BindingContext {
return dataBinder; return dataBinder;
} }
/**
* Extension point to create the WebDataBinder instance.
* By default, this is {@code WebRequestDataBinder}.
* @param target the binding target or {@code null} for type conversion only
* @param name the binding target object name
* @return the created {@link WebExchangeDataBinder} instance
* @since 6.2.1
*/
protected WebExchangeDataBinder createBinderInstance(@Nullable Object target, String name) {
return new WebExchangeDataBinder(target, name);
}
/** /**
* Initialize the data binder instance for the given exchange. * Initialize the data binder instance for the given exchange.
* @throws ServerErrorException if {@code @InitBinder} method invocation fails * @throws ServerErrorException if {@code @InitBinder} method invocation fails
@ -200,51 +207,6 @@ public class BindingContext {
} }
/**
* Extended variant of {@link WebExchangeDataBinder}, adding path variables.
*/
private static class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder {
public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) {
super(target, objectName);
}
@Override
public Mono<Map<String, Object>> getValuesToBind(ServerWebExchange exchange) {
return super.getValuesToBind(exchange).doOnNext(map -> {
Map<String, String> vars = exchange.getAttribute(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE);
if (!CollectionUtils.isEmpty(vars)) {
vars.forEach((key, value) -> addValueIfNotPresent(map, "URI variable", key, value));
}
HttpHeaders headers = exchange.getRequest().getHeaders();
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
List<String> values = entry.getValue();
if (!CollectionUtils.isEmpty(values)) {
String name = entry.getKey().replace("-", "");
addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values));
}
}
});
}
private static void addValueIfNotPresent(
Map<String, Object> map, String label, String name, @Nullable Object value) {
if (value != null) {
if (map.containsKey(name)) {
if (logger.isDebugEnabled()) {
logger.debug(label + " '" + name + "' overridden by request bind value.");
}
}
else {
map.put(name, value);
}
}
}
}
/** /**
* Excludes Bean Validation if the method parameter has {@code @Valid}. * Excludes Bean Validation if the method parameter has {@code @Valid}.
*/ */

View File

@ -0,0 +1,120 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.result.method.annotation;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.support.WebExchangeDataBinder;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.server.ServerWebExchange;
/**
* Extended variant of {@link WebExchangeDataBinder} that adds URI path variables
* and request headers to the bind values map.
*
* <p>Note: This class has existed since 5.0, but only as a private class within
* {@link org.springframework.web.reactive.BindingContext}.
*
* @author Rossen Stoyanchev
* @since 6.2.1
*/
public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder {
private static final Set<String> FILTERED_HEADER_NAMES = Set.of("Priority");
private Predicate<String> headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name);
public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) {
super(target, objectName);
}
/**
* Add a Predicate that filters the header names to use for data binding.
* Multiple predicates are combined with {@code AND}.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void addHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = this.headerPredicate.and(headerPredicate);
}
/**
* Set the Predicate that filters the header names to use for data binding.
* <p>Note that this method resets any previous predicates that may have been
* set, including headers excluded by default such as the RFC 9218 defined
* "Priority" header.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void setHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = headerPredicate;
}
@Override
public Mono<Map<String, Object>> getValuesToBind(ServerWebExchange exchange) {
return super.getValuesToBind(exchange).doOnNext(map -> {
Map<String, String> vars = exchange.getAttribute(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE);
if (!CollectionUtils.isEmpty(vars)) {
vars.forEach((key, value) -> addValueIfNotPresent(map, "URI variable", key, value));
}
HttpHeaders headers = exchange.getRequest().getHeaders();
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
String name = entry.getKey();
if (!this.headerPredicate.test(entry.getKey())) {
continue;
}
List<String> values = entry.getValue();
if (!CollectionUtils.isEmpty(values)) {
// For constructor args with @BindParam mapped to the actual header name
addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values));
// Also adapt to Java conventions for setters
name = StringUtils.uncapitalize(entry.getKey().replace("-", ""));
addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values));
}
}
});
}
private static void addValueIfNotPresent(
Map<String, Object> map, String label, String name, @Nullable Object value) {
if (value != null) {
if (map.containsKey(name)) {
if (logger.isDebugEnabled()) {
logger.debug(label + " '" + name + "' overridden by request bind value.");
}
}
else {
map.put(name, value);
}
}
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2023 the original author or authors. * Copyright 2002-2024 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -71,6 +71,15 @@ class InitBinderBindingContext extends BindingContext {
} }
/**
* Returns an instance of {@link ExtendedWebExchangeDataBinder}.
* @since 6.2.1
*/
@Override
protected WebExchangeDataBinder createBinderInstance(@Nullable Object target, String name) {
return new ExtendedWebExchangeDataBinder(target, name);
}
@Override @Override
protected WebExchangeDataBinder initDataBinder(WebExchangeDataBinder dataBinder, ServerWebExchange exchange) { protected WebExchangeDataBinder initDataBinder(WebExchangeDataBinder dataBinder, ServerWebExchange exchange) {
this.binderMethods.stream() this.binderMethods.stream()

View File

@ -17,20 +17,16 @@
package org.springframework.web.reactive; package org.springframework.web.reactive;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Map;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.http.MediaType;
import org.springframework.validation.Errors; import org.springframework.validation.Errors;
import org.springframework.validation.SmartValidator; import org.springframework.validation.SmartValidator;
import org.springframework.validation.Validator; import org.springframework.validation.Validator;
import org.springframework.validation.beanvalidation.LocalValidatorFactoryBean; import org.springframework.validation.beanvalidation.LocalValidatorFactoryBean;
import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.support.WebExchangeDataBinder;
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest; import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest;
import org.springframework.web.testfixture.server.MockServerWebExchange; import org.springframework.web.testfixture.server.MockServerWebExchange;
@ -68,54 +64,6 @@ class BindingContextTests {
assertThat(binder.getValidatorsToApply()).containsExactly(springValidator); assertThat(binder.getValidatorsToApply()).containsExactly(springValidator);
} }
@Test
void bindUriVariablesAndHeaders() {
MockServerHttpRequest request = MockServerHttpRequest.get("/path")
.header("Some-Int-Array", "1")
.header("Some-Int-Array", "2")
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
exchange.getAttributes().put(
HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Map.of("name", "John", "age", "25"));
TestBean target = new TestBean();
BindingContext bindingContext = new BindingContext(null);
WebExchangeDataBinder binder = bindingContext.createDataBinder(exchange, target, "testBean", null);
binder.bind(exchange).block();
assertThat(target.getName()).isEqualTo("John");
assertThat(target.getAge()).isEqualTo(25);
assertThat(target.getSomeIntArray()).containsExactly(1, 2);
}
@Test
void bindUriVarsAndHeadersAddedConditionally() {
MockServerHttpRequest request = MockServerHttpRequest.post("/path")
.header("name", "Johnny")
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body("name=John&age=25");
MockServerWebExchange exchange = MockServerWebExchange.from(request);
exchange.getAttributes().put(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, Map.of("age", "26"));
TestBean target = new TestBean();
BindingContext bindingContext = new BindingContext(null);
WebExchangeDataBinder binder = bindingContext.createDataBinder(exchange, target, "testBean", null);
binder.bind(exchange).block();
assertThat(target.getName()).isEqualTo("John");
assertThat(target.getAge()).isEqualTo(25);
}
@SuppressWarnings("unused") @SuppressWarnings("unused")
private void handleValidObject(@Valid Foo foo) { private void handleValidObject(@Valid Foo foo) {
} }

View File

@ -20,18 +20,25 @@ import java.lang.reflect.Method;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.ConversionService;
import org.springframework.format.support.DefaultFormattingConversionService; import org.springframework.format.support.DefaultFormattingConversionService;
import org.springframework.http.MediaType;
import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.annotation.BindParam;
import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.InitBinder;
import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; import org.springframework.web.bind.support.ConfigurableWebBindingInitializer;
import org.springframework.web.bind.support.WebExchangeDataBinder;
import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.result.method.SyncHandlerMethodArgumentResolver; import org.springframework.web.reactive.result.method.SyncHandlerMethodArgumentResolver;
import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod; import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod;
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest; import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest;
@ -123,6 +130,95 @@ class InitBinderBindingContextTests {
assertThat(dataBinder.getDisallowedFields()[0]).isEqualToIgnoringCase("requestParam-22"); assertThat(dataBinder.getDisallowedFields()[0]).isEqualToIgnoringCase("requestParam-22");
} }
@Test
void bindUriVariablesAndHeadersViaSetters() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.get("/path")
.header("Some-Int-Array", "1")
.header("Some-Int-Array", "2")
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
exchange.getAttributes().put(
HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Map.of("name", "John", "age", "25"));
TestBean target = new TestBean();
BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class);
WebExchangeDataBinder binder = context.createDataBinder(exchange, target, "testBean", null);
binder.bind(exchange).block();
assertThat(target.getName()).isEqualTo("John");
assertThat(target.getAge()).isEqualTo(25);
assertThat(target.getSomeIntArray()).containsExactly(1, 2);
}
@Test
void bindUriVariablesAndHeadersViaConstructor() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.get("/path")
.header("Some-Int-Array", "1")
.header("Some-Int-Array", "2")
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
exchange.getAttributes().put(
HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Map.of("name", "John", "age", "25"));
BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class);
WebExchangeDataBinder binder = context.createDataBinder(exchange, null, "dataBean", null);
binder.setTargetType(ResolvableType.forClass(DataBean.class));
binder.construct(exchange).block();
DataBean bean = (DataBean) binder.getTarget();
assertThat(bean.name()).isEqualTo("John");
assertThat(bean.age()).isEqualTo(25);
assertThat(bean.someIntArray()).containsExactly(1, 2);
}
@Test
void bindUriVarsAndHeadersAddedConditionally() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.post("/path")
.header("name", "Johnny")
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body("name=John&age=25");
MockServerWebExchange exchange = MockServerWebExchange.from(request);
exchange.getAttributes().put(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, Map.of("age", "26"));
TestBean target = new TestBean();
BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class);
WebExchangeDataBinder binder = context.createDataBinder(exchange, target, "testBean", null);
binder.bind(exchange).block();
assertThat(target.getName()).isEqualTo("John");
assertThat(target.getAge()).isEqualTo(25);
}
@Test
void headerPredicate() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.get("/path")
.header("Priority", "u1")
.header("Some-Int-Array", "1")
.header("Another-Int-Array", "1")
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class);
ExtendedWebExchangeDataBinder binder = (ExtendedWebExchangeDataBinder) context.createDataBinder(exchange, null, "", null);
binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array"));
Map<String, Object> map = binder.getValuesToBind(exchange).block();
assertThat(map).containsExactlyInAnyOrderEntriesOf(Map.of("someIntArray", "1", "Some-Int-Array", "1"));
}
private BindingContext createBindingContext(String methodName, Class<?>... parameterTypes) throws Exception { private BindingContext createBindingContext(String methodName, Class<?>... parameterTypes) throws Exception {
Object handler = new InitBinderHandler(); Object handler = new InitBinderHandler();
@ -161,4 +257,8 @@ class InitBinderBindingContextTests {
} }
} }
private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) {
}
} }

View File

@ -21,12 +21,14 @@ import java.util.Enumeration;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.Predicate;
import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.MutablePropertyValues;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.ServletRequestDataBinder; import org.springframework.web.bind.ServletRequestDataBinder;
import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.HandlerMapping;
@ -51,6 +53,12 @@ import org.springframework.web.servlet.HandlerMapping;
*/ */
public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
private static final Set<String> FILTERED_HEADER_NAMES = Set.of("Priority");
private Predicate<String> headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name);
/** /**
* Create a new instance, with default object name. * Create a new instance, with default object name.
* @param target the target object to bind onto (or {@code null} * @param target the target object to bind onto (or {@code null}
@ -73,6 +81,29 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
} }
/**
* Add a Predicate that filters the header names to use for data binding.
* Multiple predicates are combined with {@code AND}.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void addHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = this.headerPredicate.and(headerPredicate);
}
/**
* Set the Predicate that filters the header names to use for data binding.
* <p>Note that this method resets any previous predicates that may have been
* set, including headers excluded by default such as the RFC 9218 defined
* "Priority" header.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void setHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = headerPredicate;
}
@Override @Override
protected ServletRequestValueResolver createValueResolver(ServletRequest request) { protected ServletRequestValueResolver createValueResolver(ServletRequest request) {
return new ExtendedServletRequestValueResolver(request, this); return new ExtendedServletRequestValueResolver(request, this);
@ -93,7 +124,7 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
String name = names.nextElement(); String name = names.nextElement();
Object value = getHeaderValue(httpRequest, name); Object value = getHeaderValue(httpRequest, name);
if (value != null) { if (value != null) {
name = name.replace("-", ""); name = StringUtils.uncapitalize(name.replace("-", ""));
addValueIfNotPresent(mpvs, "Header", name, value); addValueIfNotPresent(mpvs, "Header", name, value);
} }
} }
@ -118,7 +149,11 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
} }
@Nullable @Nullable
private static Object getHeaderValue(HttpServletRequest request, String name) { private Object getHeaderValue(HttpServletRequest request, String name) {
if (!this.headerPredicate.test(name)) {
return null;
}
Enumeration<String> valuesEnum = request.getHeaders(name); Enumeration<String> valuesEnum = request.getHeaders(name);
if (!valuesEnum.hasMoreElements()) { if (!valuesEnum.hasMoreElements()) {
return null; return null;
@ -141,7 +176,7 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
/** /**
* Resolver of values that looks up URI path variables. * Resolver of values that looks up URI path variables.
*/ */
private static class ExtendedServletRequestValueResolver extends ServletRequestValueResolver { private class ExtendedServletRequestValueResolver extends ServletRequestValueResolver {
ExtendedServletRequestValueResolver(ServletRequest request, WebDataBinder dataBinder) { ExtendedServletRequestValueResolver(ServletRequest request, WebDataBinder dataBinder) {
super(request, dataBinder); super(request, dataBinder);
@ -156,6 +191,9 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
if (uriVars != null) { if (uriVars != null) {
value = uriVars.get(name); value = uriVars.get(name);
} }
if (value == null && getRequest() instanceof HttpServletRequest httpServletRequest) {
value = getHeaderValue(httpServletRequest, name);
}
} }
return value; return value;
} }
@ -167,6 +205,13 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
if (uriVars != null) { if (uriVars != null) {
set.addAll(uriVars.keySet()); set.addAll(uriVars.keySet());
} }
if (request instanceof HttpServletRequest httpServletRequest) {
Enumeration<String> enumeration = httpServletRequest.getHeaderNames();
while (enumeration.hasMoreElements()) {
String headerName = enumeration.nextElement();
set.add(headerName.replaceAll("-", ""));
}
}
return set; return set;
} }
} }

View File

@ -18,11 +18,16 @@ package org.springframework.web.servlet.mvc.method.annotation;
import java.util.Map; import java.util.Map;
import jakarta.servlet.ServletRequest;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.core.ResolvableType;
import org.springframework.web.bind.ServletRequestDataBinder; import org.springframework.web.bind.ServletRequestDataBinder;
import org.springframework.web.bind.annotation.BindParam;
import org.springframework.web.bind.support.BindParamNameResolver;
import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
@ -45,7 +50,7 @@ class ExtendedServletRequestDataBinderTests {
@Test @Test
void createBinder() { void createBinderViaSetters() {
request.setAttribute( request.setAttribute(
HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Map.of("name", "John", "age", "25")); Map.of("name", "John", "age", "25"));
@ -62,6 +67,27 @@ class ExtendedServletRequestDataBinderTests {
assertThat(target.getSomeIntArray()).containsExactly(1, 2); assertThat(target.getSomeIntArray()).containsExactly(1, 2);
} }
@Test
void createBinderViaConstructor() {
request.setAttribute(
HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Map.of("name", "John", "age", "25"));
request.addHeader("Some-Int-Array", "1");
request.addHeader("Some-Int-Array", "2");
ServletRequestDataBinder binder = new ExtendedServletRequestDataBinder(null);
binder.setTargetType(ResolvableType.forClass(DataBean.class));
binder.setNameResolver(new BindParamNameResolver());
binder.construct(request);
DataBean bean = (DataBean) binder.getTarget();
assertThat(bean.name()).isEqualTo("John");
assertThat(bean.age()).isEqualTo(25);
assertThat(bean.someIntArray()).containsExactly(1, 2);
}
@Test @Test
void uriVarsAndHeadersAddedConditionally() { void uriVarsAndHeadersAddedConditionally() {
request.addParameter("name", "John"); request.addParameter("name", "John");
@ -78,6 +104,22 @@ class ExtendedServletRequestDataBinderTests {
assertThat(target.getAge()).isEqualTo(25); assertThat(target.getAge()).isEqualTo(25);
} }
@Test
void headerPredicate() {
TestBinder binder = new TestBinder();
binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array"));
MutablePropertyValues mpvs = new MutablePropertyValues();
request.addHeader("Priority", "u1");
request.addHeader("Some-Int-Array", "1");
request.addHeader("Another-Int-Array", "1");
binder.addBindValues(mpvs, request);
assertThat(mpvs.size()).isEqualTo(1);
assertThat(mpvs.get("someIntArray")).isEqualTo("1");
}
@Test @Test
void noUriTemplateVars() { void noUriTemplateVars() {
TestBean target = new TestBean(); TestBean target = new TestBean();
@ -88,4 +130,21 @@ class ExtendedServletRequestDataBinderTests {
assertThat(target.getAge()).isEqualTo(0); assertThat(target.getAge()).isEqualTo(0);
} }
private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) {
}
private static class TestBinder extends ExtendedServletRequestDataBinder {
public TestBinder() {
super(null);
}
@Override
public void addBindValues(MutablePropertyValues mpvs, ServletRequest request) {
super.addBindValues(mpvs, request);
}
}
} }