Improve TestCompiler and allow lookup based class defines

Update the `TestCompiler` so that classes can be defined using
a `Lookup`. This update allows package-private classes to be
accessed without needing a quite so unusual classloader setup.

The `@CompileWithTargetClassAccess` should be added to any
test that needs to use `Lookup` based defines. The test will
run with a completed forked classloader so not to pollute the
main classloader.

This commit also adds some useful additional APIs.

See gh-28120
This commit is contained in:
Phillip Webb 2022-04-13 16:48:13 -07:00
parent b3efdf3c2b
commit 4b82546b97
18 changed files with 622 additions and 60 deletions

View File

@ -4,6 +4,9 @@ dependencies {
api(project(":spring-core"))
api("org.assertj:assertj-core")
api("com.thoughtworks.qdox:qdox")
compileOnly("org.junit.jupiter:junit-jupiter")
compileOnly("org.junit.platform:junit-platform-engine")
compileOnly("org.junit.platform:junit-platform-launcher")
}
tasks.withType(PublishToMavenRepository).configureEach {

View File

@ -16,6 +16,11 @@
package org.springframework.aot.test.generator.compile;
import org.springframework.aot.test.generator.file.ResourceFile;
import org.springframework.aot.test.generator.file.ResourceFiles;
import org.springframework.aot.test.generator.file.SourceFile;
import org.springframework.aot.test.generator.file.SourceFiles;
/**
* Exception thrown when code cannot compile.
*
@ -25,8 +30,29 @@ package org.springframework.aot.test.generator.compile;
@SuppressWarnings("serial")
public class CompilationException extends RuntimeException {
CompilationException(String message) {
super(message);
CompilationException(String errors, SourceFiles sourceFiles, ResourceFiles resourceFiles) {
super(buildMessage(errors, sourceFiles, resourceFiles));
}
private static String buildMessage(String errors, SourceFiles sourceFiles,
ResourceFiles resourceFiles) {
StringBuilder message = new StringBuilder();
message.append("Unable to compile source\n\n");
message.append(errors);
message.append("\n\n");
for (SourceFile sourceFile : sourceFiles) {
message.append("---- source: " + sourceFile.getPath() + "\n\n");
message.append(sourceFile.getContent());
message.append("\n\n");
}
for (ResourceFile resourceFile : resourceFiles) {
message.append("---- resource: " + resourceFile.getPath() + "\n\n");
message.append(resourceFile.getContent());
message.append("\n\n");
}
return message.toString();
}
}

View File

@ -0,0 +1,56 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.aot.test.generator.compile;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodHandles.Lookup;
import org.junit.jupiter.api.extension.ExtendWith;
/**
* Annotation that can be used on tests that need a {@link TestCompiler} with
* non-public access to a target class. Allows the compiler to use
* {@link MethodHandles#privateLookupIn} to {@link Lookup#defineClass define the
* class} without polluting the test {@link ClassLoader}.
*
* @author Phillip Webb
* @since 6.0
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ ElementType.TYPE, ElementType.METHOD })
@Documented
@ExtendWith(CompileWithTargetClassAccessExtension.class)
public @interface CompileWithTargetClassAccess {
/**
* The target class names.
* @return the class name
*/
String[] classNames() default {};
/**
* The target classes.
* @return the classes
*/
Class<?>[] classes() default {};
}

View File

