Browse Source

Fix aggregation streams, count result conversion.

See: #4939
Original Pull Request: #4970
pull/4976/head
Mark Paluch 7 months ago
parent
commit
2e52276f39
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 4
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java
  2. 77
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java
  3. 42
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java
  4. 34
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateInteraction.java
  5. 7
      spring-data-mongodb/src/test/java/example/aot/UserRepository.java
  6. 14
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java
  7. 4
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

4
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java

@ -103,6 +103,10 @@ public class MongoAotRepositoryFragmentSupport {
return list; return list;
} }
protected Object convertSimpleRawResult(Class<?> targetType, Document rawResult) {
return extractSimpleTypeResult(rawResult, targetType, mongoConverter);
}
private static <T> @Nullable T extractSimpleTypeResult(@Nullable Document source, Class<T> targetType, private static <T> @Nullable T extractSimpleTypeResult(@Nullable Document source, Class<T> targetType,
MongoConverter converter) { MongoConverter converter) {

77
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

@ -19,6 +19,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import java.util.stream.Stream;
import org.bson.Document; import org.bson.Document;
import org.jspecify.annotations.NullUnmarked; import org.jspecify.annotations.NullUnmarked;
@ -49,7 +50,6 @@ import org.springframework.data.mongodb.repository.query.MongoQueryExecution.Sli
import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.util.ReflectionUtils; import org.springframework.data.util.ReflectionUtils;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder; import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.javapoet.TypeName; import org.springframework.javapoet.TypeName;
@ -182,17 +182,15 @@ class MongoCodeBlocks {
String mongoOpsRef = context.fieldNameOf(MongoOperations.class); String mongoOpsRef = context.fieldNameOf(MongoOperations.class);
Builder builder = CodeBlock.builder(); Builder builder = CodeBlock.builder();
Class<?> domainType = context.getRepositoryInformation().getDomainType();
boolean isProjecting = context.getActualReturnType() != null boolean isProjecting = context.getActualReturnType() != null
&& !ObjectUtils.nullSafeEquals(TypeName.get(context.getRepositoryInformation().getDomainType()), && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType());
context.getActualReturnType());
Object actualReturnType = isProjecting ? context.getActualReturnType().getType() Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType;
: context.getRepositoryInformation().getDomainType();
builder.add("\n"); builder.add("\n");
builder.addStatement("$T<$T> remover = $L.remove($T.class)", ExecutableRemove.class, builder.addStatement("$T<$T> $L = $L.remove($T.class)", ExecutableRemove.class, domainType,
context.getRepositoryInformation().getDomainType(), mongoOpsRef, context.localVariable("remover"), mongoOpsRef, domainType);
context.getRepositoryInformation().getDomainType());
DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL;
if (!queryMethod.isCollectionQuery()) { if (!queryMethod.isCollectionQuery()) {
@ -204,11 +202,20 @@ class MongoCodeBlocks {
} }
actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())
? ClassName.get(context.getMethod().getReturnType()) ? TypeName.get(context.getMethod().getReturnType())
: queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType;
builder.addStatement("return ($T) new $T(remover, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, if (ClassUtils.isVoidType(context.getMethod().getReturnType())) {
DeleteExecution.Type.class, type.name(), queryVariableName); builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"),
DeleteExecution.Type.class, type.name(), queryVariableName);
} else if (context.getMethod().getReturnType() == Optional.class) {
builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class,
actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class,
type.name(), queryVariableName);
} else {
builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class,
context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName);
}
return builder.build(); return builder.build();
} }
@ -318,14 +325,25 @@ class MongoCodeBlocks {
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, if (queryMethod.isStreamQuery()) {
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
if (!queryMethod.isCollectionQuery()) { builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class,
builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))", context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
CollectionUtils.class, returnType, returnType, context.localVariable("results"));
builder.addStatement("return $L.map(it -> ($T) convertSimpleRawResult($T.class, it))",
context.localVariable("results"), returnType, returnType);
} else { } else {
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
context.localVariable("results")); builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
if (!queryMethod.isCollectionQuery()) {
builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
CollectionUtils.class, returnType, returnType, context.localVariable("results"));
} else {
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
context.localVariable("results"));
}
} }
} else { } else {
if (queryMethod.isSliceQuery()) { if (queryMethod.isSliceQuery()) {
@ -339,8 +357,15 @@ class MongoCodeBlocks {
context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(), context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(),
context.localVariable("hasNext")); context.localVariable("hasNext"));
} else { } else {
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType); if (queryMethod.isStreamQuery()) {
builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName,
outputType);
} else {
builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef,
aggregationVariableName, outputType);
}
} }
} }
@ -420,8 +445,16 @@ class MongoCodeBlocks {
builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(), builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(),
scrollPositionParameterName); scrollPositionParameterName);
} else { } else {
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(), if (query.isCount() && !ClassUtils.isAssignable(Long.class, context.getActualReturnType().getRawClass())) {
terminatingMethod);
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
builder.addStatement("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)", NumberUtils.class,
context.localVariable("finder"), query.name(), terminatingMethod, returnType);
} else {
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
terminatingMethod);
}
} }
return builder.build(); return builder.build();

