diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java b/src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java index 4921a8428..0ce55ab5c 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContext.java @@ -21,7 +21,6 @@ import java.util.ArrayList; import java.util.List; import org.jspecify.annotations.Nullable; - import org.springframework.core.ResolvableType; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.core.annotation.MergedAnnotationSelectors; @@ -50,8 +49,8 @@ public class AotQueryMethodGenerationContext { private final MethodMetadata targetMethodMetadata; private final VariableNameFactory variableNameFactory; - protected AotQueryMethodGenerationContext(RepositoryInformation repositoryInformation, Method method, QueryMethod queryMethod, - AotRepositoryFragmentMetadata targetTypeMetadata) { + protected AotQueryMethodGenerationContext(RepositoryInformation repositoryInformation, Method method, + QueryMethod queryMethod, AotRepositoryFragmentMetadata targetTypeMetadata) { this.method = method; this.annotations = MergedAnnotations.from(method); @@ -270,8 +269,7 @@ public class AotQueryMethodGenerationContext { * @return the parameter name for the {@link org.springframework.data.domain.Sort sort parameter} or {@code null} if * the method does not declare a sort parameter. */ - @Nullable - public String getSortParameterName() { + public @Nullable String getSortParameterName() { return getParameterName(queryMethod.getParameters().getSortIndex()); } @@ -279,8 +277,7 @@ public class AotQueryMethodGenerationContext { * @return the parameter name for the {@link org.springframework.data.domain.Pageable pageable parameter} or * {@code null} if the method does not declare a pageable parameter. */ - @Nullable - public String getPageableParameterName() { + public @Nullable String getPageableParameterName() { return getParameterName(queryMethod.getParameters().getPageableIndex()); } @@ -288,8 +285,7 @@ public class AotQueryMethodGenerationContext { * @return the parameter name for the {@link org.springframework.data.domain.Limit limit parameter} or {@code null} if * the method does not declare a limit parameter. */ - @Nullable - public String getLimitParameterName() { + public @Nullable String getLimitParameterName() { return getParameterName(queryMethod.getParameters().getLimitIndex()); } @@ -297,8 +293,7 @@ public class AotQueryMethodGenerationContext { * @return the parameter name for the {@link org.springframework.data.domain.ScrollPosition scroll position parameter} * or {@code null} if the method does not declare a scroll position parameter. */ - @Nullable - public String getScrollPositionParameterName() { + public @Nullable String getScrollPositionParameterName() { return getParameterName(queryMethod.getParameters().getScrollPositionIndex()); } @@ -306,9 +301,32 @@ public class AotQueryMethodGenerationContext { * @return the parameter name for the {@link Class dynamic projection parameter} or {@code null} if the method does * not declare a dynamic projection parameter. */ - @Nullable - public String getDynamicProjectionParameterName() { + public @Nullable String getDynamicProjectionParameterName() { return getParameterName(queryMethod.getParameters().getDynamicProjectionIndex()); } + /** + * @return the parameter name for the {@link org.springframework.data.domain.Vector vector parameter} or {@code null} + * if the method does not declare a vector type parameter. + */ + public @Nullable String getVectorParameterName() { + return getParameterName(queryMethod.getParameters().getVectorIndex()); + } + + /** + * @return the parameter name for the {@link org.springframework.data.domain.Score score parameter} or {@code null} if + * the method does not declare a score type parameter. + */ + public @Nullable String getScoreParameterName() { + return getParameterName(queryMethod.getParameters().getScoreIndex()); + } + + /** + * @return the parameter name for the {@link org.springframework.data.domain.Range score range parameter} or + * {@code null} if the method does not declare a score range type parameter. + */ + public @Nullable String getScoreRangeParameterName() { + return getParameterName(queryMethod.getParameters().getScoreRangeIndex()); + } + } diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContextUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContextUnitTests.java index c529af7a8..9fa9edaa5 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContextUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotQueryMethodGenerationContextUnitTests.java @@ -25,7 +25,11 @@ import org.mockito.Mockito; import org.springframework.data.domain.Limit; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Vector; import org.springframework.data.domain.Window; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.repository.Repository; @@ -74,6 +78,19 @@ class AotQueryMethodGenerationContextUnitTests { assertThat(ctx.getDynamicProjectionParameterName()).isNull(); } + @Test // GH-3265 + void returnsCorrectParameterNamesForVectorSearch() throws NoSuchMethodException { + + AotQueryMethodGenerationContext ctx = ctxFor("searchAllNearWithScore"); + + assertThat(ctx.getVectorParameterName()).isEqualTo("vEctOR"); + assertThat(ctx.getScoreParameterName()).isEqualTo("distance"); + + ctx = ctxFor("searchAllNearInRange"); + + assertThat(ctx.getScoreRangeParameterName()).isEqualTo("rDistance"); + } + AotQueryMethodGenerationContext ctxFor(String methodName) throws NoSuchMethodException { Method target = null; @@ -105,5 +122,9 @@ class AotQueryMethodGenerationContextUnitTests { Window limitScrollPositionDynamicProjection(Limit l, ScrollPosition sp, Class projection); Page pageable(Pageable p); + + SearchResults searchAllNearWithScore(Vector vEctOR, Score distance, Limit limit); + + SearchResults searchAllNearInRange(Vector vEctOR, Range rDistance, Limit limit); } }