Browse Source

Polishing.

Add configuration for constructor origin.

See #3344
Original pull request: #3351
issue/3353
Mark Paluch 4 months ago
parent
commit
2aeaa707f1
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 60
      src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecorator.java
  2. 176
      src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryConstructorBuilder.java
  3. 21
      src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java
  4. 28
      src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java
  5. 41
      src/main/java/org/springframework/data/repository/aot/generate/DefaultParameterOrigin.java
  6. 160
      src/main/java/org/springframework/data/repository/aot/generate/RepositoryConstructorBuilder.java
  7. 17
      src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java
  8. 1
      src/main/java/org/springframework/data/repository/aot/generate/VariableNameFactory.java
  9. 63
      src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests.java

60
src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecorator.java

@ -15,7 +15,9 @@ @@ -15,7 +15,9 @@
*/
package org.springframework.data.repository.aot.generate;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.Supplier;
@ -23,14 +25,18 @@ import java.util.function.Supplier; @@ -23,14 +25,18 @@ import java.util.function.Supplier;
import javax.lang.model.element.Modifier;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata.ConstructorArgument;
import org.springframework.data.repository.core.support.RepositoryComposition;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeName;
import org.springframework.javapoet.TypeSpec;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
/**
@ -44,14 +50,15 @@ import org.springframework.util.StringUtils; @@ -44,14 +50,15 @@ import org.springframework.util.StringUtils;
*/
public class AotRepositoryBeanDefinitionPropertiesDecorator {
private static final Map<ResolvableType, String> RESERVED_TYPES;
static final Map<ResolvableType, String> RESERVED_TYPES;
private final Supplier<CodeBlock> inheritedProperties;
private final RepositoryContributor repositoryContributor;
static {
RESERVED_TYPES = new LinkedHashMap<>(2);
RESERVED_TYPES = new LinkedHashMap<>(3);
RESERVED_TYPES.put(ResolvableType.forClass(BeanDefinition.class), "beanDefinition");
RESERVED_TYPES.put(ResolvableType.forClass(BeanFactory.class), "beanFactory");
RESERVED_TYPES.put(ResolvableType.forClass(RepositoryFactoryBeanSupport.FragmentCreationContext.class), "context");
}
@ -90,9 +97,17 @@ public class AotRepositoryBeanDefinitionPropertiesDecorator { @@ -90,9 +97,17 @@ public class AotRepositoryBeanDefinitionPropertiesDecorator {
MethodSpec.Builder callbackMethod = MethodSpec.methodBuilder("getRepositoryFragments").addModifiers(Modifier.PUBLIC)
.returns(RepositoryComposition.RepositoryFragments.class);
for (Entry<ResolvableType, String> entry : RESERVED_TYPES.entrySet()) {
callbackMethod.addParameter(entry.getKey().toClass(), entry.getValue());
}
ReflectionUtils.doWithMethods(RepositoryFactoryBeanSupport.RepositoryFragmentsFunction.class, it -> {
for (int i = 0; i < it.getParameterCount(); i++) {
MethodParameter parameter = new MethodParameter(it, i);
parameter.initParameterNameDiscovery(new DefaultParameterNameDiscoverer());
callbackMethod.addParameter(parameter.getParameterType(), parameter.getParameterName());
}
}, method -> method.getName().equals("getRepositoryFragments"));
callbackMethod.addCode(buildCallbackBody());
@ -109,26 +124,33 @@ public class AotRepositoryBeanDefinitionPropertiesDecorator { @@ -109,26 +124,33 @@ public class AotRepositoryBeanDefinitionPropertiesDecorator {
private CodeBlock buildCallbackBody() {
CodeBlock.Builder callback = CodeBlock.builder();
List<Object> arguments = new ArrayList<>();
for (Entry<String, ResolvableType> entry : repositoryContributor.requiredArgs().entrySet()) {
for (Entry<String, ConstructorArgument> entry : repositoryContributor.getConstructorArguments().entrySet()) {
TypeName argumentType = TypeName.get(entry.getValue().getType());
String reservedArgumentName = RESERVED_TYPES.get(entry.getValue());
if (reservedArgumentName == null) {
callback.addStatement("$1T $2L = beanFactory.getBean($1T.class)", argumentType, entry.getKey());
} else {
ConstructorArgument argument = entry.getValue();
AotRepositoryConstructorBuilder.ParameterOrigin parameterOrigin = argument.parameterOrigin();
if (reservedArgumentName.equals(entry.getKey())) {
continue;
}
String ref = parameterOrigin.getReference();
CodeBlock codeBlock = parameterOrigin.getCodeBlock();
callback.addStatement("$T $L = $L", argumentType, entry.getKey(), reservedArgumentName);
if (StringUtils.hasText(ref)) {
arguments.add(ref);
if (!codeBlock.isEmpty()) {
callback.add(codeBlock);
}
} else {
arguments.add(codeBlock);
}
}
callback.addStatement("return $T.just(new $L($L))", RepositoryComposition.RepositoryFragments.class,
repositoryContributor.getContributedTypeName().getCanonicalName(),
StringUtils.collectionToDelimitedString(repositoryContributor.requiredArgs().keySet(), ", "));
List<Object> args = new ArrayList<>();
args.add(RepositoryComposition.RepositoryFragments.class);
args.add(repositoryContributor.getContributedTypeName().getCanonicalName());
args.addAll(arguments);
callback.addStatement("return $T.just(new $L(%s%s))".formatted("$L".repeat(arguments.isEmpty() ? 0 : 1),
", $L".repeat(Math.max(0, arguments.size() - 1))), args.toArray());
return callback.build();
}

176
src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryConstructorBuilder.java

@ -15,8 +15,15 @@ @@ -15,8 +15,15 @@
*/
package org.springframework.data.repository.aot.generate;
import java.util.function.Consumer;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.core.ResolvableType;
import org.springframework.javapoet.CodeBlock;
import org.springframework.util.Assert;
/**
* Builder for AOT Repository Constructors.
@ -65,7 +72,31 @@ public interface AotRepositoryConstructorBuilder { @@ -65,7 +72,31 @@ public interface AotRepositoryConstructorBuilder {
* @param type parameter type.
* @param bindToField whether to create a field for the parameter and assign its value to the field.
*/
void addParameter(String parameterName, ResolvableType type, boolean bindToField);
default void addParameter(String parameterName, ResolvableType type, boolean bindToField) {
addParameter(parameterName, type, c -> c.bindToField(bindToField));
}
/**
* Add constructor parameter.
*
* @param parameterName name of the parameter.
* @param type parameter type.
* @param parameterCustomizer customizer for the parameter.
*/
default void addParameter(String parameterName, Class<?> type,
Consumer<ConstructorParameterCustomizer> parameterCustomizer) {
addParameter(parameterName, ResolvableType.forClass(type), parameterCustomizer);
}
/**
* Add constructor parameter.
*
* @param parameterName name of the parameter.
* @param type parameter type.
* @param parameterCustomizer customizer for the parameter.
*/
void addParameter(String parameterName, ResolvableType type,
Consumer<ConstructorParameterCustomizer> parameterCustomizer);
/**
* Add constructor body customizer. The customizer is invoked after adding constructor arguments and before assigning
@ -89,4 +120,147 @@ public interface AotRepositoryConstructorBuilder { @@ -89,4 +120,147 @@ public interface AotRepositoryConstructorBuilder {
}
/**
* Customizer for a AOT repository constructor parameter.
*/
interface ConstructorParameterCustomizer {
/**
* Bind the constructor parameter to a field of the same type using the original parameter name.
*
* @return {@code this} for method chaining.
*/
default ConstructorParameterCustomizer bindToField() {
return bindToField(true);
}
/**
* Bind the constructor parameter to a field of the same type using the original parameter name.
*
* @return {@code this} for method chaining.
*/
ConstructorParameterCustomizer bindToField(boolean bindToField);
/**
* Use the given {@link BeanReference} to define the constructor parameter origin. Bean references can be by name,
* by type or by type and name. Using a bean reference renders a lookup to a local variable using the parameter name
* as guidance for the local variable name
*
* @see FragmentParameterContext#localVariable(String)
*/
void origin(BeanReference reference);
/**
* Use the given {@link BeanReference} to define the constructor parameter origin. Bean references can be by name,
* by type or by type and name.
*/
void origin(Function<FragmentParameterContext, ParameterOrigin> originFunction);
}
/**
* Context to obtain a constructor parameter value when declaring the constructor parameter origin.
*/
interface FragmentParameterContext {
/**
* @return variable name of the {@link org.springframework.beans.factory.BeanFactory}.
*/
String beanFactory();
/**
* @return parameter origin to obtain the {@link org.springframework.beans.factory.BeanFactory}.
*/
default ParameterOrigin getBeanFactory() {
return ParameterOrigin.ofReference(beanFactory());
}
/**
* @return variable name of the
* {@link org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport.FragmentCreationContext}.
*/
String fragmentCreationContext();
/**
* @return parameter origin to obtain the fragment creation context.
*/
default ParameterOrigin getFragmentCreationContext() {
return ParameterOrigin.ofReference(fragmentCreationContext());
}
/**
* Obtain a naming-clash free variant for the given logical variable name within the local method context. Returns
* the target variable name when called multiple times with the same {@code variableName}.
*
* @param variableName the logical variable name.
* @return the variable name used in the generated code.
*/
String localVariable(String variableName);
}
/**
* Interface describing the origin of a constructor parameter. The parameter value can be obtained either from a
* {@link #getReference() reference} variable, a {@link #getCodeBlock() code block} or a combination of both.
*
* @author Mark Paluch
* @since 4.0
*/
interface ParameterOrigin {
/**
* Construct a {@link ParameterOrigin} from the given {@link CodeBlock} and reference name.
*
* @param reference the reference name to obtain the parameter value from.
* @param codeBlock the code block that is required to set up the parameter value.
* @return a {@link ParameterOrigin} from the given {@link CodeBlock} and reference name.
*/
static ParameterOrigin of(String reference, CodeBlock codeBlock) {
Assert.hasText(reference, "Parameter reference must not be empty");
return new DefaultParameterOrigin(reference, codeBlock);
}
/**
* Construct a {@link ParameterOrigin} from the given {@link CodeBlock}.
*
* @param codeBlock the code block that produces the parameter value.
* @return a {@link ParameterOrigin} from the given {@link CodeBlock}.
*/
static ParameterOrigin of(CodeBlock codeBlock) {
Assert.notNull(codeBlock, "CodeBlock reference must not be empty");
return new DefaultParameterOrigin(null, codeBlock);
}
/**
* Construct a {@link ParameterOrigin} from the given reference name.
*
* @param reference the reference name of the variable to obtain the parameter value from.
* @return a {@link ParameterOrigin} from the given reference name.
*/
static ParameterOrigin ofReference(String reference) {
Assert.hasText(reference, "Parameter reference must not be empty");
return of(reference, CodeBlock.builder().build());
}
/**
* Obtain the reference name to obtain the parameter value from. Can be {@code null} if the parameter value is
* solely obtained from the {@link #getCodeBlock() code block}.
*
* @return name of the reference or {@literal null} if absent.
*/
@Nullable
String getReference();
/**
* Obtain the code block to obtain the parameter value from. Never {@literal null}, can be empty.
*/
CodeBlock getCodeBlock();
}
}

21
src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java

@ -110,6 +110,10 @@ class AotRepositoryCreator { @@ -110,6 +110,10 @@ class AotRepositoryCreator {
return autowireFields;
}
Map<String, ConstructorArgument> getConstructorArguments() {
return generationMetadata.getConstructorArguments();
}
RepositoryInformation getRepositoryInformation() {
return repositoryInformation;
}
@ -304,6 +308,8 @@ class AotRepositoryCreator { @@ -304,6 +308,8 @@ class AotRepositoryCreator {
logger.trace("Skipping method [%s.%s] contribution, no MethodContributor available"
.formatted(repositoryInformation.getRepositoryInterface().getName(), method.getName()));
}
return;
}
if (contributor.contributesMethodSpec() && !repositoryInformation.isReactiveRepository()) {
@ -313,21 +319,6 @@ class AotRepositoryCreator { @@ -313,21 +319,6 @@ class AotRepositoryCreator {
}
}
/**
* Customizer interface to customize the AOT repository fragment constructor through
* {@link AotRepositoryConstructorBuilder}.
*/
public interface ConstructorCustomizer {
/**
* Apply customization ot the AOT repository fragment constructor.
*
* @param constructorBuilder the builder to be customized.
*/
void customize(AotRepositoryConstructorBuilder constructorBuilder);
}
/**
* Factory interface to conditionally create {@link MethodContributor} instances. An implementation may decide whether
* to return a {@link MethodContributor} or {@literal null}, if no method (code or metadata) should be contributed.

28
src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java

@ -20,10 +20,12 @@ import java.util.HashMap; @@ -20,10 +20,12 @@ import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.Supplier;
import javax.lang.model.element.Modifier;
import org.jspecify.annotations.Nullable;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.repository.query.QueryMethod;
@ -89,16 +91,21 @@ class AotRepositoryFragmentMetadata { @@ -89,16 +91,21 @@ class AotRepositoryFragmentMetadata {
*
* @param parameterName name of the constructor parameter to add. Must be unique.
* @param type type of the constructor parameter.
* @param fieldName name of the field to bind the constructor parameter to, or {@literal null} if no field should be
* created.
* @param argumentSupplier supplier to create the constructor argument.
*/
public void addConstructorArgument(String parameterName, ResolvableType type, @Nullable String fieldName) {
public void addConstructorArgument(String parameterName, ResolvableType type,
Supplier<ConstructorArgument> argumentSupplier) {
this.constructorArguments.putIfAbsent(parameterName, new ConstructorArgument(parameterName, type, fieldName));
this.constructorArguments.computeIfAbsent(parameterName, it -> {
if (fieldName != null) {
addField(parameterName, type, Modifier.PRIVATE, Modifier.FINAL);
}
ConstructorArgument constructorArgument = argumentSupplier.get();
if (constructorArgument.isBoundToField()) {
addField(parameterName, type, Modifier.PRIVATE, Modifier.FINAL);
}
return constructorArgument;
});
}
public void addRepositoryMethod(Method source, MethodContributor<? extends QueryMethod> methodContributor) {
@ -148,12 +155,13 @@ class AotRepositoryFragmentMetadata { @@ -148,12 +155,13 @@ class AotRepositoryFragmentMetadata {
*
* @param parameterName
* @param parameterType
* @param fieldName
* @param bindToField
*/
public record ConstructorArgument(String parameterName, ResolvableType parameterType, @Nullable String fieldName) {
public record ConstructorArgument(String parameterName, ResolvableType parameterType, boolean bindToField,
AotRepositoryConstructorBuilder.ParameterOrigin parameterOrigin) {
boolean isBoundToField() {
return fieldName != null;
return bindToField;
}
TypeName typeName() {

41
src/main/java/org/springframework/data/repository/aot/generate/DefaultParameterOrigin.java

@ -0,0 +1,41 @@ @@ -0,0 +1,41 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.aot.generate;
import org.jspecify.annotations.Nullable;
import org.springframework.javapoet.CodeBlock;
/**
* Default implementation of {@link AotRepositoryConstructorBuilder.ParameterOrigin}.
*
* @author Mark Paluch
* @since 4.0
*/
record DefaultParameterOrigin(@Nullable String reference,
CodeBlock codeBlock) implements AotRepositoryConstructorBuilder.ParameterOrigin {
@Override
public @Nullable String getReference() {
return reference();
}
@Override
public CodeBlock getCodeBlock() {
return codeBlock();
}
}

160
src/main/java/org/springframework/data/repository/aot/generate/RepositoryConstructorBuilder.java

@ -15,16 +15,28 @@ @@ -15,16 +15,28 @@
*/
package org.springframework.data.repository.aot.generate;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.lang.model.element.Modifier;
import org.jspecify.annotations.Nullable;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata.ConstructorArgument;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeName;
import org.springframework.util.Assert;
/**
@ -36,25 +48,138 @@ import org.springframework.util.Assert; @@ -36,25 +48,138 @@ import org.springframework.util.Assert;
*/
class RepositoryConstructorBuilder implements AotRepositoryConstructorBuilder {
private final String beanFactory = AotRepositoryBeanDefinitionPropertiesDecorator.RESERVED_TYPES
.get(ResolvableType.forClass(BeanFactory.class));
private final String fragmentCreationContext = AotRepositoryBeanDefinitionPropertiesDecorator.RESERVED_TYPES
.get(ResolvableType.forClass(RepositoryFactoryBeanSupport.FragmentCreationContext.class));
private final AotRepositoryFragmentMetadata metadata;
private ConstructorCustomizer customizer = (builder) -> {};
private final Set<String> parametersAdded = new HashSet<>();
private final Set<String> parametersAdded = new LinkedHashSet<>();
private final Map<String, String> localVariables = new LinkedHashMap<>();
private final VariableNameFactory variableNameFactory;
// add super call with all parameters added
private ConstructorCustomizer customizer = (builder) -> {
builder.addStatement("super(%s%s)".formatted( //
"$L".repeat(parametersAdded.isEmpty() ? 0 : 1), //
", $L".repeat(Math.max(0, parametersAdded.size() - 1))), parametersAdded.toArray());
};
RepositoryConstructorBuilder(AotRepositoryFragmentMetadata metadata) {
this.metadata = metadata;
this.variableNameFactory = new LocalVariableNameFactory(
AotRepositoryBeanDefinitionPropertiesDecorator.RESERVED_TYPES.values());
}
/**
* Add constructor parameter.
*
* @param parameterName name of the parameter.
* @param type parameter type.
* @param bindToField whether to create a field for the parameter and assign its value to the field.
*/
@Override
public void addParameter(String parameterName, ResolvableType type, boolean bindToField) {
public void addParameter(String parameterName, ResolvableType type,
Consumer<ConstructorParameterCustomizer> customizer) {
this.parametersAdded.add(parameterName);
this.metadata.addConstructorArgument(parameterName, type, bindToField ? parameterName : null);
Supplier<ConstructorArgument> constructorArgumentSupplier = () -> {
ConstructorParameterContext context = new ConstructorParameterContext(this::localVariable, parameterName, type);
customizer.accept(context);
return new ConstructorArgument(parameterName, type, context.bindToField, context.getRequiredParameterOrigin());
};
this.metadata.addConstructorArgument(parameterName, type, constructorArgumentSupplier);
}
/**
* Context to customize a constructor parameter.
*/
class ConstructorParameterContext implements ConstructorParameterCustomizer {
private final VariableNameFactory variableFactory;
private final String parameterName;
private final TypeName typeName;
boolean bindToField;
@Nullable ParameterOrigin block;
ConstructorParameterContext(VariableNameFactory variableFactory, String parameterName,
ResolvableType resolvableType) {
this.variableFactory = variableFactory;
this.parameterName = parameterName;
this.typeName = AotRepositoryFragmentMetadata.typeNameOf(resolvableType);
if (resolvableType.isAssignableFrom(BeanFactory.class)) {
origin(FragmentParameterContext::getBeanFactory);
} else if (resolvableType.isAssignableFrom(RepositoryFactoryBeanSupport.FragmentCreationContext.class)) {
origin(FragmentParameterContext::getFragmentCreationContext);
} else {
origin(new RuntimeBeanReference(resolvableType.toClass()));
}
}
@Override
public ConstructorParameterCustomizer bindToField(boolean bindToField) {
this.bindToField = bindToField;
return this;
}
@Override
public void origin(BeanReference reference) {
origin(ctx -> {
CodeBlock.Builder builder = CodeBlock.builder();
String variableName = ctx.localVariable(parameterName);
if (reference instanceof RuntimeBeanReference rbr && rbr.getBeanType() != null) {
if (rbr.getBeanName().equals(rbr.getBeanType().getName())) {
builder.addStatement("$1T $2L = $3L.getBean($4T.class)", typeName, variableName, ctx.beanFactory(),
rbr.getBeanType());
} else {
builder.addStatement("$1T $2L = $3L.getBean($4S, $5T.class)", typeName, variableName, ctx.beanFactory(),
rbr.getBeanName(), rbr.getBeanType());
}
} else {
builder.addStatement("$1T $2L = ($1T) $3L.getBean($4S)", typeName, variableName, ctx.beanFactory(),
reference.getBeanName());
}
return ParameterOrigin.of(variableName, builder.build());
});
}
@Override
public void origin(Function<FragmentParameterContext, ParameterOrigin> originFunction) {
FragmentParameterContext ctx = new FragmentParameterContext() {
@Override
public String beanFactory() {
return beanFactory;
}
@Override
public String fragmentCreationContext() {
return fragmentCreationContext;
}
@Override
public String localVariable(String variableName) {
return variableFactory.generateName(variableName);
}
};
this.block = originFunction.apply(ctx);
Assert.state(block != null, "Resulting ParameterOriginBlock must not be null");
}
public ParameterOrigin getRequiredParameterOrigin() {
Assert.state(block != null, "ParameterOriginBlock must not be null");
return block;
}
}
/**
@ -70,6 +195,17 @@ class RepositoryConstructorBuilder implements AotRepositoryConstructorBuilder { @@ -70,6 +195,17 @@ class RepositoryConstructorBuilder implements AotRepositoryConstructorBuilder {
this.customizer = customizer;
}
/**
* Obtain a naming-clash free variant for the given logical variable name within the local method context. Returns the
* target variable name when called multiple times with the same {@code variableName}.
*
* @param variableName the logical variable name.
* @return the variable name used in the generated code.
*/
public String localVariable(String variableName) {
return localVariables.computeIfAbsent(variableName, variableNameFactory::generateName);
}
public MethodSpec buildConstructor() {
MethodSpec.Builder builder = MethodSpec.constructorBuilder().addModifiers(Modifier.PUBLIC);

17
src/main/java/org/springframework/data/repository/aot/generate/RepositoryContributor.java

@ -35,6 +35,7 @@ import org.springframework.data.repository.config.AotRepositoryContext; @@ -35,6 +35,7 @@ import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.TypeSpec;
/**
* Contributor for AOT repository fragments.
@ -107,6 +108,17 @@ public class RepositoryContributor { @@ -107,6 +108,17 @@ public class RepositoryContributor {
return Collections.unmodifiableMap(creator.getAutowireFields());
}
/**
* Get the required constructor arguments for the to be generated repository implementation.
* <p>
* Can be overridden if required. Needs to match arguments of generated repository implementation.
*
* @return key/value pairs of required argument required to instantiate the generated fragment.
*/
java.util.Map<String, AotRepositoryFragmentMetadata.ConstructorArgument> getConstructorArguments() {
return Collections.unmodifiableMap(creator.getConstructorArguments());
}
/**
* Contribute the AOT repository fragment to the given {@link GenerationContext}. This method will prepare the
* metadata, generate the source code and write it to the {@link GenerationContext}.
@ -142,13 +154,14 @@ public class RepositoryContributor { @@ -142,13 +154,14 @@ public class RepositoryContributor {
-------------------
""".formatted(aotBundle.repositoryJsonFileName(), repositoryJson));
JavaFile javaFile = JavaFile.builder(creator.packageName(), targetTypeSpec.build()).build();
TypeSpec typeSpec = targetTypeSpec.build();
JavaFile javaFile = JavaFile.builder(creator.packageName(), typeSpec).build();
logger.trace("""
------ AOT Generated Repository: %s ------
%s
-------------------
""".formatted(null, javaFile));
""".formatted(typeSpec.name(), javaFile));
}
generationContext.getGeneratedFiles().addResourceFile(aotBundle.repositoryJsonFileName(), repositoryJson);

1
src/main/java/org/springframework/data/repository/aot/generate/VariableNameFactory.java

@ -24,6 +24,7 @@ import org.springframework.lang.CheckReturnValue; @@ -24,6 +24,7 @@ import org.springframework.lang.CheckReturnValue;
* @author Christoph Strobl
* @since 4.0
*/
@FunctionalInterface
interface VariableNameFactory {
/**

63
src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests.java

@ -18,9 +18,7 @@ package org.springframework.data.repository.aot.generate; @@ -18,9 +18,7 @@ package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.jupiter.api.BeforeEach;
@ -31,6 +29,7 @@ import org.mockito.junit.jupiter.MockitoExtension; @@ -31,6 +29,7 @@ import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.aot.generate.GeneratedTypeReference;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.util.Version;
@ -41,6 +40,7 @@ import org.springframework.javapoet.CodeBlock; @@ -41,6 +40,7 @@ import org.springframework.javapoet.CodeBlock;
* Unit testa for {@link AotRepositoryBeanDefinitionPropertiesDecorator}.
*
* @author Christoph Strobl
* @author Mark Paluch
*/
@ExtendWith(MockitoExtension.class)
class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
@ -48,7 +48,8 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -48,7 +48,8 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
private static final String TYPE_NAME = "com.example.UserRepositoryImpl__AotRepository";
@Mock RepositoryContributor contributor;
CodeBlock.Builder inheritedSource;
AotRepositoryFragmentMetadata metadata = new AotRepositoryFragmentMetadata();
RepositoryConstructorBuilder constructorBuilder = new RepositoryConstructorBuilder(metadata);
AotRepositoryBeanDefinitionPropertiesDecorator decorator;
@BeforeEach
@ -57,6 +58,8 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -57,6 +58,8 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
when(contributor.getContributedTypeName()).thenReturn(GeneratedTypeReference.of(ClassName.bestGuess(TYPE_NAME)));
inheritedSource = CodeBlock.builder();
decorator = new AotRepositoryBeanDefinitionPropertiesDecorator(() -> inheritedSource.build(), contributor);
when(contributor.getConstructorArguments()).thenReturn(metadata.getConstructorArguments());
}
@Test // GH-3344
@ -64,7 +67,6 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -64,7 +67,6 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
inheritedSource.add("beanDefinition.getPropertyValues().addPropertyValue($S, $S)", "repositoryBaseClass",
"org.springframework.data.BaseRepository");
when(contributor.requiredArgs()).thenReturn(Map.of());
CodeBlock decorate = decorator.decorate();
@ -76,8 +78,6 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -76,8 +78,6 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
@Test // GH-3344
void addsFragmentFunction() {
when(contributor.requiredArgs()).thenReturn(Map.of());
CodeBlock decorate = decorator.decorate();
assertThat(decorate.toString()) //
@ -90,8 +90,6 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -90,8 +90,6 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
@Test // GH-3344
void addsPlainNoArgConstructorForEmptyArgs() {
when(contributor.requiredArgs()).thenReturn(Map.of());
CodeBlock decorate = decorator.decorate();
assertThat(decorate.toString()) //
@ -101,12 +99,9 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -101,12 +99,9 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
@Test // GH-3344
void resolvesAndAddsArgumentsForCtor() {
Map<String, ResolvableType> ctorArgs = new LinkedHashMap<>(3);
ctorArgs.put("plain", ResolvableType.forClass(Version.class));
ctorArgs.put("noGenericsDefined", ResolvableType.forClass(List.class));
ctorArgs.put("withGenerics", ResolvableType.forClassWithGenerics(Set.class, String.class));
when(contributor.requiredArgs()).thenReturn(ctorArgs);
constructorBuilder.addParameter("plain", ResolvableType.forClass(Version.class));
constructorBuilder.addParameter("noGenericsDefined", ResolvableType.forClass(List.class));
constructorBuilder.addParameter("withGenerics", ResolvableType.forClassWithGenerics(Set.class, String.class));
CodeBlock decorate = decorator.decorate();
@ -118,10 +113,33 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -118,10 +113,33 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
.formatted(TYPE_NAME));
}
@Test // GH-3344
void resolvesValueFromCodeblock() {
constructorBuilder.addParameter("byTypeAndName", ResolvableType.forClass(Version.class), customizer -> {
customizer.origin(new RuntimeBeanReference("foo", Version.class));
});
constructorBuilder.addParameter("byName", ResolvableType.forClass(Version.class), customizer -> {
customizer.origin(new RuntimeBeanReference("bar"));
});
constructorBuilder.addParameter("foo", Integer.class,
customizer -> customizer.origin(ctx -> AotRepositoryConstructorBuilder.ParameterOrigin.of(CodeBlock.of("1"))));
CodeBlock decorate = decorator.decorate();
assertThat(decorate.toString()) //
.containsSubsequence("Version byTypeAndName = beanFactory.getBean(\"foo\"", "Version.class)") //
.containsSubsequence("Version byName = ", "Version) beanFactory.getBean(\"bar\"") //
.containsSubsequence("return ", "RepositoryFragments.just(new %s(byTypeAndName, byName, 1));" //
.formatted(TYPE_NAME));
}
@Test // GH-3344
void passesOnBeanFactoryIfRequested() {
when(contributor.requiredArgs()).thenReturn(Map.of("beanFactory", ResolvableType.forClass(BeanFactory.class)));
constructorBuilder.addParameter("beanFactory", BeanFactory.class);
CodeBlock decorate = decorator.decorate();
@ -133,8 +151,7 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -133,8 +151,7 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
@Test // GH-3344
void passesOnContextIfRequested() {
when(contributor.requiredArgs()).thenReturn(
Map.of("context", ResolvableType.forClass(RepositoryFactoryBeanSupport.FragmentCreationContext.class)));
constructorBuilder.addParameter("context", RepositoryFactoryBeanSupport.FragmentCreationContext.class);
CodeBlock decorate = decorator.decorate();
@ -146,27 +163,23 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests { @@ -146,27 +163,23 @@ class AotRepositoryBeanDefinitionPropertiesDecoratorUnitTests {
@Test // GH-3344
void passesOnContextWithDifferentNameIfRequested() {
when(contributor.requiredArgs()).thenReturn(
Map.of("theContext", ResolvableType.forClass(RepositoryFactoryBeanSupport.FragmentCreationContext.class)));
constructorBuilder.addParameter("theContext", RepositoryFactoryBeanSupport.FragmentCreationContext.class);
CodeBlock decorate = decorator.decorate();
assertThat(decorate.toString()) //
.doesNotContain("FragmentCreationContext theContext = beanFactory.getBean(") //
.containsSubsequence("FragmentCreationContext theContext = context") //
.containsSubsequence("return ", "RepositoryFragments.just(new %s(theContext));".formatted(TYPE_NAME));
.containsSubsequence("return ", "RepositoryFragments.just(new %s(context));".formatted(TYPE_NAME));
}
@Test // GH-3344
void passesOnBeanFactoryDifferentNameIfRequested() {
when(contributor.requiredArgs()).thenReturn(Map.of("myBeanFactory", ResolvableType.forClass(BeanFactory.class)));
constructorBuilder.addParameter("myBeanFactory", BeanFactory.class);
CodeBlock decorate = decorator.decorate();
assertThat(decorate.toString()) //
.doesNotContain("BeanFactory myBeanFactory = beanFactory.getBean(") //
.containsSubsequence("BeanFactory myBeanFactory = beanFactory") //
.containsSubsequence("return ", "RepositoryFragments.just(new %s(myBeanFactory));".formatted(TYPE_NAME));
.containsSubsequence("return ", "RepositoryFragments.just(new %s(beanFactory));".formatted(TYPE_NAME));
}
}

Loading…
Cancel
Save