diff --git a/spring-test/src/main/java/org/springframework/test/context/aot/AotTestAttributesFactory.java b/spring-test/src/main/java/org/springframework/test/context/aot/AotTestAttributesFactory.java
index 5f86413a3e6..0ca8bb80ccd 100644
--- a/spring-test/src/main/java/org/springframework/test/context/aot/AotTestAttributesFactory.java
+++ b/spring-test/src/main/java/org/springframework/test/context/aot/AotTestAttributesFactory.java
@@ -16,16 +16,11 @@
package org.springframework.test.context.aot;
-import java.lang.reflect.Method;
-import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.aot.AotDetector;
import org.springframework.lang.Nullable;
-import org.springframework.util.Assert;
-import org.springframework.util.ClassUtils;
-import org.springframework.util.ReflectionUtils;
/**
* Factory for {@link AotTestAttributes}.
@@ -64,7 +59,7 @@ final class AotTestAttributesFactory {
}
/**
- * Reset AOT test attributes.
+ * Reset the factory.
*
Only for internal use.
*/
static void reset() {
@@ -73,23 +68,11 @@ final class AotTestAttributesFactory {
}
}
- @SuppressWarnings({ "rawtypes", "unchecked" })
+ @SuppressWarnings("unchecked")
private static Map loadAttributesMap() {
String className = AotTestAttributesCodeGenerator.GENERATED_ATTRIBUTES_CLASS_NAME;
String methodName = AotTestAttributesCodeGenerator.GENERATED_ATTRIBUTES_METHOD_NAME;
- try {
- Class> clazz = ClassUtils.forName(className, null);
- Method method = ReflectionUtils.findMethod(clazz, methodName);
- Assert.state(method != null, () -> "No %s() method found in %s".formatted(methodName, clazz.getName()));
- Map attributes = (Map) ReflectionUtils.invokeMethod(method, null);
- return Collections.unmodifiableMap(attributes);
- }
- catch (IllegalStateException ex) {
- throw ex;
- }
- catch (Exception ex) {
- throw new IllegalStateException("Failed to invoke %s() method on %s".formatted(methodName, className), ex);
- }
+ return GeneratedMapUtils.loadMap(className, methodName);
}
}
diff --git a/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializers.java b/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializers.java
index e82b05075c3..7dbf09f45d2 100644
--- a/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializers.java
+++ b/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializers.java
@@ -16,50 +16,35 @@
package org.springframework.test.context.aot;
-import java.lang.reflect.Method;
import java.util.Map;
import java.util.function.Supplier;
+import org.springframework.aot.AotDetector;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.lang.Nullable;
-import org.springframework.util.Assert;
-import org.springframework.util.ClassUtils;
-import org.springframework.util.ReflectionUtils;
/**
* {@code AotTestContextInitializers} provides mappings from test classes to
* AOT-optimized context initializers.
*
- * If a test class is not {@linkplain #isSupportedTestClass(Class) supported} in
- * AOT mode, {@link #getContextInitializer(Class)} will return {@code null}.
+ *
Intended solely for internal use within the framework.
*
- *
Reflectively accesses {@link #GENERATED_MAPPINGS_CLASS_NAME} generated by
- * the {@link TestContextAotGenerator} to retrieve the mappings generated during
- * AOT processing.
+ *
If we are not running in {@linkplain AotDetector#useGeneratedArtifacts()
+ * AOT mode} or if a test class is not {@linkplain #isSupportedTestClass(Class)
+ * supported} in AOT mode, {@link #getContextInitializer(Class)} will return
+ * {@code null}.
*
* @author Sam Brannen
- * @author Stephane Nicoll
* @since 6.0
*/
public class AotTestContextInitializers {
- // TODO Add support in ClassNameGenerator for supplying a predefined class name.
- // There is a similar issue in Spring Boot where code relies on a generated name.
- // Ideally we would generate a class named: org.springframework.test.context.aot.GeneratedAotTestContextInitializers
- static final String GENERATED_MAPPINGS_CLASS_NAME = AotTestContextInitializers.class.getName() + "__Generated";
-
- static final String GENERATED_MAPPINGS_METHOD_NAME = "getContextInitializers";
-
private final Map>> contextInitializers;
public AotTestContextInitializers() {
- this(GENERATED_MAPPINGS_CLASS_NAME);
- }
-
- AotTestContextInitializers(String initializerClassName) {
- this(loadContextInitializersMap(initializerClassName));
+ this(AotTestContextInitializersFactory.getContextInitializers());
}
AotTestContextInitializers(Map>> contextInitializers) {
@@ -90,26 +75,4 @@ public class AotTestContextInitializers {
return (supplier != null ? supplier.get() : null);
}
-
- @SuppressWarnings({ "rawtypes", "unchecked" })
- private static Map>>
- loadContextInitializersMap(String className) {
-
- String methodName = GENERATED_MAPPINGS_METHOD_NAME;
-
- try {
- Class> clazz = ClassUtils.forName(className, null);
- Method method = ReflectionUtils.findMethod(clazz, methodName);
- Assert.state(method != null, () -> "No %s() method found in %s".formatted(methodName, clazz.getName()));
- return (Map>>)
- ReflectionUtils.invokeMethod(method, null);
- }
- catch (IllegalStateException ex) {
- throw ex;
- }
- catch (Exception ex) {
- throw new IllegalStateException("Failed to invoke %s() method in %s".formatted(methodName, className), ex);
- }
- }
-
}
diff --git a/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializersCodeGenerator.java b/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializersCodeGenerator.java
index c7732cf2615..798023c62ae 100644
--- a/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializersCodeGenerator.java
+++ b/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializersCodeGenerator.java
@@ -61,6 +61,15 @@ class AotTestContextInitializersCodeGenerator {
private static final TypeName CONTEXT_SUPPLIER_MAP = ParameterizedTypeName
.get(ClassName.get(Map.class), ClassName.get(String.class), CONTEXT_INITIALIZER_SUPPLIER);
+ private static final String GENERATED_SUFFIX = "Generated";
+
+ // TODO Add support in ClassNameGenerator for supplying a predefined class name.
+ // There is a similar issue in Spring Boot where code relies on a generated name.
+ // Ideally we would generate a class named: org.springframework.test.context.aot.GeneratedAotTestContextInitializers
+ static final String GENERATED_MAPPINGS_CLASS_NAME = AotTestContextInitializers.class.getName() + "__" + GENERATED_SUFFIX;
+
+ static final String GENERATED_MAPPINGS_METHOD_NAME = "getContextInitializers";
+
private final MultiValueMap> initializerClassMappings;
@@ -71,7 +80,7 @@ class AotTestContextInitializersCodeGenerator {
GeneratedClasses generatedClasses) {
this.initializerClassMappings = initializerClassMappings;
- this.generatedClass = generatedClasses.addForFeature("Generated", this::generateType);
+ this.generatedClass = generatedClasses.addForFeature(GENERATED_SUFFIX, this::generateType);
}
@@ -88,7 +97,7 @@ class AotTestContextInitializersCodeGenerator {
}
private MethodSpec generateMappingMethod() {
- MethodSpec.Builder method = MethodSpec.methodBuilder(AotTestContextInitializers.GENERATED_MAPPINGS_METHOD_NAME);
+ MethodSpec.Builder method = MethodSpec.methodBuilder(GENERATED_MAPPINGS_METHOD_NAME);
method.addModifiers(Modifier.PUBLIC, Modifier.STATIC);
method.returns(CONTEXT_SUPPLIER_MAP);
method.addCode(generateMappingCode());
diff --git a/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializersFactory.java b/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializersFactory.java
new file mode 100644
index 00000000000..e714dcf6dbe
--- /dev/null
+++ b/spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializersFactory.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.test.context.aot;
+
+import java.util.Map;
+import java.util.function.Supplier;
+
+import org.springframework.aot.AotDetector;
+import org.springframework.context.ApplicationContextInitializer;
+import org.springframework.context.ConfigurableApplicationContext;
+import org.springframework.lang.Nullable;
+
+/**
+ * Factory for {@link AotTestContextInitializers}.
+ *
+ * @author Sam Brannen
+ * @since 6.0
+ */
+final class AotTestContextInitializersFactory {
+
+ @Nullable
+ private static volatile Map>> contextInitializers;
+
+
+ private AotTestContextInitializersFactory() {
+ }
+
+ /**
+ * Get the underlying map.
+ * If the map is not already loaded, this method loads the map from the
+ * generated class when running in {@linkplain AotDetector#useGeneratedArtifacts()
+ * AOT execution mode} and otherwise creates an immutable, empty map.
+ */
+ static Map>> getContextInitializers() {
+ Map>> initializers = contextInitializers;
+ if (initializers == null) {
+ synchronized (AotTestContextInitializersFactory.class) {
+ initializers = contextInitializers;
+ if (initializers == null) {
+ initializers = (AotDetector.useGeneratedArtifacts() ? loadContextInitializersMap() : Map.of());
+ contextInitializers = initializers;
+ }
+ }
+ }
+ return initializers;
+ }
+
+ /**
+ * Reset the factory.
+ * Only for internal use.
+ */
+ static void reset() {
+ synchronized (AotTestContextInitializersFactory.class) {
+ contextInitializers = null;
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private static Map>> loadContextInitializersMap() {
+ String className = AotTestContextInitializersCodeGenerator.GENERATED_MAPPINGS_CLASS_NAME;
+ String methodName = AotTestContextInitializersCodeGenerator.GENERATED_MAPPINGS_METHOD_NAME;
+ return GeneratedMapUtils.loadMap(className, methodName);
+ }
+
+}
diff --git a/spring-test/src/main/java/org/springframework/test/context/aot/GeneratedMapUtils.java b/spring-test/src/main/java/org/springframework/test/context/aot/GeneratedMapUtils.java
new file mode 100644
index 00000000000..f6e3f8703d2
--- /dev/null
+++ b/spring-test/src/main/java/org/springframework/test/context/aot/GeneratedMapUtils.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.test.context.aot;
+
+import java.lang.reflect.Method;
+import java.util.Collections;
+import java.util.Map;
+
+import org.springframework.util.Assert;
+import org.springframework.util.ClassUtils;
+import org.springframework.util.ReflectionUtils;
+
+/**
+ * Utilities for loading generated maps.
+ *
+ * @author Sam Brannen
+ * @author Stephane Nicoll
+ * @since 6.0
+ */
+final class GeneratedMapUtils {
+
+ private GeneratedMapUtils() {
+ }
+
+ /**
+ * Load a generated map.
+ * @param className the name of the class in which the static method resides
+ * @param methodName the name of the static method to invoke
+ * @return an unmodifiable map retrieved from a static method
+ */
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ static Map loadMap(String className, String methodName) {
+ try {
+ Class> clazz = ClassUtils.forName(className, null);
+ Method method = ReflectionUtils.findMethod(clazz, methodName);
+ Assert.state(method != null, () -> "No %s() method found in %s".formatted(methodName, className));
+ Map map = (Map) ReflectionUtils.invokeMethod(method, null);
+ return Collections.unmodifiableMap(map);
+ }
+ catch (IllegalStateException ex) {
+ throw ex;
+ }
+ catch (Exception ex) {
+ throw new IllegalStateException("Failed to invoke %s() method on %s".formatted(methodName, className), ex);
+ }
+ }
+
+}
diff --git a/spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java b/spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java
index c717774f18d..ae67e17a472 100644
--- a/spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java
+++ b/spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java
@@ -24,6 +24,7 @@ import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.springframework.aot.AotDetector;
import org.springframework.aot.generate.ClassNameGenerator;
import org.springframework.aot.generate.DefaultGenerationContext;
import org.springframework.aot.generate.GeneratedClasses;
@@ -109,9 +110,9 @@ public class TestContextAotGenerator {
* @throws TestContextAotException if an error occurs during AOT processing
*/
public void processAheadOfTime(Stream> testClasses) throws TestContextAotException {
+ Assert.state(!AotDetector.useGeneratedArtifacts(), "Cannot perform AOT processing during AOT run-time execution");
try {
- // Make sure AOT attributes are cleared before processing
- AotTestAttributesFactory.reset();
+ resetAotFactories();
MultiValueMap> mergedConfigMappings = new LinkedMultiValueMap<>();
testClasses.forEach(testClass -> mergedConfigMappings.add(buildMergedContextConfiguration(testClass), testClass));
@@ -121,11 +122,15 @@ public class TestContextAotGenerator {
generateAotTestAttributes();
}
finally {
- // Clear AOT attributes after processing
- AotTestAttributesFactory.reset();
+ resetAotFactories();
}
}
+ private void resetAotFactories() {
+ AotTestAttributesFactory.reset();
+ AotTestContextInitializersFactory.reset();
+ }
+
private MultiValueMap> processAheadOfTime(MultiValueMap> mergedConfigMappings) {
MultiValueMap> initializerClassMappings = new LinkedMultiValueMap<>();
mergedConfigMappings.forEach((mergedConfig, testClasses) -> {
diff --git a/spring-test/src/main/java/org/springframework/test/context/cache/DefaultCacheAwareContextLoaderDelegate.java b/spring-test/src/main/java/org/springframework/test/context/cache/DefaultCacheAwareContextLoaderDelegate.java
index f7a3eab00e2..a2919db9341 100644
--- a/spring-test/src/main/java/org/springframework/test/context/cache/DefaultCacheAwareContextLoaderDelegate.java
+++ b/spring-test/src/main/java/org/springframework/test/context/cache/DefaultCacheAwareContextLoaderDelegate.java
@@ -19,7 +19,6 @@ package org.springframework.test.context.cache;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-import org.springframework.aot.AotDetector;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
@@ -56,8 +55,7 @@ public class DefaultCacheAwareContextLoaderDelegate implements CacheAwareContext
*/
static final ContextCache defaultContextCache = new DefaultContextCache();
- @Nullable
- private final AotTestContextInitializers aotTestContextInitializers = getAotTestContextInitializers();
+ private final AotTestContextInitializers aotTestContextInitializers = new AotTestContextInitializers();
private final ContextCache contextCache;
@@ -200,21 +198,7 @@ public class DefaultCacheAwareContextLoaderDelegate implements CacheAwareContext
* Determine if we are running in AOT mode for the supplied test class.
*/
private boolean runningInAotMode(Class> testClass) {
- return (this.aotTestContextInitializers != null &&
- this.aotTestContextInitializers.isSupportedTestClass(testClass));
- }
-
- @Nullable
- private static AotTestContextInitializers getAotTestContextInitializers() {
- if (AotDetector.useGeneratedArtifacts()) {
- try {
- return new AotTestContextInitializers();
- }
- catch (Exception ex) {
- throw new IllegalStateException("Failed to instantiate AotTestContextInitializers", ex);
- }
- }
- return null;
+ return this.aotTestContextInitializers.isSupportedTestClass(testClass);
}
}
diff --git a/spring-test/src/main/java/org/springframework/test/context/support/DependencyInjectionTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/support/DependencyInjectionTestExecutionListener.java
index ec7b7839065..bba77ea4576 100644
--- a/spring-test/src/main/java/org/springframework/test/context/support/DependencyInjectionTestExecutionListener.java
+++ b/spring-test/src/main/java/org/springframework/test/context/support/DependencyInjectionTestExecutionListener.java
@@ -19,14 +19,12 @@ package org.springframework.test.context.support;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-import org.springframework.aot.AotDetector;
import org.springframework.beans.factory.annotation.AutowiredAnnotationBeanPostProcessor;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.Conventions;
-import org.springframework.lang.Nullable;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.aot.AotTestContextInitializers;
@@ -60,8 +58,7 @@ public class DependencyInjectionTestExecutionListener extends AbstractTestExecut
private static final Log logger = LogFactory.getLog(DependencyInjectionTestExecutionListener.class);
- @Nullable
- private final AotTestContextInitializers aotTestContextInitializers = getAotTestContextInitializers();
+ private final AotTestContextInitializers aotTestContextInitializers = new AotTestContextInitializers();
/**
@@ -162,21 +159,7 @@ public class DependencyInjectionTestExecutionListener extends AbstractTestExecut
* Determine if we are running in AOT mode for the supplied test class.
*/
private boolean runningInAotMode(Class> testClass) {
- return (this.aotTestContextInitializers != null &&
- this.aotTestContextInitializers.isSupportedTestClass(testClass));
- }
-
- @Nullable
- private static AotTestContextInitializers getAotTestContextInitializers() {
- if (AotDetector.useGeneratedArtifacts()) {
- try {
- return new AotTestContextInitializers();
- }
- catch (Exception ex) {
- throw new IllegalStateException("Failed to instantiate AotTestContextInitializers", ex);
- }
- }
- return null;
+ return this.aotTestContextInitializers.isSupportedTestClass(testClass);
}
}
diff --git a/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java b/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java
index d7000fdaddb..bcbd7c58b39 100644
--- a/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java
+++ b/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java
@@ -114,6 +114,7 @@ class TestContextAotGeneratorTests extends AbstractAotTests {
try {
System.setProperty(AotDetector.AOT_ENABLED, "true");
AotTestAttributesFactory.reset();
+ AotTestContextInitializersFactory.reset();
AotTestAttributes aotAttributes = AotTestAttributes.getInstance();
assertThatExceptionOfType(UnsupportedOperationException.class)
@@ -153,7 +154,7 @@ class TestContextAotGeneratorTests extends AbstractAotTests {
}
private static void assertRuntimeHints(RuntimeHints runtimeHints) {
- assertReflectionRegistered(runtimeHints, AotTestContextInitializers.GENERATED_MAPPINGS_CLASS_NAME, INVOKE_PUBLIC_METHODS);
+ assertReflectionRegistered(runtimeHints, AotTestContextInitializersCodeGenerator.GENERATED_MAPPINGS_CLASS_NAME, INVOKE_PUBLIC_METHODS);
assertReflectionRegistered(runtimeHints, AotTestAttributesCodeGenerator.GENERATED_ATTRIBUTES_CLASS_NAME, INVOKE_PUBLIC_METHODS);
Stream.of(