diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/AotServices.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/AotServices.java index ab271011dfa..912eb48d101 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/AotServices.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/AotServices.java @@ -17,12 +17,15 @@ package org.springframework.beans.factory.aot; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; +import java.util.IdentityHashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.stream.Stream; +import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableBeanFactory; @@ -30,6 +33,7 @@ import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.core.io.support.SpringFactoriesLoader; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; /** * A collection of AOT services that can be {@link Loader loaded} from @@ -50,16 +54,30 @@ public final class AotServices implements Iterable { private final Map beans; + private final Map sources; + private AotServices(List loaded, Map beans) { + this.services = collectServices(loaded, beans); + this.sources = collectSources(loaded, beans.values()); + this.beans = beans; + } + + private List collectServices(List loaded, Map beans) { List services = new ArrayList<>(); services.addAll(beans.values()); services.addAll(loaded); AnnotationAwareOrderComparator.sort(services); - this.services = Collections.unmodifiableList(services); - this.beans = beans; + return Collections.unmodifiableList(services); } + private Map collectSources(Collection loaded, + Collection beans) { + Map sources = new IdentityHashMap<>(); + loaded.forEach(service -> sources.put(service, Source.SPRING_FACTORIES_LOADER)); + beans.forEach(service -> sources.put(service, Source.BEAN_FACTORY)); + return Collections.unmodifiableMap(sources); + } /** * Return a new {@link Loader} that will obtain AOT services from @@ -154,6 +172,18 @@ public final class AotServices implements Iterable { return this.beans.get(beanName); } + /** + * Return the source of the given service. + * @param service the service instance + * @return the source of the service + */ + public Source getSource(T service) { + Source source = this.sources.get(service); + Assert.state(source != null, + "Unable to find service " + ObjectUtils.identityToString(source)); + return source; + } + /** * Loader class used to actually load the services. @@ -189,4 +219,21 @@ public final class AotServices implements Iterable { } + /** + * Sources from which services were obtained. + */ + public enum Source { + + /** + * An AOT service loaded from {@link SpringFactoriesLoader}. + */ + SPRING_FACTORIES_LOADER, + + /** + * An AOT service loaded from a {@link BeanFactory}. + */ + BEAN_FACTORY + + } + } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/AotServicesTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/AotServicesTests.java index 8d11204ed03..a637b0e897d 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/AotServicesTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/AotServicesTests.java @@ -22,6 +22,7 @@ import java.util.Enumeration; import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.aot.AotServices.Source; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.Ordered; @@ -30,6 +31,8 @@ import org.springframework.core.mock.MockSpringFactoriesLoader; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.Mockito.mock; /** * Tests for {@link AotServices}. @@ -163,6 +166,25 @@ class AotServicesTests { assertThat(loaded).map(Object::toString).containsExactly("b1", "l1", "b2", "l2"); } + @Test + void getSourceReturnsSource() { + MockSpringFactoriesLoader loader = new MockSpringFactoriesLoader(); + loader.addInstance(TestService.class, new TestServiceImpl()); + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("test", new RootBeanDefinition(TestBean.class)); + AotServices loaded = AotServices.factoriesAndBeans(loader, beanFactory).load(TestService.class); + assertThat(loaded.getSource(loaded.asList().get(0))).isEqualTo(Source.SPRING_FACTORIES_LOADER); + assertThat(loaded.getSource(loaded.asList().get(1))).isEqualTo(Source.BEAN_FACTORY); + TestService missing = mock(TestService.class); + assertThatIllegalStateException().isThrownBy(()->loaded.getSource(missing)); + } + + @Test + void getSourceWhenMissingThrowsException() { + AotServices loaded = AotServices.factories().load(TestService.class); + TestService missing = mock(TestService.class); + assertThatIllegalStateException().isThrownBy(()->loaded.getSource(missing)); + } interface TestService { }