Polish RegistrationBean logic

This commit is contained in:
Phillip Webb 2013-08-01 00:24:52 -07:00
parent d835ee548d
commit 25f31b6358
6 changed files with 49 additions and 56 deletions

View File

@ -217,7 +217,7 @@ public abstract class AbstractEmbeddedServletContainerFactory implements
@Override
public void addErrorPages(ErrorPage... errorPages) {
Assert.notNull(this.initializers, "ErrorPages must not be null");
Assert.notNull(errorPages, "ErrorPages must not be null");
this.errorPages.addAll(Arrays.asList(errorPages));
}

View File

@ -20,7 +20,6 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map.Entry;
@ -64,9 +63,10 @@ import org.springframework.web.context.support.WebApplicationContextUtils;
* <p>
* In addition, any {@link Servlet} or {@link Filter} beans defined in the context will be
* automatically registered with the embedded Servlet container. In the case of a single
* Servlet bean, the '/*' mapping will be used. If multiple Servlet beans are found then
* the lowercase bean name will be used as a mapping prefix. Filter beans will be mapped
* to all URLs ('/*').
* Servlet bean, the '/' mapping will be used. If multiple Servlet beans are found then
* the lowercase bean name will be used as a mapping prefix. Any Servlet named
* 'dispatcherServlet' will always be mapped to '/'. Filter beans will be mapped to all
* URLs ('/*').
*
* <p>
* For more advanced configuration, the context can instead define beans that implement
@ -205,43 +205,47 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
protected Collection<ServletContextInitializer> getServletContextInitializerBeans() {
Set<ServletContextInitializer> initializers = new LinkedHashSet<ServletContextInitializer>();
Set<Object> targets = new HashSet<Object>();
Set<Servlet> servletRegistrations = new LinkedHashSet<Servlet>();
Set<Filter> filterRegistrations = new LinkedHashSet<Filter>();
for (Entry<String, ServletContextInitializer> initializerBean : getOrderedBeansOfType(ServletContextInitializer.class)) {
ServletContextInitializer initializer = initializerBean.getValue();
if (initializer instanceof RegistrationBean) {
targets.add(((RegistrationBean) initializer).getRegistrationTarget());
}
initializers.add(initializer);
if (initializer instanceof ServletRegistrationBean) {
servletRegistrations.add(((ServletRegistrationBean) initializer)
.getServlet());
}
if (initializer instanceof FilterRegistrationBean) {
filterRegistrations.add(((FilterRegistrationBean) initializer)
.getFilter());
}
}
List<Entry<String, Servlet>> servletBeans = getOrderedBeansOfType(Servlet.class);
for (Entry<String, Servlet> servletBean : servletBeans) {
final String name = servletBean.getKey();
Servlet servlet = servletBean.getValue();
if (targets.contains(servlet)) {
continue;
if (!servletRegistrations.contains(servlet)) {
String url = (servletBeans.size() == 1 ? "/" : "/" + name + "/");
if (name.equals(DISPATCHER_SERVLET_NAME)) {
url = "/"; // always map the main dispatcherServlet to "/"
}
ServletRegistrationBean registration = new ServletRegistrationBean(
servlet, url);
registration.setName(name);
registration.setMultipartConfig(getMultipartConfig());
initializers.add(registration);
}
String url = (servletBeans.size() == 1 ? "/" : "/" + name + "/*");
if (name.equals(DISPATCHER_SERVLET_NAME)) {
url = "/"; // always map the main dispatcherServlet to "/"
}
ServletRegistrationBean registration = new ServletRegistrationBean(servlet,
url);
registration.setName(name);
registration.setMultipartConfig(getMultipartConfig());
initializers.add(registration);
}
for (Entry<String, Filter> filterBean : getOrderedBeansOfType(Filter.class)) {
String name = filterBean.getKey();
Filter filter = filterBean.getValue();
if (targets.contains(filter)) {
continue;
if (!servletRegistrations.contains(filter)) {
FilterRegistrationBean registration = new FilterRegistrationBean(filter);
registration.setName(name);
initializers.add(registration);
}
FilterRegistrationBean registration = new FilterRegistrationBean(filter);
registration.setName(name);
initializers.add(registration);
}
return initializers;

View File

@ -93,6 +93,13 @@ public class FilterRegistrationBean extends RegistrationBean {
}
}
/**
* Returns the filter being registered.
*/
protected Filter getFilter() {
return this.filter;
}
/**
* Set the filter to be registered.
*/
@ -220,12 +227,7 @@ public class FilterRegistrationBean extends RegistrationBean {
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
Assert.notNull(this.filter, "Filter must not be null");
configure(servletContext.addFilter(getName(), this.filter));
}
@Override
public Object getRegistrationTarget() {
return this.filter;
configure(servletContext.addFilter(getOrDeduceName(this.filter), this.filter));
}
/**

View File

@ -48,13 +48,6 @@ public abstract class RegistrationBean implements ServletContextInitializer {
this.name = name;
}
/**
* @return the name
*/
public String getName() {
return getOrDeduceName(getRegistrationTarget());
}
/**
* Sets if asynchronous operations are support for this registration. If not specified
* defaults to {@code true}.
@ -98,20 +91,12 @@ public abstract class RegistrationBean implements ServletContextInitializer {
this.initParameters.put(name, value);
}
/**
* The target of the registration (e.g. a Servlet or a Filter) that can be used to
* guess its name if none is supplied explicitly.
*
* @return the target of this registration
*/
public abstract Object getRegistrationTarget();
/**
* Deduces the name for this registration. Will return user specified name or fallback
* to convention based naming.
* @param value the object used for convention based names
*/
private String getOrDeduceName(Object value) {
protected final String getOrDeduceName(Object value) {
return (this.name != null ? this.name : Conventions.getVariableName(value));
}

View File

@ -75,6 +75,13 @@ public class ServletRegistrationBean extends RegistrationBean {
this.urlMappings.addAll(Arrays.asList(urlMappings));
}
/**
* Returns the servlet being registered.
*/
protected Servlet getServlet() {
return this.servlet;
}
/**
* Sets the servlet to be registered.
*/
@ -140,12 +147,7 @@ public class ServletRegistrationBean extends RegistrationBean {
* Returns the servlet name that will be registered.
*/
public String getServletName() {
return getName();
}
@Override
public Object getRegistrationTarget() {
return this.servlet;
return getOrDeduceName(this.servlet);
}
@Override

View File

@ -190,9 +190,9 @@ public class EmbeddedWebApplicationContextTests {
ordered.verify(servletContext).addServlet("servletBean1", servlet1);
ordered.verify(servletContext).addServlet("servletBean2", servlet2);
verify(escf.getRegisteredServlet(0).getRegistration()).addMapping(
"/servletBean1/*");
"/servletBean1/");
verify(escf.getRegisteredServlet(1).getRegistration()).addMapping(
"/servletBean2/*");
"/servletBean2/");
}
@Test
@ -215,7 +215,7 @@ public class EmbeddedWebApplicationContextTests {
ordered.verify(servletContext).addServlet("servletBean2", servlet2);
verify(escf.getRegisteredServlet(0).getRegistration()).addMapping("/");
verify(escf.getRegisteredServlet(1).getRegistration()).addMapping(
"/servletBean2/*");
"/servletBean2/");
}
@Test