diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionGenerationException.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionGenerationException.java new file mode 100644 index 0000000000..9ccc8e721d --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionGenerationException.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import org.springframework.beans.factory.config.BeanDefinition; + +/** + * Thrown when a bean definition could not be generated. + * + * @author Stephane Nicoll + * @since 6.0 + */ +@SuppressWarnings("serial") +public class BeanDefinitionGenerationException extends RuntimeException { + + private final String beanName; + + private final BeanDefinition beanDefinition; + + public BeanDefinitionGenerationException(String beanName, BeanDefinition beanDefinition, String message, Throwable cause) { + super(message, cause); + this.beanName = beanName; + this.beanDefinition = beanDefinition; + } + + public BeanDefinitionGenerationException(String beanName, BeanDefinition beanDefinition, String message) { + super(message); + this.beanName = beanName; + this.beanDefinition = beanDefinition; + } + + /** + * Return the bean name that could not be generated. + * @return the bean name + */ + public String getBeanName() { + return this.beanName; + } + + /** + * Return the bean definition that could not be generated. + * @return the bean definition + */ + public BeanDefinition getBeanDefinition() { + return this.beanDefinition; + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java new file mode 100644 index 0000000000..83e6d9a012 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanDefinitionsContribution.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.core.io.support.SpringFactoriesLoader; + +/** + * A {@link BeanFactoryContribution} that generates the bean definitions of a + * bean factory, using {@link BeanRegistrationContributionProvider} to use + * appropriate customizations if necessary. + * + *

{@link BeanRegistrationContributionProvider} can be ordered, with the default + * implementation always coming last. + * + * @author Stephane Nicoll + * @since 6.0 + * @see DefaultBeanRegistrationContributionProvider + */ +public class BeanDefinitionsContribution implements BeanFactoryContribution { + + private final DefaultListableBeanFactory beanFactory; + + private final List contributionProviders; + + private final Map contributions; + + BeanDefinitionsContribution(DefaultListableBeanFactory beanFactory, + List contributionProviders) { + this.beanFactory = beanFactory; + this.contributionProviders = contributionProviders; + this.contributions = new HashMap<>(); + } + + public BeanDefinitionsContribution(DefaultListableBeanFactory beanFactory) { + this(beanFactory, initializeProviders(beanFactory)); + } + + private static List initializeProviders(DefaultListableBeanFactory beanFactory) { + List providers = new ArrayList<>(SpringFactoriesLoader.loadFactories( + BeanRegistrationContributionProvider.class, beanFactory.getBeanClassLoader())); + providers.add(new DefaultBeanRegistrationContributionProvider(beanFactory)); + return providers; + } + + @Override + public void applyTo(BeanFactoryInitialization initialization) { + writeBeanDefinitions(initialization); + } + + private void writeBeanDefinitions(BeanFactoryInitialization initialization) { + for (String beanName : this.beanFactory.getBeanDefinitionNames()) { + handleMergedBeanDefinition(beanName, beanDefinition -> { + BeanFactoryContribution registrationContribution = getBeanRegistrationContribution( + beanName, beanDefinition); + registrationContribution.applyTo(initialization); + }); + } + } + + private BeanFactoryContribution getBeanRegistrationContribution( + String beanName, RootBeanDefinition beanDefinition) { + return this.contributions.computeIfAbsent(beanName, name -> { + for (BeanRegistrationContributionProvider provider : this.contributionProviders) { + BeanFactoryContribution contribution = provider.getContributionFor( + beanName, beanDefinition); + if (contribution != null) { + return contribution; + } + } + throw new BeanRegistrationContributionNotFoundException(beanName, beanDefinition); + }); + } + + private void handleMergedBeanDefinition(String beanName, Consumer consumer) { + RootBeanDefinition beanDefinition = (RootBeanDefinition) this.beanFactory.getMergedBeanDefinition(beanName); + try { + consumer.accept(beanDefinition); + } + catch (BeanDefinitionGenerationException ex) { + throw ex; + } + catch (Exception ex) { + String msg = String.format("Failed to handle bean with name '%s' and type '%s'", + beanName, beanDefinition.getResolvableType()); + throw new BeanDefinitionGenerationException(beanName, beanDefinition, msg, ex); + } + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationContributionNotFoundException.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationContributionNotFoundException.java new file mode 100644 index 0000000000..786ca801fc --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanRegistrationContributionNotFoundException.java @@ -0,0 +1,37 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import org.springframework.beans.factory.config.BeanDefinition; + +/** + * Thrown when no suitable {@link BeanFactoryContribution} can be provided + * for the registration of a given bean definition. + * + * @author Stephane Nicoll + * @since 6.0 + */ +@SuppressWarnings("serial") +public class BeanRegistrationContributionNotFoundException extends BeanDefinitionGenerationException { + + public BeanRegistrationContributionNotFoundException(String beanName, BeanDefinition beanDefinition) { + super(beanName, beanDefinition, String.format( + "No suitable contribution found for bean with name '%s' and type '%s'", + beanName, beanDefinition.getResolvableType())); + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java new file mode 100644 index 0000000000..17b1eb0b15 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanDefinitionsContributionTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.BDDMockito; +import org.mockito.Mockito; + +import org.springframework.aot.generator.DefaultGeneratedTypeContext; +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.generator.GeneratedTypeContext; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.support.CodeSnippet; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link BeanDefinitionsContribution}. + * + * @author Stephane Nicoll + */ +class BeanDefinitionsContributionTests { + + @Test + void contributeThrowsContributionNotFoundIfNoContributionIsAvailable() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("test", new RootBeanDefinition()); + BeanDefinitionsContribution contribution = new BeanDefinitionsContribution(beanFactory, + List.of(Mockito.mock(BeanRegistrationContributionProvider.class))); + BeanFactoryInitialization initialization = new BeanFactoryInitialization(createGenerationContext()); + assertThatThrownBy(() -> contribution.applyTo(initialization)) + .isInstanceOfSatisfying(BeanRegistrationContributionNotFoundException.class, ex -> { + assertThat(ex.getBeanName()).isEqualTo("test"); + assertThat(ex.getBeanDefinition()).isSameAs(beanFactory.getMergedBeanDefinition("test")); + }); + } + + @Test + void contributeThrowsBeanRegistrationExceptionIfContributionThrowsException() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("test", new RootBeanDefinition()); + BeanFactoryContribution testContribution = Mockito.mock(BeanFactoryContribution.class); + IllegalStateException testException = new IllegalStateException(); + BDDMockito.willThrow(testException).given(testContribution).applyTo(ArgumentMatchers.any(BeanFactoryInitialization.class)); + BeanDefinitionsContribution contribution = new BeanDefinitionsContribution(beanFactory, + List.of(new TestBeanRegistrationContributionProvider("test", testContribution))); + BeanFactoryInitialization initialization = new BeanFactoryInitialization(createGenerationContext()); + assertThatThrownBy(() -> contribution.applyTo(initialization)) + .isInstanceOfSatisfying(BeanDefinitionGenerationException.class, ex -> { + assertThat(ex.getBeanName()).isEqualTo("test"); + assertThat(ex.getBeanDefinition()).isSameAs(beanFactory.getMergedBeanDefinition("test")); + assertThat(ex.getCause()).isEqualTo(testException); + }); + } + + @Test + void contributeGeneratesBeanDefinitionsInOrder() { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("counter", BeanDefinitionBuilder + .rootBeanDefinition(Integer.class, "valueOf").addConstructorArgValue(42).getBeanDefinition()); + beanFactory.registerBeanDefinition("name", BeanDefinitionBuilder + .rootBeanDefinition(String.class).addConstructorArgValue("Hello").getBeanDefinition()); + CodeSnippet code = contribute(beanFactory, createGenerationContext()); + assertThat(code.getSnippet()).isEqualTo(""" + BeanDefinitionRegistrar.of("counter", Integer.class).withFactoryMethod(Integer.class, "valueOf", int.class) + .instanceSupplier((instanceContext) -> instanceContext.create(beanFactory, (attributes) -> Integer.valueOf(attributes.get(0)))).customize((bd) -> bd.getConstructorArgumentValues().addIndexedArgumentValue(0, 42)).register(beanFactory); + BeanDefinitionRegistrar.of("name", String.class).withConstructor(String.class) + .instanceSupplier((instanceContext) -> instanceContext.create(beanFactory, (attributes) -> new String(attributes.get(0, String.class)))).customize((bd) -> bd.getConstructorArgumentValues().addIndexedArgumentValue(0, "Hello")).register(beanFactory); + """); + } + + private CodeSnippet contribute(DefaultListableBeanFactory beanFactory, GeneratedTypeContext generationContext) { + BeanDefinitionsContribution contribution = new BeanDefinitionsContribution(beanFactory); + BeanFactoryInitialization initialization = new BeanFactoryInitialization(generationContext); + contribution.applyTo(initialization); + return CodeSnippet.of(initialization.toCodeBlock()); + } + + private GeneratedTypeContext createGenerationContext() { + return new DefaultGeneratedTypeContext("com.example", packageName -> + GeneratedType.of(ClassName.get(packageName, "Test"))); + } + + static class TestBeanRegistrationContributionProvider implements BeanRegistrationContributionProvider { + + private final String beanName; + + private final BeanFactoryContribution contribution; + + public TestBeanRegistrationContributionProvider(String beanName, BeanFactoryContribution contribution) { + this.beanName = beanName; + this.contribution = contribution; + } + + @Override + public BeanFactoryContribution getContributionFor(String beanName, RootBeanDefinition beanDefinition) { + return (beanName.equals(this.beanName) ? this.contribution : null); + } + } + +}