diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java index 38ac76c0b..56ca6db4e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java @@ -41,6 +41,7 @@ import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.data.util.ReflectionUtils; +import org.springframework.data.util.Streamable; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; import org.springframework.util.ClassUtils; @@ -145,8 +146,15 @@ class AggregationBlocks { builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); } else { - builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, + + CodeBlock resultBlock = CodeBlock.of("$L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, aggregationVariableName, outputType); + + if (queryMethod.getReturnType().getType().equals(Streamable.class)) { + resultBlock = CodeBlock.of("$T.of($L)", Streamable.class, resultBlock); + } + + builder.addStatement("return $L", resultBlock); } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java index a1b86ab01..2ed605c02 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -34,6 +34,7 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.data.util.Lazy; +import org.springframework.data.util.Streamable; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; import org.springframework.javapoet.TypeName; @@ -145,8 +146,15 @@ class QueryBlocks { context.localVariable("finder"), query.name(), terminatingMethod, returnType); } else { - builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(), + + CodeBlock resultBlock = CodeBlock.of("$L.matching($L).$L", context.localVariable("finder"), query.name(), terminatingMethod); + + if (queryMethod.getReturnType().getType().equals(Streamable.class)) { + resultBlock = CodeBlock.of("$T.of($L)", Streamable.class, resultBlock); + } + + builder.addStatement("return $L", resultBlock); } } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 2f8b8a4cb..b7ffe2219 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -48,6 +48,7 @@ import org.springframework.data.mongodb.repository.Update; import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.query.Param; +import org.springframework.data.util.Streamable; /** * @author Christoph Strobl @@ -58,6 +59,8 @@ public interface UserRepository extends CrudRepository { List findUserNoArgumentsBy(); + Streamable streamUserNoArgumentsBy(); + User findOneByUsername(String username); Optional findOptionalOneByUsername(String username); @@ -267,6 +270,11 @@ public interface UserRepository extends CrudRepository { "{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" }) Stream streamGroupByLastnameAndAsAggregationResults(String property); + @Aggregation(pipeline = { // + "{ '$match' : { 'last_name' : { '$ne' : null } } }", // + "{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" }) + Streamable streamAsStreamableGroupByLastnameAndAsAggregationResults(String property); + @Aggregation(pipeline = { // "{ '$match' : { 'posts' : { '$ne' : null } } }", // "{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", // diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java index c361d5513..c7eca37c3 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java @@ -79,6 +79,7 @@ import org.springframework.data.mongodb.test.util.DirtiesStateExtension.DirtiesS import org.springframework.data.mongodb.test.util.DirtiesStateExtension.ProvidesState; import org.springframework.data.mongodb.test.util.EnableIfMongoServerVersion; import org.springframework.data.querydsl.QSort; +import org.springframework.data.util.Streamable; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.util.ReflectionTestUtils; @@ -312,6 +313,33 @@ public abstract class AbstractPersonRepositoryIntegrationTests implements Dirtie assertThat(result).hasSize(1).contains(dave); } + @Test // GH-5089 + void streamPersonByAddressCorrectly() { + + Address address = new Address("Foo Street 1", "C0123", "Bar"); + dave.setAddress(address); + repository.save(dave); + + Streamable result = repository.streamByAddress(address); + assertThat(result).hasSize(1).contains(dave); + } + + @Test // GH-5089 + void streamPersonByAddressCorrectlyWhenPaged() { + + Address address = new Address("Foo Street 1", "C0123", "Bar"); + dave.setAddress(address); + oliver.setAddress(address); + repository.saveAll(List.of(dave, oliver)); + + Streamable result = repository.streamByAddress(address, + PageRequest.of(0, 1, Sort.by(Direction.DESC, "firstname"))); + assertThat(result).containsExactly(oliver); + + result = repository.streamByAddress(address, PageRequest.of(1, 1, Sort.by(Direction.DESC, "firstname"))); + assertThat(result).containsExactly(dave); + } + @Test void findsPeopleByZipCode() { @@ -1516,6 +1544,16 @@ public abstract class AbstractPersonRepositoryIntegrationTests implements Dirtie new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August"))); } + @Test // GH-5089 + void annotatedAggregationReturningStreamable() { + + assertThat(repository.streamGroupByLastnameAnd("firstname", PageRequest.of(1, 2, Sort.by("lastname")))) // + .isInstanceOf(Streamable.class) // + .containsExactly( // + new PersonAggregate("Lessard", Collections.singletonList("Stefan")), // + new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August"))); + } + @Test // DATAMONGO-2153 void annotatedAggregationWithSingleSimpleResult() { assertThat(repository.sumAge()).isEqualTo(245); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java index a9f694e20..7718c1241 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java @@ -45,6 +45,7 @@ import org.springframework.data.mongodb.core.query.UpdateDefinition; import org.springframework.data.mongodb.repository.Person.Sex; import org.springframework.data.querydsl.QuerydslPredicateExecutor; import org.springframework.data.repository.query.Param; +import org.springframework.data.util.Streamable; /** * Sample repository managing {@link Person} entities. @@ -211,6 +212,10 @@ public interface PersonRepository extends MongoRepository, Query */ List findByAddress(Address address); + Streamable streamByAddress(Address address); + + Streamable streamByAddress(Address address, Pageable pageable); + List findByAddressZipCode(String zipCode); List findByLastnameLikeAndAgeBetween(String lastname, int from, int to); @@ -442,6 +447,9 @@ public interface PersonRepository extends MongoRepository, Query @Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }") List groupByLastnameAnd(String property, Pageable page); + @Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }") + Streamable streamGroupByLastnameAnd(String property, Pageable page); + @Aggregation(pipeline = "{ '$group' : { '_id' : null, 'total' : { $sum: '$age' } } }") int sumAge(); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index 9598e50c5..f5834a67e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -387,6 +387,25 @@ class QueryMethodContributionUnitTests { "Document(\"$sort\", mappedSort.append(\"__score__\", -1))"); } + @Test // GH-5089 + void rendersStreamableReturnType() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "streamUserNoArgumentsBy"); + + assertThat(methodSpec.toString()) // + .containsSubsequence("return", "Streamable.of(", "all())"); + } + + @Test // GH-5089 + void rendersStreamableReturnTypeForAggregation() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "streamAsStreamableGroupByLastnameAndAsAggregationResults", + String.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("return", "Streamable.of(", "getMappedResults())"); + } + private static MethodSpec codeOf(Class repository, String methodName, Class... args) throws NoSuchMethodException {