Browse Source

Fix aggregation streams, count result conversion.

See: #4939
Original Pull Request: #4970
pull/4976/head
Mark Paluch 7 months ago
parent
commit
2e52276f39
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 4
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java
  2. 53
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java
  3. 30
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java
  4. 32
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateInteraction.java
  5. 7
      spring-data-mongodb/src/test/java/example/aot/UserRepository.java
  6. 14
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java
  7. 4
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

4
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java

@ -103,6 +103,10 @@ public class MongoAotRepositoryFragmentSupport { @@ -103,6 +103,10 @@ public class MongoAotRepositoryFragmentSupport {
return list;
}
protected Object convertSimpleRawResult(Class<?> targetType, Document rawResult) {
return extractSimpleTypeResult(rawResult, targetType, mongoConverter);
}
private static <T> @Nullable T extractSimpleTypeResult(@Nullable Document source, Class<T> targetType,
MongoConverter converter) {

53
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

@ -19,6 +19,7 @@ import java.util.ArrayList; @@ -19,6 +19,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import org.bson.Document;
import org.jspecify.annotations.NullUnmarked;
@ -49,7 +50,6 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli @@ -49,7 +50,6 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.util.ReflectionUtils;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.javapoet.TypeName;
@ -182,17 +182,15 @@ class MongoCodeBlocks { @@ -182,17 +182,15 @@ class MongoCodeBlocks {
String mongoOpsRef = context.fieldNameOf(MongoOperations.class);
Builder builder = CodeBlock.builder();
Class<?> domainType = context.getRepositoryInformation().getDomainType();
boolean isProjecting = context.getActualReturnType() != null
&& !ObjectUtils.nullSafeEquals(TypeName.get(context.getRepositoryInformation().getDomainType()),
context.getActualReturnType());
&& !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType());
Object actualReturnType = isProjecting ? context.getActualReturnType().getType()
: context.getRepositoryInformation().getDomainType();
Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType;
builder.add("\n");
builder.addStatement("$T<$T> remover = $L.remove($T.class)", ExecutableRemove.class,
context.getRepositoryInformation().getDomainType(), mongoOpsRef,
context.getRepositoryInformation().getDomainType());
builder.addStatement("$T<$T> $L = $L.remove($T.class)", ExecutableRemove.class, domainType,
context.localVariable("remover"), mongoOpsRef, domainType);
DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL;
if (!queryMethod.isCollectionQuery()) {
@ -204,11 +202,20 @@ class MongoCodeBlocks { @@ -204,11 +202,20 @@ class MongoCodeBlocks {
}
actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())
? ClassName.get(context.getMethod().getReturnType())
? TypeName.get(context.getMethod().getReturnType())
: queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType;
builder.addStatement("return ($T) new $T(remover, $T.$L).execute($L)", actualReturnType, DeleteExecution.class,
if (ClassUtils.isVoidType(context.getMethod().getReturnType())) {
builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"),
DeleteExecution.Type.class, type.name(), queryVariableName);
} else if (context.getMethod().getReturnType() == Optional.class) {
builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class,
actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class,
type.name(), queryVariableName);
} else {
builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class,
context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName);
}
return builder.build();
}
@ -318,8 +325,18 @@ class MongoCodeBlocks { @@ -318,8 +325,18 @@ class MongoCodeBlocks {
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
if (queryMethod.isStreamQuery()) {
builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class,
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
builder.addStatement("return $L.map(it -> ($T) convertSimpleRawResult($T.class, it))",
context.localVariable("results"), returnType, returnType);
} else {
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
if (!queryMethod.isCollectionQuery()) {
builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
CollectionUtils.class, returnType, returnType, context.localVariable("results"));
@ -327,6 +344,7 @@ class MongoCodeBlocks { @@ -327,6 +344,7 @@ class MongoCodeBlocks {
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
context.localVariable("results"));
}
}
} else {
if (queryMethod.isSliceQuery()) {
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
@ -339,10 +357,17 @@ class MongoCodeBlocks { @@ -339,10 +357,17 @@ class MongoCodeBlocks {
context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(),
context.localVariable("hasNext"));
} else {
if (queryMethod.isStreamQuery()) {
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
outputType);
} else {
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType);
}
}
}
return builder.build();
}
@ -419,10 +444,18 @@ class MongoCodeBlocks { @@ -419,10 +444,18 @@ class MongoCodeBlocks {
builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(),
scrollPositionParameterName);
} else {
if (query.isCount() && !ClassUtils.isAssignable(Long.class, context.getActualReturnType().getRawClass())) {
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
builder.addStatement("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)", NumberUtils.class,
context.localVariable("finder"), query.name(), terminatingMethod, returnType);
} else {
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
terminatingMethod);
}
}
return builder.build();
}

