diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/EnvironmentConverter.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/EnvironmentConverter.java index 840636088e4..7e6c325c975 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/EnvironmentConverter.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/EnvironmentConverter.java @@ -26,7 +26,6 @@ import org.springframework.core.env.MutablePropertySources; import org.springframework.core.env.PropertySource; import org.springframework.core.env.StandardEnvironment; import org.springframework.util.ClassUtils; -import org.springframework.web.context.ConfigurableWebEnvironment; import org.springframework.web.context.support.StandardServletEnvironment; /** @@ -34,6 +33,7 @@ import org.springframework.web.context.support.StandardServletEnvironment; * * @author Ethan Rubinson * @author Andy Wilkinson + * @author Madhura Bhave */ final class EnvironmentConverter { @@ -61,46 +61,44 @@ final class EnvironmentConverter { } /** - * Converts the given {@code environment} to a {@link StandardEnvironment}. If the - * environment is already a {@code StandardEnvironment} and is not a - * {@link ConfigurableWebEnvironment} no conversion is performed and it is returned - * unchanged. + * Converts the given {@code environment} to the given {@link StandardEnvironment} + * type. If the environment is already of the same type, no conversion is performed + * and it is returned unchanged. * @param environment the Environment to convert + * @param conversionType the type to convert the Environment to * @return the converted Environment */ - StandardEnvironment convertToStandardEnvironmentIfNecessary( - ConfigurableEnvironment environment) { - if (environment instanceof StandardEnvironment - && !isWebEnvironment(environment, this.classLoader)) { + StandardEnvironment convertEnvironmentIfNecessary(ConfigurableEnvironment environment, + Class conversionType) { + if (conversionType.equals(environment.getClass())) { return (StandardEnvironment) environment; } - return convertToStandardEnvironment(environment); + return convertEnvironment(environment, conversionType); } - private boolean isWebEnvironment(ConfigurableEnvironment environment, - ClassLoader classLoader) { - try { - Class webEnvironmentClass = ClassUtils - .forName(CONFIGURABLE_WEB_ENVIRONMENT_CLASS, classLoader); - return (webEnvironmentClass.isInstance(environment)); - } - catch (Throwable ex) { - return false; - } - } - - private StandardEnvironment convertToStandardEnvironment( - ConfigurableEnvironment environment) { - StandardEnvironment result = new StandardEnvironment(); + private StandardEnvironment convertEnvironment(ConfigurableEnvironment environment, + Class conversionType) { + StandardEnvironment result = createEnvironment(conversionType); result.setActiveProfiles(environment.getActiveProfiles()); result.setConversionService(environment.getConversionService()); - copyNonServletPropertySources(environment, result); + copyPropertySources(environment, result); return result; } - private void copyNonServletPropertySources(ConfigurableEnvironment source, + private StandardEnvironment createEnvironment( + Class conversionType) { + try { + return conversionType.newInstance(); + } + catch (Exception ex) { + return new StandardEnvironment(); + } + } + + private void copyPropertySources(ConfigurableEnvironment source, StandardEnvironment target) { - removeAllPropertySources(target.getPropertySources()); + removePropertySources(target.getPropertySources(), + isServletEnvironment(target.getClass(), this.classLoader)); for (PropertySource propertySource : source.getPropertySources()) { if (!SERVLET_ENVIRONMENT_SOURCE_NAMES.contains(propertySource.getName())) { target.getPropertySources().addLast(propertySource); @@ -108,13 +106,31 @@ final class EnvironmentConverter { } } - private void removeAllPropertySources(MutablePropertySources propertySources) { + private boolean isServletEnvironment(Class conversionType, + ClassLoader classLoader) { + try { + Class webEnvironmentClass = ClassUtils + .forName(CONFIGURABLE_WEB_ENVIRONMENT_CLASS, classLoader); + return webEnvironmentClass.isAssignableFrom(conversionType); + } + catch (Throwable ex) { + return false; + } + } + + private void removePropertySources(MutablePropertySources propertySources, + boolean isServletEnvironment) { Set names = new HashSet<>(); for (PropertySource propertySource : propertySources) { names.add(propertySource.getName()); } for (String name : names) { - propertySources.remove(name); + if (!isServletEnvironment) { + propertySources.remove(name); + } + else if (!SERVLET_ENVIRONMENT_SOURCE_NAMES.contains(name)) { + propertySources.remove(name); + } } } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java index d9e5bf6e3f0..6625901fc6c 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java @@ -45,6 +45,7 @@ import org.springframework.boot.Banner.Mode; import org.springframework.boot.context.properties.bind.Bindable; import org.springframework.boot.context.properties.bind.Binder; import org.springframework.boot.context.properties.source.ConfigurationPropertySources; +import org.springframework.boot.web.reactive.context.StandardReactiveWebEnvironment; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.ApplicationListener; @@ -239,6 +240,8 @@ public class SpringApplication { private boolean allowBeanDefinitionOverriding; + private boolean isCustomEnvironment = false; + /** * Create a new {@link SpringApplication} instance. The application context will load * beans from the specified primary sources (see {@link SpringApplication class-level} @@ -364,14 +367,24 @@ public class SpringApplication { configureEnvironment(environment, applicationArguments.getSourceArgs()); listeners.environmentPrepared(environment); bindToSpringApplication(environment); - if (this.webApplicationType == WebApplicationType.NONE) { + if (!this.isCustomEnvironment) { environment = new EnvironmentConverter(getClassLoader()) - .convertToStandardEnvironmentIfNecessary(environment); + .convertEnvironmentIfNecessary(environment, deduceEnvironmentClass()); } ConfigurationPropertySources.attach(environment); return environment; } + private Class deduceEnvironmentClass() { + if (this.webApplicationType == WebApplicationType.SERVLET) { + return StandardServletEnvironment.class; + } + if (this.webApplicationType == WebApplicationType.REACTIVE) { + return StandardReactiveWebEnvironment.class; + } + return StandardEnvironment.class; + } + private void prepareContext(ConfigurableApplicationContext context, ConfigurableEnvironment environment, SpringApplicationRunListeners listeners, ApplicationArguments applicationArguments, Banner printedBanner) { @@ -469,6 +482,9 @@ public class SpringApplication { if (this.webApplicationType == WebApplicationType.SERVLET) { return new StandardServletEnvironment(); } + if (this.webApplicationType == WebApplicationType.REACTIVE) { + return new StandardReactiveWebEnvironment(); + } return new StandardEnvironment(); } @@ -1071,6 +1087,7 @@ public class SpringApplication { * @param environment the environment */ public void setEnvironment(ConfigurableEnvironment environment) { + this.isCustomEnvironment = true; this.environment = environment; } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/EnvironmentConverterTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/EnvironmentConverterTests.java index d65ef80211e..73234b9b025 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/EnvironmentConverterTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/EnvironmentConverterTests.java @@ -16,10 +16,14 @@ package org.springframework.boot; +import java.util.HashSet; +import java.util.Set; + import org.junit.Test; import org.springframework.core.convert.support.ConfigurableConversionService; import org.springframework.core.env.AbstractEnvironment; +import org.springframework.core.env.PropertySource; import org.springframework.core.env.StandardEnvironment; import org.springframework.mock.env.MockEnvironment; import org.springframework.web.context.support.StandardServletEnvironment; @@ -32,6 +36,7 @@ import static org.mockito.Mockito.mock; * * @author Ethan Rubinson * @author Andy Wilkinson + * @author Madhura Bhave */ public class EnvironmentConverterTests { @@ -43,7 +48,8 @@ public class EnvironmentConverterTests { AbstractEnvironment originalEnvironment = new MockEnvironment(); originalEnvironment.setActiveProfiles("activeProfile1", "activeProfile2"); StandardEnvironment convertedEnvironment = this.environmentConverter - .convertToStandardEnvironmentIfNecessary(originalEnvironment); + .convertEnvironmentIfNecessary(originalEnvironment, + StandardEnvironment.class); assertThat(convertedEnvironment.getActiveProfiles()) .containsExactly("activeProfile1", "activeProfile2"); } @@ -55,16 +61,18 @@ public class EnvironmentConverterTests { ConfigurableConversionService.class); originalEnvironment.setConversionService(conversionService); StandardEnvironment convertedEnvironment = this.environmentConverter - .convertToStandardEnvironmentIfNecessary(originalEnvironment); + .convertEnvironmentIfNecessary(originalEnvironment, + StandardEnvironment.class); assertThat(convertedEnvironment.getConversionService()) .isEqualTo(conversionService); } @Test - public void standardEnvironmentIsReturnedUnconverted() { + public void envClassSameShouldReturnEnvironmentUnconverted() { StandardEnvironment standardEnvironment = new StandardEnvironment(); StandardEnvironment convertedEnvironment = this.environmentConverter - .convertToStandardEnvironmentIfNecessary(standardEnvironment); + .convertEnvironmentIfNecessary(standardEnvironment, + StandardEnvironment.class); assertThat(convertedEnvironment).isSameAs(standardEnvironment); } @@ -72,8 +80,53 @@ public class EnvironmentConverterTests { public void standardServletEnvironmentIsConverted() { StandardServletEnvironment standardServletEnvironment = new StandardServletEnvironment(); StandardEnvironment convertedEnvironment = this.environmentConverter - .convertToStandardEnvironmentIfNecessary(standardServletEnvironment); + .convertEnvironmentIfNecessary(standardServletEnvironment, + StandardEnvironment.class); assertThat(convertedEnvironment).isNotSameAs(standardServletEnvironment); } + @Test + public void servletPropertySourcesAreNotCopiedOverIfNotWebEnvironment() { + StandardServletEnvironment standardServletEnvironment = new StandardServletEnvironment(); + StandardEnvironment convertedEnvironment = this.environmentConverter + .convertEnvironmentIfNecessary(standardServletEnvironment, + StandardEnvironment.class); + assertThat(convertedEnvironment).isNotSameAs(standardServletEnvironment); + Set names = new HashSet<>(); + for (PropertySource propertySource : convertedEnvironment + .getPropertySources()) { + names.add(propertySource.getName()); + } + assertThat(names).doesNotContain( + StandardServletEnvironment.SERVLET_CONTEXT_PROPERTY_SOURCE_NAME, + StandardServletEnvironment.SERVLET_CONFIG_PROPERTY_SOURCE_NAME, + StandardServletEnvironment.JNDI_PROPERTY_SOURCE_NAME); + } + + @Test + public void envClassSameShouldReturnEnvironmentUnconvertedEvenForWeb() { + StandardServletEnvironment standardServletEnvironment = new StandardServletEnvironment(); + StandardEnvironment convertedEnvironment = this.environmentConverter + .convertEnvironmentIfNecessary(standardServletEnvironment, + StandardServletEnvironment.class); + assertThat(convertedEnvironment).isSameAs(standardServletEnvironment); + } + + @Test + public void servletPropertySourcesArePresentWhenTypeToConvertIsWeb() { + StandardEnvironment standardEnvironment = new StandardEnvironment(); + StandardEnvironment convertedEnvironment = this.environmentConverter + .convertEnvironmentIfNecessary(standardEnvironment, + StandardServletEnvironment.class); + assertThat(convertedEnvironment).isNotSameAs(standardEnvironment); + Set names = new HashSet<>(); + for (PropertySource propertySource : convertedEnvironment + .getPropertySources()) { + names.add(propertySource.getName()); + } + assertThat(names).contains( + StandardServletEnvironment.SERVLET_CONTEXT_PROPERTY_SOURCE_NAME, + StandardServletEnvironment.SERVLET_CONFIG_PROPERTY_SOURCE_NAME); + } + } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java index 0c7bbb1a1f4..993fb8e2c0d 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java @@ -61,6 +61,7 @@ import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactory; import org.springframework.boot.web.reactive.context.AnnotationConfigReactiveWebServerApplicationContext; import org.springframework.boot.web.reactive.context.ReactiveWebApplicationContext; +import org.springframework.boot.web.reactive.context.StandardReactiveWebEnvironment; import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; @@ -415,6 +416,25 @@ public class SpringApplicationTests { .isInstanceOf(AnnotationConfigReactiveWebServerApplicationContext.class); } + @Test + public void environmentForWeb() { + SpringApplication application = new SpringApplication(ExampleWebConfig.class); + application.setWebApplicationType(WebApplicationType.SERVLET); + this.context = application.run(); + assertThat(this.context.getEnvironment()) + .isInstanceOf(StandardServletEnvironment.class); + } + + @Test + public void environmentForReactiveWeb() { + SpringApplication application = new SpringApplication( + ExampleReactiveWebConfig.class); + application.setWebApplicationType(WebApplicationType.REACTIVE); + this.context = application.run(); + assertThat(this.context.getEnvironment()) + .isInstanceOf(StandardReactiveWebEnvironment.class); + } + @Test public void customEnvironment() { TestSpringApplication application = new TestSpringApplication( @@ -1100,6 +1120,35 @@ public class SpringApplicationTests { .isNotInstanceOfAny(ConfigurableWebEnvironment.class); } + @Test + public void webApplicationConfiguredViaAPropertyHasTheCorrectTypeOfContextAndEnvironment() { + ConfigurableApplicationContext context = new SpringApplication( + ExampleWebConfig.class).run("--spring.main.web-application-type=servlet"); + assertThat(context).isInstanceOfAny(WebApplicationContext.class); + assertThat(context.getEnvironment()) + .isInstanceOfAny(StandardServletEnvironment.class); + } + + @Test + public void reactiveApplicationConfiguredViaAPropertyHasTheCorrectTypeOfContextAndEnvironment() { + ConfigurableApplicationContext context = new SpringApplication( + ExampleReactiveWebConfig.class) + .run("--spring.main.web-application-type=reactive"); + assertThat(context).isInstanceOfAny(ReactiveWebApplicationContext.class); + assertThat(context.getEnvironment()) + .isInstanceOfAny(StandardReactiveWebEnvironment.class); + } + + @Test + public void environmentIsConvertedIfTypeDoesNotMatch() { + ConfigurableApplicationContext context = new SpringApplication( + ExampleReactiveWebConfig.class) + .run("--spring.profiles.active=withwebapplicationtype"); + assertThat(context).isInstanceOfAny(ReactiveWebApplicationContext.class); + assertThat(context.getEnvironment()) + .isInstanceOfAny(StandardReactiveWebEnvironment.class); + } + @Test public void failureResultsInSingleStackTrace() throws Exception { ThreadGroup group = new ThreadGroup("main"); diff --git a/spring-boot-project/spring-boot/src/test/resources/application-withwebapplicationtype.yml b/spring-boot-project/spring-boot/src/test/resources/application-withwebapplicationtype.yml new file mode 100644 index 00000000000..9c913ff8db5 --- /dev/null +++ b/spring-boot-project/spring-boot/src/test/resources/application-withwebapplicationtype.yml @@ -0,0 +1 @@ +spring.main.web-application-type: reactive \ No newline at end of file