diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java index 284ea7e0d53..152a9554fe8 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java @@ -21,7 +21,6 @@ import java.net.URI; import java.security.Principal; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; @@ -296,11 +295,6 @@ public abstract class RequestPredicates { } } - private static void restoreAttributes(ServerRequest request, Map attributes) { - request.attributes().clear(); - request.attributes().putAll(attributes); - } - private static Map mergePathVariables(Map oldVariables, Map newVariables) { @@ -432,13 +426,94 @@ public abstract class RequestPredicates { } + /** + * Extension of {@code RequestPredicate} that can modify the {@code ServerRequest}. + */ + static abstract class RequestModifyingPredicate implements RequestPredicate { + + + public static RequestModifyingPredicate of(RequestPredicate requestPredicate) { + if (requestPredicate instanceof RequestModifyingPredicate modifyingPredicate) { + return modifyingPredicate; + } + else { + return new RequestModifyingPredicate() { + @Override + protected Result testInternal(ServerRequest request) { + return Result.of(requestPredicate.test(request)); + } + }; + } + } + + + @Override + public final boolean test(ServerRequest request) { + Result result = testInternal(request); + boolean value = result.value(); + if (value) { + result.modify(request); + } + return value; + } + + protected abstract Result testInternal(ServerRequest request); + + + protected static final class Result { + + private static final Result TRUE = new Result(true, null); + + private static final Result FALSE = new Result(false, null); + + + private final boolean value; + + @Nullable + private final Consumer modify; + + + private Result(boolean value, @Nullable Consumer modify) { + this.value = value; + this.modify = modify; + } + + + public static Result of(boolean value) { + return of(value, null); + } + + public static Result of(boolean value, @Nullable Consumer commit) { + if (commit == null) { + return value ? TRUE : FALSE; + } + else { + return new Result(value, commit); + } + } + + + public boolean value() { + return this.value; + } + + public void modify(ServerRequest request) { + if (this.modify != null) { + this.modify.accept(request); + } + } + } + + } + + private static class HttpMethodPredicate implements RequestPredicate { private final Set httpMethods; public HttpMethodPredicate(HttpMethod httpMethod) { Assert.notNull(httpMethod, "HttpMethod must not be null"); - this.httpMethods = Collections.singleton(httpMethod); + this.httpMethods = Set.of(httpMethod); } public HttpMethodPredicate(HttpMethod... httpMethods) { @@ -482,39 +557,41 @@ public abstract class RequestPredicates { } - private static class PathPatternPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { + private static class PathPatternPredicate extends RequestModifyingPredicate + implements ChangePathPatternParserVisitor.Target { private PathPattern pattern; + public PathPatternPredicate(PathPattern pattern) { Assert.notNull(pattern, "'pattern' must not be null"); this.pattern = pattern; } + @Override - public boolean test(ServerRequest request) { + protected Result testInternal(ServerRequest request) { PathContainer pathContainer = request.requestPath().pathWithinApplication(); PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer); traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null); if (info != null) { - mergeAttributes(request, info.getUriVariables(), this.pattern); - return true; + return Result.of(true, serverRequest -> mergeAttributes(serverRequest, info.getUriVariables())); } else { - return false; + return Result.of(false); } } - private static void mergeAttributes(ServerRequest request, Map variables, - PathPattern pattern) { + private void mergeAttributes(ServerRequest request, Map variables) { + Map attributes = request.attributes(); Map pathVariables = mergePathVariables(request.pathVariables(), variables); - request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, - Collections.unmodifiableMap(pathVariables)); + attributes.put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, + Collections.unmodifiableMap(pathVariables)); - pattern = mergePatterns( - (PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE), - pattern); - request.attributes().put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern); + PathPattern pattern = mergePatterns( + (PathPattern) attributes.get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE), + this.pattern); + attributes.put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern); } @Override @@ -756,28 +833,42 @@ public abstract class RequestPredicates { * {@link RequestPredicate} for where both {@code left} and {@code right} predicates * must match. */ - static class AndRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { + static class AndRequestPredicate extends RequestModifyingPredicate + implements ChangePathPatternParserVisitor.Target { private final RequestPredicate left; + private final RequestModifyingPredicate leftModifying; + private final RequestPredicate right; + private final RequestModifyingPredicate rightModifying; + + public AndRequestPredicate(RequestPredicate left, RequestPredicate right) { Assert.notNull(left, "Left RequestPredicate must not be null"); Assert.notNull(right, "Right RequestPredicate must not be null"); this.left = left; + this.leftModifying = of(left); this.right = right; + this.rightModifying = of(right); } - @Override - public boolean test(ServerRequest request) { - Map oldAttributes = new HashMap<>(request.attributes()); - if (this.left.test(request) && this.right.test(request)) { - return true; + @Override + protected Result testInternal(ServerRequest request) { + Result leftResult = this.leftModifying.testInternal(request); + if (!leftResult.value()) { + return leftResult; } - restoreAttributes(request, oldAttributes); - return false; + Result rightResult = this.rightModifying.testInternal(request); + if (!rightResult.value()) { + return rightResult; + } + return Result.of(true, serverRequest -> { + leftResult.modify(serverRequest); + rightResult.modify(serverRequest); + }); } @Override @@ -796,11 +887,11 @@ public abstract class RequestPredicates { @Override public void changeParser(PathPatternParser parser) { - if (this.left instanceof ChangePathPatternParserVisitor.Target leftTarget) { - leftTarget.changeParser(parser); + if (this.left instanceof ChangePathPatternParserVisitor.Target target) { + target.changeParser(parser); } - if (this.right instanceof ChangePathPatternParserVisitor.Target rightTarget) { - rightTarget.changeParser(parser); + if (this.right instanceof ChangePathPatternParserVisitor.Target target) { + target.changeParser(parser); } } @@ -814,23 +905,25 @@ public abstract class RequestPredicates { /** * {@link RequestPredicate} that negates a delegate predicate. */ - static class NegateRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { + static class NegateRequestPredicate extends RequestModifyingPredicate + implements ChangePathPatternParserVisitor.Target { private final RequestPredicate delegate; + private final RequestModifyingPredicate delegateModifying; + + public NegateRequestPredicate(RequestPredicate delegate) { Assert.notNull(delegate, "Delegate must not be null"); this.delegate = delegate; + this.delegateModifying = of(delegate); } + @Override - public boolean test(ServerRequest request) { - Map oldAttributes = new HashMap<>(request.attributes()); - boolean result = !this.delegate.test(request); - if (!result) { - restoreAttributes(request, oldAttributes); - } - return result; + protected Result testInternal(ServerRequest request) { + Result result = this.delegateModifying.testInternal(request); + return Result.of(!result.value(), result::modify); } @Override @@ -858,34 +951,36 @@ public abstract class RequestPredicates { * {@link RequestPredicate} where either {@code left} or {@code right} predicates * may match. */ - static class OrRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { + static class OrRequestPredicate extends RequestModifyingPredicate + implements ChangePathPatternParserVisitor.Target { private final RequestPredicate left; + private final RequestModifyingPredicate leftModifying; + private final RequestPredicate right; + private final RequestModifyingPredicate rightModifying; + + public OrRequestPredicate(RequestPredicate left, RequestPredicate right) { Assert.notNull(left, "Left RequestPredicate must not be null"); Assert.notNull(right, "Right RequestPredicate must not be null"); this.left = left; + this.leftModifying = of(left); this.right = right; + this.rightModifying = of(right); } @Override - public boolean test(ServerRequest request) { - Map oldAttributes = new HashMap<>(request.attributes()); - - if (this.left.test(request)) { - return true; + protected Result testInternal(ServerRequest request) { + Result leftResult = this.leftModifying.testInternal(request); + if (leftResult.value()) { + return leftResult; } else { - restoreAttributes(request, oldAttributes); - if (this.right.test(request)) { - return true; - } + return this.rightModifying.testInternal(request); } - restoreAttributes(request, oldAttributes); - return false; } @Override @@ -910,11 +1005,11 @@ public abstract class RequestPredicates { @Override public void changeParser(PathPatternParser parser) { - if (this.left instanceof ChangePathPatternParserVisitor.Target leftTarget) { - leftTarget.changeParser(parser); + if (this.left instanceof ChangePathPatternParserVisitor.Target target) { + target.changeParser(parser); } - if (this.right instanceof ChangePathPatternParserVisitor.Target rightTarget) { - rightTarget.changeParser(parser); + if (this.right instanceof ChangePathPatternParserVisitor.Target target) { + target.changeParser(parser); } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicateAttributesTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicateAttributesTests.java index f20e3678a4a..faef1bc9091 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicateAttributesTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicateAttributesTests.java @@ -182,24 +182,25 @@ public class RequestPredicateAttributesTests { } - private static class AddAttributePredicate implements RequestPredicate { + private static class AddAttributePredicate extends RequestPredicates.RequestModifyingPredicate { - private boolean result; + private final boolean result; private final String key; private final String value; - private AddAttributePredicate(boolean result, String key, String value) { + + public AddAttributePredicate(boolean result, String key, String value) { this.result = result; this.key = key; this.value = value; } + @Override - public boolean test(ServerRequest request) { - request.attributes().put(key, value); - return this.result; + protected Result testInternal(ServerRequest request) { + return Result.of(this.result, serverRequest -> serverRequest.attributes().put(this.key, this.value)); } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java index fae4532aa66..906dbcce748 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java @@ -23,7 +23,6 @@ import java.security.Principal; import java.time.Instant; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; @@ -295,11 +294,6 @@ public abstract class RequestPredicates { } } - private static void restoreAttributes(ServerRequest request, Map attributes) { - request.attributes().clear(); - request.attributes().putAll(attributes); - } - private static Map mergePathVariables(Map oldVariables, Map newVariables) { @@ -431,6 +425,87 @@ public abstract class RequestPredicates { } + /** + * Extension of {@code RequestPredicate} that can modify the {@code ServerRequest}. + */ + private static abstract class RequestModifyingPredicate implements RequestPredicate { + + + public static RequestModifyingPredicate of(RequestPredicate requestPredicate) { + if (requestPredicate instanceof RequestModifyingPredicate modifyingPredicate) { + return modifyingPredicate; + } + else { + return new RequestModifyingPredicate() { + @Override + protected Result testInternal(ServerRequest request) { + return Result.of(requestPredicate.test(request)); + } + }; + } + } + + + @Override + public final boolean test(ServerRequest request) { + Result result = testInternal(request); + boolean value = result.value(); + if (value) { + result.modify(request); + } + return value; + } + + protected abstract Result testInternal(ServerRequest request); + + + protected static final class Result { + + private static final Result TRUE = new Result(true, null); + + private static final Result FALSE = new Result(false, null); + + + private final boolean value; + + @Nullable + private final Consumer modify; + + + private Result(boolean value, @Nullable Consumer modify) { + this.value = value; + this.modify = modify; + } + + + public static Result of(boolean value) { + return of(value, null); + } + + public static Result of(boolean value, @Nullable Consumer commit) { + if (commit == null) { + return value ? TRUE : FALSE; + } + else { + return new Result(value, commit); + } + } + + + public boolean value() { + return this.value; + } + + public void modify(ServerRequest request) { + if (this.modify != null) { + this.modify.accept(request); + } + } + } + + } + + private static class HttpMethodPredicate implements RequestPredicate { private final Set httpMethods; @@ -481,39 +556,41 @@ public abstract class RequestPredicates { } - private static class PathPatternPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { + private static class PathPatternPredicate extends RequestModifyingPredicate + implements ChangePathPatternParserVisitor.Target { private PathPattern pattern; + public PathPatternPredicate(PathPattern pattern) { Assert.notNull(pattern, "'pattern' must not be null"); this.pattern = pattern; } + @Override - public boolean test(ServerRequest request) { + protected Result testInternal(ServerRequest request) { PathContainer pathContainer = request.requestPath().pathWithinApplication(); PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(pathContainer); traceMatch("Pattern", this.pattern.getPatternString(), request.path(), info != null); if (info != null) { - mergeAttributes(request, info.getUriVariables(), this.pattern); - return true; + return Result.of(true, serverRequest -> mergeAttributes(serverRequest, info.getUriVariables())); } else { - return false; + return Result.of(false); } } - private static void mergeAttributes(ServerRequest request, Map variables, - PathPattern pattern) { + private void mergeAttributes(ServerRequest request, Map variables) { + Map attributes = request.attributes(); Map pathVariables = mergePathVariables(request.pathVariables(), variables); - request.attributes().put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, - Collections.unmodifiableMap(pathVariables)); + attributes.put(RouterFunctions.URI_TEMPLATE_VARIABLES_ATTRIBUTE, + Collections.unmodifiableMap(pathVariables)); - pattern = mergePatterns( - (PathPattern) request.attributes().get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE), - pattern); - request.attributes().put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern); + PathPattern pattern = mergePatterns( + (PathPattern) attributes.get(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE), + this.pattern); + attributes.put(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE, pattern); } @Override @@ -755,28 +832,42 @@ public abstract class RequestPredicates { * {@link RequestPredicate} for where both {@code left} and {@code right} predicates * must match. */ - static class AndRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { + static class AndRequestPredicate extends RequestModifyingPredicate + implements ChangePathPatternParserVisitor.Target { private final RequestPredicate left; + private final RequestModifyingPredicate leftModifying; + private final RequestPredicate right; + private final RequestModifyingPredicate rightModifying; + + public AndRequestPredicate(RequestPredicate left, RequestPredicate right) { Assert.notNull(left, "Left RequestPredicate must not be null"); Assert.notNull(right, "Right RequestPredicate must not be null"); this.left = left; + this.leftModifying = of(left); this.right = right; + this.rightModifying = of(right); } - @Override - public boolean test(ServerRequest request) { - Map oldAttributes = new HashMap<>(request.attributes()); - if (this.left.test(request) && this.right.test(request)) { - return true; + @Override + protected Result testInternal(ServerRequest request) { + Result leftResult = this.leftModifying.testInternal(request); + if (!leftResult.value()) { + return leftResult; } - restoreAttributes(request, oldAttributes); - return false; + Result rightResult = this.rightModifying.testInternal(request); + if (!rightResult.value()) { + return rightResult; + } + return Result.of(true, serverRequest -> { + leftResult.modify(serverRequest); + rightResult.modify(serverRequest); + }); } @Override @@ -813,23 +904,25 @@ public abstract class RequestPredicates { /** * {@link RequestPredicate} that negates a delegate predicate. */ - static class NegateRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { + static class NegateRequestPredicate extends RequestModifyingPredicate + implements ChangePathPatternParserVisitor.Target { private final RequestPredicate delegate; + private final RequestModifyingPredicate delegateModifying; + + public NegateRequestPredicate(RequestPredicate delegate) { Assert.notNull(delegate, "Delegate must not be null"); this.delegate = delegate; + this.delegateModifying = of(delegate); } + @Override - public boolean test(ServerRequest request) { - Map oldAttributes = new HashMap<>(request.attributes()); - boolean result = !this.delegate.test(request); - if (!result) { - restoreAttributes(request, oldAttributes); - } - return result; + protected Result testInternal(ServerRequest request) { + Result result = this.delegateModifying.testInternal(request); + return Result.of(!result.value(), result::modify); } @Override @@ -857,34 +950,36 @@ public abstract class RequestPredicates { * {@link RequestPredicate} where either {@code left} or {@code right} predicates * may match. */ - static class OrRequestPredicate implements RequestPredicate, ChangePathPatternParserVisitor.Target { + static class OrRequestPredicate extends RequestModifyingPredicate + implements ChangePathPatternParserVisitor.Target { private final RequestPredicate left; + private final RequestModifyingPredicate leftModifying; + private final RequestPredicate right; + private final RequestModifyingPredicate rightModifying; + + public OrRequestPredicate(RequestPredicate left, RequestPredicate right) { Assert.notNull(left, "Left RequestPredicate must not be null"); Assert.notNull(right, "Right RequestPredicate must not be null"); this.left = left; + this.leftModifying = of(left); this.right = right; + this.rightModifying = of(right); } @Override - public boolean test(ServerRequest request) { - Map oldAttributes = new HashMap<>(request.attributes()); - - if (this.left.test(request)) { - return true; + protected Result testInternal(ServerRequest request) { + Result leftResult = this.leftModifying.testInternal(request); + if (leftResult.value()) { + return leftResult; } else { - restoreAttributes(request, oldAttributes); - if (this.right.test(request)) { - return true; - } + return this.rightModifying.testInternal(request); } - restoreAttributes(request, oldAttributes); - return false; } @Override