Browse Source

Improve location of generated bean definitions of FactoryBeans

This commit improves the location of generated bean definitions for
FactoryBean implementations by checking the type that the factory
bean generates, rather than the factory bean implementation itself.

Closes gh-28812
pull/28843/head
Stephane Nicoll 4 years ago
parent
commit
85d4a79cdc
  1. 17
      spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java
  2. 157
      spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java

17
spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java

@ -16,12 +16,14 @@
package org.springframework.beans.factory.aot; package org.springframework.beans.factory.aot;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable; import java.lang.reflect.Executable;
import java.util.List; import java.util.List;
import java.util.function.Predicate; import java.util.function.Predicate;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generate.MethodReference;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.config.BeanDefinitionHolder;
import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.InstanceSupplier;
@ -69,14 +71,21 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments
public Class<?> getTarget(RegisteredBean registeredBean, public Class<?> getTarget(RegisteredBean registeredBean,
Executable constructorOrFactoryMethod) { Executable constructorOrFactoryMethod) {
Class<?> target = ClassUtils Class<?> target = extractDeclaringClass(constructorOrFactoryMethod);
.getUserClass(constructorOrFactoryMethod.getDeclaringClass());
while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) { while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) {
target = registeredBean.getParent().getBeanClass(); target = registeredBean.getParent().getBeanClass();
} }
return target; return target;
} }
private Class<?> extractDeclaringClass(Executable executable) {
Class<?> declaringClass = ClassUtils.getUserClass(executable.getDeclaringClass());
if (executable instanceof Constructor<?> && FactoryBean.class.isAssignableFrom(declaringClass)) {
return ResolvableType.forType(declaringClass).as(FactoryBean.class).getGeneric(0).toClass();
}
return executable.getDeclaringClass();
}
@Override @Override
public CodeBlock generateNewBeanDefinitionCode(GenerationContext generationContext, public CodeBlock generateNewBeanDefinitionCode(GenerationContext generationContext,
ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) { ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) {
@ -107,7 +116,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments
generationContext.getRuntimeHints(), attributeFilter, generationContext.getRuntimeHints(), attributeFilter,
beanRegistrationCode.getMethods(), beanRegistrationCode.getMethods(),
(name, value) -> generateValueCode(generationContext, name, value)) (name, value) -> generateValueCode(generationContext, name, value))
.generateCode(beanDefinition); .generateCode(beanDefinition);
} }
@Nullable @Nullable
@ -171,7 +180,7 @@ class DefaultBeanRegistrationCodeFragments extends BeanRegistrationCodeFragments
return new InstanceSupplierCodeGenerator(generationContext, return new InstanceSupplierCodeGenerator(generationContext,
beanRegistrationCode.getClassName(), beanRegistrationCode.getClassName(),
beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) beanRegistrationCode.getMethods(), allowDirectSupplierShortcut)
.generateCode(this.registeredBean, constructorOrFactoryMethod); .generateCode(this.registeredBean, constructorOrFactoryMethod);
} }
@Override @Override

157
spring-beans/src/test/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragmentsTests.java

@ -0,0 +1,157 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.factory.aot;
import java.lang.reflect.Method;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.annotation.InjectAnnotationBeanPostProcessorTests.StringFactoryBean;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.DummyFactory;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link DefaultBeanRegistrationCodeFragments}.
*
* @author Stephane Nicoll
*/
class DefaultBeanRegistrationCodeFragmentsTests {
private final BeanRegistrationsCode beanRegistrationsCode = new MockBeanRegistrationsCode(new TestGenerationContext());
private final DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
@Test
void getTargetOnConstructor() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
assertThat(createInstance(registeredBean).getTarget(registeredBean,
TestBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnConstructorToFactoryBean() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
assertThat(createInstance(registeredBean).getTarget(registeredBean,
TestBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnMethod() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
Method method = ReflectionUtils.findMethod(TestBeanFactoryBean.class, "getObject");
assertThat(method).isNotNull();
assertThat(createInstance(registeredBean).getTarget(registeredBean,
method)).isEqualTo(TestBeanFactoryBean.class);
}
@Test
void getTargetOnMethodWithInnerBeanInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class));
Method method = ReflectionUtils.findMethod(getClass(), "createString");
assertThat(method).isNotNull();
assertThat(createInstance(innerBean).getTarget(innerBean,
method)).isEqualTo(getClass());
}
@Test
void getTargetOnConstructorWithInnerBeanInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(String.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
String.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnConstructorWithInnerBeanOnTypeInJavaPackage() {
RegisteredBean registeredBean = registerTestBean(TestBean.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(StringFactoryBean.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
StringFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnMethodWithInnerBeanInRegularPackage() {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(TestBean.class));
Method method = ReflectionUtils.findMethod(TestBeanFactoryBean.class, "getObject");
assertThat(method).isNotNull();
assertThat(createInstance(innerBean).getTarget(innerBean, method)).isEqualTo(TestBeanFactoryBean.class);
}
@Test
void getTargetOnConstructorWithInnerBeanInRegularPackage() {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean", new RootBeanDefinition(TestBean.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
TestBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
@Test
void getTargetOnConstructorWithInnerBeanOnFactoryBeanOnTypeInRegularPackage() {
RegisteredBean registeredBean = registerTestBean(DummyFactory.class);
RegisteredBean innerBean = RegisteredBean.ofInnerBean(registeredBean, "innerTestBean",
new RootBeanDefinition(TestBean.class));
assertThat(createInstance(innerBean).getTarget(innerBean,
TestBeanFactoryBean.class.getDeclaredConstructors()[0])).isEqualTo(TestBean.class);
}
private RegisteredBean registerTestBean(Class<?> beanType) {
this.beanFactory.registerBeanDefinition("testBean",
new RootBeanDefinition(beanType));
return RegisteredBean.of(this.beanFactory, "testBean");
}
private BeanRegistrationCodeFragments createInstance(RegisteredBean registeredBean) {
return new DefaultBeanRegistrationCodeFragments(this.beanRegistrationsCode, registeredBean, new BeanDefinitionMethodGeneratorFactory(this.beanFactory));
}
@SuppressWarnings("unused")
static String createString() {
return "Test";
}
@SuppressWarnings("unused")
static class TestBean {
}
static class TestBeanFactoryBean implements FactoryBean<TestBean> {
@Override
public TestBean getObject() throws Exception {
return new TestBean();
}
@Override
public Class<?> getObjectType() {
return TestBean.class;
}
}
}
Loading…
Cancel
Save