Browse Source

Polishing.

Improve encapsulation.

See #3338
pull/3345/head
Mark Paluch 5 months ago
parent
commit
dcb98dd8e9
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 4
      src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java
  2. 34
      src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilder.java
  3. 46
      src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java
  4. 14
      src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java

4
src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java

@ -265,7 +265,7 @@ public class AotQueryMethodGenerationContext {
* @return the variable name used in the generated code. * @return the variable name used in the generated code.
*/ */
public String localVariable(String variableName) { public String localVariable(String variableName) {
return targetMethodMetadata.getLocalVariables().computeIfAbsent(variableName, variableNameFactory::generateName); return targetMethodMetadata.getOrCreateLocalVariable(variableName, variableNameFactory::generateName);
} }
/** /**
@ -346,7 +346,7 @@ public class AotQueryMethodGenerationContext {
/** /**
* Obtain the {@link ExpressionMarker} for the current method. Will add a local class within the method that can be * Obtain the {@link ExpressionMarker} for the current method. Will add a local class within the method that can be
* referenced via {@link ExpressionMarker#enclosingMethod()}. * referenced via {@link ExpressionMarker#enclosingMethod()}.
* *
* @return the {@link ExpressionMarker} for this particular method. * @return the {@link ExpressionMarker} for this particular method.
*/ */
public ExpressionMarker getExpressionMarker() { public ExpressionMarker getExpressionMarker() {

34
src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilder.java

@ -17,6 +17,7 @@ package org.springframework.data.repository.aot.generate;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.TypeVariable; import java.lang.reflect.TypeVariable;
import java.util.Map;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -25,9 +26,9 @@ import javax.lang.model.element.Modifier;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.ParameterSpec;
import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeName;
import org.springframework.javapoet.TypeVariableName; import org.springframework.javapoet.TypeVariableName;
import org.springframework.util.StringUtils;
/** /**
* Builder for AOT repository query methods. * Builder for AOT repository query methods.
@ -82,27 +83,38 @@ class AotRepositoryMethodBuilder {
public MethodSpec buildMethod() { public MethodSpec buildMethod() {
CodeBlock methodBody = contribution.apply(context); CodeBlock methodBody = contribution.apply(context);
MethodSpec.Builder builder = initializeMethodBuilder();
if (context.getExpressionMarker().isInUse()) {
builder.addCode(context.getExpressionMarker().declaration());
}
builder.addCode(methodBody);
customizer.accept(context, builder);
return builder.build();
}
private MethodSpec.Builder initializeMethodBuilder() {
MethodSpec.Builder builder = MethodSpec.methodBuilder(context.getMethod().getName()).addModifiers(Modifier.PUBLIC); MethodSpec.Builder builder = MethodSpec.methodBuilder(context.getMethod().getName()).addModifiers(Modifier.PUBLIC);
builder.returns(TypeName.get(context.getReturnType().getType())); builder.returns(TypeName.get(context.getReturnType().getType()));
TypeVariable<Method>[] tvs = context.getMethod().getTypeParameters(); TypeVariable<Method>[] tvs = context.getMethod().getTypeParameters();
for (TypeVariable<Method> tv : tvs) { for (TypeVariable<Method> tv : tvs) {
builder.addTypeVariable(TypeVariableName.get(tv)); builder.addTypeVariable(TypeVariableName.get(tv));
} }
MethodMetadata methodMetadata = context.getTargetMethodMetadata();
Map<String, ParameterSpec> methodArguments = methodMetadata.getMethodArguments();
builder.addJavadoc("AOT generated implementation of {@link $T#$L($L)}.", context.getMethod().getDeclaringClass(), builder.addJavadoc("AOT generated implementation of {@link $T#$L($L)}.", context.getMethod().getDeclaringClass(),
context.getMethod().getName(), StringUtils.collectionToCommaDelimitedString(context.getTargetMethodMetadata() context.getMethod().getName(),
.getMethodArguments().values().stream().map(it -> it.type().toString()).collect(Collectors.toList()))); methodArguments.values().stream().map(it -> it.type().toString()).collect(Collectors.joining(", ")));
context.getTargetMethodMetadata().getMethodArguments().forEach((name, spec) -> builder.addParameter(spec));
if(context.getExpressionMarker().isInUse()) {
builder.addCode(context.getExpressionMarker().declaration());
}
builder.addCode(methodBody);
customizer.accept(context, builder);
return builder.build(); methodArguments.forEach((name, spec) -> builder.addParameter(spec));
return builder;
} }
} }

46
src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java

@ -16,13 +16,17 @@
package org.springframework.data.repository.aot.generate; package org.springframework.data.repository.aot.generate;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.function.Function;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterNameDiscoverer; import org.springframework.core.ParameterNameDiscoverer;
@ -40,8 +44,8 @@ import org.springframework.javapoet.TypeName;
*/ */
class MethodMetadata { class MethodMetadata {
private final Map<String, ParameterSpec> methodArguments = new LinkedHashMap<>(); private final Map<String, ParameterSpec> methodArguments;
private final Map<String, MethodParameter> methodParameters = new LinkedHashMap<>(); private final Map<String, MethodParameter> methodParameters;
private final Map<String, String> localVariables = new LinkedHashMap<>(); private final Map<String, String> localVariables = new LinkedHashMap<>();
private final ResolvableType actualReturnType; private final ResolvableType actualReturnType;
private final ResolvableType returnType; private final ResolvableType returnType;
@ -50,15 +54,24 @@ class MethodMetadata {
this.returnType = repositoryInformation.getReturnType(method).toResolvableType(); this.returnType = repositoryInformation.getReturnType(method).toResolvableType();
this.actualReturnType = repositoryInformation.getReturnedDomainTypeInformation(method).toResolvableType(); this.actualReturnType = repositoryInformation.getReturnedDomainTypeInformation(method).toResolvableType();
this.initParameters(repositoryInformation, method, new DefaultParameterNameDiscoverer());
}
private void initParameters(RepositoryInformation repositoryInformation, Method method, Map<String, ParameterSpec> methodArguments = new LinkedHashMap<>();
ParameterNameDiscoverer nameDiscoverer) { Map<String, MethodParameter> methodParameters = new LinkedHashMap<>();
ResolvableType repositoryInterface = ResolvableType.forClass(repositoryInformation.getRepositoryInterface()); ResolvableType repositoryInterface = ResolvableType.forClass(repositoryInformation.getRepositoryInterface());
ParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
initializeMethodArguments(method, nameDiscoverer, repositoryInterface, methodArguments, methodParameters);
this.methodArguments = Collections.unmodifiableMap(methodArguments);
this.methodParameters = Collections.unmodifiableMap(methodParameters);
}
for (java.lang.reflect.Parameter parameter : method.getParameters()) { private static void initializeMethodArguments(Method method, ParameterNameDiscoverer nameDiscoverer,
ResolvableType repositoryInterface, Map<String, ParameterSpec> methodArguments,
Map<String, MethodParameter> methodParameters) {
for (Parameter parameter : method.getParameters()) {
MethodParameter methodParameter = MethodParameter.forParameter(parameter); MethodParameter methodParameter = MethodParameter.forParameter(parameter);
methodParameter.initParameterNameDiscovery(nameDiscoverer); methodParameter.initParameterNameDiscovery(nameDiscoverer);
@ -66,8 +79,14 @@ class MethodMetadata {
TypeName parameterType = TypeName.get(resolvableParameterType.getType()); TypeName parameterType = TypeName.get(resolvableParameterType.getType());
addParameter(ParameterSpec.builder(parameterType, methodParameter.getParameterName()).build()); ParameterSpec parameterSpec = ParameterSpec.builder(parameterType, methodParameter.getParameterName()).build();
methodParameters.put(methodParameter.getParameterName(), methodParameter);
if (methodArguments.containsKey(parameterSpec.name())) {
throw new IllegalStateException("Parameter with name '" + parameterSpec.name() + "' already exists.");
}
methodArguments.put(parameterSpec.name(), parameterSpec);
methodParameters.put(parameterSpec.name(), methodParameter);
} }
} }
@ -79,10 +98,6 @@ class MethodMetadata {
return actualReturnType; return actualReturnType;
} }
void addParameter(ParameterSpec parameterSpec) {
this.methodArguments.put(parameterSpec.name(), parameterSpec);
}
Map<String, ParameterSpec> getMethodArguments() { Map<String, ParameterSpec> getMethodArguments() {
return methodArguments; return methodArguments;
} }
@ -105,8 +120,9 @@ class MethodMetadata {
return null; return null;
} }
Map<String, String> getLocalVariables() { public String getOrCreateLocalVariable(String variableName,
return localVariables; Function<? super String, ? extends String> mappingFunction) {
return localVariables.computeIfAbsent(variableName, mappingFunction);
} }
} }

14
src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java

@ -15,10 +15,9 @@
*/ */
package org.springframework.data.repository.aot.generate; package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.*;
import static org.mockito.Mockito.when;
import example.UserRepository; import example.UserRepository;
import example.UserRepository.User; import example.UserRepository.User;
@ -33,11 +32,10 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.util.TypeInformation; import org.springframework.data.util.TypeInformation;
import org.springframework.javapoet.ParameterSpec;
import org.springframework.javapoet.ParameterizedTypeName;
/** /**
* @author Christoph Strobl * @author Christoph Strobl
@ -66,7 +64,6 @@ class AotRepositoryMethodBuilderUnitTests {
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any()); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata); when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext); AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext);
@ -84,8 +81,6 @@ class AotRepositoryMethodBuilderUnitTests {
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any()); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata
.addParameter(ParameterSpec.builder(ParameterizedTypeName.get(List.class, String.class), "firstnames").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata); when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext); AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext);
@ -104,7 +99,6 @@ class AotRepositoryMethodBuilderUnitTests {
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any()); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata); when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
when(methodGenerationContext.getExpressionMarker()).thenReturn(expressionMarker); when(methodGenerationContext.getExpressionMarker()).thenReturn(expressionMarker);

Loading…
Cancel
Save