30
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

@ -18,6 +18,7 @@ package org.springframework.data.mongodb.repository.aot; @@ -18,6 +18,7 @@ package org.springframework.data.mongodb.repository.aot;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*;
import java.lang.reflect.Method;
import java.util.Locale;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
@ -119,16 +120,25 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -119,16 +120,25 @@ public class MongoRepositoryContributor extends RepositoryContributor {
if (queryMethod.isModifyingQuery()) {
int updateIndex = queryMethod.getParameters().getUpdateIndex();
if (updateIndex != -1) {
UpdateInteraction update = new UpdateInteraction(query, null, updateIndex);
return updateMethodContributor(queryMethod, update);
} else {
Update updateSource = queryMethod.getUpdateSource();
if (StringUtils.hasText(updateSource.value())) {
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()));
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null);
return updateMethodContributor(queryMethod, update);
}
if (!ObjectUtils.isEmpty(updateSource.pipeline())) {
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline());
return aggregationUpdateMethodContributor(queryMethod, update);
}
}
}
return queryMethodContributor(queryMethod, query);
}
@ -160,10 +170,12 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -160,10 +170,12 @@ public class MongoRepositoryContributor extends RepositoryContributor {
private static boolean backoff(MongoQueryMethod method) {
boolean skip = method.isGeoNearQuery() || method.isSearchQuery();
// TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays.
boolean skip = method.isGeoNearQuery() || method.isSearchQuery()
|| method.getName().toLowerCase(Locale.ROOT).contains("regex") || method.getReturnType().getType().isArray();
if (skip && logger.isDebugEnabled()) {
logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming, search or scrolling query"
logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query"
.formatted(method.getName()));
}
return skip;
@ -197,9 +209,15 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -197,9 +209,15 @@ public class MongoRepositoryContributor extends RepositoryContributor {
.usingQueryVariableName(filterVariableName).build());
// update definition
String updateVariableName = context.localVariable("updateDefinition");
builder.add(
updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName).build());
String updateVariableName;
if (update.hasUpdateDefinitionParameter()) {
updateVariableName = context.getParameterName(update.getRequiredUpdateDefinitionParameter());
} else {
updateVariableName = context.localVariable("updateDefinition");
builder.add(updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName)
.build());
}
builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName)
.referencingUpdate(updateVariableName).build());

32
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateInteraction.java

