Register hints for types exposed via PersistenceManagedTypes

Add binding hints for managed types and hints for
@EntityListeners, @IdClass and @Converter.

Closes gh-29096
This commit is contained in:
Sébastien Deleuze 2022-09-07 11:55:13 +02:00
parent d373435856
commit 5e1b5af0e0
7 changed files with 261 additions and 2 deletions

View File

@ -21,16 +21,25 @@ import java.util.List;
import javax.lang.model.element.Modifier; import javax.lang.model.element.Modifier;
import jakarta.persistence.Converter;
import jakarta.persistence.EntityListeners;
import jakarta.persistence.IdClass;
import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.BindingReflectionHintsRegistrar;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments; import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments;
import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.ParameterizedTypeName; import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
/** /**
* {@link BeanRegistrationAotProcessor} implementations for persistence managed * {@link BeanRegistrationAotProcessor} implementations for persistence managed
@ -40,6 +49,7 @@ import org.springframework.lang.Nullable;
* and replaced by a hard-coded list of managed class names and packages. * and replaced by a hard-coded list of managed class names and packages.
* *
* @author Stephane Nicoll * @author Stephane Nicoll
* @author Sebastien Deleuze
* @since 6.0 * @since 6.0
*/ */
class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor { class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor {
@ -60,6 +70,8 @@ class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistr
private final RegisteredBean registeredBean; private final RegisteredBean registeredBean;
private final BindingReflectionHintsRegistrar bindingRegistrar = new BindingReflectionHintsRegistrar();
public JpaManagedTypesBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments, public JpaManagedTypesBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments,
RegisteredBean registeredBean) { RegisteredBean registeredBean) {
super(codeFragments); super(codeFragments);
@ -73,6 +85,7 @@ class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistr
boolean allowDirectSupplierShortcut) { boolean allowDirectSupplierShortcut) {
PersistenceManagedTypes persistenceManagedTypes = this.registeredBean.getBeanFactory() PersistenceManagedTypes persistenceManagedTypes = this.registeredBean.getBeanFactory()
.getBean(this.registeredBean.getBeanName(), PersistenceManagedTypes.class); .getBean(this.registeredBean.getBeanName(), PersistenceManagedTypes.class);
contributeHints(generationContext.getRuntimeHints(), persistenceManagedTypes.getManagedClassNames());
GeneratedMethod generatedMethod = beanRegistrationCode.getMethods() GeneratedMethod generatedMethod = beanRegistrationCode.getMethods()
.add("getInstance", method -> { .add("getInstance", method -> {
Class<?> beanType = PersistenceManagedTypes.class; Class<?> beanType = PersistenceManagedTypes.class;
@ -93,5 +106,43 @@ class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistr
return CodeBlock.join(values.stream().map(value -> CodeBlock.of("$S", value)).toList(), ", "); return CodeBlock.join(values.stream().map(value -> CodeBlock.of("$S", value)).toList(), ", ");
} }
private void contributeHints(RuntimeHints hints, List<String> managedClassNames) {
for (String managedClassName : managedClassNames) {
try {
Class<?> managedClass = ClassUtils.forName(managedClassName, null);
this.bindingRegistrar.registerReflectionHints(hints.reflection(), managedClass);
contributeEntityListenersHints(hints, managedClass);
contributeIdClassHints(hints, managedClass);
contributeConverterHints(hints, managedClass);
}
catch (ClassNotFoundException ex) {
throw new IllegalArgumentException("Failed to instantiate the managed class: " + managedClassName, ex);
}
}
}
private void contributeEntityListenersHints(RuntimeHints hints, Class<?> managedClass) {
EntityListeners entityListeners = AnnotationUtils.findAnnotation(managedClass, EntityListeners.class);
if (entityListeners != null) {
for (Class<?> entityListener : entityListeners.value()) {
hints.reflection().registerType(entityListener, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_PUBLIC_METHODS);
}
}
}
private void contributeIdClassHints(RuntimeHints hints, Class<?> managedClass) {
IdClass idClass = AnnotationUtils.findAnnotation(managedClass, IdClass.class);
if (idClass != null) {
this.bindingRegistrar.registerReflectionHints(hints.reflection(), idClass.value());
}
}
private void contributeConverterHints(RuntimeHints hints, Class<?> managedClass) {
Converter converter = AnnotationUtils.findAnnotation(managedClass, Converter.class);
if (converter != null) {
hints.reflection().registerType(managedClass, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_PUBLIC_METHODS);
}
}
} }
} }

