From 66a571fe276eb49727b7d1bacc4da92548dd3573 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Mon, 11 Sep 2023 09:47:44 +0200 Subject: [PATCH] Make constructorOrFactory method resolution optional This commit allows a custom code fragment to provide the code to create a bean without relying on ConstructorResolver. This is especially important for use cases that derive from the default behaviour and provide an instance supplier with the regular runtime scenario. This is a breaking change for code fragments providing a custom implementation of the related methods. As it turns out, almost all of them did not need the Executable argument. Configuration class parsing is the exception, where it needs to provide a different constructor in the case of the proxy. To make this use case possible, InstanceSupplierCodeGenerator has been made public. Closes gh-31117 --- ...opedProxyBeanRegistrationAotProcessor.java | 7 +- .../aot/BeanDefinitionMethodGenerator.java | 66 +------ .../aot/BeanRegistrationCodeFragments.java | 36 ++-- ...eanRegistrationCodeFragmentsDecorator.java | 12 +- .../aot/BeanRegistrationCodeGenerator.java | 12 +- .../DefaultBeanRegistrationCodeFragments.java | 17 +- .../aot/InstanceSupplierCodeGenerator.java | 83 ++++++++- ...ultBeanRegistrationCodeFragmentsTests.java | 166 +++++++++++++----- .../ConfigurationClassPostProcessor.java | 22 ++- ...agedTypesBeanRegistrationAotProcessor.java | 2 - 10 files changed, 260 insertions(+), 163 deletions(-) diff --git a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java index dbf2df4fa19..63f40110c4b 100644 --- a/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java +++ b/spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java @@ -16,7 +16,6 @@ package org.springframework.aop.scope; -import java.lang.reflect.Executable; import java.util.function.Predicate; import javax.lang.model.element.Modifier; @@ -109,7 +108,7 @@ class ScopedProxyBeanRegistrationAotProcessor implements BeanRegistrationAotProc } @Override - public ClassName getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { + public ClassName getTarget(RegisteredBean registeredBean) { return ClassName.get(this.targetBeanDefinition.getResolvableType().toClass()); } @@ -139,9 +138,7 @@ class ScopedProxyBeanRegistrationAotProcessor implements BeanRegistrationAotProc @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, - BeanRegistrationCode beanRegistrationCode, - Executable constructorOrFactoryMethod, - boolean allowDirectSupplierShortcut) { + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { GeneratedMethod generatedMethod = beanRegistrationCode.getMethods() .add("getScopedProxyInstance", method -> { diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index 68a06fd7dbf..6986183d14f 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -16,10 +16,6 @@ package org.springframework.beans.factory.aot; -import java.lang.reflect.Constructor; -import java.lang.reflect.Executable; -import java.lang.reflect.Method; -import java.lang.reflect.Proxy; import java.util.List; import javax.lang.model.element.Modifier; @@ -29,14 +25,9 @@ import org.springframework.aot.generate.GeneratedMethod; import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; -import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.factory.config.BeanDefinition; -import org.springframework.beans.factory.config.DependencyDescriptor; -import org.springframework.beans.factory.support.AutowireCandidateResolver; -import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; -import org.springframework.core.MethodParameter; import org.springframework.javapoet.ClassName; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; @@ -56,8 +47,6 @@ class BeanDefinitionMethodGenerator { private final RegisteredBean registeredBean; - private final Executable constructorOrFactoryMethod; - @Nullable private final String currentPropertyName; @@ -83,7 +72,6 @@ class BeanDefinitionMethodGenerator { } this.methodGeneratorFactory = methodGeneratorFactory; this.registeredBean = registeredBean; - this.constructorOrFactoryMethod = registeredBean.resolveConstructorOrFactoryMethod(); this.currentPropertyName = currentPropertyName; this.aotContributions = aotContributions; } @@ -98,9 +86,8 @@ class BeanDefinitionMethodGenerator { MethodReference generateBeanDefinitionMethod(GenerationContext generationContext, BeanRegistrationsCode beanRegistrationsCode) { - registerRuntimeHintsIfNecessary(generationContext.getRuntimeHints()); BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext, beanRegistrationsCode); - ClassName target = codeFragments.getTarget(this.registeredBean, this.constructorOrFactoryMethod); + ClassName target = codeFragments.getTarget(this.registeredBean); if (isWritablePackageName(target)) { GeneratedClass generatedClass = lookupGeneratedClass(generationContext, target); GeneratedMethods generatedMethods = generatedClass.getMethods().withPrefix(getName()); @@ -178,8 +165,7 @@ class BeanDefinitionMethodGenerator { BeanRegistrationCodeFragments codeFragments, Modifier modifier) { BeanRegistrationCodeGenerator codeGenerator = new BeanRegistrationCodeGenerator( - className, generatedMethods, this.registeredBean, - this.constructorOrFactoryMethod, codeFragments); + className, generatedMethods, this.registeredBean, codeFragments); this.aotContributions.forEach(aotContribution -> aotContribution.applyTo(generationContext, codeGenerator)); @@ -218,52 +204,4 @@ class BeanDefinitionMethodGenerator { return StringUtils.uncapitalize(beanName); } - private void registerRuntimeHintsIfNecessary(RuntimeHints runtimeHints) { - if (this.registeredBean.getBeanFactory() instanceof DefaultListableBeanFactory dlbf) { - ProxyRuntimeHintsRegistrar registrar = new ProxyRuntimeHintsRegistrar(dlbf.getAutowireCandidateResolver()); - if (this.constructorOrFactoryMethod instanceof Method method) { - registrar.registerRuntimeHints(runtimeHints, method); - } - else if (this.constructorOrFactoryMethod instanceof Constructor constructor) { - registrar.registerRuntimeHints(runtimeHints, constructor); - } - } - } - - - private static class ProxyRuntimeHintsRegistrar { - - private final AutowireCandidateResolver candidateResolver; - - public ProxyRuntimeHintsRegistrar(AutowireCandidateResolver candidateResolver) { - this.candidateResolver = candidateResolver; - } - - public void registerRuntimeHints(RuntimeHints runtimeHints, Method method) { - Class[] parameterTypes = method.getParameterTypes(); - for (int i = 0; i < parameterTypes.length; i++) { - MethodParameter methodParam = new MethodParameter(method, i); - DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(methodParam, true); - registerProxyIfNecessary(runtimeHints, dependencyDescriptor); - } - } - - public void registerRuntimeHints(RuntimeHints runtimeHints, Constructor constructor) { - Class[] parameterTypes = constructor.getParameterTypes(); - for (int i = 0; i < parameterTypes.length; i++) { - MethodParameter methodParam = new MethodParameter(constructor, i); - DependencyDescriptor dependencyDescriptor = new DependencyDescriptor( - methodParam, true); - registerProxyIfNecessary(runtimeHints, dependencyDescriptor); - } - } - - private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) { - Class proxyType = this.candidateResolver.getLazyResolutionProxyClass(dependencyDescriptor, null); - if (proxyType != null && Proxy.isProxyClass(proxyType)) { - runtimeHints.proxies().registerJdkProxy(proxyType.getInterfaces()); - } - } - } - } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java index d7f02a43e5e..db1bd2e8155 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ package org.springframework.beans.factory.aot; -import java.lang.reflect.Executable; import java.util.List; import java.util.function.Predicate; +import java.util.function.UnaryOperator; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; @@ -31,9 +31,19 @@ import org.springframework.javapoet.CodeBlock; /** * Generate the various fragments of code needed to register a bean. - * + *

+ * A default implementation is provided that suits most needs and custom code + * fragments are only expected to be used by library authors having built custom + * arrangement on top of the core container. + *

+ * Users are not expected to implement this interface directly, but rather extends + * from {@link BeanRegistrationCodeFragmentsDecorator} and only override the + * necessary method(s). * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 + * @see BeanRegistrationCodeFragmentsDecorator + * @see BeanRegistrationAotContribution#withCustomCodeFragments(UnaryOperator) */ public interface BeanRegistrationCodeFragments { @@ -50,16 +60,19 @@ public interface BeanRegistrationCodeFragments { /** * Return the target for the registration. Used to determine where to write - * the code. + * the code. This should take into account visibility issue, such as + * package access of an element of the bean to register. * @param registeredBean the registered bean - * @param constructorOrFactoryMethod the constructor or factory method * @return the target {@link ClassName} */ - ClassName getTarget(RegisteredBean registeredBean, - Executable constructorOrFactoryMethod); + ClassName getTarget(RegisteredBean registeredBean); /** * Generate the code that defines the new bean definition instance. + *

+ * This should declare a variable named {@value BEAN_DEFINITION_VARIABLE} + * so that further fragments can refer to the variable to further tune + * the bean definition. * @param generationContext the generation context * @param beanType the bean type * @param beanRegistrationCode the bean registration code @@ -81,6 +94,11 @@ public interface BeanRegistrationCodeFragments { /** * Generate the code that sets the instance supplier on the bean definition. + *

+ * The {@code postProcessors} represent methods to be exposed once the + * instance has been created to further configure it. Each method should + * accept two parameters, the {@link RegisteredBean} and the bean + * instance, and should return the modified bean instance. * @param generationContext the generation context * @param beanRegistrationCode the bean registration code * @param instanceSupplierCode the instance supplier code supplier code @@ -96,15 +114,13 @@ public interface BeanRegistrationCodeFragments { * Generate the instance supplier code. * @param generationContext the generation context * @param beanRegistrationCode the bean registration code - * @param constructorOrFactoryMethod the constructor or factory method for - * the bean * @param allowDirectSupplierShortcut if direct suppliers may be used rather * than always needing an {@link InstanceSupplier} * @return the generated code */ CodeBlock generateInstanceSupplierCode( GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode, - Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut); + boolean allowDirectSupplierShortcut); /** * Generate the return statement. diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java index e4ff961262e..4a493d0d939 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.beans.factory.aot; -import java.lang.reflect.Executable; import java.util.List; import java.util.function.Predicate; import java.util.function.UnaryOperator; @@ -51,8 +50,8 @@ public class BeanRegistrationCodeFragmentsDecorator implements BeanRegistrationC } @Override - public ClassName getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { - return this.delegate.getTarget(registeredBean, constructorOrFactoryMethod); + public ClassName getTarget(RegisteredBean registeredBean) { + return this.delegate.getTarget(registeredBean); } @Override @@ -83,11 +82,10 @@ public class BeanRegistrationCodeFragmentsDecorator implements BeanRegistrationC @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, - BeanRegistrationCode beanRegistrationCode, Executable constructorOrFactoryMethod, - boolean allowDirectSupplierShortcut) { + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { return this.delegate.generateInstanceSupplierCode(generationContext, - beanRegistrationCode, constructorOrFactoryMethod, allowDirectSupplierShortcut); + beanRegistrationCode, allowDirectSupplierShortcut); } @Override diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java index 3547378b067..98564d4852e 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.beans.factory.aot; -import java.lang.reflect.Executable; import java.util.ArrayList; import java.util.List; import java.util.function.Predicate; @@ -47,19 +46,15 @@ class BeanRegistrationCodeGenerator implements BeanRegistrationCode { private final RegisteredBean registeredBean; - private final Executable constructorOrFactoryMethod; - private final BeanRegistrationCodeFragments codeFragments; BeanRegistrationCodeGenerator(ClassName className, GeneratedMethods generatedMethods, - RegisteredBean registeredBean, Executable constructorOrFactoryMethod, - BeanRegistrationCodeFragments codeFragments) { + RegisteredBean registeredBean, BeanRegistrationCodeFragments codeFragments) { this.className = className; this.generatedMethods = generatedMethods; this.registeredBean = registeredBean; - this.constructorOrFactoryMethod = constructorOrFactoryMethod; this.codeFragments = codeFragments; } @@ -87,8 +82,7 @@ class BeanRegistrationCodeGenerator implements BeanRegistrationCode { generationContext, this, this.registeredBean.getMergedBeanDefinition(), REJECT_ALL_ATTRIBUTES_FILTER)); CodeBlock instanceSupplierCode = this.codeFragments.generateInstanceSupplierCode( - generationContext, this, this.constructorOrFactoryMethod, - this.instancePostProcessors.isEmpty()); + generationContext, this, this.instancePostProcessors.isEmpty()); code.add(this.codeFragments.generateSetBeanInstanceSupplierCode(generationContext, this, instanceSupplierCode, this.instancePostProcessors)); code.add(this.codeFragments.generateReturnCode(generationContext, this)); diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index 26df73d45ee..e51aebd7f02 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -21,6 +21,7 @@ import java.lang.reflect.Executable; import java.lang.reflect.Modifier; import java.util.List; import java.util.function.Predicate; +import java.util.function.Supplier; import org.springframework.aot.generate.AccessControl; import org.springframework.aot.generate.GenerationContext; @@ -39,12 +40,14 @@ import org.springframework.javapoet.ParameterizedTypeName; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.function.SingletonSupplier; /** * Internal {@link BeanRegistrationCodeFragments} implementation used by * default. * * @author Phillip Webb + * @author Stephane Nicoll */ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragments { @@ -54,6 +57,8 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory; + private final Supplier constructorOrFactoryMethod; + DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode, RegisteredBean registeredBean, @@ -62,14 +67,13 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme this.beanRegistrationsCode = beanRegistrationsCode; this.registeredBean = registeredBean; this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory; + this.constructorOrFactoryMethod = SingletonSupplier.of(registeredBean::resolveConstructorOrFactoryMethod); } @Override - public ClassName getTarget(RegisteredBean registeredBean, - Executable constructorOrFactoryMethod) { - - Class target = extractDeclaringClass(registeredBean.getBeanType(), constructorOrFactoryMethod); + public ClassName getTarget(RegisteredBean registeredBean) { + Class target = extractDeclaringClass(registeredBean.getBeanType(), this.constructorOrFactoryMethod.get()); while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) { RegisteredBean parent = registeredBean.getParent(); Assert.state(parent != null, "No parent available for inner bean"); @@ -219,12 +223,11 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, - BeanRegistrationCode beanRegistrationCode, - Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut) { + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) - .generateCode(this.registeredBean, constructorOrFactoryMethod); + .generateCode(this.registeredBean,this.constructorOrFactoryMethod.get()); } @Override diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java index e21ead29bf5..87de98852ec 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java @@ -21,6 +21,7 @@ import java.lang.reflect.Executable; import java.lang.reflect.Member; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.Proxy; import java.util.Arrays; import java.util.function.Consumer; @@ -38,9 +39,14 @@ import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator; import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.ReflectionHints; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.beans.factory.config.DependencyDescriptor; +import org.springframework.beans.factory.support.AutowireCandidateResolver; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.core.KotlinDetector; +import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; @@ -51,9 +57,10 @@ import org.springframework.util.ClassUtils; import org.springframework.util.function.ThrowingSupplier; /** - * Internal code generator to create an {@link InstanceSupplier}, usually in + * Default code generator to create an {@link InstanceSupplier}, usually in * the form of a {@link BeanInstanceSupplier} that retains the executable - * that is used to instantiate the bean. + * that is used to instantiate the bean. Takes care of registering the + * necessary hints if reflection or a JDK proxy is required. * *

Generated code is usually a method reference that generates the * {@link BeanInstanceSupplier}, but some shortcut can be used as well such as: @@ -66,8 +73,9 @@ import org.springframework.util.function.ThrowingSupplier; * @author Juergen Hoeller * @author Sebastien Deleuze * @since 6.0 + * @see BeanRegistrationCodeFragments */ -class InstanceSupplierCodeGenerator { +public class InstanceSupplierCodeGenerator { private static final String REGISTERED_BEAN_PARAMETER_NAME = "registeredBean"; @@ -89,7 +97,15 @@ class InstanceSupplierCodeGenerator { private final boolean allowDirectSupplierShortcut; - InstanceSupplierCodeGenerator(GenerationContext generationContext, + /** + * Create a new instance. + * @param generationContext the generation context + * @param className the class name of the bean to instantiate + * @param generatedMethods the generated methods + * @param allowDirectSupplierShortcut whether a direct supplier may be used rather + * than always needing an {@link InstanceSupplier} + */ + public InstanceSupplierCodeGenerator(GenerationContext generationContext, ClassName className, GeneratedMethods generatedMethods, boolean allowDirectSupplierShortcut) { this.generationContext = generationContext; @@ -98,8 +114,14 @@ class InstanceSupplierCodeGenerator { this.allowDirectSupplierShortcut = allowDirectSupplierShortcut; } - - CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { + /** + * Generate the instance supplier code. + * @param registeredBean the bean to handle + * @param constructorOrFactoryMethod the executable to use to create the bean + * @return the generated code + */ + public CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { + registerRuntimeHintsIfNecessary(registeredBean, constructorOrFactoryMethod); if (constructorOrFactoryMethod instanceof Constructor constructor) { return generateCodeForConstructor(registeredBean, constructor); } @@ -110,6 +132,19 @@ class InstanceSupplierCodeGenerator { "No suitable executor found for " + registeredBean.getBeanName()); } + private void registerRuntimeHintsIfNecessary(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { + if (registeredBean.getBeanFactory() instanceof DefaultListableBeanFactory dlbf) { + RuntimeHints runtimeHints = this.generationContext.getRuntimeHints(); + ProxyRuntimeHintsRegistrar registrar = new ProxyRuntimeHintsRegistrar(dlbf.getAutowireCandidateResolver()); + if (constructorOrFactoryMethod instanceof Method method) { + registrar.registerRuntimeHints(runtimeHints, method); + } + else if (constructorOrFactoryMethod instanceof Constructor constructor) { + registrar.registerRuntimeHints(runtimeHints, constructor); + } + } + } + private CodeBlock generateCodeForConstructor(RegisteredBean registeredBean, Constructor constructor) { String beanName = registeredBean.getBeanName(); Class beanClass = registeredBean.getBeanClass(); @@ -372,4 +407,40 @@ class InstanceSupplierCodeGenerator { } + + private static class ProxyRuntimeHintsRegistrar { + + private final AutowireCandidateResolver candidateResolver; + + public ProxyRuntimeHintsRegistrar(AutowireCandidateResolver candidateResolver) { + this.candidateResolver = candidateResolver; + } + + public void registerRuntimeHints(RuntimeHints runtimeHints, Method method) { + Class[] parameterTypes = method.getParameterTypes(); + for (int i = 0; i < parameterTypes.length; i++) { + MethodParameter methodParam = new MethodParameter(method, i); + DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(methodParam, true); + registerProxyIfNecessary(runtimeHints, dependencyDescriptor); + } + } + + public void registerRuntimeHints(RuntimeHints runtimeHints, Constructor constructor) { + Class[] parameterTypes = constructor.getParameterTypes(); + for (int i = 0; i < parameterTypes.length; i++) { + MethodParameter methodParam = new MethodParameter(constructor, i); + DependencyDescriptor dependencyDescriptor = new DependencyDescriptor( + methodParam, true); + registerProxyIfNecessary(runtimeHints, dependencyDescriptor); + } + } + + private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) { + Class proxyType = this.candidateResolver.getLazyResolutionProxyClass(dependencyDescriptor, null); + if (proxyType != null && Proxy.isProxyClass(proxyType)) { + runtimeHints.proxies().registerJdkProxy(proxyType.getInterfaces()); + } + } + } + } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java index 1ab505e3d35..859cbe62b66 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java @@ -16,10 +16,14 @@ package org.springframework.beans.factory.aot; +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; import java.lang.reflect.Method; +import java.util.function.UnaryOperator; import org.junit.jupiter.api.Test; +import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.annotation.InjectAnnotationBeanPostProcessorTests.StringFactoryBean; @@ -28,6 +32,7 @@ import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.factory.DummyFactory; import org.springframework.beans.testfixture.beans.factory.aot.GenericFactoryBean; +import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationCode; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.beans.testfixture.beans.factory.aot.NumberFactoryBean; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; @@ -35,9 +40,14 @@ import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanConfigu import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanFactoryBean; import org.springframework.core.ResolvableType; import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.lang.Nullable; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; /** * Tests for {@link DefaultBeanRegistrationCodeFragments}. @@ -48,136 +58,202 @@ class DefaultBeanRegistrationCodeFragmentsTests { private final BeanRegistrationsCode beanRegistrationsCode = new MockBeanRegistrationsCode(new TestGenerationContext()); + private final GenerationContext generationContext = new TestGenerationContext(); + private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); @Test void getTargetOnConstructor() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - SimpleBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, + SimpleBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class); } @Test void getTargetOnConstructorToPublicFactoryBean() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, + SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class); } @Test void getTargetOnConstructorToPublicGenericFactoryBeanExtractTargetFromFactoryBeanType() { - RegisteredBean registeredBean = registerTestBean(ResolvableType - .forClassWithGenerics(GenericFactoryBean.class, SimpleBean.class)); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - GenericFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + ResolvableType beanType = ResolvableType.forClassWithGenerics( + GenericFactoryBean.class, SimpleBean.class); + RegisteredBean registeredBean = registerTestBean(beanType, + GenericFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class); } @Test void getTargetOnConstructorToPublicGenericFactoryBeanWithBoundExtractTargetFromFactoryBeanType() { - RegisteredBean registeredBean = registerTestBean(ResolvableType - .forClassWithGenerics(NumberFactoryBean.class, Integer.class)); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - NumberFactoryBean.class.getDeclaredConstructors()[0]), Integer.class); + ResolvableType beanType = ResolvableType.forClassWithGenerics( + NumberFactoryBean.class, Integer.class); + RegisteredBean registeredBean = registerTestBean(beanType, + NumberFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), Integer.class); } @Test void getTargetOnConstructorToPublicGenericFactoryBeanUseBeanTypeAsFallback() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - GenericFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, + GenericFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class); } @Test void getTargetOnConstructorToProtectedFactoryBean() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, - PrivilegedTestBeanFactoryBean.class.getDeclaredConstructors()[0]), + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, + PrivilegedTestBeanFactoryBean.class.getDeclaredConstructors()[0]); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), PrivilegedTestBeanFactoryBean.class); } @Test void getTargetOnMethod() { - RegisteredBean registeredBean = registerTestBean(SimpleBean.class); Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean"); assertThat(method).isNotNull(); - assertTarget(createInstance(registeredBean).getTarget(registeredBean, method), + RegisteredBean registeredBean = registerTestBean(SimpleBean.class, method); + assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBeanConfiguration.class); } @Test void getTargetOnMethodWithInnerBeanInJavaPackage() { RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(String.class)); Method method = ReflectionUtils.findMethod(getClass(), "createString"); assertThat(method).isNotNull(); - assertTarget(createInstance(innerBean).getTarget(innerBean, method), getClass()); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + applyConstructorOrFactoryMethod(new RootBeanDefinition(String.class), method)); + assertTarget(createInstance(innerBean).getTarget(innerBean), getClass()); } @Test void getTargetOnConstructorWithInnerBeanInJavaPackage() { RegisteredBean registeredBean = registerTestBean(SimpleBean.class); - RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class)); - assertTarget(createInstance(innerBean).getTarget(innerBean, - String.class.getDeclaredConstructors()[0]), SimpleBean.class); + RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod( + new RootBeanDefinition(String.class), String.class.getDeclaredConstructors()[0]); + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + innerBeanDefinition); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class); } @Test void getTargetOnConstructorWithInnerBeanOnTypeInJavaPackage() { RegisteredBean registeredBean = registerTestBean(SimpleBean.class); + RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod( + new RootBeanDefinition(StringFactoryBean.class), + StringFactoryBean.class.getDeclaredConstructors()[0]); RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(StringFactoryBean.class)); - assertTarget(createInstance(innerBean).getTarget(innerBean, - StringFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + innerBeanDefinition); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class); } @Test void getTargetOnMethodWithInnerBeanInRegularPackage() { RegisteredBean registeredBean = registerTestBean(DummyFactory.class); - RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(SimpleBean.class)); Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean"); assertThat(method).isNotNull(); - assertTarget(createInstance(innerBean).getTarget(innerBean, method), + RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", + applyConstructorOrFactoryMethod(new RootBeanDefinition(SimpleBean.class), method)); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBeanConfiguration.class); } @Test void getTargetOnConstructorWithInnerBeanInRegularPackage() { RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod( + new RootBeanDefinition(SimpleBean.class), SimpleBean.class.getDeclaredConstructors()[0]); RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(SimpleBean.class)); - assertTarget(createInstance(innerBean).getTarget(innerBean, - SimpleBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + innerBeanDefinition); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class); } @Test void getTargetOnConstructorWithInnerBeanOnFactoryBeanOnTypeInRegularPackage() { RegisteredBean registeredBean = registerTestBean(DummyFactory.class); + RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod( + new RootBeanDefinition(SimpleBean.class), + SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]); RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", - new RootBeanDefinition(SimpleBean.class)); - assertTarget(createInstance(innerBean).getTarget(innerBean, - SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); + innerBeanDefinition); + assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class); + } + + @Test + void customizedGetTargetDoesNotResolveConstructorOrFactoryMethod() { + RegisteredBean registeredBean = spy(registerTestBean(SimpleBean.class)); + BeanRegistrationCodeFragments customCodeFragments = createCustomCodeFragments(registeredBean, codeFragments -> new BeanRegistrationCodeFragmentsDecorator(codeFragments) { + @Override + public ClassName getTarget(RegisteredBean registeredBean) { + return ClassName.get(String.class); + } + }); + assertTarget(customCodeFragments.getTarget(registeredBean), String.class); + verify(registeredBean, never()).resolveConstructorOrFactoryMethod(); + } + + @Test + void customizedGenerateInstanceSupplierCodeDoesNotResolveConstructorOrFactoryMethod() { + RegisteredBean registeredBean = spy(registerTestBean(SimpleBean.class)); + BeanRegistrationCodeFragments customCodeFragments = createCustomCodeFragments(registeredBean, codeFragments -> new BeanRegistrationCodeFragmentsDecorator(codeFragments) { + @Override + public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { + return CodeBlock.of("// Hello"); + } + }); + assertThat(customCodeFragments.generateInstanceSupplierCode(this.generationContext, + new MockBeanRegistrationCode(this.generationContext), false)).hasToString("// Hello"); + verify(registeredBean, never()).resolveConstructorOrFactoryMethod(); + } + + private BeanRegistrationCodeFragments createCustomCodeFragments(RegisteredBean registeredBean, UnaryOperator customFragments) { + BeanRegistrationAotContribution aotContribution = BeanRegistrationAotContribution. + withCustomCodeFragments(customFragments); + BeanRegistrationCodeFragments defaultCodeFragments = createInstance(registeredBean); + return aotContribution.customizeBeanRegistrationCodeFragments( + this.generationContext, defaultCodeFragments); } private void assertTarget(ClassName target, Class expected) { assertThat(target).isEqualTo(ClassName.get(expected)); } - private RegisteredBean registerTestBean(Class beanType) { - this.beanFactory.registerBeanDefinition("testBean", - new RootBeanDefinition(beanType)); + return registerTestBean(beanType, null); + } + + private RegisteredBean registerTestBean(Class beanType, + @Nullable Executable constructorOrFactoryMethod) { + this.beanFactory.registerBeanDefinition("testBean", applyConstructorOrFactoryMethod( + new RootBeanDefinition(beanType), constructorOrFactoryMethod)); return RegisteredBean.of(this.beanFactory, "testBean"); } - private RegisteredBean registerTestBean(ResolvableType beanType) { + + private RegisteredBean registerTestBean(ResolvableType beanType, + @Nullable Executable constructorOrFactoryMethod) { RootBeanDefinition beanDefinition = new RootBeanDefinition(); beanDefinition.setTargetType(beanType); - this.beanFactory.registerBeanDefinition("testBean", beanDefinition); + this.beanFactory.registerBeanDefinition("testBean", + applyConstructorOrFactoryMethod(beanDefinition, constructorOrFactoryMethod)); return RegisteredBean.of(this.beanFactory, "testBean"); } + private RootBeanDefinition applyConstructorOrFactoryMethod(RootBeanDefinition beanDefinition, + @Nullable Executable constructorOrFactoryMethod) { + + if (constructorOrFactoryMethod instanceof Method method) { + beanDefinition.setResolvedFactoryMethod(method); + } + else if (constructorOrFactoryMethod instanceof Constructor constructor) { + beanDefinition.setAttribute(RootBeanDefinition.PREFERRED_CONSTRUCTORS_ATTRIBUTE, constructor); + } + return beanDefinition; + } + private BeanRegistrationCodeFragments createInstance(RegisteredBean registeredBean) { return new DefaultBeanRegistrationCodeFragments(this.beanRegistrationsCode, registeredBean, new BeanDefinitionMethodGeneratorFactory(this.beanFactory)); diff --git a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java index aefd1f7ae10..87ad3b39aed 100644 --- a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java @@ -58,6 +58,7 @@ import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments; import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator; +import org.springframework.beans.factory.aot.InstanceSupplierCodeGenerator; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; @@ -315,9 +316,8 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo Object configClassAttr = registeredBean.getMergedBeanDefinition() .getAttribute(ConfigurationClassUtils.CONFIGURATION_CLASS_ATTRIBUTE); if (ConfigurationClassUtils.CONFIGURATION_CLASS_FULL.equals(configClassAttr)) { - Class proxyClass = registeredBean.getBeanType().toClass(); return BeanRegistrationAotContribution.withCustomCodeFragments(codeFragments -> - new ConfigurationClassProxyBeanRegistrationCodeFragments(codeFragments, proxyClass)); + new ConfigurationClassProxyBeanRegistrationCodeFragments(codeFragments, registeredBean)); } return null; } @@ -749,12 +749,15 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo private static class ConfigurationClassProxyBeanRegistrationCodeFragments extends BeanRegistrationCodeFragmentsDecorator { + private final RegisteredBean registeredBean; + private final Class proxyClass; public ConfigurationClassProxyBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments, - Class proxyClass) { + RegisteredBean registeredBean) { super(codeFragments); - this.proxyClass = proxyClass; + this.registeredBean = registeredBean; + this.proxyClass = registeredBean.getBeanType().toClass(); } @Override @@ -770,11 +773,14 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, - BeanRegistrationCode beanRegistrationCode, Executable constructorOrFactoryMethod, + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { - Executable executableToUse = proxyExecutable(generationContext.getRuntimeHints(), constructorOrFactoryMethod); - return super.generateInstanceSupplierCode(generationContext, beanRegistrationCode, - executableToUse, allowDirectSupplierShortcut); + + Executable executableToUse = proxyExecutable(generationContext.getRuntimeHints(), + this.registeredBean.resolveConstructorOrFactoryMethod()); + return new InstanceSupplierCodeGenerator(generationContext, + beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) + .generateCode(this.registeredBean, executableToUse); } private Executable proxyExecutable(RuntimeHints runtimeHints, Executable userExecutable) { diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java index c1a5cf8eae2..2a6d68a2b61 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java @@ -17,7 +17,6 @@ package org.springframework.orm.jpa.persistenceunit; import java.lang.annotation.Annotation; -import java.lang.reflect.Executable; import java.util.List; import javax.lang.model.element.Modifier; @@ -97,7 +96,6 @@ class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistr @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode, - Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut) { PersistenceManagedTypes persistenceManagedTypes = this.registeredBean.getBeanFactory() .getBean(this.registeredBean.getBeanName(), PersistenceManagedTypes.class);