@ -17,37 +17,60 @@ package org.springframework.data.mongodb.repository.aot; @@ -17,37 +17,60 @@ package org.springframework.data.mongodb.repository.aot;
import java.util.Map;
import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.aot.generate.QueryMetadata;
import org.springframework.util.Assert;
/**
* An {@link MongoInteraction} to execute an update.
*
* @author Christoph Strobl
* @author Mark Paluch
* @since 5.0
*/
class UpdateInteraction extends MongoInteraction implements QueryMetadata {
private final QueryInteraction filter;
private final StringUpdate update;
private final @Nullable StringUpdate update;
private final @Nullable Integer updateDefinitionParameter;
UpdateInteraction(QueryInteraction filter, StringUpdate update) {
UpdateInteraction(QueryInteraction filter, @Nullable StringUpdate update,
@Nullable Integer updateDefinitionParameter) {
this.filter = filter;
this.update = update;
this.updateDefinitionParameter = updateDefinitionParameter;
}
QueryInteraction getFilter() {
public QueryInteraction getFilter() {
return filter;
}
StringUpdate getUpdate() {
public @Nullable StringUpdate getUpdate() {
return update;
}
public int getRequiredUpdateDefinitionParameter() {
Assert.notNull(updateDefinitionParameter, "UpdateDefinitionParameter must not be null!");
return updateDefinitionParameter;
}
public boolean hasUpdateDefinitionParameter() {
return updateDefinitionParameter != null;
}
@Override
public Map<String, Object> serialize() {
Map<String, Object> serialized = filter.serialize();
if (update != null) {
serialized.put("filter", filter.getQuery().getQueryString());
serialized.put("update", update.getUpdateString());
}
return serialized;
}
@ -55,4 +78,5 @@ class UpdateInteraction extends MongoInteraction implements QueryMetadata { @@ -55,4 +78,5 @@ class UpdateInteraction extends MongoInteraction implements QueryMetadata {
InteractionType getExecutionType() {
return InteractionType.UPDATE;
}
}

7
spring-data-mongodb/src/test/java/example/aot/UserRepository.java

@ -55,6 +55,8 @@ public interface UserRepository extends CrudRepository<User, String> { @@ -55,6 +55,8 @@ public interface UserRepository extends CrudRepository<User, String> {
Long countUsersByLastname(String lastname);
int countUsersAsIntByLastname(String lastname);
Boolean existsUserByLastname(String lastname);
List<User> findByLastnameStartingWith(String lastname);
@ -216,6 +218,11 @@ public interface UserRepository extends CrudRepository<User, String> { @@ -216,6 +218,11 @@ public interface UserRepository extends CrudRepository<User, String> {
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
AggregationResults<UserAggregate> groupByLastnameAndAsAggregationResults(String property);
@Aggregation(pipeline = { //
"{ '$match' : { 'last_name' : { '$ne' : null } } }", //
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
Stream<UserAggregate> streamGroupByLastnameAndAsAggregationResults(String property);
@Aggregation(pipeline = { //
"{ '$match' : { 'posts' : { '$ne' : null } } }", //
"{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", //

14
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java

@ -107,8 +107,8 @@ class MongoRepositoryContributorTests { @@ -107,8 +107,8 @@ class MongoRepositoryContributorTests {
@Test
void testDerivedCount() {
Long value = fragment.countUsersByLastname("Skywalker");
assertThat(value).isEqualTo(2L);
assertThat(fragment.countUsersByLastname("Skywalker")).isEqualTo(2L);
assertThat(fragment.countUsersAsIntByLastname("Skywalker")).isEqualTo(2);
}
@Test
@ -559,6 +559,16 @@ class MongoRepositoryContributorTests { @@ -559,6 +559,16 @@ class MongoRepositoryContributorTests {
new UserAggregate("Solo", List.of("Han", "Ben")));
}
@Test
void testAggregationStreamWithProjectedResultsWrappedInAggregationResults() {
List<UserAggregate> allLastnames = fragment.streamGroupByLastnameAndAsAggregationResults("first_name").toList();
assertThat(allLastnames).containsExactlyInAnyOrder(//
new UserAggregate("Skywalker", List.of("Anakin", "Luke")), //
new UserAggregate("Organa", List.of("Leia")), //
new UserAggregate("Solo", List.of("Han", "Ben")));
}
@Test
void testAggregationWithSingleResultExtraction() {
assertThat(fragment.sumPosts()).isEqualTo(5);

4
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

@ -37,12 +37,12 @@ import org.springframework.data.repository.core.support.RepositoryComposition; @@ -37,12 +37,12 @@ import org.springframework.data.repository.core.support.RepositoryComposition;
/**
* @author Christoph Strobl
*/
class TestMongoAotRepositoryContext implements AotRepositoryContext {
public class TestMongoAotRepositoryContext implements AotRepositoryContext {
private final StubRepositoryInformation repositoryInformation;
private final Environment environment = new StandardEnvironment();
TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) {
public TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) {
this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition);
}

Loading…
Cancel
Save