Track the sources used to load AOT services

Update `AotServices` so that it tracks which source was used to
provide a given instance.

See gh-28866
This commit is contained in:
Phillip Webb 2022-07-28 12:38:59 +01:00
parent 5218cf4c16
commit a4f3d1d6e8
2 changed files with 71 additions and 2 deletions

View File

@ -17,12 +17,15 @@
package org.springframework.beans.factory.aot;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
@ -30,6 +33,7 @@ import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.core.io.support.SpringFactoriesLoader;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
/**
* A collection of AOT services that can be {@link Loader loaded} from
@ -50,16 +54,30 @@ public final class AotServices<T> implements Iterable<T> {
private final Map<String, T> beans;
private final Map<T, Source> sources;
private AotServices(List<T> loaded, Map<String, T> beans) {
this.services = collectServices(loaded, beans);
this.sources = collectSources(loaded, beans.values());
this.beans = beans;
}
private List<T> collectServices(List<T> loaded, Map<String, T> beans) {
List<T> services = new ArrayList<>();
services.addAll(beans.values());
services.addAll(loaded);
AnnotationAwareOrderComparator.sort(services);
this.services = Collections.unmodifiableList(services);
this.beans = beans;
return Collections.unmodifiableList(services);
}
private Map<T, Source> collectSources(Collection<T> loaded,
Collection<T> beans) {
Map<T, Source> sources = new IdentityHashMap<>();
loaded.forEach(service -> sources.put(service, Source.SPRING_FACTORIES_LOADER));
beans.forEach(service -> sources.put(service, Source.BEAN_FACTORY));
return Collections.unmodifiableMap(sources);
}
/**
* Return a new {@link Loader} that will obtain AOT services from
@ -154,6 +172,18 @@ public final class AotServices<T> implements Iterable<T> {
return this.beans.get(beanName);
}
/**
* Return the source of the given service.
* @param service the service instance
* @return the source of the service
*/
public Source getSource(T service) {
Source source = this.sources.get(service);
Assert.state(source != null,
"Unable to find service " + ObjectUtils.identityToString(source));
return source;
}
/**
* Loader class used to actually load the services.
@ -189,4 +219,21 @@ public final class AotServices<T> implements Iterable<T> {
}
/**
* Sources from which services were obtained.
*/
public enum Source {
/**
* An AOT service loaded from {@link SpringFactoriesLoader}.
*/
SPRING_FACTORIES_LOADER,
/**
* An AOT service loaded from a {@link BeanFactory}.
*/
BEAN_FACTORY
}
}

View File

@ -22,6 +22,7 @@ import java.util.Enumeration;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.aot.AotServices.Source;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.Ordered;
@ -30,6 +31,8 @@ import org.springframework.core.mock.MockSpringFactoriesLoader;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link AotServices}.
@ -163,6 +166,25 @@ class AotServicesTests {
assertThat(loaded).map(Object::toString).containsExactly("b1", "l1", "b2", "l2");
}
@Test
void getSourceReturnsSource() {
MockSpringFactoriesLoader loader = new MockSpringFactoriesLoader();
loader.addInstance(TestService.class, new TestServiceImpl());
DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
beanFactory.registerBeanDefinition("test", new RootBeanDefinition(TestBean.class));
AotServices<TestService> loaded = AotServices.factoriesAndBeans(loader, beanFactory).load(TestService.class);
assertThat(loaded.getSource(loaded.asList().get(0))).isEqualTo(Source.SPRING_FACTORIES_LOADER);
assertThat(loaded.getSource(loaded.asList().get(1))).isEqualTo(Source.BEAN_FACTORY);
TestService missing = mock(TestService.class);
assertThatIllegalStateException().isThrownBy(()->loaded.getSource(missing));
}
@Test
void getSourceWhenMissingThrowsException() {
AotServices<TestService> loaded = AotServices.factories().load(TestService.class);
TestService missing = mock(TestService.class);
assertThatIllegalStateException().isThrownBy(()->loaded.getSource(missing));
}
interface TestService {
}