From be7029c35dca4d70de15ea097649799aec19ac88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Wed, 10 Jul 2024 12:26:03 +0200 Subject: [PATCH] Retain previous factory method in case of nested invocation with AOT This commit harmonizes the invocation of a bean supplier with what SimpleInstantiationStrategy does. Previously, the current factory method was set to `null` once the invocation completes. This did not take into account recursive scenarios where an instance supplier triggers another instance supplier. For consistency, the thread local is removed now if we attempt to set the current method to null. SimpleInstantiationStrategy itself uses the shortcut to align the code as much as possible. Closes gh-33185 --- .../factory/aot/BeanInstanceSupplier.java | 5 ++- .../support/SimpleInstantiationStrategy.java | 23 +++++----- .../aot/BeanInstanceSupplierTests.java | 45 ++++++++++++++++++- 3 files changed, 59 insertions(+), 14 deletions(-) diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java index 265e939e01c..5c2bb461fd1 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanInstanceSupplier.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -209,12 +209,13 @@ public final class BeanInstanceSupplier extends AutowiredElementResolver impl if (!(executable instanceof Method method)) { return beanSupplier.get(); } + Method priorInvokedFactoryMethod = SimpleInstantiationStrategy.getCurrentlyInvokedFactoryMethod(); try { SimpleInstantiationStrategy.setCurrentlyInvokedFactoryMethod(method); return beanSupplier.get(); } finally { - SimpleInstantiationStrategy.setCurrentlyInvokedFactoryMethod(null); + SimpleInstantiationStrategy.setCurrentlyInvokedFactoryMethod(priorInvokedFactoryMethod); } } diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/SimpleInstantiationStrategy.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/SimpleInstantiationStrategy.java index 0fc15c5a353..56f5dffe14b 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/SimpleInstantiationStrategy.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/SimpleInstantiationStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,12 +54,18 @@ public class SimpleInstantiationStrategy implements InstantiationStrategy { } /** - * Set the factory method currently being invoked or {@code null} to reset. + * Set the factory method currently being invoked or {@code null} to remove + * the current value, if any. * @param method the factory method currently being invoked or {@code null} * @since 6.0 */ public static void setCurrentlyInvokedFactoryMethod(@Nullable Method method) { - currentlyInvokedFactoryMethod.set(method); + if (method != null) { + currentlyInvokedFactoryMethod.set(method); + } + else { + currentlyInvokedFactoryMethod.remove(); + } } @@ -133,9 +139,9 @@ public class SimpleInstantiationStrategy implements InstantiationStrategy { try { ReflectionUtils.makeAccessible(factoryMethod); - Method priorInvokedFactoryMethod = currentlyInvokedFactoryMethod.get(); + Method priorInvokedFactoryMethod = getCurrentlyInvokedFactoryMethod(); try { - currentlyInvokedFactoryMethod.set(factoryMethod); + setCurrentlyInvokedFactoryMethod(factoryMethod); Object result = factoryMethod.invoke(factoryBean, args); if (result == null) { result = new NullBean(); @@ -143,12 +149,7 @@ public class SimpleInstantiationStrategy implements InstantiationStrategy { return result; } finally { - if (priorInvokedFactoryMethod != null) { - currentlyInvokedFactoryMethod.set(priorInvokedFactoryMethod); - } - else { - currentlyInvokedFactoryMethod.remove(); - } + setCurrentlyInvokedFactoryMethod(priorInvokedFactoryMethod); } } catch (IllegalArgumentException ex) { diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java index 422de9d4420..ec25e328f4c 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanInstanceSupplierTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Stream; @@ -51,6 +52,7 @@ import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.InstanceSupplier; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.beans.factory.support.SimpleInstantiationStrategy; import org.springframework.core.env.Environment; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.ResourceLoader; @@ -292,6 +294,33 @@ class BeanInstanceSupplierTests { assertThat(instance).isEqualTo("1"); } + @Test // gh-33180 + void getWithNestedInvocationRetainsFactoryMethod() throws Exception { + AtomicReference testMethodReference = new AtomicReference<>(); + AtomicReference anotherMethodReference = new AtomicReference<>(); + + BeanInstanceSupplier nestedInstanceSupplier = BeanInstanceSupplier + .forFactoryMethod(AnotherTestStringFactory.class, "another") + .withGenerator(registeredBean -> { + anotherMethodReference.set(SimpleInstantiationStrategy.getCurrentlyInvokedFactoryMethod()); + return "Another"; + }); + RegisteredBean nestedRegisteredBean = new Source(String.class, nestedInstanceSupplier).registerBean(this.beanFactory); + BeanInstanceSupplier instanceSupplier = BeanInstanceSupplier + .forFactoryMethod(TestStringFactory.class, "test") + .withGenerator(registeredBean -> { + Object nested = nestedInstanceSupplier.get(nestedRegisteredBean); + testMethodReference.set(SimpleInstantiationStrategy.getCurrentlyInvokedFactoryMethod()); + return "custom" + nested; + }); + RegisteredBean registeredBean = new Source(String.class, instanceSupplier).registerBean(this.beanFactory); + Object value = instanceSupplier.get(registeredBean); + + assertThat(value).isEqualTo("customAnother"); + assertThat(testMethodReference.get()).isEqualTo(instanceSupplier.getFactoryMethod()); + assertThat(anotherMethodReference.get()).isEqualTo(nestedInstanceSupplier.getFactoryMethod()); + } + @Test void resolveArgumentsWithNoArgConstructor() { RootBeanDefinition beanDefinition = new RootBeanDefinition( @@ -934,4 +963,18 @@ class BeanInstanceSupplierTests { } + static class TestStringFactory { + + String test() { + return "test"; + } + } + + static class AnotherTestStringFactory { + + String another() { + return "another"; + } + } + }