@ -0,0 +1,78 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.aot.test.generator.compile;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.Enumeration;
/**
* {@link ClassLoader} implementation to support
* {@link CompileWithTargetClassAccess @CompileWithTargetClassAccess}.
*
* @author Phillip Webb
* @since 6.0
*/
final class CompileWithTargetClassAccessClassLoader extends ClassLoader {
private final ClassLoader testClassLoader;
private final String[] targetClasses;
public CompileWithTargetClassAccessClassLoader(ClassLoader testClassLoader,
String[] targetClasses) {
super(testClassLoader.getParent());
this.testClassLoader = testClassLoader;
this.targetClasses = targetClasses;
}
public String[] getTargetClasses() {
return this.targetClasses;
}
@Override
public Class<?> loadClass(String name) throws ClassNotFoundException {
if (name.startsWith("org.junit") || name.startsWith("org.hamcrest")) {
return Class.forName(name, false, this.testClassLoader);
}
return super.loadClass(name);
}
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
String resourceName = name.replace(".", "/") + ".class";
InputStream stream = this.testClassLoader.getResourceAsStream(resourceName);
if (stream != null) {
try (stream) {
byte[] bytes = stream.readAllBytes();
return defineClass(name, bytes, 0, bytes.length, null);
}
catch (IOException ex) {
}
}
return super.findClass(name);
}
@Override
protected Enumeration<URL> findResources(String name) throws IOException {
return this.testClassLoader.getResources(name);
}
}

View File

@ -0,0 +1,197 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.aot.test.generator.compile;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.InvocationInterceptor;
import org.junit.jupiter.api.extension.ReflectiveInvocationContext;
import org.junit.platform.engine.discovery.DiscoverySelectors;
import org.junit.platform.launcher.Launcher;
import org.junit.platform.launcher.LauncherDiscoveryRequest;
import org.junit.platform.launcher.TestPlan;
import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder;
import org.junit.platform.launcher.core.LauncherFactory;
import org.junit.platform.launcher.listeners.SummaryGeneratingListener;
import org.junit.platform.launcher.listeners.TestExecutionSummary;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;
/**
* JUnit {@link InvocationInterceptor} to support
* {@link CompileWithTargetClassAccess @CompileWithTargetClassAccess}.
*
* @author Christoph Dreis
* @author Phillip Webb
* @since 6.0
*/
class CompileWithTargetClassAccessExtension implements InvocationInterceptor {
@Override
public void interceptBeforeAllMethod(Invocation<Void> invocation,
ReflectiveInvocationContext<Method> invocationContext,
ExtensionContext extensionContext) throws Throwable {
intercept(invocation, extensionContext);
}
@Override
public void interceptBeforeEachMethod(Invocation<Void> invocation,
ReflectiveInvocationContext<Method> invocationContext,
ExtensionContext extensionContext) throws Throwable {
intercept(invocation, extensionContext);
}
@Override
public void interceptAfterEachMethod(Invocation<Void> invocation,
ReflectiveInvocationContext<Method> invocationContext,
ExtensionContext extensionContext) throws Throwable {
intercept(invocation, extensionContext);
}
@Override
public void interceptAfterAllMethod(Invocation<Void> invocation,
ReflectiveInvocationContext<Method> invocationContext,
ExtensionContext extensionContext) throws Throwable {
intercept(invocation, extensionContext);
}
@Override
public void interceptTestMethod(Invocation<Void> invocation,
ReflectiveInvocationContext<Method> invocationContext,
ExtensionContext extensionContext) throws Throwable {
intercept(invocation, extensionContext,
() -> runTestWithModifiedClassPath(invocationContext, extensionContext));
}
private void intercept(Invocation<Void> invocation, ExtensionContext extensionContext)
throws Throwable {
intercept(invocation, extensionContext, Action.NONE);
}
private void intercept(Invocation<Void> invocation, ExtensionContext extensionContext,
Action action) throws Throwable {
if (isUsingForkedClassPathLoader(extensionContext)) {
invocation.proceed();
return;
}
invocation.skip();
action.run();
}
private boolean isUsingForkedClassPathLoader(ExtensionContext extensionContext) {
Class<?> testClass = extensionContext.getRequiredTestClass();
ClassLoader classLoader = testClass.getClassLoader();
return classLoader.getClass().getName()
.equals(CompileWithTargetClassAccessClassLoader.class.getName());
}
private void runTestWithModifiedClassPath(
ReflectiveInvocationContext<Method> invocationContext,
ExtensionContext extensionContext) throws Throwable {
Class<?> testClass = extensionContext.getRequiredTestClass();
Method testMethod = invocationContext.getExecutable();
String[] targetClasses = getTargetClasses(testClass, testMethod);
ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader();
ClassLoader forkedClassPathClassLoader = new CompileWithTargetClassAccessClassLoader(
testClass.getClassLoader(), targetClasses);
Thread.currentThread().setContextClassLoader(forkedClassPathClassLoader);
try {
runTest(forkedClassPathClassLoader, testClass.getName(), testMethod.getName());
}
finally {
Thread.currentThread().setContextClassLoader(originalClassLoader);
}
}
private String[] getTargetClasses(AnnotatedElement... elements) {
Set<String> targetClasses = new LinkedHashSet<>();
for (AnnotatedElement element : elements) {
MergedAnnotation<?> annotation = MergedAnnotations.from(element)
.get(CompileWithTargetClassAccess.class);
if (annotation.isPresent()) {
Arrays.stream(annotation.getStringArray("classNames")).forEach(targetClasses::add);
Arrays.stream(annotation.getClassArray("classes")).map(Class::getName).forEach(targetClasses::add);
if (element instanceof Class<?> clazz) {
targetClasses.add(clazz.getName());
}
}
}
return targetClasses.toArray(String[]::new);
}
private void runTest(ClassLoader classLoader, String testClassName,
String testMethodName) throws Throwable {
Class<?> testClass = classLoader.loadClass(testClassName);
Method testMethod = findMethod(testClass, testMethodName);
LauncherDiscoveryRequest request = LauncherDiscoveryRequestBuilder.request()
.selectors(DiscoverySelectors.selectMethod(testClass, testMethod))
.build();
Launcher launcher = LauncherFactory.create();
TestPlan testPlan = launcher.discover(request);
SummaryGeneratingListener listener = new SummaryGeneratingListener();
launcher.registerTestExecutionListeners(listener);
launcher.execute(testPlan);
TestExecutionSummary summary = listener.getSummary();
if (!CollectionUtils.isEmpty(summary.getFailures())) {
throw summary.getFailures().get(0).getException();
}
}
private Method findMethod(Class<?> testClass, String testMethodName) {
Method method = ReflectionUtils.findMethod(testClass, testMethodName);
if (method == null) {
Method[] methods = ReflectionUtils.getUniqueDeclaredMethods(testClass);
for (Method candidate : methods) {
if (candidate.getName().equals(testMethodName)) {
return candidate;
}
}
}
Assert.state(method != null, () -> "Unable to find " + testClass + "." + testMethodName);
return method;
}
interface Action {
static Action NONE = () -> {
};
void run() throws Throwable;
}
}

