Polishing and minor refactoring in UrlHandlerFilter

See gh-32830
This commit is contained in:
rstoyanchev 2024-07-01 15:50:02 +01:00
parent fd3bf5b352
commit 80d1d50478
2 changed files with 138 additions and 157 deletions

View File

@ -17,7 +17,7 @@
package org.springframework.web.filter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@ -46,13 +46,13 @@ import org.springframework.web.util.pattern.PathPatternParser;
* {@code Filter} that can be configured to trim trailing slashes, and either
* send a redirect or wrap the request and continue processing.
*
* <p>Use the static {@link #trimTrailingSlash(String...)} method to begin to
* <p>Use the static {@link #trailingSlashHandler(String...)} method to begin to
* configure and build an instance. For example:
*
* <pre>
* UrlHandlerFilter filter = UrlHandlerFilter
* .trimTrailingSlash("/path1/**").andRedirect(HttpStatus.PERMANENT_REDIRECT)
* .trimTrailingSlash("/path2/**").andHandleRequest()
* .trailingSlashHandler("/path1/**").redirect(HttpStatus.PERMANENT_REDIRECT)
* .trailingSlashHandler("/path2/**").wrapRequest()
* .build();
* </pre>
*
@ -67,10 +67,10 @@ public final class UrlHandlerFilter extends OncePerRequestFilter {
private static final Log logger = LogFactory.getLog(UrlHandlerFilter.class);
private final Map<PathPattern, Handler> handlers;
private final Map<PathPattern, UrlHandler> handlers;
private UrlHandlerFilter(Map<PathPattern, Handler> handlers) {
private UrlHandlerFilter(Map<PathPattern, UrlHandler> handlers) {
this.handlers = new LinkedHashMap<>(handlers);
}
@ -95,9 +95,9 @@ public final class UrlHandlerFilter extends OncePerRequestFilter {
if (path == null) {
path = ServletRequestPathUtils.parseAndCache(request);
}
for (Map.Entry<PathPattern, Handler> entry : this.handlers.entrySet()) {
Handler handler = entry.getValue();
if (entry.getKey().matches(path) && handler.shouldHandle(request)) {
for (Map.Entry<PathPattern, UrlHandler> entry : this.handlers.entrySet()) {
UrlHandler handler = entry.getValue();
if (entry.getKey().matches(path) && handler.canHandle(request)) {
handler.handle(request, response, chain);
return;
}
@ -114,66 +114,67 @@ public final class UrlHandlerFilter extends OncePerRequestFilter {
/**
* Begin to configure and build a {@link UrlHandlerFilter} by adding a
* trailing slash handler for the specified paths. For more details, see
* {@link Builder#trimTrailingSlash(String...)}.
* @param pathPatterns the URL patterns to which trimming applies.
* The pattern itself does not need to end with a trailing slash.
* @return a spec to continue with configuring the handler
* Create a builder for a {@link UrlHandlerFilter} by adding a handler for
* URL's with a trailing slash.
* @param pathPatterns path patterns to map the handler to, e.g.
* <code>"/path/&#42;"</code>, <code>"/path/&#42;&#42;"</code>,
* <code>"/path/foo/"</code>.
* @return a spec to configure the trailing slash handler with
* @see Builder#trailingSlashHandler(String...)
*/
public static TrailingSlashHandlerSpec trimTrailingSlash(String... pathPatterns) {
return new DefaultBuilder().trimTrailingSlash(pathPatterns);
public static Builder.TrailingSlashSpec trailingSlashHandler(String... pathPatterns) {
return new DefaultBuilder().trailingSlashHandler(pathPatterns);
}
/**
* Builder to configure and build a {@link UrlHandlerFilter}.
* Builder for {@link UrlHandlerFilter}.
*/
public interface Builder {
/**
* An entry point to configure a trim trailing slash handler.
* @param pathPatterns the URL patterns to which trimming applies.
* The pattern itself does not need to end with a trailing slash.
* @return a spec to continue with configuring the handler
* Add a handler for URL's with a trailing slash.
* @param pathPatterns path patterns to map the handler to, e.g.
* <code>"/path/&#42;"</code>, <code>"/path/&#42;&#42;"</code>,
* <code>"/path/foo/"</code>.
* @return a spec to configure the handler with
*/
TrailingSlashHandlerSpec trimTrailingSlash(String... pathPatterns);
TrailingSlashSpec trailingSlashHandler(String... pathPatterns);
/**
* Build the {@link UrlHandlerFilter} instance.
*/
UrlHandlerFilter build();
}
/**
* A spec to configure a trailing slash handler.
*/
public interface TrailingSlashHandlerSpec {
/**
* A callback to intercept requests with a trailing slash.
* @param consumer callback to be invoked for requests with a trailing slash
* @return the same spec instance
* A spec to configure a trailing slash handler.
*/
TrailingSlashHandlerSpec intercept(Consumer<HttpServletRequest> consumer);
interface TrailingSlashSpec {
/**
* Handle by sending a redirect with the given HTTP status and a location
* with the trailing slash trimmed.
* @param status the status to use
* @return to go back to the main {@link Builder} and either add more
* handlers or build the {@code Filter} instance.
*/
Builder andRedirect(HttpStatus status);
/**
* Intercept requests with a trailing slash. The callback is invoked
* just before the configured trailing slash handler.
*/
TrailingSlashSpec intercept(Consumer<HttpServletRequest> consumer);
/**
* Handle by wrapping the request with the trimmed trailing slash and
* delegating to the rest of the filter chain.
* @return to go back to the main {@link Builder} and either add more
* handlers or build the {@code Filter} instance.
*/
Builder andHandleRequest();
/**
* Handle requests by sending a redirect to the same URL but the
* trailing slash trimmed.
* @param status the redirect status to use
* @return the top level {@link Builder}, which allows adding more
* handlers and then building the Filter instance.
*/
Builder redirect(HttpStatus status);
/**
* Handle the request by wrapping it in order to trim the trailing
* slash, and delegating to the rest of the filter chain.
* @return the top level {@link Builder}, which allows adding more
* handlers and then building the Filter instance.
*/
Builder wrapRequest();
}
}
@ -184,28 +185,16 @@ public final class UrlHandlerFilter extends OncePerRequestFilter {
private final PathPatternParser patternParser = new PathPatternParser();
private final Map<PathPattern, Handler> handlers = new LinkedHashMap<>();
private final Map<PathPattern, UrlHandler> handlers = new LinkedHashMap<>();
@Override
public TrailingSlashHandlerSpec trimTrailingSlash(String... pathPatterns) {
return new DefaultTrailingSlashHandlerSpec(this, parseTrailingSlashPatterns(pathPatterns));
public TrailingSlashSpec trailingSlashHandler(String... patterns) {
return new DefaultTrailingSlashSpec(patterns);
}
public void addHandler(List<PathPattern> pathPatterns, Handler handler) {
for (PathPattern pattern : pathPatterns) {
this.handlers.put(pattern, handler);
}
}
private List<PathPattern> parseTrailingSlashPatterns(String... patternValues) {
List<PathPattern> patterns = new ArrayList<>(patternValues.length);
for (String s : patternValues) {
if (!s.endsWith("**") && s.charAt(s.length() - 1) != '/') {
s += "/";
}
patterns.add(this.patternParser.parse(s));
}
return patterns;
private DefaultBuilder addHandler(List<PathPattern> pathPatterns, UrlHandler handler) {
pathPatterns.forEach(pattern -> this.handlers.put(pattern, handler));
return this;
}
@Override
@ -213,80 +202,72 @@ public final class UrlHandlerFilter extends OncePerRequestFilter {
return new UrlHandlerFilter(this.handlers);
}
}
private final class DefaultTrailingSlashSpec implements TrailingSlashSpec {
private static final Predicate<HttpServletRequest> trailingSlashPredicate =
request -> request.getRequestURI().endsWith("/");
/**
* Default {@link TrailingSlashHandlerSpec} implementation.
*/
private static final class DefaultTrailingSlashHandlerSpec implements TrailingSlashHandlerSpec {
private static final Predicate<HttpServletRequest> trailingSlashPredicate =
request -> request.getRequestURI().endsWith("/");
private static final Function<String, String> trimTralingSlashFunction = path -> {
int index = (StringUtils.hasLength(path) ? path.lastIndexOf('/') : -1);
return (index != -1 ? path.substring(0, index) : path);
};
private final DefaultBuilder parent;
private final List<PathPattern> pathPatterns;
@Nullable
private Consumer<HttpServletRequest> interceptors;
private DefaultTrailingSlashHandlerSpec(DefaultBuilder parent, List<PathPattern> pathPatterns) {
this.parent = parent;
this.pathPatterns = pathPatterns;
}
@Override
public TrailingSlashHandlerSpec intercept(Consumer<HttpServletRequest> interceptor) {
this.interceptors = (this.interceptors != null ? this.interceptors.andThen(interceptor) : interceptor);
return this;
}
@Override
public Builder andRedirect(HttpStatus status) {
return addHandler(new RedirectPathHandler(
trailingSlashPredicate, trimTralingSlashFunction, status, initInterceptor()));
}
@Override
public Builder andHandleRequest() {
return addHandler(new RequestWrappingPathHandler(
trailingSlashPredicate, trimTralingSlashFunction, initInterceptor()));
}
private Consumer<HttpServletRequest> initInterceptor() {
if (this.interceptors != null) {
return this.interceptors;
}
return request -> {
if (logger.isTraceEnabled()) {
logger.trace("Trimmed trailing slash: " +
request.getMethod() + " " + request.getRequestURI());
}
private static final Function<String, String> tralingSlashTrimFunction = path -> {
int index = (StringUtils.hasLength(path) ? path.lastIndexOf('/') : -1);
return (index != -1 ? path.substring(0, index) : path);
};
}
private DefaultBuilder addHandler(Handler handler) {
this.parent.addHandler(this.pathPatterns, handler);
return this.parent;
private final List<PathPattern> pathPatterns;
@Nullable
private Consumer<HttpServletRequest> requestConsumer;
private DefaultTrailingSlashSpec(String[] patterns) {
this.pathPatterns = Arrays.stream(patterns)
.map(pattern -> pattern.endsWith("**") || pattern.endsWith("/") ? pattern : pattern + "/")
.map(patternParser::parse)
.toList();
}
@Override
public TrailingSlashSpec intercept(Consumer<HttpServletRequest> consumer) {
this.requestConsumer = (this.requestConsumer != null ?
this.requestConsumer.andThen(consumer) : consumer);
return this;
}
@Override
public Builder redirect(HttpStatus status) {
return DefaultBuilder.this.addHandler(
this.pathPatterns, new RedirectUrlHandler(
trailingSlashPredicate, tralingSlashTrimFunction, status, initRequestConsumer()));
}
@Override
public Builder wrapRequest() {
return DefaultBuilder.this.addHandler(
this.pathPatterns, new RequestWrappingUrlHandler(
trailingSlashPredicate, tralingSlashTrimFunction, initRequestConsumer()));
}
private Consumer<HttpServletRequest> initRequestConsumer() {
return this.requestConsumer != null ? this.requestConsumer :
(request -> {
if (logger.isTraceEnabled()) {
logger.trace("Trimmed trailing slash: " +
request.getMethod() + " " + request.getRequestURI());
}
});
}
}
}
/**
* Internal handler for {@link UrlHandlerFilter} to delegate to.
* Internal handler to encapsulate different ways to handle a request.
*/
private interface Handler {
private interface UrlHandler {
/**
* Whether the handler handles the given request.
*/
boolean shouldHandle(HttpServletRequest request);
boolean canHandle(HttpServletRequest request);
/**
* Handle the request, possibly delegating to the rest of the filter chain.
@ -297,23 +278,23 @@ public final class UrlHandlerFilter extends OncePerRequestFilter {
/**
* Base class for handlers that modify the URL path.
* Base class for {@code UrlHandler} implementations.
*/
private abstract static class AbstractPathHandler implements Handler {
private abstract static class AbstractUrlHandler implements UrlHandler {
private final Predicate<HttpServletRequest> pathPredicate;
private final Predicate<HttpServletRequest> requestPredicate;
private final Function<String, String> pathFunction;
private final Consumer<HttpServletRequest> interceptor;
private final Consumer<HttpServletRequest> requestConsumer;
AbstractPathHandler(
Predicate<HttpServletRequest> pathPredicate, Function<String, String> pathFunction,
Consumer<HttpServletRequest> interceptor) {
AbstractUrlHandler(
Predicate<HttpServletRequest> requestPredicate, Function<String, String> pathFunction,
Consumer<HttpServletRequest> requestConsumer) {
this.pathPredicate = pathPredicate;
this.requestPredicate = requestPredicate;
this.pathFunction = pathFunction;
this.interceptor = interceptor;
this.requestConsumer = requestConsumer;
}
protected Function<String, String> getPathFunction() {
@ -321,15 +302,15 @@ public final class UrlHandlerFilter extends OncePerRequestFilter {
}
@Override
public boolean shouldHandle(HttpServletRequest request) {
return this.pathPredicate.test(request);
public boolean canHandle(HttpServletRequest request) {
return this.requestPredicate.test(request);
}
@Override
public void handle(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
this.interceptor.accept(request);
this.requestConsumer.accept(request);
handleInternal(request, response, chain);
}
@ -342,11 +323,11 @@ public final class UrlHandlerFilter extends OncePerRequestFilter {
/**
* Path handler that sends a redirect.
*/
private static final class RedirectPathHandler extends AbstractPathHandler {
private static final class RedirectUrlHandler extends AbstractUrlHandler {
private final HttpStatus httpStatus;
RedirectPathHandler(
RedirectUrlHandler(
Predicate<HttpServletRequest> pathPredicate, Function<String, String> pathFunction,
HttpStatus httpStatus, Consumer<HttpServletRequest> interceptor) {
@ -371,9 +352,9 @@ public final class UrlHandlerFilter extends OncePerRequestFilter {
/**
* Path handler that wraps the request and continues processing.
*/
private static final class RequestWrappingPathHandler extends AbstractPathHandler {
private static final class RequestWrappingUrlHandler extends AbstractUrlHandler {
RequestWrappingPathHandler(
RequestWrappingUrlHandler(
Predicate<HttpServletRequest> pathPredicate, Function<String, String> pathFunction,
Consumer<HttpServletRequest> interceptor) {

View File

@ -40,16 +40,16 @@ import static org.assertj.core.api.Assertions.assertThat;
public class UrlHandlerFilterTests {
@Test
void trimTrailingSlashAndHandle() throws Exception {
testTrimTrailingSlashAndHandle("/path/**", "/path/123", null);
testTrimTrailingSlashAndHandle("/path/*", "/path", "/123");
testTrimTrailingSlashAndHandle("/path/*", "", "/path/123");
void trailingSlashWithRequestWrapping() throws Exception {
testTrailingSlashWithRequestWrapping("/path/**", "/path/123", null);
testTrailingSlashWithRequestWrapping("/path/*", "/path", "/123");
testTrailingSlashWithRequestWrapping("/path/*", "", "/path/123");
}
void testTrimTrailingSlashAndHandle(
void testTrailingSlashWithRequestWrapping(
String pattern, String servletPath, @Nullable String pathInfo) throws Exception {
UrlHandlerFilter filter = UrlHandlerFilter.trimTrailingSlash(pattern).andHandleRequest().build();
UrlHandlerFilter filter = UrlHandlerFilter.trailingSlashHandler(pattern).wrapRequest().build();
boolean hasPathInfo = StringUtils.hasLength(pathInfo);
String requestURI = servletPath + (hasPathInfo ? pathInfo : "");
@ -70,15 +70,15 @@ public class UrlHandlerFilterTests {
}
@Test
void noTrailingSlashNoHandling() throws Exception {
testNoTrailingSlashNoHandling("/path/**", "/path/123");
testNoTrailingSlashNoHandling("/path/*", "/path/123");
void noTrailingSlashWithRequestWrapping() throws Exception {
testNoTrailingSlashWithRequestWrapping("/path/**", "/path/123");
testNoTrailingSlashWithRequestWrapping("/path/*", "/path/123");
}
private static void testNoTrailingSlashNoHandling(
private static void testNoTrailingSlashWithRequestWrapping(
String pattern, String requestURI) throws ServletException, IOException {
UrlHandlerFilter filter = UrlHandlerFilter.trimTrailingSlash(pattern).andHandleRequest().build();
UrlHandlerFilter filter = UrlHandlerFilter.trailingSlashHandler(pattern).wrapRequest().build();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
MockFilterChain chain = new MockFilterChain();
@ -89,9 +89,9 @@ public class UrlHandlerFilterTests {
}
@Test
void trimTrailingSlashAndRedirect() throws Exception {
void trailingSlashHandlerWithRedirect() throws Exception {
HttpStatus status = HttpStatus.PERMANENT_REDIRECT;
UrlHandlerFilter filter = UrlHandlerFilter.trimTrailingSlash("/path/*").andRedirect(status).build();
UrlHandlerFilter filter = UrlHandlerFilter.trailingSlashHandler("/path/*").redirect(status).build();
String path = "/path/123";
MockHttpServletResponse response = new MockHttpServletResponse();
@ -106,9 +106,9 @@ public class UrlHandlerFilterTests {
}
@Test
void noTrailingSlashNoRedirect() throws Exception {
void noTrailingSlashWithRedirect() throws Exception {
HttpStatus status = HttpStatus.PERMANENT_REDIRECT;
UrlHandlerFilter filter = UrlHandlerFilter.trimTrailingSlash("/path/*").andRedirect(status).build();
UrlHandlerFilter filter = UrlHandlerFilter.trailingSlashHandler("/path/*").redirect(status).build();
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/path/123");
MockHttpServletResponse response = new MockHttpServletResponse();