Browse Source

Retain original requested bean class for SpringContainedBean

Closes GH-36115

Signed-off-by: Yanming Zhou <zhouyanming@gmail.com>
pull/36125/head
Yanming Zhou 3 weeks ago committed by Juergen Hoeller
parent
commit
d3a385d222
  1. 45
      spring-orm/src/main/java/org/springframework/orm/jpa/hibernate/SpringBeanContainer.java
  2. 16
      spring-orm/src/test/java/org/springframework/orm/jpa/hibernate/HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests.java

45
spring-orm/src/main/java/org/springframework/orm/jpa/hibernate/SpringBeanContainer.java

@ -69,6 +69,7 @@ import org.springframework.util.ConcurrentReferenceHashMap; @@ -69,6 +69,7 @@ import org.springframework.util.ConcurrentReferenceHashMap;
* integration will be registered out of the box.
*
* @author Juergen Hoeller
* @author Yanming Zhou
* @since 7.0
* @see LocalSessionFactoryBean#setBeanFactory
* @see LocalSessionFactoryBuilder#setBeanContainer
@ -139,17 +140,18 @@ public final class SpringBeanContainer implements BeanContainer { @@ -139,17 +140,18 @@ public final class SpringBeanContainer implements BeanContainer {
}
private SpringContainedBean<?> createBean(
Class<?> beanType, LifecycleOptions lifecycleOptions, BeanInstanceProducer fallbackProducer) {
private <B> SpringContainedBean<B> createBean(
Class<B> beanType, LifecycleOptions lifecycleOptions, BeanInstanceProducer fallbackProducer) {
try {
if (lifecycleOptions.useJpaCompliantCreation()) {
return new SpringContainedBean<>(
beanType,
this.beanFactory.createBean(beanType),
this.beanFactory::destroyBean);
}
else {
return new SpringContainedBean<>(this.beanFactory.getBean(beanType));
return new SpringContainedBean<>(beanType, this.beanFactory.getBean(beanType));
}
}
catch (BeansException ex) {
@ -158,7 +160,7 @@ public final class SpringBeanContainer implements BeanContainer { @@ -158,7 +160,7 @@ public final class SpringBeanContainer implements BeanContainer {
beanType + ": " + ex);
}
try {
return new SpringContainedBean<>(fallbackProducer.produceBeanInstance(beanType));
return new SpringContainedBean<>(beanType, fallbackProducer.produceBeanInstance(beanType));
}
catch (RuntimeException ex2) {
if (ex instanceof BeanCreationException) {
@ -176,42 +178,44 @@ public final class SpringBeanContainer implements BeanContainer { @@ -176,42 +178,44 @@ public final class SpringBeanContainer implements BeanContainer {
}
}
private SpringContainedBean<?> createBean(
String name, Class<?> beanType, LifecycleOptions lifecycleOptions, BeanInstanceProducer fallbackProducer) {
@SuppressWarnings("unchecked")
private <B> SpringContainedBean<B> createBean(
String name, Class<B> beanType, LifecycleOptions lifecycleOptions, BeanInstanceProducer fallbackProducer) {
try {
if (lifecycleOptions.useJpaCompliantCreation()) {
Object bean = null;
B bean = null;
if (fallbackProducer instanceof TypeBootstrapContext) {
// Special Hibernate type construction rules, including TypeBootstrapContext resolution.
bean = fallbackProducer.produceBeanInstance(name, beanType);
}
if (this.beanFactory.containsBean(name)) {
if (bean == null) {
bean = this.beanFactory.autowire(beanType, AutowireCapableBeanFactory.AUTOWIRE_CONSTRUCTOR, false);
bean = (B) this.beanFactory.autowire(beanType, AutowireCapableBeanFactory.AUTOWIRE_CONSTRUCTOR, false);
}
this.beanFactory.autowireBeanProperties(bean, AutowireCapableBeanFactory.AUTOWIRE_NO, false);
this.beanFactory.applyBeanPropertyValues(bean, name);
bean = this.beanFactory.initializeBean(bean, name);
return new SpringContainedBean<>(bean, beanInstance -> this.beanFactory.destroyBean(name, beanInstance));
bean = (B) this.beanFactory.initializeBean(bean, name);
return new SpringContainedBean<>(beanType, bean, beanInstance -> this.beanFactory.destroyBean(name, beanInstance));
}
else if (bean != null) {
// No bean found by name but constructed with TypeBootstrapContext rules
this.beanFactory.autowireBeanProperties(bean, AutowireCapableBeanFactory.AUTOWIRE_NO, false);
bean = this.beanFactory.initializeBean(bean, name);
return new SpringContainedBean<>(bean, this.beanFactory::destroyBean);
bean = (B) this.beanFactory.initializeBean(bean, name);
return new SpringContainedBean<>(beanType, bean, this.beanFactory::destroyBean);
}
else {
// No bean found by name -> construct by type using createBean
return new SpringContainedBean<>(
beanType,
this.beanFactory.createBean(beanType),
this.beanFactory::destroyBean);
}
}
else {
return (this.beanFactory.containsBean(name) ?
new SpringContainedBean<>(this.beanFactory.getBean(name, beanType)) :
new SpringContainedBean<>(this.beanFactory.getBean(beanType)));
new SpringContainedBean<>(beanType, this.beanFactory.getBean(name, beanType)) :
new SpringContainedBean<>(beanType, this.beanFactory.getBean(beanType)));
}
}
catch (BeansException ex) {
@ -220,7 +224,7 @@ public final class SpringBeanContainer implements BeanContainer { @@ -220,7 +224,7 @@ public final class SpringBeanContainer implements BeanContainer {
beanType + " with name '" + name + "': " + ex);
}
try {
return new SpringContainedBean<>(fallbackProducer.produceBeanInstance(name, beanType));
return new SpringContainedBean<>(beanType, fallbackProducer.produceBeanInstance(name, beanType));
}
catch (RuntimeException ex2) {
if (ex instanceof BeanCreationException) {
@ -241,15 +245,19 @@ public final class SpringBeanContainer implements BeanContainer { @@ -241,15 +245,19 @@ public final class SpringBeanContainer implements BeanContainer {
private static final class SpringContainedBean<B> implements ContainedBean<B> {
private final Class<B> beanClass;
private final B beanInstance;
private @Nullable Consumer<B> destructionCallback;
public SpringContainedBean(B beanInstance) {
public SpringContainedBean(Class<B> beanClass, B beanInstance) {
this.beanClass = beanClass;
this.beanInstance = beanInstance;
}
public SpringContainedBean(B beanInstance, Consumer<B> destructionCallback) {
public SpringContainedBean(Class<B> beanClass, B beanInstance, Consumer<B> destructionCallback) {
this.beanClass = beanClass;
this.beanInstance = beanInstance;
this.destructionCallback = destructionCallback;
}
@ -260,9 +268,8 @@ public final class SpringBeanContainer implements BeanContainer { @@ -260,9 +268,8 @@ public final class SpringBeanContainer implements BeanContainer {
}
@Override
@SuppressWarnings("unchecked")
public Class<B> getBeanClass() {
return (Class<B>) this.beanInstance.getClass();
return this.beanClass;
}
public void destroyIfNecessary() {

16
spring-orm/src/test/java/org/springframework/orm/jpa/hibernate/HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests.java

@ -33,6 +33,7 @@ import org.springframework.orm.jpa.hibernate.beans.BeanSource; @@ -33,6 +33,7 @@ import org.springframework.orm.jpa.hibernate.beans.BeanSource;
import org.springframework.orm.jpa.hibernate.beans.MultiplePrototypesInSpringContextTestBean;
import org.springframework.orm.jpa.hibernate.beans.NoDefinitionInSpringContextTestBean;
import org.springframework.orm.jpa.hibernate.beans.SinglePrototypeInSpringContextTestBean;
import org.springframework.orm.jpa.hibernate.beans.TestBean;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -42,6 +43,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -42,6 +43,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
*
* @author Yoann Rodiere
* @author Juergen Hoeller
* @author Yanming Zhou
*/
class HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests
extends AbstractEntityManagerFactoryIntegrationTests {
@ -275,6 +277,20 @@ class HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests @@ -275,6 +277,20 @@ class HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests
));
}
@Test
void testRetrieveBeanShouldRetainOriginalBeanType() {
BeanContainer beanContainer = getBeanContainer();
assertThat(beanContainer).isNotNull();
ContainedBean<TestBean> bean = beanContainer.getBean(
"single", TestBean.class,
NativeLifecycleOptions.INSTANCE, IneffectiveBeanInstanceProducer.INSTANCE
);
assertThat(bean).isNotNull();
assertThat(bean.getBeanClass()).isSameAs(TestBean.class);
}
/**
* The lifecycle options mandated by the JPA spec and used as a default in Hibernate ORM.

Loading…
Cancel
Save