View File

@ -74,6 +74,28 @@ public class Compiled {
return this.sourceFiles.getSingle();
}
/**
* Return the single matching source file that was compiled.
* @param pattern the pattern used to find the file
* @return the single source file
* @throws IllegalStateException if the compiler wasn't passed exactly one
* file
*/
public SourceFile getSourceFile(String pattern) {
return this.sourceFiles.getSingle(pattern);
}
/**
* Return the single source file that was compiled in the given package.
* @param packageName the package name to check
* @return the single source file
* @throws IllegalStateException if the compiler wasn't passed exactly one
* file
*/
public SourceFile getSourceFileFromPackage(String packageName) {
return this.sourceFiles.getSingleFromPackage(packageName);
}
/**
* Return all source files that were compiled.
* @return the source files used by the compiler

View File

@ -23,7 +23,7 @@ import java.lang.System.Logger;
import java.lang.System.Logger.Level;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodHandles.Lookup;
import java.lang.reflect.Modifier;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
@ -34,9 +34,9 @@ import java.util.Map;
import org.springframework.aot.test.generator.file.ResourceFile;
import org.springframework.aot.test.generator.file.ResourceFiles;
import org.springframework.aot.test.generator.file.SourceFile;
import org.springframework.aot.test.generator.file.SourceFiles;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
/**
* {@link ClassLoader} used to expose dynamically generated content.
@ -50,20 +50,15 @@ public class DynamicClassLoader extends ClassLoader {
private static final Logger logger = System.getLogger(DynamicClassLoader.class.getName());
private final SourceFiles sourceFiles;
private final ResourceFiles resourceFiles;
private final Map<String, DynamicClassFileObject> classFiles;
private final ClassLoader sourceLoader;
public DynamicClassLoader(ClassLoader parent, ResourceFiles resourceFiles,
Map<String, DynamicClassFileObject> classFiles) {
public DynamicClassLoader(ClassLoader sourceLoader, SourceFiles sourceFiles,
ResourceFiles resourceFiles, Map<String, DynamicClassFileObject> classFiles) {
super(sourceLoader.getParent());
this.sourceLoader = sourceLoader;
this.sourceFiles = sourceFiles;
super(parent);
this.resourceFiles = resourceFiles;
this.classFiles = classFiles;
}
@ -75,42 +70,43 @@ public class DynamicClassLoader extends ClassLoader {
if (classFile != null) {
return defineClass(name, classFile);
}
try {
Class<?> fromSourceLoader = this.sourceLoader.loadClass(name);
if (Modifier.isPublic(fromSourceLoader.getModifiers())) {
return fromSourceLoader;
}
}
catch (Exception ex) {
// Continue
}
try (InputStream classStream = this.sourceLoader.getResourceAsStream(name.replace(".", "/") + ".class")) {
byte[] bytes = classStream.readAllBytes();
return defineClass(name, bytes, 0, bytes.length, null);
}
catch (IOException ex) {
throw new ClassNotFoundException(name);
}
return super.findClass(name);
}
private Class<?> defineClass(String name, DynamicClassFileObject classFile) {
byte[] bytes = classFile.getBytes();
SourceFile sourceFile = this.sourceFiles.get(name);
if (sourceFile != null && sourceFile.getTarget() != null) {
Class<?> targetClass = getTargetClass(name);
if (targetClass != null) {
try {
Lookup lookup = MethodHandles.privateLookupIn(sourceFile.getTarget(),
MethodHandles.lookup());
Lookup lookup = MethodHandles.privateLookupIn(targetClass, MethodHandles.lookup());
return lookup.defineClass(bytes);
}
catch (IllegalAccessException ex) {
logger.log(Level.WARNING,
"Unable to define class using MethodHandles Lookup, "
+ "only public methods and classes will be accessible");
logger.log(Level.WARNING, "Unable to define class using MethodHandles Lookup, "
+ "only public methods and classes will be accessible");
}
}
return defineClass(name, bytes, 0, bytes.length, null);
}
private Class<?> getTargetClass(String name) {
ClassLoader parentClassLoader = getParent();
if (parentClassLoader.getClass().getName()
.equals(CompileWithTargetClassAccessClassLoader.class.getName())) {
String packageName = ClassUtils.getPackageName(name);
Method method = ReflectionUtils.findMethod(parentClassLoader.getClass(), "getTargetClasses");
ReflectionUtils.makeAccessible(method);
String[] targetCasses = (String[]) ReflectionUtils.invokeMethod(method, parentClassLoader);
for (String targetClass : targetCasses) {
String targetPackageName = ClassUtils.getPackageName(targetClass);
if (targetPackageName.equals(packageName)) {
return ClassUtils.resolveClassName(targetClass, this);
}
}
}
return null;
}
@Override
protected Enumeration<URL> findResources(String name) throws IOException {
URL resource = findResource(name);

View File

@ -16,6 +16,7 @@
package org.springframework.aot.test.generator.compile;
import java.io.PrintStream;
import java.util.List;
import java.util.Locale;
import java.util.function.Consumer;
@ -92,6 +93,16 @@ public final class TestCompiler {
this.sourceFiles.and(sourceFiles), this.resourceFiles);
}
/**
* Return a new {@link TestCompiler} instance with addition source files.
* @param sourceFiles the additional source files
* @return a new {@link TestCompiler} instance
*/
public TestCompiler withSources(Iterable<SourceFile> sourceFiles) {
return new TestCompiler(this.classLoader, this.compiler,
this.sourceFiles.and(sourceFiles), this.resourceFiles);
}
/**
* Return a new {@link TestCompiler} instance with addition source files.
* @param sourceFiles the additional source files
@ -112,6 +123,16 @@ public final class TestCompiler {
this.resourceFiles.and(resourceFiles));
}
/**
* Return a new {@link TestCompiler} instance with addition source files.
* @param resourceFiles the additional source files
* @return a new {@link TestCompiler} instance
*/
public TestCompiler withResources(Iterable<ResourceFile> resourceFiles) {
return new TestCompiler(this.classLoader, this.compiler, this.sourceFiles,
this.resourceFiles.and(resourceFiles));
}
/**
* Return a new {@link TestCompiler} instance with addition resource files.
* @param resourceFiles the additional resource files
@ -179,8 +200,11 @@ public final class TestCompiler {
ClassLoader previousClassLoader = Thread.currentThread().getContextClassLoader();
try {
Thread.currentThread().setContextClassLoader(dynamicClassLoader);
compiled.accept(new Compiled(dynamicClassLoader, this.sourceFiles,
this.resourceFiles));
compiled.accept(new Compiled(dynamicClassLoader, this.sourceFiles, this.resourceFiles));
}
catch (IllegalAccessError ex) {
throw new IllegalAccessError(ex.getMessage() + ". " +
"For non-public access ensure annotate your tests with @CompileWithTargetClassAccess");
}
finally {
Thread.currentThread().setContextClassLoader(previousClassLoader);
@ -202,11 +226,30 @@ public final class TestCompiler {
null, compilationUnits);
boolean result = task.call();
if (!result || errors.hasReportedErrors()) {
throw new CompilationException("Unable to compile source" + errors);
throw new CompilationException(errors.toString(), this.sourceFiles, this.resourceFiles);
}
}
return new DynamicClassLoader(classLoaderToUse, this.sourceFiles,
this.resourceFiles, fileManager.getClassFiles());
return new DynamicClassLoader(classLoaderToUse, this.resourceFiles, fileManager.getClassFiles());
}
/**
* Print the contents of the source and resource files to the specified
* {@link PrintStream}.
* @param printStream the destination print stream
* @return this instance
*/
public TestCompiler printFiles(PrintStream printStream) {
for (SourceFile sourceFile : this.sourceFiles) {
printStream.append("---- source: " + sourceFile.getPath() + "\n\n");
printStream.append(sourceFile.getContent());
printStream.append("\n\n");
}
for (ResourceFile resourceFile : this.resourceFiles) {
printStream.append("---- resource: " + resourceFile.getPath() + "\n\n");
printStream.append(resourceFile.getContent());
printStream.append("\n\n");
}
return this;
}

