From b5a86dec92eefb43795ac0e9f6802d589fff5956 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Wed, 10 Jul 2024 12:26:03 +0200 Subject: [PATCH] Retain previous factory method in case of nested invocation with AOT This commit harmonizes the invocation of a bean supplier with what SimpleInstantiationStrategy does. Previously, the current factory method was set to `null` once the invocation completes. This did not take into account recursive scenarios where an instance supplier triggers another instance supplier. For consistency, the thread local is removed now if we attempt to set the current method to null. SimpleInstantiationStrategy itself uses the shortcut to align the code as much as possible. Closes gh-33180 --- .../factory/aot/BeanInstanceSupplier.java | 5 ++- .../support/SimpleInstantiationStrategy.java | 23 +++++----- .../aot/BeanInstanceSupplierTests.java | 45 ++++++++++++++++++- 3 files changed, 59 insertions(+), 14 deletions(-) diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java index 56d3e79268c..a1992cb8876 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -213,12 +213,13 @@ public final class BeanInstanceSupplier extends AutowiredElementResolver impl if (!(executable instanceof Method method)) { return beanSupplier.get(); } + Method priorInvokedFactoryMethod = SimpleInstantiationStrategy.getCurrentlyInvokedFactoryMethod(); try { SimpleInstantiationStrategy.setCurrentlyInvokedFactoryMethod(method); return beanSupplier.get(); } finally { - SimpleInstantiationStrategy.setCurrentlyInvokedFactoryMethod(null); + SimpleInstantiationStrategy.setCurrentlyInvokedFactoryMethod(priorInvokedFactoryMethod); } } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/SimpleInstantiationStrategy.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/SimpleInstantiationStrategy.java index d1d98d35e5f..49c38d7e738 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/SimpleInstantiationStrategy.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/SimpleInstantiationStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,12 +55,18 @@ public class SimpleInstantiationStrategy implements InstantiationStrategy { } /** - * Set the factory method currently being invoked or {@code null} to reset. + * Set the factory method currently being invoked or {@code null} to remove + * the current value, if any. * @param method the factory method currently being invoked or {@code null} * @since 6.0 */ public static void setCurrentlyInvokedFactoryMethod(@Nullable Method method) { - currentlyInvokedFactoryMethod.set(method); + if (method != null) { + currentlyInvokedFactoryMethod.set(method); + } + else { + currentlyInvokedFactoryMethod.remove(); + } } @@ -134,9 +140,9 @@ public class SimpleInstantiationStrategy implements InstantiationStrategy { try { ReflectionUtils.makeAccessible(factoryMethod); - Method priorInvokedFactoryMethod = currentlyInvokedFactoryMethod.get(); + Method priorInvokedFactoryMethod = getCurrentlyInvokedFactoryMethod(); try { - currentlyInvokedFactoryMethod.set(factoryMethod); + setCurrentlyInvokedFactoryMethod(factoryMethod); Object result = factoryMethod.invoke(factoryBean, args); if (result == null) { result = new NullBean(); @@ -144,12 +150,7 @@ public class SimpleInstantiationStrategy implements InstantiationStrategy { return result; } finally { - if (priorInvokedFactoryMethod != null) { - currentlyInvokedFactoryMethod.set(priorInvokedFactoryMethod); - } - else { - currentlyInvokedFactoryMethod.remove(); - } + setCurrentlyInvokedFactoryMethod(priorInvokedFactoryMethod); } } catch (IllegalArgumentException ex) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java index 6d8da260546..1a1a8f6d56f 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Stream; @@ -51,6 +52,7 @@ import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.factory.support.SimpleInstantiationStrategy; import org.springframework.core.env.Environment; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.ResourceLoader; @@ -292,6 +294,33 @@ class BeanInstanceSupplierTests { assertThat(instance).isEqualTo("1"); } + @Test // gh-33180 + void getWithNestedInvocationRetainsFactoryMethod() throws Exception { + AtomicReference testMethodReference = new AtomicReference<>(); + AtomicReference anotherMethodReference = new AtomicReference<>(); + + BeanInstanceSupplier nestedInstanceSupplier = BeanInstanceSupplier + .forFactoryMethod(AnotherTestStringFactory.class, "another") + .withGenerator(registeredBean -> { + anotherMethodReference.set(SimpleInstantiationStrategy.getCurrentlyInvokedFactoryMethod()); + return "Another"; + }); + RegisteredBean nestedRegisteredBean = new Source(String.class, nestedInstanceSupplier).registerBean(this.beanFactory); + BeanInstanceSupplier instanceSupplier = BeanInstanceSupplier + .forFactoryMethod(TestStringFactory.class, "test") + .withGenerator(registeredBean -> { + Object nested = nestedInstanceSupplier.get(nestedRegisteredBean); + testMethodReference.set(SimpleInstantiationStrategy.getCurrentlyInvokedFactoryMethod()); + return "custom" + nested; + }); + RegisteredBean registeredBean = new Source(String.class, instanceSupplier).registerBean(this.beanFactory); + Object value = instanceSupplier.get(registeredBean); + + assertThat(value).isEqualTo("customAnother"); + assertThat(testMethodReference.get()).isEqualTo(instanceSupplier.getFactoryMethod()); + assertThat(anotherMethodReference.get()).isEqualTo(nestedInstanceSupplier.getFactoryMethod()); + } + @Test void resolveArgumentsWithNoArgConstructor() { RootBeanDefinition beanDefinition = new RootBeanDefinition( @@ -1030,4 +1059,18 @@ class BeanInstanceSupplierTests { } + static class TestStringFactory { + + String test() { + return "test"; + } + } + + static class AnotherTestStringFactory { + + String another() { + return "another"; + } + } + }