Browse Source

Add AOT support for dynamic projections, streaming/scroll queries and Meta annotation.

Closes: #4970
pull/4976/head
Mark Paluch 7 months ago
parent
commit
0e606d26bf
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 62
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java
  2. 7
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java
  3. 13
      spring-data-mongodb/src/test/java/example/aot/UserRepository.java
  4. 5
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java
  5. 43
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java
  6. 90
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorUnitTests.java
  7. 5
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/StubRepositoryInformation.java
  8. 6
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java
  9. 7
      src/main/antora/modules/ROOT/pages/mongodb/aot.adoc

62
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

@ -23,6 +23,7 @@ import java.util.regex.Pattern; @@ -23,6 +23,7 @@ import java.util.regex.Pattern;
import org.bson.Document;
import org.jspecify.annotations.NullUnmarked;
import org.jspecify.annotations.Nullable;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.domain.Sort.Order;
@ -40,6 +41,7 @@ import org.springframework.data.mongodb.core.query.BasicQuery; @@ -40,6 +41,7 @@ import org.springframework.data.mongodb.core.query.BasicQuery;
import org.springframework.data.mongodb.core.query.BasicUpdate;
import org.springframework.data.mongodb.core.query.Collation;
import org.springframework.data.mongodb.repository.Hint;
import org.springframework.data.mongodb.repository.Meta;
import org.springframework.data.mongodb.repository.ReadPreference;
import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution;
import org.springframework.data.mongodb.repository.query.MongoQueryExecution.PagedExecution;
@ -256,15 +258,13 @@ class MongoCodeBlocks { @@ -256,15 +258,13 @@ class MongoCodeBlocks {
updateReference);
} else if (ClassUtils.isAssignable(Long.class, returnType)) {
builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()",
context.localVariable("updater"), queryVariableName,
updateReference);
context.localVariable("updater"), queryVariableName, updateReference);
} else {
builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class,
context.localVariable("modifiedCount"), context.localVariable("updater"),
queryVariableName, updateReference);
context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName,
updateReference);
builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class,
context.localVariable("modifiedCount"),
returnType);
context.localVariable("modifiedCount"), returnType);
}
return builder.build();
@ -319,11 +319,9 @@ class MongoCodeBlocks { @@ -319,11 +319,9 @@ class MongoCodeBlocks {
Class<?> returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType());
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef,
aggregationVariableName, outputType);
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
if (!queryMethod.isCollectionQuery()) {
builder.addStatement(
"return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))",
CollectionUtils.class, returnType, returnType, context.localVariable("results"));
} else {
builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType,
@ -332,8 +330,7 @@ class MongoCodeBlocks { @@ -332,8 +330,7 @@ class MongoCodeBlocks {
} else {
if (queryMethod.isSliceQuery()) {
builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class,
context.localVariable("results"), mongoOpsRef,
aggregationVariableName, outputType);
context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType);
builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()",
context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName());
builder.addStatement(
@ -378,12 +375,16 @@ class MongoCodeBlocks { @@ -378,12 +375,16 @@ class MongoCodeBlocks {
boolean isProjecting = context.getReturnedType().isProjecting();
Class<?> domainType = context.getRepositoryInformation().getDomainType();
Object actualReturnType = isProjecting ? context.getActualReturnType().getType()
Object actualReturnType = queryMethod.getParameters().hasDynamicProjection() || isProjecting
? TypeName.get(context.getActualReturnType().getType())
: domainType;
builder.add("\n");
if (isProjecting) {
if (queryMethod.getParameters().hasDynamicProjection()) {
builder.addStatement("$T<$T> $L = $L.query($T.class).as($L)", FindWithQuery.class, actualReturnType,
context.localVariable("finder"), mongoOpsRef, domainType, context.getDynamicProjectionParameterName());
} else if (isProjecting) {
builder.addStatement("$T<$T> $L = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType,
context.localVariable("finder"), mongoOpsRef, domainType, actualReturnType);
} else {
@ -400,6 +401,8 @@ class MongoCodeBlocks { @@ -400,6 +401,8 @@ class MongoCodeBlocks {
terminatingMethod = "count()";
} else if (query.isExists()) {
terminatingMethod = "exists()";
} else if (queryMethod.isStreamQuery()) {
terminatingMethod = "stream()";
} else {
terminatingMethod = Optional.class.isAssignableFrom(context.getReturnType().toClass()) ? "one()" : "oneValue()";
}
@ -410,6 +413,12 @@ class MongoCodeBlocks { @@ -410,6 +413,12 @@ class MongoCodeBlocks {
} else if (queryMethod.isSliceQuery()) {
builder.addStatement("return new $T($L, $L).execute($L)", SlicedExecution.class,
context.localVariable("finder"), context.getPageableParameterName(), query.name());
} else if (queryMethod.isScrollQuery()) {
String scrollPositionParameterName = context.getScrollPositionParameterName();
builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(),
scrollPositionParameterName);
} else {
builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(),
terminatingMethod);
@ -544,8 +553,7 @@ class MongoCodeBlocks { @@ -544,8 +553,7 @@ class MongoCodeBlocks {
Builder optionsBuilder = CodeBlock.builder();
optionsBuilder.add("$T $L = $T.builder()\n", AggregationOptions.class,
context.localVariable("aggregationOptions"),
AggregationOptions.class);
context.localVariable("aggregationOptions"), AggregationOptions.class);
optionsBuilder.indent();
for (CodeBlock optionBlock : options) {
optionsBuilder.add(optionBlock);
@ -709,7 +717,27 @@ class MongoCodeBlocks { @@ -709,7 +717,27 @@ class MongoCodeBlocks {
com.mongodb.ReadPreference.class, readPreference);
}
// TODO: Meta annotation
MergedAnnotation<Meta> metaAnnotation = context.getAnnotation(Meta.class);
if (metaAnnotation.isPresent()) {
long maxExecutionTimeMs = metaAnnotation.getLong("maxExecutionTimeMs");
if (maxExecutionTimeMs != -1) {
builder.addStatement("$L.maxTimeMsec($L)", queryVariableName, maxExecutionTimeMs);
}
int cursorBatchSize = metaAnnotation.getInt("cursorBatchSize");
if (cursorBatchSize != 0) {
builder.addStatement("$L.cursorBatchSize($L)", queryVariableName, cursorBatchSize);
}
String comment = metaAnnotation.getString("comment");
if (StringUtils.hasText("comment")) {
builder.addStatement("$L.comment($S)", queryVariableName, comment);
}
}
// TODO: Meta annotation: Disk usage
return builder.build();
}

7
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

@ -48,6 +48,7 @@ import org.springframework.util.StringUtils; @@ -48,6 +48,7 @@ import org.springframework.util.StringUtils;
* MongoDB specific {@link RepositoryContributor}.
*
* @author Christoph Strobl
* @author Mark Paluch
* @since 5.0
*/
public class MongoRepositoryContributor extends RepositoryContributor {
@ -159,8 +160,7 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -159,8 +160,7 @@ public class MongoRepositoryContributor extends RepositoryContributor {
private static boolean backoff(MongoQueryMethod method) {
boolean skip = method.isGeoNearQuery() || method.isScrollQuery() || method.isStreamQuery()
|| method.isSearchQuery();
boolean skip = method.isGeoNearQuery() || method.isSearchQuery();
if (skip && logger.isDebugEnabled()) {
logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming, search or scrolling query"
@ -225,8 +225,7 @@ public class MongoRepositoryContributor extends RepositoryContributor { @@ -225,8 +225,7 @@ public class MongoRepositoryContributor extends RepositoryContributor {
.usingAggregationVariableName(updateVariableName).pipelineOnly(true).build());
builder.addStatement("$T $L = $T.from($L.getOperations())", AggregationUpdate.class,
context.localVariable("aggregationUpdate"),
AggregationUpdate.class, updateVariableName);
context.localVariable("aggregationUpdate"), AggregationUpdate.class, updateVariableName);
builder.add(updateExecutionBlockBuilder(context, queryMethod).withFilter(filterVariableName)
.referencingUpdate(context.localVariable("aggregationUpdate")).build());

13
spring-data-mongodb/src/test/java/example/aot/UserRepository.java

@ -22,13 +22,16 @@ import java.util.List; @@ -22,13 +22,16 @@ import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.springframework.data.annotation.Id;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Window;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.repository.Aggregation;
import org.springframework.data.mongodb.repository.Hint;
@ -94,8 +97,10 @@ public interface UserRepository extends CrudRepository<User, String> { @@ -94,8 +97,10 @@ public interface UserRepository extends CrudRepository<User, String> {
Slice<User> findSliceOfUserByLastnameStartingWith(String lastname, Pageable page);
// TODO: Streaming
// TODO: Scrolling
Stream<User> streamByLastnameStartingWith(String lastname, Sort sort, Limit limit);
Window<User> findTop2WindowByLastnameStartingWithOrderByUsername(String lastname, ScrollPosition scrollPosition);
// TODO: GeoQueries
// TODO: TextSearch
@ -176,14 +181,14 @@ public interface UserRepository extends CrudRepository<User, String> { @@ -176,14 +181,14 @@ public interface UserRepository extends CrudRepository<User, String> {
@ReadPreference("no-such-read-preference")
User findWithReadPreferenceByUsername(String username);
// TODO: hints
/* Projecting Queries */
List<UserProjection> findUserProjectionByLastnameStartingWith(String lastname);
Page<UserProjection> findUserProjectionByLastnameStartingWith(String lastname, Pageable page);
<T> Page<T> findUserProjectionByLastnameStartingWith(String lastname, Pageable page, Class<T> projectionType);
/* Aggregations */
@Aggregation(pipeline = { //

5
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java

@ -62,7 +62,8 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce @@ -62,7 +62,8 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce
new MongoRepositoryContributor(repositoryContext).contribute(generationContext);
AbstractBeanDefinition aotGeneratedRepository = BeanDefinitionBuilder
.genericBeanDefinition(repositoryInterface.getName() + "Impl__Aot") //
.genericBeanDefinition(
repositoryInterface.getPackageName() + "." + repositoryInterface.getSimpleName() + "Impl__Aot") //
.addConstructorArgReference("mongoOperations") //
.addConstructorArgValue(getCreationContext(repositoryContext)).getBeanDefinition();
@ -80,6 +81,8 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce @@ -80,6 +81,8 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce
}).getBeanDefinition();
((BeanDefinitionRegistry) beanFactory).registerBeanDefinition("fragmentFacade", fragmentFacade);
beanFactory.registerSingleton("generationContext", generationContext);
}
private Object getFragmentFacadeProxy(Object fragment) {

43
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java

@ -34,11 +34,15 @@ import org.junit.jupiter.api.extension.ExtendWith; @@ -34,11 +34,15 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.domain.KeysetScrollPosition;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.OffsetScrollPosition;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Window;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
@ -271,6 +275,37 @@ class MongoRepositoryContributorTests { @@ -271,6 +275,37 @@ class MongoRepositoryContributorTests {
assertThat(slice.getContent()).extracting(User::getUsername).containsExactly("han", "kylo");
}
@Test
void testDerivedQueryReturningStream() {
List<User> results = fragment.streamByLastnameStartingWith("S", Sort.by("username"), Limit.of(2)).toList();
assertThat(results).hasSize(2);
assertThat(results).extracting(User::getUsername).containsExactly("han", "kylo");
}
@Test
void testDerivedQueryReturningWindowByOffset() {
Window<User> window1 = fragment.findTop2WindowByLastnameStartingWithOrderByUsername("S", ScrollPosition.offset());
assertThat(window1).extracting(User::getUsername).containsExactly("han", "kylo");
assertThat(window1.positionAt(1)).isInstanceOf(OffsetScrollPosition.class);
Window<User> window2 = fragment.findTop2WindowByLastnameStartingWithOrderByUsername("S", window1.positionAt(1));
assertThat(window2).extracting(User::getUsername).containsExactly("luke", "vader");
}
@Test
void testDerivedQueryReturningWindowByKeyset() {
Window<User> window1 = fragment.findTop2WindowByLastnameStartingWithOrderByUsername("S", ScrollPosition.keyset());
assertThat(window1).extracting(User::getUsername).containsExactly("han", "kylo");
assertThat(window1.positionAt(1)).isInstanceOf(KeysetScrollPosition.class);
Window<User> window2 = fragment.findTop2WindowByLastnameStartingWithOrderByUsername("S", window1.positionAt(1));
assertThat(window2).extracting(User::getUsername).containsExactly("luke", "vader");
}
@Test
void testAnnotatedFinderReturningSingleValueWithQuery() {
@ -439,6 +474,14 @@ class MongoRepositoryContributorTests { @@ -439,6 +474,14 @@ class MongoRepositoryContributorTests {
assertThat(users).extracting(UserProjection::getUsername).containsExactly("han", "kylo");
}
@Test
void testDerivedFinderReturningPageOfDynamicProjections() {
Page<UserProjection> users = fragment.findUserProjectionByLastnameStartingWith("S",
PageRequest.of(0, 2, Sort.by("username")), UserProjection.class);
assertThat(users).extracting(UserProjection::getUsername).containsExactly("han", "kylo");
}
@Test
void testUpdateWithDerivedQuery() {

90
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorUnitTests.java

@ -0,0 +1,90 @@ @@ -0,0 +1,90 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.aot;
import static org.mockito.Mockito.*;
import static org.springframework.data.mongodb.test.util.Assertions.*;
import example.aot.User;
import example.aot.UserRepository;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.aot.generate.GeneratedFiles;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.InputStreamSource;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.repository.Meta;
import org.springframework.data.mongodb.test.util.MongoClientExtension;
import org.springframework.data.repository.CrudRepository;
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
/**
* Unit tests for the {@link UserRepository} fragment sources via {@link MongoRepositoryContributor}.
*
* @author Mark Paluch
*/
@ExtendWith(MongoClientExtension.class)
@SpringJUnitConfig(classes = MongoRepositoryContributorUnitTests.MongoRepositoryContributorConfiguration.class)
class MongoRepositoryContributorUnitTests {
@Configuration
static class MongoRepositoryContributorConfiguration extends AotFragmentTestConfigurationSupport {
public MongoRepositoryContributorConfiguration() {
super(MetaUserRepository.class);
}
@Bean
MongoOperations mongoOperations() {
return mock(MongoOperations.class);
}
}
@Autowired TestGenerationContext generationContext;
@Test
void shouldConsiderMetaAnnotation() throws IOException {
InputStreamSource aotFragment = generationContext.getGeneratedFiles().getGeneratedFile(GeneratedFiles.Kind.SOURCE,
MetaUserRepository.class.getPackageName().replace('.', '/') + "/MetaUserRepositoryImpl__Aot.java");
String content = new InputStreamResource(aotFragment).getContentAsString(StandardCharsets.UTF_8);
assertThat(content).contains("filterQuery.maxTimeMsec(555)");
assertThat(content).contains("filterQuery.cursorBatchSize(1234)");
assertThat(content).contains("filterQuery.comment(\"foo\")");
}
interface MetaUserRepository extends CrudRepository<User, String> {
@Meta
User findAllByLastname(String lastname);
@Meta(cursorBatchSize = 1234, comment = "foo", maxExecutionTimeMs = 555)
User findWithMetaAllByLastname(String lastname);
}
}

5
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/StubRepositoryInformation.java

@ -69,6 +69,11 @@ class StubRepositoryInformation implements RepositoryInformation { @@ -69,6 +69,11 @@ class StubRepositoryInformation implements RepositoryInformation {
return metadata.getReturnedDomainClass(method);
}
@Override
public TypeInformation<?> getReturnedDomainTypeInformation(Method method) {
return metadata.getReturnedDomainTypeInformation(method);
}
@Override
public CrudMethods getCrudMethods() {
return metadata.getCrudMethods();

6
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/TestMongoAotRepositoryContext.java

@ -30,6 +30,7 @@ import org.springframework.core.io.ClassPathResource; @@ -30,6 +30,7 @@ import org.springframework.core.io.ClassPathResource;
import org.springframework.core.test.tools.ClassFile;
import org.springframework.data.mongodb.core.mapping.Document;
import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.config.RepositoryConfigurationSource;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryComposition;
@ -70,6 +71,11 @@ class TestMongoAotRepositoryContext implements AotRepositoryContext { @@ -70,6 +71,11 @@ class TestMongoAotRepositoryContext implements AotRepositoryContext {
return "MongoDB";
}
@Override
public RepositoryConfigurationSource getConfigurationSource() {
return null;
}
@Override
public Set<String> getBasePackages() {
return Set.of("org.springframework.data.dummy.repository.aot");

7
src/main/antora/modules/ROOT/pages/mongodb/aot.adoc

@ -66,16 +66,15 @@ These are typically all query methods that are not backed by an xref:repositorie @@ -66,16 +66,15 @@ These are typically all query methods that are not backed by an xref:repositorie
* Query methods annotated with `@Query` (excluding those containing SpEL)
* Methods annotated with `@Aggregation`
* Methods using `@Update`
* `@Hint` & `@ReadPreference` support
* `@Hint`, `@Meta`, and `@ReadPreference` support
* `Page`, `Slice`, and `Optional` return types
* DTO Projections
**Limitations**
* `@Meta` annotations are not evaluated.
* `@Meta.allowDiskUse` and `flags` are not evaluated.
* Queries / Aggregations / Updates containing `SpEL` cannot be generated.
* Limited `Collation` detection.
* Reserved parameter names (must not be used in method signature) `finder`, `filterQuery`, `countQuery`, `deleteQuery`, `remover` `updateDefinition`, `aggregation`, `aggregationPipeline`, `aggregationUpdate`, `aggregationOptions`, `updater`, `results`, `fields`.
**Excluded methods**
@ -83,6 +82,4 @@ These are typically all query methods that are not backed by an xref:repositorie @@ -83,6 +82,4 @@ These are typically all query methods that are not backed by an xref:repositorie
* Querydsl and Query by Example methods
* Methods whose implementation would be overly complex
* Query Methods obtaining MQL from a file
** Methods accepting `ScrollPosition` (e.g. `Keyset` pagination)
** Dynamic projections
** Geospatial Queries

Loading…
Cancel
Save