42
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

@ -18,6 +18,7 @@ package org.springframework.data.mongodb.repository.aot;
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*; import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Locale;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
@ -119,14 +120,23 @@ public class MongoRepositoryContributor extends RepositoryContributor {
if (queryMethod.isModifyingQuery()) { if (queryMethod.isModifyingQuery()) {
Update updateSource = queryMethod.getUpdateSource(); int updateIndex = queryMethod.getParameters().getUpdateIndex();
if (StringUtils.hasText(updateSource.value())) { if (updateIndex != -1) {
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()));
UpdateInteraction update = new UpdateInteraction(query, null, updateIndex);
return updateMethodContributor(queryMethod, update); return updateMethodContributor(queryMethod, update);
}
if (!ObjectUtils.isEmpty(updateSource.pipeline())) { } else {
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline()); Update updateSource = queryMethod.getUpdateSource();
return aggregationUpdateMethodContributor(queryMethod, update); if (StringUtils.hasText(updateSource.value())) {
UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null);
return updateMethodContributor(queryMethod, update);
}
if (!ObjectUtils.isEmpty(updateSource.pipeline())) {
AggregationUpdateInteraction update = new AggregationUpdateInteraction(query, updateSource.pipeline());
return aggregationUpdateMethodContributor(queryMethod, update);
}
} }
} }
@ -160,10 +170,12 @@ public class MongoRepositoryContributor extends RepositoryContributor {
private static boolean backoff(MongoQueryMethod method) { private static boolean backoff(MongoQueryMethod method) {
boolean skip = method.isGeoNearQuery() || method.isSearchQuery(); // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays.
boolean skip = method.isGeoNearQuery() || method.isSearchQuery()
|| method.getName().toLowerCase(Locale.ROOT).contains("regex") || method.getReturnType().getType().isArray();
if (skip && logger.isDebugEnabled()) { if (skip && logger.isDebugEnabled()) {
logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming, search or scrolling query" logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query"
.formatted(method.getName())); .formatted(method.getName()));
} }
return skip; return skip;
@ -197,9 +209,15 @@ public class MongoRepositoryContributor extends RepositoryContributor {
.usingQueryVariableName(filterVariableName).build()); .usingQueryVariableName(filterVariableName).build());
// update definition // update definition
String updateVariableName = context.localVariable("updateDefinition"); String updateVariableName;
builder.add(
updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName).build()); if (update.hasUpdateDefinitionParameter()) {
updateVariableName = context.getParameterName(update.getRequiredUpdateDefinitionParameter());
} else {
updateVariableName = context.localVariable("updateDefinition");
builder.add(updateBlockBuilder(context, queryMethod).update(update).usingUpdateVariableName(updateVariableName)
.build());
}
builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName) builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName)
.referencingUpdate(updateVariableName).build()); .referencingUpdate(updateVariableName).build());

34
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateInteraction.java

