Make sure ErrorPageFilter is only applied once per request
Fixes gh-1257
This commit is contained in:
parent
0c52817c88
commit
4a33ab5577
|
|
@ -38,6 +38,7 @@ import org.springframework.boot.context.embedded.ErrorPage;
|
||||||
import org.springframework.core.Ordered;
|
import org.springframework.core.Ordered;
|
||||||
import org.springframework.core.annotation.Order;
|
import org.springframework.core.annotation.Order;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.web.filter.OncePerRequestFilter;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded
|
* A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded
|
||||||
|
|
@ -77,20 +78,27 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple
|
||||||
|
|
||||||
private final Map<Class<?>, Class<?>> subtypes = new HashMap<Class<?>, Class<?>>();
|
private final Map<Class<?>, Class<?>> subtypes = new HashMap<Class<?>, Class<?>>();
|
||||||
|
|
||||||
|
private final OncePerRequestFilter delegate = new OncePerRequestFilter(
|
||||||
|
) {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void doFilterInternal(HttpServletRequest request,
|
||||||
|
HttpServletResponse response, FilterChain chain)
|
||||||
|
throws ServletException, IOException {
|
||||||
|
ErrorPageFilter.this.doFilter(request, response, chain);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void init(FilterConfig filterConfig) throws ServletException {
|
public void init(FilterConfig filterConfig) throws ServletException {
|
||||||
|
delegate.init(filterConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doFilter(ServletRequest request, ServletResponse response,
|
public void doFilter(ServletRequest request, ServletResponse response,
|
||||||
FilterChain chain) throws IOException, ServletException {
|
FilterChain chain) throws IOException, ServletException {
|
||||||
if (request instanceof HttpServletRequest
|
delegate.doFilter(request, response, chain);
|
||||||
&& response instanceof HttpServletResponse) {
|
|
||||||
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
chain.doFilter(request, response);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void doFilter(HttpServletRequest request, HttpServletResponse response,
|
private void doFilter(HttpServletRequest request, HttpServletResponse response,
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,11 @@
|
||||||
|
|
||||||
package org.springframework.boot.context.web;
|
package org.springframework.boot.context.web;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
import static org.junit.Assert.assertThat;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
import javax.servlet.RequestDispatcher;
|
import javax.servlet.RequestDispatcher;
|
||||||
|
|
@ -29,13 +34,10 @@ import org.junit.Test;
|
||||||
import org.springframework.boot.context.embedded.ErrorPage;
|
import org.springframework.boot.context.embedded.ErrorPage;
|
||||||
import org.springframework.http.HttpStatus;
|
import org.springframework.http.HttpStatus;
|
||||||
import org.springframework.mock.web.MockFilterChain;
|
import org.springframework.mock.web.MockFilterChain;
|
||||||
|
import org.springframework.mock.web.MockFilterConfig;
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.mock.web.MockHttpServletResponse;
|
import org.springframework.mock.web.MockHttpServletResponse;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
|
||||||
import static org.junit.Assert.assertThat;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for {@link ErrorPageFilter}.
|
* Tests for {@link ErrorPageFilter}.
|
||||||
*
|
*
|
||||||
|
|
@ -97,6 +99,21 @@ public class ErrorPageFilterTests {
|
||||||
equalTo(400));
|
equalTo(400));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void oncePerRequest() throws Exception {
|
||||||
|
this.chain = new MockFilterChain() {
|
||||||
|
@Override
|
||||||
|
public void doFilter(ServletRequest request, ServletResponse response)
|
||||||
|
throws IOException, ServletException {
|
||||||
|
((HttpServletResponse) response).sendError(400, "BAD");
|
||||||
|
assertNotNull(request.getAttribute("FILTER.FILTERED"));
|
||||||
|
super.doFilter(request, response);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
filter.init(new MockFilterConfig("FILTER"));
|
||||||
|
this.filter.doFilter(this.request, this.response, this.chain);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void globalError() throws Exception {
|
public void globalError() throws Exception {
|
||||||
this.filter.addErrorPages(new ErrorPage("/error"));
|
this.filter.addErrorPages(new ErrorPage("/error"));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue