From 3b2b36d0b8e275807c54586946e8dac4fd248d40 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Mon, 3 Oct 2022 10:49:25 +0200 Subject: [PATCH] Allow AccessControl to determine visibility from a given type This commit adapts AccessVisibility so that it can determine if the member or type signature is accessible from a given package. This lets implementers figure out if reflection is necessary without assuming that package private visibility is OK. Closes gh-29245 --- .../AutowiredAnnotationBeanPostProcessor.java | 38 +-- .../DefaultBeanRegistrationCodeFragments.java | 8 +- .../aot/InstanceSupplierCodeGenerator.java | 25 +- ...nBeanRegistrationAotContributionTests.java | 59 +++- .../TestBeanWithPackagePrivateField.java | 23 ++ .../TestBeanWithPackagePrivateMethod.java | 28 ++ .../PackagePrivateFieldInjectionSample.java | 3 +- .../PackagePrivateMethodInjectionSample.java | 6 +- .../PrivateFieldInjectionSample.java | 3 +- .../PrivateMethodInjectionSample.java | 3 +- ...PrivateFieldInjectionFromParentSample.java | 25 ++ ...rivateMethodInjectionFromParentSample.java | 24 ++ .../aot/generate/AccessControl.java | 275 ++++++++++++++++++ .../aot/generate/AccessVisibility.java | 184 ------------ .../aot/generate/AccessControlTests.java | 227 +++++++++++++++ .../aot/generate/AccessVisibilityTests.java | 175 ----------- .../jpa/support/InjectionCodeGenerator.java | 19 +- ...ersistenceAnnotationBeanPostProcessor.java | 13 +- .../support/InjectionCodeGeneratorTests.java | 118 ++++++-- 19 files changed, 816 insertions(+), 440 deletions(-) create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateField.java create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateMethod.java rename spring-beans/src/{test/java/org/springframework => testFixtures/java/org/springframework/beans/testfixture}/beans/factory/annotation/PackagePrivateFieldInjectionSample.java (85%) rename spring-beans/src/{test/java/org/springframework => testFixtures/java/org/springframework/beans/testfixture}/beans/factory/annotation/PackagePrivateMethodInjectionSample.java (83%) rename spring-beans/src/{test/java/org/springframework => testFixtures/java/org/springframework/beans/testfixture}/beans/factory/annotation/PrivateFieldInjectionSample.java (85%) rename spring-beans/src/{test/java/org/springframework => testFixtures/java/org/springframework/beans/testfixture}/beans/factory/annotation/PrivateMethodInjectionSample.java (86%) create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/subpkg/PackagePrivateFieldInjectionFromParentSample.java create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/subpkg/PackagePrivateMethodInjectionFromParentSample.java create mode 100644 spring-core/src/main/java/org/springframework/aot/generate/AccessControl.java delete mode 100644 spring-core/src/main/java/org/springframework/aot/generate/AccessVisibility.java create mode 100644 spring-core/src/test/java/org/springframework/aot/generate/AccessControlTests.java delete mode 100644 spring-core/src/test/java/org/springframework/aot/generate/AccessVisibilityTests.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java index 92dec61c232..eabf7380b57 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanPostProcessor.java @@ -40,7 +40,7 @@ import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.aot.generate.AccessVisibility; +import org.springframework.aot.generate.AccessControl; import org.springframework.aot.generate.GeneratedClass; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GenerationContext; @@ -81,6 +81,7 @@ import org.springframework.core.annotation.AnnotationAttributes; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.core.annotation.MergedAnnotations; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -941,7 +942,8 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); method.addParameter(this.target, INSTANCE_PARAMETER); method.returns(this.target); - method.addCode(generateMethodCode(generationContext.getRuntimeHints())); + method.addCode(generateMethodCode(generatedClass.getName(), + generationContext.getRuntimeHints())); }); beanRegistrationCode.addInstancePostProcessor(generateMethod.toMethodReference()); @@ -950,41 +952,42 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA } } - private CodeBlock generateMethodCode(RuntimeHints hints) { + private CodeBlock generateMethodCode(ClassName targetClassName, RuntimeHints hints) { CodeBlock.Builder code = CodeBlock.builder(); for (AutowiredElement autowiredElement : this.autowiredElements) { - code.addStatement( - generateMethodStatementForElement(autowiredElement, hints)); + code.addStatement(generateMethodStatementForElement( + targetClassName, autowiredElement, hints)); } code.addStatement("return $L", INSTANCE_PARAMETER); return code.build(); } - private CodeBlock generateMethodStatementForElement( + private CodeBlock generateMethodStatementForElement(ClassName targetClassName, AutowiredElement autowiredElement, RuntimeHints hints) { Member member = autowiredElement.getMember(); boolean required = autowiredElement.required; if (member instanceof Field field) { - return generateMethodStatementForField(field, required, hints); + return generateMethodStatementForField( + targetClassName, field, required, hints); } if (member instanceof Method method) { - return generateMethodStatementForMethod(method, required, hints); + return generateMethodStatementForMethod( + targetClassName, method, required, hints); } throw new IllegalStateException( "Unsupported member type " + member.getClass().getName()); } - private CodeBlock generateMethodStatementForField(Field field, boolean required, - RuntimeHints hints) { + private CodeBlock generateMethodStatementForField(ClassName targetClassName, + Field field, boolean required, RuntimeHints hints) { hints.reflection().registerField(field); CodeBlock resolver = CodeBlock.of("$T.$L($S)", AutowiredFieldValueResolver.class, (!required) ? "forField" : "forRequiredField", field.getName()); - AccessVisibility visibility = AccessVisibility.forMember(field); - if (visibility == AccessVisibility.PRIVATE - || visibility == AccessVisibility.PROTECTED) { + AccessControl accessControl = AccessControl.forMember(field); + if (!accessControl.isAccessibleFrom(targetClassName)) { return CodeBlock.of("$L.resolveAndSet($L, $L)", resolver, REGISTERED_BEAN_PARAMETER, INSTANCE_PARAMETER); } @@ -992,8 +995,8 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA field.getName(), resolver, REGISTERED_BEAN_PARAMETER); } - private CodeBlock generateMethodStatementForMethod(Method method, - boolean required, RuntimeHints hints) { + private CodeBlock generateMethodStatementForMethod(ClassName targetClassName, + Method method, boolean required, RuntimeHints hints) { CodeBlock.Builder code = CodeBlock.builder(); code.add("$T.$L", AutowiredMethodArgumentsResolver.class, @@ -1004,9 +1007,8 @@ public class AutowiredAnnotationBeanPostProcessor implements SmartInstantiationA generateParameterTypesCode(method.getParameterTypes())); } code.add(")"); - AccessVisibility visibility = AccessVisibility.forMember(method); - if (visibility == AccessVisibility.PRIVATE - || visibility == AccessVisibility.PROTECTED) { + AccessControl accessControl = AccessControl.forMember(method); + if (!accessControl.isAccessibleFrom(targetClassName)) { hints.reflection().registerMethod(method, ExecutableMode.INVOKE); code.add(".resolveAndInvoke($L, $L)", REGISTERED_BEAN_PARAMETER, INSTANCE_PARAMETER); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index f756f210888..756bd21b157 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -21,7 +21,7 @@ import java.lang.reflect.Executable; import java.util.List; import java.util.function.Predicate; -import org.springframework.aot.generate.AccessVisibility; +import org.springframework.aot.generate.AccessControl; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; @@ -85,9 +85,9 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme private Class extractDeclaringClass(ResolvableType beanType, Executable executable) { Class declaringClass = ClassUtils.getUserClass(executable.getDeclaringClass()); - if (executable instanceof Constructor && - AccessVisibility.forMember(executable) == AccessVisibility.PUBLIC && - FactoryBean.class.isAssignableFrom(declaringClass)) { + if (executable instanceof Constructor + && AccessControl.forMember(executable).isPublic() + && FactoryBean.class.isAssignableFrom(declaringClass)) { return extractTargetClassFromFactoryBean(declaringClass, beanType); } return executable.getDeclaringClass(); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index 217ac0f1e8f..7bc9e7f3612 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -24,7 +24,8 @@ import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.function.Consumer; -import org.springframework.aot.generate.AccessVisibility; +import org.springframework.aot.generate.AccessControl; +import org.springframework.aot.generate.AccessControl.Visibility; import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; @@ -111,10 +112,9 @@ class InstanceSupplierCodeGenerator { Class declaringClass = ClassUtils .getUserClass(constructor.getDeclaringClass()); boolean dependsOnBean = ClassUtils.isInnerClass(declaringClass); - AccessVisibility accessVisibility = getAccessVisibility(registeredBean, - constructor); - if (accessVisibility == AccessVisibility.PUBLIC - || accessVisibility == AccessVisibility.PACKAGE_PRIVATE) { + Visibility accessVisibility = getAccessVisibility(registeredBean, constructor); + if (accessVisibility == Visibility.PUBLIC + || accessVisibility == Visibility.PACKAGE_PRIVATE) { return generateCodeForAccessibleConstructor(beanName, beanClass, constructor, dependsOnBean, declaringClass); } @@ -207,10 +207,9 @@ class InstanceSupplierCodeGenerator { Class declaringClass = ClassUtils .getUserClass(factoryMethod.getDeclaringClass()); boolean dependsOnBean = !Modifier.isStatic(factoryMethod.getModifiers()); - AccessVisibility accessVisibility = getAccessVisibility(registeredBean, - factoryMethod); - if (accessVisibility == AccessVisibility.PUBLIC - || accessVisibility == AccessVisibility.PACKAGE_PRIVATE) { + Visibility accessVisibility = getAccessVisibility(registeredBean, factoryMethod); + if (accessVisibility == Visibility.PUBLIC + || accessVisibility == Visibility.PACKAGE_PRIVATE) { return generateCodeForAccessibleFactoryMethod(beanName, beanClass, factoryMethod, declaringClass, dependsOnBean); } @@ -314,13 +313,13 @@ class InstanceSupplierCodeGenerator { return code.build(); } - protected AccessVisibility getAccessVisibility(RegisteredBean registeredBean, + private Visibility getAccessVisibility(RegisteredBean registeredBean, Member member) { - AccessVisibility beanTypeAccessVisibility = AccessVisibility + AccessControl beanTypeAccessControl = AccessControl .forResolvableType(registeredBean.getBeanType()); - AccessVisibility memberAccessVisibility = AccessVisibility.forMember(member); - return AccessVisibility.lowest(beanTypeAccessVisibility, memberAccessVisibility); + AccessControl memberAccessControl = AccessControl.forMember(member); + return AccessControl.lowest(beanTypeAccessControl, memberAccessControl).getVisibility(); } private CodeBlock generateParameterTypesCode(Class[] parameterTypes, int offset) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java index 6ef922f5d41..81773b39737 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java @@ -31,11 +31,18 @@ import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.testfixture.beans.factory.annotation.PackagePrivateFieldInjectionSample; +import org.springframework.beans.testfixture.beans.factory.annotation.PackagePrivateMethodInjectionSample; +import org.springframework.beans.testfixture.beans.factory.annotation.PrivateFieldInjectionSample; +import org.springframework.beans.testfixture.beans.factory.annotation.PrivateMethodInjectionSample; +import org.springframework.beans.testfixture.beans.factory.annotation.subpkg.PackagePrivateFieldInjectionFromParentSample; +import org.springframework.beans.testfixture.beans.factory.annotation.subpkg.PackagePrivateMethodInjectionFromParentSample; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationCode; import org.springframework.core.env.Environment; import org.springframework.core.env.StandardEnvironment; import org.springframework.core.test.tools.CompileWithForkedClassLoader; import org.springframework.core.test.tools.Compiled; +import org.springframework.core.test.tools.SourceFile; import org.springframework.core.test.tools.TestCompiler; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; @@ -79,7 +86,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { PrivateFieldInjectionSample instance = new PrivateFieldInjectionSample(); postProcessor.apply(registeredBean, instance); assertThat(instance).extracting("environment").isSameAs(environment); - assertThat(compiled.getSourceFileFromPackage(getClass().getPackageName())) + assertThat(getSourceFile(compiled, PrivateFieldInjectionSample.class)) .contains("resolveAndSet("); }); } @@ -98,11 +105,30 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { PackagePrivateFieldInjectionSample instance = new PackagePrivateFieldInjectionSample(); postProcessor.apply(registeredBean, instance); assertThat(instance).extracting("environment").isSameAs(environment); - assertThat(compiled.getSourceFileFromPackage(getClass().getPackageName())) + assertThat(getSourceFile(compiled, PackagePrivateFieldInjectionSample.class)) .contains("instance.environment ="); }); } + @Test + @CompileWithForkedClassLoader + void contributeWhenPackagePrivateFieldInjectionOnParentClassInjectsUsingReflection() { + Environment environment = new StandardEnvironment(); + this.beanFactory.registerSingleton("environment", environment); + RegisteredBean registeredBean = getAndApplyContribution( + PackagePrivateFieldInjectionFromParentSample.class); + assertThat(RuntimeHintsPredicates.reflection() + .onField(PackagePrivateFieldInjectionSample.class, "environment")) + .accepts(this.generationContext.getRuntimeHints()); + compile(registeredBean, (postProcessor, compiled) -> { + PackagePrivateFieldInjectionFromParentSample instance = new PackagePrivateFieldInjectionFromParentSample(); + postProcessor.apply(registeredBean, instance); + assertThat(instance).extracting("environment").isSameAs(environment); + assertThat(getSourceFile(compiled, PackagePrivateFieldInjectionFromParentSample.class)) + .contains("resolveAndSet"); + }); + } + @Test void contributeWhenPrivateMethodInjectionInjectsUsingReflection() { Environment environment = new StandardEnvironment(); @@ -116,7 +142,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { PrivateMethodInjectionSample instance = new PrivateMethodInjectionSample(); postProcessor.apply(registeredBean, instance); assertThat(instance).extracting("environment").isSameAs(environment); - assertThat(compiled.getSourceFileFromPackage(getClass().getPackageName())) + assertThat(getSourceFile(compiled, PrivateMethodInjectionSample.class)) .contains("resolveAndInvoke("); }); } @@ -134,12 +160,31 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { compile(registeredBean, (postProcessor, compiled) -> { PackagePrivateMethodInjectionSample instance = new PackagePrivateMethodInjectionSample(); postProcessor.apply(registeredBean, instance); - assertThat(instance).extracting("environment").isSameAs(environment); - assertThat(compiled.getSourceFileFromPackage(getClass().getPackageName())) + assertThat(instance.environment).isSameAs(environment); + assertThat(getSourceFile(compiled, PackagePrivateMethodInjectionSample.class)) .contains("args -> instance.setTestBean("); }); } + @Test + @CompileWithForkedClassLoader + void contributeWhenPackagePrivateMethodInjectionOnParentClassInjectsUsingReflection() { + Environment environment = new StandardEnvironment(); + this.beanFactory.registerSingleton("environment", environment); + RegisteredBean registeredBean = getAndApplyContribution( + PackagePrivateMethodInjectionFromParentSample.class); + assertThat(RuntimeHintsPredicates.reflection() + .onMethod(PackagePrivateMethodInjectionSample.class, "setTestBean")) + .accepts(this.generationContext.getRuntimeHints()); + compile(registeredBean, (postProcessor, compiled) -> { + PackagePrivateMethodInjectionFromParentSample instance = new PackagePrivateMethodInjectionFromParentSample(); + postProcessor.apply(registeredBean, instance); + assertThat(instance.environment).isSameAs(environment); + assertThat(getSourceFile(compiled, PackagePrivateMethodInjectionFromParentSample.class)) + .contains("resolveAndInvoke("); + }); + } + private RegisteredBean getAndApplyContribution(Class beanClass) { RegisteredBean registeredBean = registerBean(beanClass); BeanRegistrationAotContribution contribution = new AutowiredAnnotationBeanPostProcessor() @@ -156,6 +201,10 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { return RegisteredBean.of(this.beanFactory, beanName); } + private static SourceFile getSourceFile(Compiled compiled, Class sample) { + return compiled.getSourceFileFromPackage(sample.getPackageName()); + } + @SuppressWarnings("unchecked") private void compile(RegisteredBean registeredBean, BiConsumer, Compiled> result) { diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateField.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateField.java new file mode 100644 index 00000000000..d7cb24e27f0 --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateField.java @@ -0,0 +1,23 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans; + +public class TestBeanWithPackagePrivateField { + + int age; + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateMethod.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateMethod.java new file mode 100644 index 00000000000..6c6e3ee4c7c --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/TestBeanWithPackagePrivateMethod.java @@ -0,0 +1,28 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans; + +@SuppressWarnings("unused") +public class TestBeanWithPackagePrivateMethod { + + private int age; + + void setAge(int age) { + this.age = age; + } + +} diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/PackagePrivateFieldInjectionSample.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PackagePrivateFieldInjectionSample.java similarity index 85% rename from spring-beans/src/test/java/org/springframework/beans/factory/annotation/PackagePrivateFieldInjectionSample.java rename to spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PackagePrivateFieldInjectionSample.java index 9616e8c3225..b1452ba360d 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/PackagePrivateFieldInjectionSample.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PackagePrivateFieldInjectionSample.java @@ -14,8 +14,9 @@ * limitations under the License. */ -package org.springframework.beans.factory.annotation; +package org.springframework.beans.testfixture.beans.factory.annotation; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.env.Environment; public class PackagePrivateFieldInjectionSample { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/PackagePrivateMethodInjectionSample.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PackagePrivateMethodInjectionSample.java similarity index 83% rename from spring-beans/src/test/java/org/springframework/beans/factory/annotation/PackagePrivateMethodInjectionSample.java rename to spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PackagePrivateMethodInjectionSample.java index 010065d1225..2ffcf726e72 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/PackagePrivateMethodInjectionSample.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PackagePrivateMethodInjectionSample.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package org.springframework.beans.factory.annotation; +package org.springframework.beans.testfixture.beans.factory.annotation; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.env.Environment; public class PackagePrivateMethodInjectionSample { - @SuppressWarnings("unused") - private Environment environment; + public Environment environment; @Autowired void setTestBean(Environment environment) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/PrivateFieldInjectionSample.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PrivateFieldInjectionSample.java similarity index 85% rename from spring-beans/src/test/java/org/springframework/beans/factory/annotation/PrivateFieldInjectionSample.java rename to spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PrivateFieldInjectionSample.java index bbe50557ccf..9aa2c69187d 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/PrivateFieldInjectionSample.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PrivateFieldInjectionSample.java @@ -14,8 +14,9 @@ * limitations under the License. */ -package org.springframework.beans.factory.annotation; +package org.springframework.beans.testfixture.beans.factory.annotation; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.env.Environment; public class PrivateFieldInjectionSample { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/PrivateMethodInjectionSample.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PrivateMethodInjectionSample.java similarity index 86% rename from spring-beans/src/test/java/org/springframework/beans/factory/annotation/PrivateMethodInjectionSample.java rename to spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PrivateMethodInjectionSample.java index bdd48f2dc56..d741e69b4a4 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/PrivateMethodInjectionSample.java +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/PrivateMethodInjectionSample.java @@ -14,8 +14,9 @@ * limitations under the License. */ -package org.springframework.beans.factory.annotation; +package org.springframework.beans.testfixture.beans.factory.annotation; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.env.Environment; public class PrivateMethodInjectionSample { diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/subpkg/PackagePrivateFieldInjectionFromParentSample.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/subpkg/PackagePrivateFieldInjectionFromParentSample.java new file mode 100644 index 00000000000..33e51f6e1f2 --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/subpkg/PackagePrivateFieldInjectionFromParentSample.java @@ -0,0 +1,25 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans.factory.annotation.subpkg; + +import org.springframework.beans.testfixture.beans.factory.annotation.PackagePrivateFieldInjectionSample; + +public class PackagePrivateFieldInjectionFromParentSample extends PackagePrivateFieldInjectionSample { + + // see environment from parent + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/subpkg/PackagePrivateMethodInjectionFromParentSample.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/subpkg/PackagePrivateMethodInjectionFromParentSample.java new file mode 100644 index 00000000000..f2afef824b1 --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/annotation/subpkg/PackagePrivateMethodInjectionFromParentSample.java @@ -0,0 +1,24 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans.factory.annotation.subpkg; + +import org.springframework.beans.testfixture.beans.factory.annotation.PackagePrivateMethodInjectionSample; + +public class PackagePrivateMethodInjectionFromParentSample extends PackagePrivateMethodInjectionSample { + + // see setTestBean from parent +} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/AccessControl.java b/spring-core/src/main/java/org/springframework/aot/generate/AccessControl.java new file mode 100644 index 00000000000..3f4cd401851 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/aot/generate/AccessControl.java @@ -0,0 +1,275 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Field; +import java.lang.reflect.Member; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.function.IntFunction; + +import org.springframework.core.ResolvableType; +import org.springframework.javapoet.ClassName; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; + +/** + * Determine the access control of a {@link Member} or type signature. + * + * @author Stephane Nicoll + * @author Phillip Webb + * @since 6.0 + */ +public final class AccessControl { + + private final Class target; + + private final Visibility visibility; + + AccessControl(Class target, Visibility visibility) { + this.target = target; + this.visibility = visibility; + } + + /** + * Create an {@link AccessControl} for the given member. This considers the + * member modifier, parameter types, return types and any enclosing classes. + * The lowest overall {@link Visibility} is used. + * @param member the source member + * @return the {@link AccessControl} for the member + */ + public static AccessControl forMember(Member member) { + return new AccessControl(member.getDeclaringClass(), Visibility.forMember(member)); + } + + /** + * Create an {@link AccessControl} for the given {@link ResolvableType}. + * This considers the type itself as well as any generics. + * @param resolvableType the source resolvable type + * @return the {@link AccessControl} for the type + */ + public static AccessControl forResolvableType(ResolvableType resolvableType) { + return new AccessControl(resolvableType.toClass(), + Visibility.forResolvableType(resolvableType)); + } + + /** + * Create an {@link AccessControl} for the given {@link Class}. + * @param type the source class + * @return the {@link AccessControl} for the class + */ + public static AccessControl forClass(Class type) { + return new AccessControl(type, Visibility.forClass(type)); + } + + /** + * Returns the lowest {@link AccessControl} from the given candidates. + * @param candidates the candidates to check + * @return the lowest {@link AccessControl} from the candidates + */ + public static AccessControl lowest(AccessControl... candidates) { + int index = Visibility.lowestIndex(Arrays.stream(candidates) + .map(AccessControl::getVisibility).toArray(Visibility[]::new)); + return candidates[index]; + } + + /** + * Return the lowest {@link Visibility} of this instance. + * @return the visibility + */ + public Visibility getVisibility() { + return this.visibility; + } + + /** + * Return whether the member or type signature backed by ths instance is + * accessible from any package. + * @return {@code true} if it is public + */ + public boolean isPublic() { + return this.visibility == Visibility.PUBLIC; + } + + /** + * Specify whether the member or type signature backed by this instance is + * accessible from the specified {@link ClassName}. + * @param type the type to check + * @return {@code true} if it is accessible + */ + public boolean isAccessibleFrom(ClassName type) { + if (this.visibility == Visibility.PRIVATE) { + return false; + } + if (this.visibility == Visibility.PUBLIC) { + return true; + } + return this.target.getPackageName().equals(type.packageName()); + } + + /** + * Access visibility types as determined by the modifiers + * on a {@link Member} or {@link ResolvableType}. + */ + public enum Visibility { + + /** + * Public visibility. The member or type is visible to all classes. + */ + PUBLIC, + + /** + * Protected visibility. The member or type is only visible to classes + * in the same package or subclasses. + */ + PROTECTED, + + /** + * Package-private visibility. The member or type is only visible to classes + * in the same package. + */ + PACKAGE_PRIVATE, + + /** + * Private visibility. The member or type is not visible to other classes. + */ + PRIVATE; + + + private static Visibility forMember(Member member) { + Assert.notNull(member, "'member' must not be null"); + Visibility visibility = forModifiers(member.getModifiers()); + Visibility declaringClassVisibility = forClass(member.getDeclaringClass()); + visibility = lowest(visibility, declaringClassVisibility); + if (visibility != PRIVATE) { + if (member instanceof Field field) { + Visibility fieldVisibility = forResolvableType( + ResolvableType.forField(field)); + return lowest(visibility, fieldVisibility); + } + if (member instanceof Constructor constructor) { + Visibility parameterVisibility = forParameterTypes(constructor, + i -> ResolvableType.forConstructorParameter(constructor, i)); + return lowest(visibility, parameterVisibility); + } + if (member instanceof Method method) { + Visibility parameterVisibility = forParameterTypes(method, + i -> ResolvableType.forMethodParameter(method, i)); + Visibility returnTypeVisibility = forResolvableType( + ResolvableType.forMethodReturnType(method)); + return lowest(visibility, parameterVisibility, returnTypeVisibility); + } + } + return PRIVATE; + } + + private static Visibility forResolvableType(ResolvableType resolvableType) { + return forResolvableType(resolvableType, new HashSet<>()); + } + + private static Visibility forResolvableType(ResolvableType resolvableType, + Set seen) { + if (!seen.add(resolvableType)) { + return Visibility.PUBLIC; + } + Class userClass = ClassUtils.getUserClass(resolvableType.toClass()); + ResolvableType userType = resolvableType.as(userClass); + Visibility visibility = forClass(userType.toClass()); + for (ResolvableType generic : userType.getGenerics()) { + visibility = lowest(visibility, forResolvableType(generic, seen)); + } + return visibility; + } + + private static Visibility forParameterTypes(Executable executable, + IntFunction resolvableTypeFactory) { + Visibility visibility = Visibility.PUBLIC; + Class[] parameterTypes = executable.getParameterTypes(); + for (int i = 0; i < parameterTypes.length; i++) { + ResolvableType type = resolvableTypeFactory.apply(i); + visibility = lowest(visibility, forResolvableType(type)); + } + return visibility; + } + + private static Visibility forClass(Class clazz) { + clazz = ClassUtils.getUserClass(clazz); + Visibility visibility = forModifiers(clazz.getModifiers()); + if (clazz.isArray()) { + visibility = lowest(visibility, forClass(clazz.getComponentType())); + } + Class enclosingClass = clazz.getEnclosingClass(); + if (enclosingClass != null) { + visibility = lowest(visibility, forClass(clazz.getEnclosingClass())); + } + return visibility; + } + + private static Visibility forModifiers(int modifiers) { + if (Modifier.isPublic(modifiers)) { + return PUBLIC; + } + if (Modifier.isProtected(modifiers)) { + return PROTECTED; + } + if (Modifier.isPrivate(modifiers)) { + return PRIVATE; + } + return PACKAGE_PRIVATE; + } + + /** + * Returns the lowest {@link Visibility} from the given candidates. + * @param candidates the candidates to check + * @return the lowest {@link Visibility} from the candidates + */ + static Visibility lowest(Visibility... candidates) { + Visibility visibility = PUBLIC; + for (Visibility candidate : candidates) { + if (candidate.ordinal() > visibility.ordinal()) { + visibility = candidate; + } + } + return visibility; + } + + /** + * Returns the index of the lowest {@link Visibility} from the given + * candidates. + * @param candidates the candidates to check + * @return the index of the lowest {@link Visibility} from the candidates + */ + static int lowestIndex(Visibility... candidates) { + Visibility visibility = PUBLIC; + int index = 0; + for (int i = 0; i < candidates.length; i++) { + Visibility candidate = candidates[i]; + if (candidate.ordinal() > visibility.ordinal()) { + visibility = candidate; + index = i; + } + } + return index; + } + + } +} diff --git a/spring-core/src/main/java/org/springframework/aot/generate/AccessVisibility.java b/spring-core/src/main/java/org/springframework/aot/generate/AccessVisibility.java deleted file mode 100644 index a042945f17d..00000000000 --- a/spring-core/src/main/java/org/springframework/aot/generate/AccessVisibility.java +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import java.lang.reflect.Constructor; -import java.lang.reflect.Executable; -import java.lang.reflect.Field; -import java.lang.reflect.Member; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.HashSet; -import java.util.Set; -import java.util.function.IntFunction; - -import org.springframework.core.ResolvableType; -import org.springframework.util.Assert; -import org.springframework.util.ClassUtils; - -/** - * Access visibility types as determined by the modifiers - * on a {@link Member} or {@link ResolvableType}. - * - * @author Phillip Webb - * @author Stephane Nicoll - * @since 6.0 - * @see #forMember(Member) - * @see #forResolvableType(ResolvableType) - */ -public enum AccessVisibility { - - /** - * Public visibility. The member or type is visible to all classes. - */ - PUBLIC, - - /** - * Protected visibility. The member or type is only visible to subclasses. - */ - PROTECTED, - - /** - * Package-private visibility. The member or type is only visible to classes - * in the same package. - */ - PACKAGE_PRIVATE, - - /** - * Private visibility. The member or type is not visible to other classes. - */ - PRIVATE; - - - /** - * Determine the {@link AccessVisibility} for the given member. This method - * will consider the member modifier, parameter types, return types and any - * enclosing classes. The lowest overall visibility will be returned. - * @param member the source member - * @return the {@link AccessVisibility} for the member - */ - public static AccessVisibility forMember(Member member) { - Assert.notNull(member, "'member' must not be null"); - AccessVisibility visibility = forModifiers(member.getModifiers()); - AccessVisibility declaringClassVisibility = forClass(member.getDeclaringClass()); - visibility = lowest(visibility, declaringClassVisibility); - if (visibility != PRIVATE) { - if (member instanceof Field field) { - AccessVisibility fieldVisibility = forResolvableType( - ResolvableType.forField(field)); - return lowest(visibility, fieldVisibility); - } - if (member instanceof Constructor constructor) { - AccessVisibility parameterVisibility = forParameterTypes(constructor, - i -> ResolvableType.forConstructorParameter(constructor, i)); - return lowest(visibility, parameterVisibility); - } - if (member instanceof Method method) { - AccessVisibility parameterVisibility = forParameterTypes(method, - i -> ResolvableType.forMethodParameter(method, i)); - AccessVisibility returnTypeVisibility = forResolvableType( - ResolvableType.forMethodReturnType(method)); - return lowest(visibility, parameterVisibility, returnTypeVisibility); - } - } - return PRIVATE; - } - - /** - * Determine the {@link AccessVisibility} for the given - * {@link ResolvableType}. This method will consider the type itself as well - * as any generics. - * @param resolvableType the source resolvable type - * @return the {@link AccessVisibility} for the type - */ - public static AccessVisibility forResolvableType(ResolvableType resolvableType) { - return forResolvableType(resolvableType, new HashSet<>()); - } - - private static AccessVisibility forResolvableType(ResolvableType resolvableType, - Set seen) { - if (!seen.add(resolvableType)) { - return AccessVisibility.PUBLIC; - } - Class userClass = ClassUtils.getUserClass(resolvableType.toClass()); - ResolvableType userType = resolvableType.as(userClass); - AccessVisibility visibility = forClass(userType.toClass()); - for (ResolvableType generic : userType.getGenerics()) { - visibility = lowest(visibility, forResolvableType(generic, seen)); - } - return visibility; - } - - private static AccessVisibility forParameterTypes(Executable executable, - IntFunction resolvableTypeFactory) { - AccessVisibility visibility = AccessVisibility.PUBLIC; - Class[] parameterTypes = executable.getParameterTypes(); - for (int i = 0; i < parameterTypes.length; i++) { - ResolvableType type = resolvableTypeFactory.apply(i); - visibility = lowest(visibility, forResolvableType(type)); - } - return visibility; - } - - /** - * Determine the {@link AccessVisibility} for the given {@link Class}. - * @param clazz the source class - * @return the {@link AccessVisibility} for the class - */ - public static AccessVisibility forClass(Class clazz) { - clazz = ClassUtils.getUserClass(clazz); - AccessVisibility visibility = forModifiers(clazz.getModifiers()); - if (clazz.isArray()) { - visibility = lowest(visibility, forClass(clazz.getComponentType())); - } - Class enclosingClass = clazz.getEnclosingClass(); - if (enclosingClass != null) { - visibility = lowest(visibility, forClass(clazz.getEnclosingClass())); - } - return visibility; - } - - private static AccessVisibility forModifiers(int modifiers) { - if (Modifier.isPublic(modifiers)) { - return PUBLIC; - } - if (Modifier.isProtected(modifiers)) { - return PROTECTED; - } - if (Modifier.isPrivate(modifiers)) { - return PRIVATE; - } - return PACKAGE_PRIVATE; - } - - /** - * Returns the lowest {@link AccessVisibility} put of the given candidates. - * @param candidates the candidates to check - * @return the lowest {@link AccessVisibility} from the candidates - */ - public static AccessVisibility lowest(AccessVisibility... candidates) { - AccessVisibility visibility = PUBLIC; - for (AccessVisibility candidate : candidates) { - if (candidate.ordinal() > visibility.ordinal()) { - visibility = candidate; - } - } - return visibility; - } - -} diff --git a/spring-core/src/test/java/org/springframework/aot/generate/AccessControlTests.java b/spring-core/src/test/java/org/springframework/aot/generate/AccessControlTests.java new file mode 100644 index 00000000000..53bfb4acaeb --- /dev/null +++ b/spring-core/src/test/java/org/springframework/aot/generate/AccessControlTests.java @@ -0,0 +1,227 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.generate; + +import java.lang.reflect.Field; +import java.lang.reflect.Member; +import java.lang.reflect.Method; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.AccessControl.Visibility; +import org.springframework.core.ResolvableType; +import org.springframework.core.testfixture.aot.generator.visibility.ProtectedGenericParameter; +import org.springframework.core.testfixture.aot.generator.visibility.ProtectedParameter; +import org.springframework.core.testfixture.aot.generator.visibility.PublicFactoryBean; +import org.springframework.javapoet.ClassName; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link AccessControl}. + * + * @author Phillip Webb + * @author Stephane Nicoll + */ +class AccessControlTests { + + @Test + void isAccessibleWhenPublicVisibilityInSamePackage() { + AccessControl accessControl = new AccessControl(PublicClass.class, Visibility.PUBLIC); + assertThat(accessControl.isAccessibleFrom(ClassName.get(PublicClass.class))).isTrue(); + } + + @Test + void isAccessibleWhenPublicVisibilityInDifferentPackage() { + AccessControl accessControl = new AccessControl(PublicClass.class, Visibility.PUBLIC); + assertThat(accessControl.isAccessibleFrom(ClassName.get(String.class))).isTrue(); + } + + @Test + void isAccessibleWhenProtectedVisibilityInSamePackage() { + AccessControl accessControl = new AccessControl(PublicClass.class, Visibility.PROTECTED); + assertThat(accessControl.isAccessibleFrom(ClassName.get(PublicClass.class))).isTrue(); + } + + @Test + void isAccessibleWhenProtectedVisibilityInDifferentPackage() { + AccessControl accessControl = new AccessControl(PublicClass.class, Visibility.PROTECTED); + assertThat(accessControl.isAccessibleFrom(ClassName.get(String.class))).isFalse(); + } + + @Test + void isAccessibleWhenPackagePrivateVisibilityInSamePackage() { + AccessControl accessControl = new AccessControl(PublicClass.class, Visibility.PACKAGE_PRIVATE); + assertThat(accessControl.isAccessibleFrom(ClassName.get(PublicClass.class))).isTrue(); + } + + @Test + void isAccessibleWhenPackagePrivateVisibilityInDifferentPackage() { + AccessControl accessControl = new AccessControl(PublicClass.class, Visibility.PACKAGE_PRIVATE); + assertThat(accessControl.isAccessibleFrom(ClassName.get(String.class))).isFalse(); + } + + @Test + void isAccessibleWhenPrivateVisibilityInSamePackage() { + AccessControl accessControl = new AccessControl(PublicClass.class, Visibility.PRIVATE); + assertThat(accessControl.isAccessibleFrom(ClassName.get(PublicClass.class))).isFalse(); + } + + @Test + void isAccessibleWhenPrivateVisibilityInDifferentPackage() { + AccessControl accessControl = new AccessControl(PublicClass.class, Visibility.PRIVATE); + assertThat(accessControl.isAccessibleFrom(ClassName.get(String.class))).isFalse(); + } + + @Test + void forMemberWhenPublicConstructor() throws NoSuchMethodException { + Member member = PublicClass.class.getConstructor(); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PUBLIC); + } + + @Test + void forMemberWhenPackagePrivateConstructor() { + Member member = ProtectedAccessor.class.getDeclaredConstructors()[0]; + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPackagePrivateClassWithPublicConstructor() { + Member member = PackagePrivateClass.class.getDeclaredConstructors()[0]; + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPackagePrivateClassWithPublicMethod() { + Member member = method(PackagePrivateClass.class, "stringBean"); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPackagePrivateConstructorParameter() { + Member member = ProtectedParameter.class.getConstructors()[0]; + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPackagePrivateGenericOnConstructorParameter() { + Member member = ProtectedGenericParameter.class.getConstructors()[0]; + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPackagePrivateMethod() { + Member member = method(PublicClass.class, "getProtectedMethod"); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPackagePrivateMethodReturnType() { + Member member = method(ProtectedAccessor.class, "methodWithProtectedReturnType"); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPackagePrivateMethodParameter() { + Member member = method(ProtectedAccessor.class, "methodWithProtectedParameter", + PackagePrivateClass.class); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPackagePrivateField() { + Field member = field(PublicClass.class, "protectedField"); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPublicFieldAndPackagePrivateFieldType() { + Member member = field(PublicClass.class, "protectedClassField"); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPrivateField() { + Member member = field(PublicClass.class, "privateField"); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPublicMethodAndPackagePrivateGenericOnReturnType() { + Member member = method(PublicFactoryBean.class, "protectedTypeFactoryBean"); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forMemberWhenPublicClassWithPackagePrivateArrayComponent() { + Member member = field(PublicClass.class, "packagePrivateClasses"); + AccessControl accessControl = AccessControl.forMember(member); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forResolvableTypeWhenPackagePrivateGeneric() { + ResolvableType resolvableType = PublicFactoryBean + .resolveToProtectedGenericParameter(); + AccessControl accessControl = AccessControl.forResolvableType(resolvableType); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + @Test + void forResolvableTypeWhenRecursiveType() { + ResolvableType resolvableType = ResolvableType + .forClassWithGenerics(SelfReference.class, SelfReference.class); + AccessControl accessControl = AccessControl.forResolvableType(resolvableType); + assertThat(accessControl.getVisibility()).isEqualTo(Visibility.PACKAGE_PRIVATE); + } + + + private static Method method(Class type, String name, Class... parameterTypes) { + Method method = ReflectionUtils.findMethod(type, name, parameterTypes); + assertThat(method).isNotNull(); + return method; + } + + private static Field field(Class type, String name) { + Field field = ReflectionUtils.findField(type, name); + assertThat(field).isNotNull(); + return field; + } + + static class SelfReference> { + + @SuppressWarnings({ "unchecked", "unused" }) + T getThis() { + return (T) this; + } + + } +} diff --git a/spring-core/src/test/java/org/springframework/aot/generate/AccessVisibilityTests.java b/spring-core/src/test/java/org/springframework/aot/generate/AccessVisibilityTests.java deleted file mode 100644 index 0b45d6827e9..00000000000 --- a/spring-core/src/test/java/org/springframework/aot/generate/AccessVisibilityTests.java +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Copyright 2002-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.aot.generate; - -import java.lang.reflect.Field; -import java.lang.reflect.Member; -import java.lang.reflect.Method; - -import org.junit.jupiter.api.Test; - -import org.springframework.core.ResolvableType; -import org.springframework.core.testfixture.aot.generator.visibility.ProtectedGenericParameter; -import org.springframework.core.testfixture.aot.generator.visibility.ProtectedParameter; -import org.springframework.core.testfixture.aot.generator.visibility.PublicFactoryBean; -import org.springframework.util.ReflectionUtils; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Tests for {@link AccessVisibility}. - * - * @author Phillip Webb - * @author Stephane Nicoll - */ -class AccessVisibilityTests { - - @Test - void forMemberWhenPublicConstructor() throws NoSuchMethodException { - Member member = PublicClass.class.getConstructor(); - assertThat(AccessVisibility.forMember(member)).isEqualTo(AccessVisibility.PUBLIC); - } - - @Test - void forMemberWhenPackagePrivateConstructor() { - Member member = ProtectedAccessor.class.getDeclaredConstructors()[0]; - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPackagePrivateClassWithPublicConstructor() { - Member member = PackagePrivateClass.class.getDeclaredConstructors()[0]; - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPackagePrivateClassWithPublicMethod() { - Member member = method(PackagePrivateClass.class, "stringBean"); - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPackagePrivateConstructorParameter() { - Member member = ProtectedParameter.class.getConstructors()[0]; - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPackagePrivateGenericOnConstructorParameter() { - Member member = ProtectedGenericParameter.class.getConstructors()[0]; - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPackagePrivateMethod() { - Member member = method(PublicClass.class, "getProtectedMethod"); - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPackagePrivateMethodReturnType() { - Member member = method(ProtectedAccessor.class, "methodWithProtectedReturnType"); - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPackagePrivateMethodParameter() { - Member member = method(ProtectedAccessor.class, "methodWithProtectedParameter", - PackagePrivateClass.class); - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPackagePrivateField() { - Field member = field(PublicClass.class, "protectedField"); - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPublicFieldAndPackagePrivateFieldType() { - Member member = field(PublicClass.class, "protectedClassField"); - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPublicMethodAndPackagePrivateGenericOnReturnType() { - Member member = method(PublicFactoryBean.class, "protectedTypeFactoryBean"); - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPackagePrivateArrayComponent() { - Member member = field(PublicClass.class, "packagePrivateClasses"); - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forResolvableTypeWhenPackagePrivateGeneric() { - ResolvableType resolvableType = PublicFactoryBean - .resolveToProtectedGenericParameter(); - assertThat(AccessVisibility.forResolvableType(resolvableType)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forResolvableTypeWhenRecursiveType() { - ResolvableType resolvableType = ResolvableType - .forClassWithGenerics(SelfReference.class, SelfReference.class); - assertThat(AccessVisibility.forResolvableType(resolvableType)) - .isEqualTo(AccessVisibility.PACKAGE_PRIVATE); - } - - @Test - void forMemberWhenPublicClassWithPrivateField() { - Member member = field(PublicClass.class, "privateField"); - assertThat(AccessVisibility.forMember(member)) - .isEqualTo(AccessVisibility.PRIVATE); - } - - private static Method method(Class type, String name, Class... parameterTypes) { - Method method = ReflectionUtils.findMethod(type, name, parameterTypes); - assertThat(method).isNotNull(); - return method; - } - - private static Field field(Class type, String name) { - Field field = ReflectionUtils.findField(type, name); - assertThat(field).isNotNull(); - return field; - } - - static class SelfReference> { - - @SuppressWarnings("unchecked") - T getThis() { - return (T) this; - } - - } -} diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/InjectionCodeGenerator.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/InjectionCodeGenerator.java index 3121df3dd18..bd16b52b5cc 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/support/InjectionCodeGenerator.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/InjectionCodeGenerator.java @@ -20,9 +20,10 @@ import java.lang.reflect.Field; import java.lang.reflect.Member; import java.lang.reflect.Method; -import org.springframework.aot.generate.AccessVisibility; +import org.springframework.aot.generate.AccessControl; import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.RuntimeHints; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; @@ -45,11 +46,15 @@ import org.springframework.util.ReflectionUtils; */ class InjectionCodeGenerator { + private final ClassName targetClassName; + private final RuntimeHints hints; - InjectionCodeGenerator(RuntimeHints hints) { + InjectionCodeGenerator(ClassName targetClassName, RuntimeHints hints) { + Assert.notNull(hints, "TargetClassName must not be null"); Assert.notNull(hints, "Hints must not be null"); + this.targetClassName = targetClassName; this.hints = hints; } @@ -72,9 +77,8 @@ class InjectionCodeGenerator { CodeBlock resourceToInject) { CodeBlock.Builder code = CodeBlock.builder(); - AccessVisibility visibility = AccessVisibility.forMember(field); - if (visibility == AccessVisibility.PRIVATE - || visibility == AccessVisibility.PROTECTED) { + AccessControl accessControl = AccessControl.forMember(field); + if (!accessControl.isAccessibleFrom(this.targetClassName)) { this.hints.reflection().registerField(field); code.addStatement("$T field = $T.findField($T.class, $S)", Field.class, ReflectionUtils.class, field.getDeclaringClass(), field.getName()); @@ -95,9 +99,8 @@ class InjectionCodeGenerator { Assert.isTrue(method.getParameterCount() == 1, "Method '" + method.getName() + "' must declare a single parameter"); CodeBlock.Builder code = CodeBlock.builder(); - AccessVisibility visibility = AccessVisibility.forMember(method); - if (visibility == AccessVisibility.PRIVATE - || visibility == AccessVisibility.PROTECTED) { + AccessControl accessControl = AccessControl.forMember(method); + if (!accessControl.isAccessibleFrom(this.targetClassName)) { this.hints.reflection().registerMethod(method, ExecutableMode.INVOKE); code.addStatement("$T method = $T.findMethod($T.class, $S, $T.class)", Method.class, ReflectionUtils.class, method.getDeclaringClass(), diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java index abac448ac02..fc286851271 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java @@ -794,16 +794,17 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); method.addParameter(this.target, INSTANCE_PARAMETER); method.returns(this.target); - method.addCode(generateMethodCode(generationContext.getRuntimeHints(), generatedClass.getMethods())); + method.addCode(generateMethodCode(generationContext.getRuntimeHints(), generatedClass)); }); beanRegistrationCode.addInstancePostProcessor(generatedMethod.toMethodReference()); } - private CodeBlock generateMethodCode(RuntimeHints hints, GeneratedMethods generatedMethods) { + private CodeBlock generateMethodCode(RuntimeHints hints, GeneratedClass generatedClass) { CodeBlock.Builder code = CodeBlock.builder(); - InjectionCodeGenerator injectionCodeGenerator = new InjectionCodeGenerator(hints); + InjectionCodeGenerator injectionCodeGenerator = new InjectionCodeGenerator( + generatedClass.getName(), hints); for (InjectedElement injectedElement : this.injectedElements) { - CodeBlock resourceToInject = generateResourceToInjectCode(generatedMethods, + CodeBlock resourceToInject = generateResourceToInjectCode(generatedClass.getMethods(), (PersistenceElement) injectedElement); code.add(injectionCodeGenerator.generateInjectionCode( injectedElement.getMember(), INSTANCE_PARAMETER, @@ -823,9 +824,9 @@ public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwar EntityManagerFactoryUtils.class, ListableBeanFactory.class, REGISTERED_BEAN_PARAMETER, unitName); } - String[] methodNameParts = { "get" , unitName, "EntityManager" }; + String[] methodNameParts = { "get", unitName, "EntityManager" }; GeneratedMethod generatedMethod = generatedMethods.add(methodNameParts, method -> - generateGetEntityManagerMethod(method, injectedElement)); + generateGetEntityManagerMethod(method, injectedElement)); return CodeBlock.of("$L($L)", generatedMethod.getName(), REGISTERED_BEAN_PARAMETER); } diff --git a/spring-orm/src/test/java/org/springframework/orm/jpa/support/InjectionCodeGeneratorTests.java b/spring-orm/src/test/java/org/springframework/orm/jpa/support/InjectionCodeGeneratorTests.java index 60c41dc4ba8..5aaa3008828 100644 --- a/spring-orm/src/test/java/org/springframework/orm/jpa/support/InjectionCodeGeneratorTests.java +++ b/spring-orm/src/test/java/org/springframework/orm/jpa/support/InjectionCodeGeneratorTests.java @@ -28,10 +28,14 @@ import org.junit.jupiter.api.Test; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.beans.testfixture.beans.TestBeanWithPackagePrivateField; +import org.springframework.beans.testfixture.beans.TestBeanWithPackagePrivateMethod; import org.springframework.beans.testfixture.beans.TestBeanWithPrivateMethod; import org.springframework.beans.testfixture.beans.TestBeanWithPublicField; +import org.springframework.core.test.tools.CompileWithForkedClassLoader; import org.springframework.core.test.tools.Compiled; import org.springframework.core.test.tools.TestCompiler; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.JavaFile; import org.springframework.javapoet.MethodSpec; @@ -50,17 +54,18 @@ class InjectionCodeGeneratorTests { private static final String INSTANCE_VARIABLE = "instance"; - private RuntimeHints hints = new RuntimeHints(); + private static final ClassName TEST_TARGET = ClassName.get("com.example", "Test"); - private InjectionCodeGenerator generator = new InjectionCodeGenerator(hints); + private final RuntimeHints hints = new RuntimeHints(); @Test void generateCodeWhenPublicFieldInjectsValue() { TestBeanWithPublicField bean = new TestBeanWithPublicField(); Field field = ReflectionUtils.findField(bean.getClass(), "age"); - CodeBlock generatedCode = this.generator.generateInjectionCode(field, INSTANCE_VARIABLE, - CodeBlock.of("$L", 123)); - testCompiledResult(generatedCode, TestBeanWithPublicField.class, (actual, compiled) -> { + ClassName targetClassName = TEST_TARGET; + CodeBlock generatedCode = createGenerator(targetClassName).generateInjectionCode( + field, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + testCompiledResult(targetClassName, generatedCode, TestBeanWithPublicField.class, (actual, compiled) -> { TestBeanWithPublicField instance = new TestBeanWithPublicField(); actual.accept(instance); assertThat(instance).extracting("age").isEqualTo(123); @@ -68,13 +73,45 @@ class InjectionCodeGeneratorTests { }); } + @Test + @CompileWithForkedClassLoader + void generateCodeWhenPackagePrivateFieldInTargetPackageInjectsValue() { + TestBeanWithPackagePrivateField bean = new TestBeanWithPackagePrivateField(); + Field field = ReflectionUtils.findField(bean.getClass(), "age"); + ClassName targetClassName = ClassName.get(TestBeanWithPackagePrivateField.class.getPackageName(), "Test"); + CodeBlock generatedCode = createGenerator(targetClassName).generateInjectionCode( + field, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + testCompiledResult(targetClassName, generatedCode, TestBeanWithPackagePrivateField.class, (actual, compiled) -> { + TestBeanWithPackagePrivateField instance = new TestBeanWithPackagePrivateField(); + actual.accept(instance); + assertThat(instance).extracting("age").isEqualTo(123); + assertThat(compiled.getSourceFile()).contains("instance.age = 123"); + }); + } + + @Test + void generateCodeWhenPackagePrivateFieldInAnotherPackageUsesReflection() { + TestBeanWithPackagePrivateField bean = new TestBeanWithPackagePrivateField(); + Field field = ReflectionUtils.findField(bean.getClass(), "age"); + ClassName targetClassName = TEST_TARGET; + CodeBlock generatedCode = createGenerator(targetClassName).generateInjectionCode( + field, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + testCompiledResult(targetClassName, generatedCode, TestBeanWithPackagePrivateField.class, (actual, compiled) -> { + TestBeanWithPackagePrivateField instance = new TestBeanWithPackagePrivateField(); + actual.accept(instance); + assertThat(instance).extracting("age").isEqualTo(123); + assertThat(compiled.getSourceFile()).contains("setField("); + }); + } + @Test void generateCodeWhenPrivateFieldInjectsValueUsingReflection() { TestBean bean = new TestBean(); Field field = ReflectionUtils.findField(bean.getClass(), "age"); - CodeBlock generatedCode = this.generator.generateInjectionCode(field, INSTANCE_VARIABLE, - CodeBlock.of("$L", 123)); - testCompiledResult(generatedCode, TestBean.class, (actual, compiled) -> { + ClassName targetClassName = ClassName.get(TestBean.class); + CodeBlock generatedCode = createGenerator(targetClassName).generateInjectionCode( + field, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + testCompiledResult(targetClassName, generatedCode, TestBean.class, (actual, compiled) -> { TestBean instance = new TestBean(); actual.accept(instance); assertThat(instance).extracting("age").isEqualTo(123); @@ -86,7 +123,8 @@ class InjectionCodeGeneratorTests { void generateCodeWhenPrivateFieldAddsHint() { TestBean bean = new TestBean(); Field field = ReflectionUtils.findField(bean.getClass(), "age"); - this.generator.generateInjectionCode(field, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + createGenerator(TEST_TARGET).generateInjectionCode( + field, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); assertThat(RuntimeHintsPredicates.reflection().onField(TestBean.class, "age")) .accepts(this.hints); } @@ -95,9 +133,10 @@ class InjectionCodeGeneratorTests { void generateCodeWhenPublicMethodInjectsValue() { TestBean bean = new TestBean(); Method method = ReflectionUtils.findMethod(bean.getClass(), "setAge", int.class); - CodeBlock generatedCode = this.generator.generateInjectionCode(method, INSTANCE_VARIABLE, - CodeBlock.of("$L", 123)); - testCompiledResult(generatedCode, TestBean.class, (actual, compiled) -> { + ClassName targetClassName = TEST_TARGET; + CodeBlock generatedCode = createGenerator(targetClassName).generateInjectionCode( + method, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + testCompiledResult(targetClassName, generatedCode, TestBean.class, (actual, compiled) -> { TestBean instance = new TestBean(); actual.accept(instance); assertThat(instance).extracting("age").isEqualTo(123); @@ -105,13 +144,45 @@ class InjectionCodeGeneratorTests { }); } + @Test + @CompileWithForkedClassLoader + void generateCodeWhenPackagePrivateMethodInTargetPackageInjectsValue() { + TestBeanWithPackagePrivateMethod bean = new TestBeanWithPackagePrivateMethod(); + Method method = ReflectionUtils.findMethod(bean.getClass(), "setAge", int.class); + ClassName targetClassName = ClassName.get(TestBeanWithPackagePrivateMethod.class); + CodeBlock generatedCode = createGenerator(targetClassName).generateInjectionCode( + method, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + testCompiledResult(targetClassName, generatedCode, TestBeanWithPackagePrivateMethod.class, (actual, compiled) -> { + TestBeanWithPackagePrivateMethod instance = new TestBeanWithPackagePrivateMethod(); + actual.accept(instance); + assertThat(instance).extracting("age").isEqualTo(123); + assertThat(compiled.getSourceFile()).contains("instance.setAge("); + }); + } + + @Test + void generateCodeWhenPackagePrivateMethodInAnotherPackageUsesReflection() { + TestBeanWithPackagePrivateMethod bean = new TestBeanWithPackagePrivateMethod(); + Method method = ReflectionUtils.findMethod(bean.getClass(), "setAge", int.class); + ClassName targetClassName = TEST_TARGET; + CodeBlock generatedCode = createGenerator(targetClassName).generateInjectionCode( + method, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + testCompiledResult(targetClassName, generatedCode, TestBeanWithPackagePrivateMethod.class, (actual, compiled) -> { + TestBeanWithPackagePrivateMethod instance = new TestBeanWithPackagePrivateMethod(); + actual.accept(instance); + assertThat(instance).extracting("age").isEqualTo(123); + assertThat(compiled.getSourceFile()).contains("invokeMethod("); + }); + } + @Test void generateCodeWhenPrivateMethodInjectsValueUsingReflection() { TestBeanWithPrivateMethod bean = new TestBeanWithPrivateMethod(); Method method = ReflectionUtils.findMethod(bean.getClass(), "setAge", int.class); - CodeBlock generatedCode = this.generator.generateInjectionCode(method, INSTANCE_VARIABLE, - CodeBlock.of("$L", 123)); - testCompiledResult(generatedCode, TestBeanWithPrivateMethod.class, (actual, compiled) -> { + ClassName targetClassName = ClassName.get(TestBeanWithPrivateMethod.class); + CodeBlock generatedCode = createGenerator(targetClassName) + .generateInjectionCode(method, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + testCompiledResult(targetClassName, generatedCode, TestBeanWithPrivateMethod.class, (actual, compiled) -> { TestBeanWithPrivateMethod instance = new TestBeanWithPrivateMethod(); actual.accept(instance); assertThat(instance).extracting("age").isEqualTo(123); @@ -123,26 +194,31 @@ class InjectionCodeGeneratorTests { void generateCodeWhenPrivateMethodAddsHint() { TestBeanWithPrivateMethod bean = new TestBeanWithPrivateMethod(); Method method = ReflectionUtils.findMethod(bean.getClass(), "setAge", int.class); - this.generator.generateInjectionCode(method, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + createGenerator(TEST_TARGET).generateInjectionCode( + method, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); assertThat(RuntimeHintsPredicates.reflection() .onMethod(TestBeanWithPrivateMethod.class, "setAge").invoke()).accepts(this.hints); } + private InjectionCodeGenerator createGenerator(ClassName target) { + return new InjectionCodeGenerator(target, this.hints); + } + @SuppressWarnings("unchecked") - private void testCompiledResult(CodeBlock generatedCode, Class target, + private void testCompiledResult(ClassName generatedClasName, CodeBlock generatedCode, Class target, BiConsumer, Compiled> result) { - JavaFile javaFile = createJavaFile(generatedCode, target); + JavaFile javaFile = createJavaFile(generatedClasName, generatedCode, target); TestCompiler.forSystem().compile(javaFile::writeTo, compiled -> result.accept(compiled.getInstance(Consumer.class), compiled)); } - private JavaFile createJavaFile(CodeBlock generatedCode, Class target) { - TypeSpec.Builder builder = TypeSpec.classBuilder("Injector"); + private JavaFile createJavaFile(ClassName generatedClasName, CodeBlock generatedCode, Class target) { + TypeSpec.Builder builder = TypeSpec.classBuilder(generatedClasName.simpleName() + "__Injector"); builder.addModifiers(Modifier.PUBLIC); builder.addSuperinterface(ParameterizedTypeName.get(Consumer.class, target)); builder.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) .addParameter(target, INSTANCE_VARIABLE).addCode(generatedCode).build()); - return JavaFile.builder("__", builder.build()).build(); + return JavaFile.builder(generatedClasName.packageName(), builder.build()).build(); } }