From 50de1d6a0ed27752c4113034ba88589d5a88da5d Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 15 Jul 2025 16:28:44 +0200 Subject: [PATCH] Use consistently ParameterBindingDocumentCodec to parse queries and aggregations in AOT-generated code. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We now use ParameterBindingDocumentCodec instead of Document.parse(…) to reinstate lenient MQL parsing and to align with reflective behavior. Previously, we've used Document.parse(…) requiring a stricter syntax for e.g. values. Closes #5018 --- .../MongoAotRepositoryFragmentSupport.java | 4 ++ .../repository/aot/MongoCodeBlocks.java | 47 +++---------------- .../mongodb/repository/aot/QueryBlocks.java | 11 +++-- .../repository/aot/VectorSearchBlocks.java | 2 +- .../json/ParameterBindingDocumentCodec.java | 4 +- .../test/java/example/aot/UserRepository.java | 19 +++----- .../aot/MongoRepositoryContributorTests.java | 11 +++++ .../aot/QueryMethodContributionUnitTests.java | 2 +- .../ParameterBindingJsonReaderUnitTests.java | 8 ++-- 9 files changed, 41 insertions(+), 67 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java index 86025b98f..b7f65df1a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java @@ -101,6 +101,10 @@ public class MongoAotRepositoryFragmentSupport { it -> valueExpressions.createValueContextProvider(mongoParameters.get().get(it)))); } + protected Document parse(String json) { + return CODEC.decode(json); + } + protected Document bindParameters(Method method, String source, Object... args) { expandGeoShapes(args); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 48cd04f5f..64a67c2bd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -15,9 +15,6 @@ */ package org.springframework.data.mongodb.repository.aot; -import java.util.Iterator; -import java.util.Map; -import java.util.Map.Entry; import java.util.regex.Pattern; import org.bson.Document; @@ -171,10 +168,10 @@ class MongoCodeBlocks { Builder builder = CodeBlock.builder(); if (!StringUtils.hasText(source)) { builder.add("new $T()", Document.class); - } else if (!containsPlaceholder(source)) { - builder.add("$T.parse($S)", Document.class, source); - } else { + } else if (containsPlaceholder(source)) { builder.add("bindParameters(ExpressionMarker.class.getEnclosingMethod(), $S, $L);\n", source, argNames); + } else { + builder.add("parse($S)", source); } return builder.build(); } @@ -185,47 +182,15 @@ class MongoCodeBlocks { Builder builder = CodeBlock.builder(); if (!StringUtils.hasText(source)) { builder.addStatement("$1T $2L = new $1T()", Document.class, variableName); - } else if (!containsPlaceholder(source)) { - builder.addStatement("$1T $2L = $1T.parse($3S)", Document.class, variableName, source); - } else { + } else if (containsPlaceholder(source)) { builder.add("$T $L = bindParameters(ExpressionMarker.class.getEnclosingMethod(), $S, $L);\n", Document.class, variableName, source, argNames); + } else { + builder.addStatement("$1T $2L = parse($3S)", Document.class, variableName, source); } return builder.build(); } - static CodeBlock renderArgumentMap(Map arguments) { - - Builder builder = CodeBlock.builder(); - builder.add("argumentMap("); - Iterator> iterator = arguments.entrySet().iterator(); - while (iterator.hasNext()) { - Entry next = iterator.next(); - builder.add("$S, ", next.getKey()); - builder.add(next.getValue()); - if (iterator.hasNext()) { - builder.add(", "); - } - } - builder.add(")"); - return builder.build(); - } - - static CodeBlock renderArgumentArray(Map arguments) { - - Builder builder = CodeBlock.builder(); - builder.add("arguments("); - Iterator iterator = arguments.values().iterator(); - while (iterator.hasNext()) { - builder.add(iterator.next()); - if (iterator.hasNext()) { - builder.add(", "); - } - } - builder.add(")"); - return builder.build(); - } - static CodeBlock evaluateNumberPotentially(String value, Class targetType, AotQueryMethodGenerationContext context) { try { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java index b1a4a680e..cd30d48e8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -258,13 +258,14 @@ class QueryBlocks { String source = this.source.getQuery().getQueryString(); if (!StringUtils.hasText(source)) { return CodeBlock.of("new $T(new $T())", BasicQuery.class, Document.class); + } else if (MongoCodeBlocks.containsPlaceholder(source)) { + Builder builder = CodeBlock.builder(); + builder.add("createQuery(ExpressionMarker.class.getEnclosingMethod(), $S, $L)", source, parameterNames); + return builder.build(); } - if (!MongoCodeBlocks.containsPlaceholder(source)) { - return CodeBlock.of("new $T($T.parse($S))", BasicQuery.class, Document.class, source); + else { + return CodeBlock.of("new $T(parse($S))", BasicQuery.class, source); } - Builder builder = CodeBlock.builder(); - builder.add("createQuery(ExpressionMarker.class.getEnclosingMethod(), $S, $L)", source, parameterNames); - return builder.build(); } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBlocks.java index cc1958538..940d90696 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBlocks.java @@ -165,7 +165,7 @@ class VectorSearchBlocks { builder.add("($T) ($L) -> {\n", AggregationOperation.class, ctx); builder.indent(); - builder.add("$1T $4L = $5L.getMappedObject($1T.parse($2S), $3T.class);\n", Document.class, filter.getSortString(), + builder.add("$1T $4L = $5L.getMappedObject(parse($2S), $3T.class);\n", Document.class, filter.getSortString(), context.getActualReturnType().getType(), mappedSort, ctx); builder.add("return new $1T($2S, $3L.append(\"__score__\", -1));\n", Document.class, "$sort", mappedSort); builder.unindent(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/json/ParameterBindingDocumentCodec.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/json/ParameterBindingDocumentCodec.java index 8138f397a..de65cdfec 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/json/ParameterBindingDocumentCodec.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/json/ParameterBindingDocumentCodec.java @@ -168,7 +168,7 @@ public class ParameterBindingDocumentCodec implements CollectibleCodec } // Spring Data Customization START - public Document decode(@Nullable String json, Object[] values) { + public Document decode(@Nullable String json, Object... values) { return decode(json, new ParameterBindingContext((index) -> values[index], new SpelExpressionParser(), () -> EvaluationContextProvider.DEFAULT.getEvaluationContext(values))); @@ -221,7 +221,7 @@ public class ParameterBindingDocumentCodec implements CollectibleCodec return document; } else if (bindingReader.currentValue instanceof String stringValue) { try { - return decode(stringValue, new Object[0]); + return decode(stringValue); } catch (JsonParseException jsonParseException) { throw new IllegalArgumentException("Expression result is not a valid json document", jsonParseException); } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index ee1058fdc..5145159a3 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -26,18 +26,7 @@ import java.util.regex.Pattern; import java.util.stream.Stream; import org.springframework.data.annotation.Id; -import org.springframework.data.domain.Limit; -import org.springframework.data.domain.Page; -import org.springframework.data.domain.Pageable; -import org.springframework.data.domain.Range; -import org.springframework.data.domain.Score; -import org.springframework.data.domain.ScrollPosition; -import org.springframework.data.domain.SearchResults; -import org.springframework.data.domain.Similarity; -import org.springframework.data.domain.Slice; -import org.springframework.data.domain.Sort; -import org.springframework.data.domain.Vector; -import org.springframework.data.domain.Window; +import org.springframework.data.domain.*; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; @@ -297,6 +286,9 @@ public interface UserRepository extends CrudRepository { "{ '$project': { '_id' : '$last_name' } }" }, collation = "no_collation") List findAllLastnamesWithCollation(); + @Aggregation("{ $group : { _id : $customerId, total : { $sum : 1 } } }") + List totalOrdersPerCustomer(Sort sort); + // Vector Search @VectorSearch(indexName = "embedding.vector_cos", filter = "{lastname: ?0}", numCandidates = "#{10+10}", @@ -362,4 +354,7 @@ public interface UserRepository extends CrudRepository { return Objects.hash(lastname, names); } } + + record OrdersPerCustomer(Object id, long total) { + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index 40e1f35fa..61f434b15 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -30,6 +30,7 @@ import java.util.regex.Pattern; import org.bson.BsonString; import org.bson.Document; +import org.bson.json.JsonParseException; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -688,6 +689,16 @@ class MongoRepositoryContributorTests { .withMessageContaining("'locale' is invalid"); } + @Test // GH-5018 + void aggregationIsParsedLeniently() { + + List result = fragment.totalOrdersPerCustomer(Sort.by("_id")); + assertThat(result).hasSize(1); + + assertThatExceptionOfType(JsonParseException.class) + .isThrownBy(() -> Document.parse("{ $group : { _id : $customerId, total : { $sum : 1 } } }")); + } + @Test // GH-5004 void testNear() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java index 151bb1741..bd5c190e2 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -388,7 +388,7 @@ class QueryMethodContributionUnitTests { .containsSubsequence("var $sort = ", // "(ctx) -> {", // "mappedSort = ctx.getMappedObject(", // - "Document.parse(\"{\\\"firstname\\\": 1}\")", // + "parse(\"{\\\"firstname\\\": 1}\")", // "Document(\"$sort\", mappedSort.append(\"__score__\", -1))"); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/util/json/ParameterBindingJsonReaderUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/util/json/ParameterBindingJsonReaderUnitTests.java index 20b5060f7..dc3cae8bd 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/util/json/ParameterBindingJsonReaderUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/util/json/ParameterBindingJsonReaderUnitTests.java @@ -15,8 +15,7 @@ */ package org.springframework.data.mongodb.util.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.*; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -31,6 +30,7 @@ import org.bson.BsonRegularExpression; import org.bson.Document; import org.bson.codecs.DecoderContext; import org.junit.jupiter.api.Test; + import org.springframework.data.expression.ValueExpressionParser; import org.springframework.data.spel.EvaluationContextProvider; import org.springframework.data.spel.ExpressionDependencies; @@ -635,9 +635,7 @@ class ParameterBindingJsonReaderUnitTests { } private static Document parse(String json, Object... args) { - - ParameterBindingJsonReader reader = new ParameterBindingJsonReader(json, args); - return new ParameterBindingDocumentCodec().decode(reader, DecoderContext.builder().build()); + return new ParameterBindingDocumentCodec().decode(json, args); } // DATAMONGO-2545