diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java index 151a3a2e6..e8d69504d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java @@ -127,6 +127,7 @@ public abstract class AbstractMongoQuery implements RepositoryQuery { * @param accessor for providing invocation arguments. Never {@literal null}. * @param typeToRead the desired component target type. Can be {@literal null}. */ + @Nullable protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor, @Nullable Class typeToRead) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java index 444be1351..3e4c2fec1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java @@ -27,6 +27,7 @@ import org.springframework.data.domain.Sort.Order; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; import org.springframework.data.mongodb.core.aggregation.AggregationOptions; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.query.Collation; import org.springframework.data.mongodb.core.query.Meta; @@ -125,7 +126,7 @@ abstract class AggregationUtils { * @param accessor * @param targetType */ - static void appendSortIfPresent(List aggregationPipeline, ConvertingParameterAccessor accessor, + static void appendSortIfPresent(AggregationPipeline aggregationPipeline, ConvertingParameterAccessor accessor, Class targetType) { if (accessor.getSort().isUnsorted()) { @@ -150,7 +151,7 @@ abstract class AggregationUtils { * @param aggregationPipeline * @param accessor */ - static void appendLimitAndOffsetIfPresent(List aggregationPipeline, + static void appendLimitAndOffsetIfPresent(AggregationPipeline aggregationPipeline, ConvertingParameterAccessor accessor) { appendLimitAndOffsetIfPresent(aggregationPipeline, accessor, LongUnaryOperator.identity(), IntUnaryOperator.identity()); @@ -166,7 +167,7 @@ abstract class AggregationUtils { * @param limitOperator * @since 3.3 */ - static void appendLimitAndOffsetIfPresent(List aggregationPipeline, + static void appendLimitAndOffsetIfPresent(AggregationPipeline aggregationPipeline, ConvertingParameterAccessor accessor, LongUnaryOperator offsetOperator, IntUnaryOperator limitOperator) { Pageable pageable = accessor.getPageable(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java index dca76ff7f..a512fd6e0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java @@ -38,6 +38,7 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; import org.springframework.data.support.PageableExecutionUtils; import org.springframework.data.util.TypeInformation; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -55,6 +56,7 @@ import com.mongodb.client.result.DeleteResult; @FunctionalInterface interface MongoQueryExecution { + @Nullable Object execute(Query query); /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregation.java index 1ebfdbfcf..b55bbf6fe 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregation.java @@ -15,6 +15,8 @@ */ package org.springframework.data.mongodb.repository.query; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.util.ReflectionUtils; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -81,7 +83,7 @@ public class ReactiveStringBasedAggregation extends AbstractReactiveMongoQuery { Class sourceType = method.getDomainClass(); Class targetType = typeToRead; - List pipeline = it; + AggregationPipeline pipeline = new AggregationPipeline(it); AggregationUtils.appendSortIfPresent(pipeline, accessor, typeToRead); AggregationUtils.appendLimitAndOffsetIfPresent(pipeline, accessor); @@ -93,10 +95,13 @@ public class ReactiveStringBasedAggregation extends AbstractReactiveMongoQuery { targetType = Document.class; } - AggregationOptions options = computeOptions(method, accessor); - TypedAggregation aggregation = new TypedAggregation<>(sourceType, pipeline, options); + AggregationOptions options = computeOptions(method, accessor, pipeline); + TypedAggregation aggregation = new TypedAggregation<>(sourceType, pipeline.getOperations(), options); Flux flux = reactiveMongoOperations.aggregate(aggregation, targetType); + if(ReflectionUtils.isVoid(typeToRead)) { + return flux.then(); + } if (isSimpleReturnType && !isRawReturnType) { flux = flux.handle((item, sink) -> { @@ -121,7 +126,7 @@ public class ReactiveStringBasedAggregation extends AbstractReactiveMongoQuery { return parseAggregationPipeline(getQueryMethod().getAnnotatedAggregation(), accessor); } - private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor) { + private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor, AggregationPipeline pipeline) { AggregationOptions.Builder builder = Aggregation.newAggregationOptions(); @@ -129,6 +134,9 @@ public class ReactiveStringBasedAggregation extends AbstractReactiveMongoQuery { expressionParser, evaluationContextProvider); AggregationUtils.applyMeta(builder, method); AggregationUtils.applyHint(builder, method); + if(ReflectionUtils.isVoid(method.getReturnType().getComponentType().getType()) && pipeline.isOutOrMerge()) { + builder.skipOutput(); + } return builder.build(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java index a3847a390..f69a49149 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java @@ -29,6 +29,7 @@ import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; import org.springframework.data.mongodb.core.aggregation.AggregationOptions; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.AggregationOptions.Builder; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; @@ -37,8 +38,12 @@ import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider; import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.util.ReflectionUtils; import org.springframework.expression.ExpressionParser; +import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; +import org.springframework.util.ObjectUtils; /** * {@link AbstractMongoQuery} implementation to run string-based aggregations using @@ -85,13 +90,14 @@ public class StringBasedAggregation extends AbstractMongoQuery { * @see org.springframework.data.mongodb.repository.query.AbstractReactiveMongoQuery#doExecute(org.springframework.data.mongodb.repository.query.MongoQueryMethod, org.springframework.data.repository.query.ResultProcessor, org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor, java.lang.Class) */ @Override + @Nullable protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProcessor, ConvertingParameterAccessor accessor, Class typeToRead) { Class sourceType = method.getDomainClass(); Class targetType = typeToRead; - List pipeline = computePipeline(method, accessor); + AggregationPipeline pipeline = computePipeline(method, accessor); AggregationUtils.appendSortIfPresent(pipeline, accessor, typeToRead); if (method.isSliceQuery()) { @@ -112,8 +118,8 @@ public class StringBasedAggregation extends AbstractMongoQuery { targetType = method.getReturnType().getRequiredActualType().getRequiredComponentType().getType(); } - AggregationOptions options = computeOptions(method, accessor); - TypedAggregation aggregation = new TypedAggregation<>(sourceType, pipeline, options); + AggregationOptions options = computeOptions(method, accessor, pipeline); + TypedAggregation aggregation = new TypedAggregation<>(sourceType, pipeline.getOperations(), options); if (method.isStreamQuery()) { @@ -127,6 +133,9 @@ public class StringBasedAggregation extends AbstractMongoQuery { } AggregationResults result = (AggregationResults) mongoOperations.aggregate(aggregation, targetType); + if(ReflectionUtils.isVoid(typeToRead)) { + return null; + } if (isRawAggregationResult) { return result; @@ -168,11 +177,11 @@ public class StringBasedAggregation extends AbstractMongoQuery { return MongoSimpleTypes.HOLDER.isSimpleType(targetType); } - List computePipeline(MongoQueryMethod method, ConvertingParameterAccessor accessor) { - return parseAggregationPipeline(method.getAnnotatedAggregation(), accessor); + AggregationPipeline computePipeline(MongoQueryMethod method, ConvertingParameterAccessor accessor) { + return new AggregationPipeline(parseAggregationPipeline(method.getAnnotatedAggregation(), accessor)); } - private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor) { + private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor, AggregationPipeline pipeline) { AggregationOptions.Builder builder = Aggregation.newAggregationOptions(); @@ -181,6 +190,10 @@ public class StringBasedAggregation extends AbstractMongoQuery { AggregationUtils.applyMeta(builder, method); AggregationUtils.applyHint(builder, method); + if(ReflectionUtils.isVoid(method.getReturnType().getType()) && pipeline.isOutOrMerge()) { + builder.skipOutput(); + } + return builder.build(); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregationUnitTests.java index 99b0b241d..fb10dcb41 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregationUnitTests.java @@ -79,6 +79,7 @@ public class ReactiveStringBasedAggregationUnitTests { private static final String RAW_SORT_STRING = "{ '$sort' : { 'lastname' : -1 } }"; private static final String RAW_GROUP_BY_LASTNAME_STRING = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$firstname' } } }"; + private static final String RAW_OUT = "{ '$out' : 'authors' }"; private static final String GROUP_BY_LASTNAME_STRING_WITH_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', names : { '$addToSet' : '$?0' } } }"; private static final String GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$?#{[0]}' } } }"; @@ -196,6 +197,22 @@ public class ReactiveStringBasedAggregationUnitTests { return new AggregationInvocation(aggregationCaptor.getValue(), targetTypeCaptor.getValue(), result); } + @Test // GH-4088 + void aggregateWithVoidReturnTypeSkipsResultOnOutStage() { + + AggregationInvocation invocation = executeAggregation("outSkipResult"); + + assertThat(skipResultsOf(invocation)).isTrue(); + } + + @Test // GH-4088 + void aggregateWithOutStageDoesNotSkipResults() { + + AggregationInvocation invocation = executeAggregation("outDoNotSkipResult"); + + assertThat(skipResultsOf(invocation)).isFalse(); + } + private ReactiveStringBasedAggregation createAggregationForMethod(String name, Class... parameters) { Method method = ClassUtils.getMethod(SampleRepository.class, name, parameters); @@ -230,6 +247,11 @@ public class ReactiveStringBasedAggregationUnitTests { : null; } + private Boolean skipResultsOf(AggregationInvocation invocation) { + return invocation.aggregation.getOptions() != null ? invocation.aggregation.getOptions().isSkipResults() + : false; + } + private Class targetTypeOf(AggregationInvocation invocation) { return invocation.getTargetType(); } @@ -261,6 +283,12 @@ public class ReactiveStringBasedAggregationUnitTests { @Hint("idx") @Aggregation(RAW_GROUP_BY_LASTNAME_STRING) String withHint(); + + @Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT }) + Flux outDoNotSkipResult(); + + @Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT }) + Mono outSkipResult(); } static class PersonAggregate { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java index a88534afd..91afca80b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java @@ -92,6 +92,7 @@ public class StringBasedAggregationUnitTests { private static final String RAW_SORT_STRING = "{ '$sort' : { 'lastname' : -1 } }"; private static final String RAW_GROUP_BY_LASTNAME_STRING = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$firstname' } } }"; + private static final String RAW_OUT = "{ '$out' : 'authors' }"; private static final String GROUP_BY_LASTNAME_STRING_WITH_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', names : { '$addToSet' : '$?0' } } }"; private static final String GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER = "{ '$group': { '_id' : '$lastname', 'names' : { '$addToSet' : '$?#{[0]}' } } }"; @@ -268,6 +269,22 @@ public class StringBasedAggregationUnitTests { assertThat(hintOf(invocation)).isEqualTo("idx"); } + @Test // GH-4088 + void aggregateWithVoidReturnTypeSkipsResultOnOutStage() { + + AggregationInvocation invocation = executeAggregation("outSkipResult"); + + assertThat(skipResultsOf(invocation)).isTrue(); + } + + @Test // GH-4088 + void aggregateWithOutStageDoesNotSkipResults() { + + AggregationInvocation invocation = executeAggregation("outDoNotSkipResult"); + + assertThat(skipResultsOf(invocation)).isFalse(); + } + private AggregationInvocation executeAggregation(String name, Object... args) { Class[] argTypes = Arrays.stream(args).map(Object::getClass).toArray(Class[]::new); @@ -316,6 +333,11 @@ public class StringBasedAggregationUnitTests { : null; } + private Boolean skipResultsOf(AggregationInvocation invocation) { + return invocation.aggregation.getOptions() != null ? invocation.aggregation.getOptions().isSkipResults() + : false; + } + private Class targetTypeOf(AggregationInvocation invocation) { return invocation.getTargetType(); } @@ -368,6 +390,12 @@ public class StringBasedAggregationUnitTests { @Hint("idx") @Aggregation(RAW_GROUP_BY_LASTNAME_STRING) String withHint(); + + @Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT }) + List outDoNotSkipResult(); + + @Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, RAW_OUT }) + void outSkipResult(); } private interface UnsupportedRepository extends Repository { diff --git a/src/main/asciidoc/reference/mongo-repositories-aggregation.adoc b/src/main/asciidoc/reference/mongo-repositories-aggregation.adoc index 1e6b40ac3..d2a160eb0 100644 --- a/src/main/asciidoc/reference/mongo-repositories-aggregation.adoc +++ b/src/main/asciidoc/reference/mongo-repositories-aggregation.adoc @@ -37,6 +37,12 @@ public interface PersonRepository extends CrudRepository { @Aggregation("{ '$project': { '_id' : '$lastname' } }") List findAllLastnames(); <9> + + @Aggregation(pipeline = { + "{ $group : { _id : '$author', books: { $push: '$title' } } }", + "{ $out : 'authors' }" + }) + void groupAndOutSkippingOutput(); <10> } ---- [source,java] @@ -75,6 +81,7 @@ Therefore, the `Sort` properties are mapped against the methods return type `Per To gain more control, you might consider `AggregationResult` as method return type as shown in <7>. <8> Obtain the raw `AggregationResults` mapped to the generic target wrapper type `SumValue` or `org.bson.Document`. <9> Like in <6>, a single value can be directly obtained from multiple result ``Document``s. +<10> Skips the output of the `$out` stage when return type is `void`. ==== In some scenarios, aggregations might require additional options, such as a maximum run time, additional log comments, or the permission to temporarily write data to disk.