Browse Source

Polishing.

Improve encapsulation.

See #3338
pull/3345/head
Mark Paluch 4 months ago
parent
commit
dcb98dd8e9
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 4
      src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java
  2. 34
      src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilder.java
  3. 46
      src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java
  4. 14
      src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java

4
src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java

@ -265,7 +265,7 @@ public class AotQueryMethodGenerationContext { @@ -265,7 +265,7 @@ public class AotQueryMethodGenerationContext {
* @return the variable name used in the generated code.
*/
public String localVariable(String variableName) {
return targetMethodMetadata.getLocalVariables().computeIfAbsent(variableName, variableNameFactory::generateName);
return targetMethodMetadata.getOrCreateLocalVariable(variableName, variableNameFactory::generateName);
}
/**
@ -346,7 +346,7 @@ public class AotQueryMethodGenerationContext { @@ -346,7 +346,7 @@ public class AotQueryMethodGenerationContext {
/**
* Obtain the {@link ExpressionMarker} for the current method. Will add a local class within the method that can be
* referenced via {@link ExpressionMarker#enclosingMethod()}.
*
*
* @return the {@link ExpressionMarker} for this particular method.
*/
public ExpressionMarker getExpressionMarker() {

34
src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilder.java

@ -17,6 +17,7 @@ package org.springframework.data.repository.aot.generate; @@ -17,6 +17,7 @@ package org.springframework.data.repository.aot.generate;
import java.lang.reflect.Method;
import java.lang.reflect.TypeVariable;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
@ -25,9 +26,9 @@ import javax.lang.model.element.Modifier; @@ -25,9 +26,9 @@ import javax.lang.model.element.Modifier;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.ParameterSpec;
import org.springframework.javapoet.TypeName;
import org.springframework.javapoet.TypeVariableName;
import org.springframework.util.StringUtils;
/**
* Builder for AOT repository query methods.
@ -82,27 +83,38 @@ class AotRepositoryMethodBuilder { @@ -82,27 +83,38 @@ class AotRepositoryMethodBuilder {
public MethodSpec buildMethod() {
CodeBlock methodBody = contribution.apply(context);
MethodSpec.Builder builder = initializeMethodBuilder();
if (context.getExpressionMarker().isInUse()) {
builder.addCode(context.getExpressionMarker().declaration());
}
builder.addCode(methodBody);
customizer.accept(context, builder);
return builder.build();
}
private MethodSpec.Builder initializeMethodBuilder() {
MethodSpec.Builder builder = MethodSpec.methodBuilder(context.getMethod().getName()).addModifiers(Modifier.PUBLIC);
builder.returns(TypeName.get(context.getReturnType().getType()));
TypeVariable<Method>[] tvs = context.getMethod().getTypeParameters();
for (TypeVariable<Method> tv : tvs) {
builder.addTypeVariable(TypeVariableName.get(tv));
}
MethodMetadata methodMetadata = context.getTargetMethodMetadata();
Map<String, ParameterSpec> methodArguments = methodMetadata.getMethodArguments();
builder.addJavadoc("AOT generated implementation of {@link $T#$L($L)}.", context.getMethod().getDeclaringClass(),
context.getMethod().getName(), StringUtils.collectionToCommaDelimitedString(context.getTargetMethodMetadata()
.getMethodArguments().values().stream().map(it -> it.type().toString()).collect(Collectors.toList())));
context.getTargetMethodMetadata().getMethodArguments().forEach((name, spec) -> builder.addParameter(spec));
if(context.getExpressionMarker().isInUse()) {
builder.addCode(context.getExpressionMarker().declaration());
}
builder.addCode(methodBody);
customizer.accept(context, builder);
context.getMethod().getName(),
methodArguments.values().stream().map(it -> it.type().toString()).collect(Collectors.joining(", ")));
return builder.build();
methodArguments.forEach((name, spec) -> builder.addParameter(spec));
return builder;
}
}

46
src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java

@ -16,13 +16,17 @@ @@ -16,13 +16,17 @@
package org.springframework.data.repository.aot.generate;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterNameDiscoverer;
@ -40,8 +44,8 @@ import org.springframework.javapoet.TypeName; @@ -40,8 +44,8 @@ import org.springframework.javapoet.TypeName;
*/
class MethodMetadata {
private final Map<String, ParameterSpec> methodArguments = new LinkedHashMap<>();
private final Map<String, MethodParameter> methodParameters = new LinkedHashMap<>();
private final Map<String, ParameterSpec> methodArguments;
private final Map<String, MethodParameter> methodParameters;
private final Map<String, String> localVariables = new LinkedHashMap<>();
private final ResolvableType actualReturnType;
private final ResolvableType returnType;
@ -50,15 +54,24 @@ class MethodMetadata { @@ -50,15 +54,24 @@ class MethodMetadata {
this.returnType = repositoryInformation.getReturnType(method).toResolvableType();
this.actualReturnType = repositoryInformation.getReturnedDomainTypeInformation(method).toResolvableType();
this.initParameters(repositoryInformation, method, new DefaultParameterNameDiscoverer());
}
private void initParameters(RepositoryInformation repositoryInformation, Method method,
ParameterNameDiscoverer nameDiscoverer) {
Map<String, ParameterSpec> methodArguments = new LinkedHashMap<>();
Map<String, MethodParameter> methodParameters = new LinkedHashMap<>();
ResolvableType repositoryInterface = ResolvableType.forClass(repositoryInformation.getRepositoryInterface());
ParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
initializeMethodArguments(method, nameDiscoverer, repositoryInterface, methodArguments, methodParameters);
this.methodArguments = Collections.unmodifiableMap(methodArguments);
this.methodParameters = Collections.unmodifiableMap(methodParameters);
}
for (java.lang.reflect.Parameter parameter : method.getParameters()) {
private static void initializeMethodArguments(Method method, ParameterNameDiscoverer nameDiscoverer,
ResolvableType repositoryInterface, Map<String, ParameterSpec> methodArguments,
Map<String, MethodParameter> methodParameters) {
for (Parameter parameter : method.getParameters()) {
MethodParameter methodParameter = MethodParameter.forParameter(parameter);
methodParameter.initParameterNameDiscovery(nameDiscoverer);
@ -66,8 +79,14 @@ class MethodMetadata { @@ -66,8 +79,14 @@ class MethodMetadata {
TypeName parameterType = TypeName.get(resolvableParameterType.getType());
addParameter(ParameterSpec.builder(parameterType, methodParameter.getParameterName()).build());
methodParameters.put(methodParameter.getParameterName(), methodParameter);
ParameterSpec parameterSpec = ParameterSpec.builder(parameterType, methodParameter.getParameterName()).build();
if (methodArguments.containsKey(parameterSpec.name())) {
throw new IllegalStateException("Parameter with name '" + parameterSpec.name() + "' already exists.");
}
methodArguments.put(parameterSpec.name(), parameterSpec);
methodParameters.put(parameterSpec.name(), methodParameter);
}
}
@ -79,10 +98,6 @@ class MethodMetadata { @@ -79,10 +98,6 @@ class MethodMetadata {
return actualReturnType;
}
void addParameter(ParameterSpec parameterSpec) {
this.methodArguments.put(parameterSpec.name(), parameterSpec);
}
Map<String, ParameterSpec> getMethodArguments() {
return methodArguments;
}
@ -105,8 +120,9 @@ class MethodMetadata { @@ -105,8 +120,9 @@ class MethodMetadata {
return null;
}
Map<String, String> getLocalVariables() {
return localVariables;
public String getOrCreateLocalVariable(String variableName,
Function<? super String, ? extends String> mappingFunction) {
return localVariables.computeIfAbsent(variableName, mappingFunction);
}
}

14
src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java

@ -15,10 +15,9 @@ @@ -15,10 +15,9 @@
*/
package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.when;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
import example.UserRepository;
import example.UserRepository.User;
@ -33,11 +32,10 @@ import org.junit.jupiter.params.ParameterizedTest; @@ -33,11 +32,10 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.util.TypeInformation;
import org.springframework.javapoet.ParameterSpec;
import org.springframework.javapoet.ParameterizedTypeName;
/**
* @author Christoph Strobl
@ -66,7 +64,6 @@ class AotRepositoryMethodBuilderUnitTests { @@ -66,7 +64,6 @@ class AotRepositoryMethodBuilderUnitTests {
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext);
@ -84,8 +81,6 @@ class AotRepositoryMethodBuilderUnitTests { @@ -84,8 +81,6 @@ class AotRepositoryMethodBuilderUnitTests {
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata
.addParameter(ParameterSpec.builder(ParameterizedTypeName.get(List.class, String.class), "firstnames").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext);
@ -104,7 +99,6 @@ class AotRepositoryMethodBuilderUnitTests { @@ -104,7 +99,6 @@ class AotRepositoryMethodBuilderUnitTests {
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
when(methodGenerationContext.getExpressionMarker()).thenReturn(expressionMarker);

Loading…
Cancel
Save