@ -17,37 +17,60 @@ package org.springframework.data.mongodb.repository.aot;
import java.util.Map; import java.util.Map;
import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.aot.generate.QueryMetadata; import org.springframework.data.repository.aot.generate.QueryMetadata;
import org.springframework.util.Assert;
/** /**
* An {@link MongoInteraction} to execute an update. * An {@link MongoInteraction} to execute an update.
* *
* @author Christoph Strobl * @author Christoph Strobl
* @author Mark Paluch
* @since 5.0 * @since 5.0
*/ */
class UpdateInteraction extends MongoInteraction implements QueryMetadata { class UpdateInteraction extends MongoInteraction implements QueryMetadata {
private final QueryInteraction filter; private final QueryInteraction filter;
private final StringUpdate update; private final @Nullable StringUpdate update;
private final @Nullable Integer updateDefinitionParameter;
UpdateInteraction(QueryInteraction filter, StringUpdate update) { UpdateInteraction(QueryInteraction filter, @Nullable StringUpdate update,
@Nullable Integer updateDefinitionParameter) {
this.filter = filter; this.filter = filter;
this.update = update; this.update = update;
this.updateDefinitionParameter = updateDefinitionParameter;
} }
QueryInteraction getFilter() { public QueryInteraction getFilter() {
return filter; return filter;
} }
StringUpdate getUpdate() { public @Nullable StringUpdate getUpdate() {
return update; return update;
} }
public int getRequiredUpdateDefinitionParameter() {
Assert.notNull(updateDefinitionParameter, "UpdateDefinitionParameter must not be null!");
return updateDefinitionParameter;
}
public boolean hasUpdateDefinitionParameter() {
return updateDefinitionParameter != null;
}
@Override @Override
public Map<String, Object> serialize() { public Map<String, Object> serialize() {
Map<String, Object> serialized = filter.serialize(); Map<String, Object> serialized = filter.serialize();
serialized.put("update", update.getUpdateString());
if (update != null) {
serialized.put("filter", filter.getQuery().getQueryString());
serialized.put("update", update.getUpdateString());
}
return serialized; return serialized;
} }
@ -55,4 +78,5 @@ class UpdateInteraction extends MongoInteraction implements QueryMetadata {
InteractionType getExecutionType() { InteractionType getExecutionType() {
return InteractionType.UPDATE; return InteractionType.UPDATE;
} }
} }

7
spring-data-mongodb/src/test/java/example/aot/UserRepository.java

@ -55,6 +55,8 @@ public interface UserRepository extends CrudRepository<User, String> {
Long countUsersByLastname(String lastname); Long countUsersByLastname(String lastname);
int countUsersAsIntByLastname(String lastname);
Boolean existsUserByLastname(String lastname); Boolean existsUserByLastname(String lastname);
List<User> findByLastnameStartingWith(String lastname); List<User> findByLastnameStartingWith(String lastname);
@ -216,6 +218,11 @@ public interface UserRepository extends CrudRepository<User, String> {
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" }) "{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
AggregationResults<UserAggregate> groupByLastnameAndAsAggregationResults(String property); AggregationResults<UserAggregate> groupByLastnameAndAsAggregationResults(String property);
@Aggregation(pipeline = { //
"{ '$match' : { 'last_name' : { '$ne' : null } } }", //
"{ '$group': { '_id' : '$last_name', names : { $addToSet : '$?0' } } }" })
Stream<UserAggregate> streamGroupByLastnameAndAsAggregationResults(String property);
@Aggregation(pipeline = { // @Aggregation(pipeline = { //
"{ '$match' : { 'posts' : { '$ne' : null } } }", // "{ '$match' : { 'posts' : { '$ne' : null } } }", //
"{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", // "{ '$project': { 'nrPosts' : {'$size': '$posts' } } }", //

14
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java

@ -107,8 +107,8 @@ class MongoRepositoryContributorTests {
@Test @Test
void testDerivedCount() { void testDerivedCount() {
Long value = fragment.countUsersByLastname("Skywalker"); assertThat(fragment.countUsersByLastname("Skywalker")).isEqualTo(2L);
assertThat(value).isEqualTo(2L); assertThat(fragment.countUsersAsIntByLastname("Skywalker")).isEqualTo(2);
} }
@Test @Test
@ -559,6 +559,16 @@ class MongoRepositoryContributorTests {
new UserAggregate("Solo", List.of("Han", "Ben"))); new UserAggregate("Solo", List.of("Han", "Ben")));
} }
@Test
void testAggregationStreamWithProjectedResultsWrappedInAggregationResults() {
List<UserAggregate> allLastnames = fragment.streamGroupByLastnameAndAsAggregationResults("first_name").toList();
assertThat(allLastnames).containsExactlyInAnyOrder(//
new UserAggregate("Skywalker", List.of("Anakin", "Luke")), //
new UserAggregate("Organa", List.of("Leia")), //
new UserAggregate("Solo", List.of("Han", "Ben")));
}
@Test @Test
void testAggregationWithSingleResultExtraction() { void testAggregationWithSingleResultExtraction() {
assertThat(fragment.sumPosts()).isEqualTo(5); assertThat(fragment.sumPosts()).isEqualTo(5);

4
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

@ -37,12 +37,12 @@ import org.springframework.data.repository.core.support.RepositoryComposition;
/** /**
* @author Christoph Strobl * @author Christoph Strobl
*/ */
class TestMongoAotRepositoryContext implements AotRepositoryContext { public class TestMongoAotRepositoryContext implements AotRepositoryContext {
private final StubRepositoryInformation repositoryInformation; private final StubRepositoryInformation repositoryInformation;
private final Environment environment = new StandardEnvironment(); private final Environment environment = new StandardEnvironment();
TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) { public TestMongoAotRepositoryContext(Class<?> repositoryInterface, @Nullable RepositoryComposition composition) {
this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition); this.repositoryInformation = new StubRepositoryInformation(repositoryInterface, composition);
} }

Loading…
Cancel
Save