Browse Source

Make constructorOrFactory method resolution optional

This commit allows a custom code fragment to provide the code to
create a bean without relying on ConstructorResolver. This is especially
important for use cases that derive from the default behaviour and
provide an instance supplier with the regular runtime scenario.

This is a breaking change for code fragments providing a custom
implementation of the related methods. As it turns out, almost all of
them did not need the Executable argument. Configuration class parsing
is the exception, where it needs to provide a different constructor in
the case of the proxy. To make this use case possible,
InstanceSupplierCodeGenerator has been made public.

Closes gh-31117
pull/31206/head
Stephane Nicoll 2 years ago
parent
commit
66a571fe27
  1. 7
      spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java
  2. 66
      spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java
  3. 36
      spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java
  4. 12
      spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java
  5. 12
      spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java
  6. 17
      spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java
  7. 83
      spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java
  8. 166
      spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java
  9. 22
      spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java
  10. 2
      spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java

7
spring-aop/src/main/java/org/springframework/aop/scope/ScopedProxyBeanRegistrationAotProcessor.java

@ -16,7 +16,6 @@
package org.springframework.aop.scope; package org.springframework.aop.scope;
import java.lang.reflect.Executable;
import java.util.function.Predicate; import java.util.function.Predicate;
import javax.lang.model.element.Modifier; import javax.lang.model.element.Modifier;
@ -109,7 +108,7 @@ class ScopedProxyBeanRegistrationAotProcessor implements BeanRegistrationAotProc
} }
@Override @Override
public ClassName getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { public ClassName getTarget(RegisteredBean registeredBean) {
return ClassName.get(this.targetBeanDefinition.getResolvableType().toClass()); return ClassName.get(this.targetBeanDefinition.getResolvableType().toClass());
} }
@ -139,9 +138,7 @@ class ScopedProxyBeanRegistrationAotProcessor implements BeanRegistrationAotProc
@Override @Override
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
BeanRegistrationCode beanRegistrationCode, BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) {
Executable constructorOrFactoryMethod,
boolean allowDirectSupplierShortcut) {
GeneratedMethod generatedMethod = beanRegistrationCode.getMethods() GeneratedMethod generatedMethod = beanRegistrationCode.getMethods()
.add("getScopedProxyInstance", method -> { .add("getScopedProxyInstance", method -> {

66
spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java

@ -16,10 +16,6 @@
package org.springframework.beans.factory.aot; package org.springframework.beans.factory.aot;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.List; import java.util.List;
import javax.lang.model.element.Modifier; import javax.lang.model.element.Modifier;
@ -29,14 +25,9 @@ import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GeneratedMethods; import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.DependencyDescriptor;
import org.springframework.beans.factory.support.AutowireCandidateResolver;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.MethodParameter;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -56,8 +47,6 @@ class BeanDefinitionMethodGenerator {
private final RegisteredBean registeredBean; private final RegisteredBean registeredBean;
private final Executable constructorOrFactoryMethod;
@Nullable @Nullable
private final String currentPropertyName; private final String currentPropertyName;
@ -83,7 +72,6 @@ class BeanDefinitionMethodGenerator {
} }
this.methodGeneratorFactory = methodGeneratorFactory; this.methodGeneratorFactory = methodGeneratorFactory;
this.registeredBean = registeredBean; this.registeredBean = registeredBean;
this.constructorOrFactoryMethod = registeredBean.resolveConstructorOrFactoryMethod();
this.currentPropertyName = currentPropertyName; this.currentPropertyName = currentPropertyName;
this.aotContributions = aotContributions; this.aotContributions = aotContributions;
} }
@ -98,9 +86,8 @@ class BeanDefinitionMethodGenerator {
MethodReference generateBeanDefinitionMethod(GenerationContext generationContext, MethodReference generateBeanDefinitionMethod(GenerationContext generationContext,
BeanRegistrationsCode beanRegistrationsCode) { BeanRegistrationsCode beanRegistrationsCode) {
registerRuntimeHintsIfNecessary(generationContext.getRuntimeHints());
BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext, beanRegistrationsCode); BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext, beanRegistrationsCode);
ClassName target = codeFragments.getTarget(this.registeredBean, this.constructorOrFactoryMethod); ClassName target = codeFragments.getTarget(this.registeredBean);
if (isWritablePackageName(target)) { if (isWritablePackageName(target)) {
GeneratedClass generatedClass = lookupGeneratedClass(generationContext, target); GeneratedClass generatedClass = lookupGeneratedClass(generationContext, target);
GeneratedMethods generatedMethods = generatedClass.getMethods().withPrefix(getName()); GeneratedMethods generatedMethods = generatedClass.getMethods().withPrefix(getName());
@ -178,8 +165,7 @@ class BeanDefinitionMethodGenerator {
BeanRegistrationCodeFragments codeFragments, Modifier modifier) { BeanRegistrationCodeFragments codeFragments, Modifier modifier) {
BeanRegistrationCodeGenerator codeGenerator = new BeanRegistrationCodeGenerator( BeanRegistrationCodeGenerator codeGenerator = new BeanRegistrationCodeGenerator(
className, generatedMethods, this.registeredBean, className, generatedMethods, this.registeredBean, codeFragments);
this.constructorOrFactoryMethod, codeFragments);
this.aotContributions.forEach(aotContribution -> aotContribution.applyTo(generationContext, codeGenerator)); this.aotContributions.forEach(aotContribution -> aotContribution.applyTo(generationContext, codeGenerator));
@ -218,52 +204,4 @@ class BeanDefinitionMethodGenerator {
return StringUtils.uncapitalize(beanName); return StringUtils.uncapitalize(beanName);
} }
private void registerRuntimeHintsIfNecessary(RuntimeHints runtimeHints) {
if (this.registeredBean.getBeanFactory() instanceof DefaultListableBeanFactory dlbf) {
ProxyRuntimeHintsRegistrar registrar = new ProxyRuntimeHintsRegistrar(dlbf.getAutowireCandidateResolver());
if (this.constructorOrFactoryMethod instanceof Method method) {
registrar.registerRuntimeHints(runtimeHints, method);
}
else if (this.constructorOrFactoryMethod instanceof Constructor<?> constructor) {
registrar.registerRuntimeHints(runtimeHints, constructor);
}
}
}
private static class ProxyRuntimeHintsRegistrar {
private final AutowireCandidateResolver candidateResolver;
public ProxyRuntimeHintsRegistrar(AutowireCandidateResolver candidateResolver) {
this.candidateResolver = candidateResolver;
}
public void registerRuntimeHints(RuntimeHints runtimeHints, Method method) {
Class<?>[] parameterTypes = method.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
MethodParameter methodParam = new MethodParameter(method, i);
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(methodParam, true);
registerProxyIfNecessary(runtimeHints, dependencyDescriptor);
}
}
public void registerRuntimeHints(RuntimeHints runtimeHints, Constructor<?> constructor) {
Class<?>[] parameterTypes = constructor.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
MethodParameter methodParam = new MethodParameter(constructor, i);
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(
methodParam, true);
registerProxyIfNecessary(runtimeHints, dependencyDescriptor);
}
}
private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) {
Class<?> proxyType = this.candidateResolver.getLazyResolutionProxyClass(dependencyDescriptor, null);
if (proxyType != null && Proxy.isProxyClass(proxyType)) {
runtimeHints.proxies().registerJdkProxy(proxyType.getInterfaces());
}
}
}
} }

36
spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragments.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,9 +16,9 @@
package org.springframework.beans.factory.aot; package org.springframework.beans.factory.aot;
import java.lang.reflect.Executable;
import java.util.List; import java.util.List;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.function.UnaryOperator;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference;
@ -31,9 +31,19 @@ import org.springframework.javapoet.CodeBlock;
/** /**
* Generate the various fragments of code needed to register a bean. * Generate the various fragments of code needed to register a bean.
* * <p>
* A default implementation is provided that suits most needs and custom code
* fragments are only expected to be used by library authors having built custom
* arrangement on top of the core container.
* <p>
* Users are not expected to implement this interface directly, but rather extends
* from {@link BeanRegistrationCodeFragmentsDecorator} and only override the
* necessary method(s).
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll
* @since 6.0 * @since 6.0
* @see BeanRegistrationCodeFragmentsDecorator
* @see BeanRegistrationAotContribution#withCustomCodeFragments(UnaryOperator)
*/ */
public interface BeanRegistrationCodeFragments { public interface BeanRegistrationCodeFragments {
@ -50,16 +60,19 @@ public interface BeanRegistrationCodeFragments {
/** /**
* Return the target for the registration. Used to determine where to write * Return the target for the registration. Used to determine where to write
* the code. * the code. This should take into account visibility issue, such as
* package access of an element of the bean to register.
* @param registeredBean the registered bean * @param registeredBean the registered bean
* @param constructorOrFactoryMethod the constructor or factory method
* @return the target {@link ClassName} * @return the target {@link ClassName}
*/ */
ClassName getTarget(RegisteredBean registeredBean, ClassName getTarget(RegisteredBean registeredBean);
Executable constructorOrFactoryMethod);
/** /**
* Generate the code that defines the new bean definition instance. * Generate the code that defines the new bean definition instance.
* <p>
* This should declare a variable named {@value BEAN_DEFINITION_VARIABLE}
* so that further fragments can refer to the variable to further tune
* the bean definition.
* @param generationContext the generation context * @param generationContext the generation context
* @param beanType the bean type * @param beanType the bean type
* @param beanRegistrationCode the bean registration code * @param beanRegistrationCode the bean registration code
@ -81,6 +94,11 @@ public interface BeanRegistrationCodeFragments {
/** /**
* Generate the code that sets the instance supplier on the bean definition. * Generate the code that sets the instance supplier on the bean definition.
* <p>
* The {@code postProcessors} represent methods to be exposed once the
* instance has been created to further configure it. Each method should
* accept two parameters, the {@link RegisteredBean} and the bean
* instance, and should return the modified bean instance.
* @param generationContext the generation context * @param generationContext the generation context
* @param beanRegistrationCode the bean registration code * @param beanRegistrationCode the bean registration code
* @param instanceSupplierCode the instance supplier code supplier code * @param instanceSupplierCode the instance supplier code supplier code
@ -96,15 +114,13 @@ public interface BeanRegistrationCodeFragments {
* Generate the instance supplier code. * Generate the instance supplier code.
* @param generationContext the generation context * @param generationContext the generation context
* @param beanRegistrationCode the bean registration code * @param beanRegistrationCode the bean registration code
* @param constructorOrFactoryMethod the constructor or factory method for
* the bean
* @param allowDirectSupplierShortcut if direct suppliers may be used rather * @param allowDirectSupplierShortcut if direct suppliers may be used rather
* than always needing an {@link InstanceSupplier} * than always needing an {@link InstanceSupplier}
* @return the generated code * @return the generated code
*/ */
CodeBlock generateInstanceSupplierCode( CodeBlock generateInstanceSupplierCode(
GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode, GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode,
Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut); boolean allowDirectSupplierShortcut);
/** /**
* Generate the return statement. * Generate the return statement.

12
spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeFragmentsDecorator.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,7 +16,6 @@
package org.springframework.beans.factory.aot; package org.springframework.beans.factory.aot;
import java.lang.reflect.Executable;
import java.util.List; import java.util.List;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.function.UnaryOperator; import java.util.function.UnaryOperator;
@ -51,8 +50,8 @@ public class BeanRegistrationCodeFragmentsDecorator implements BeanRegistrationC
} }
@Override @Override
public ClassName getTarget(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { public ClassName getTarget(RegisteredBean registeredBean) {
return this.delegate.getTarget(registeredBean, constructorOrFactoryMethod); return this.delegate.getTarget(registeredBean);
} }
@Override @Override
@ -83,11 +82,10 @@ public class BeanRegistrationCodeFragmentsDecorator implements BeanRegistrationC
@Override @Override
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
BeanRegistrationCode beanRegistrationCode, Executable constructorOrFactoryMethod, BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) {
boolean allowDirectSupplierShortcut) {
return this.delegate.generateInstanceSupplierCode(generationContext, return this.delegate.generateInstanceSupplierCode(generationContext,
beanRegistrationCode, constructorOrFactoryMethod, allowDirectSupplierShortcut); beanRegistrationCode, allowDirectSupplierShortcut);
} }
@Override @Override

12
spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanRegistrationCodeGenerator.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2022 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,7 +16,6 @@
package org.springframework.beans.factory.aot; package org.springframework.beans.factory.aot;
import java.lang.reflect.Executable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.function.Predicate; import java.util.function.Predicate;
@ -47,19 +46,15 @@ class BeanRegistrationCodeGenerator implements BeanRegistrationCode {
private final RegisteredBean registeredBean; private final RegisteredBean registeredBean;
private final Executable constructorOrFactoryMethod;
private final BeanRegistrationCodeFragments codeFragments; private final BeanRegistrationCodeFragments codeFragments;
BeanRegistrationCodeGenerator(ClassName className, GeneratedMethods generatedMethods, BeanRegistrationCodeGenerator(ClassName className, GeneratedMethods generatedMethods,
RegisteredBean registeredBean, Executable constructorOrFactoryMethod, RegisteredBean registeredBean, BeanRegistrationCodeFragments codeFragments) {
BeanRegistrationCodeFragments codeFragments) {
this.className = className; this.className = className;
this.generatedMethods = generatedMethods; this.generatedMethods = generatedMethods;
this.registeredBean = registeredBean; this.registeredBean = registeredBean;
this.constructorOrFactoryMethod = constructorOrFactoryMethod;
this.codeFragments = codeFragments; this.codeFragments = codeFragments;
} }
@ -87,8 +82,7 @@ class BeanRegistrationCodeGenerator implements BeanRegistrationCode {
generationContext, this, this.registeredBean.getMergedBeanDefinition(), generationContext, this, this.registeredBean.getMergedBeanDefinition(),
REJECT_ALL_ATTRIBUTES_FILTER)); REJECT_ALL_ATTRIBUTES_FILTER));
CodeBlock instanceSupplierCode = this.codeFragments.generateInstanceSupplierCode( CodeBlock instanceSupplierCode = this.codeFragments.generateInstanceSupplierCode(
generationContext, this, this.constructorOrFactoryMethod, generationContext, this, this.instancePostProcessors.isEmpty());
this.instancePostProcessors.isEmpty());
code.add(this.codeFragments.generateSetBeanInstanceSupplierCode(generationContext, code.add(this.codeFragments.generateSetBeanInstanceSupplierCode(generationContext,
this, instanceSupplierCode, this.instancePostProcessors)); this, instanceSupplierCode, this.instancePostProcessors));
code.add(this.codeFragments.generateReturnCode(generationContext, this)); code.add(this.codeFragments.generateReturnCode(generationContext, this));

17
spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java

@ -21,6 +21,7 @@ import java.lang.reflect.Executable;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.util.List; import java.util.List;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.function.Supplier;
import org.springframework.aot.generate.AccessControl; import org.springframework.aot.generate.AccessControl;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
@ -39,12 +40,14 @@ import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.function.SingletonSupplier;
/** /**
* Internal {@link BeanRegistrationCodeFragments} implementation used by * Internal {@link BeanRegistrationCodeFragments} implementation used by
* default. * default.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Stephane Nicoll
*/ */
class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragments { class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragments {
@ -54,6 +57,8 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory; private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory;
private final Supplier<Executable> constructorOrFactoryMethod;
DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode, DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode,
RegisteredBean registeredBean, RegisteredBean registeredBean,
@ -62,14 +67,13 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
this.beanRegistrationsCode = beanRegistrationsCode; this.beanRegistrationsCode = beanRegistrationsCode;
this.registeredBean = registeredBean; this.registeredBean = registeredBean;
this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory; this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory;
this.constructorOrFactoryMethod = SingletonSupplier.of(registeredBean::resolveConstructorOrFactoryMethod);
} }
@Override @Override
public ClassName getTarget(RegisteredBean registeredBean, public ClassName getTarget(RegisteredBean registeredBean) {
Executable constructorOrFactoryMethod) { Class<?> target = extractDeclaringClass(registeredBean.getBeanType(), this.constructorOrFactoryMethod.get());
Class<?> target = extractDeclaringClass(registeredBean.getBeanType(), constructorOrFactoryMethod);
while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) { while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) {
RegisteredBean parent = registeredBean.getParent(); RegisteredBean parent = registeredBean.getParent();
Assert.state(parent != null, "No parent available for inner bean"); Assert.state(parent != null, "No parent available for inner bean");
@ -219,12 +223,11 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
@Override @Override
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
BeanRegistrationCode beanRegistrationCode, BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) {
Executable constructorOrFactoryMethod, boolean allowDirectSupplierShortcut) {
return new InstanceSupplierCodeGenerator(generationContext, return new InstanceSupplierCodeGenerator(generationContext,
beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut)
.generateCode(this.registeredBean, constructorOrFactoryMethod); .generateCode(this.registeredBean,this.constructorOrFactoryMethod.get());
} }
@Override @Override

83
spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java

@ -21,6 +21,7 @@ import java.lang.reflect.Executable;
import java.lang.reflect.Member; import java.lang.reflect.Member;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.lang.reflect.Proxy;
import java.util.Arrays; import java.util.Arrays;
import java.util.function.Consumer; import java.util.function.Consumer;
@ -38,9 +39,14 @@ import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.ExecutableMode;
import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.ReflectionHints; import org.springframework.aot.hint.ReflectionHints;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.config.DependencyDescriptor;
import org.springframework.beans.factory.support.AutowireCandidateResolver;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.InstanceSupplier;
import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.KotlinDetector; import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
@ -51,9 +57,10 @@ import org.springframework.util.ClassUtils;
import org.springframework.util.function.ThrowingSupplier; import org.springframework.util.function.ThrowingSupplier;
/** /**
* Internal code generator to create an {@link InstanceSupplier}, usually in * Default code generator to create an {@link InstanceSupplier}, usually in
* the form of a {@link BeanInstanceSupplier} that retains the executable * the form of a {@link BeanInstanceSupplier} that retains the executable
* that is used to instantiate the bean. * that is used to instantiate the bean. Takes care of registering the
* necessary hints if reflection or a JDK proxy is required.
* *
* <p>Generated code is usually a method reference that generates the * <p>Generated code is usually a method reference that generates the
* {@link BeanInstanceSupplier}, but some shortcut can be used as well such as: * {@link BeanInstanceSupplier}, but some shortcut can be used as well such as:
@ -66,8 +73,9 @@ import org.springframework.util.function.ThrowingSupplier;
* @author Juergen Hoeller * @author Juergen Hoeller
* @author Sebastien Deleuze * @author Sebastien Deleuze
* @since 6.0 * @since 6.0
* @see BeanRegistrationCodeFragments
*/ */
class InstanceSupplierCodeGenerator { public class InstanceSupplierCodeGenerator {
private static final String REGISTERED_BEAN_PARAMETER_NAME = "registeredBean"; private static final String REGISTERED_BEAN_PARAMETER_NAME = "registeredBean";
@ -89,7 +97,15 @@ class InstanceSupplierCodeGenerator {
private final boolean allowDirectSupplierShortcut; private final boolean allowDirectSupplierShortcut;
InstanceSupplierCodeGenerator(GenerationContext generationContext, /**
* Create a new instance.
* @param generationContext the generation context
* @param className the class name of the bean to instantiate
* @param generatedMethods the generated methods
* @param allowDirectSupplierShortcut whether a direct supplier may be used rather
* than always needing an {@link InstanceSupplier}
*/
public InstanceSupplierCodeGenerator(GenerationContext generationContext,
ClassName className, GeneratedMethods generatedMethods, boolean allowDirectSupplierShortcut) { ClassName className, GeneratedMethods generatedMethods, boolean allowDirectSupplierShortcut) {
this.generationContext = generationContext; this.generationContext = generationContext;
@ -98,8 +114,14 @@ class InstanceSupplierCodeGenerator {
this.allowDirectSupplierShortcut = allowDirectSupplierShortcut; this.allowDirectSupplierShortcut = allowDirectSupplierShortcut;
} }
/**
CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) { * Generate the instance supplier code.
* @param registeredBean the bean to handle
* @param constructorOrFactoryMethod the executable to use to create the bean
* @return the generated code
*/
public CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) {
registerRuntimeHintsIfNecessary(registeredBean, constructorOrFactoryMethod);
if (constructorOrFactoryMethod instanceof Constructor<?> constructor) { if (constructorOrFactoryMethod instanceof Constructor<?> constructor) {
return generateCodeForConstructor(registeredBean, constructor); return generateCodeForConstructor(registeredBean, constructor);
} }
@ -110,6 +132,19 @@ class InstanceSupplierCodeGenerator {
"No suitable executor found for " + registeredBean.getBeanName()); "No suitable executor found for " + registeredBean.getBeanName());
} }
private void registerRuntimeHintsIfNecessary(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) {
if (registeredBean.getBeanFactory() instanceof DefaultListableBeanFactory dlbf) {
RuntimeHints runtimeHints = this.generationContext.getRuntimeHints();
ProxyRuntimeHintsRegistrar registrar = new ProxyRuntimeHintsRegistrar(dlbf.getAutowireCandidateResolver());
if (constructorOrFactoryMethod instanceof Method method) {
registrar.registerRuntimeHints(runtimeHints, method);
}
else if (constructorOrFactoryMethod instanceof Constructor<?> constructor) {
registrar.registerRuntimeHints(runtimeHints, constructor);
}
}
}
private CodeBlock generateCodeForConstructor(RegisteredBean registeredBean, Constructor<?> constructor) { private CodeBlock generateCodeForConstructor(RegisteredBean registeredBean, Constructor<?> constructor) {
String beanName = registeredBean.getBeanName(); String beanName = registeredBean.getBeanName();
Class<?> beanClass = registeredBean.getBeanClass(); Class<?> beanClass = registeredBean.getBeanClass();
@ -372,4 +407,40 @@ class InstanceSupplierCodeGenerator {
} }
private static class ProxyRuntimeHintsRegistrar {
private final AutowireCandidateResolver candidateResolver;
public ProxyRuntimeHintsRegistrar(AutowireCandidateResolver candidateResolver) {
this.candidateResolver = candidateResolver;
}
public void registerRuntimeHints(RuntimeHints runtimeHints, Method method) {
Class<?>[] parameterTypes = method.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
MethodParameter methodParam = new MethodParameter(method, i);
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(methodParam, true);
registerProxyIfNecessary(runtimeHints, dependencyDescriptor);
}
}
public void registerRuntimeHints(RuntimeHints runtimeHints, Constructor<?> constructor) {
Class<?>[] parameterTypes = constructor.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
MethodParameter methodParam = new MethodParameter(constructor, i);
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(
methodParam, true);
registerProxyIfNecessary(runtimeHints, dependencyDescriptor);
}
}
private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) {
Class<?> proxyType = this.candidateResolver.getLazyResolutionProxyClass(dependencyDescriptor, null);
if (proxyType != null && Proxy.isProxyClass(proxyType)) {
runtimeHints.proxies().registerJdkProxy(proxyType.getInterfaces());
}
}
}
} }

166
spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java

@ -16,10 +16,14 @@
package org.springframework.beans.factory.aot; package org.springframework.beans.factory.aot;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.function.UnaryOperator;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.annotation.InjectAnnotationBeanPostProcessorTests.StringFactoryBean; import org.springframework.beans.factory.annotation.InjectAnnotationBeanPostProcessorTests.StringFactoryBean;
@ -28,6 +32,7 @@ import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.DummyFactory; import org.springframework.beans.testfixture.beans.factory.DummyFactory;
import org.springframework.beans.testfixture.beans.factory.aot.GenericFactoryBean; import org.springframework.beans.testfixture.beans.factory.aot.GenericFactoryBean;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationCode;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode;
import org.springframework.beans.testfixture.beans.factory.aot.NumberFactoryBean; import org.springframework.beans.testfixture.beans.factory.aot.NumberFactoryBean;
import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean;
@ -35,9 +40,14 @@ import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanConfigu
import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanFactoryBean; import org.springframework.beans.testfixture.beans.factory.aot.SimpleBeanFactoryBean;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.lang.Nullable;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
/** /**
* Tests for {@link DefaultBeanRegistrationCodeFragments}. * Tests for {@link DefaultBeanRegistrationCodeFragments}.
@ -48,136 +58,202 @@ class DefaultBeanRegistrationCodeFragmentsTests {
private final BeanRegistrationsCode beanRegistrationsCode = new MockBeanRegistrationsCode(new TestGenerationContext()); private final BeanRegistrationsCode beanRegistrationsCode = new MockBeanRegistrationsCode(new TestGenerationContext());
private final GenerationContext generationContext = new TestGenerationContext();
private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
@Test @Test
void getTargetOnConstructor() { void getTargetOnConstructor() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class); RegisteredBean registeredBean = registerTestBean(SimpleBean.class,
assertTarget(createInstance(registeredBean).getTarget(registeredBean, SimpleBean.class.getDeclaredConstructors()[0]);
SimpleBean.class.getDeclaredConstructors()[0]), SimpleBean.class); assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class);
} }
@Test @Test
void getTargetOnConstructorToPublicFactoryBean() { void getTargetOnConstructorToPublicFactoryBean() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class); RegisteredBean registeredBean = registerTestBean(SimpleBean.class,
assertTarget(createInstance(registeredBean).getTarget(registeredBean, SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]);
SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class);
} }
@Test @Test
void getTargetOnConstructorToPublicGenericFactoryBeanExtractTargetFromFactoryBeanType() { void getTargetOnConstructorToPublicGenericFactoryBeanExtractTargetFromFactoryBeanType() {
RegisteredBean registeredBean = registerTestBean(ResolvableType ResolvableType beanType = ResolvableType.forClassWithGenerics(
.forClassWithGenerics(GenericFactoryBean.class, SimpleBean.class)); GenericFactoryBean.class, SimpleBean.class);
assertTarget(createInstance(registeredBean).getTarget(registeredBean, RegisteredBean registeredBean = registerTestBean(beanType,
GenericFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); GenericFactoryBean.class.getDeclaredConstructors()[0]);
assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class);
} }
@Test @Test
void getTargetOnConstructorToPublicGenericFactoryBeanWithBoundExtractTargetFromFactoryBeanType() { void getTargetOnConstructorToPublicGenericFactoryBeanWithBoundExtractTargetFromFactoryBeanType() {
RegisteredBean registeredBean = registerTestBean(ResolvableType ResolvableType beanType = ResolvableType.forClassWithGenerics(
.forClassWithGenerics(NumberFactoryBean.class, Integer.class)); NumberFactoryBean.class, Integer.class);
assertTarget(createInstance(registeredBean).getTarget(registeredBean, RegisteredBean registeredBean = registerTestBean(beanType,
NumberFactoryBean.class.getDeclaredConstructors()[0]), Integer.class); NumberFactoryBean.class.getDeclaredConstructors()[0]);
assertTarget(createInstance(registeredBean).getTarget(registeredBean), Integer.class);
} }
@Test @Test
void getTargetOnConstructorToPublicGenericFactoryBeanUseBeanTypeAsFallback() { void getTargetOnConstructorToPublicGenericFactoryBeanUseBeanTypeAsFallback() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class); RegisteredBean registeredBean = registerTestBean(SimpleBean.class,
assertTarget(createInstance(registeredBean).getTarget(registeredBean, GenericFactoryBean.class.getDeclaredConstructors()[0]);
GenericFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); assertTarget(createInstance(registeredBean).getTarget(registeredBean), SimpleBean.class);
} }
@Test @Test
void getTargetOnConstructorToProtectedFactoryBean() { void getTargetOnConstructorToProtectedFactoryBean() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class); RegisteredBean registeredBean = registerTestBean(SimpleBean.class,
assertTarget(createInstance(registeredBean).getTarget(registeredBean, PrivilegedTestBeanFactoryBean.class.getDeclaredConstructors()[0]);
PrivilegedTestBeanFactoryBean.class.getDeclaredConstructors()[0]), assertTarget(createInstance(registeredBean).getTarget(registeredBean),
PrivilegedTestBeanFactoryBean.class); PrivilegedTestBeanFactoryBean.class);
} }
@Test @Test
void getTargetOnMethod() { void getTargetOnMethod() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean"); Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean");
assertThat(method).isNotNull(); assertThat(method).isNotNull();
assertTarget(createInstance(registeredBean).getTarget(registeredBean, method), RegisteredBean registeredBean = registerTestBean(SimpleBean.class, method);
assertTarget(createInstance(registeredBean).getTarget(registeredBean),
SimpleBeanConfiguration.class); SimpleBeanConfiguration.class);
} }
@Test @Test
void getTargetOnMethodWithInnerBeanInJavaPackage() { void getTargetOnMethodWithInnerBeanInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class); RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(String.class));
Method method = ReflectionUtils.findMethod(getClass(), "createString"); Method method = ReflectionUtils.findMethod(getClass(), "createString");
assertThat(method).isNotNull(); assertThat(method).isNotNull();
assertTarget(createInstance(innerBean).getTarget(innerBean, method), getClass()); RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
applyConstructorOrFactoryMethod(new RootBeanDefinition(String.class), method));
assertTarget(createInstance(innerBean).getTarget(innerBean), getClass());
} }
@Test @Test
void getTargetOnConstructorWithInnerBeanInJavaPackage() { void getTargetOnConstructorWithInnerBeanInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class); RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class)); RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod(
assertTarget(createInstance(innerBean).getTarget(innerBean, new RootBeanDefinition(String.class), String.class.getDeclaredConstructors()[0]);
String.class.getDeclaredConstructors()[0]), SimpleBean.class); RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
innerBeanDefinition);
assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class);
} }
@Test @Test
void getTargetOnConstructorWithInnerBeanOnTypeInJavaPackage() { void getTargetOnConstructorWithInnerBeanOnTypeInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(SimpleBean.class); RegisteredBean registeredBean = registerTestBean(SimpleBean.class);
RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod(
new RootBeanDefinition(StringFactoryBean.class),
StringFactoryBean.class.getDeclaredConstructors()[0]);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(StringFactoryBean.class)); innerBeanDefinition);
assertTarget(createInstance(innerBean).getTarget(innerBean, assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class);
StringFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class);
} }
@Test @Test
void getTargetOnMethodWithInnerBeanInRegularPackage() { void getTargetOnMethodWithInnerBeanInRegularPackage() {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class); RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(SimpleBean.class));
Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean"); Method method = ReflectionUtils.findMethod(SimpleBeanConfiguration.class, "simpleBean");
assertThat(method).isNotNull(); assertThat(method).isNotNull();
assertTarget(createInstance(innerBean).getTarget(innerBean, method), RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
applyConstructorOrFactoryMethod(new RootBeanDefinition(SimpleBean.class), method));
assertTarget(createInstance(innerBean).getTarget(innerBean),
SimpleBeanConfiguration.class); SimpleBeanConfiguration.class);
} }
@Test @Test
void getTargetOnConstructorWithInnerBeanInRegularPackage() { void getTargetOnConstructorWithInnerBeanInRegularPackage() {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class); RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod(
new RootBeanDefinition(SimpleBean.class), SimpleBean.class.getDeclaredConstructors()[0]);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(SimpleBean.class)); innerBeanDefinition);
assertTarget(createInstance(innerBean).getTarget(innerBean, assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class);
SimpleBean.class.getDeclaredConstructors()[0]), SimpleBean.class);
} }
@Test @Test
void getTargetOnConstructorWithInnerBeanOnFactoryBeanOnTypeInRegularPackage() { void getTargetOnConstructorWithInnerBeanOnFactoryBeanOnTypeInRegularPackage() {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class); RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RootBeanDefinition innerBeanDefinition = applyConstructorOrFactoryMethod(
new RootBeanDefinition(SimpleBean.class),
SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(SimpleBean.class)); innerBeanDefinition);
assertTarget(createInstance(innerBean).getTarget(innerBean, assertTarget(createInstance(innerBean).getTarget(innerBean), SimpleBean.class);
SimpleBeanFactoryBean.class.getDeclaredConstructors()[0]), SimpleBean.class); }
@Test
void customizedGetTargetDoesNotResolveConstructorOrFactoryMethod() {
RegisteredBean registeredBean = spy(registerTestBean(SimpleBean.class));
BeanRegistrationCodeFragments customCodeFragments = createCustomCodeFragments(registeredBean, codeFragments -> new BeanRegistrationCodeFragmentsDecorator(codeFragments) {
@Override
public ClassName getTarget(RegisteredBean registeredBean) {
return ClassName.get(String.class);
}
});
assertTarget(customCodeFragments.getTarget(registeredBean), String.class);
verify(registeredBean, never()).resolveConstructorOrFactoryMethod();
}
@Test
void customizedGenerateInstanceSupplierCodeDoesNotResolveConstructorOrFactoryMethod() {
RegisteredBean registeredBean = spy(registerTestBean(SimpleBean.class));
BeanRegistrationCodeFragments customCodeFragments = createCustomCodeFragments(registeredBean, codeFragments -> new BeanRegistrationCodeFragmentsDecorator(codeFragments) {
@Override
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) {
return CodeBlock.of("// Hello");
}
});
assertThat(customCodeFragments.generateInstanceSupplierCode(this.generationContext,
new MockBeanRegistrationCode(this.generationContext), false)).hasToString("// Hello");
verify(registeredBean, never()).resolveConstructorOrFactoryMethod();
}
private BeanRegistrationCodeFragments createCustomCodeFragments(RegisteredBean registeredBean, UnaryOperator<BeanRegistrationCodeFragments> customFragments) {
BeanRegistrationAotContribution aotContribution = BeanRegistrationAotContribution.
withCustomCodeFragments(customFragments);
BeanRegistrationCodeFragments defaultCodeFragments = createInstance(registeredBean);
return aotContribution.customizeBeanRegistrationCodeFragments(
this.generationContext, defaultCodeFragments);
} }
private void assertTarget(ClassName target, Class<?> expected) { private void assertTarget(ClassName target, Class<?> expected) {
assertThat(target).isEqualTo(ClassName.get(expected)); assertThat(target).isEqualTo(ClassName.get(expected));
} }
private RegisteredBean registerTestBean(Class<?> beanType) { private RegisteredBean registerTestBean(Class<?> beanType) {
this.beanFactory.registerBeanDefinition("testBean", return registerTestBean(beanType, null);
new RootBeanDefinition(beanType)); }
private RegisteredBean registerTestBean(Class<?> beanType,
@Nullable Executable constructorOrFactoryMethod) {
this.beanFactory.registerBeanDefinition("testBean", applyConstructorOrFactoryMethod(
new RootBeanDefinition(beanType), constructorOrFactoryMethod));
return RegisteredBean.of(this.beanFactory, "testBean"); return RegisteredBean.of(this.beanFactory, "testBean");
} }
private RegisteredBean registerTestBean(ResolvableType beanType) {
private RegisteredBean registerTestBean(ResolvableType beanType,
@Nullable Executable constructorOrFactoryMethod) {
RootBeanDefinition beanDefinition = new RootBeanDefinition(); RootBeanDefinition beanDefinition = new RootBeanDefinition();
beanDefinition.setTargetType(beanType); beanDefinition.setTargetType(beanType);
this.beanFactory.registerBeanDefinition("testBean", beanDefinition); this.beanFactory.registerBeanDefinition("testBean",
applyConstructorOrFactoryMethod(beanDefinition, constructorOrFactoryMethod));
return RegisteredBean.of(this.beanFactory, "testBean"); return RegisteredBean.of(this.beanFactory, "testBean");
} }
private RootBeanDefinition applyConstructorOrFactoryMethod(RootBeanDefinition beanDefinition,
@Nullable Executable constructorOrFactoryMethod) {
if (constructorOrFactoryMethod instanceof Method method) {
beanDefinition.setResolvedFactoryMethod(method);
}
else if (constructorOrFactoryMethod instanceof Constructor<?> constructor) {
beanDefinition.setAttribute(RootBeanDefinition.PREFERRED_CONSTRUCTORS_ATTRIBUTE, constructor);
}
return beanDefinition;
}
private BeanRegistrationCodeFragments createInstance(RegisteredBean registeredBean) { private BeanRegistrationCodeFragments createInstance(RegisteredBean registeredBean) {
return new DefaultBeanRegistrationCodeFragments(this.beanRegistrationsCode, registeredBean, return new DefaultBeanRegistrationCodeFragments(this.beanRegistrationsCode, registeredBean,
new BeanDefinitionMethodGeneratorFactory(this.beanFactory)); new BeanDefinitionMethodGeneratorFactory(this.beanFactory));

22
spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java

@ -58,6 +58,7 @@ import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments; import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments;
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator; import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator;
import org.springframework.beans.factory.aot.InstanceSupplierCodeGenerator;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.config.BeanDefinitionHolder;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
@ -315,9 +316,8 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
Object configClassAttr = registeredBean.getMergedBeanDefinition() Object configClassAttr = registeredBean.getMergedBeanDefinition()
.getAttribute(ConfigurationClassUtils.CONFIGURATION_CLASS_ATTRIBUTE); .getAttribute(ConfigurationClassUtils.CONFIGURATION_CLASS_ATTRIBUTE);
if (ConfigurationClassUtils.CONFIGURATION_CLASS_FULL.equals(configClassAttr)) { if (ConfigurationClassUtils.CONFIGURATION_CLASS_FULL.equals(configClassAttr)) {
Class<?> proxyClass = registeredBean.getBeanType().toClass();
return BeanRegistrationAotContribution.withCustomCodeFragments(codeFragments -> return BeanRegistrationAotContribution.withCustomCodeFragments(codeFragments ->
new ConfigurationClassProxyBeanRegistrationCodeFragments(codeFragments, proxyClass)); new ConfigurationClassProxyBeanRegistrationCodeFragments(codeFragments, registeredBean));
} }
return null; return null;
} }
@ -749,12 +749,15 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
private static class ConfigurationClassProxyBeanRegistrationCodeFragments extends BeanRegistrationCodeFragmentsDecorator { private static class ConfigurationClassProxyBeanRegistrationCodeFragments extends BeanRegistrationCodeFragmentsDecorator {
private final RegisteredBean registeredBean;
private final Class<?> proxyClass; private final Class<?> proxyClass;
public ConfigurationClassProxyBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments, public ConfigurationClassProxyBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments,
Class<?> proxyClass) { RegisteredBean registeredBean) {
super(codeFragments); super(codeFragments);
this.proxyClass = proxyClass; this.registeredBean = registeredBean;
this.proxyClass = registeredBean.getBeanType().toClass();
} }
@Override @Override
@ -770,11 +773,14 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo
@Override @Override
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
BeanRegistrationCode beanRegistrationCode, Executable constructorOrFactoryMethod, BeanRegistrationCode beanRegistrationCode,
boolean allowDirectSupplierShortcut) { boolean allowDirectSupplierShortcut) {
Executable executableToUse = proxyExecutable(generationContext.getRuntimeHints(), constructorOrFactoryMethod);
return super.generateInstanceSupplierCode(generationContext, beanRegistrationCode, Executable executableToUse = proxyExecutable(generationContext.getRuntimeHints(),
executableToUse, allowDirectSupplierShortcut); this.registeredBean.resolveConstructorOrFactoryMethod());
return new InstanceSupplierCodeGenerator(generationContext,
beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut)
.generateCode(this.registeredBean, executableToUse);
} }
private Executable proxyExecutable(RuntimeHints runtimeHints, Executable userExecutable) { private Executable proxyExecutable(RuntimeHints runtimeHints, Executable userExecutable) {

2
spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java

@ -17,7 +17,6 @@
package org.springframework.orm.jpa.persistenceunit; package org.springframework.orm.jpa.persistenceunit;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.lang.reflect.Executable;
import java.util.List; import java.util.List;
import javax.lang.model.element.Modifier; import javax.lang.model.element.Modifier;
@ -97,7 +96,6 @@ class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistr
@Override @Override
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
BeanRegistrationCode beanRegistrationCode, BeanRegistrationCode beanRegistrationCode,
Executable constructorOrFactoryMethod,
boolean allowDirectSupplierShortcut) { boolean allowDirectSupplierShortcut) {
PersistenceManagedTypes persistenceManagedTypes = this.registeredBean.getBeanFactory() PersistenceManagedTypes persistenceManagedTypes = this.registeredBean.getBeanFactory()
.getBean(this.registeredBean.getBeanName(), PersistenceManagedTypes.class); .getBean(this.registeredBean.getBeanName(), PersistenceManagedTypes.class);

Loading…
Cancel
Save