From 4a33ab557721cc3d25583e246bf4b5d9173c15c8 Mon Sep 17 00:00:00 2001 From: Dave Syer Date: Thu, 17 Jul 2014 14:20:29 +0100 Subject: [PATCH] Make sure ErrorPageFilter is only applied once per request Fixes gh-1257 --- .../boot/context/web/ErrorPageFilter.java | 22 ++++++++++------ .../context/web/ErrorPageFilterTests.java | 25 ++++++++++++++++--- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java index b202922cd56..3a794b33612 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/web/ErrorPageFilter.java @@ -38,6 +38,7 @@ import org.springframework.boot.context.embedded.ErrorPage; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; +import org.springframework.web.filter.OncePerRequestFilter; /** * A special {@link AbstractConfigurableEmbeddedServletContainer} for non-embedded @@ -76,21 +77,28 @@ class ErrorPageFilter extends AbstractConfigurableEmbeddedServletContainer imple private final Map, String> exceptions = new HashMap, String>(); private final Map, Class> subtypes = new HashMap, Class>(); + + private final OncePerRequestFilter delegate = new OncePerRequestFilter( + ) { + + @Override + protected void doFilterInternal(HttpServletRequest request, + HttpServletResponse response, FilterChain chain) + throws ServletException, IOException { + ErrorPageFilter.this.doFilter(request, response, chain); + } + + }; @Override public void init(FilterConfig filterConfig) throws ServletException { + delegate.init(filterConfig); } @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - if (request instanceof HttpServletRequest - && response instanceof HttpServletResponse) { - doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain); - } - else { - chain.doFilter(request, response); - } + delegate.doFilter(request, response, chain); } private void doFilter(HttpServletRequest request, HttpServletResponse response, diff --git a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java index a96a71d32f7..07419ccb714 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/web/ErrorPageFilterTests.java @@ -16,6 +16,11 @@ package org.springframework.boot.context.web; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + import java.io.IOException; import javax.servlet.RequestDispatcher; @@ -29,13 +34,10 @@ import org.junit.Test; import org.springframework.boot.context.embedded.ErrorPage; import org.springframework.http.HttpStatus; import org.springframework.mock.web.MockFilterChain; +import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import static org.hamcrest.Matchers.equalTo; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; - /** * Tests for {@link ErrorPageFilter}. * @@ -97,6 +99,21 @@ public class ErrorPageFilterTests { equalTo(400)); } + @Test + public void oncePerRequest() throws Exception { + this.chain = new MockFilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + ((HttpServletResponse) response).sendError(400, "BAD"); + assertNotNull(request.getAttribute("FILTER.FILTERED")); + super.doFilter(request, response); + } + }; + filter.init(new MockFilterConfig("FILTER")); + this.filter.doFilter(this.request, this.response, this.chain); + } + @Test public void globalError() throws Exception { this.filter.addErrorPages(new ErrorPage("/error"));