Browse Source

Support `Streamable` return type in AOT repositories.

Closes #5089
Original pull request: #5090
pull/5093/merge
Christoph Strobl 1 month ago committed by Mark Paluch
parent
commit
4d94ee126e
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 10
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java
  2. 10
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java
  3. 8
      spring-data-mongodb/src/test/java/example/aot/UserRepository.java
  4. 38
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java
  5. 8
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java
  6. 19
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java

10
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java

@ -41,6 +41,7 @@ import org.springframework.data.mongodb.repository.ReadPreference;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.util.ReflectionUtils; import org.springframework.data.util.ReflectionUtils;
import org.springframework.data.util.Streamable;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder; import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
@ -145,8 +146,15 @@ class AggregationBlocks {
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
outputType); outputType);
} else { } else {
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
CodeBlock resultBlock = CodeBlock.of("$L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType); aggregationVariableName, outputType);
if (queryMethod.getReturnType().getType().equals(Streamable.class)) {
resultBlock = CodeBlock.of("$T.of($L)", Streamable.class, resultBlock);
}
builder.addStatement("return $L", resultBlock);
} }
} }
} }

10
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java

@ -34,6 +34,7 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli
import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.util.Lazy; import org.springframework.data.util.Lazy;
import org.springframework.data.util.Streamable;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder; import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeName;
@ -145,8 +146,15 @@ class QueryBlocks {
context.localVariable("finder"), query.name(), terminatingMethod, returnType); context.localVariable("finder"), query.name(), terminatingMethod, returnType);
} else { } else {
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
CodeBlock resultBlock = CodeBlock.of("$L.matching($L).$L", context.localVariable("finder"), query.name(),
terminatingMethod); terminatingMethod);
if (queryMethod.getReturnType().getType().equals(Streamable.class)) {
resultBlock = CodeBlock.of("$T.of($L)", Streamable.class, resultBlock);
}
builder.addStatement("return $L", resultBlock);
} }
} }

8
spring-data-mongodb/src/test/java/example/aot/UserRepository.java

