From 8a8c8fe00edf09c3df4885ef92e3f2f728d21704 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Mon, 22 Apr 2024 09:45:12 +0200 Subject: [PATCH] Detect target of factory method with AOT Previously, if a factory method is defined on a parent, the generated code would blindly use the method's declaring class for both the target of the generated code, and the signature of the method. This commit improves the resolution by considering the factory metadata in the BeanDefinition. Closes gh-32609 --- .../DefaultBeanRegistrationCodeFragments.java | 28 ++++----- .../aot/InstanceSupplierCodeGenerator.java | 62 ++++++++++++------- .../beans/factory/support/RegisteredBean.java | 39 +++++++++++- .../BeanDefinitionMethodGeneratorTests.java | 8 ++- ...ultBeanRegistrationCodeFragmentsTests.java | 25 ++++++-- .../InstanceSupplierCodeGeneratorTests.java | 30 +++++++-- ...nstanceSupplierCodeGeneratorKotlinTests.kt | 8 +-- .../aot/DefaultSimpleBeanContract.java | 25 ++++++++ .../beans/factory/aot/SimpleBeanContract.java | 30 +++++++++ .../ConfigurationClassPostProcessor.java | 15 +++-- ...stContextAotGeneratorIntegrationTests.java | 1 - .../management/ManagementConfiguration.java | 6 +- 12 files changed, 214 insertions(+), 63 deletions(-) create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/DefaultSimpleBeanContract.java create mode 100644 spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanContract.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index 0ef92102283..feddeb21d4c 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ package org.springframework.beans.factory.aot; import java.lang.reflect.Constructor; -import java.lang.reflect.Executable; import java.lang.reflect.Modifier; import java.util.List; import java.util.function.Predicate; @@ -35,6 +34,7 @@ import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.ResolvableType; import org.springframework.javapoet.ClassName; @@ -62,7 +62,7 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory; - private final Supplier constructorOrFactoryMethod; + private final Supplier instantiationDescriptor; DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode, @@ -72,7 +72,7 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme this.beanRegistrationsCode = beanRegistrationsCode; this.registeredBean = registeredBean; this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory; - this.constructorOrFactoryMethod = SingletonSupplier.of(registeredBean::resolveConstructorOrFactoryMethod); + this.instantiationDescriptor = SingletonSupplier.of(registeredBean::resolveInstantiationDescriptor); } @@ -82,7 +82,7 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme throw new IllegalStateException("Default code generation is not supported for bean definitions " + "declaring an instance supplier callback: " + registeredBean.getMergedBeanDefinition()); } - Class target = extractDeclaringClass(registeredBean.getBeanType(), this.constructorOrFactoryMethod.get()); + Class target = extractDeclaringClass(registeredBean, this.instantiationDescriptor.get()); while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) { RegisteredBean parent = registeredBean.getParent(); Assert.state(parent != null, "No parent available for inner bean"); @@ -91,14 +91,14 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme return (target.isArray() ? ClassName.get(target.getComponentType()) : ClassName.get(target)); } - private Class extractDeclaringClass(ResolvableType beanType, Executable executable) { - Class declaringClass = ClassUtils.getUserClass(executable.getDeclaringClass()); - if (executable instanceof Constructor - && AccessControl.forMember(executable).isPublic() + private Class extractDeclaringClass(RegisteredBean registeredBean, InstantiationDescriptor instantiationDescriptor) { + Class declaringClass = ClassUtils.getUserClass(instantiationDescriptor.targetClass()); + if (instantiationDescriptor.executable() instanceof Constructor + && AccessControl.forMember(instantiationDescriptor.executable()).isPublic() && FactoryBean.class.isAssignableFrom(declaringClass)) { - return extractTargetClassFromFactoryBean(declaringClass, beanType); + return extractTargetClassFromFactoryBean(declaringClass, registeredBean.getBeanType()); } - return executable.getDeclaringClass(); + return declaringClass; } /** @@ -238,9 +238,9 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme throw new IllegalStateException("Default code generation is not supported for bean definitions declaring " + "an instance supplier callback: " + this.registeredBean.getMergedBeanDefinition()); } - return new InstanceSupplierCodeGenerator(generationContext, - beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) - .generateCode(this.registeredBean, this.constructorOrFactoryMethod.get()); + return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(), + beanRegistrationCode.getMethods(), allowDirectSupplierShortcut).generateCode( + this.registeredBean, this.instantiationDescriptor.get()); } @Override diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index 3087d8b481a..acc796df582 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,6 +46,7 @@ import org.springframework.beans.factory.support.AutowireCandidateResolver; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor; import org.springframework.core.KotlinDetector; import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; @@ -120,14 +121,29 @@ public class InstanceSupplierCodeGenerator { * @param registeredBean the bean to handle * @param constructorOrFactoryMethod the executable to use to create the bean * @return the generated code + * @deprecated in favor of {@link #generateCode(RegisteredBean, InstantiationDescriptor)} */ + @Deprecated(since = "6.1.7") public CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { + return generateCode(registeredBean, new InstantiationDescriptor( + constructorOrFactoryMethod, constructorOrFactoryMethod.getDeclaringClass())); + } + + /** + * Generate the instance supplier code. + * @param registeredBean the bean to handle + * @param instantiationDescriptor the executable to use to create the bean + * @return the generated code + * @since 6.1.7 + */ + public CodeBlock generateCode(RegisteredBean registeredBean, InstantiationDescriptor instantiationDescriptor) { + Executable constructorOrFactoryMethod = instantiationDescriptor.executable(); registerRuntimeHintsIfNecessary(registeredBean, constructorOrFactoryMethod); if (constructorOrFactoryMethod instanceof Constructor constructor) { return generateCodeForConstructor(registeredBean, constructor); } if (constructorOrFactoryMethod instanceof Method method) { - return generateCodeForFactoryMethod(registeredBean, method); + return generateCodeForFactoryMethod(registeredBean, method, instantiationDescriptor.targetClass()); } throw new IllegalStateException( "No suitable executor found for " + registeredBean.getBeanName()); @@ -253,21 +269,21 @@ public class InstanceSupplierCodeGenerator { declaringClass.getSimpleName(), args); } - private CodeBlock generateCodeForFactoryMethod(RegisteredBean registeredBean, Method factoryMethod) { + private CodeBlock generateCodeForFactoryMethod(RegisteredBean registeredBean, Method factoryMethod, Class targetClass) { String beanName = registeredBean.getBeanName(); - Class declaringClass = ClassUtils.getUserClass(factoryMethod.getDeclaringClass()); + Class targetClassToUse = ClassUtils.getUserClass(targetClass); boolean dependsOnBean = !Modifier.isStatic(factoryMethod.getModifiers()); Visibility accessVisibility = getAccessVisibility(registeredBean, factoryMethod); if (accessVisibility != Visibility.PRIVATE) { return generateCodeForAccessibleFactoryMethod( - beanName, factoryMethod, declaringClass, dependsOnBean); + beanName, factoryMethod, targetClassToUse, dependsOnBean); } - return generateCodeForInaccessibleFactoryMethod(beanName, factoryMethod, declaringClass); + return generateCodeForInaccessibleFactoryMethod(beanName, factoryMethod, targetClassToUse); } private CodeBlock generateCodeForAccessibleFactoryMethod(String beanName, - Method factoryMethod, Class declaringClass, boolean dependsOnBean) { + Method factoryMethod, Class targetClass, boolean dependsOnBean) { this.generationContext.getRuntimeHints().reflection().registerMethod( factoryMethod, ExecutableMode.INTROSPECT); @@ -276,20 +292,20 @@ public class InstanceSupplierCodeGenerator { Class suppliedType = ClassUtils.resolvePrimitiveIfNecessary(factoryMethod.getReturnType()); CodeBlock.Builder code = CodeBlock.builder(); code.add("$T.<$T>forFactoryMethod($T.class, $S)", BeanInstanceSupplier.class, - suppliedType, declaringClass, factoryMethod.getName()); + suppliedType, targetClass, factoryMethod.getName()); code.add(".withGenerator(($L) -> $T.$L())", REGISTERED_BEAN_PARAMETER_NAME, - declaringClass, factoryMethod.getName()); + targetClass, factoryMethod.getName()); return code.build(); } GeneratedMethod getInstanceMethod = generateGetInstanceSupplierMethod(method -> buildGetInstanceMethodForFactoryMethod(method, beanName, factoryMethod, - declaringClass, dependsOnBean, PRIVATE_STATIC)); + targetClass, dependsOnBean, PRIVATE_STATIC)); return generateReturnStatement(getInstanceMethod); } private CodeBlock generateCodeForInaccessibleFactoryMethod( - String beanName, Method factoryMethod, Class declaringClass) { + String beanName, Method factoryMethod, Class targetClass) { this.generationContext.getRuntimeHints().reflection().registerMethod(factoryMethod, ExecutableMode.INVOKE); GeneratedMethod getInstanceMethod = generateGetInstanceSupplierMethod(method -> { @@ -298,19 +314,19 @@ public class InstanceSupplierCodeGenerator { method.addModifiers(PRIVATE_STATIC); method.returns(ParameterizedTypeName.get(BeanInstanceSupplier.class, suppliedType)); method.addStatement(generateInstanceSupplierForFactoryMethod( - factoryMethod, suppliedType, declaringClass, factoryMethod.getName())); + factoryMethod, suppliedType, targetClass, factoryMethod.getName())); }); return generateReturnStatement(getInstanceMethod); } private void buildGetInstanceMethodForFactoryMethod(MethodSpec.Builder method, - String beanName, Method factoryMethod, Class declaringClass, + String beanName, Method factoryMethod, Class targetClass, boolean dependsOnBean, javax.lang.model.element.Modifier... modifiers) { String factoryMethodName = factoryMethod.getName(); Class suppliedType = ClassUtils.resolvePrimitiveIfNecessary(factoryMethod.getReturnType()); CodeWarnings codeWarnings = new CodeWarnings(); - codeWarnings.detectDeprecation(declaringClass, factoryMethod, suppliedType) + codeWarnings.detectDeprecation(targetClass, factoryMethod, suppliedType) .detectDeprecation(Arrays.stream(factoryMethod.getParameters()).map(Parameter::getType)); method.addJavadoc("Get the bean instance supplier for '$L'.", beanName); @@ -320,41 +336,41 @@ public class InstanceSupplierCodeGenerator { CodeBlock.Builder code = CodeBlock.builder(); code.add(generateInstanceSupplierForFactoryMethod( - factoryMethod, suppliedType, declaringClass, factoryMethodName)); + factoryMethod, suppliedType, targetClass, factoryMethodName)); boolean hasArguments = factoryMethod.getParameterCount() > 0; CodeBlock arguments = hasArguments ? - new AutowiredArgumentsCodeGenerator(declaringClass, factoryMethod) + new AutowiredArgumentsCodeGenerator(targetClass, factoryMethod) .generateCode(factoryMethod.getParameterTypes()) : NO_ARGS; CodeBlock newInstance = generateNewInstanceCodeForMethod( - dependsOnBean, declaringClass, factoryMethodName, arguments); + dependsOnBean, targetClass, factoryMethodName, arguments); code.add(generateWithGeneratorCode(hasArguments, newInstance)); method.addStatement(code.build()); } private CodeBlock generateInstanceSupplierForFactoryMethod(Method factoryMethod, - Class suppliedType, Class declaringClass, String factoryMethodName) { + Class suppliedType, Class targetClass, String factoryMethodName) { if (factoryMethod.getParameterCount() == 0) { return CodeBlock.of("return $T.<$T>forFactoryMethod($T.class, $S)", - BeanInstanceSupplier.class, suppliedType, declaringClass, factoryMethodName); + BeanInstanceSupplier.class, suppliedType, targetClass, factoryMethodName); } CodeBlock parameterTypes = generateParameterTypesCode(factoryMethod.getParameterTypes(), 0); return CodeBlock.of("return $T.<$T>forFactoryMethod($T.class, $S, $L)", - BeanInstanceSupplier.class, suppliedType, declaringClass, factoryMethodName, parameterTypes); + BeanInstanceSupplier.class, suppliedType, targetClass, factoryMethodName, parameterTypes); } private CodeBlock generateNewInstanceCodeForMethod(boolean dependsOnBean, - Class declaringClass, String factoryMethodName, CodeBlock args) { + Class targetClass, String factoryMethodName, CodeBlock args) { if (!dependsOnBean) { - return CodeBlock.of("$T.$L($L)", declaringClass, factoryMethodName, args); + return CodeBlock.of("$T.$L($L)", targetClass, factoryMethodName, args); } return CodeBlock.of("$L.getBeanFactory().getBean($T.class).$L($L)", - REGISTERED_BEAN_PARAMETER_NAME, declaringClass, factoryMethodName, args); + REGISTERED_BEAN_PARAMETER_NAME, targetClass, factoryMethodName, args); } private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/RegisteredBean.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/RegisteredBean.java index 8b80359352d..7ed86ef98f7 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/RegisteredBean.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/RegisteredBean.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ package org.springframework.beans.factory.support; import java.lang.reflect.Executable; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Supplier; @@ -206,12 +208,33 @@ public final class RegisteredBean { /** * Resolve the constructor or factory method to use for this bean. * @return the {@link java.lang.reflect.Constructor} or {@link java.lang.reflect.Method} + * @deprecated in favor of {@link #resolveInstantiationDescriptor()} */ + @Deprecated(since = "6.1.7") public Executable resolveConstructorOrFactoryMethod() { return new ConstructorResolver((AbstractAutowireCapableBeanFactory) getBeanFactory()) .resolveConstructorOrFactoryMethod(getBeanName(), getMergedBeanDefinition()); } + /** + * Resolve the {@linkplain InstantiationDescriptor descriptor} to use to + * instantiate this bean. It defines the {@link java.lang.reflect.Constructor} + * or {@link java.lang.reflect.Method} to use as well as additional metadata. + * @since 6.1.7 + */ + public InstantiationDescriptor resolveInstantiationDescriptor() { + Executable executable = resolveConstructorOrFactoryMethod(); + if (executable instanceof Method method && !Modifier.isStatic(method.getModifiers())) { + String factoryBeanName = getMergedBeanDefinition().getFactoryBeanName(); + if (factoryBeanName != null && this.beanFactory.containsBean(factoryBeanName)) { + Class target = this.beanFactory.getMergedBeanDefinition(factoryBeanName) + .getResolvableType().toClass(); + return new InstantiationDescriptor(executable, target); + } + } + return new InstantiationDescriptor(executable, executable.getDeclaringClass()); + } + /** * Resolve an autowired argument. * @param descriptor the descriptor for the dependency (field/method/constructor) @@ -237,6 +260,20 @@ public final class RegisteredBean { .append("mergedBeanDefinition", getMergedBeanDefinition()).toString(); } + /** + * Describe how a bean should be instantiated. While the {@code targetClass} + * is usually the declaring class of the {@code executable}, there are cases + * where retaining the actual concrete type is necessary. + * @param executable the {@link Executable} to invoke + * @param targetClass the target {@link Class} of the executable + * @since 6.1.7 + */ + public record InstantiationDescriptor(Executable executable, Class targetClass) { + + public InstantiationDescriptor(Executable executable) { + this(executable, executable.getDeclaringClass()); + } + } /** * Resolver used to obtain inner-bean details. diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 991578a2fca..0cd797b8afd 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -161,7 +161,8 @@ class BeanDefinitionMethodGeneratorTests { @Test void generateWithBeanClassAndFactoryMethodNameSetsTargetTypeAndBeanClass() { - this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration()); + this.beanFactory.registerBeanDefinition("factory", + new RootBeanDefinition(SimpleBeanConfiguration.class)); RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class); beanDefinition.setFactoryBeanName("factory"); beanDefinition.setFactoryMethodName("simpleBean"); @@ -182,7 +183,8 @@ class BeanDefinitionMethodGeneratorTests { @Test void generateWithTargetTypeAndFactoryMethodNameSetsOnlyBeanClass() { - this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration()); + this.beanFactory.registerBeanDefinition("factory", + new RootBeanDefinition(SimpleBeanConfiguration.class)); RootBeanDefinition beanDefinition = new RootBeanDefinition(); beanDefinition.setTargetType(SimpleBean.class); beanDefinition.setFactoryBeanName("factory"); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java index c1ded5f4a29..d6f77e1b0ff 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java @@ -31,6 +31,7 @@ import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.factory.DummyFactory; import org.springframework.beans.testfixture.beans.factory.StringFactoryBean; +import org.springframework.beans.testfixture.beans.factory.aot.DefaultSimpleBeanContract; import org.springframework.beans.testfixture.beans.factory.aot.GenericFactoryBean; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationCode; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; @@ -38,6 +39,7 @@ import org.springframework.beans.testfixture.beans.factory.aot.NumberFactoryBean import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanArrayFactoryBean; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanConfiguration; +import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanContract; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanFactoryBean; import org.springframework.core.ResolvableType; import org.springframework.javapoet.ClassName; @@ -126,6 +128,21 @@ class DefaultBeanRegistrationCodeFragmentsTests { SimpleBeanConfiguration.class); } + @Test // gh-32609 + void getTargetOnMethodFromInterface() { + this.beanFactory.registerBeanDefinition("configuration", + new RootBeanDefinition(DefaultSimpleBeanContract.class)); + Method method = ReflectionUtils.findMethod(SimpleBeanContract.class, "simpleBean"); + assertThat(method).isNotNull(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class); + applyConstructorOrFactoryMethod(beanDefinition, method); + beanDefinition.setFactoryBeanName("configuration"); + this.beanFactory.registerBeanDefinition("testBean", beanDefinition); + RegisteredBean registeredBean = RegisteredBean.of(this.beanFactory, "testBean"); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), + DefaultSimpleBeanContract.class); + } + @Test void getTargetOnMethodWithInnerBeanInJavaPackage() { RegisteredBean registeredBean = registerTestBean(SimpleBean.class); @@ -190,7 +207,7 @@ class DefaultBeanRegistrationCodeFragmentsTests { } @Test - void customizedGetTargetDoesNotResolveConstructorOrFactoryMethod() { + void customizedGetTargetDoesNotResolveInstantiationDescriptor() { RegisteredBean registeredBean = spy(registerTestBean(SimpleBean.class)); BeanRegistrationCodeFragments customCodeFragments = createCustomCodeFragments(registeredBean, codeFragments -> new BeanRegistrationCodeFragmentsDecorator(codeFragments) { @Override @@ -199,11 +216,11 @@ class DefaultBeanRegistrationCodeFragmentsTests { } }); assertTarget(customCodeFragments.getTarget(registeredBean), String.class); - verify(registeredBean, never()).resolveConstructorOrFactoryMethod(); + verify(registeredBean, never()).resolveInstantiationDescriptor(); } @Test - void customizedGenerateInstanceSupplierCodeDoesNotResolveConstructorOrFactoryMethod() { + void customizedGenerateInstanceSupplierCodeDoesNotResolveInstantiationDescriptor() { RegisteredBean registeredBean = spy(registerTestBean(SimpleBean.class)); BeanRegistrationCodeFragments customCodeFragments = createCustomCodeFragments(registeredBean, codeFragments -> new BeanRegistrationCodeFragmentsDecorator(codeFragments) { @Override @@ -214,7 +231,7 @@ class DefaultBeanRegistrationCodeFragmentsTests { }); assertThat(customCodeFragments.generateInstanceSupplierCode(this.generationContext, new MockBeanRegistrationCode(this.generationContext), false)).hasToString("// Hello"); - verify(registeredBean, never()).resolveConstructorOrFactoryMethod(); + verify(registeredBean, never()).resolveInstantiationDescriptor(); } private BeanRegistrationCodeFragments createCustomCodeFragments(RegisteredBean registeredBean, UnaryOperator customFragments) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java index 0f60beecac6..84e5480494e 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.beans.factory.aot; -import java.lang.reflect.Executable; import java.util.function.BiConsumer; import java.util.function.Supplier; @@ -38,10 +37,14 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.TestBeanWithPrivateConstructor; +import org.springframework.beans.testfixture.beans.factory.aot.DefaultSimpleBeanContract; import org.springframework.beans.testfixture.beans.factory.aot.DeferredTypeBuilder; +import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; +import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanContract; import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration; import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration.EnvironmentAwareComponent; import org.springframework.beans.testfixture.beans.factory.generator.InnerComponentConfiguration.NoDependencyComponent; @@ -185,6 +188,23 @@ class InstanceSupplierCodeGeneratorTests { .satisfies(hasMethodWithMode(ExecutableMode.INTROSPECT)); } + @Test + void generateWhenHasFactoryMethodOnInterface() { + BeanDefinition beanDefinition = BeanDefinitionBuilder + .rootBeanDefinition(SimpleBean.class) + .setFactoryMethodOnBean("simpleBean", "config").getBeanDefinition(); + this.beanFactory.registerBeanDefinition("config", BeanDefinitionBuilder + .rootBeanDefinition(DefaultSimpleBeanContract.class).getBeanDefinition()); + compile(beanDefinition, (instanceSupplier, compiled) -> { + Object bean = getBean(beanDefinition, instanceSupplier); + assertThat(bean).isInstanceOf(SimpleBean.class); + assertThat(compiled.getSourceFile()).contains( + "getBeanFactory().getBean(DefaultSimpleBeanContract.class).simpleBean()"); + }); + assertThat(getReflectionHints().getTypeHint(SimpleBeanContract.class)) + .satisfies(hasMethodWithMode(ExecutableMode.INTROSPECT)); + } + @Test void generateWhenHasPrivateStaticFactoryMethodWithNoArg() { BeanDefinition beanDefinition = BeanDefinitionBuilder @@ -402,9 +422,9 @@ class InstanceSupplierCodeGeneratorTests { InstanceSupplierCodeGenerator generator = new InstanceSupplierCodeGenerator( this.generationContext, generateClass.getName(), generateClass.getMethods(), false); - Executable constructorOrFactoryMethod = registeredBean.resolveConstructorOrFactoryMethod(); - assertThat(constructorOrFactoryMethod).isNotNull(); - CodeBlock generatedCode = generator.generateCode(registeredBean, constructorOrFactoryMethod); + InstantiationDescriptor instantiationDescriptor = registeredBean.resolveInstantiationDescriptor(); + assertThat(instantiationDescriptor).isNotNull(); + CodeBlock generatedCode = generator.generateCode(registeredBean, instantiationDescriptor); typeBuilder.set(type -> { type.addModifiers(Modifier.PUBLIC); type.addSuperinterface(ParameterizedTypeName.get(Supplier.class, InstanceSupplier.class)); diff --git a/spring-beans/src/test/kotlin/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorKotlinTests.kt b/spring-beans/src/test/kotlin/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorKotlinTests.kt index 6623bb4581f..65c71ce2dda 100644 --- a/spring-beans/src/test/kotlin/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorKotlinTests.kt +++ b/spring-beans/src/test/kotlin/org/springframework/beans/factory/aot/InstanceSupplierCodeGeneratorKotlinTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -116,9 +116,9 @@ class InstanceSupplierCodeGeneratorKotlinTests { generationContext, generateClass.name, generateClass.methods, false ) - val constructorOrFactoryMethod = registeredBean.resolveConstructorOrFactoryMethod() - Assertions.assertThat(constructorOrFactoryMethod).isNotNull() - val generatedCode = generator.generateCode(registeredBean, constructorOrFactoryMethod) + val instantiationDescriptor = registeredBean.resolveInstantiationDescriptor() + Assertions.assertThat(instantiationDescriptor).isNotNull() + val generatedCode = generator.generateCode(registeredBean, instantiationDescriptor) typeBuilder.set { type: TypeSpec.Builder -> type.addModifiers(Modifier.PUBLIC) type.addSuperinterface( diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/DefaultSimpleBeanContract.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/DefaultSimpleBeanContract.java new file mode 100644 index 00000000000..4893369a9d2 --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/DefaultSimpleBeanContract.java @@ -0,0 +1,25 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans.factory.aot; + +public class DefaultSimpleBeanContract implements SimpleBeanContract { + + public SimpleBean anotherSimpleBean() { + return new SimpleBean(); + } + +} diff --git a/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanContract.java b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanContract.java new file mode 100644 index 00000000000..96d72498740 --- /dev/null +++ b/spring-beans/src/testFixtures/java/org/springframework/beans/testfixture/beans/factory/aot/SimpleBeanContract.java @@ -0,0 +1,30 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.testfixture.beans.factory.aot; + +/** + * Showcase a factory method that is defined on an interface. + * + * @author Stephane Nicoll + */ +public interface SimpleBeanContract { + + default SimpleBean simpleBean() { + return new SimpleBean(); + } + +} diff --git a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java index 6396f826eae..bba880ba512 100644 --- a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java @@ -75,6 +75,7 @@ import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProce import org.springframework.beans.factory.support.BeanNameGenerator; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationStartupAware; import org.springframework.context.EnvironmentAware; @@ -795,24 +796,26 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { - Executable executableToUse = proxyExecutable(generationContext.getRuntimeHints(), - this.registeredBean.resolveConstructorOrFactoryMethod()); + InstantiationDescriptor instantiationDescriptor = proxyInstantiationDescriptor( + generationContext.getRuntimeHints(), this.registeredBean.resolveInstantiationDescriptor()); return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) - .generateCode(this.registeredBean, executableToUse); + .generateCode(this.registeredBean, instantiationDescriptor); } - private Executable proxyExecutable(RuntimeHints runtimeHints, Executable userExecutable) { + private InstantiationDescriptor proxyInstantiationDescriptor(RuntimeHints runtimeHints, InstantiationDescriptor instantiationDescriptor) { + Executable userExecutable = instantiationDescriptor.executable(); if (userExecutable instanceof Constructor userConstructor) { try { runtimeHints.reflection().registerConstructor(userConstructor, ExecutableMode.INTROSPECT); - return this.proxyClass.getConstructor(userExecutable.getParameterTypes()); + Constructor constructor = this.proxyClass.getConstructor(userExecutable.getParameterTypes()); + return new InstantiationDescriptor(constructor); } catch (NoSuchMethodException ex) { throw new IllegalStateException("No matching constructor found on proxy " + this.proxyClass, ex); } } - return userExecutable; + return instantiationDescriptor; } } diff --git a/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorIntegrationTests.java b/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorIntegrationTests.java index c056ffb84cc..0fdbe04b1ef 100644 --- a/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorIntegrationTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorIntegrationTests.java @@ -432,7 +432,6 @@ class TestContextAotGeneratorIntegrationTests extends AbstractAotTests { "org/springframework/test/context/aot/samples/web/WebTestConfiguration__TestContext006_BeanDefinitions.java", "org/springframework/web/servlet/config/annotation/DelegatingWebMvcConfiguration__TestContext006_Autowiring.java", "org/springframework/web/servlet/config/annotation/DelegatingWebMvcConfiguration__TestContext006_BeanDefinitions.java", - "org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport__TestContext006_BeanDefinitions.java", // XmlSpringJupiterTests "org/springframework/context/event/DefaultEventListenerFactory__TestContext007_BeanDefinitions.java", "org/springframework/context/event/EventListenerMethodProcessor__TestContext007_BeanDefinitions.java", diff --git a/spring-test/src/test/java/org/springframework/test/context/aot/samples/management/ManagementConfiguration.java b/spring-test/src/test/java/org/springframework/test/context/aot/samples/management/ManagementConfiguration.java index 7500bb6f44a..dc444f3465c 100644 --- a/spring-test/src/test/java/org/springframework/test/context/aot/samples/management/ManagementConfiguration.java +++ b/spring-test/src/test/java/org/springframework/test/context/aot/samples/management/ManagementConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import org.springframework.aot.generate.GenerationContext; import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; import org.springframework.beans.factory.aot.BeanRegistrationCode; +import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.aot.ApplicationContextAotGenerator; @@ -44,7 +45,8 @@ public class ManagementConfiguration { @Bean static BeanRegistrationAotProcessor beanRegistrationAotProcessor() { return registeredBean -> { - Executable factoryMethod = registeredBean.resolveConstructorOrFactoryMethod(); + InstantiationDescriptor instantiationDescriptor = registeredBean.resolveInstantiationDescriptor(); + Executable factoryMethod = instantiationDescriptor.executable(); // Make AOT contribution for @Managed @Bean methods. if (AnnotatedElementUtils.hasAnnotation(factoryMethod, Managed.class)) { return new AotContribution(createManagementContext());