Browse Source

Support `Streamable` return type in AOT repositories.

Closes #5089
Original pull request: #5090
pull/5093/merge
Christoph Strobl 1 month ago committed by Mark Paluch
parent
commit
4d94ee126e
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 10
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java
  2. 10
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java
  3. 8
      spring-data-mongodb/src/test/java/example/aot/UserRepository.java
  4. 38
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java
  5. 8
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java
  6. 19
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java

10
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java

@ -41,6 +41,7 @@ import org.springframework.data.mongodb.repository.ReadPreference; @@ -41,6 +41,7 @@ import org.springframework.data.mongodb.repository.ReadPreference;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.util.ReflectionUtils;
import org.springframework.data.util.Streamable;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.util.ClassUtils;
@ -145,8 +146,15 @@ class AggregationBlocks { @@ -145,8 +146,15 @@ class AggregationBlocks {
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
outputType);
} else {
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
CodeBlock resultBlock = CodeBlock.of("$L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType);
if (queryMethod.getReturnType().getType().equals(Streamable.class)) {
resultBlock = CodeBlock.of("$T.of($L)", Streamable.class, resultBlock);
}
builder.addStatement("return $L", resultBlock);
}
}
}

10
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java

@ -34,6 +34,7 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli @@ -34,6 +34,7 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.util.Lazy;
import org.springframework.data.util.Streamable;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.javapoet.TypeName;
@ -145,8 +146,15 @@ class QueryBlocks { @@ -145,8 +146,15 @@ class QueryBlocks {
context.localVariable("finder"), query.name(), terminatingMethod, returnType);
} else {
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
CodeBlock resultBlock = CodeBlock.of("$L.matching($L).$L", context.localVariable("finder"), query.name(),
terminatingMethod);
if (queryMethod.getReturnType().getType().equals(Streamable.class)) {
resultBlock = CodeBlock.of("$T.of($L)", Streamable.class, resultBlock);
}
builder.addStatement("return $L", resultBlock);
}
}

8
spring-data-mongodb/src/test/java/example/aot/UserRepository.java

