diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java index df635fcd2..8e9439e7f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java @@ -103,6 +103,10 @@ public class MongoAotRepositoryFragmentSupport { return list; } + protected Object convertSimpleRawResult(Class targetType, Document rawResult) { + return extractSimpleTypeResult(rawResult, targetType, mongoConverter); + } + private static @Nullable T extractSimpleTypeResult(@Nullable Document source, Class targetType, MongoConverter converter) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 8ac187cb7..999391f5e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.regex.Pattern; +import java.util.stream.Stream; import org.bson.Document; import org.jspecify.annotations.NullUnmarked; @@ -49,7 +50,6 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.data.util.ReflectionUtils; -import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; import org.springframework.javapoet.TypeName; @@ -182,17 +182,15 @@ class MongoCodeBlocks { String mongoOpsRef = context.fieldNameOf(MongoOperations.class); Builder builder = CodeBlock.builder(); + Class domainType = context.getRepositoryInformation().getDomainType(); boolean isProjecting = context.getActualReturnType() != null - && !ObjectUtils.nullSafeEquals(TypeName.get(context.getRepositoryInformation().getDomainType()), - context.getActualReturnType()); + && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType()); - Object actualReturnType = isProjecting ? context.getActualReturnType().getType() - : context.getRepositoryInformation().getDomainType(); + Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; builder.add("\n"); - builder.addStatement("$T<$T> remover = $L.remove($T.class)", ExecutableRemove.class, - context.getRepositoryInformation().getDomainType(), mongoOpsRef, - context.getRepositoryInformation().getDomainType()); + builder.addStatement("$T<$T> $L = $L.remove($T.class)", ExecutableRemove.class, domainType, + context.localVariable("remover"), mongoOpsRef, domainType); DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; if (!queryMethod.isCollectionQuery()) { @@ -204,11 +202,20 @@ class MongoCodeBlocks { } actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) - ? ClassName.get(context.getMethod().getReturnType()) + ? TypeName.get(context.getMethod().getReturnType()) : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; - builder.addStatement("return ($T) new $T(remover, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, - DeleteExecution.Type.class, type.name(), queryVariableName); + if (ClassUtils.isVoidType(context.getMethod().getReturnType())) { + builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"), + DeleteExecution.Type.class, type.name(), queryVariableName); + } else if (context.getMethod().getReturnType() == Optional.class) { + builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class, + actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class, + type.name(), queryVariableName); + } else { + builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, + context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName); + } return builder.build(); } @@ -318,14 +325,25 @@ class MongoCodeBlocks { Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - if (!queryMethod.isCollectionQuery()) { - builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))", - CollectionUtils.class, returnType, returnType, context.localVariable("results")); + if (queryMethod.isStreamQuery()) { + + builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + + builder.addStatement("return $L.map(it -> ($T) convertSimpleRawResult($T.class, it))", + context.localVariable("results"), returnType, returnType); } else { - builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, - context.localVariable("results")); + + builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + + if (!queryMethod.isCollectionQuery()) { + builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))", + CollectionUtils.class, returnType, returnType, context.localVariable("results")); + } else { + builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, + context.localVariable("results")); + } } } else { if (queryMethod.isSliceQuery()) { @@ -339,8 +357,15 @@ class MongoCodeBlocks { context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(), context.localVariable("hasNext")); } else { - builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, - aggregationVariableName, outputType); + + if (queryMethod.isStreamQuery()) { + builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, + outputType); + } else { + + builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, + aggregationVariableName, outputType); + } } } @@ -420,8 +445,16 @@ class MongoCodeBlocks { builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(), scrollPositionParameterName); } else { - builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(), - terminatingMethod); + if (query.isCount() && !ClassUtils.isAssignable(Long.class, context.getActualReturnType().getRawClass())) { + + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + builder.addStatement("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)", NumberUtils.class, + context.localVariable("finder"), query.name(), terminatingMethod, returnType); + + } else { + builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(), + terminatingMethod); + } } return builder.build(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 354f4a629..424d067d7 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -18,6 +18,7 @@ package org.springframework.data.mongodb.repository.aot; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*; import java.lang.reflect.Method; +import java.util.Locale; import java.util.regex.Pattern; import org.apache.commons.logging.Log; @@ -119,14 +120,23 @@ public class MongoRepositoryContributor extends RepositoryContributor { if (queryMethod.isModifyingQuery()) { - Update updateSource = queryMethod.getUpdateSource(); - if (StringUtils.hasText(updateSource.value())) { - UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value())); + int updateIndex = queryMethod.getParameters().getUpdateIndex(); + if (updateIndex != -1) { + + UpdateInteraction update = new UpdateInteraction(query, null, updateIndex); return updateMethodContributor(queryMethod, update); - } - if (!ObjectUtils.isEmpty(updateSource.pipeline())) { - AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline()); - return aggregationUpdateMethodContributor(queryMethod, update); + + } else { + Update updateSource = queryMethod.getUpdateSource(); + if (StringUtils.hasText(updateSource.value())) { + UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null); + return updateMethodContributor(queryMethod, update); + } + + if (!ObjectUtils.isEmpty(updateSource.pipeline())) { + AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline()); + return aggregationUpdateMethodContributor(queryMethod, update); + } } } @@ -160,10 +170,12 @@ public class MongoRepositoryContributor extends RepositoryContributor { private static boolean backoff(MongoQueryMethod method) { - boolean skip = method.isGeoNearQuery() || method.isSearchQuery(); + // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays. + boolean skip = method.isGeoNearQuery() || method.isSearchQuery() + || method.getName().toLowerCase(Locale.ROOT).contains("regex") || method.getReturnType().getType().isArray(); if (skip && logger.isDebugEnabled()) { - logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming, search or scrolling query" + logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query" .formatted(method.getName())); } return skip; @@ -197,9 +209,15 @@ public class MongoRepositoryContributor extends RepositoryContributor { .usingQueryVariableName(filterVariableName).build()); // update definition - String updateVariableName = context.localVariable("updateDefinition"); - builder.add( - updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName).build()); + String updateVariableName; + + if (update.hasUpdateDefinitionParameter()) { + updateVariableName = context.getParameterName(update.getRequiredUpdateDefinitionParameter()); + } else { + updateVariableName = context.localVariable("updateDefinition"); + builder.add(updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName) + .build()); + } builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName) .referencingUpdate(updateVariableName).build()); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateInteraction.java index bbc76bec5..525a4782a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateInteraction.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateInteraction.java @@ -17,37 +17,60 @@ package org.springframework.data.mongodb.repository.aot; import java.util.Map; +import org.jspecify.annotations.Nullable; + import org.springframework.data.repository.aot.generate.QueryMetadata; +import org.springframework.util.Assert; /** * An {@link MongoInteraction} to execute an update. * * @author Christoph Strobl + * @author Mark Paluch * @since 5.0 */ class UpdateInteraction extends MongoInteraction implements QueryMetadata { private final QueryInteraction filter; - private final StringUpdate update; + private final @Nullable StringUpdate update; + private final @Nullable Integer updateDefinitionParameter; - UpdateInteraction(QueryInteraction filter, StringUpdate update) { + UpdateInteraction(QueryInteraction filter, @Nullable StringUpdate update, + @Nullable Integer updateDefinitionParameter) { this.filter = filter; this.update = update; + this.updateDefinitionParameter = updateDefinitionParameter; } - QueryInteraction getFilter() { + public QueryInteraction getFilter() { return filter; } - StringUpdate getUpdate() { + public @Nullable StringUpdate getUpdate() { return update; } + public int getRequiredUpdateDefinitionParameter() { + + Assert.notNull(updateDefinitionParameter, "UpdateDefinitionParameter must not be null!"); + + return updateDefinitionParameter; + } + + public boolean hasUpdateDefinitionParameter() { + return updateDefinitionParameter != null; + } + @Override public Map serialize() { Map serialized = filter.serialize(); - serialized.put("update", update.getUpdateString()); + + if (update != null) { + serialized.put("filter", filter.getQuery().getQueryString()); + serialized.put("update", update.getUpdateString()); + } + return serialized; } @@ -55,4 +78,5 @@ class UpdateInteraction extends MongoInteraction implements QueryMetadata { InteractionType getExecutionType() { return InteractionType.UPDATE; } + } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 0b62f5979..5eb9fed68 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -55,6 +55,8 @@ public interface UserRepository extends CrudRepository { Long countUsersByLastname(String lastname); + int countUsersAsIntByLastname(String lastname); + Boolean existsUserByLastname(String lastname); List findByLastnameStartingWith(String lastname); @@ -216,6 +218,11 @@ public interface UserRepository extends CrudRepository { "{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" }) AggregationResults groupByLastnameAndAsAggregationResults(String property); + @Aggregation(pipeline = { // + "{ '$match' : { 'last_name' : { '$ne' : null } } }", // + "{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" }) + Stream streamGroupByLastnameAndAsAggregationResults(String property); + @Aggregation(pipeline = { // "{ '$match' : { 'posts' : { '$ne' : null } } }", // "{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", // diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 5a86c9658..1c9796ead 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -107,8 +107,8 @@ class MongoRepositoryContributorTests { @Test void testDerivedCount() { - Long value = fragment.countUsersByLastname("Skywalker"); - assertThat(value).isEqualTo(2L); + assertThat(fragment.countUsersByLastname("Skywalker")).isEqualTo(2L); + assertThat(fragment.countUsersAsIntByLastname("Skywalker")).isEqualTo(2); } @Test @@ -559,6 +559,16 @@ class MongoRepositoryContributorTests { new UserAggregate("Solo", List.of("Han", "Ben"))); } + @Test + void testAggregationStreamWithProjectedResultsWrappedInAggregationResults() { + + List allLastnames = fragment.streamGroupByLastnameAndAsAggregationResults("first_name").toList(); + assertThat(allLastnames).containsExactlyInAnyOrder(// + new UserAggregate("Skywalker", List.of("Anakin", "Luke")), // + new UserAggregate("Organa", List.of("Leia")), // + new UserAggregate("Solo", List.of("Han", "Ben"))); + } + @Test void testAggregationWithSingleResultExtraction() { assertThat(fragment.sumPosts()).isEqualTo(5); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java index 39ce43e99..5f470dd55 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java @@ -37,12 +37,12 @@ import org.springframework.data.repository.core.support.RepositoryComposition; /** * @author Christoph Strobl */ -class TestMongoAotRepositoryContext implements AotRepositoryContext { +public class TestMongoAotRepositoryContext implements AotRepositoryContext { private final StubRepositoryInformation repositoryInformation; private final Environment environment = new StandardEnvironment(); - TestMongoAotRepositoryContext(Class repositoryInterface, @Nullable RepositoryComposition composition) { + public TestMongoAotRepositoryContext(Class repositoryInterface, @Nullable RepositoryComposition composition) { this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition); }