Track the sources used to load AOT services
Update `AotServices` so that it tracks which source was used to provide a given instance. See gh-28866
This commit is contained in:
parent
5218cf4c16
commit
a4f3d1d6e8
|
|
@ -17,12 +17,15 @@
|
|||
package org.springframework.beans.factory.aot;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.IdentityHashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import org.springframework.beans.factory.BeanFactory;
|
||||
import org.springframework.beans.factory.BeanFactoryUtils;
|
||||
import org.springframework.beans.factory.ListableBeanFactory;
|
||||
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
|
||||
|
|
@ -30,6 +33,7 @@ import org.springframework.core.annotation.AnnotationAwareOrderComparator;
|
|||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ObjectUtils;
|
||||
|
||||
/**
|
||||
* A collection of AOT services that can be {@link Loader loaded} from
|
||||
|
|
@ -50,16 +54,30 @@ public final class AotServices<T> implements Iterable<T> {
|
|||
|
||||
private final Map<String, T> beans;
|
||||
|
||||
private final Map<T, Source> sources;
|
||||
|
||||
|
||||
private AotServices(List<T> loaded, Map<String, T> beans) {
|
||||
this.services = collectServices(loaded, beans);
|
||||
this.sources = collectSources(loaded, beans.values());
|
||||
this.beans = beans;
|
||||
}
|
||||
|
||||
private List<T> collectServices(List<T> loaded, Map<String, T> beans) {
|
||||
List<T> services = new ArrayList<>();
|
||||
services.addAll(beans.values());
|
||||
services.addAll(loaded);
|
||||
AnnotationAwareOrderComparator.sort(services);
|
||||
this.services = Collections.unmodifiableList(services);
|
||||
this.beans = beans;
|
||||
return Collections.unmodifiableList(services);
|
||||
}
|
||||
|
||||
private Map<T, Source> collectSources(Collection<T> loaded,
|
||||
Collection<T> beans) {
|
||||
Map<T, Source> sources = new IdentityHashMap<>();
|
||||
loaded.forEach(service -> sources.put(service, Source.SPRING_FACTORIES_LOADER));
|
||||
beans.forEach(service -> sources.put(service, Source.BEAN_FACTORY));
|
||||
return Collections.unmodifiableMap(sources);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new {@link Loader} that will obtain AOT services from
|
||||
|
|
@ -154,6 +172,18 @@ public final class AotServices<T> implements Iterable<T> {
|
|||
return this.beans.get(beanName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the source of the given service.
|
||||
* @param service the service instance
|
||||
* @return the source of the service
|
||||
*/
|
||||
public Source getSource(T service) {
|
||||
Source source = this.sources.get(service);
|
||||
Assert.state(source != null,
|
||||
"Unable to find service " + ObjectUtils.identityToString(source));
|
||||
return source;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Loader class used to actually load the services.
|
||||
|
|
@ -189,4 +219,21 @@ public final class AotServices<T> implements Iterable<T> {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
* Sources from which services were obtained.
|
||||
*/
|
||||
public enum Source {
|
||||
|
||||
/**
|
||||
* An AOT service loaded from {@link SpringFactoriesLoader}.
|
||||
*/
|
||||
SPRING_FACTORIES_LOADER,
|
||||
|
||||
/**
|
||||
* An AOT service loaded from a {@link BeanFactory}.
|
||||
*/
|
||||
BEAN_FACTORY
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import java.util.Enumeration;
|
|||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.beans.factory.aot.AotServices.Source;
|
||||
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
|
||||
import org.springframework.beans.factory.support.RootBeanDefinition;
|
||||
import org.springframework.core.Ordered;
|
||||
|
|
@ -30,6 +31,8 @@ import org.springframework.core.mock.MockSpringFactoriesLoader;
|
|||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
||||
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Tests for {@link AotServices}.
|
||||
|
|
@ -163,6 +166,25 @@ class AotServicesTests {
|
|||
assertThat(loaded).map(Object::toString).containsExactly("b1", "l1", "b2", "l2");
|
||||
}
|
||||
|
||||
@Test
|
||||
void getSourceReturnsSource() {
|
||||
MockSpringFactoriesLoader loader = new MockSpringFactoriesLoader();
|
||||
loader.addInstance(TestService.class, new TestServiceImpl());
|
||||
DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
|
||||
beanFactory.registerBeanDefinition("test", new RootBeanDefinition(TestBean.class));
|
||||
AotServices<TestService> loaded = AotServices.factoriesAndBeans(loader, beanFactory).load(TestService.class);
|
||||
assertThat(loaded.getSource(loaded.asList().get(0))).isEqualTo(Source.SPRING_FACTORIES_LOADER);
|
||||
assertThat(loaded.getSource(loaded.asList().get(1))).isEqualTo(Source.BEAN_FACTORY);
|
||||
TestService missing = mock(TestService.class);
|
||||
assertThatIllegalStateException().isThrownBy(()->loaded.getSource(missing));
|
||||
}
|
||||
|
||||
@Test
|
||||
void getSourceWhenMissingThrowsException() {
|
||||
AotServices<TestService> loaded = AotServices.factories().load(TestService.class);
|
||||
TestService missing = mock(TestService.class);
|
||||
assertThatIllegalStateException().isThrownBy(()->loaded.getSource(missing));
|
||||
}
|
||||
|
||||
interface TestService {
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue