Introduce hooking of SpringApplication

Closes gh-30538
This commit is contained in:
Andy Wilkinson 2022-04-06 10:29:57 +01:00
parent 6d15e6d4ef
commit 9cd17f8a14
3 changed files with 217 additions and 10 deletions

View File

@ -288,6 +288,7 @@ public class SpringApplication {
* @return a running {@link ApplicationContext}
*/
public ConfigurableApplicationContext run(String... args) {
SpringApplicationHooks.hooks().preRun(this);
long startTime = System.nanoTime();
DefaultBootstrapContext bootstrapContext = createBootstrapContext();
ConfigurableApplicationContext context = null;
@ -302,27 +303,32 @@ public class SpringApplication {
context = createApplicationContext();
context.setApplicationStartup(this.applicationStartup);
prepareContext(bootstrapContext, context, environment, listeners, applicationArguments, printedBanner);
refreshContext(context);
afterRefresh(context, applicationArguments);
Duration timeTakenToStartup = Duration.ofNanos(System.nanoTime() - startTime);
if (this.logStartupInfo) {
new StartupInfoLogger(this.mainApplicationClass).logStarted(getApplicationLog(), timeTakenToStartup);
if (refreshContext(context)) {
afterRefresh(context, applicationArguments);
Duration timeTakenToStartup = Duration.ofNanos(System.nanoTime() - startTime);
if (this.logStartupInfo) {
new StartupInfoLogger(this.mainApplicationClass).logStarted(getApplicationLog(),
timeTakenToStartup);
}
listeners.started(context, timeTakenToStartup);
callRunners(context, applicationArguments);
}
listeners.started(context, timeTakenToStartup);
callRunners(context, applicationArguments);
}
catch (Throwable ex) {
handleRunFailure(context, ex, listeners);
throw new IllegalStateException(ex);
}
try {
Duration timeTakenToReady = Duration.ofNanos(System.nanoTime() - startTime);
listeners.ready(context, timeTakenToReady);
if (context.isRunning()) {
Duration timeTakenToReady = Duration.ofNanos(System.nanoTime() - startTime);
listeners.ready(context, timeTakenToReady);
}
}
catch (Throwable ex) {
handleRunFailure(context, ex, null);
throw new IllegalStateException(ex);
}
SpringApplicationHooks.hooks().postRun(this, context);
return context;
}
@ -397,11 +403,15 @@ public class SpringApplication {
listeners.contextLoaded(context);
}
private void refreshContext(ConfigurableApplicationContext context) {
private boolean refreshContext(ConfigurableApplicationContext context) {
if (!SpringApplicationHooks.hooks().preRefresh(this, context)) {
return false;
}
if (this.registerShutdownHook) {
shutdownHook.registerApplicationContext(context);
}
refresh(context);
return true;
}
private void configureHeadlessProperty() {

View File

@ -0,0 +1,169 @@
/*
* Copyright 2012-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot;
import java.util.ArrayList;
import java.util.List;
import org.springframework.context.ConfigurableApplicationContext;
/**
* Low-level hooks that can observe a {@link SpringApplication} and modify its behavior.
* Hooks are managed on a per-thread basis providing isolation when multiple applications
* are executed in parallel.
*
* @author Andy Wilkinson
*/
final class SpringApplicationHooks {
private static final ThreadLocal<Hooks> hooks = ThreadLocal.withInitial(Hooks::new);
private SpringApplicationHooks() {
}
/**
* Runs the given {@code action} with the given {@code hook} attached.
* @param hook the hook to attach
* @param action the action to run
* @param <T> the type of the action's result
* @return the result of the action
* @throws Exception if a failure occurs while performing the action
*/
static <T> T withHook(Hook hook, Action<T> action) throws Exception {
hooks.get().add(hook);
try {
return action.perform();
}
finally {
hooks.get().remove(hook);
}
}
/**
* Runs the given {@code action} with the given {@code hook} attached.
* @param hook the hook to attach
* @param action the action to run
*/
static void withHook(Hook hook, Runnable action) {
hooks.get().add(hook);
try {
action.run();
}
finally {
hooks.get().remove(hook);
}
}
static Hooks hooks() {
return hooks.get();
}
/**
* A hook that can observe and modify the behavior of a {@link SpringApplication}.
*/
interface Hook {
/**
* Called at the beginning of {@link SpringApplication#run(String...)}. Provides
* an opportunity to inspect and customise the application.
* @param application the application that is being run
*/
default void preRun(SpringApplication application) {
}
/**
* Called at the end of {@link SpringApplication#run(String...)}. Provides access
* to the {@link ConfigurableApplicationContext context} that has been created for
* the application.
* @param application the application that has been run
* @param context the application's context
*/
default void postRun(SpringApplication application, ConfigurableApplicationContext context) {
}
/**
* Called immediately before the given {@code context} is refreshed.
* @param application the application for which the context is being refreshed
* @param context the application's context
* @return whether to continue with refresh processing
*/
default boolean preRefresh(SpringApplication application, ConfigurableApplicationContext context) {
return true;
}
}
/**
* An action that can be performed with a hook attached.
* <p>
* <strong>For internal use only.</strong>
*
* @param <T> the type of the action's result
*/
interface Action<T> {
/**
* Perform the action.
* @return the result of the action
* @throws Exception if a failure occurs
*/
T perform() throws Exception;
}
static final class Hooks implements Hook {
private final List<Hook> delegates = new ArrayList<>();
private void add(Hook hook) {
this.delegates.add(hook);
}
private void remove(Hook hook) {
this.delegates.remove(hook);
}
@Override
public void preRun(SpringApplication application) {
for (Hook delegate : this.delegates) {
delegate.preRun(application);
}
}
@Override
public void postRun(SpringApplication application, ConfigurableApplicationContext context) {
for (Hook delegate : this.delegates) {
delegate.postRun(application, context);
}
}
@Override
public boolean preRefresh(SpringApplication application, ConfigurableApplicationContext context) {
for (Hook delegate : this.delegates) {
if (!delegate.preRefresh(application, context)) {
return false;
}
}
return true;
}
}
}

View File

@ -54,6 +54,8 @@ import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanNameGenerator;
import org.springframework.beans.factory.support.DefaultBeanNameGenerator;
import org.springframework.boot.BootstrapRegistry.InstanceSupplier;
import org.springframework.boot.SpringApplicationHooks.Action;
import org.springframework.boot.SpringApplicationHooks.Hook;
import org.springframework.boot.availability.AvailabilityChangeEvent;
import org.springframework.boot.availability.AvailabilityState;
import org.springframework.boot.availability.LivenessState;
@ -126,6 +128,7 @@ import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.then;
@ -1262,6 +1265,31 @@ class SpringApplicationTests {
.didNotRegisterApplicationContext(failure.getApplicationContext());
}
@Test
void hookIsCalledWhenApplicationIsRun() throws Exception {
Hook hook = mock(Hook.class);
SpringApplication application = new SpringApplication(ExampleConfig.class);
application.setWebApplicationType(WebApplicationType.NONE);
given(hook.preRefresh(eq(application), any(ConfigurableApplicationContext.class))).willReturn(true);
this.context = SpringApplicationHooks.withHook(hook, (Action<ConfigurableApplicationContext>) application::run);
then(hook).should().preRun(application);
then(hook).should().preRefresh(application, this.context);
then(hook).should().postRun(application, this.context);
assertThat(this.context.isRunning()).isTrue();
}
@Test
void hookIsCalledAndCanPreventRefreshWhenApplicationIsRun() throws Exception {
Hook hook = mock(Hook.class);
SpringApplication application = new SpringApplication(ExampleConfig.class);
application.setWebApplicationType(WebApplicationType.NONE);
this.context = SpringApplicationHooks.withHook(hook, (Action<ConfigurableApplicationContext>) application::run);
then(hook).should().preRun(application);
then(hook).should().preRefresh(application, this.context);
then(hook).should().postRun(application, this.context);
assertThat(this.context.isRunning()).isFalse();
}
private <S extends AvailabilityState> ArgumentMatcher<ApplicationEvent> isAvailabilityChangeEventWithState(
S state) {
return (argument) -> (argument instanceof AvailabilityChangeEvent<?>)