diff --git a/config/src/test/groovy/org/springframework/security/config/message/MessagesConfigTests.groovy b/config/src/test/groovy/org/springframework/security/config/message/MessagesConfigTests.groovy index c1c2da88f8..a0c2f5d4e7 100644 --- a/config/src/test/groovy/org/springframework/security/config/message/MessagesConfigTests.groovy +++ b/config/src/test/groovy/org/springframework/security/config/message/MessagesConfigTests.groovy @@ -36,7 +36,7 @@ import org.springframework.security.core.context.SecurityContextHolder * @author Rob Winch */ class MessagesConfigTests extends AbstractXmlConfigTests { - Authentication messageUser + Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER') def cleanup() { SecurityContextHolder.clearContext() @@ -61,6 +61,21 @@ class MessagesConfigTests extends AbstractXmlConfigTests { clientInboundChannel.send(message('/permitAll')) } + def 'anonymous authentication supported'() { + setup: + messages { + 'message-interceptor'(pattern:'/permitAll',access:'permitAll') + 'message-interceptor'(pattern:'/denyAll',access:'denyAll') + } + messageUser = null + + when: 'message is sent to the permitAll endpoint with no user' + clientInboundChannel.send(message('/permitAll')) + + then: 'access is granted' + noExceptionThrown() + } + def 'messages with no id automatically adds Authentication argument resolver'() { setup: def id = 'authenticationController' @@ -198,12 +213,13 @@ class MessagesConfigTests extends AbstractXmlConfigTests { } def message(String destination) { - messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER') SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create() headers.sessionId = '123' headers.sessionAttributes = [:] headers.destination = destination - headers.user = messageUser + if(messageUser != null) { + headers.user = messageUser + } new GenericMessage("hi",headers.messageHeaders) } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java new file mode 100644 index 0000000000..ed3b6866e3 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java @@ -0,0 +1,265 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.security.config.annotation.web.socket; + + +import org.junit.After; +import org.junit.Before; + +import org.junit.Test; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.core.MethodParameter; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageDeliveryException; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.messaging.support.GenericMessage; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockServletConfig; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry; +import org.springframework.security.core.annotation.AuthenticationPrincipal; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.MissingCsrfTokenException; +import org.springframework.stereotype.Controller; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.web.HttpRequestHandler; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.HandlerMapping; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; +import org.springframework.web.socket.config.annotation.StompEndpointRegistry; +import org.springframework.web.socket.server.HandshakeFailureException; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; +import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler; +import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession; + +import javax.servlet.http.HttpServletRequest; +import java.util.HashMap; +import java.util.Map; + +import static org.fest.assertions.Assertions.assertThat; +import static org.junit.Assert.fail; + +public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { + AnnotationConfigWebApplicationContext context; + + TestingAuthenticationToken messageUser; + + CsrfToken token; + + String sessionAttr; + + @Before + public void setup() { + token = new DefaultCsrfToken("header", "param", "token"); + sessionAttr = "sessionAttr"; + messageUser = new TestingAuthenticationToken("user","pass","ROLE_USER"); + } + + @After + public void cleanup() { + if(context != null) { + context.close(); + } + } + + @Test + public void simpleRegistryMappings() { + loadConfig(SockJsSecurityConfig.class); + + clientInboundChannel().send(message("/permitAll")); + + try { + clientInboundChannel().send(message("/denyAll")); + fail("Expected Exception"); + } catch(MessageDeliveryException expected) { + assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class); + } + } + + @Test + public void annonymousSupported() { + loadConfig(SockJsSecurityConfig.class); + + messageUser = null; + clientInboundChannel().send(message("/permitAll")); + } + + @Test + public void addsAuthenticationPrincipalResolver() throws InterruptedException { + loadConfig(SockJsSecurityConfig.class); + + MessageChannel messageChannel = clientInboundChannel(); + Message message = message("/permitAll/authentication"); + messageChannel.send(message); + + assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal()); + } + + private MockHttpServletRequest sockjsHttpRequest(String mapping) { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("GET"); + request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket"); + request.setRequestURI(mapping + "/289/tpyx6mde/websocket"); + request.getSession().setAttribute(sessionAttr,"sessionValue"); + + request.setAttribute(CsrfToken.class.getName(), token); + return request; + } + + private Message message(String destination) { + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); + return message(headers, destination); + } + + private Message message(SimpMessageHeaderAccessor headers, String destination) { + headers.setSessionId("123"); + headers.setSessionAttributes(new HashMap()); + if(destination != null) { + headers.setDestination(destination); + } + if(messageUser != null) { + headers.setUser(messageUser); + } + return new GenericMessage("hi",headers.getMessageHeaders()); + } + + private MessageChannel clientInboundChannel() { + return context.getBean("clientInboundChannel", MessageChannel.class); + } + + private void loadConfig(Class... configs) { + context = new AnnotationConfigWebApplicationContext(); + context.register(configs); + context.setServletConfig(new MockServletConfig()); + context.refresh(); + } + + + @Controller + static class MyController { + + String authenticationPrincipal; + MyCustomArgument myCustomArgument; + + + @MessageMapping("/authentication") + public void authentication(@AuthenticationPrincipal String un) { + this.authenticationPrincipal = un; + } + + @MessageMapping("/myCustom") + public void myCustom(MyCustomArgument myCustomArgument) { + this.myCustomArgument = myCustomArgument; + } + } + + static class MyCustomArgument { + MyCustomArgument(String notDefaultConstr) {} + } + + static class MyCustomArgumentResolver implements HandlerMethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return parameter.getParameterType().isAssignableFrom(MyCustomArgument.class); + } + + @Override + public Object resolveArgument(MethodParameter parameter, Message message) throws Exception { + return new MyCustomArgument(""); + } + } + + static class TestHandshakeHandler implements HandshakeHandler { + Map attributes; + + public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws HandshakeFailureException { + this.attributes = attributes; + if(wsHandler instanceof SockJsWebSocketHandler) { + // work around SPR-12716 + SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler; + WebSocketServerSockJsSession session = (WebSocketServerSockJsSession) ReflectionTestUtils.getField(sockJs, "sockJsSession"); + this.attributes = session.getAttributes(); + } + return true; + } + } + + @Configuration + @EnableWebSocketMessageBroker + @Import(SyncExecutorConfig.class) + static class SockJsSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { + + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry + .addEndpoint("/other") + .setHandshakeHandler(testHandshakeHandler()) + .withSockJS() + .setInterceptors(new HttpSessionHandshakeInterceptor()); + + registry + .addEndpoint("/chat") + .setHandshakeHandler(testHandshakeHandler()) + .withSockJS() + .setInterceptors(new HttpSessionHandshakeInterceptor()); + } + + @Override + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { + messages + .simpDestMatchers("/permitAll/**").permitAll() + .anyMessage().denyAll(); + } + + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.enableSimpleBroker("/queue/", "/topic/"); + registry.setApplicationDestinationPrefixes("/permitAll", "/denyAll"); + } + + @Bean + public MyController myController() { + return new MyController(); + } + + @Bean + public TestHandshakeHandler testHandshakeHandler() { + return new TestHandshakeHandler(); + } + } + + @Configuration + static class SyncExecutorConfig { + @Bean + public static SyncExecutorSubscribableChannelPostProcessor postProcessor() { + return new SyncExecutorSubscribableChannelPostProcessor(); + } + } +} \ No newline at end of file diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/SyncExecutorSubscribableChannelPostProcessor.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/SyncExecutorSubscribableChannelPostProcessor.java new file mode 100644 index 0000000000..958874e245 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/SyncExecutorSubscribableChannelPostProcessor.java @@ -0,0 +1,42 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.security.config.annotation.web.socket; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.config.BeanPostProcessor; +import org.springframework.messaging.support.ExecutorSubscribableChannel; + +/** + * @author Rob Winch + */ +public class SyncExecutorSubscribableChannelPostProcessor implements BeanPostProcessor { + + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { + if(bean instanceof ExecutorSubscribableChannel) { + ExecutorSubscribableChannel original = (ExecutorSubscribableChannel) bean; + ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(); + channel.setInterceptors(original.getInterceptors()); + return channel; + } + return bean; + } + + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { + return bean; + } +} diff --git a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java index 6387a2ce47..293b4cf2ef 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java @@ -15,13 +15,17 @@ */ package org.springframework.security.messaging.context; +import java.util.Stack; + import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.support.ChannelInterceptorAdapter; import org.springframework.messaging.support.ExecutorChannelInterceptor; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.util.Assert; @@ -37,10 +41,12 @@ import org.springframework.util.Assert; */ public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter implements ExecutorChannelInterceptor { private final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext(); - private static final ThreadLocal ORIGINAL_CONTEXT = new ThreadLocal(); + private static final ThreadLocal> ORIGINAL_CONTEXT = new ThreadLocal>(); private final String authenticationHeaderName; + private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + /** * Creates a new instance using the header of the name {@link SimpMessageHeaderAccessor#USER_HEADER}. */ @@ -57,6 +63,21 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA Assert.notNull(authenticationHeaderName, "authenticationHeaderName cannot be null"); this.authenticationHeaderName = authenticationHeaderName; } + + /** + * Allows setting the Authentication used for anonymous authentication. Default is: + * + *
+     * new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
+     * 
+ * + * @param authentication the Authentication used for anonymous authentication. Cannot be null. + */ + public void setAnonymousAuthentication(Authentication authentication) { + Assert.notNull(authentication, "authentication cannot be null"); + this.anonymous = authentication; + } + @Override public Message preSend(Message message, MessageChannel channel) { setup(message); @@ -79,25 +100,42 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA private void setup(Message message) { SecurityContext currentContext = SecurityContextHolder.getContext(); - ORIGINAL_CONTEXT.set(currentContext); + + Stack contextStack = ORIGINAL_CONTEXT.get(); + if(contextStack == null) { + contextStack = new Stack(); + ORIGINAL_CONTEXT.set(contextStack); + } + contextStack.push(currentContext); Object user = message.getHeaders().get(authenticationHeaderName); - if(!(user instanceof Authentication)) { - return; + + Authentication authentication; + if((user instanceof Authentication)) { + authentication = (Authentication) user; + } else { + authentication = this.anonymous; } - Authentication authentication = (Authentication) user; SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authentication); SecurityContextHolder.setContext(context); } private void cleanup() { - SecurityContext originalContext = ORIGINAL_CONTEXT.get(); - ORIGINAL_CONTEXT.remove(); + Stack contextStack = ORIGINAL_CONTEXT.get(); + + if(contextStack == null || contextStack.isEmpty()) { + SecurityContextHolder.clearContext(); + ORIGINAL_CONTEXT.remove(); + return; + } + + SecurityContext originalContext = contextStack.pop(); try { if(EMPTY_CONTEXT.equals(originalContext)) { SecurityContextHolder.clearContext(); + ORIGINAL_CONTEXT.remove(); } else { SecurityContextHolder.setContext(originalContext); } diff --git a/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java index 8459534374..01973d43fa 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java @@ -10,8 +10,11 @@ import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; @@ -35,10 +38,13 @@ public class SecurityContextChannelInterceptorTests { SecurityContextChannelInterceptor interceptor; + AnonymousAuthenticationToken expectedAnonymous; + @Before public void setup() { authentication = new TestingAuthenticationToken("user","pass", "ROLE_USER"); messageBuilder = MessageBuilder.withPayload("payload"); + expectedAnonymous = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); interceptor = new SecurityContextChannelInterceptor(); } @@ -73,20 +79,45 @@ public class SecurityContextChannelInterceptorTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication); } + @Test(expected = IllegalArgumentException.class) + public void setAnonymousAuthenticationNull() { + interceptor.setAnonymousAuthentication(null); + } + + @Test + public void preSendUsesCustomAnonymous() throws Exception { + expectedAnonymous = new AnonymousAuthenticationToken("customKey", "customAnonymous", AuthorityUtils.createAuthorityList("ROLE_CUSTOM")); + interceptor.setAnonymousAuthentication(expectedAnonymous); + + interceptor.preSend(messageBuilder.build(), channel); + + assertAnonymous(); + } + + // SEC-2845 @Test public void preSendUserNotAuthentication() throws Exception { messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, principal); interceptor.preSend(messageBuilder.build(), channel); - assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + assertAnonymous(); } + // SEC-2845 @Test public void preSendUserNotSet() throws Exception { interceptor.preSend(messageBuilder.build(), channel); - assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + assertAnonymous(); + } + + // SEC-2845 + @Test + public void preSendUserNotSetCustomAnonymous() throws Exception { + interceptor.preSend(messageBuilder.build(), channel); + + assertAnonymous(); } @Test @@ -114,20 +145,22 @@ public class SecurityContextChannelInterceptorTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication); } + // SEC-2845 @Test public void beforeHandleUserNotAuthentication() throws Exception { messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, principal); interceptor.beforeHandle(messageBuilder.build(), channel, handler); - assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + assertAnonymous(); } + // SEC-2845 @Test public void beforeHandleUserNotSet() throws Exception { interceptor.beforeHandle(messageBuilder.build(), channel, handler); - assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + assertAnonymous(); } @@ -147,6 +180,7 @@ public class SecurityContextChannelInterceptorTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } + // SEC-2829 @Test public void restoresOriginalContext() throws Exception { TestingAuthenticationToken original = new TestingAuthenticationToken("original", "original", "ROLE_USER"); @@ -161,4 +195,47 @@ public class SecurityContextChannelInterceptorTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(original); } + + /** + * If a user sends a message when processing another message + * + * @throws Exception + */ + @Test + public void restoresOriginalContextNestedThreeDeep() throws Exception { + AnonymousAuthenticationToken anonymous = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList("ROLE_USER")); + + TestingAuthenticationToken origional = new TestingAuthenticationToken("original", "origional", "ROLE_USER"); + SecurityContextHolder.getContext().setAuthentication(origional); + + messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, authentication); + interceptor.beforeHandle(messageBuilder.build(), channel, handler); + + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication); + + // start send message + messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, null); + interceptor.beforeHandle(messageBuilder.build(), channel, handler); + + assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo(anonymous.getName()); + + interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null); + + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication); + // end send message + + interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null); + + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(origional); + } + + private void assertAnonymous() { + Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication(); + assertThat(currentAuthentication).isInstanceOf(AnonymousAuthenticationToken.class); + + AnonymousAuthenticationToken anonymous = (AnonymousAuthenticationToken) currentAuthentication; + assertThat(anonymous.getName()).isEqualTo(expectedAnonymous.getName()); + assertThat(anonymous.getAuthorities()).containsOnly(expectedAnonymous.getAuthorities().toArray()); + assertThat(anonymous.getKeyHash()).isEqualTo(expectedAnonymous.getKeyHash()); + } } \ No newline at end of file