diff --git a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/CompileWithTargetClassAccessClassLoader.java b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/CompileWithTargetClassAccessClassLoader.java index c944995bb24..cfc90bc9770 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/CompileWithTargetClassAccessClassLoader.java +++ b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/CompileWithTargetClassAccessClassLoader.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.util.Enumeration; +import java.util.function.Function; import org.springframework.lang.Nullable; @@ -34,12 +35,19 @@ final class CompileWithTargetClassAccessClassLoader extends ClassLoader { private final ClassLoader testClassLoader; + private Function classResourceLookup = name -> null; + public CompileWithTargetClassAccessClassLoader(ClassLoader testClassLoader) { super(testClassLoader.getParent()); this.testClassLoader = testClassLoader; } + // Invoked reflectively by DynamicClassLoader constructor + @SuppressWarnings("unused") + void setClassResourceLookup(Function classResourceLookup) { + this.classResourceLookup = classResourceLookup; + } @Override public Class loadClass(String name) throws ClassNotFoundException { @@ -51,25 +59,36 @@ final class CompileWithTargetClassAccessClassLoader extends ClassLoader { @Override protected Class findClass(String name) throws ClassNotFoundException { + byte[] bytes = findClassBytes(name); + return (bytes != null) ? defineClass(name, bytes, 0, bytes.length, null) : super.findClass(name); + } + + @Nullable + private byte[] findClassBytes(String name) { + byte[] bytes = this.classResourceLookup.apply(name); + if (bytes != null) { + return bytes; + } String resourceName = name.replace(".", "/") + ".class"; InputStream stream = this.testClassLoader.getResourceAsStream(resourceName); if (stream != null) { try (stream) { - byte[] bytes = stream.readAllBytes(); - return defineClass(name, bytes, 0, bytes.length, null); + return stream.readAllBytes(); } catch (IOException ex) { + // ignore } } - return super.findClass(name); + return null; } - // Invoked reflectively by DynamicClassLoader.findDefineClassMethod(ClassLoader) + @SuppressWarnings("unused") Class defineClassWithTargetAccess(String name, byte[] b, int off, int len) { return super.defineClass(name, b, off, len); } + @Override protected Enumeration findResources(String name) throws IOException { return this.testClassLoader.getResources(name); diff --git a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicClassFileObject.java b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicClassFileObject.java index 35be0ccb973..e598a737778 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicClassFileObject.java +++ b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicClassFileObject.java @@ -16,7 +16,10 @@ package org.springframework.aot.test.generate.compile; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.net.URI; @@ -31,23 +34,42 @@ import javax.tools.SimpleJavaFileObject; */ class DynamicClassFileObject extends SimpleJavaFileObject { - private volatile byte[] bytes = new byte[0]; + private static final byte[] NO_BYTES = new byte[0]; + + private final String className; + + private volatile byte[] bytes; DynamicClassFileObject(String className) { + this(className, NO_BYTES); + } + + DynamicClassFileObject(String className, byte[] bytes) { super(URI.create("class:///" + className.replace('.', '/') + ".class"), Kind.CLASS); + this.className = className; + this.bytes = bytes; } - @Override - public OutputStream openOutputStream() { - return new JavaClassOutputStream(); + String getClassName() { + return this.className; } byte[] getBytes() { return this.bytes; } + @Override + public InputStream openInputStream() throws IOException { + return new ByteArrayInputStream(this.bytes); + } + + @Override + public OutputStream openOutputStream() { + return new JavaClassOutputStream(); + } + class JavaClassOutputStream extends ByteArrayOutputStream { diff --git a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicClassLoader.java b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicClassLoader.java index 36cf2372955..4b8c080de43 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicClassLoader.java +++ b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicClassLoader.java @@ -26,6 +26,8 @@ import java.net.URLConnection; import java.net.URLStreamHandler; import java.nio.charset.StandardCharsets; import java.util.Enumeration; +import java.util.Map; +import java.util.function.Function; import org.springframework.aot.test.generate.file.ClassFile; import org.springframework.aot.test.generate.file.ClassFiles; @@ -47,47 +49,65 @@ public class DynamicClassLoader extends ClassLoader { private final ClassFiles classFiles; + private final Map compiledClasses; + @Nullable private final Method defineClassMethod; public DynamicClassLoader(ClassLoader parent, ResourceFiles resourceFiles, - ClassFiles classFiles) { + ClassFiles classFiles, Map compiledClasses) { super(parent); this.resourceFiles = resourceFiles; this.classFiles = classFiles; - this.defineClassMethod = findDefineClassMethod(parent); - if (this.defineClassMethod != null) { - classFiles.forEach(this::defineClass); - } - } - - @Nullable - private Method findDefineClassMethod(ClassLoader parent) { + this.compiledClasses = compiledClasses; Class parentClass = parent.getClass(); if (parentClass.getName().equals(CompileWithTargetClassAccessClassLoader.class.getName())) { - Method defineClassMethod = ReflectionUtils.findMethod(parentClass, + Method setClassResourceLookupMethod = ReflectionUtils.findMethod(parentClass, + "setClassResourceLookup", Function.class); + ReflectionUtils.makeAccessible(setClassResourceLookupMethod); + ReflectionUtils.invokeMethod(setClassResourceLookupMethod, + getParent(), (Function) this::findClassBytes); + this.defineClassMethod = ReflectionUtils.findMethod(parentClass, "defineClassWithTargetAccess", String.class, byte[].class, int.class, int.class); - ReflectionUtils.makeAccessible(defineClassMethod); - return defineClassMethod; + ReflectionUtils.makeAccessible(this.defineClassMethod); + this.compiledClasses.forEach((name, file) -> defineClass(name, file.getBytes())); + } + else { + this.defineClassMethod = null; } - return null; } @Override protected Class findClass(String name) throws ClassNotFoundException { + byte[] bytes = findClassBytes(name); + if (bytes != null) { + return defineClass(name, bytes); + } + return super.findClass(name); + } + + @Nullable + private byte[] findClassBytes(String name) { + DynamicClassFileObject compiledClass = this.compiledClasses.get(name); + if(compiledClass != null) { + return compiledClass.getBytes(); + } + return findClassFileBytes(name); + } + + @Nullable + private byte[] findClassFileBytes(String name) { ClassFile classFile = this.classFiles.get(name); if (classFile != null) { - return defineClass(classFile); + return classFile.getContent(); } - return super.findClass(name); + return null; } - private Class defineClass(ClassFile classFile) { - String name = classFile.getName(); - byte[] bytes = classFile.getContent(); + private Class defineClass(String name, byte[] bytes) { if (this.defineClassMethod != null) { return (Class) ReflectionUtils.invokeMethod(this.defineClassMethod, getParent(), name, bytes, 0, bytes.length); diff --git a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicJavaFileManager.java b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicJavaFileManager.java index 80e56a97a62..0df5291af21 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicJavaFileManager.java +++ b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/DynamicJavaFileManager.java @@ -16,10 +16,7 @@ package org.springframework.aot.test.generate.compile; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.InputStream; -import java.net.URI; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; @@ -32,7 +29,6 @@ import javax.tools.ForwardingJavaFileManager; import javax.tools.JavaFileManager; import javax.tools.JavaFileObject; import javax.tools.JavaFileObject.Kind; -import javax.tools.SimpleJavaFileObject; import org.springframework.aot.test.generate.file.ClassFile; import org.springframework.aot.test.generate.file.ClassFiles; @@ -48,19 +44,18 @@ import org.springframework.util.ClassUtils; */ class DynamicJavaFileManager extends ForwardingJavaFileManager { - private final ClassFiles existingClasses; - private final ClassLoader classLoader; + private final ClassFiles classFiles; + private final Map compiledClasses = Collections.synchronizedMap( new LinkedHashMap<>()); - DynamicJavaFileManager(JavaFileManager fileManager, ClassLoader classLoader, - ClassFiles existingClasses) { + DynamicJavaFileManager(JavaFileManager fileManager, ClassLoader classLoader, ClassFiles classFiles) { super(fileManager); this.classLoader = classLoader; - this.existingClasses = existingClasses; + this.classFiles = classFiles; } @@ -84,49 +79,27 @@ class DynamicJavaFileManager extends ForwardingJavaFileManager Set kinds, boolean recurse) throws IOException { List result = new ArrayList<>(); if (kinds.contains(Kind.CLASS)) { - for (ClassFile existingClass : this.existingClasses) { - String existingPackageName = ClassUtils.getPackageName(existingClass.getName()); + for (ClassFile candidate : this.classFiles) { + String existingPackageName = ClassUtils.getPackageName(candidate.getName()); if (existingPackageName.equals(packageName) || (recurse && existingPackageName.startsWith(packageName + "."))) { - result.add(new ClassFileJavaFileObject(existingClass)); + result.add(new DynamicClassFileObject(candidate.getName(), candidate.getContent())); } } } - Iterable listed = super.list(location, packageName, kinds, recurse); - listed.forEach(result::add); + super.list(location, packageName, kinds, recurse).forEach(result::add); return result; } @Override public String inferBinaryName(Location location, JavaFileObject file) { - if (file instanceof ClassFileJavaFileObject classFile) { - return classFile.getClassName(); + if (file instanceof DynamicClassFileObject dynamicClassFileObject) { + return dynamicClassFileObject.getClassName(); } return super.inferBinaryName(location, file); } - ClassFiles getClassFiles() { - return this.existingClasses.and(this.compiledClasses.entrySet().stream().map(entry -> - ClassFile.of(entry.getKey(), entry.getValue().getBytes())).toList()); - } - - private static final class ClassFileJavaFileObject extends SimpleJavaFileObject { - - private final ClassFile classFile; - - private ClassFileJavaFileObject(ClassFile classFile) { - super(URI.create("class:///" + classFile.getName().replace('.', '/') + ".class"), Kind.CLASS); - this.classFile = classFile; - } - - public String getClassName() { - return this.classFile.getName(); - } - - @Override - public InputStream openInputStream() { - return new ByteArrayInputStream(this.classFile.getContent()); - } - + Map getCompiledClasses() { + return this.compiledClasses; } } diff --git a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/TestCompiler.java b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/TestCompiler.java index dbe2f6127b3..40021f42bf9 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/TestCompiler.java +++ b/spring-core-test/src/main/java/org/springframework/aot/test/generate/compile/TestCompiler.java @@ -304,7 +304,7 @@ public final class TestCompiler { throw new CompilationException(errors.toString(), this.sourceFiles, this.resourceFiles); } } - return new DynamicClassLoader(classLoaderToUse, this.resourceFiles, fileManager.getClassFiles()); + return new DynamicClassLoader(classLoaderToUse, this.resourceFiles, this.classFiles, fileManager.getCompiledClasses()); } /** diff --git a/spring-core-test/src/test/java/org/springframework/aot/test/generate/compile/DynamicJavaFileManagerTests.java b/spring-core-test/src/test/java/org/springframework/aot/test/generate/compile/DynamicJavaFileManagerTests.java index 6faa4729b99..f3af3fd7f62 100644 --- a/spring-core-test/src/test/java/org/springframework/aot/test/generate/compile/DynamicJavaFileManagerTests.java +++ b/spring-core-test/src/test/java/org/springframework/aot/test/generate/compile/DynamicJavaFileManagerTests.java @@ -101,8 +101,8 @@ class DynamicJavaFileManagerTests { Kind.CLASS, null); this.fileManager.getJavaFileForOutput(this.location, "com.example.MyClass2", Kind.CLASS, null); - assertThat(this.fileManager.getClassFiles().stream().map(ClassFile::getName)) - .contains("com.example.MyClass1", "com.example.MyClass2"); + assertThat(this.fileManager.getCompiledClasses()).containsKeys( + "com.example.MyClass1", "com.example.MyClass2"); } @Test