diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java index 02d1d991638..4b1bc3a1aa9 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGenerator.java @@ -16,16 +16,28 @@ package org.springframework.beans.factory.aot; +import java.beans.BeanInfo; +import java.beans.IntrospectionException; +import java.beans.Introspector; +import java.beans.PropertyDescriptor; import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.function.BiFunction; import java.util.function.BiPredicate; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.hint.ExecutableHint; +import org.springframework.aot.hint.ExecutableMode; import org.springframework.aot.hint.RuntimeHints; +import org.springframework.beans.BeanInfoFactory; +import org.springframework.beans.ExtendedBeanInfoFactory; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.PropertyValue; import org.springframework.beans.factory.config.BeanDefinition; @@ -36,6 +48,7 @@ import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.ReflectionUtils; @@ -69,6 +82,9 @@ class BeanDefinitionPropertiesCodeGenerator { private static final String BEAN_DEFINITION_VARIABLE = BeanRegistrationCodeFragments.BEAN_DEFINITION_VARIABLE; + private static final Consumer INVOKE_HINT = hint -> hint.withMode(ExecutableMode.INVOKE); + + private static final BeanInfoFactory beanInfoFactory = new ExtendedBeanInfoFactory(); private final RuntimeHints hints; @@ -125,14 +141,14 @@ class BeanDefinitionPropertiesCodeGenerator { AbstractBeanDefinition beanDefinition, String[] methodNames, String format) { if (!ObjectUtils.isEmpty(methodNames)) { - Class beanUserClass = ClassUtils + Class beanType = ClassUtils .getUserClass(beanDefinition.getResolvableType().toClass()); Builder arguments = CodeBlock.builder(); for (int i = 0; i < methodNames.length; i++) { String methodName = methodNames[i]; if (!AbstractBeanDefinition.INFER_METHOD.equals(methodName)) { arguments.add((i != 0) ? ", $S" : "$S", methodName); - addInitDestroyHint(beanUserClass, methodName); + addInitDestroyHint(beanType, methodName); } } builder.addStatement(format, BEAN_DEFINITION_VARIABLE, arguments.build()); @@ -181,9 +197,53 @@ class BeanDefinitionPropertiesCodeGenerator { builder.addStatement("$L.getPropertyValues().addPropertyValue($S, $L)", BEAN_DEFINITION_VARIABLE, propertyValue.getName(), code); } + Class beanType = ClassUtils + .getUserClass(beanDefinition.getResolvableType().toClass()); + BeanInfo beanInfo = (beanType != Object.class) ? getBeanInfo(beanType) : null; + if (beanInfo != null) { + Map writeMethods = getWriteMethods(beanInfo); + for (PropertyValue propertyValue : propertyValues) { + Method writeMethod = writeMethods.get(propertyValue.getName()); + if (writeMethod != null) { + this.hints.reflection().registerMethod(writeMethod, INVOKE_HINT); + } + } + } } } + @Nullable + private BeanInfo getBeanInfo(Class beanType) { + try { + BeanInfo beanInfo = beanInfoFactory.getBeanInfo(beanType); + if (beanInfo != null) { + return beanInfo; + } + return Introspector.getBeanInfo(beanType, Introspector.IGNORE_ALL_BEANINFO); + } + catch (IntrospectionException ex) { + return null; + } + } + + private Map getWriteMethods(BeanInfo beanInfo) { + Map writeMethods = new HashMap<>(); + for (PropertyDescriptor propertyDescriptor : beanInfo.getPropertyDescriptors()) { + writeMethods.put(propertyDescriptor.getName(), + propertyDescriptor.getWriteMethod()); + } + return Collections.unmodifiableMap(writeMethods); + } + + @Nullable + private Method findWriteMethod(BeanInfo beanInfo, String propertyName) { + return Arrays.stream(beanInfo.getPropertyDescriptors()) + .filter(pd -> propertyName.equals(pd.getName())) + .map(java.beans.PropertyDescriptor::getWriteMethod) + .filter(Objects::nonNull).findFirst().orElse(null); + } + + private void addAttributes(CodeBlock.Builder builder, BeanDefinition beanDefinition) { String[] attributeNames = beanDefinition.attributeNames(); if (!ObjectUtils.isEmpty(attributeNames)) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java index cba19031d51..f234301fa7e 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionPropertiesCodeGeneratorTests.java @@ -225,7 +225,8 @@ class BeanDefinitionPropertiesCodeGeneratorTests { this.beanDefinition.setInitMethodName("i1"); testCompiledResult((actual, compiled) -> assertThat(actual.getInitMethodNames()) .containsExactly("i1")); - assertHasMethodInvokeHints("i1"); + String[] methodNames = { "i1" }; + assertHasMethodInvokeHints(InitDestroyBean.class, methodNames); } @Test @@ -234,7 +235,8 @@ class BeanDefinitionPropertiesCodeGeneratorTests { this.beanDefinition.setInitMethodNames("i1", "i2"); testCompiledResult((actual, compiled) -> assertThat(actual.getInitMethodNames()) .containsExactly("i1", "i2")); - assertHasMethodInvokeHints("i1", "i2"); + String[] methodNames = { "i1", "i2" }; + assertHasMethodInvokeHints(InitDestroyBean.class, methodNames); } @Test @@ -244,7 +246,8 @@ class BeanDefinitionPropertiesCodeGeneratorTests { testCompiledResult( (actual, compiled) -> assertThat(actual.getDestroyMethodNames()) .containsExactly("d1")); - assertHasMethodInvokeHints("d1"); + String[] methodNames = { "d1" }; + assertHasMethodInvokeHints(InitDestroyBean.class, methodNames); } @Test @@ -254,20 +257,20 @@ class BeanDefinitionPropertiesCodeGeneratorTests { testCompiledResult( (actual, compiled) -> assertThat(actual.getDestroyMethodNames()) .containsExactly("d1", "d2")); - assertHasMethodInvokeHints("d1", "d2"); - } - - private void assertHasMethodInvokeHints(String... methodNames) { - assertThat(hints.reflection().getTypeHint(InitDestroyBean.class)) - .satisfies(typeHint -> { - for (String methodName : methodNames) { - assertThat(typeHint.methods()).anySatisfy(methodHint -> { - assertThat(methodHint.getName()).isEqualTo(methodName); - assertThat(methodHint.getModes()) - .containsExactly(ExecutableMode.INVOKE); - }); - } + String[] methodNames = { "d1", "d2" }; + assertHasMethodInvokeHints(InitDestroyBean.class, methodNames); + } + + private void assertHasMethodInvokeHints(Class beanType, String... methodNames) { + assertThat(this.hints.reflection().getTypeHint(beanType)).satisfies(typeHint -> { + for (String methodName : methodNames) { + assertThat(typeHint.methods()).anySatisfy(methodHint -> { + assertThat(methodHint.getName()).isEqualTo(methodName); + assertThat(methodHint.getModes()) + .containsExactly(ExecutableMode.INVOKE); }); + } + }); } @Test @@ -289,12 +292,15 @@ class BeanDefinitionPropertiesCodeGeneratorTests { @Test void propertyValuesWhenValues() { + this.beanDefinition.setTargetType(PropertyValuesBean.class); this.beanDefinition.getPropertyValues().add("test", String.class); this.beanDefinition.getPropertyValues().add("spring", "framework"); testCompiledResult((actual, compiled) -> { assertThat(actual.getPropertyValues().get("test")).isEqualTo(String.class); assertThat(actual.getPropertyValues().get("spring")).isEqualTo("framework"); }); + String[] methodNames = { "setTest", "setSpring" }; + assertHasMethodInvokeHints(PropertyValuesBean.class, methodNames); } @Test @@ -438,4 +444,28 @@ class BeanDefinitionPropertiesCodeGeneratorTests { } + static class PropertyValuesBean { + + private Class test; + + private String spring; + + public Class getTest() { + return this.test; + } + + public void setTest(Class test) { + this.test = test; + } + + public String getSpring() { + return this.spring; + } + + public void setSpring(String spring) { + this.spring = spring; + } + + } + }