diff --git a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java index 9fdd983608..f9904abdcc 100644 --- a/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/websocket/WebSocketMessageBrokerSecurityBeanDefinitionParser.java @@ -255,8 +255,8 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements for (String beanName : beanNames) { BeanDefinition bd = registry.getBeanDefinition(beanName); String beanClassName = bd.getBeanClassName(); - if (beanClassName.equals(SimpAnnotationMethodMessageHandler.class - .getName()) || beanClassName.equals(WEB_SOCKET_AMMH_CLASS_NAME)) { + if (SimpAnnotationMethodMessageHandler.class.getName().equals(beanClassName) || + WEB_SOCKET_AMMH_CLASS_NAME.equals(beanClassName)) { PropertyValue current = bd.getPropertyValues().getPropertyValue( CUSTOM_ARG_RESOLVERS_PROP); ManagedList argResolvers = new ManagedList(); @@ -275,16 +275,16 @@ public final class WebSocketMessageBrokerSecurityBeanDefinitionParser implements } } } - else if (beanClassName - .equals("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler")) { + else if ("org.springframework.web.socket.server.support.WebSocketHttpRequestHandler" + .equals(beanClassName)) { addCsrfTokenHandshakeInterceptor(bd); } - else if (beanClassName - .equals("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService")) { + else if ("org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService" + .equals(beanClassName)) { addCsrfTokenHandshakeInterceptor(bd); } - else if (beanClassName - .equals("org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService")) { + else if ("org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService" + .equals(beanClassName)) { addCsrfTokenHandshakeInterceptor(bd); } } diff --git a/config/src/test/java/org/springframework/security/config/websocket/MessageSecurityPostProcessorTest.java b/config/src/test/java/org/springframework/security/config/websocket/MessageSecurityPostProcessorTest.java new file mode 100644 index 0000000000..2988cbc4a6 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/websocket/MessageSecurityPostProcessorTest.java @@ -0,0 +1,34 @@ +/* + * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.config.websocket; + +import org.junit.Test; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.GenericBeanDefinition; +import org.springframework.beans.factory.support.SimpleBeanDefinitionRegistry; + +public class MessageSecurityPostProcessorTest { + + private WebSocketMessageBrokerSecurityBeanDefinitionParser.MessageSecurityPostProcessor postProcessor = + new WebSocketMessageBrokerSecurityBeanDefinitionParser.MessageSecurityPostProcessor("id", false); + + @Test + public void handlesBeansWithoutClass() { + BeanDefinitionRegistry registry = new SimpleBeanDefinitionRegistry(); + registry.registerBeanDefinition("beanWithoutClass", new GenericBeanDefinition()); + postProcessor.postProcessBeanDefinitionRegistry(registry); + } +}