diff --git a/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java b/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java index 942ac1857c5..d692453338e 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2013 the original author or authors. + * Copyright 2012-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,12 @@ package org.springframework.boot.context.embedded; import java.util.Collection; +import java.util.Collections; import java.util.EventListener; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; import javax.servlet.Filter; import javax.servlet.Servlet; @@ -29,6 +34,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.config.Scope; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextException; import org.springframework.core.io.Resource; @@ -77,6 +83,9 @@ import org.springframework.web.context.support.WebApplicationContextUtils; */ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext { + private static final Log logger = LogFactory + .getLog(EmbeddedWebApplicationContext.class); + /** * Constant value for the DispatcherServlet bean name. A Servlet bean with this name * is deemed to be the "main" servlet and is automatically given a mapping of "/" by @@ -194,18 +203,26 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext return new ServletContextInitializer() { @Override public void onStartup(ServletContext servletContext) throws ServletException { - prepareEmbeddedWebApplicationContext(servletContext); - WebApplicationContextUtils.registerWebApplicationScopes(getBeanFactory(), - getServletContext()); - WebApplicationContextUtils.registerEnvironmentBeans(getBeanFactory(), - getServletContext()); - for (ServletContextInitializer beans : getServletContextInitializerBeans()) { - beans.onStartup(servletContext); - } + selfInitialize(servletContext); } }; } + private void selfInitialize(ServletContext servletContext) throws ServletException { + prepareEmbeddedWebApplicationContext(servletContext); + ConfigurableListableBeanFactory beanFactory = getBeanFactory(); + ExistingWebApplicationScopes existingScopes = new ExistingWebApplicationScopes( + beanFactory); + WebApplicationContextUtils.registerWebApplicationScopes(beanFactory, + getServletContext()); + existingScopes.restore(); + WebApplicationContextUtils.registerEnvironmentBeans(beanFactory, + getServletContext()); + for (ServletContextInitializer beans : getServletContextInitializerBeans()) { + beans.onStartup(servletContext); + } + } + /** * Returns {@link ServletContextInitializer}s that should be used with the embedded * Servlet context. By default this method will first attempt to find @@ -319,4 +336,45 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext return this.embeddedServletContainer; } + /** + * Utility class to store and restore any user defined scopes. This allow scopes to be + * registered in an ApplicationContextInitializer in the same way as they would in a + * classic non-embedded web application context. + */ + public static class ExistingWebApplicationScopes { + + private static final Set SCOPES; + static { + Set scopes = new LinkedHashSet(); + scopes.add(WebApplicationContext.SCOPE_REQUEST); + scopes.add(WebApplicationContext.SCOPE_SESSION); + scopes.add(WebApplicationContext.SCOPE_GLOBAL_SESSION); + SCOPES = Collections.unmodifiableSet(scopes); + } + + private final ConfigurableListableBeanFactory beanFactory; + + private final Map scopes = new HashMap(); + + public ExistingWebApplicationScopes(ConfigurableListableBeanFactory beanFactory) { + this.beanFactory = beanFactory; + for (String scopeName : SCOPES) { + Scope scope = beanFactory.getRegisteredScope(scopeName); + if (scope != null) { + this.scopes.put(scopeName, scope); + } + } + } + + public void restore() { + for (Map.Entry entry : this.scopes.entrySet()) { + if (logger.isInfoEnabled()) { + logger.info("Restoring user defined scope " + entry.getKey()); + } + this.beanFactory.registerScope(entry.getKey(), entry.getValue()); + } + } + + } + } diff --git a/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java b/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java index 3e8341ad213..68b7e4c39e1 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2013 the original author or authors. + * Copyright 2012-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,9 @@ import org.junit.rules.ExpectedException; import org.mockito.InOrder; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConstructorArgumentValues; +import org.springframework.beans.factory.config.Scope; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationContextException; import org.springframework.context.ApplicationListener; @@ -53,6 +55,7 @@ import org.springframework.web.filter.GenericFilterBean; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; @@ -408,6 +411,24 @@ public class EmbeddedWebApplicationContextTests { equalTo(8080)); } + @Test + public void doesNotReplaceExistingScopes() throws Exception { // gh-2082 + Scope scope = mock(Scope.class); + ConfigurableListableBeanFactory factory = this.context.getBeanFactory(); + factory.registerScope(WebApplicationContext.SCOPE_REQUEST, scope); + factory.registerScope(WebApplicationContext.SCOPE_SESSION, scope); + factory.registerScope(WebApplicationContext.SCOPE_GLOBAL_SESSION, scope); + addEmbeddedServletContainerFactoryBean(); + this.context.refresh(); + assertThat(factory.getRegisteredScope(WebApplicationContext.SCOPE_REQUEST), + sameInstance(scope)); + assertThat(factory.getRegisteredScope(WebApplicationContext.SCOPE_SESSION), + sameInstance(scope)); + assertThat( + factory.getRegisteredScope(WebApplicationContext.SCOPE_GLOBAL_SESSION), + sameInstance(scope)); + } + private void addEmbeddedServletContainerFactoryBean() { this.context.registerBeanDefinition("embeddedServletContainerFactory", new RootBeanDefinition(MockEmbeddedServletContainerFactory.class));