From 0085c8063a888b8fe62528743e988c8da23fb574 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Thu, 23 Jul 2020 13:12:06 +0200 Subject: [PATCH] DATAMONGO-2557 - Use configured CodecRegistry when parsing String based queries instead of default one. Original pull request: #879. --- .../repository/query/AbstractMongoQuery.java | 34 ++++++++++++- .../query/AbstractReactiveMongoQuery.java | 51 +++++++++++++------ .../repository/query/AggregationUtils.java | 36 +------------ .../query/ReactiveStringBasedAggregation.java | 32 ++++++------ .../query/ReactiveStringBasedMongoQuery.java | 30 ++++++----- .../query/StringBasedAggregation.java | 31 +++++++++-- .../query/StringBasedMongoQuery.java | 27 ++++------ ...activeStringBasedAggregationUnitTests.java | 11 ++++ ...eactiveStringBasedMongoQueryUnitTests.java | 4 ++ .../StringBasedAggregationUnitTests.java | 13 +++++ .../query/StringBasedMongoQueryUnitTests.java | 4 ++ 11 files changed, 173 insertions(+), 100 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java index 6c713d444..e74e7d64b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java @@ -16,6 +16,8 @@ package org.springframework.data.mongodb.repository.query; import org.bson.Document; +import org.bson.codecs.configuration.CodecRegistry; +import org.springframework.data.mapping.model.SpELExpressionEvaluator; import org.springframework.data.mongodb.core.ExecutableFindOperation.ExecutableFind; import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; import org.springframework.data.mongodb.core.ExecutableFindOperation.TerminatingFind; @@ -30,11 +32,14 @@ import org.springframework.data.repository.query.ParameterAccessor; import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider; import org.springframework.data.repository.query.RepositoryQuery; import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.spel.ExpressionDependencies; +import org.springframework.expression.EvaluationContext; import org.springframework.expression.ExpressionParser; -import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import com.mongodb.client.MongoDatabase; + /** * Base class for {@link RepositoryQuery} implementations for Mongo. * @@ -209,6 +214,29 @@ public abstract class AbstractMongoQuery implements RepositoryQuery { return applyQueryMetaAttributesWhenPresent(createQuery(accessor)); } + /** + * Obtain a the {@link EvaluationContext} suitable to evaluate expressions backed by the given dependencies. + * + * @param dependencies must not be {@literal null}. + * @param accessor must not be {@literal null}. + * @return the {@link SpELExpressionEvaluator}. + * @since 2.4 + */ + protected SpELExpressionEvaluator getSpELExpressionEvaluatorFor(ExpressionDependencies dependencies, + ConvertingParameterAccessor accessor) { + + return new DefaultSpELExpressionEvaluator(expressionParser, evaluationContextProvider + .getEvaluationContext(getQueryMethod().getParameters(), accessor.getValues(), dependencies)); + } + + /** + * @return the {@link CodecRegistry} used. + * @since 2.4 + */ + protected CodecRegistry getCodecRegistry() { + return operations.execute(AbstractMongoQuery::obtainCodecRegistry); + } + /** * Creates a {@link Query} instance using the given {@link ParameterAccessor} * @@ -247,4 +275,8 @@ public abstract class AbstractMongoQuery implements RepositoryQuery { * @since 2.0.4 */ protected abstract boolean isLimiting(); + + private static CodecRegistry obtainCodecRegistry(MongoDatabase db) { + return db.getCodecRegistry(); + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractReactiveMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractReactiveMongoQuery.java index a0d8e8780..aba7a46c2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractReactiveMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractReactiveMongoQuery.java @@ -18,6 +18,7 @@ package org.springframework.data.mongodb.repository.query; import reactor.core.publisher.Mono; import org.bson.Document; +import org.bson.codecs.configuration.CodecRegistry; import org.reactivestreams.Publisher; import org.springframework.core.convert.converter.Converter; import org.springframework.data.mapping.model.EntityInstantiators; @@ -42,6 +43,9 @@ import org.springframework.expression.ExpressionParser; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import com.mongodb.MongoClientSettings; +import com.mongodb.reactivestreams.client.MongoDatabase; + /** * Base class for reactive {@link RepositoryQuery} implementations for MongoDB. * @@ -257,6 +261,35 @@ public abstract class AbstractReactiveMongoQuery implements RepositoryQuery { return createQuery(accessor).map(this::applyQueryMetaAttributesWhenPresent); } + /** + * Obtain a {@link Mono publisher} emitting the {@link SpELExpressionEvaluator} suitable to evaluate expressions + * backed by the given dependencies. + * + * @param dependencies must not be {@literal null}. + * @param accessor must not be {@literal null}. + * @return a {@link Mono} emitting the {@link SpELExpressionEvaluator} when ready. + * @since 2.4 + */ + protected Mono getSpelEvaluatorFor(ExpressionDependencies dependencies, + ConvertingParameterAccessor accessor) { + + return evaluationContextProvider + .getEvaluationContextLater(getQueryMethod().getParameters(), accessor.getValues(), dependencies) + .map(evaluationContext -> (SpELExpressionEvaluator) new DefaultSpELExpressionEvaluator(expressionParser, + evaluationContext)) + .defaultIfEmpty(DefaultSpELExpressionEvaluator.unsupported()); + } + + /** + * @return a {@link Mono} emitting the {@link CodecRegistry} when ready. + * @since 2.4 + */ + protected Mono getCodecRegistry() { + + return Mono.from(operations.execute(AbstractReactiveMongoQuery::obtainCodecRegistry)) + .defaultIfEmpty(MongoClientSettings.getDefaultCodecRegistry()); + } + /** * Creates a {@link Query} instance using the given {@link ParameterAccessor} * @@ -296,21 +329,7 @@ public abstract class AbstractReactiveMongoQuery implements RepositoryQuery { */ protected abstract boolean isLimiting(); - /** - * Obtain a {@link Mono publisher} emitting the {@link SpELExpressionEvaluator} suitable to evaluate expressions - * backed by the given dependencies. - * - * @param dependencies must not be {@literal null}. - * @param accessor must not be {@literal null}. - * @return a {@link Mono} emitting the {@link SpELExpressionEvaluator} when ready. - */ - protected Mono getSpelEvaluatorFor(ExpressionDependencies dependencies, - ConvertingParameterAccessor accessor) { - - return evaluationContextProvider - .getEvaluationContextLater(getQueryMethod().getParameters(), accessor.getValues(), dependencies) - .map(evaluationContext -> (SpELExpressionEvaluator) new DefaultSpELExpressionEvaluator(expressionParser, - evaluationContext)) - .defaultIfEmpty(DefaultSpELExpressionEvaluator.unsupported()); + private static Mono obtainCodecRegistry(MongoDatabase db) { + return Mono.justOrEmpty(db.getCodecRegistry()); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java index 6e77b9558..e04e7945a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java @@ -16,7 +16,6 @@ package org.springframework.data.mongodb.repository.query; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -25,17 +24,13 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort.Order; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; -import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.AggregationOptions; import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.query.Collation; import org.springframework.data.mongodb.core.query.Meta; import org.springframework.data.mongodb.core.query.Query; -import org.springframework.data.mongodb.util.json.ParameterBindingContext; -import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider; import org.springframework.expression.ExpressionParser; -import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; @@ -50,10 +45,7 @@ import org.springframework.util.StringUtils; */ abstract class AggregationUtils { - private static final ParameterBindingDocumentCodec CODEC = new ParameterBindingDocumentCodec(); - - private AggregationUtils() { - } + private AggregationUtils() {} /** * Apply a collation extracted from the given {@literal collationExpression} to the given @@ -106,32 +98,6 @@ abstract class AggregationUtils { return builder; } - /** - * Compute the {@link AggregationOperation aggregation} pipeline for the given {@link MongoQueryMethod}. The raw - * {@link org.springframework.data.mongodb.repository.Aggregation#pipeline()} is parsed with a - * {@link ParameterBindingDocumentCodec} to obtain the MongoDB native {@link Document} representation returned by - * {@link AggregationOperation#toDocument(AggregationOperationContext)} that is mapped against the domain type - * properties. - * - * @param method - * @param accessor - * @param expressionParser - * @param evaluationContextProvider - * @return - */ - static List computePipeline(MongoQueryMethod method, ConvertingParameterAccessor accessor, - ExpressionParser expressionParser, QueryMethodEvaluationContextProvider evaluationContextProvider) { - - ParameterBindingContext bindingContext = new ParameterBindingContext((accessor::getBindableValue), expressionParser, - () -> evaluationContextProvider.getEvaluationContext(method.getParameters(), accessor.getValues())); - - List target = new ArrayList<>(method.getAnnotatedAggregation().length); - for (String source : method.getAnnotatedAggregation()) { - target.add(ctx -> ctx.getMappedObject(CODEC.decode(source, bindingContext), method.getDomainClass())); - } - return target; - } - /** * Append {@code $sort} aggregation stage if {@link ConvertingParameterAccessor#getSort()} is present. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregation.java index df9608eae..cacf351d0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregation.java @@ -49,8 +49,6 @@ import org.springframework.util.ClassUtils; */ public class ReactiveStringBasedAggregation extends AbstractReactiveMongoQuery { - private static final ParameterBindingDocumentCodec CODEC = new ParameterBindingDocumentCodec(); - private final ExpressionParser expressionParser; private final ReactiveQueryMethodEvaluationContextProvider evaluationContextProvider; private final ReactiveMongoOperations reactiveMongoOperations; @@ -125,25 +123,29 @@ public class ReactiveStringBasedAggregation extends AbstractReactiveMongoQuery { private Mono> computePipeline(ConvertingParameterAccessor accessor) { - MongoQueryMethod method = getQueryMethod(); - - List> stages = new ArrayList<>(method.getAnnotatedAggregation().length); + return getCodecRegistry().map(ParameterBindingDocumentCodec::new).flatMap(codec -> { - for (String source : method.getAnnotatedAggregation()) { + String[] sourcePipeline = getQueryMethod().getAnnotatedAggregation(); - ExpressionDependencies dependencies = CODEC.captureExpressionDependencies(source, - accessor::getBindableValue, expressionParser); + List> stages = new ArrayList<>(sourcePipeline.length); + for (String source : sourcePipeline) { + stages.add(computePipelineStage(source, accessor, codec)); + } + return Flux.concat(stages).collectList(); + }); + } - Mono stage = getSpelEvaluatorFor(dependencies, accessor).map(it -> { + private Mono computePipelineStage(String source, ConvertingParameterAccessor accessor, + ParameterBindingDocumentCodec codec) { - ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue, it); + ExpressionDependencies dependencies = codec.captureExpressionDependencies(source, accessor::getBindableValue, + expressionParser); - return ctx -> ctx.getMappedObject(CODEC.decode(source, bindingContext), method.getDomainClass()); - }); - stages.add(stage); - } + return getSpelEvaluatorFor(dependencies, accessor).map(it -> { - return Flux.concat(stages).collectList(); + ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue, it); + return ctx -> ctx.getMappedObject(codec.decode(source, bindingContext), getQueryMethod().getDomainClass()); + }); } private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedMongoQuery.java index 1ab7762d3..43ef9374b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedMongoQuery.java @@ -44,7 +44,6 @@ public class ReactiveStringBasedMongoQuery extends AbstractReactiveMongoQuery { private static final String COUNT_EXISTS_AND_DELETE = "Manually defined query for %s cannot be a count and exists or delete query at the same time!"; private static final Logger LOG = LoggerFactory.getLogger(ReactiveStringBasedMongoQuery.class); - private static final ParameterBindingDocumentCodec CODEC = new ParameterBindingDocumentCodec(); private final String query; private final String fieldSpec; @@ -121,27 +120,30 @@ public class ReactiveStringBasedMongoQuery extends AbstractReactiveMongoQuery { @Override protected Mono createQuery(ConvertingParameterAccessor accessor) { - Mono queryObject = getBindingContext(accessor, expressionParser, this.query) - .map(it -> CODEC.decode(this.query, it)); - Mono fieldsObject = getBindingContext(accessor, expressionParser, this.fieldSpec) - .map(it -> CODEC.decode(this.fieldSpec, it)); + return getCodecRegistry().map(ParameterBindingDocumentCodec::new).flatMap(codec -> { - return queryObject.zipWith(fieldsObject).map(tuple -> { + Mono queryObject = getBindingContext(query, accessor, codec) + .map(context -> codec.decode(query, context)); + Mono fieldsObject = getBindingContext(fieldSpec, accessor, codec) + .map(context -> codec.decode(fieldSpec, context)); - Query query = new BasicQuery(tuple.getT1(), tuple.getT2()).with(accessor.getSort()); + return queryObject.zipWith(fieldsObject).map(tuple -> { - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Created query %s for %s fields.", query.getQueryObject(), query.getFieldsObject())); - } + Query query = new BasicQuery(tuple.getT1(), tuple.getT2()).with(accessor.getSort()); + + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Created query %s for %s fields.", query.getQueryObject(), query.getFieldsObject())); + } - return query; + return query; + }); }); } - private Mono getBindingContext(ConvertingParameterAccessor accessor, - ExpressionParser expressionParser, String json) { + private Mono getBindingContext(String json, ConvertingParameterAccessor accessor, + ParameterBindingDocumentCodec codec) { - ExpressionDependencies dependencies = CODEC.captureExpressionDependencies(json, accessor::getBindableValue, + ExpressionDependencies dependencies = codec.captureExpressionDependencies(json, accessor::getBindableValue, expressionParser); return getSpelEvaluatorFor(dependencies, accessor) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java index 2dea1ff89..ecbbfa058 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java @@ -15,11 +15,12 @@ */ package org.springframework.data.mongodb.repository.query; +import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; import org.bson.Document; - +import org.springframework.data.mapping.model.SpELExpressionEvaluator; import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.Aggregation; @@ -30,8 +31,11 @@ import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes; import org.springframework.data.mongodb.core.query.Query; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; import org.springframework.data.repository.query.QueryMethodEvaluationContextProvider; import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.spel.ExpressionDependencies; import org.springframework.expression.ExpressionParser; import org.springframework.util.ClassUtils; @@ -73,7 +77,9 @@ public class StringBasedAggregation extends AbstractMongoQuery { ConvertingParameterAccessor accessor, Class typeToRead) { if (method.isPageQuery() || method.isSliceQuery()) { - throw new InvalidMongoDbApiUsageException(String.format("Repository aggregation method '%s' does not support '%s' return type. Please use eg. 'List' instead.", method.getName(), method.getReturnType().getType().getSimpleName())); + throw new InvalidMongoDbApiUsageException(String.format( + "Repository aggregation method '%s' does not support '%s' return type. Please use eg. 'List' instead.", + method.getName(), method.getReturnType().getType().getSimpleName())); } Class sourceType = method.getDomainClass(); @@ -125,7 +131,26 @@ public class StringBasedAggregation extends AbstractMongoQuery { } List computePipeline(MongoQueryMethod method, ConvertingParameterAccessor accessor) { - return AggregationUtils.computePipeline(method, accessor, expressionParser, evaluationContextProvider); + + ParameterBindingDocumentCodec codec = new ParameterBindingDocumentCodec(getCodecRegistry()); + String[] sourcePipeline = method.getAnnotatedAggregation(); + + List stages = new ArrayList<>(sourcePipeline.length); + for (String source : sourcePipeline) { + stages.add(computePipelineStage(source, accessor, codec)); + } + return stages; + } + + private AggregationOperation computePipelineStage(String source, ConvertingParameterAccessor accessor, + ParameterBindingDocumentCodec codec) { + + ExpressionDependencies dependencies = codec.captureExpressionDependencies(source, accessor::getBindableValue, + expressionParser); + + SpELExpressionEvaluator evaluator = getSpELExpressionEvaluatorFor(dependencies, accessor); + ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue, evaluator); + return ctx -> ctx.getMappedObject(codec.decode(source, bindingContext), getQueryMethod().getDomainClass()); } private AggregationOptions computeOptions(MongoQueryMethod method, ConvertingParameterAccessor accessor) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedMongoQuery.java index bb6864233..fb86830f4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedMongoQuery.java @@ -16,7 +16,6 @@ package org.springframework.data.mongodb.repository.query; import org.bson.Document; -import org.bson.codecs.configuration.CodecRegistry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.data.mapping.model.SpELExpressionEvaluator; @@ -31,9 +30,6 @@ import org.springframework.expression.ExpressionParser; import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.util.Assert; -import com.mongodb.MongoClientSettings; -import com.mongodb.client.MongoDatabase; - /** * Query to use a plain JSON String to create the {@link Query} to actually execute. * @@ -50,7 +46,6 @@ public class StringBasedMongoQuery extends AbstractMongoQuery { private final String query; private final String fieldSpec; - private final ParameterBindingDocumentCodec codec; private final ExpressionParser expressionParser; private final QueryMethodEvaluationContextProvider evaluationContextProvider; @@ -112,10 +107,6 @@ public class StringBasedMongoQuery extends AbstractMongoQuery { this.isExistsQuery = false; this.isDeleteQuery = false; } - - CodecRegistry codecRegistry = mongoOperations.execute(MongoDatabase::getCodecRegistry); - this.codec = new ParameterBindingDocumentCodec( - codecRegistry != null ? codecRegistry : MongoClientSettings.getDefaultCodecRegistry()); } /* @@ -125,8 +116,10 @@ public class StringBasedMongoQuery extends AbstractMongoQuery { @Override protected Query createQuery(ConvertingParameterAccessor accessor) { - Document queryObject = codec.decode(this.query, getBindingContext(accessor, expressionParser, this.query)); - Document fieldsObject = codec.decode(this.fieldSpec, getBindingContext(accessor, expressionParser, this.fieldSpec)); + ParameterBindingDocumentCodec codec = getParameterBindingCodec(); + + Document queryObject = codec.decode(this.query, getBindingContext(this.query, accessor, codec)); + Document fieldsObject = codec.decode(this.fieldSpec, getBindingContext(this.fieldSpec, accessor, codec)); Query query = new BasicQuery(queryObject, fieldsObject).with(accessor.getSort()); @@ -137,15 +130,13 @@ public class StringBasedMongoQuery extends AbstractMongoQuery { return query; } - private ParameterBindingContext getBindingContext(ConvertingParameterAccessor accessor, - ExpressionParser expressionParser, String json) { + private ParameterBindingContext getBindingContext(String json, ConvertingParameterAccessor accessor, + ParameterBindingDocumentCodec codec) { ExpressionDependencies dependencies = codec.captureExpressionDependencies(json, accessor::getBindableValue, expressionParser); - SpELExpressionEvaluator evaluator = new DefaultSpELExpressionEvaluator(expressionParser, evaluationContextProvider - .getEvaluationContext(getQueryMethod().getParameters(), accessor.getValues(), dependencies)); - + SpELExpressionEvaluator evaluator = getSpELExpressionEvaluatorFor(dependencies, accessor); return new ParameterBindingContext(accessor::getBindableValue, evaluator); } @@ -189,4 +180,8 @@ public class StringBasedMongoQuery extends AbstractMongoQuery { boolean isDeleteQuery) { return BooleanUtil.countBooleanTrueValues(isCountQuery, isExistsQuery, isDeleteQuery) > 1; } + + private ParameterBindingDocumentCodec getParameterBindingCodec() { + return new ParameterBindingDocumentCodec(getCodecRegistry()); + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregationUnitTests.java index 35fb113cb..f7c255bed 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedAggregationUnitTests.java @@ -91,6 +91,7 @@ public class ReactiveStringBasedAggregationUnitTests { converter = new MappingMongoConverter(dbRefResolver, new MongoMappingContext()); when(operations.getConverter()).thenReturn(converter); when(operations.aggregate(any(TypedAggregation.class), any())).thenReturn(Flux.empty()); + when(operations.execute(any())).thenReturn(Flux.empty()); } @Test // DATAMONGO-2153 @@ -166,6 +167,13 @@ public class ReactiveStringBasedAggregationUnitTests { assertThat(collationOf(invocation)).isEqualTo(Collation.of("en_US")); } + @Test // DATAMONGO-2557 + void aggregationRetrievesCodecFromDriverJustOnceForMultipleAggregationOperationsInPipeline() { + + executeAggregation("multiOperationPipeline", "firstname"); + verify(operations).execute(any()); + } + private AggregationInvocation executeAggregation(String name, Object... args) { Class[] argTypes = Arrays.stream(args).map(Object::getClass).toArray(size -> new Class[size]); @@ -228,6 +236,9 @@ public class ReactiveStringBasedAggregationUnitTests { @Aggregation(GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER) Mono spelParameterReplacementAggregation(String arg0); + @Aggregation(pipeline = {RAW_GROUP_BY_LASTNAME_STRING, GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER}) + Mono multiOperationPipeline(String arg0); + @Aggregation(pipeline = RAW_GROUP_BY_LASTNAME_STRING, collation = "de_AT") Mono aggregateWithCollation(); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedMongoQueryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedMongoQueryUnitTests.java index 19f1ae485..1fd5f68e0 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedMongoQueryUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveStringBasedMongoQueryUnitTests.java @@ -19,6 +19,8 @@ import static org.assertj.core.api.Assertions.*; import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.*; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -62,6 +64,7 @@ import org.springframework.util.Base64Utils; * @author Christoph Strobl */ @ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) public class ReactiveStringBasedMongoQueryUnitTests { SpelExpressionParser PARSER = new SpelExpressionParser(); @@ -76,6 +79,7 @@ public class ReactiveStringBasedMongoQueryUnitTests { public void setUp() { when(operations.query(any())).thenReturn(reactiveFind); + when(operations.execute(any())).thenReturn(Flux.empty()); this.converter = new MappingMongoConverter(factory, new MongoMappingContext()); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java index c24701dea..9c2d828cf 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java @@ -66,6 +66,8 @@ import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; +import com.mongodb.MongoClientSettings; + /** * Unit tests for {@link StringBasedAggregation}. * @@ -97,6 +99,7 @@ public class StringBasedAggregationUnitTests { converter = new MappingMongoConverter(dbRefResolver, new MongoMappingContext()); when(operations.getConverter()).thenReturn(converter); when(operations.aggregate(any(TypedAggregation.class), any())).thenReturn(aggregationResults); + when(operations.execute(any())).thenReturn(MongoClientSettings.getDefaultCodecRegistry()); } @Test // DATAMONGO-2153 @@ -218,6 +221,13 @@ public class StringBasedAggregationUnitTests { .withMessageContaining("Page"); } + @Test // DATAMONGO-2557 + void aggregationRetrievesCodecFromDriverJustOnceForMultipleAggregationOperationsInPipeline() { + + executeAggregation("multiOperationPipeline", "firstname"); + verify(operations).execute(any()); + } + private AggregationInvocation executeAggregation(String name, Object... args) { Class[] argTypes = Arrays.stream(args).map(Object::getClass).toArray(Class[]::new); @@ -291,6 +301,9 @@ public class StringBasedAggregationUnitTests { @Aggregation(GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER) PersonAggregate spelParameterReplacementAggregation(String arg0); + @Aggregation(pipeline = { RAW_GROUP_BY_LASTNAME_STRING, GROUP_BY_LASTNAME_STRING_WITH_SPEL_PARAMETER_PLACEHOLDER }) + PersonAggregate multiOperationPipeline(String arg0); + @Aggregation(pipeline = RAW_GROUP_BY_LASTNAME_STRING, collation = "de_AT") PersonAggregate aggregateWithCollation(); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedMongoQueryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedMongoQueryUnitTests.java index 30b32cb91..9366d5e4a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedMongoQueryUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedMongoQueryUnitTests.java @@ -39,6 +39,8 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; import org.springframework.data.mongodb.core.DbCallback; import org.springframework.data.mongodb.core.DocumentTestUtils; import org.springframework.data.mongodb.core.ExecutableFindOperation.ExecutableFind; @@ -72,6 +74,7 @@ import com.mongodb.reactivestreams.client.MongoClients; * @author Mark Paluch */ @ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) public class StringBasedMongoQueryUnitTests { SpelExpressionParser PARSER = new SpelExpressionParser(); @@ -88,6 +91,7 @@ public class StringBasedMongoQueryUnitTests { this.converter = new MappingMongoConverter(factory, new MongoMappingContext()); doReturn(findOperation).when(operations).query(any()); + doReturn(MongoClientSettings.getDefaultCodecRegistry()).when(operations).execute(any()); } @Test