Support header filtering in web data binding

Closes gh-34039
This commit is contained in:
rstoyanchev 2024-12-11 16:06:56 +00:00
parent 70c326ed30
commit 8aeced9f80
4 changed files with 121 additions and 4 deletions

View File

@ -18,6 +18,8 @@ package org.springframework.web.reactive.result.method.annotation;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import reactor.core.publisher.Mono;
@ -41,12 +43,40 @@ import org.springframework.web.server.ServerWebExchange;
*/
public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder {
private static final Set<String> FILTERED_HEADER_NAMES = Set.of("Priority");
private Predicate<String> headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name);
public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) {
super(target, objectName);
}
/**
* Add a Predicate that filters the header names to use for data binding.
* Multiple predicates are combined with {@code AND}.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void addHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = this.headerPredicate.and(headerPredicate);
}
/**
* Set the Predicate that filters the header names to use for data binding.
* <p>Note that this method resets any previous predicates that may have been
* set, including headers excluded by default such as the RFC 9218 defined
* "Priority" header.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void setHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = headerPredicate;
}
@Override
public Mono<Map<String, Object>> getValuesToBind(ServerWebExchange exchange) {
return super.getValuesToBind(exchange).doOnNext(map -> {
@ -56,10 +86,13 @@ public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder {
}
HttpHeaders headers = exchange.getRequest().getHeaders();
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
String name = entry.getKey();
if (!this.headerPredicate.test(entry.getKey())) {
continue;
}
List<String> values = entry.getValue();
if (!CollectionUtils.isEmpty(values)) {
// For constructor args with @BindParam mapped to the actual header name
String name = entry.getKey();
addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values));
// Also adapt to Java conventions for setters
name = StringUtils.uncapitalize(entry.getKey().replace("-", ""));

View File

@ -202,6 +202,24 @@ class InitBinderBindingContextTests {
assertThat(target.getAge()).isEqualTo(25);
}
@Test
void headerPredicate() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.get("/path")
.header("Priority", "u1")
.header("Some-Int-Array", "1")
.header("Another-Int-Array", "1")
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class);
ExtendedWebExchangeDataBinder binder = (ExtendedWebExchangeDataBinder) context.createDataBinder(exchange, null, "", null);
binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array"));
Map<String, Object> map = binder.getValuesToBind(exchange).block();
assertThat(map).containsExactlyInAnyOrderEntriesOf(Map.of("someIntArray", "1", "Some-Int-Array", "1"));
}
private BindingContext createBindingContext(String methodName, Class<?>... parameterTypes) throws Exception {
Object handler = new InitBinderHandler();
Method method = handler.getClass().getMethod(methodName, parameterTypes);

View File

@ -21,12 +21,14 @@ import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.ServletRequestDataBinder;
import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.servlet.HandlerMapping;
@ -51,6 +53,12 @@ import org.springframework.web.servlet.HandlerMapping;
*/
public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
private static final Set<String> FILTERED_HEADER_NAMES = Set.of("Priority");
private Predicate<String> headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name);
/**
* Create a new instance, with default object name.
* @param target the target object to bind onto (or {@code null}
@ -73,6 +81,29 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
}
/**
* Add a Predicate that filters the header names to use for data binding.
* Multiple predicates are combined with {@code AND}.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void addHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = this.headerPredicate.and(headerPredicate);
}
/**
* Set the Predicate that filters the header names to use for data binding.
* <p>Note that this method resets any previous predicates that may have been
* set, including headers excluded by default such as the RFC 9218 defined
* "Priority" header.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void setHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = headerPredicate;
}
@Override
protected ServletRequestValueResolver createValueResolver(ServletRequest request) {
return new ExtendedServletRequestValueResolver(request, this);
@ -93,7 +124,7 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
String name = names.nextElement();
Object value = getHeaderValue(httpRequest, name);
if (value != null) {
name = name.replace("-", "");
name = StringUtils.uncapitalize(name.replace("-", ""));
addValueIfNotPresent(mpvs, "Header", name, value);
}
}
@ -118,7 +149,11 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
}
@Nullable
private static Object getHeaderValue(HttpServletRequest request, String name) {
private Object getHeaderValue(HttpServletRequest request, String name) {
if (!this.headerPredicate.test(name)) {
return null;
}
Enumeration<String> valuesEnum = request.getHeaders(name);
if (!valuesEnum.hasMoreElements()) {
return null;
@ -141,7 +176,7 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
/**
* Resolver of values that looks up URI path variables.
*/
private static class ExtendedServletRequestValueResolver extends ServletRequestValueResolver {
private class ExtendedServletRequestValueResolver extends ServletRequestValueResolver {
ExtendedServletRequestValueResolver(ServletRequest request, WebDataBinder dataBinder) {
super(request, dataBinder);

View File

@ -18,9 +18,11 @@ package org.springframework.web.servlet.mvc.method.annotation;
import java.util.Map;
import jakarta.servlet.ServletRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.core.ResolvableType;
import org.springframework.web.bind.ServletRequestDataBinder;
@ -102,6 +104,22 @@ class ExtendedServletRequestDataBinderTests {
assertThat(target.getAge()).isEqualTo(25);
}
@Test
void headerPredicate() {
TestBinder binder = new TestBinder();
binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array"));
MutablePropertyValues mpvs = new MutablePropertyValues();
request.addHeader("Priority", "u1");
request.addHeader("Some-Int-Array", "1");
request.addHeader("Another-Int-Array", "1");
binder.addBindValues(mpvs, request);
assertThat(mpvs.size()).isEqualTo(1);
assertThat(mpvs.get("someIntArray")).isEqualTo("1");
}
@Test
void noUriTemplateVars() {
TestBean target = new TestBean();
@ -116,4 +134,17 @@ class ExtendedServletRequestDataBinderTests {
private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) {
}
private static class TestBinder extends ExtendedServletRequestDataBinder {
public TestBinder() {
super(null);
}
@Override
public void addBindValues(MutablePropertyValues mpvs, ServletRequest request) {
super.addBindValues(mpvs, request);
}
}
}