Merge pull request #30028 from yuzawa-san:request-predicate-commit

* gh-30028:
  Polishing external contribution
  Improve attribute handling in RequestPredicates
This commit is contained in:
Arjen Poutsma 2023-09-12 12:36:23 +02:00
commit 9df735b3ab
3 changed files with 300 additions and 109 deletions

View File

@ -21,7 +21,6 @@ import java.net.URI;
import java.security.Principal; import java.security.Principal;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
@ -296,11 +295,6 @@ public abstract class RequestPredicates {
} }
} }
private static void restoreAttributes(ServerRequest request, Map<String, Object> attributes) {
request.attributes().clear();
request.attributes().putAll(attributes);
}
private static Map<String, String> mergePathVariables(Map<String, String> oldVariables, private static Map<String, String> mergePathVariables(Map<String, String> oldVariables,
Map<String, String> newVariables) { Map<String, String> newVariables) {
@ -432,13 +426,94 @@ public abstract class RequestPredicates {
} }
/**
* Extension of {@code RequestPredicate} that can modify the {@code ServerRequest}.
*/
static abstract class RequestModifyingPredicate implements RequestPredicate {
public static RequestModifyingPredicate of(RequestPredicate requestPredicate) {
if (requestPredicate instanceof RequestModifyingPredicate modifyingPredicate) {
return modifyingPredicate;
}
else {
return new RequestModifyingPredicate() {
@Override
protected Result testInternal(ServerRequest request) {
return Result.of(requestPredicate.test(request));
}
};
}
}
@Override
public final boolean test(ServerRequest request) {
Result result = testInternal(request);
boolean value = result.value();
if (value) {
result.modify(request);
}
return value;
}
protected abstract Result testInternal(ServerRequest request);
protected static final class Result {
private static final Result TRUE = new Result(true, null);
private static final Result FALSE = new Result(false, null);
private final boolean value;
@Nullable
private final Consumer<ServerRequest> modify;
private Result(boolean value, @Nullable Consumer<ServerRequest> modify) {
this.value = value;
this.modify = modify;
}
public static Result of(boolean value) {
return of(value, null);
}
public static Result of(boolean value, @Nullable Consumer<ServerRequest> commit) {
if (commit == null) {
return value ? TRUE : FALSE;
}
else {
return new Result(value, commit);
}
}
public boolean value() {
return this.value;
}
public void modify(ServerRequest request) {
if (this.modify != null) {
this.modify.accept(request);
}
}
}
}
private static class HttpMethodPredicate implements RequestPredicate { private static class HttpMethodPredicate implements RequestPredicate {
private final Set<HttpMethod> httpMethods; private final Set<HttpMethod> httpMethods;
public HttpMethodPredicate(HttpMethod httpMethod) { public HttpMethodPredicate(HttpMethod httpMethod) {
Assert.notNull(httpMethod, "HttpMethod must not be null"); Assert.notNull(httpMethod, "HttpMethod must not be null");
this.httpMethods = Collections.singleton(httpMethod); this.httpMethods = Set.of(httpMethod);
} }
public HttpMethodPredicate(HttpMethod... httpMethods) { public HttpMethodPredicate(HttpMethod... httpMethods) {
@ -482,39 +557,41 @@ public abstract class RequestPredicates {
} }
private static class PathPatternPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { private static class PathPatternPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private PathPattern pattern; private PathPattern pattern;
public PathPatternPredicate(PathPattern pattern) { public PathPatternPredicate(PathPattern pattern) {
Assert.notNull(pattern, "'pattern' must not be null"); Assert.notNull(pattern, "'pattern' must not be null");
this.pattern = pattern; this.pattern = pattern;
} }
@Override @Override
public boolean test(ServerRequest request) { protected Result testInternal(ServerRequest request) {
PathContainer pathContainer = request.requestPath().pathWithinApplication(); PathContainer pathContainer = request.requestPath().pathWithinApplication();
PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer); PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer);
traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null); traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null);
if (info != null) { if (info != null) {
mergeAttributes(request, info.getUriVariables(), this.pattern); return Result.of(true, serverRequest -> mergeAttributes(serverRequest, info.getUriVariables()));
return true;
} }
else { else {
return false; return Result.of(false);
} }
} }
private static void mergeAttributes(ServerRequest request, Map<String, String> variables, private void mergeAttributes(ServerRequest request, Map<String, String> variables) {
PathPattern pattern) { Map<String, Object> attributes = request.attributes();
Map<String, String> pathVariables = mergePathVariables(request.pathVariables(), variables); Map<String, String> pathVariables = mergePathVariables(request.pathVariables(), variables);
request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, attributes.put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Collections.unmodifiableMap(pathVariables)); Collections.unmodifiableMap(pathVariables));
pattern = mergePatterns( PathPattern pattern = mergePatterns(
(PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE), (PathPattern) attributes.get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
pattern); this.pattern);
request.attributes().put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern); attributes.put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
} }
@Override @Override
@ -756,28 +833,42 @@ public abstract class RequestPredicates {
* {@link RequestPredicate} for where both {@code left} and {@code right} predicates * {@link RequestPredicate} for where both {@code left} and {@code right} predicates
* must match. * must match.
*/ */
static class AndRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { static class AndRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate left; private final RequestPredicate left;
private final RequestModifyingPredicate leftModifying;
private final RequestPredicate right; private final RequestPredicate right;
private final RequestModifyingPredicate rightModifying;
public AndRequestPredicate(RequestPredicate left, RequestPredicate right) { public AndRequestPredicate(RequestPredicate left, RequestPredicate right) {
Assert.notNull(left, "Left RequestPredicate must not be null"); Assert.notNull(left, "Left RequestPredicate must not be null");
Assert.notNull(right, "Right RequestPredicate must not be null"); Assert.notNull(right, "Right RequestPredicate must not be null");
this.left = left; this.left = left;
this.leftModifying = of(left);
this.right = right; this.right = right;
this.rightModifying = of(right);
} }
@Override
public boolean test(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
if (this.left.test(request) && this.right.test(request)) { @Override
return true; protected Result testInternal(ServerRequest request) {
Result leftResult = this.leftModifying.testInternal(request);
if (!leftResult.value()) {
return leftResult;
} }
restoreAttributes(request, oldAttributes); Result rightResult = this.rightModifying.testInternal(request);
return false; if (!rightResult.value()) {
return rightResult;
}
return Result.of(true, serverRequest -> {
leftResult.modify(serverRequest);
rightResult.modify(serverRequest);
});
} }
@Override @Override
@ -796,11 +887,11 @@ public abstract class RequestPredicates {
@Override @Override
public void changeParser(PathPatternParser parser) { public void changeParser(PathPatternParser parser) {
if (this.left instanceof ChangePathPatternParserVisitor.Target leftTarget) { if (this.left instanceof ChangePathPatternParserVisitor.Target target) {
leftTarget.changeParser(parser); target.changeParser(parser);
} }
if (this.right instanceof ChangePathPatternParserVisitor.Target rightTarget) { if (this.right instanceof ChangePathPatternParserVisitor.Target target) {
rightTarget.changeParser(parser); target.changeParser(parser);
} }
} }
@ -814,23 +905,25 @@ public abstract class RequestPredicates {
/** /**
* {@link RequestPredicate} that negates a delegate predicate. * {@link RequestPredicate} that negates a delegate predicate.
*/ */
static class NegateRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { static class NegateRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate delegate; private final RequestPredicate delegate;
private final RequestModifyingPredicate delegateModifying;
public NegateRequestPredicate(RequestPredicate delegate) { public NegateRequestPredicate(RequestPredicate delegate) {
Assert.notNull(delegate, "Delegate must not be null"); Assert.notNull(delegate, "Delegate must not be null");
this.delegate = delegate; this.delegate = delegate;
this.delegateModifying = of(delegate);
} }
@Override @Override
public boolean test(ServerRequest request) { protected Result testInternal(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes()); Result result = this.delegateModifying.testInternal(request);
boolean result = !this.delegate.test(request); return Result.of(!result.value(), result::modify);
if (!result) {
restoreAttributes(request, oldAttributes);
}
return result;
} }
@Override @Override
@ -858,34 +951,36 @@ public abstract class RequestPredicates {
* {@link RequestPredicate} where either {@code left} or {@code right} predicates * {@link RequestPredicate} where either {@code left} or {@code right} predicates
* may match. * may match.
*/ */
static class OrRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { static class OrRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate left; private final RequestPredicate left;
private final RequestModifyingPredicate leftModifying;
private final RequestPredicate right; private final RequestPredicate right;
private final RequestModifyingPredicate rightModifying;
public OrRequestPredicate(RequestPredicate left, RequestPredicate right) { public OrRequestPredicate(RequestPredicate left, RequestPredicate right) {
Assert.notNull(left, "Left RequestPredicate must not be null"); Assert.notNull(left, "Left RequestPredicate must not be null");
Assert.notNull(right, "Right RequestPredicate must not be null"); Assert.notNull(right, "Right RequestPredicate must not be null");
this.left = left; this.left = left;
this.leftModifying = of(left);
this.right = right; this.right = right;
this.rightModifying = of(right);
} }
@Override @Override
public boolean test(ServerRequest request) { protected Result testInternal(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes()); Result leftResult = this.leftModifying.testInternal(request);
if (leftResult.value()) {
if (this.left.test(request)) { return leftResult;
return true;
} }
else { else {
restoreAttributes(request, oldAttributes); return this.rightModifying.testInternal(request);
if (this.right.test(request)) {
return true;
}
} }
restoreAttributes(request, oldAttributes);
return false;
} }
@Override @Override
@ -910,11 +1005,11 @@ public abstract class RequestPredicates {
@Override @Override
public void changeParser(PathPatternParser parser) { public void changeParser(PathPatternParser parser) {
if (this.left instanceof ChangePathPatternParserVisitor.Target leftTarget) { if (this.left instanceof ChangePathPatternParserVisitor.Target target) {
leftTarget.changeParser(parser); target.changeParser(parser);
} }
if (this.right instanceof ChangePathPatternParserVisitor.Target rightTarget) { if (this.right instanceof ChangePathPatternParserVisitor.Target target) {
rightTarget.changeParser(parser); target.changeParser(parser);
} }
} }

View File

@ -182,24 +182,25 @@ public class RequestPredicateAttributesTests {
} }
private static class AddAttributePredicate implements RequestPredicate { private static class AddAttributePredicate extends RequestPredicates.RequestModifyingPredicate {
private boolean result; private final boolean result;
private final String key; private final String key;
private final String value; private final String value;
private AddAttributePredicate(boolean result, String key, String value) {
public AddAttributePredicate(boolean result, String key, String value) {
this.result = result; this.result = result;
this.key = key; this.key = key;
this.value = value; this.value = value;
} }
@Override @Override
public boolean test(ServerRequest request) { protected Result testInternal(ServerRequest request) {
request.attributes().put(key, value); return Result.of(this.result, serverRequest -> serverRequest.attributes().put(this.key, this.value));
return this.result;
} }
} }

View File

@ -23,7 +23,6 @@ import java.security.Principal;
import java.time.Instant; import java.time.Instant;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
@ -295,11 +294,6 @@ public abstract class RequestPredicates {
} }
} }
private static void restoreAttributes(ServerRequest request, Map<String, Object> attributes) {
request.attributes().clear();
request.attributes().putAll(attributes);
}
private static Map<String, String> mergePathVariables(Map<String, String> oldVariables, private static Map<String, String> mergePathVariables(Map<String, String> oldVariables,
Map<String, String> newVariables) { Map<String, String> newVariables) {
@ -431,6 +425,87 @@ public abstract class RequestPredicates {
} }
/**
* Extension of {@code RequestPredicate} that can modify the {@code ServerRequest}.
*/
private static abstract class RequestModifyingPredicate implements RequestPredicate {
public static RequestModifyingPredicate of(RequestPredicate requestPredicate) {
if (requestPredicate instanceof RequestModifyingPredicate modifyingPredicate) {
return modifyingPredicate;
}
else {
return new RequestModifyingPredicate() {
@Override
protected Result testInternal(ServerRequest request) {
return Result.of(requestPredicate.test(request));
}
};
}
}
@Override
public final boolean test(ServerRequest request) {
Result result = testInternal(request);
boolean value = result.value();
if (value) {
result.modify(request);
}
return value;
}
protected abstract Result testInternal(ServerRequest request);
protected static final class Result {
private static final Result TRUE = new Result(true, null);
private static final Result FALSE = new Result(false, null);
private final boolean value;
@Nullable
private final Consumer<ServerRequest> modify;
private Result(boolean value, @Nullable Consumer<ServerRequest> modify) {
this.value = value;
this.modify = modify;
}
public static Result of(boolean value) {
return of(value, null);
}
public static Result of(boolean value, @Nullable Consumer<ServerRequest> commit) {
if (commit == null) {
return value ? TRUE : FALSE;
}
else {
return new Result(value, commit);
}
}
public boolean value() {
return this.value;
}
public void modify(ServerRequest request) {
if (this.modify != null) {
this.modify.accept(request);
}
}
}
}
private static class HttpMethodPredicate implements RequestPredicate { private static class HttpMethodPredicate implements RequestPredicate {
private final Set<HttpMethod> httpMethods; private final Set<HttpMethod> httpMethods;
@ -481,39 +556,41 @@ public abstract class RequestPredicates {
} }
private static class PathPatternPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { private static class PathPatternPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private PathPattern pattern; private PathPattern pattern;
public PathPatternPredicate(PathPattern pattern) { public PathPatternPredicate(PathPattern pattern) {
Assert.notNull(pattern, "'pattern' must not be null"); Assert.notNull(pattern, "'pattern' must not be null");
this.pattern = pattern; this.pattern = pattern;
} }
@Override @Override
public boolean test(ServerRequest request) { protected Result testInternal(ServerRequest request) {
PathContainer pathContainer = request.requestPath().pathWithinApplication(); PathContainer pathContainer = request.requestPath().pathWithinApplication();
PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer); PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer);
traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null); traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null);
if (info != null) { if (info != null) {
mergeAttributes(request, info.getUriVariables(), this.pattern); return Result.of(true, serverRequest -> mergeAttributes(serverRequest, info.getUriVariables()));
return true;
} }
else { else {
return false; return Result.of(false);
} }
} }
private static void mergeAttributes(ServerRequest request, Map<String, String> variables, private void mergeAttributes(ServerRequest request, Map<String, String> variables) {
PathPattern pattern) { Map<String, Object> attributes = request.attributes();
Map<String, String> pathVariables = mergePathVariables(request.pathVariables(), variables); Map<String, String> pathVariables = mergePathVariables(request.pathVariables(), variables);
request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, attributes.put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Collections.unmodifiableMap(pathVariables)); Collections.unmodifiableMap(pathVariables));
pattern = mergePatterns( PathPattern pattern = mergePatterns(
(PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE), (PathPattern) attributes.get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
pattern); this.pattern);
request.attributes().put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern); attributes.put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
} }
@Override @Override
@ -755,28 +832,42 @@ public abstract class RequestPredicates {
* {@link RequestPredicate} for where both {@code left} and {@code right} predicates * {@link RequestPredicate} for where both {@code left} and {@code right} predicates
* must match. * must match.
*/ */
static class AndRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { static class AndRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate left; private final RequestPredicate left;
private final RequestModifyingPredicate leftModifying;
private final RequestPredicate right; private final RequestPredicate right;
private final RequestModifyingPredicate rightModifying;
public AndRequestPredicate(RequestPredicate left, RequestPredicate right) { public AndRequestPredicate(RequestPredicate left, RequestPredicate right) {
Assert.notNull(left, "Left RequestPredicate must not be null"); Assert.notNull(left, "Left RequestPredicate must not be null");
Assert.notNull(right, "Right RequestPredicate must not be null"); Assert.notNull(right, "Right RequestPredicate must not be null");
this.left = left; this.left = left;
this.leftModifying = of(left);
this.right = right; this.right = right;
this.rightModifying = of(right);
} }
@Override
public boolean test(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
if (this.left.test(request) && this.right.test(request)) { @Override
return true; protected Result testInternal(ServerRequest request) {
Result leftResult = this.leftModifying.testInternal(request);
if (!leftResult.value()) {
return leftResult;
} }
restoreAttributes(request, oldAttributes); Result rightResult = this.rightModifying.testInternal(request);
return false; if (!rightResult.value()) {
return rightResult;
}
return Result.of(true, serverRequest -> {
leftResult.modify(serverRequest);
rightResult.modify(serverRequest);
});
} }
@Override @Override
@ -813,23 +904,25 @@ public abstract class RequestPredicates {
/** /**
* {@link RequestPredicate} that negates a delegate predicate. * {@link RequestPredicate} that negates a delegate predicate.
*/ */
static class NegateRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { static class NegateRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate delegate; private final RequestPredicate delegate;
private final RequestModifyingPredicate delegateModifying;
public NegateRequestPredicate(RequestPredicate delegate) { public NegateRequestPredicate(RequestPredicate delegate) {
Assert.notNull(delegate, "Delegate must not be null"); Assert.notNull(delegate, "Delegate must not be null");
this.delegate = delegate; this.delegate = delegate;
this.delegateModifying = of(delegate);
} }
@Override @Override
public boolean test(ServerRequest request) { protected Result testInternal(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes()); Result result = this.delegateModifying.testInternal(request);
boolean result = !this.delegate.test(request); return Result.of(!result.value(), result::modify);
if (!result) {
restoreAttributes(request, oldAttributes);
}
return result;
} }
@Override @Override
@ -857,34 +950,36 @@ public abstract class RequestPredicates {
* {@link RequestPredicate} where either {@code left} or {@code right} predicates * {@link RequestPredicate} where either {@code left} or {@code right} predicates
* may match. * may match.
*/ */
static class OrRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { static class OrRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate left; private final RequestPredicate left;
private final RequestModifyingPredicate leftModifying;
private final RequestPredicate right; private final RequestPredicate right;
private final RequestModifyingPredicate rightModifying;
public OrRequestPredicate(RequestPredicate left, RequestPredicate right) { public OrRequestPredicate(RequestPredicate left, RequestPredicate right) {
Assert.notNull(left, "Left RequestPredicate must not be null"); Assert.notNull(left, "Left RequestPredicate must not be null");
Assert.notNull(right, "Right RequestPredicate must not be null"); Assert.notNull(right, "Right RequestPredicate must not be null");
this.left = left; this.left = left;
this.leftModifying = of(left);
this.right = right; this.right = right;
this.rightModifying = of(right);
} }
@Override @Override
public boolean test(ServerRequest request) { protected Result testInternal(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes()); Result leftResult = this.leftModifying.testInternal(request);
if (leftResult.value()) {
if (this.left.test(request)) { return leftResult;
return true;
} }
else { else {
restoreAttributes(request, oldAttributes); return this.rightModifying.testInternal(request);
if (this.right.test(request)) {
return true;
}
} }
restoreAttributes(request, oldAttributes);
return false;
} }
@Override @Override