diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/SpringConfigurator.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/SpringConfigurator.java index e916c9ce50f..6991161ac9e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/SpringConfigurator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/SpringConfigurator.java @@ -16,13 +16,14 @@ package org.springframework.web.socket.server.endpoint; -import java.util.Map; - import javax.websocket.server.ServerEndpoint; import javax.websocket.server.ServerEndpointConfig.Configurator; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.stereotype.Component; +import org.springframework.util.ClassUtils; import org.springframework.web.context.ContextLoader; import org.springframework.web.context.WebApplicationContext; @@ -58,25 +59,28 @@ public class SpringConfigurator extends Configurator { throw new IllegalStateException(message); } - Map beans = wac.getBeansOfType(endpointClass); - if (beans.isEmpty()) { + String beanName = ClassUtils.getShortNameAsProperty(endpointClass); + if (wac.containsBean(beanName)) { + T endpoint = wac.getBean(beanName, endpointClass); if (logger.isTraceEnabled()) { - logger.trace("Creating new @ServerEndpoint instance of type " + endpointClass); + logger.trace("Using @ServerEndpoint singleton " + endpoint); } - return wac.getAutowireCapableBeanFactory().createBean(endpointClass); + return endpoint; } - else if (beans.size() == 1) { + + Component annot = AnnotationUtils.findAnnotation(endpointClass, Component.class); + if ((annot != null) && wac.containsBean(annot.value())) { + T endpoint = wac.getBean(annot.value(), endpointClass); if (logger.isTraceEnabled()) { - logger.trace("Using @ServerEndpoint singleton " + beans.keySet().iterator().next()); + logger.trace("Using @ServerEndpoint singleton " + endpoint); } - return beans.values().iterator().next(); + return endpoint; } - else { - // Should not happen .. - String message = "Found more than one matching @ServerEndpoint beans of type " + endpointClass; - logger.error(message); - throw new IllegalStateException(message); + + if (logger.isTraceEnabled()) { + logger.trace("Creating new @ServerEndpoint instance of type " + endpointClass); } + return wac.getAutowireCapableBeanFactory().createBean(endpointClass); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/endpoint/SpringConfiguratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/endpoint/SpringConfiguratorTests.java index 87f2bcf9f4f..62b69bf0396 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/endpoint/SpringConfiguratorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/endpoint/SpringConfiguratorTests.java @@ -25,8 +25,10 @@ import org.junit.Before; import org.junit.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; import org.springframework.mock.web.test.MockServletContext; +import org.springframework.stereotype.Component; import org.springframework.web.context.ContextLoader; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; @@ -40,6 +42,8 @@ public class SpringConfiguratorTests { private AnnotationConfigWebApplicationContext webAppContext; + private SpringConfigurator configurator; + @Before public void setup() { @@ -50,6 +54,8 @@ public class SpringConfiguratorTests { this.contextLoader = new ContextLoader(this.webAppContext); this.contextLoader.initWebApplicationContext(this.servletContext); + + this.configurator = new SpringConfigurator(); } @After @@ -59,29 +65,33 @@ public class SpringConfiguratorTests { @Test - public void getEndpointInstanceCreateBean() throws Exception { - - PerConnectionEchoEndpoint endpoint = new SpringConfigurator().getEndpointInstance(PerConnectionEchoEndpoint.class); - + public void getEndpointInstancePerConnection() throws Exception { + PerConnectionEchoEndpoint endpoint = this.configurator.getEndpointInstance(PerConnectionEchoEndpoint.class); assertNotNull(endpoint); } @Test - public void getEndpointInstanceUseBean() throws Exception { - - EchoEndpointBean expected = this.webAppContext.getBean(EchoEndpointBean.class); - EchoEndpointBean actual = new SpringConfigurator().getEndpointInstance(EchoEndpointBean.class); + public void getEndpointInstanceSingletonByType() throws Exception { + EchoEndpoint expected = this.webAppContext.getBean(EchoEndpoint.class); + EchoEndpoint actual = this.configurator.getEndpointInstance(EchoEndpoint.class); + assertSame(expected, actual); + } + @Test + public void getEndpointInstanceSingletonByComponentName() throws Exception { + AlternativeEchoEndpoint expected = this.webAppContext.getBean(AlternativeEchoEndpoint.class); + AlternativeEchoEndpoint actual = this.configurator.getEndpointInstance(AlternativeEchoEndpoint.class); assertSame(expected, actual); } @Configuration + @ComponentScan(basePackageClasses=SpringConfiguratorTests.class) static class Config { @Bean - public EchoEndpointBean echoEndpointBean() { - return new EchoEndpointBean(echoService()); + public EchoEndpoint echoEndpoint() { + return new EchoEndpoint(echoService()); } @Bean @@ -90,13 +100,29 @@ public class SpringConfiguratorTests { } } - private static class EchoEndpointBean extends Endpoint { + private static class EchoEndpoint extends Endpoint { @SuppressWarnings("unused") private final EchoService service; @Autowired - public EchoEndpointBean(EchoService service) { + public EchoEndpoint(EchoService service) { + this.service = service; + } + + @Override + public void onOpen(Session session, EndpointConfig config) { + } + } + + @Component("echoEndpoint") + private static class AlternativeEchoEndpoint extends Endpoint { + + @SuppressWarnings("unused") + private final EchoService service; + + @Autowired + public AlternativeEchoEndpoint(EchoService service) { this.service = service; }