DefaultWebFilterChain is a top-level, public class

Issue: SPR-15348
This commit is contained in:
Rossen Stoyanchev 2017-03-16 13:44:55 -04:00
parent ab7db413c6
commit 37592ea07c
3 changed files with 113 additions and 70 deletions

View File

@ -0,0 +1,64 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.server.handler;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import reactor.core.publisher.Mono;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebHandler;
/**
* Default implementation of {@link WebFilterChain}.
*
* @author Rossen Stoyanchev
* @since 5.0
*/
public class DefaultWebFilterChain implements WebFilterChain {
private final List<WebFilter> filters;
private final WebHandler handler;
private volatile int index;
public DefaultWebFilterChain(WebHandler handler, WebFilter... filters) {
Assert.notNull(handler, "WebHandler is required");
this.filters = ObjectUtils.isEmpty(filters) ? Collections.emptyList() : Arrays.asList(filters);
this.handler = handler;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange) {
if (this.index < this.filters.size()) {
WebFilter filter = this.filters.get(this.index++);
return filter.filter(exchange, this);
}
else {
return this.handler.handle(exchange);
}
}
}

View File

@ -16,14 +16,15 @@
package org.springframework.web.server.handler;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import reactor.core.publisher.Mono;
import org.springframework.util.CollectionUtils;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebHandler;
/**
@ -35,7 +36,7 @@ import org.springframework.web.server.WebHandler;
*/
public class FilteringWebHandler extends WebHandlerDecorator {
private final List<WebFilter> filters;
private final WebFilter[] filters;
/**
@ -44,41 +45,24 @@ public class FilteringWebHandler extends WebHandlerDecorator {
*/
public FilteringWebHandler(WebHandler webHandler, List<WebFilter> filters) {
super(webHandler);
this.filters = Collections.unmodifiableList(filters);
this.filters = !CollectionUtils.isEmpty(filters) ?
filters.toArray(new WebFilter[filters.size()]) : new WebFilter[0];
}
/**
* Return read-only list of the configured filters.
* Return a read-only list of the configured filters.
*/
public List<WebFilter> getFilters() {
return this.filters;
return Arrays.asList(this.filters);
}
@Override
public Mono<Void> handle(ServerWebExchange exchange) {
if (this.filters.isEmpty()) {
return super.handle(exchange);
}
return new DefaultWebFilterChain().filter(exchange);
}
private class DefaultWebFilterChain implements WebFilterChain {
private int index;
@Override
public Mono<Void> filter(ServerWebExchange exchange) {
if (this.index < filters.size()) {
WebFilter filter = filters.get(this.index++);
return filter.filter(exchange, this);
}
else {
return getDelegate().handle(exchange);
}
}
return this.filters.length != 0 ?
new DefaultWebFilterChain(getDelegate(), this.filters).filter(exchange) :
super.handle(exchange);
}
}

View File

@ -16,18 +16,16 @@
package org.springframework.web.server.handler;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.web.server.ServerWebExchange;
@ -49,89 +47,86 @@ public class FilteringWebHandlerTests {
private static Log logger = LogFactory.getLog(FilteringWebHandlerTests.class);
private MockServerHttpRequest request;
private MockServerHttpResponse response;
@Before
public void setUp() throws Exception {
this.request = MockServerHttpRequest.get("http://localhost").build();
this.response = new MockServerHttpResponse();
}
@Test
public void multipleFilters() throws Exception {
StubWebHandler webHandler = new StubWebHandler();
TestFilter filter1 = new TestFilter();
TestFilter filter2 = new TestFilter();
TestFilter filter3 = new TestFilter();
HttpHandler httpHandler = createHttpHandler(webHandler, filter1, filter2, filter3);
httpHandler.handle(this.request, this.response).block();
StubWebHandler targetHandler = new StubWebHandler();
new FilteringWebHandler(targetHandler, Arrays.asList(filter1, filter2, filter3))
.handle(MockServerHttpRequest.get("/").toExchange())
.block(Duration.ZERO);
assertTrue(filter1.invoked());
assertTrue(filter2.invoked());
assertTrue(filter3.invoked());
assertTrue(webHandler.invoked());
assertTrue(targetHandler.invoked());
}
@Test
public void zeroFilters() throws Exception {
StubWebHandler webHandler = new StubWebHandler();
HttpHandler httpHandler = createHttpHandler(webHandler);
httpHandler.handle(this.request, this.response).block();
assertTrue(webHandler.invoked());
StubWebHandler targetHandler = new StubWebHandler();
new FilteringWebHandler(targetHandler, Collections.emptyList())
.handle(MockServerHttpRequest.get("/").toExchange())
.block(Duration.ZERO);
assertTrue(targetHandler.invoked());
}
@Test
public void shortcircuitFilter() throws Exception {
StubWebHandler webHandler = new StubWebHandler();
TestFilter filter1 = new TestFilter();
ShortcircuitingFilter filter2 = new ShortcircuitingFilter();
TestFilter filter3 = new TestFilter();
HttpHandler httpHandler = createHttpHandler(webHandler, filter1, filter2, filter3);
httpHandler.handle(this.request, this.response).block();
StubWebHandler targetHandler = new StubWebHandler();
new FilteringWebHandler(targetHandler, Arrays.asList(filter1, filter2, filter3))
.handle(MockServerHttpRequest.get("/").toExchange())
.block(Duration.ZERO);
assertTrue(filter1.invoked());
assertTrue(filter2.invoked());
assertFalse(filter3.invoked());
assertFalse(webHandler.invoked());
assertFalse(targetHandler.invoked());
}
@Test
public void asyncFilter() throws Exception {
StubWebHandler webHandler = new StubWebHandler();
AsyncFilter filter = new AsyncFilter();
HttpHandler httpHandler = createHttpHandler(webHandler, filter);
httpHandler.handle(this.request, this.response).block();
StubWebHandler targetHandler = new StubWebHandler();
new FilteringWebHandler(targetHandler, Collections.singletonList(filter))
.handle(MockServerHttpRequest.get("/").toExchange())
.block(Duration.ZERO);
assertTrue(filter.invoked());
assertTrue(webHandler.invoked());
assertTrue(targetHandler.invoked());
}
@Test
public void handleErrorFromFilter() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.get("/").build();
MockServerHttpResponse response = new MockServerHttpResponse();
TestExceptionHandler exceptionHandler = new TestExceptionHandler();
List<ExceptionFilter> filters = Collections.singletonList(new ExceptionFilter());
List<WebExceptionHandler> exceptionHandlers = Collections.singletonList(exceptionHandler);
WebHttpHandlerBuilder.webHandler(new StubWebHandler())
.filters(filters).exceptionHandlers(exceptionHandlers).build()
.handle(this.request, this.response)
.filters(Collections.singletonList(new ExceptionFilter()))
.exceptionHandlers(Collections.singletonList(exceptionHandler)).build()
.handle(request, response)
.block();
assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, this.response.getStatusCode());
Throwable savedException = exceptionHandler.ex;
assertNotNull(savedException);
assertEquals("boo", savedException.getMessage());
}
private HttpHandler createHttpHandler(StubWebHandler webHandler, WebFilter... filters) {
return WebHttpHandlerBuilder.webHandler(webHandler).filters(Arrays.asList(filters)).build();
assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode());
assertNotNull(exceptionHandler.ex);
assertEquals("boo", exceptionHandler.ex.getMessage());
}