@ -48,6 +48,7 @@ import org.springframework.data.mongodb.repository.Update;
import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.query.Param; import org.springframework.data.repository.query.Param;
import org.springframework.data.util.Streamable;
/** /**
* @author Christoph Strobl * @author Christoph Strobl
@ -58,6 +59,8 @@ public interface UserRepository extends CrudRepository<User, String> {
List<User> findUserNoArgumentsBy(); List<User> findUserNoArgumentsBy();
Streamable<User> streamUserNoArgumentsBy();
User findOneByUsername(String username); User findOneByUsername(String username);
Optional<User> findOptionalOneByUsername(String username); Optional<User> findOptionalOneByUsername(String username);
@ -267,6 +270,11 @@ public interface UserRepository extends CrudRepository<User, String> {
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" }) "{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
Stream<UserAggregate> streamGroupByLastnameAndAsAggregationResults(String property); Stream<UserAggregate> streamGroupByLastnameAndAsAggregationResults(String property);
@Aggregation(pipeline = { //
"{ '$match' : { 'last_name' : { '$ne' : null } } }", //
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
Streamable<UserAggregate> streamAsStreamableGroupByLastnameAndAsAggregationResults(String property);
@Aggregation(pipeline = { // @Aggregation(pipeline = { //
"{ '$match' : { 'posts' : { '$ne' : null } } }", // "{ '$match' : { 'posts' : { '$ne' : null } } }", //
"{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", // "{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", //

38
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java

@ -79,6 +79,7 @@ import org.springframework.data.mongodb.test.util.DirtiesStateExtension.DirtiesS
import org.springframework.data.mongodb.test.util.DirtiesStateExtension.ProvidesState; import org.springframework.data.mongodb.test.util.DirtiesStateExtension.ProvidesState;
import org.springframework.data.mongodb.test.util.EnableIfMongoServerVersion; import org.springframework.data.mongodb.test.util.EnableIfMongoServerVersion;
import org.springframework.data.querydsl.QSort; import org.springframework.data.querydsl.QSort;
import org.springframework.data.util.Streamable;
import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
@ -312,6 +313,33 @@ public abstract class AbstractPersonRepositoryIntegrationTests implements Dirtie
assertThat(result).hasSize(1).contains(dave); assertThat(result).hasSize(1).contains(dave);
} }
@Test // GH-5089
void streamPersonByAddressCorrectly() {
Address address = new Address("Foo Street 1", "C0123", "Bar");
dave.setAddress(address);
repository.save(dave);
Streamable<Person> result = repository.streamByAddress(address);
assertThat(result).hasSize(1).contains(dave);
}
@Test // GH-5089
void streamPersonByAddressCorrectlyWhenPaged() {
Address address = new Address("Foo Street 1", "C0123", "Bar");
dave.setAddress(address);
oliver.setAddress(address);
repository.saveAll(List.of(dave, oliver));
Streamable<Person> result = repository.streamByAddress(address,
PageRequest.of(0, 1, Sort.by(Direction.DESC, "firstname")));
assertThat(result).containsExactly(oliver);
result = repository.streamByAddress(address, PageRequest.of(1, 1, Sort.by(Direction.DESC, "firstname")));
assertThat(result).containsExactly(dave);
}
@Test @Test
void findsPeopleByZipCode() { void findsPeopleByZipCode() {
@ -1516,6 +1544,16 @@ public abstract class AbstractPersonRepositoryIntegrationTests implements Dirtie
new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August"))); new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August")));
} }
@Test // GH-5089
void annotatedAggregationReturningStreamable() {
assertThat(repository.streamGroupByLastnameAnd("firstname", PageRequest.of(1, 2, Sort.by("lastname")))) //
.isInstanceOf(Streamable.class) //
.containsExactly( //
new PersonAggregate("Lessard", Collections.singletonList("Stefan")), //
new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August")));
}
@Test // DATAMONGO-2153 @Test // DATAMONGO-2153
void annotatedAggregationWithSingleSimpleResult() { void annotatedAggregationWithSingleSimpleResult() {
assertThat(repository.sumAge()).isEqualTo(245); assertThat(repository.sumAge()).isEqualTo(245);

8
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java

@ -45,6 +45,7 @@ import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.repository.Person.Sex; import org.springframework.data.mongodb.repository.Person.Sex;
import org.springframework.data.querydsl.QuerydslPredicateExecutor; import org.springframework.data.querydsl.QuerydslPredicateExecutor;
import org.springframework.data.repository.query.Param; import org.springframework.data.repository.query.Param;
import org.springframework.data.util.Streamable;
/** /**
* Sample repository managing {@link Person} entities. * Sample repository managing {@link Person} entities.
@ -211,6 +212,10 @@ public interface PersonRepository extends MongoRepository<Person, String>, Query
*/ */
List<Person> findByAddress(Address address); List<Person> findByAddress(Address address);
Streamable<Person> streamByAddress(Address address);
Streamable<Person> streamByAddress(Address address, Pageable pageable);
List<Person> findByAddressZipCode(String zipCode); List<Person> findByAddressZipCode(String zipCode);
List<Person> findByLastnameLikeAndAgeBetween(String lastname, int from, int to); List<Person> findByLastnameLikeAndAgeBetween(String lastname, int from, int to);
@ -442,6 +447,9 @@ public interface PersonRepository extends MongoRepository<Person, String>, Query
@Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }") @Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }")
List<PersonAggregate> groupByLastnameAnd(String property, Pageable page); List<PersonAggregate> groupByLastnameAnd(String property, Pageable page);
@Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }")
Streamable<PersonAggregate> streamGroupByLastnameAnd(String property, Pageable page);
@Aggregation(pipeline = "{ '$group' : { '_id' : null, 'total' : { $sum: '$age' } } }") @Aggregation(pipeline = "{ '$group' : { '_id' : null, 'total' : { $sum: '$age' } } }")
int sumAge(); int sumAge();

19
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java

@ -387,6 +387,25 @@ class QueryMethodContributionUnitTests {
"Document(\"$sort\", mappedSort.append(\"__score__\", -1))"); "Document(\"$sort\", mappedSort.append(\"__score__\", -1))");
} }
@Test // GH-5089
void rendersStreamableReturnType() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "streamUserNoArgumentsBy");
assertThat(methodSpec.toString()) //
.containsSubsequence("return", "Streamable.of(", "all())");
}
@Test // GH-5089
void rendersStreamableReturnTypeForAggregation() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "streamAsStreamableGroupByLastnameAndAsAggregationResults",
String.class);
assertThat(methodSpec.toString()) //
.containsSubsequence("return", "Streamable.of(", "getMappedResults())");
}
private static MethodSpec codeOf(Class<?> repository, String methodName, Class<?>... args) private static MethodSpec codeOf(Class<?> repository, String methodName, Class<?>... args)
throws NoSuchMethodException { throws NoSuchMethodException {

Loading…
Cancel
Save