diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/InjectionCodeGenerator.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/InjectionCodeGenerator.java new file mode 100644 index 00000000000..7cb517b71db --- /dev/null +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/InjectionCodeGenerator.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.orm.jpa.support; + +import java.lang.reflect.Field; +import java.lang.reflect.Member; +import java.lang.reflect.Method; + +import org.springframework.aot.generate.AccessVisibility; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.javapoet.CodeBlock; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +/** + * Internal code generator that can inject a value into a field or single-arg + * method. + *

+ * Generates code in the form:

{@code
+ * instance.age = value;
+ * }
or
{@code
+ * instance.setAge(value);
+ * }
+ *

+ * Will also generate reflection based injection and register hints if the + * member is not visible. + * + * @author Phillip Webb + * @since 6.0 + */ +class InjectionCodeGenerator { + + private final RuntimeHints hints; + + + InjectionCodeGenerator(RuntimeHints hints) { + Assert.notNull(hints, "Hints must not be null"); + this.hints = hints; + } + + + CodeBlock generateInjectionCode(Member member, String instanceVariable, + CodeBlock resourceToInject) { + + if (member instanceof Field field) { + return generateFieldInjectionCode(field, instanceVariable, resourceToInject); + } + if (member instanceof Method method) { + return generateMethodInjectionCode(method, instanceVariable, + resourceToInject); + } + throw new IllegalStateException( + "Unsupported member type " + member.getClass().getName()); + } + + private CodeBlock generateFieldInjectionCode(Field field, String instanceVariable, + CodeBlock resourceToInject) { + + CodeBlock.Builder builder = CodeBlock.builder(); + AccessVisibility visibility = AccessVisibility.forMember(field); + if (visibility == AccessVisibility.PRIVATE + || visibility == AccessVisibility.PROTECTED) { + this.hints.reflection().registerField(field); + builder.addStatement("$T field = $T.findField($T.class, $S)", Field.class, + ReflectionUtils.class, field.getDeclaringClass(), field.getName()); + builder.addStatement("$T.makeAccessible($L)", ReflectionUtils.class, "field"); + builder.addStatement("$T.setField($L, $L, $L)", ReflectionUtils.class, + "field", instanceVariable, resourceToInject); + } + else { + builder.addStatement("$L.$L = $L", instanceVariable, field.getName(), + resourceToInject); + } + return builder.build(); + } + + private CodeBlock generateMethodInjectionCode(Method method, String instanceVariable, + CodeBlock resourceToInject) { + + Assert.isTrue(method.getParameterCount() == 1, + "Method '" + method.getName() + "' must declare a single parameter"); + CodeBlock.Builder builder = CodeBlock.builder(); + AccessVisibility visibility = AccessVisibility.forMember(method); + if (visibility == AccessVisibility.PRIVATE + || visibility == AccessVisibility.PROTECTED) { + this.hints.reflection().registerMethod(method); + builder.addStatement("$T method = $T.findMethod($T.class, $S, $T.class)", + Method.class, ReflectionUtils.class, method.getDeclaringClass(), + method.getName(), method.getParameterTypes()[0]); + builder.addStatement("$T.makeAccessible($L)", ReflectionUtils.class, + "method"); + builder.addStatement("$T.invokeMethod($L, $L, $L)", ReflectionUtils.class, + "method", instanceVariable, resourceToInject); + } + else { + builder.addStatement("$L.$L($L)", instanceVariable, method.getName(), + resourceToInject); + } + return builder.build(); + } + +} diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java index d61ee632468..a3e9cbe2c5c 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java @@ -29,6 +29,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; import jakarta.persistence.EntityManager; @@ -39,8 +40,15 @@ import jakarta.persistence.PersistenceProperty; import jakarta.persistence.PersistenceUnit; import jakarta.persistence.SynchronizationType; +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GeneratedMethods; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.generate.MethodGenerator; +import org.springframework.aot.generate.MethodNameGenerator; +import org.springframework.aot.generate.MethodReference; import org.springframework.aot.generator.CodeContribution; import org.springframework.aot.generator.ProtectedAccess.Options; +import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.BeanUtils; import org.springframework.beans.PropertyValues; import org.springframework.beans.factory.BeanCreationException; @@ -50,6 +58,9 @@ import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.annotation.InjectionMetadata; import org.springframework.beans.factory.annotation.InjectionMetadata.InjectedElement; +import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; +import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; +import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.DestructionAwareBeanPostProcessor; @@ -59,12 +70,17 @@ import org.springframework.beans.factory.generator.AotContributingBeanPostProces import org.springframework.beans.factory.generator.BeanFieldGenerator; import org.springframework.beans.factory.generator.BeanInstantiationContribution; import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor; +import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.BridgeMethodResolver; import org.springframework.core.Ordered; import org.springframework.core.PriorityOrdered; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeSpec; import org.springframework.javapoet.support.MultiStatement; import org.springframework.jndi.JndiLocatorDelegate; import org.springframework.jndi.JndiTemplate; @@ -176,6 +192,7 @@ import org.springframework.util.StringUtils; * @author Rod Johnson * @author Juergen Hoeller * @author Stephane Nicoll + * @author Phillip Webb * @since 2.0 * @see jakarta.persistence.PersistenceUnit * @see jakarta.persistence.PersistenceContext @@ -183,7 +200,7 @@ import org.springframework.util.StringUtils; @SuppressWarnings("serial") public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwareBeanPostProcessor, DestructionAwareBeanPostProcessor, - MergedBeanDefinitionPostProcessor, AotContributingBeanPostProcessor, + MergedBeanDefinitionPostProcessor, AotContributingBeanPostProcessor, BeanRegistrationAotProcessor, PriorityOrdered, BeanFactoryAware, Serializable { @Nullable @@ -358,6 +375,19 @@ public class PersistenceAnnotationBeanPostProcessor return null; } + @Override + public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) { + Class beanClass = registeredBean.getBeanClass(); + String beanName = registeredBean.getBeanName(); + RootBeanDefinition beanDefinition = registeredBean.getMergedBeanDefinition(); + InjectionMetadata metadata = findInjectionMetadata(beanDefinition, beanClass, beanName); + Collection injectedElements = metadata.getInjectedElements(); + if (!CollectionUtils.isEmpty(injectedElements)) { + return new AotContribution(beanClass, injectedElements); + } + return null; + } + private InjectionMetadata findInjectionMetadata(RootBeanDefinition beanDefinition, Class beanType, String beanName) { InjectionMetadata metadata = findPersistenceMetadata(beanName, beanType, null); metadata.checkConfigMembers(beanDefinition); @@ -815,4 +845,126 @@ public class PersistenceAnnotationBeanPostProcessor } + + private static class AotContribution implements BeanRegistrationAotContribution { + + private static final String APPLY_METHOD = "apply"; + + private static final String REGISTERED_BEAN_PARAMETER = "registeredBean"; + + private static final String INSTANCE_PARAMETER = "instance"; + + + private final Class target; + + private final Collection injectedElements; + + + AotContribution(Class target, Collection injectedElements) { + this.target = target; + this.injectedElements = injectedElements; + } + + + @Override + public void applyTo(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode) { + ClassName className = generationContext.getClassNameGenerator() + .generateClassName(this.target, "PersistenceInjection"); + TypeSpec.Builder classBuilder = TypeSpec.classBuilder(className); + classBuilder.addJavadoc("Persistence injection for {@link $T}.", this.target); + classBuilder.addModifiers(javax.lang.model.element.Modifier.PUBLIC); + GeneratedMethods methods = new GeneratedMethods( + new MethodNameGenerator(APPLY_METHOD)); + classBuilder.addMethod(generateMethod(generationContext.getRuntimeHints(), + className, methods)); + methods.doWithMethodSpecs(classBuilder::addMethod); + JavaFile javaFile = JavaFile + .builder(className.packageName(), classBuilder.build()).build(); + generationContext.getGeneratedFiles().addSourceFile(javaFile); + beanRegistrationCode.addInstancePostProcessor( + MethodReference.ofStatic(className, APPLY_METHOD)); + } + + private MethodSpec generateMethod(RuntimeHints hints, ClassName className, + MethodGenerator methodGenerator) { + MethodSpec.Builder builder = MethodSpec.methodBuilder(APPLY_METHOD); + builder.addJavadoc("Apply the persistence injection."); + builder.addModifiers(javax.lang.model.element.Modifier.PUBLIC, + javax.lang.model.element.Modifier.STATIC); + builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); + builder.addParameter(this.target, INSTANCE_PARAMETER); + builder.returns(this.target); + builder.addCode(generateMethodCode(hints, methodGenerator)); + return builder.build(); + } + + private CodeBlock generateMethodCode(RuntimeHints hints, + MethodGenerator methodGenerator) { + CodeBlock.Builder builder = CodeBlock.builder(); + InjectionCodeGenerator injectionCodeGenerator = new InjectionCodeGenerator( + hints); + for (InjectedElement injectedElement : this.injectedElements) { + CodeBlock resourceToInject = getResourceToInject(methodGenerator, + (PersistenceElement) injectedElement); + builder.add(injectionCodeGenerator.generateInjectionCode( + injectedElement.getMember(), INSTANCE_PARAMETER, + resourceToInject)); + } + builder.addStatement("return $L", INSTANCE_PARAMETER); + return builder.build(); + } + + private CodeBlock getResourceToInject(MethodGenerator methodGenerator, + PersistenceElement injectedElement) { + String unitName = injectedElement.unitName; + boolean requireEntityManager = (injectedElement.type != null); + if (!requireEntityManager) { + return CodeBlock.of( + "$T.findEntityManagerFactory(($T) $L.getBeanFactory(), $S)", + EntityManagerFactoryUtils.class, ListableBeanFactory.class, + REGISTERED_BEAN_PARAMETER, unitName); + } + GeneratedMethod getEntityManagerMethod = methodGenerator + .generateMethod("get", unitName, "EntityManager") + .using(builder -> buildGetEntityManagerMethod(builder, + injectedElement)); + return CodeBlock.of("$L($L)", getEntityManagerMethod.getName(), + REGISTERED_BEAN_PARAMETER); + } + + private void buildGetEntityManagerMethod(MethodSpec.Builder builder, + PersistenceElement injectedElement) { + String unitName = injectedElement.unitName; + Properties properties = injectedElement.properties; + builder.addJavadoc("Get the '$L' {@link $T}", + (StringUtils.hasLength(unitName)) ? unitName : "default", + EntityManager.class); + builder.addModifiers(javax.lang.model.element.Modifier.PUBLIC, + javax.lang.model.element.Modifier.STATIC); + builder.returns(EntityManager.class); + builder.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER); + builder.addStatement( + "$T entityManagerFactory = $T.findEntityManagerFactory(($T) $L.getBeanFactory(), $S)", + EntityManagerFactory.class, EntityManagerFactoryUtils.class, + ListableBeanFactory.class, REGISTERED_BEAN_PARAMETER, unitName); + boolean hasProperties = !CollectionUtils.isEmpty(properties); + if (hasProperties) { + builder.addStatement("$T properties = new Properties()", + Properties.class); + for (String propertyName : new TreeSet<>( + properties.stringPropertyNames())) { + builder.addStatement("properties.put($S, $S)", propertyName, + properties.getProperty(propertyName)); + } + } + builder.addStatement( + "return $T.createSharedEntityManager(entityManagerFactory, $L, $L)", + SharedEntityManagerCreator.class, + (hasProperties) ? "properties" : null, + injectedElement.synchronizedWithTransaction); + } + + } + } diff --git a/spring-orm/src/test/java/org/springframework/orm/jpa/support/InjectionCodeGeneratorTests.java b/spring-orm/src/test/java/org/springframework/orm/jpa/support/InjectionCodeGeneratorTests.java new file mode 100644 index 00000000000..9737a01b5cf --- /dev/null +++ b/spring-orm/src/test/java/org/springframework/orm/jpa/support/InjectionCodeGeneratorTests.java @@ -0,0 +1,154 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.orm.jpa.support; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.Test; + +import org.springframework.aot.hint.ExecutableMode; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.test.generator.compile.Compiled; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.beans.testfixture.beans.TestBeanWithPrivateMethod; +import org.springframework.beans.testfixture.beans.TestBeanWithPublicField; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.ParameterizedTypeName; +import org.springframework.javapoet.TypeSpec; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link InjectionCodeGenerator}. + * + * @author Phillip Webb + */ +class InjectionCodeGeneratorTests { + + private static final String INSTANCE_VARIABLE = "instance"; + + private RuntimeHints hints = new RuntimeHints(); + + private InjectionCodeGenerator generator = new InjectionCodeGenerator(hints); + + @Test + void generateCodeWhenPublicFieldInjectsValue() { + TestBeanWithPublicField bean = new TestBeanWithPublicField(); + Field field = ReflectionUtils.findField(bean.getClass(), "age"); + CodeBlock generatedCode = this.generator.generateInjectionCode(field, INSTANCE_VARIABLE, + CodeBlock.of("$L", 123)); + testCompiledResult(generatedCode, TestBeanWithPublicField.class, (actual, compiled) -> { + TestBeanWithPublicField instance = new TestBeanWithPublicField(); + actual.accept(instance); + assertThat(instance).extracting("age").isEqualTo(123); + assertThat(compiled.getSourceFile()).contains("instance.age = 123"); + }); + } + + @Test + void generateCodeWhenPrivateFieldInjectsValueUsingReflection() { + TestBean bean = new TestBean(); + Field field = ReflectionUtils.findField(bean.getClass(), "age"); + CodeBlock generatedCode = this.generator.generateInjectionCode(field, INSTANCE_VARIABLE, + CodeBlock.of("$L", 123)); + testCompiledResult(generatedCode, TestBean.class, (actual, compiled) -> { + TestBean instance = new TestBean(); + actual.accept(instance); + assertThat(instance).extracting("age").isEqualTo(123); + assertThat(compiled.getSourceFile()).contains("setField("); + }); + } + + @Test + void generateCodeWhenPrivateFieldAddsHint() { + TestBean bean = new TestBean(); + Field field = ReflectionUtils.findField(bean.getClass(), "age"); + this.generator.generateInjectionCode(field, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + assertThat(this.hints.reflection().getTypeHint(TestBean.class)) + .satisfies(hint -> assertThat(hint.fields()).anySatisfy(fieldHint -> { + assertThat(fieldHint.getName()).isEqualTo("age"); + assertThat(fieldHint.isAllowWrite()).isTrue(); + })); + } + + @Test + void generateCodeWhenPublicMethodInjectsValue() { + TestBean bean = new TestBean(); + Method method = ReflectionUtils.findMethod(bean.getClass(), "setAge", int.class); + CodeBlock generatedCode = this.generator.generateInjectionCode(method, INSTANCE_VARIABLE, + CodeBlock.of("$L", 123)); + testCompiledResult(generatedCode, TestBean.class, (actual, compiled) -> { + TestBean instance = new TestBean(); + actual.accept(instance); + assertThat(instance).extracting("age").isEqualTo(123); + assertThat(compiled.getSourceFile()).contains("instance.setAge("); + }); + } + + @Test + void generateCodeWhenPrivateMethodInjectsValueUsingReflection() { + TestBeanWithPrivateMethod bean = new TestBeanWithPrivateMethod(); + Method method = ReflectionUtils.findMethod(bean.getClass(), "setAge", int.class); + CodeBlock generatedCode = this.generator.generateInjectionCode(method, INSTANCE_VARIABLE, + CodeBlock.of("$L", 123)); + testCompiledResult(generatedCode, TestBeanWithPrivateMethod.class, (actual, compiled) -> { + TestBeanWithPrivateMethod instance = new TestBeanWithPrivateMethod(); + actual.accept(instance); + assertThat(instance).extracting("age").isEqualTo(123); + assertThat(compiled.getSourceFile()).contains("invokeMethod("); + }); + } + + @Test + void generateCodeWhenPrivateMethodAddsHint() { + TestBeanWithPrivateMethod bean = new TestBeanWithPrivateMethod(); + Method method = ReflectionUtils.findMethod(bean.getClass(), "setAge", int.class); + this.generator.generateInjectionCode(method, INSTANCE_VARIABLE, CodeBlock.of("$L", 123)); + assertThat(this.hints.reflection().getTypeHint(TestBeanWithPrivateMethod.class)) + .satisfies(hint -> assertThat(hint.methods()).anySatisfy(methodHint -> { + assertThat(methodHint.getName()).isEqualTo("setAge"); + assertThat(methodHint.getModes()).contains(ExecutableMode.INVOKE); + })); + } + + @SuppressWarnings("unchecked") + private void testCompiledResult(CodeBlock generatedCode, Class target, + BiConsumer, Compiled> result) { + JavaFile javaFile = createJavaFile(generatedCode, target); + TestCompiler.forSystem().compile(javaFile::writeTo, + compiled -> result.accept(compiled.getInstance(Consumer.class), compiled)); + } + + private JavaFile createJavaFile(CodeBlock generatedCode, Class target) { + TypeSpec.Builder builder = TypeSpec.classBuilder("Injector"); + builder.addModifiers(Modifier.PUBLIC); + builder.addSuperinterface(ParameterizedTypeName.get(Consumer.class, target)); + builder.addMethod(MethodSpec.methodBuilder("accept").addModifiers(Modifier.PUBLIC) + .addParameter(target, INSTANCE_VARIABLE).addCode(generatedCode).build()); + return JavaFile.builder("__", builder.build()).build(); + } + +} diff --git a/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java b/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java new file mode 100644 index 00000000000..f38e1da8eb4 --- /dev/null +++ b/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorAotContributionTests.java @@ -0,0 +1,267 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.orm.jpa.support; + +import java.lang.reflect.Field; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.List; +import java.util.function.BiConsumer; + +import jakarta.persistence.EntityManager; +import jakarta.persistence.EntityManagerFactory; +import jakarta.persistence.PersistenceContext; +import jakarta.persistence.PersistenceProperty; +import jakarta.persistence.PersistenceUnit; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generate.DefaultGenerationContext; +import org.springframework.aot.generate.InMemoryGeneratedFiles; +import org.springframework.aot.hint.TypeReference; +import org.springframework.aot.test.generator.compile.CompileWithTargetClassAccess; +import org.springframework.aot.test.generator.compile.Compiled; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; +import org.springframework.beans.factory.aot.BeanRegistrationCode; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link PersistenceAnnotationBeanPostProcessor} AOT contribution. + * + * @author Stephane Nicoll + * @author Phillip Webb + */ +@CompileWithTargetClassAccess +class PersistenceAnnotationBeanPostProcessorAotContributionTests { + + private DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + + private InMemoryGeneratedFiles generatedFiles; + + private DefaultGenerationContext generationContext; + + @BeforeEach + void setup() { + this.beanFactory = new DefaultListableBeanFactory(); + this.generatedFiles = new InMemoryGeneratedFiles(); + this.generationContext = new DefaultGenerationContext(generatedFiles); + } + + @Test + void processAheadOfTimeWhenPersistenceUnitOnPublicField() { + RegisteredBean registeredBean = registerBean(DefaultPersistenceUnitField.class); + testCompile(registeredBean, (actual, compiled) -> { + EntityManagerFactory entityManagerFactory = mock(EntityManagerFactory.class); + this.beanFactory.registerSingleton("entityManagerFactory", + entityManagerFactory); + DefaultPersistenceUnitField instance = new DefaultPersistenceUnitField(); + actual.accept(registeredBean, instance); + assertThat(instance).extracting("emf").isSameAs(entityManagerFactory); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()) + .isEmpty(); + }); + } + + @Test + void processAheadOfTimeWhenPersistenceUnitOnPublicSetter() { + RegisteredBean registeredBean = registerBean(DefaultPersistenceUnitMethod.class); + testCompile(registeredBean, (actual, compiled) -> { + EntityManagerFactory entityManagerFactory = mock(EntityManagerFactory.class); + this.beanFactory.registerSingleton("entityManagerFactory", + entityManagerFactory); + DefaultPersistenceUnitMethod instance = new DefaultPersistenceUnitMethod(); + actual.accept(registeredBean, instance); + assertThat(instance).extracting("emf").isSameAs(entityManagerFactory); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()) + .isEmpty(); + }); + } + + @Test + void processAheadOfTimeWhenCustomPersistenceUnitOnPublicSetter() { + RegisteredBean registeredBean = registerBean( + CustomUnitNamePublicPersistenceUnitMethod.class); + testCompile(registeredBean, (actual, compiled) -> { + EntityManagerFactory entityManagerFactory = mock(EntityManagerFactory.class); + this.beanFactory.registerSingleton("custom", entityManagerFactory); + CustomUnitNamePublicPersistenceUnitMethod instance = new CustomUnitNamePublicPersistenceUnitMethod(); + actual.accept(registeredBean, instance); + assertThat(instance).extracting("emf").isSameAs(entityManagerFactory); + assertThat(compiled.getSourceFile()).contains( + "findEntityManagerFactory((ListableBeanFactory) registeredBean.getBeanFactory(), \"custom\")"); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()) + .isEmpty(); + }); + } + + @Test + void processAheadOfTimeWhenPersistenceContextOnPrivateField() { + RegisteredBean registeredBean = registerBean( + DefaultPersistenceContextField.class); + testCompile(registeredBean, (actual, compiled) -> { + EntityManagerFactory entityManagerFactory = mock(EntityManagerFactory.class); + this.beanFactory.registerSingleton("entityManagerFactory", + entityManagerFactory); + DefaultPersistenceContextField instance = new DefaultPersistenceContextField(); + actual.accept(registeredBean, instance); + assertThat(instance).extracting("entityManager").isNotNull(); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()) + .singleElement().satisfies(typeHint -> { + assertThat(typeHint.getType()).isEqualTo( + TypeReference.of(DefaultPersistenceContextField.class)); + assertThat(typeHint.fields()).singleElement() + .satisfies(fieldHint -> { + assertThat(fieldHint.getName()) + .isEqualTo("entityManager"); + assertThat(fieldHint.isAllowWrite()).isTrue(); + assertThat(fieldHint.isAllowUnsafeAccess()).isFalse(); + }); + }); + }); + } + + @Test + void processAheadOfTimeWhenPersistenceContextWithCustomPropertiesOnMethod() { + RegisteredBean registeredBean = registerBean( + CustomPropertiesPersistenceContextMethod.class); + testCompile(registeredBean, (actual, compiled) -> { + EntityManagerFactory entityManagerFactory = mock(EntityManagerFactory.class); + this.beanFactory.registerSingleton("entityManagerFactory", + entityManagerFactory); + CustomPropertiesPersistenceContextMethod instance = new CustomPropertiesPersistenceContextMethod(); + actual.accept(registeredBean, instance); + Field field = ReflectionUtils.findField( + CustomPropertiesPersistenceContextMethod.class, "entityManager"); + ReflectionUtils.makeAccessible(field); + EntityManager sharedEntityManager = (EntityManager) ReflectionUtils + .getField(field, instance); + InvocationHandler invocationHandler = Proxy + .getInvocationHandler(sharedEntityManager); + assertThat(invocationHandler).extracting("properties") + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsEntry("jpa.test", "value") + .containsEntry("jpa.test2", "value2"); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()) + .isEmpty(); + }); + } + + private RegisteredBean registerBean(Class beanClass) { + String beanName = "testBean"; + this.beanFactory.registerBeanDefinition(beanName, + new RootBeanDefinition(beanClass)); + return RegisteredBean.of(this.beanFactory, beanName); + } + + private void testCompile(RegisteredBean registeredBean, + BiConsumer, Compiled> result) { + PersistenceAnnotationBeanPostProcessor postProcessor = new PersistenceAnnotationBeanPostProcessor(); + BeanRegistrationAotContribution contribution = postProcessor + .processAheadOfTime(registeredBean); + BeanRegistrationCode beanRegistrationCode = mock(BeanRegistrationCode.class); + contribution.applyTo(generationContext, beanRegistrationCode); + TestCompiler.forSystem().withFiles(generatedFiles) + .compile(compiled -> result.accept(new Invoker(compiled), compiled)); + } + + static class Invoker implements BiConsumer { + + private Compiled compiled; + + Invoker(Compiled compiled) { + this.compiled = compiled; + } + + @Override + public void accept(RegisteredBean registeredBean, Object instance) { + List> compiledClasses = compiled.getAllCompiledClasses(); + assertThat(compiledClasses).hasSize(1); + Class compiledClass = compiledClasses.get(0); + for (Method method : ReflectionUtils.getDeclaredMethods(compiledClass)) { + if (method.getName().equals("apply")) { + ReflectionUtils.invokeMethod(method, null, registeredBean, instance); + return; + } + } + throw new IllegalStateException("Did not find apply method"); + } + + } + + static class DefaultPersistenceUnitField { + + @PersistenceUnit + public EntityManagerFactory emf; + + } + + static class DefaultPersistenceUnitMethod { + + @SuppressWarnings("unused") + private EntityManagerFactory emf; + + @PersistenceUnit + public void setEmf(EntityManagerFactory emf) { + this.emf = emf; + } + + } + + static class CustomUnitNamePublicPersistenceUnitMethod { + + @SuppressWarnings("unused") + private EntityManagerFactory emf; + + @PersistenceUnit(unitName = "custom") + public void setEmf(EntityManagerFactory emf) { + this.emf = emf; + } + + } + + static class DefaultPersistenceContextField { + + @SuppressWarnings("unused") + @PersistenceContext + private EntityManager entityManager; + + } + + static class CustomPropertiesPersistenceContextMethod { + + @SuppressWarnings("unused") + private EntityManager entityManager; + + @PersistenceContext( + properties = { @PersistenceProperty(name = "jpa.test", value = "value"), + @PersistenceProperty(name = "jpa.test2", value = "value2") }) + public void setEntityManager(EntityManager entityManager) { + this.entityManager = entityManager; + } + + } + +}