diff --git a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java index 2352887b1f..8e7adf3c67 100644 --- a/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java +++ b/cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java @@ -42,6 +42,8 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -205,6 +207,8 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil private AuthenticationFailureHandler proxyFailureHandler = new SimpleUrlAuthenticationFailureHandler(); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + public CasAuthenticationFilter() { super("/login/cas"); setAuthenticationFailureHandler(new SimpleUrlAuthenticationFailureHandler()); @@ -223,6 +227,7 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authResult); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); if (this.eventPublisher != null) { this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass())); } @@ -274,6 +279,18 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil return result; } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + /** * Sets the {@link AuthenticationFailureHandler} for proxy requests. * @param proxyFailureHandler diff --git a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java index fab4d2ed1d..c1338cc5ae 100644 --- a/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java +++ b/cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java @@ -21,6 +21,7 @@ import javax.servlet.FilterChain; import org.jasig.cas.client.proxy.ProxyGrantingTicketStorage; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -32,12 +33,15 @@ import org.springframework.security.cas.ServiceProperties; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.context.SecurityContextRepository; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -182,6 +186,38 @@ public class CasAuthenticationFilterTests { verify(successHandler).onAuthenticationSuccess(request, response, authentication); } + @Test + public void testSecurityContextHolder() throws Exception { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + AuthenticationManager manager = mock(AuthenticationManager.class); + Authentication authentication = new TestingAuthenticationToken("un", "pwd", "ROLE_USER"); + given(manager.authenticate(any(Authentication.class))).willReturn(authentication); + ServiceProperties serviceProperties = new ServiceProperties(); + serviceProperties.setAuthenticateAllArtifacts(true); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setParameter("ticket", "ST-1-123"); + request.setServletPath("/authenticate"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + CasAuthenticationFilter filter = new CasAuthenticationFilter(); + filter.setServiceProperties(serviceProperties); + filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class)); + filter.setAuthenticationManager(manager); + filter.setSecurityContextRepository(securityContextRepository); + filter.afterPropertiesSet(); + filter.doFilter(request, response, chain); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull() + .withFailMessage("Authentication should not be null"); + verify(chain).doFilter(request, response); + // validate for when the filterProcessUrl matches + filter.setFilterProcessesUrl(request.getServletPath()); + SecurityContextHolder.clearContext(); + filter.doFilter(request, response, chain); + ArgumentCaptor contextArg = ArgumentCaptor.forClass(SecurityContext.class); + verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response)); + assertThat(contextArg.getValue().getAuthentication().getPrincipal()).isEqualTo(authentication.getName()); + } + // SEC-1592 @Test public void testChainNotInvokedForProxyReceptor() throws Exception {