diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecorator.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecorator.java index 349091664..e39c9ffd4 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecorator.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecorator.java @@ -15,7 +15,9 @@ */ package org.springframework.data.repository.aot.generate; +import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.function.Supplier; @@ -23,14 +25,18 @@ import java.util.function.Supplier; import javax.lang.model.element.Modifier; import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.core.DefaultParameterNameDiscoverer; +import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; +import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata.ConstructorArgument; import org.springframework.data.repository.core.support.RepositoryComposition; import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; -import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeSpec; import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; /** @@ -44,14 +50,15 @@ import org.springframework.util.StringUtils; */ public class AotRepositoryBeanDefinitionPropertiesDecorator { - private static final Map RESERVED_TYPES; + static final Map RESERVED_TYPES; private final Supplier inheritedProperties; private final RepositoryContributor repositoryContributor; static { - RESERVED_TYPES = new LinkedHashMap<>(2); + RESERVED_TYPES = new LinkedHashMap<>(3); + RESERVED_TYPES.put(ResolvableType.forClass(BeanDefinition.class), "beanDefinition"); RESERVED_TYPES.put(ResolvableType.forClass(BeanFactory.class), "beanFactory"); RESERVED_TYPES.put(ResolvableType.forClass(RepositoryFactoryBeanSupport.FragmentCreationContext.class), "context"); } @@ -90,9 +97,17 @@ public class AotRepositoryBeanDefinitionPropertiesDecorator { MethodSpec.Builder callbackMethod = MethodSpec.methodBuilder("getRepositoryFragments").addModifiers(Modifier.PUBLIC) .returns(RepositoryComposition.RepositoryFragments.class); - for (Entry entry : RESERVED_TYPES.entrySet()) { - callbackMethod.addParameter(entry.getKey().toClass(), entry.getValue()); - } + ReflectionUtils.doWithMethods(RepositoryFactoryBeanSupport.RepositoryFragmentsFunction.class, it -> { + + for (int i = 0; i < it.getParameterCount(); i++) { + + MethodParameter parameter = new MethodParameter(it, i); + parameter.initParameterNameDiscovery(new DefaultParameterNameDiscoverer()); + + callbackMethod.addParameter(parameter.getParameterType(), parameter.getParameterName()); + } + + }, method -> method.getName().equals("getRepositoryFragments")); callbackMethod.addCode(buildCallbackBody()); @@ -109,26 +124,33 @@ public class AotRepositoryBeanDefinitionPropertiesDecorator { private CodeBlock buildCallbackBody() { CodeBlock.Builder callback = CodeBlock.builder(); + List arguments = new ArrayList<>(); - for (Entry entry : repositoryContributor.requiredArgs().entrySet()) { + for (Entry entry : repositoryContributor.getConstructorArguments().entrySet()) { - TypeName argumentType = TypeName.get(entry.getValue().getType()); - String reservedArgumentName = RESERVED_TYPES.get(entry.getValue()); - if (reservedArgumentName == null) { - callback.addStatement("$1T $2L = beanFactory.getBean($1T.class)", argumentType, entry.getKey()); - } else { + ConstructorArgument argument = entry.getValue(); + AotRepositoryConstructorBuilder.ParameterOrigin parameterOrigin = argument.parameterOrigin(); - if (reservedArgumentName.equals(entry.getKey())) { - continue; - } + String ref = parameterOrigin.getReference(); + CodeBlock codeBlock = parameterOrigin.getCodeBlock(); - callback.addStatement("$T $L = $L", argumentType, entry.getKey(), reservedArgumentName); + if (StringUtils.hasText(ref)) { + arguments.add(ref); + if (!codeBlock.isEmpty()) { + callback.add(codeBlock); + } + } else { + arguments.add(codeBlock); } } - callback.addStatement("return $T.just(new $L($L))", RepositoryComposition.RepositoryFragments.class, - repositoryContributor.getContributedTypeName().getCanonicalName(), - StringUtils.collectionToDelimitedString(repositoryContributor.requiredArgs().keySet(), ", ")); + List args = new ArrayList<>(); + args.add(RepositoryComposition.RepositoryFragments.class); + args.add(repositoryContributor.getContributedTypeName().getCanonicalName()); + args.addAll(arguments); + + callback.addStatement("return $T.just(new $L(%s%s))".formatted("$L".repeat(arguments.isEmpty() ? 0 : 1), + ", $L".repeat(Math.max(0, arguments.size() - 1))), args.toArray()); return callback.build(); } diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryConstructorBuilder.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryConstructorBuilder.java index d54303ae6..fe09a5192 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryConstructorBuilder.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryConstructorBuilder.java @@ -15,8 +15,15 @@ */ package org.springframework.data.repository.aot.generate; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.jspecify.annotations.Nullable; + +import org.springframework.beans.factory.config.BeanReference; import org.springframework.core.ResolvableType; import org.springframework.javapoet.CodeBlock; +import org.springframework.util.Assert; /** * Builder for AOT Repository Constructors. @@ -65,7 +72,31 @@ public interface AotRepositoryConstructorBuilder { * @param type parameter type. * @param bindToField whether to create a field for the parameter and assign its value to the field. */ - void addParameter(String parameterName, ResolvableType type, boolean bindToField); + default void addParameter(String parameterName, ResolvableType type, boolean bindToField) { + addParameter(parameterName, type, c -> c.bindToField(bindToField)); + } + + /** + * Add constructor parameter. + * + * @param parameterName name of the parameter. + * @param type parameter type. + * @param parameterCustomizer customizer for the parameter. + */ + default void addParameter(String parameterName, Class type, + Consumer parameterCustomizer) { + addParameter(parameterName, ResolvableType.forClass(type), parameterCustomizer); + } + + /** + * Add constructor parameter. + * + * @param parameterName name of the parameter. + * @param type parameter type. + * @param parameterCustomizer customizer for the parameter. + */ + void addParameter(String parameterName, ResolvableType type, + Consumer parameterCustomizer); /** * Add constructor body customizer. The customizer is invoked after adding constructor arguments and before assigning @@ -89,4 +120,147 @@ public interface AotRepositoryConstructorBuilder { } + /** + * Customizer for a AOT repository constructor parameter. + */ + interface ConstructorParameterCustomizer { + + /** + * Bind the constructor parameter to a field of the same type using the original parameter name. + * + * @return {@code this} for method chaining. + */ + default ConstructorParameterCustomizer bindToField() { + return bindToField(true); + } + + /** + * Bind the constructor parameter to a field of the same type using the original parameter name. + * + * @return {@code this} for method chaining. + */ + ConstructorParameterCustomizer bindToField(boolean bindToField); + + /** + * Use the given {@link BeanReference} to define the constructor parameter origin. Bean references can be by name, + * by type or by type and name. Using a bean reference renders a lookup to a local variable using the parameter name + * as guidance for the local variable name + * + * @see FragmentParameterContext#localVariable(String) + */ + void origin(BeanReference reference); + + /** + * Use the given {@link BeanReference} to define the constructor parameter origin. Bean references can be by name, + * by type or by type and name. + */ + void origin(Function originFunction); + + } + + /** + * Context to obtain a constructor parameter value when declaring the constructor parameter origin. + */ + interface FragmentParameterContext { + + /** + * @return variable name of the {@link org.springframework.beans.factory.BeanFactory}. + */ + String beanFactory(); + + /** + * @return parameter origin to obtain the {@link org.springframework.beans.factory.BeanFactory}. + */ + default ParameterOrigin getBeanFactory() { + return ParameterOrigin.ofReference(beanFactory()); + } + + /** + * @return variable name of the + * {@link org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport.FragmentCreationContext}. + */ + String fragmentCreationContext(); + + /** + * @return parameter origin to obtain the fragment creation context. + */ + default ParameterOrigin getFragmentCreationContext() { + return ParameterOrigin.ofReference(fragmentCreationContext()); + } + + /** + * Obtain a naming-clash free variant for the given logical variable name within the local method context. Returns + * the target variable name when called multiple times with the same {@code variableName}. + * + * @param variableName the logical variable name. + * @return the variable name used in the generated code. + */ + String localVariable(String variableName); + + } + + /** + * Interface describing the origin of a constructor parameter. The parameter value can be obtained either from a + * {@link #getReference() reference} variable, a {@link #getCodeBlock() code block} or a combination of both. + * + * @author Mark Paluch + * @since 4.0 + */ + interface ParameterOrigin { + + /** + * Construct a {@link ParameterOrigin} from the given {@link CodeBlock} and reference name. + * + * @param reference the reference name to obtain the parameter value from. + * @param codeBlock the code block that is required to set up the parameter value. + * @return a {@link ParameterOrigin} from the given {@link CodeBlock} and reference name. + */ + static ParameterOrigin of(String reference, CodeBlock codeBlock) { + + Assert.hasText(reference, "Parameter reference must not be empty"); + + return new DefaultParameterOrigin(reference, codeBlock); + } + + /** + * Construct a {@link ParameterOrigin} from the given {@link CodeBlock}. + * + * @param codeBlock the code block that produces the parameter value. + * @return a {@link ParameterOrigin} from the given {@link CodeBlock}. + */ + static ParameterOrigin of(CodeBlock codeBlock) { + + Assert.notNull(codeBlock, "CodeBlock reference must not be empty"); + + return new DefaultParameterOrigin(null, codeBlock); + } + + /** + * Construct a {@link ParameterOrigin} from the given reference name. + * + * @param reference the reference name of the variable to obtain the parameter value from. + * @return a {@link ParameterOrigin} from the given reference name. + */ + static ParameterOrigin ofReference(String reference) { + + Assert.hasText(reference, "Parameter reference must not be empty"); + + return of(reference, CodeBlock.builder().build()); + } + + /** + * Obtain the reference name to obtain the parameter value from. Can be {@code null} if the parameter value is + * solely obtained from the {@link #getCodeBlock() code block}. + * + * @return name of the reference or {@literal null} if absent. + */ + @Nullable + String getReference(); + + /** + * Obtain the code block to obtain the parameter value from. Never {@literal null}, can be empty. + */ + CodeBlock getCodeBlock(); + + } } diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java index d88d2556c..f8c0787c9 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java @@ -110,6 +110,10 @@ class AotRepositoryCreator { return autowireFields; } + Map getConstructorArguments() { + return generationMetadata.getConstructorArguments(); + } + RepositoryInformation getRepositoryInformation() { return repositoryInformation; } @@ -304,6 +308,8 @@ class AotRepositoryCreator { logger.trace("Skipping method [%s.%s] contribution, no MethodContributor available" .formatted(repositoryInformation.getRepositoryInterface().getName(), method.getName())); } + + return; } if (contributor.contributesMethodSpec() && !repositoryInformation.isReactiveRepository()) { @@ -313,21 +319,6 @@ class AotRepositoryCreator { } } - /** - * Customizer interface to customize the AOT repository fragment constructor through - * {@link AotRepositoryConstructorBuilder}. - */ - public interface ConstructorCustomizer { - - /** - * Apply customization ot the AOT repository fragment constructor. - * - * @param constructorBuilder the builder to be customized. - */ - void customize(AotRepositoryConstructorBuilder constructorBuilder); - - } - /** * Factory interface to conditionally create {@link MethodContributor} instances. An implementation may decide whether * to return a {@link MethodContributor} or {@literal null}, if no method (code or metadata) should be contributed. diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java index e70045a4c..9f6c465f6 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java @@ -20,10 +20,12 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; import java.util.Map.Entry; +import java.util.function.Supplier; import javax.lang.model.element.Modifier; import org.jspecify.annotations.Nullable; + import org.springframework.core.ResolvableType; import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.query.QueryMethod; @@ -89,16 +91,21 @@ class AotRepositoryFragmentMetadata { * * @param parameterName name of the constructor parameter to add. Must be unique. * @param type type of the constructor parameter. - * @param fieldName name of the field to bind the constructor parameter to, or {@literal null} if no field should be - * created. + * @param argumentSupplier supplier to create the constructor argument. */ - public void addConstructorArgument(String parameterName, ResolvableType type, @Nullable String fieldName) { + public void addConstructorArgument(String parameterName, ResolvableType type, + Supplier argumentSupplier) { - this.constructorArguments.putIfAbsent(parameterName, new ConstructorArgument(parameterName, type, fieldName)); + this.constructorArguments.computeIfAbsent(parameterName, it -> { - if (fieldName != null) { - addField(parameterName, type, Modifier.PRIVATE, Modifier.FINAL); - } + ConstructorArgument constructorArgument = argumentSupplier.get(); + + if (constructorArgument.isBoundToField()) { + addField(parameterName, type, Modifier.PRIVATE, Modifier.FINAL); + } + + return constructorArgument; + }); } public void addRepositoryMethod(Method source, MethodContributor methodContributor) { @@ -148,12 +155,13 @@ class AotRepositoryFragmentMetadata { * * @param parameterName * @param parameterType - * @param fieldName + * @param bindToField */ - public record ConstructorArgument(String parameterName, ResolvableType parameterType, @Nullable String fieldName) { + public record ConstructorArgument(String parameterName, ResolvableType parameterType, boolean bindToField, + AotRepositoryConstructorBuilder.ParameterOrigin parameterOrigin) { boolean isBoundToField() { - return fieldName != null; + return bindToField; } TypeName typeName() { diff --git a/src/main/java/org/springframework/data/repository/aot/generate/DefaultParameterOrigin.java b/src/main/java/org/springframework/data/repository/aot/generate/DefaultParameterOrigin.java new file mode 100644 index 000000000..c3494b40f --- /dev/null +++ b/src/main/java/org/springframework/data/repository/aot/generate/DefaultParameterOrigin.java @@ -0,0 +1,41 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.aot.generate; + +import org.jspecify.annotations.Nullable; + +import org.springframework.javapoet.CodeBlock; + +/** + * Default implementation of {@link AotRepositoryConstructorBuilder.ParameterOrigin}. + * + * @author Mark Paluch + * @since 4.0 + */ +record DefaultParameterOrigin(@Nullable String reference, + CodeBlock codeBlock) implements AotRepositoryConstructorBuilder.ParameterOrigin { + + @Override + public @Nullable String getReference() { + return reference(); + } + + @Override + public CodeBlock getCodeBlock() { + return codeBlock(); + } + +} diff --git a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryConstructorBuilder.java b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryConstructorBuilder.java index 1afae50b1..3193aa556 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryConstructorBuilder.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryConstructorBuilder.java @@ -15,16 +15,28 @@ */ package org.springframework.data.repository.aot.generate; -import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; import javax.lang.model.element.Modifier; +import org.jspecify.annotations.Nullable; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.config.BeanReference; +import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.core.ResolvableType; import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata.ConstructorArgument; +import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeName; import org.springframework.util.Assert; /** @@ -36,25 +48,138 @@ import org.springframework.util.Assert; */ class RepositoryConstructorBuilder implements AotRepositoryConstructorBuilder { + private final String beanFactory = AotRepositoryBeanDefinitionPropertiesDecorator.RESERVED_TYPES + .get(ResolvableType.forClass(BeanFactory.class)); + private final String fragmentCreationContext = AotRepositoryBeanDefinitionPropertiesDecorator.RESERVED_TYPES + .get(ResolvableType.forClass(RepositoryFactoryBeanSupport.FragmentCreationContext.class)); + private final AotRepositoryFragmentMetadata metadata; - private ConstructorCustomizer customizer = (builder) -> {}; - private final Set parametersAdded = new HashSet<>(); + private final Set parametersAdded = new LinkedHashSet<>(); + private final Map localVariables = new LinkedHashMap<>(); + private final VariableNameFactory variableNameFactory; + + // add super call with all parameters added + private ConstructorCustomizer customizer = (builder) -> { + builder.addStatement("super(%s%s)".formatted( // + "$L".repeat(parametersAdded.isEmpty() ? 0 : 1), // + ", $L".repeat(Math.max(0, parametersAdded.size() - 1))), parametersAdded.toArray()); + }; RepositoryConstructorBuilder(AotRepositoryFragmentMetadata metadata) { this.metadata = metadata; + this.variableNameFactory = new LocalVariableNameFactory( + AotRepositoryBeanDefinitionPropertiesDecorator.RESERVED_TYPES.values()); } - /** - * Add constructor parameter. - * - * @param parameterName name of the parameter. - * @param type parameter type. - * @param bindToField whether to create a field for the parameter and assign its value to the field. - */ @Override - public void addParameter(String parameterName, ResolvableType type, boolean bindToField) { + public void addParameter(String parameterName, ResolvableType type, + Consumer customizer) { + this.parametersAdded.add(parameterName); - this.metadata.addConstructorArgument(parameterName, type, bindToField ? parameterName : null); + + Supplier constructorArgumentSupplier = () -> { + + ConstructorParameterContext context = new ConstructorParameterContext(this::localVariable, parameterName, type); + customizer.accept(context); + + return new ConstructorArgument(parameterName, type, context.bindToField, context.getRequiredParameterOrigin()); + }; + + this.metadata.addConstructorArgument(parameterName, type, constructorArgumentSupplier); + } + + /** + * Context to customize a constructor parameter. + */ + class ConstructorParameterContext implements ConstructorParameterCustomizer { + + private final VariableNameFactory variableFactory; + private final String parameterName; + private final TypeName typeName; + + boolean bindToField; + @Nullable ParameterOrigin block; + + ConstructorParameterContext(VariableNameFactory variableFactory, String parameterName, + ResolvableType resolvableType) { + + this.variableFactory = variableFactory; + this.parameterName = parameterName; + this.typeName = AotRepositoryFragmentMetadata.typeNameOf(resolvableType); + + if (resolvableType.isAssignableFrom(BeanFactory.class)) { + origin(FragmentParameterContext::getBeanFactory); + } else if (resolvableType.isAssignableFrom(RepositoryFactoryBeanSupport.FragmentCreationContext.class)) { + origin(FragmentParameterContext::getFragmentCreationContext); + } else { + origin(new RuntimeBeanReference(resolvableType.toClass())); + } + } + + @Override + public ConstructorParameterCustomizer bindToField(boolean bindToField) { + this.bindToField = bindToField; + return this; + } + + @Override + public void origin(BeanReference reference) { + + origin(ctx -> { + + CodeBlock.Builder builder = CodeBlock.builder(); + String variableName = ctx.localVariable(parameterName); + + if (reference instanceof RuntimeBeanReference rbr && rbr.getBeanType() != null) { + + if (rbr.getBeanName().equals(rbr.getBeanType().getName())) { + builder.addStatement("$1T $2L = $3L.getBean($4T.class)", typeName, variableName, ctx.beanFactory(), + rbr.getBeanType()); + } else { + builder.addStatement("$1T $2L = $3L.getBean($4S, $5T.class)", typeName, variableName, ctx.beanFactory(), + rbr.getBeanName(), rbr.getBeanType()); + } + } else { + builder.addStatement("$1T $2L = ($1T) $3L.getBean($4S)", typeName, variableName, ctx.beanFactory(), + reference.getBeanName()); + } + + return ParameterOrigin.of(variableName, builder.build()); + }); + } + + @Override + public void origin(Function originFunction) { + + FragmentParameterContext ctx = new FragmentParameterContext() { + + @Override + public String beanFactory() { + return beanFactory; + } + + @Override + public String fragmentCreationContext() { + return fragmentCreationContext; + } + + @Override + public String localVariable(String variableName) { + return variableFactory.generateName(variableName); + } + }; + + this.block = originFunction.apply(ctx); + + Assert.state(block != null, "Resulting ParameterOriginBlock must not be null"); + } + + public ParameterOrigin getRequiredParameterOrigin() { + + Assert.state(block != null, "ParameterOriginBlock must not be null"); + + return block; + } } /** @@ -70,6 +195,17 @@ class RepositoryConstructorBuilder implements AotRepositoryConstructorBuilder { this.customizer = customizer; } + /** + * Obtain a naming-clash free variant for the given logical variable name within the local method context. Returns the + * target variable name when called multiple times with the same {@code variableName}. + * + * @param variableName the logical variable name. + * @return the variable name used in the generated code. + */ + public String localVariable(String variableName) { + return localVariables.computeIfAbsent(variableName, variableNameFactory::generateName); + } + public MethodSpec buildConstructor() { MethodSpec.Builder builder = MethodSpec.constructorBuilder().addModifiers(Modifier.PUBLIC); diff --git a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java index 60b54a43c..2ed9945df 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java @@ -35,6 +35,7 @@ import org.springframework.data.repository.config.AotRepositoryContext; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.query.QueryMethod; import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.TypeSpec; /** * Contributor for AOT repository fragments. @@ -107,6 +108,17 @@ public class RepositoryContributor { return Collections.unmodifiableMap(creator.getAutowireFields()); } + /** + * Get the required constructor arguments for the to be generated repository implementation. + *

+ * Can be overridden if required. Needs to match arguments of generated repository implementation. + * + * @return key/value pairs of required argument required to instantiate the generated fragment. + */ + java.util.Map getConstructorArguments() { + return Collections.unmodifiableMap(creator.getConstructorArguments()); + } + /** * Contribute the AOT repository fragment to the given {@link GenerationContext}. This method will prepare the * metadata, generate the source code and write it to the {@link GenerationContext}. @@ -142,13 +154,14 @@ public class RepositoryContributor { ------------------- """.formatted(aotBundle.repositoryJsonFileName(), repositoryJson)); - JavaFile javaFile = JavaFile.builder(creator.packageName(), targetTypeSpec.build()).build(); + TypeSpec typeSpec = targetTypeSpec.build(); + JavaFile javaFile = JavaFile.builder(creator.packageName(), typeSpec).build(); logger.trace(""" ------ AOT Generated Repository: %s ------ %s ------------------- - """.formatted(null, javaFile)); + """.formatted(typeSpec.name(), javaFile)); } generationContext.getGeneratedFiles().addResourceFile(aotBundle.repositoryJsonFileName(), repositoryJson); diff --git a/src/main/java/org/springframework/data/repository/aot/generate/VariableNameFactory.java b/src/main/java/org/springframework/data/repository/aot/generate/VariableNameFactory.java index c5e4047fc..491f888a3 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/VariableNameFactory.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/VariableNameFactory.java @@ -24,6 +24,7 @@ import org.springframework.lang.CheckReturnValue; * @author Christoph Strobl * @since 4.0 */ +@FunctionalInterface interface VariableNameFactory { /** diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests.java index fb920625d..dadad5311 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests.java @@ -18,9 +18,7 @@ package org.springframework.data.repository.aot.generate; import static org.assertj.core.api.Assertions.*; import static org.mockito.Mockito.*; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.Set; import org.junit.jupiter.api.BeforeEach; @@ -31,6 +29,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.aot.generate.GeneratedTypeReference; import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.core.ResolvableType; import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.data.util.Version; @@ -41,6 +40,7 @@ import org.springframework.javapoet.CodeBlock; * Unit testa for {@link AotRepositoryBeanDefinitionPropertiesDecorator}. * * @author Christoph Strobl + * @author Mark Paluch */ @ExtendWith(MockitoExtension.class) class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -48,7 +48,8 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { private static final String TYPE_NAME = "com.example.UserRepositoryImpl__AotRepository"; @Mock RepositoryContributor contributor; CodeBlock.Builder inheritedSource; - + AotRepositoryFragmentMetadata metadata = new AotRepositoryFragmentMetadata(); + RepositoryConstructorBuilder constructorBuilder = new RepositoryConstructorBuilder(metadata); AotRepositoryBeanDefinitionPropertiesDecorator decorator; @BeforeEach @@ -57,6 +58,8 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { when(contributor.getContributedTypeName()).thenReturn(GeneratedTypeReference.of(ClassName.bestGuess(TYPE_NAME))); inheritedSource = CodeBlock.builder(); decorator = new AotRepositoryBeanDefinitionPropertiesDecorator(() -> inheritedSource.build(), contributor); + + when(contributor.getConstructorArguments()).thenReturn(metadata.getConstructorArguments()); } @Test // GH-3344 @@ -64,7 +67,6 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { inheritedSource.add("beanDefinition.getPropertyValues().addPropertyValue($S, $S)", "repositoryBaseClass", "org.springframework.data.BaseRepository"); - when(contributor.requiredArgs()).thenReturn(Map.of()); CodeBlock decorate = decorator.decorate(); @@ -76,8 +78,6 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @Test // GH-3344 void addsFragmentFunction() { - when(contributor.requiredArgs()).thenReturn(Map.of()); - CodeBlock decorate = decorator.decorate(); assertThat(decorate.toString()) // @@ -90,8 +90,6 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @Test // GH-3344 void addsPlainNoArgConstructorForEmptyArgs() { - when(contributor.requiredArgs()).thenReturn(Map.of()); - CodeBlock decorate = decorator.decorate(); assertThat(decorate.toString()) // @@ -101,12 +99,9 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @Test // GH-3344 void resolvesAndAddsArgumentsForCtor() { - Map ctorArgs = new LinkedHashMap<>(3); - ctorArgs.put("plain", ResolvableType.forClass(Version.class)); - ctorArgs.put("noGenericsDefined", ResolvableType.forClass(List.class)); - ctorArgs.put("withGenerics", ResolvableType.forClassWithGenerics(Set.class, String.class)); - - when(contributor.requiredArgs()).thenReturn(ctorArgs); + constructorBuilder.addParameter("plain", ResolvableType.forClass(Version.class)); + constructorBuilder.addParameter("noGenericsDefined", ResolvableType.forClass(List.class)); + constructorBuilder.addParameter("withGenerics", ResolvableType.forClassWithGenerics(Set.class, String.class)); CodeBlock decorate = decorator.decorate(); @@ -118,10 +113,33 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { .formatted(TYPE_NAME)); } + @Test // GH-3344 + void resolvesValueFromCodeblock() { + + constructorBuilder.addParameter("byTypeAndName", ResolvableType.forClass(Version.class), customizer -> { + customizer.origin(new RuntimeBeanReference("foo", Version.class)); + }); + + constructorBuilder.addParameter("byName", ResolvableType.forClass(Version.class), customizer -> { + customizer.origin(new RuntimeBeanReference("bar")); + }); + + constructorBuilder.addParameter("foo", Integer.class, + customizer -> customizer.origin(ctx -> AotRepositoryConstructorBuilder.ParameterOrigin.of(CodeBlock.of("1")))); + + CodeBlock decorate = decorator.decorate(); + + assertThat(decorate.toString()) // + .containsSubsequence("Version byTypeAndName = beanFactory.getBean(\"foo\"", "Version.class)") // + .containsSubsequence("Version byName = ", "Version) beanFactory.getBean(\"bar\"") // + .containsSubsequence("return ", "RepositoryFragments.just(new %s(byTypeAndName, byName, 1));" // + .formatted(TYPE_NAME)); + } + @Test // GH-3344 void passesOnBeanFactoryIfRequested() { - when(contributor.requiredArgs()).thenReturn(Map.of("beanFactory", ResolvableType.forClass(BeanFactory.class))); + constructorBuilder.addParameter("beanFactory", BeanFactory.class); CodeBlock decorate = decorator.decorate(); @@ -133,8 +151,7 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @Test // GH-3344 void passesOnContextIfRequested() { - when(contributor.requiredArgs()).thenReturn( - Map.of("context", ResolvableType.forClass(RepositoryFactoryBeanSupport.FragmentCreationContext.class))); + constructorBuilder.addParameter("context", RepositoryFactoryBeanSupport.FragmentCreationContext.class); CodeBlock decorate = decorator.decorate(); @@ -146,27 +163,23 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @Test // GH-3344 void passesOnContextWithDifferentNameIfRequested() { - when(contributor.requiredArgs()).thenReturn( - Map.of("theContext", ResolvableType.forClass(RepositoryFactoryBeanSupport.FragmentCreationContext.class))); + constructorBuilder.addParameter("theContext", RepositoryFactoryBeanSupport.FragmentCreationContext.class); CodeBlock decorate = decorator.decorate(); assertThat(decorate.toString()) // .doesNotContain("FragmentCreationContext theContext = beanFactory.getBean(") // - .containsSubsequence("FragmentCreationContext theContext = context") // - .containsSubsequence("return ", "RepositoryFragments.just(new %s(theContext));".formatted(TYPE_NAME)); + .containsSubsequence("return ", "RepositoryFragments.just(new %s(context));".formatted(TYPE_NAME)); } @Test // GH-3344 void passesOnBeanFactoryDifferentNameIfRequested() { - when(contributor.requiredArgs()).thenReturn(Map.of("myBeanFactory", ResolvableType.forClass(BeanFactory.class))); + constructorBuilder.addParameter("myBeanFactory", BeanFactory.class); CodeBlock decorate = decorator.decorate(); assertThat(decorate.toString()) // - .doesNotContain("BeanFactory myBeanFactory = beanFactory.getBean(") // - .containsSubsequence("BeanFactory myBeanFactory = beanFactory") // - .containsSubsequence("return ", "RepositoryFragments.just(new %s(myBeanFactory));".formatted(TYPE_NAME)); + .containsSubsequence("return ", "RepositoryFragments.just(new %s(beanFactory));".formatted(TYPE_NAME)); } }