diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java index 6b2ffa4f07..1a36b0a7c0 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java @@ -27,7 +27,6 @@ import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import org.springframework.beans.BeanUtils; -import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.ApplicationContextException; @@ -363,9 +362,10 @@ public abstract class FrameworkServlet extends HttpServletBean implements Applic * @see #configureAndRefreshWebApplicationContext(ConfigurableWebApplicationContext) * @see #applyInitializers(ConfigurableApplicationContext) */ - public void setContextInitializers(ApplicationContextInitializer... contextInitializers) { - for (ApplicationContextInitializer initializer : contextInitializers) { - this.contextInitializers.add(initializer); + @SuppressWarnings("unchecked") + public void setContextInitializers(ApplicationContextInitializer... contextInitializers) { + for (ApplicationContextInitializer initializer : contextInitializers) { + this.contextInitializers.add((ApplicationContextInitializer) initializer); } } @@ -450,6 +450,23 @@ public abstract class FrameworkServlet extends HttpServletBean implements Applic this.dispatchTraceRequest = dispatchTraceRequest; } + /** + * Called by Spring via {@link ApplicationContextAware} to inject the current + * application context. This method allows FrameworkServlets to be registered as + * Spring beans inside an existing {@link WebApplicationContext} rather than + * {@link #findWebApplicationContext() finding} a + * {@link org.springframework.web.context.ContextLoaderListener bootstrapped} context. + *

Primarily added to support use in embedded servlet containers. + * @since 4.0 + */ + @Override + public void setApplicationContext(ApplicationContext applicationContext) { + if (this.webApplicationContext == null && applicationContext instanceof WebApplicationContext) { + this.webApplicationContext = (WebApplicationContext) applicationContext; + this.webApplicationContextInjected = true; + } + } + /** * Overridden method of {@link HttpServletBean}, invoked after any bean properties @@ -796,11 +813,9 @@ public abstract class FrameworkServlet extends HttpServletBean implements Applic */ @Override public void destroy() { - if (this.webApplicationContextInjected) { - return; - } getServletContext().log("Destroying Spring FrameworkServlet '" + getServletName() + "'"); - if (this.webApplicationContext instanceof ConfigurableApplicationContext) { + // Only call close() on WebApplicationContext if locally managed... + if (this.webApplicationContext instanceof ConfigurableApplicationContext && !this.webApplicationContextInjected) { ((ConfigurableApplicationContext) this.webApplicationContext).close(); } } @@ -1065,29 +1080,6 @@ public abstract class FrameworkServlet extends HttpServletBean implements Applic return (userPrincipal != null ? userPrincipal.getName() : null); } - /** - * Called by Spring via {@link ApplicationContextAware} to inject the current - * application context. This method allows FrameworkServlets to be registered as - * Spring Beans inside an existing {@link WebApplicationContext} rather than - * {@link #findWebApplicationContext() finding} a - * {@link org.springframework.web.context.ContextLoaderListener bootstrapped} context. - *

Primarily added to support use in embedded servlet containers, this method is not - * intended to be called directly. - * @since 4.0 - */ - @Override - public void setApplicationContext(ApplicationContext applicationContext) - throws BeansException { - if (this.webApplicationContext == null - && applicationContext instanceof WebApplicationContext) { - if (logger.isDebugEnabled()) { - logger.debug("Using existing application context for " - + ClassUtils.getShortName(getClass())); - } - this.webApplicationContext = (WebApplicationContext) applicationContext; - this.webApplicationContextInjected = true; - } - } /** * Subclasses must implement this method to do the work of request handling, diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java index 93ee5a1662..f5b6515b41 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java @@ -29,6 +29,7 @@ import junit.framework.TestCase; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.PropertyValue; +import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.core.env.ConfigurableEnvironment; import org.springframework.core.env.DummyEnvironment; @@ -37,7 +38,9 @@ import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.mock.web.test.MockServletConfig; import org.springframework.mock.web.test.MockServletContext; import org.springframework.tests.sample.beans.TestBean; +import org.springframework.web.context.ConfigurableWebApplicationContext; import org.springframework.web.context.ConfigurableWebEnvironment; +import org.springframework.web.context.ContextLoader; import org.springframework.web.context.ServletConfigAwareBean; import org.springframework.web.context.ServletContextAwareBean; import org.springframework.web.context.WebApplicationContext; @@ -784,6 +787,46 @@ public class DispatcherServletTests extends TestCase { assertThat(response.getHeader("Allow"), equalTo("GET, HEAD, POST, PUT, DELETE, TRACE, OPTIONS, PATCH")); } + public void testContextInitializers() throws Exception { + DispatcherServlet servlet = new DispatcherServlet(); + servlet.setContextClass(SimpleWebApplicationContext.class); + servlet.setContextInitializers(new TestWebContextInitializer(), new OtherWebContextInitializer()); + servlet.init(servletConfig); + assertEquals("true", servletConfig.getServletContext().getAttribute("initialized")); + assertEquals("true", servletConfig.getServletContext().getAttribute("otherInitialized")); + } + + public void testContextInitializerClasses() throws Exception { + DispatcherServlet servlet = new DispatcherServlet(); + servlet.setContextClass(SimpleWebApplicationContext.class); + servlet.setContextInitializerClasses( + TestWebContextInitializer.class.getName() + "," + OtherWebContextInitializer.class.getName()); + servlet.init(servletConfig); + assertEquals("true", servletConfig.getServletContext().getAttribute("initialized")); + assertEquals("true", servletConfig.getServletContext().getAttribute("otherInitialized")); + } + + public void testGlobalInitializerClasses() throws Exception { + DispatcherServlet servlet = new DispatcherServlet(); + servlet.setContextClass(SimpleWebApplicationContext.class); + servletConfig.getServletContext().setInitParameter(ContextLoader.GLOBAL_INITIALIZER_CLASSES_PARAM, + TestWebContextInitializer.class.getName() + "," + OtherWebContextInitializer.class.getName()); + servlet.init(servletConfig); + assertEquals("true", servletConfig.getServletContext().getAttribute("initialized")); + assertEquals("true", servletConfig.getServletContext().getAttribute("otherInitialized")); + } + + public void testMixedInitializerClasses() throws Exception { + DispatcherServlet servlet = new DispatcherServlet(); + servlet.setContextClass(SimpleWebApplicationContext.class); + servletConfig.getServletContext().setInitParameter(ContextLoader.GLOBAL_INITIALIZER_CLASSES_PARAM, + TestWebContextInitializer.class.getName()); + servlet.setContextInitializerClasses(OtherWebContextInitializer.class.getName()); + servlet.init(servletConfig); + assertEquals("true", servletConfig.getServletContext().getAttribute("initialized")); + assertEquals("true", servletConfig.getServletContext().getAttribute("otherInitialized")); + } + public static class ControllerFromParent implements Controller { @@ -793,4 +836,22 @@ public class DispatcherServletTests extends TestCase { } } + + private static class TestWebContextInitializer implements ApplicationContextInitializer { + + @Override + public void initialize(ConfigurableWebApplicationContext applicationContext) { + applicationContext.getServletContext().setAttribute("initialized", "true"); + } + } + + + private static class OtherWebContextInitializer implements ApplicationContextInitializer { + + @Override + public void initialize(ConfigurableWebApplicationContext applicationContext) { + applicationContext.getServletContext().setAttribute("otherInitialized", "true"); + } + } + }