View File

@ -44,6 +44,11 @@ public class DynamicFileAssert<A extends DynamicFileAssert<A, F>, F extends Dyna
return this.myself;
}
public A doesNotContain(CharSequence... values) {
assertThat(this.actual.getContent()).doesNotContain(values);
return this.myself;
}
public A isEqualTo(@Nullable Object expected) {
if (expected instanceof DynamicFile) {
return super.isEqualTo(expected);

View File

@ -20,12 +20,13 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Stream;
import org.springframework.lang.Nullable;
/**
* Internal class used by {@link SourceFiles} and {@link ResourceFiles} to
* manage {@link DynamicFile} instances.
@ -54,6 +55,12 @@ final class DynamicFiles<F extends DynamicFile> implements Iterable<F> {
return (DynamicFiles<F>) NONE;
}
DynamicFiles<F> and(Iterable<F> files) {
Map<String, F> merged = new LinkedHashMap<>(this.files);
files.forEach(file -> merged.put(file.getPath(), file));
return new DynamicFiles<>(Collections.unmodifiableMap(merged));
}
DynamicFiles<F> and(F[] files) {
Map<String, F> merged = new LinkedHashMap<>(this.files);
Arrays.stream(files).forEach(file -> merged.put(file.getPath(), file));
@ -85,10 +92,15 @@ final class DynamicFiles<F extends DynamicFile> implements Iterable<F> {
}
F getSingle() {
if (this.files.size() != 1) {
return getSingle(candidate -> true);
}
F getSingle(Predicate<F> filter) {
List<F> files = this.files.values().stream().filter(filter).toList();
if (files.size() != 1) {
throw new IllegalStateException("No single file available");
}
return this.files.values().iterator().next();
return files.iterator().next();
}
@Override

View File

@ -16,8 +16,15 @@
package org.springframework.aot.test.generator.file;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import org.assertj.core.api.AssertProvider;
import org.springframework.core.io.InputStreamSource;
import org.springframework.lang.Nullable;
import org.springframework.util.FileCopyUtils;
/**
* {@link DynamicFile} that holds resource file content and provides
* {@link ResourceFileAssert} support.
@ -46,6 +53,20 @@ public final class ResourceFile extends DynamicFile
return new ResourceFile(path, charSequence.toString());
}
/**
* Factory method to create a new {@link ResourceFile} from the given
* {@link InputStreamSource}.
* @param path the relative path of the file or {@code null} to have the
* path deduced
* @param inputStreamSource the source for the file
* @return a {@link SourceFile} instance
*/
public static ResourceFile of(@Nullable String path,
InputStreamSource inputStreamSource) {
return of(path, appendable -> appendable.append(FileCopyUtils.copyToString(
new InputStreamReader(inputStreamSource.getInputStream(), StandardCharsets.UTF_8))));
}
/**
* Factory method to create a new {@link SourceFile} from the given
* {@link WritableContent}.

View File

@ -21,7 +21,6 @@ import java.util.stream.Stream;
import org.springframework.lang.Nullable;
/**
* An immutable collection of {@link ResourceFile} instances.
*
@ -62,11 +61,21 @@ public final class ResourceFiles implements Iterable<ResourceFile> {
/**
* Return a new {@link ResourceFiles} instance that merges files from
* another array of {@link ResourceFile} instances.
* @param ResourceFiles the instances to merge
* @param resourceFiles the instances to merge
* @return a new {@link ResourceFiles} instance containing merged content
*/
public ResourceFiles and(ResourceFile... ResourceFiles) {
return new ResourceFiles(this.files.and(ResourceFiles));
public ResourceFiles and(ResourceFile... resourceFiles) {
return new ResourceFiles(this.files.and(resourceFiles));
}
/**
* Return a new {@link ResourceFiles} instance that merges files from another iterable
* of {@link ResourceFiles} instances.
* @param resourceFiles the instances to merge
* @return a new {@link ResourceFiles} instance containing merged content
*/
public ResourceFiles and(Iterable<ResourceFile> resourceFiles) {
return new ResourceFiles(this.files.and(resourceFiles));
}
/**

View File

@ -16,7 +16,10 @@
package org.springframework.aot.test.generator.file;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import com.thoughtworks.qdox.JavaProjectBuilder;
import com.thoughtworks.qdox.model.JavaClass;
@ -25,7 +28,9 @@ import com.thoughtworks.qdox.model.JavaSource;
import org.assertj.core.api.AssertProvider;
import org.assertj.core.util.Strings;
import org.springframework.core.io.InputStreamSource;
import org.springframework.lang.Nullable;
import org.springframework.util.FileCopyUtils;
/**
* {@link DynamicFile} that holds Java source code and provides
@ -59,7 +64,7 @@ public final class SourceFile extends DynamicFile
* @return a {@link SourceFile} instance
*/
public static SourceFile of(CharSequence charSequence) {
return of(null, appendable -> appendable.append(charSequence));
return of(null, charSequence);
}
/**
@ -74,6 +79,33 @@ public final class SourceFile extends DynamicFile
return of(path, appendable -> appendable.append(charSequence));
}
/**
* Factory method to create a new {@link SourceFile} from the given
* {@link InputStreamSource}.
* @param inputStreamSource the source for the file
* @return a {@link SourceFile} instance
*/
public static SourceFile of(InputStreamSource inputStreamSource) {
return of(null, inputStreamSource);
}
/**
* Factory method to create a new {@link SourceFile} from the given
* {@link InputStreamSource}.
* @param path the relative path of the file or {@code null} to have the
* path deduced
* @param inputStreamSource the source for the file
* @return a {@link SourceFile} instance
*/
public static SourceFile of(@Nullable String path, InputStreamSource inputStreamSource) {
return of(path, appendable -> appendable.append(copyToString(inputStreamSource)));
}
private static String copyToString(InputStreamSource inputStreamSource) throws IOException {
InputStreamReader reader = new InputStreamReader(inputStreamSource.getInputStream(), StandardCharsets.UTF_8);
return FileCopyUtils.copyToString(reader);
}
/**
* Factory method to create a new {@link SourceFile} from the given
* {@link WritableContent}.
@ -134,15 +166,9 @@ public final class SourceFile extends DynamicFile
}
/**
* Return the target class for this source file or {@code null}. The target
* class can be used if private lookup access is required.
* @return the target class
* Return the class name of the source file.
* @return the class name
*/
@Nullable
public Class<?> getTarget() {
return null; // Not yet supported
}
public String getClassName() {
return this.javaSource.getClasses().get(0).getFullyQualifiedName();
}

View File

@ -17,6 +17,8 @@
package org.springframework.aot.test.generator.file;
import java.util.Iterator;
import java.util.Objects;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import org.springframework.lang.Nullable;
@ -68,6 +70,16 @@ public final class SourceFiles implements Iterable<SourceFile> {
return new SourceFiles(this.files.and(sourceFiles));
}
/**
* Return a new {@link SourceFiles} instance that merges files from another
* array of {@link SourceFile} instances.
* @param sourceFiles the instances to merge
* @return a new {@link SourceFiles} instance containing merged content
*/
public SourceFiles and(Iterable<SourceFile> sourceFiles) {
return new SourceFiles(this.files.and(sourceFiles));
}
/**
* Return a new {@link SourceFiles} instance that merges files from another
* {@link SourceFiles} instance.
@ -120,6 +132,32 @@ public final class SourceFiles implements Iterable<SourceFile> {
return this.files.getSingle();
}
/**
* Return the single matching source file contained in the collection.
* @return the single file
* @throws IllegalStateException if the collection doesn't contain exactly
* one file
*/
public SourceFile getSingle(String pattern) throws IllegalStateException {
return getSingle(Pattern.compile(pattern));
}
private SourceFile getSingle(Pattern pattern) {
return this.files.getSingle(
candidate -> pattern.matcher(candidate.getClassName()).matches());
}
/**
* Return a single source file contained in the specified package.
* @return the single file
* @throws IllegalStateException if the collection doesn't contain exactly
* one file
*/
public SourceFile getSingleFromPackage(String packageName) {
return this.files.getSingle(candidate -> Objects.equals(packageName,
candidate.getJavaSource().getPackageName()));
}
@Override
public boolean equals(Object obj) {
if (this == obj) {

View File

@ -18,6 +18,9 @@ package org.springframework.aot.test.generator.compile;
import org.junit.jupiter.api.Test;
import org.springframework.aot.test.generator.file.ResourceFiles;
import org.springframework.aot.test.generator.file.SourceFiles;
import static org.assertj.core.api.Assertions.assertThat;
@ -31,8 +34,8 @@ class CompilationExceptionTests {
@Test
void getMessageReturnsMessage() {
CompilationException exception = new CompilationException("message");
assertThat(exception).hasMessage("message");
CompilationException exception = new CompilationException("message", SourceFiles.none(), ResourceFiles.none());
assertThat(exception).hasMessageContaining("message");
}
}

View File

@ -26,6 +26,7 @@ import org.springframework.aot.test.generator.file.ResourceFiles;
import org.springframework.aot.test.generator.file.SourceFile;
import org.springframework.aot.test.generator.file.SourceFiles;
import org.springframework.aot.test.generator.file.WritableContent;
import org.springframework.util.ClassUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -170,7 +171,8 @@ class TestCompilerTests {
}
@Test
void compiledCodeCanAccessExistingPackagePrivateClass() {
@CompileWithTargetClassAccess(classNames = "com.example.PackagePrivate")
void compiledCodeCanAccessExistingPackagePrivateClassIfAnnotated() throws ClassNotFoundException, LinkageError {
SourceFiles sourceFiles = SourceFiles.of(SourceFile.of("""
package com.example;
@ -187,6 +189,26 @@ class TestCompilerTests {
.isEqualTo("Hello from PackagePrivate"));
}
@Test
void compiledCodeCannotAccessExistingPackagePrivateClassIfNotAnnotated() {
SourceFiles sourceFiles = SourceFiles.of(SourceFile.of("""
package com.example;
public class Test implements PublicInterface {
public String perform() {
return new PackagePrivate().perform();
}
}
"""));
assertThatExceptionOfType(IllegalAccessError.class)
.isThrownBy(() -> TestCompiler.forSystem().compile(sourceFiles,
compiled -> compiled.getInstance(PublicInterface.class, "com.example.Test").perform()))
.withMessageContaining(ClassUtils.getShortName(CompileWithTargetClassAccess.class));
}
private void assertSuppliesHelloWorld(Compiled compiled) {
assertThat(compiled.getInstance(Supplier.class).get()).isEqualTo("Hello World!");
}

View File

@ -32,6 +32,7 @@ import org.springframework.aot.generator.DefaultGeneratedTypeContext;
import org.springframework.aot.generator.GeneratedType;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
import org.springframework.aot.test.generator.compile.CompileWithTargetClassAccess;
import org.springframework.aot.test.generator.compile.TestCompiler;
import org.springframework.aot.test.generator.file.SourceFile;
import org.springframework.aot.test.generator.file.SourceFiles;
@ -126,6 +127,7 @@ class PersistenceAnnotationBeanPostProcessorTests {
}
@Test
@CompileWithTargetClassAccess(classes = DefaultPersistenceUnitField.class)
void generateEntityManagerFactoryInjection() {
GenericApplicationContext context = new AnnotationConfigApplicationContext();
context.registerBeanDefinition("test", new RootBeanDefinition(DefaultPersistenceUnitField.class));

View File

@ -38,6 +38,9 @@
<suppress files="ResolvableType" checks="FinalClass"/>
<suppress files="[\\/]src[\\/]testFixtures[\\/]java[\\/].+" checks="IllegalImport" id="bannedJUnitJupiterImports"/>
<!-- spring-core-test -->
<suppress files="CompileWithTargetClassAccess" checks="IllegalImport" id="bannedJUnitJupiterImports" />
<!-- spring-expression -->
<suppress files="ExpressionException" checks="MutableException"/>
<suppress files="SpelMessage" checks="JavadocVariable|JavadocStyle"/>