View File

@ -0,0 +1,62 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.orm.jpa.domain;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.IdClass;
@Entity
@IdClass(EmployeeId.class)
public class Employee {
@Id
@Column
private String name;
@Id
@Column
private String department;
private EmployeeLocation location;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getDepartment() {
return department;
}
public void setDepartment(String department) {
this.department = department;
}
public EmployeeLocation getLocation() {
return location;
}
public void setLocation(EmployeeLocation location) {
this.location = location;
}
}

View File

@ -0,0 +1,29 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.orm.jpa.domain;
import java.io.Serial;
import java.io.Serializable;
public class EmployeeId implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
private String name;
private String department;
}

View File

@ -0,0 +1,30 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.orm.jpa.domain;
public class EmployeeLocation {
private String location;
public String getLocation() {
return location;
}
public void setLocation(String location) {
this.location = location;
}
}

View File

@ -0,0 +1,42 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.orm.jpa.domain;
import jakarta.persistence.AttributeConverter;
import jakarta.persistence.Converter;
@Converter(autoApply = true)
public class EmployeeLocationConverter implements AttributeConverter<EmployeeLocation, String> {
@Override
public String convertToDatabaseColumn(EmployeeLocation employeeLocation) {
if (employeeLocation != null) {
return employeeLocation.getLocation();
}
return null;
}
@Override
public EmployeeLocation convertToEntityAttribute(String data) {
if (data != null) {
EmployeeLocation employeeLocation = new EmployeeLocation();
employeeLocation.setLocation(data);
return employeeLocation;
}
return null;
}
}

View File