@ -48,6 +48,7 @@ import org.springframework.data.mongodb.repository.Update; @@ -48,6 +48,7 @@ import org.springframework.data.mongodb.repository.Update;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.query.Param;
import org.springframework.data.util.Streamable;
/**
* @author Christoph Strobl
@ -58,6 +59,8 @@ public interface UserRepository extends CrudRepository<User, String> { @@ -58,6 +59,8 @@ public interface UserRepository extends CrudRepository<User, String> {
List<User> findUserNoArgumentsBy();
Streamable<User> streamUserNoArgumentsBy();
User findOneByUsername(String username);
Optional<User> findOptionalOneByUsername(String username);
@ -267,6 +270,11 @@ public interface UserRepository extends CrudRepository<User, String> { @@ -267,6 +270,11 @@ public interface UserRepository extends CrudRepository<User, String> {
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
Stream<UserAggregate> streamGroupByLastnameAndAsAggregationResults(String property);
@Aggregation(pipeline = { //
"{ '$match' : { 'last_name' : { '$ne' : null } } }", //
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
Streamable<UserAggregate> streamAsStreamableGroupByLastnameAndAsAggregationResults(String property);
@Aggregation(pipeline = { //
"{ '$match' : { 'posts' : { '$ne' : null } } }", //
"{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", //

38
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java

@ -79,6 +79,7 @@ import org.springframework.data.mongodb.test.util.DirtiesStateExtension.DirtiesS @@ -79,6 +79,7 @@ import org.springframework.data.mongodb.test.util.DirtiesStateExtension.DirtiesS
import org.springframework.data.mongodb.test.util.DirtiesStateExtension.ProvidesState;
import org.springframework.data.mongodb.test.util.EnableIfMongoServerVersion;
import org.springframework.data.querydsl.QSort;
import org.springframework.data.util.Streamable;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.test.util.ReflectionTestUtils;
@ -312,6 +313,33 @@ public abstract class AbstractPersonRepositoryIntegrationTests implements Dirtie @@ -312,6 +313,33 @@ public abstract class AbstractPersonRepositoryIntegrationTests implements Dirtie
assertThat(result).hasSize(1).contains(dave);
}
@Test // GH-5089
void streamPersonByAddressCorrectly() {
Address address = new Address("Foo Street 1", "C0123", "Bar");
dave.setAddress(address);
repository.save(dave);
Streamable<Person> result = repository.streamByAddress(address);
assertThat(result).hasSize(1).contains(dave);
}
@Test // GH-5089
void streamPersonByAddressCorrectlyWhenPaged() {
Address address = new Address("Foo Street 1", "C0123", "Bar");
dave.setAddress(address);
oliver.setAddress(address);
repository.saveAll(List.of(dave, oliver));
Streamable<Person> result = repository.streamByAddress(address,
PageRequest.of(0, 1, Sort.by(Direction.DESC, "firstname")));
assertThat(result).containsExactly(oliver);
result = repository.streamByAddress(address, PageRequest.of(1, 1, Sort.by(Direction.DESC, "firstname")));
assertThat(result).containsExactly(dave);
}
@Test
void findsPeopleByZipCode() {
@ -1516,6 +1544,16 @@ public abstract class AbstractPersonRepositoryIntegrationTests implements Dirtie @@ -1516,6 +1544,16 @@ public abstract class AbstractPersonRepositoryIntegrationTests implements Dirtie
new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August")));
}
@Test // GH-5089
void annotatedAggregationReturningStreamable() {
assertThat(repository.streamGroupByLastnameAnd("firstname", PageRequest.of(1, 2, Sort.by("lastname")))) //
.isInstanceOf(Streamable.class) //
.containsExactly( //
new PersonAggregate("Lessard", Collections.singletonList("Stefan")), //
new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August")));
}
@Test // DATAMONGO-2153
void annotatedAggregationWithSingleSimpleResult() {
assertThat(repository.sumAge()).isEqualTo(245);

8
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java

@ -45,6 +45,7 @@ import org.springframework.data.mongodb.core.query.UpdateDefinition; @@ -45,6 +45,7 @@ import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.data.mongodb.repository.Person.Sex;
import org.springframework.data.querydsl.QuerydslPredicateExecutor;
import org.springframework.data.repository.query.Param;
import org.springframework.data.util.Streamable;
/**
* Sample repository managing {@link Person} entities.
@ -211,6 +212,10 @@ public interface PersonRepository extends MongoRepository<Person, String>, Query @@ -211,6 +212,10 @@ public interface PersonRepository extends MongoRepository<Person, String>, Query
*/
List<Person> findByAddress(Address address);
Streamable<Person> streamByAddress(Address address);
Streamable<Person> streamByAddress(Address address, Pageable pageable);
List<Person> findByAddressZipCode(String zipCode);
List<Person> findByLastnameLikeAndAgeBetween(String lastname, int from, int to);
@ -442,6 +447,9 @@ public interface PersonRepository extends MongoRepository<Person, String>, Query @@ -442,6 +447,9 @@ public interface PersonRepository extends MongoRepository<Person, String>, Query
@Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }")
List<PersonAggregate> groupByLastnameAnd(String property, Pageable page);
@Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }")
Streamable<PersonAggregate> streamGroupByLastnameAnd(String property, Pageable page);
@Aggregation(pipeline = "{ '$group' : { '_id' : null, 'total' : { $sum: '$age' } } }")
int sumAge();

19
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java

@ -387,6 +387,25 @@ class QueryMethodContributionUnitTests { @@ -387,6 +387,25 @@ class QueryMethodContributionUnitTests {
"Document(\"$sort\", mappedSort.append(\"__score__\", -1))");
}
@Test // GH-5089
void rendersStreamableReturnType() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "streamUserNoArgumentsBy");
assertThat(methodSpec.toString()) //
.containsSubsequence("return", "Streamable.of(", "all())");
}
@Test // GH-5089
void rendersStreamableReturnTypeForAggregation() throws NoSuchMethodException {
MethodSpec methodSpec = codeOf(UserRepository.class, "streamAsStreamableGroupByLastnameAndAsAggregationResults",
String.class);
assertThat(methodSpec.toString()) //
.containsSubsequence("return", "Streamable.of(", "getMappedResults())");
}
private static MethodSpec codeOf(Class<?> repository, String methodName, Class<?>... args)
throws NoSuchMethodException {

Loading…
Cancel
Save