Merge pull request #30028 from yuzawa-san:request-predicate-commit

* gh-30028:
  Polishing external contribution
  Improve attribute handling in RequestPredicates
This commit is contained in:
Arjen Poutsma 2023-09-12 12:36:23 +02:00
commit 9df735b3ab
3 changed files with 300 additions and 109 deletions

View File

@ -21,7 +21,6 @@ import java.net.URI;
import java.security.Principal;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
@ -296,11 +295,6 @@ public abstract class RequestPredicates {
}
}
private static void restoreAttributes(ServerRequest request, Map<String, Object> attributes) {
request.attributes().clear();
request.attributes().putAll(attributes);
}
private static Map<String, String> mergePathVariables(Map<String, String> oldVariables,
Map<String, String> newVariables) {
@ -432,13 +426,94 @@ public abstract class RequestPredicates {
}
/**
* Extension of {@code RequestPredicate} that can modify the {@code ServerRequest}.
*/
static abstract class RequestModifyingPredicate implements RequestPredicate {
public static RequestModifyingPredicate of(RequestPredicate requestPredicate) {
if (requestPredicate instanceof RequestModifyingPredicate modifyingPredicate) {
return modifyingPredicate;
}
else {
return new RequestModifyingPredicate() {
@Override
protected Result testInternal(ServerRequest request) {
return Result.of(requestPredicate.test(request));
}
};
}
}
@Override
public final boolean test(ServerRequest request) {
Result result = testInternal(request);
boolean value = result.value();
if (value) {
result.modify(request);
}
return value;
}
protected abstract Result testInternal(ServerRequest request);
protected static final class Result {
private static final Result TRUE = new Result(true, null);
private static final Result FALSE = new Result(false, null);
private final boolean value;
@Nullable
private final Consumer<ServerRequest> modify;
private Result(boolean value, @Nullable Consumer<ServerRequest> modify) {
this.value = value;
this.modify = modify;
}
public static Result of(boolean value) {
return of(value, null);
}
public static Result of(boolean value, @Nullable Consumer<ServerRequest> commit) {
if (commit == null) {
return value ? TRUE : FALSE;
}
else {
return new Result(value, commit);
}
}
public boolean value() {
return this.value;
}
public void modify(ServerRequest request) {
if (this.modify != null) {
this.modify.accept(request);
}
}
}
}
private static class HttpMethodPredicate implements RequestPredicate {
private final Set<HttpMethod> httpMethods;
public HttpMethodPredicate(HttpMethod httpMethod) {
Assert.notNull(httpMethod, "HttpMethod must not be null");
this.httpMethods = Collections.singleton(httpMethod);
this.httpMethods = Set.of(httpMethod);
}
public HttpMethodPredicate(HttpMethod... httpMethods) {
@ -482,39 +557,41 @@ public abstract class RequestPredicates {
}
private static class PathPatternPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
private static class PathPatternPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private PathPattern pattern;
public PathPatternPredicate(PathPattern pattern) {
Assert.notNull(pattern, "'pattern' must not be null");
this.pattern = pattern;
}
@Override
public boolean test(ServerRequest request) {
protected Result testInternal(ServerRequest request) {
PathContainer pathContainer = request.requestPath().pathWithinApplication();
PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer);
traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null);
if (info != null) {
mergeAttributes(request, info.getUriVariables(), this.pattern);
return true;
return Result.of(true, serverRequest -> mergeAttributes(serverRequest, info.getUriVariables()));
}
else {
return false;
return Result.of(false);
}
}
private static void mergeAttributes(ServerRequest request, Map<String, String> variables,
PathPattern pattern) {
private void mergeAttributes(ServerRequest request, Map<String, String> variables) {
Map<String, Object> attributes = request.attributes();
Map<String, String> pathVariables = mergePathVariables(request.pathVariables(), variables);
request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Collections.unmodifiableMap(pathVariables));
attributes.put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Collections.unmodifiableMap(pathVariables));
pattern = mergePatterns(
(PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
pattern);
request.attributes().put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
PathPattern pattern = mergePatterns(
(PathPattern) attributes.get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
this.pattern);
attributes.put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
}
@Override
@ -756,28 +833,42 @@ public abstract class RequestPredicates {
* {@link RequestPredicate} for where both {@code left} and {@code right} predicates
* must match.
*/
static class AndRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
static class AndRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate left;
private final RequestModifyingPredicate leftModifying;
private final RequestPredicate right;
private final RequestModifyingPredicate rightModifying;
public AndRequestPredicate(RequestPredicate left, RequestPredicate right) {
Assert.notNull(left, "Left RequestPredicate must not be null");
Assert.notNull(right, "Right RequestPredicate must not be null");
this.left = left;
this.leftModifying = of(left);
this.right = right;
this.rightModifying = of(right);
}
@Override
public boolean test(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
if (this.left.test(request) && this.right.test(request)) {
return true;
@Override
protected Result testInternal(ServerRequest request) {
Result leftResult = this.leftModifying.testInternal(request);
if (!leftResult.value()) {
return leftResult;
}
restoreAttributes(request, oldAttributes);
return false;
Result rightResult = this.rightModifying.testInternal(request);
if (!rightResult.value()) {
return rightResult;
}
return Result.of(true, serverRequest -> {
leftResult.modify(serverRequest);
rightResult.modify(serverRequest);
});
}
@Override
@ -796,11 +887,11 @@ public abstract class RequestPredicates {
@Override
public void changeParser(PathPatternParser parser) {
if (this.left instanceof ChangePathPatternParserVisitor.Target leftTarget) {
leftTarget.changeParser(parser);
if (this.left instanceof ChangePathPatternParserVisitor.Target target) {
target.changeParser(parser);
}
if (this.right instanceof ChangePathPatternParserVisitor.Target rightTarget) {
rightTarget.changeParser(parser);
if (this.right instanceof ChangePathPatternParserVisitor.Target target) {
target.changeParser(parser);
}
}
@ -814,23 +905,25 @@ public abstract class RequestPredicates {
/**
* {@link RequestPredicate} that negates a delegate predicate.
*/
static class NegateRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
static class NegateRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate delegate;
private final RequestModifyingPredicate delegateModifying;
public NegateRequestPredicate(RequestPredicate delegate) {
Assert.notNull(delegate, "Delegate must not be null");
this.delegate = delegate;
this.delegateModifying = of(delegate);
}
@Override
public boolean test(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
boolean result = !this.delegate.test(request);
if (!result) {
restoreAttributes(request, oldAttributes);
}
return result;
protected Result testInternal(ServerRequest request) {
Result result = this.delegateModifying.testInternal(request);
return Result.of(!result.value(), result::modify);
}
@Override
@ -858,34 +951,36 @@ public abstract class RequestPredicates {
* {@link RequestPredicate} where either {@code left} or {@code right} predicates
* may match.
*/
static class OrRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
static class OrRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate left;
private final RequestModifyingPredicate leftModifying;
private final RequestPredicate right;
private final RequestModifyingPredicate rightModifying;
public OrRequestPredicate(RequestPredicate left, RequestPredicate right) {
Assert.notNull(left, "Left RequestPredicate must not be null");
Assert.notNull(right, "Right RequestPredicate must not be null");
this.left = left;
this.leftModifying = of(left);
this.right = right;
this.rightModifying = of(right);
}
@Override
public boolean test(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
if (this.left.test(request)) {
return true;
protected Result testInternal(ServerRequest request) {
Result leftResult = this.leftModifying.testInternal(request);
if (leftResult.value()) {
return leftResult;
}
else {
restoreAttributes(request, oldAttributes);
if (this.right.test(request)) {
return true;
}
return this.rightModifying.testInternal(request);
}
restoreAttributes(request, oldAttributes);
return false;
}
@Override
@ -910,11 +1005,11 @@ public abstract class RequestPredicates {
@Override
public void changeParser(PathPatternParser parser) {
if (this.left instanceof ChangePathPatternParserVisitor.Target leftTarget) {
leftTarget.changeParser(parser);
if (this.left instanceof ChangePathPatternParserVisitor.Target target) {
target.changeParser(parser);
}
if (this.right instanceof ChangePathPatternParserVisitor.Target rightTarget) {
rightTarget.changeParser(parser);
if (this.right instanceof ChangePathPatternParserVisitor.Target target) {
target.changeParser(parser);
}
}

View File

@ -182,24 +182,25 @@ public class RequestPredicateAttributesTests {
}
private static class AddAttributePredicate implements RequestPredicate {
private static class AddAttributePredicate extends RequestPredicates.RequestModifyingPredicate {
private boolean result;
private final boolean result;
private final String key;
private final String value;
private AddAttributePredicate(boolean result, String key, String value) {
public AddAttributePredicate(boolean result, String key, String value) {
this.result = result;
this.key = key;
this.value = value;
}
@Override
public boolean test(ServerRequest request) {
request.attributes().put(key, value);
return this.result;
protected Result testInternal(ServerRequest request) {
return Result.of(this.result, serverRequest -> serverRequest.attributes().put(this.key, this.value));
}
}

View File

@ -23,7 +23,6 @@ import java.security.Principal;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
@ -295,11 +294,6 @@ public abstract class RequestPredicates {
}
}
private static void restoreAttributes(ServerRequest request, Map<String, Object> attributes) {
request.attributes().clear();
request.attributes().putAll(attributes);
}
private static Map<String, String> mergePathVariables(Map<String, String> oldVariables,
Map<String, String> newVariables) {
@ -431,6 +425,87 @@ public abstract class RequestPredicates {
}
/**
* Extension of {@code RequestPredicate} that can modify the {@code ServerRequest}.
*/
private static abstract class RequestModifyingPredicate implements RequestPredicate {
public static RequestModifyingPredicate of(RequestPredicate requestPredicate) {
if (requestPredicate instanceof RequestModifyingPredicate modifyingPredicate) {
return modifyingPredicate;
}
else {
return new RequestModifyingPredicate() {
@Override
protected Result testInternal(ServerRequest request) {
return Result.of(requestPredicate.test(request));
}
};
}
}
@Override
public final boolean test(ServerRequest request) {
Result result = testInternal(request);
boolean value = result.value();
if (value) {
result.modify(request);
}
return value;
}
protected abstract Result testInternal(ServerRequest request);
protected static final class Result {
private static final Result TRUE = new Result(true, null);
private static final Result FALSE = new Result(false, null);
private final boolean value;
@Nullable
private final Consumer<ServerRequest> modify;
private Result(boolean value, @Nullable Consumer<ServerRequest> modify) {
this.value = value;
this.modify = modify;
}
public static Result of(boolean value) {
return of(value, null);
}
public static Result of(boolean value, @Nullable Consumer<ServerRequest> commit) {
if (commit == null) {
return value ? TRUE : FALSE;
}
else {
return new Result(value, commit);
}
}
public boolean value() {
return this.value;
}
public void modify(ServerRequest request) {
if (this.modify != null) {
this.modify.accept(request);
}
}
}
}
private static class HttpMethodPredicate implements RequestPredicate {
private final Set<HttpMethod> httpMethods;
@ -481,39 +556,41 @@ public abstract class RequestPredicates {
}
private static class PathPatternPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
private static class PathPatternPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private PathPattern pattern;
public PathPatternPredicate(PathPattern pattern) {
Assert.notNull(pattern, "'pattern' must not be null");
this.pattern = pattern;
}
@Override
public boolean test(ServerRequest request) {
protected Result testInternal(ServerRequest request) {
PathContainer pathContainer = request.requestPath().pathWithinApplication();
PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer);
traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null);
if (info != null) {
mergeAttributes(request, info.getUriVariables(), this.pattern);
return true;
return Result.of(true, serverRequest -> mergeAttributes(serverRequest, info.getUriVariables()));
}
else {
return false;
return Result.of(false);
}
}
private static void mergeAttributes(ServerRequest request, Map<String, String> variables,
PathPattern pattern) {
private void mergeAttributes(ServerRequest request, Map<String, String> variables) {
Map<String, Object> attributes = request.attributes();
Map<String, String> pathVariables = mergePathVariables(request.pathVariables(), variables);
request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Collections.unmodifiableMap(pathVariables));
attributes.put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE,
Collections.unmodifiableMap(pathVariables));
pattern = mergePatterns(
(PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
pattern);
request.attributes().put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
PathPattern pattern = mergePatterns(
(PathPattern) attributes.get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE),
this.pattern);
attributes.put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern);
}
@Override
@ -755,28 +832,42 @@ public abstract class RequestPredicates {
* {@link RequestPredicate} for where both {@code left} and {@code right} predicates
* must match.
*/
static class AndRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
static class AndRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate left;
private final RequestModifyingPredicate leftModifying;
private final RequestPredicate right;
private final RequestModifyingPredicate rightModifying;
public AndRequestPredicate(RequestPredicate left, RequestPredicate right) {
Assert.notNull(left, "Left RequestPredicate must not be null");
Assert.notNull(right, "Right RequestPredicate must not be null");
this.left = left;
this.leftModifying = of(left);
this.right = right;
this.rightModifying = of(right);
}
@Override
public boolean test(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
if (this.left.test(request) && this.right.test(request)) {
return true;
@Override
protected Result testInternal(ServerRequest request) {
Result leftResult = this.leftModifying.testInternal(request);
if (!leftResult.value()) {
return leftResult;
}
restoreAttributes(request, oldAttributes);
return false;
Result rightResult = this.rightModifying.testInternal(request);
if (!rightResult.value()) {
return rightResult;
}
return Result.of(true, serverRequest -> {
leftResult.modify(serverRequest);
rightResult.modify(serverRequest);
});
}
@Override
@ -813,23 +904,25 @@ public abstract class RequestPredicates {
/**
* {@link RequestPredicate} that negates a delegate predicate.
*/
static class NegateRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
static class NegateRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate delegate;
private final RequestModifyingPredicate delegateModifying;
public NegateRequestPredicate(RequestPredicate delegate) {
Assert.notNull(delegate, "Delegate must not be null");
this.delegate = delegate;
this.delegateModifying = of(delegate);
}
@Override
public boolean test(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
boolean result = !this.delegate.test(request);
if (!result) {
restoreAttributes(request, oldAttributes);
}
return result;
protected Result testInternal(ServerRequest request) {
Result result = this.delegateModifying.testInternal(request);
return Result.of(!result.value(), result::modify);
}
@Override
@ -857,34 +950,36 @@ public abstract class RequestPredicates {
* {@link RequestPredicate} where either {@code left} or {@code right} predicates
* may match.
*/
static class OrRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target {
static class OrRequestPredicate extends RequestModifyingPredicate
implements ChangePathPatternParserVisitor.Target {
private final RequestPredicate left;
private final RequestModifyingPredicate leftModifying;
private final RequestPredicate right;
private final RequestModifyingPredicate rightModifying;
public OrRequestPredicate(RequestPredicate left, RequestPredicate right) {
Assert.notNull(left, "Left RequestPredicate must not be null");
Assert.notNull(right, "Right RequestPredicate must not be null");
this.left = left;
this.leftModifying = of(left);
this.right = right;
this.rightModifying = of(right);
}
@Override
public boolean test(ServerRequest request) {
Map<String, Object> oldAttributes = new HashMap<>(request.attributes());
if (this.left.test(request)) {
return true;
protected Result testInternal(ServerRequest request) {
Result leftResult = this.leftModifying.testInternal(request);
if (leftResult.value()) {
return leftResult;
}
else {
restoreAttributes(request, oldAttributes);
if (this.right.test(request)) {
return true;
}
return this.rightModifying.testInternal(request);
}
restoreAttributes(request, oldAttributes);
return false;
}
@Override