Support header filtering in web data binding

Closes gh-34039
This commit is contained in:
rstoyanchev 2024-12-11 16:06:56 +00:00
parent 70c326ed30
commit 8aeced9f80
4 changed files with 121 additions and 4 deletions

View File

@ -18,6 +18,8 @@ package org.springframework.web.reactive.result.method.annotation;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -41,12 +43,40 @@ import org.springframework.web.server.ServerWebExchange;
*/ */
public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder { public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder {
private static final Set<String> FILTERED_HEADER_NAMES = Set.of("Priority");
private Predicate<String> headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name);
public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) { public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) {
super(target, objectName); super(target, objectName);
} }
/**
* Add a Predicate that filters the header names to use for data binding.
* Multiple predicates are combined with {@code AND}.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void addHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = this.headerPredicate.and(headerPredicate);
}
/**
* Set the Predicate that filters the header names to use for data binding.
* <p>Note that this method resets any previous predicates that may have been
* set, including headers excluded by default such as the RFC 9218 defined
* "Priority" header.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void setHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = headerPredicate;
}
@Override @Override
public Mono<Map<String, Object>> getValuesToBind(ServerWebExchange exchange) { public Mono<Map<String, Object>> getValuesToBind(ServerWebExchange exchange) {
return super.getValuesToBind(exchange).doOnNext(map -> { return super.getValuesToBind(exchange).doOnNext(map -> {
@ -56,10 +86,13 @@ public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder {
} }
HttpHeaders headers = exchange.getRequest().getHeaders(); HttpHeaders headers = exchange.getRequest().getHeaders();
for (Map.Entry<String, List<String>> entry : headers.entrySet()) { for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
String name = entry.getKey();
if (!this.headerPredicate.test(entry.getKey())) {
continue;
}
List<String> values = entry.getValue(); List<String> values = entry.getValue();
if (!CollectionUtils.isEmpty(values)) { if (!CollectionUtils.isEmpty(values)) {
// For constructor args with @BindParam mapped to the actual header name // For constructor args with @BindParam mapped to the actual header name
String name = entry.getKey();
addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values)); addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values));
// Also adapt to Java conventions for setters // Also adapt to Java conventions for setters
name = StringUtils.uncapitalize(entry.getKey().replace("-", "")); name = StringUtils.uncapitalize(entry.getKey().replace("-", ""));

View File

@ -202,6 +202,24 @@ class InitBinderBindingContextTests {
assertThat(target.getAge()).isEqualTo(25); assertThat(target.getAge()).isEqualTo(25);
} }
@Test
void headerPredicate() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.get("/path")
.header("Priority", "u1")
.header("Some-Int-Array", "1")
.header("Another-Int-Array", "1")
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class);
ExtendedWebExchangeDataBinder binder = (ExtendedWebExchangeDataBinder) context.createDataBinder(exchange, null, "", null);
binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array"));
Map<String, Object> map = binder.getValuesToBind(exchange).block();
assertThat(map).containsExactlyInAnyOrderEntriesOf(Map.of("someIntArray", "1", "Some-Int-Array", "1"));
}
private BindingContext createBindingContext(String methodName, Class<?>... parameterTypes) throws Exception { private BindingContext createBindingContext(String methodName, Class<?>... parameterTypes) throws Exception {
Object handler = new InitBinderHandler(); Object handler = new InitBinderHandler();
Method method = handler.getClass().getMethod(methodName, parameterTypes); Method method = handler.getClass().getMethod(methodName, parameterTypes);

View File

@ -21,12 +21,14 @@ import java.util.Enumeration;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.Predicate;
import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.MutablePropertyValues;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.ServletRequestDataBinder; import org.springframework.web.bind.ServletRequestDataBinder;
import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.HandlerMapping;
@ -51,6 +53,12 @@ import org.springframework.web.servlet.HandlerMapping;
*/ */
public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder { public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
private static final Set<String> FILTERED_HEADER_NAMES = Set.of("Priority");
private Predicate<String> headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name);
/** /**
* Create a new instance, with default object name. * Create a new instance, with default object name.
* @param target the target object to bind onto (or {@code null} * @param target the target object to bind onto (or {@code null}
@ -73,6 +81,29 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
} }
/**
* Add a Predicate that filters the header names to use for data binding.
* Multiple predicates are combined with {@code AND}.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void addHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = this.headerPredicate.and(headerPredicate);
}
/**
* Set the Predicate that filters the header names to use for data binding.
* <p>Note that this method resets any previous predicates that may have been
* set, including headers excluded by default such as the RFC 9218 defined
* "Priority" header.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void setHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = headerPredicate;
}
@Override @Override
protected ServletRequestValueResolver createValueResolver(ServletRequest request) { protected ServletRequestValueResolver createValueResolver(ServletRequest request) {
return new ExtendedServletRequestValueResolver(request, this); return new ExtendedServletRequestValueResolver(request, this);
@ -93,7 +124,7 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
String name = names.nextElement(); String name = names.nextElement();
Object value = getHeaderValue(httpRequest, name); Object value = getHeaderValue(httpRequest, name);
if (value != null) { if (value != null) {
name = name.replace("-", ""); name = StringUtils.uncapitalize(name.replace("-", ""));
addValueIfNotPresent(mpvs, "Header", name, value); addValueIfNotPresent(mpvs, "Header", name, value);
} }
} }
@ -118,7 +149,11 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
} }
@Nullable @Nullable
private static Object getHeaderValue(HttpServletRequest request, String name) { private Object getHeaderValue(HttpServletRequest request, String name) {
if (!this.headerPredicate.test(name)) {
return null;
}
Enumeration<String> valuesEnum = request.getHeaders(name); Enumeration<String> valuesEnum = request.getHeaders(name);
if (!valuesEnum.hasMoreElements()) { if (!valuesEnum.hasMoreElements()) {
return null; return null;
@ -141,7 +176,7 @@ public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {
/** /**
* Resolver of values that looks up URI path variables. * Resolver of values that looks up URI path variables.
*/ */
private static class ExtendedServletRequestValueResolver extends ServletRequestValueResolver { private class ExtendedServletRequestValueResolver extends ServletRequestValueResolver {
ExtendedServletRequestValueResolver(ServletRequest request, WebDataBinder dataBinder) { ExtendedServletRequestValueResolver(ServletRequest request, WebDataBinder dataBinder) {
super(request, dataBinder); super(request, dataBinder);

View File

@ -18,9 +18,11 @@ package org.springframework.web.servlet.mvc.method.annotation;
import java.util.Map; import java.util.Map;
import jakarta.servlet.ServletRequest;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.web.bind.ServletRequestDataBinder; import org.springframework.web.bind.ServletRequestDataBinder;
@ -102,6 +104,22 @@ class ExtendedServletRequestDataBinderTests {
assertThat(target.getAge()).isEqualTo(25); assertThat(target.getAge()).isEqualTo(25);
} }
@Test
void headerPredicate() {
TestBinder binder = new TestBinder();
binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array"));
MutablePropertyValues mpvs = new MutablePropertyValues();
request.addHeader("Priority", "u1");
request.addHeader("Some-Int-Array", "1");
request.addHeader("Another-Int-Array", "1");
binder.addBindValues(mpvs, request);
assertThat(mpvs.size()).isEqualTo(1);
assertThat(mpvs.get("someIntArray")).isEqualTo("1");
}
@Test @Test
void noUriTemplateVars() { void noUriTemplateVars() {
TestBean target = new TestBean(); TestBean target = new TestBean();
@ -116,4 +134,17 @@ class ExtendedServletRequestDataBinderTests {
private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) { private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) {
} }
private static class TestBinder extends ExtendedServletRequestDataBinder {
public TestBinder() {
super(null);
}
@Override
public void addBindValues(MutablePropertyValues mpvs, ServletRequest request) {
super.addBindValues(mpvs, request);
}
}
} }