@ -17,11 +17,15 @@
package org.springframework.orm.jpa.persistenceunit; package org.springframework.orm.jpa.persistenceunit;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Consumer;
import javax.sql.DataSource; import javax.sql.DataSource;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.aot.test.generate.compile.Compiled; import org.springframework.aot.test.generate.compile.Compiled;
import org.springframework.aot.test.generate.compile.TestCompiler; import org.springframework.aot.test.generate.compile.TestCompiler;
@ -35,7 +39,12 @@ import org.springframework.core.io.ResourceLoader;
import org.springframework.orm.jpa.JpaVendorAdapter; import org.springframework.orm.jpa.JpaVendorAdapter;
import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean; import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean;
import org.springframework.orm.jpa.domain.DriversLicense; import org.springframework.orm.jpa.domain.DriversLicense;
import org.springframework.orm.jpa.domain.Employee;
import org.springframework.orm.jpa.domain.EmployeeId;
import org.springframework.orm.jpa.domain.EmployeeLocation;
import org.springframework.orm.jpa.domain.EmployeeLocationConverter;
import org.springframework.orm.jpa.domain.Person; import org.springframework.orm.jpa.domain.Person;
import org.springframework.orm.jpa.domain.PersonListener;
import org.springframework.orm.jpa.vendor.Database; import org.springframework.orm.jpa.vendor.Database;
import org.springframework.orm.jpa.vendor.HibernateJpaVendorAdapter; import org.springframework.orm.jpa.vendor.HibernateJpaVendorAdapter;
@ -46,6 +55,7 @@ import static org.mockito.Mockito.mock;
* Tests for {@link PersistenceManagedTypesBeanRegistrationAotProcessor}. * Tests for {@link PersistenceManagedTypesBeanRegistrationAotProcessor}.
* *
* @author Stephane Nicoll * @author Stephane Nicoll
* @author Sebastien Deleuze
*/ */
class PersistenceManagedTypesBeanRegistrationAotProcessorTests { class PersistenceManagedTypesBeanRegistrationAotProcessorTests {
@ -59,13 +69,38 @@ class PersistenceManagedTypesBeanRegistrationAotProcessorTests {
PersistenceManagedTypes persistenceManagedTypes = freshApplicationContext.getBean( PersistenceManagedTypes persistenceManagedTypes = freshApplicationContext.getBean(
"persistenceManagedTypes", PersistenceManagedTypes.class); "persistenceManagedTypes", PersistenceManagedTypes.class);
assertThat(persistenceManagedTypes.getManagedClassNames()).containsExactlyInAnyOrder( assertThat(persistenceManagedTypes.getManagedClassNames()).containsExactlyInAnyOrder(
DriversLicense.class.getName(), Person.class.getName()); DriversLicense.class.getName(), Person.class.getName(), Employee.class.getName(),
EmployeeLocationConverter.class.getName());
assertThat(persistenceManagedTypes.getManagedPackages()).isEmpty(); assertThat(persistenceManagedTypes.getManagedPackages()).isEmpty();
assertThat(freshApplicationContext.getBean( assertThat(freshApplicationContext.getBean(
EntityManagerWithPackagesToScanConfiguration.class).scanningInvoked).isFalse(); EntityManagerWithPackagesToScanConfiguration.class).scanningInvoked).isFalse();
}); });
} }
@Test
void contributeHints() {
GenericApplicationContext context = new AnnotationConfigApplicationContext();
context.registerBean(EntityManagerWithPackagesToScanConfiguration.class);
contributeHints(context, hints -> {
assertThat(RuntimeHintsPredicates.reflection().onType(DriversLicense.class)
.withMemberCategories(MemberCategory.DECLARED_FIELDS)).accepts(hints);
assertThat(RuntimeHintsPredicates.reflection().onType(Person.class)
.withMemberCategories(MemberCategory.DECLARED_FIELDS)).accepts(hints);
assertThat(RuntimeHintsPredicates.reflection().onType(PersonListener.class)
.withMemberCategories(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_PUBLIC_METHODS))
.accepts(hints);
assertThat(RuntimeHintsPredicates.reflection().onType(Employee.class)
.withMemberCategories(MemberCategory.DECLARED_FIELDS)).accepts(hints);
assertThat(RuntimeHintsPredicates.reflection().onType(EmployeeId.class)
.withMemberCategories(MemberCategory.DECLARED_FIELDS)).accepts(hints);
assertThat(RuntimeHintsPredicates.reflection().onType(EmployeeLocationConverter.class)
.withMemberCategories(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_PUBLIC_METHODS))
.accepts(hints);
assertThat(RuntimeHintsPredicates.reflection().onType(EmployeeLocation.class)
.withMemberCategories(MemberCategory.DECLARED_FIELDS)).accepts(hints);
});
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private void compile(GenericApplicationContext applicationContext, private void compile(GenericApplicationContext applicationContext,
@ -86,6 +121,13 @@ class PersistenceManagedTypesBeanRegistrationAotProcessorTests {
return freshApplicationContext; return freshApplicationContext;
} }
private void contributeHints(GenericApplicationContext applicationContext, Consumer<RuntimeHints> result) {
ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator();
TestGenerationContext generationContext = new TestGenerationContext();
generator.processAheadOfTime(applicationContext, generationContext);
result.accept(generationContext.getRuntimeHints());
}
@Configuration(proxyBeanMethods = false) @Configuration(proxyBeanMethods = false)
public static class EntityManagerWithPackagesToScanConfiguration { public static class EntityManagerWithPackagesToScanConfiguration {

View File

@ -22,6 +22,8 @@ import org.springframework.context.testfixture.index.CandidateComponentsTestClas
import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.orm.jpa.domain.DriversLicense; import org.springframework.orm.jpa.domain.DriversLicense;
import org.springframework.orm.jpa.domain.Employee;
import org.springframework.orm.jpa.domain.EmployeeLocationConverter;
import org.springframework.orm.jpa.domain.Person; import org.springframework.orm.jpa.domain.Person;
import org.springframework.orm.jpa.domain2.entity.User; import org.springframework.orm.jpa.domain2.entity.User;
@ -40,7 +42,8 @@ class PersistenceManagedTypesScannerTests {
void scanPackageWithOnlyEntities() { void scanPackageWithOnlyEntities() {
PersistenceManagedTypes managedTypes = this.scanner.scan("org.springframework.orm.jpa.domain"); PersistenceManagedTypes managedTypes = this.scanner.scan("org.springframework.orm.jpa.domain");
assertThat(managedTypes.getManagedClassNames()).containsExactlyInAnyOrder( assertThat(managedTypes.getManagedClassNames()).containsExactlyInAnyOrder(
Person.class.getName(), DriversLicense.class.getName()); Person.class.getName(), DriversLicense.class.getName(), Employee.class.getName(),
EmployeeLocationConverter.class.getName());
assertThat(managedTypes.getManagedPackages()).isEmpty(); assertThat(managedTypes.getManagedPackages()).isEmpty();
} }