diff --git a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java index bd163f1180..e4df6edbe9 100644 --- a/web/src/main/java/org/springframework/security/web/FilterChainProxy.java +++ b/web/src/main/java/org/springframework/security/web/FilterChainProxy.java @@ -17,6 +17,7 @@ package org.springframework.security.web; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.firewall.DefaultHttpFirewall; import org.springframework.security.web.firewall.FirewalledRequest; import org.springframework.security.web.firewall.HttpFirewall; @@ -150,6 +151,16 @@ public class FilterChainProxy extends GenericFilterBean { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { + try { + doFilterInternal(request, response, chain); + } finally { + // SEC-1950 + SecurityContextHolder.clearContext(); + } + } + + private void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { FirewalledRequest fwRequest = firewall.getFirewalledRequest((HttpServletRequest) request); HttpServletResponse fwResponse = firewall.getFirewalledResponse((HttpServletResponse) response); diff --git a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java index b564f6d51f..0ac2606081 100644 --- a/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java +++ b/web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java @@ -3,18 +3,22 @@ package org.springframework.security.web; import static org.junit.Assert.*; import static org.mockito.Mockito.*; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.firewall.FirewalledRequest; import org.springframework.security.web.firewall.HttpFirewall; import org.springframework.security.web.util.RequestMatcher; import javax.servlet.Filter; import javax.servlet.FilterChain; +import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; @@ -55,6 +59,11 @@ public class FilterChainProxyTests { chain = mock(FilterChain.class); } + @After + public void teardown() { + SecurityContextHolder.clearContext(); + } + @Test public void toStringCallSucceeds() throws Exception { fcp.afterPropertiesSet(); @@ -155,4 +164,37 @@ public class FilterChainProxyTests { verify(firstFwr).reset(); verify(fwr).reset(); } + + @Test + public void doFilterClearsSecurityContextHolder() throws Exception { + when(matcher.matches(any(HttpServletRequest.class))).thenReturn(true); + doAnswer(new Answer() { + public Object answer(InvocationOnMock inv) throws Throwable { + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("username", "password")); + return null; + } + }).when(filter).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); + + fcp.doFilter(request, response, chain); + + assertNull(SecurityContextHolder.getContext().getAuthentication()); + } + + @Test + public void doFilterClearsSecurityContextHolderWithException() throws Exception { + when(matcher.matches(any(HttpServletRequest.class))).thenReturn(true); + doAnswer(new Answer() { + public Object answer(InvocationOnMock inv) throws Throwable { + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("username", "password")); + throw new ServletException("oops"); + } + }).when(filter).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class)); + + try { + fcp.doFilter(request, response, chain); + fail("Expected Exception"); + }catch(ServletException success) {} + + assertNull(SecurityContextHolder.getContext().getAuthentication()); + } }