From 1be3b9e8f4e15d1a7d490f09e5ec4b9c2474cda2 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 11 Nov 2025 10:30:17 +0100 Subject: [PATCH] Support custom Streamable return type in AOT repository. This commit uses a conversion service to convert custom streamable types. See: #5089 --- .../repository/aot/MongoCodeBlocks.java | 21 ++++++++++++++++--- ...tractPersonRepositoryIntegrationTests.java | 12 +++++++++++ .../mongodb/repository/PersonRepository.java | 16 ++++++++++++++ 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 841fe1c2f..432113833 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -20,6 +20,8 @@ import java.util.regex.Pattern; import org.bson.Document; import org.jspecify.annotations.Nullable; import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.data.mapping.model.SimpleTypeHolder; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.aot.AggregationBlocks.AggregationCodeBlockBuilder; @@ -38,6 +40,7 @@ import org.springframework.data.repository.aot.generate.MethodReturn; import org.springframework.data.util.Streamable; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.ClassUtils; import org.springframework.util.NumberUtils; import org.springframework.util.StringUtils; @@ -238,9 +241,21 @@ class MongoCodeBlocks { * {@link MethodReturn} indicates so. */ public static CodeBlock potentiallyWrapStreamable(MethodReturn methodReturn, CodeBlock returningIterable) { - return methodReturn.toClass().equals(Streamable.class) - ? CodeBlock.of("$T.of($L)", Streamable.class, returningIterable) - : returningIterable; + + Class returnType = methodReturn.toClass(); + + if (returnType.equals(Streamable.class)) { + return CodeBlock.of("$T.of($L)", Streamable.class, returningIterable); + } + + if (ClassUtils.isAssignable(Streamable.class, returnType)) { + + return CodeBlock.of( + "($1T) $2T.getSharedInstance().convert($3T.of($4L), $5T.valueOf($3T.class), $5T.valueOf($1T.class))", + returnType, DefaultConversionService.class, Streamable.class, returningIterable, TypeDescriptor.class); + } + + return returningIterable; } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java index c7eca37c3..9cd2127a1 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java @@ -73,6 +73,7 @@ import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; import org.springframework.data.mongodb.repository.Person.Sex; +import org.springframework.data.mongodb.repository.PersonRepository.Persons; import org.springframework.data.mongodb.repository.SampleEvaluationContextExtension.SampleSecurityContextHolder; import org.springframework.data.mongodb.test.util.DirtiesStateExtension; import org.springframework.data.mongodb.test.util.DirtiesStateExtension.DirtiesState; @@ -324,6 +325,17 @@ public abstract class AbstractPersonRepositoryIntegrationTests implements Dirtie assertThat(result).hasSize(1).contains(dave); } + @Test // GH-5089 + void useCustomReturnTypeImplementingStreamable() { + + Address address = new Address("Foo Street 1", "C0123", "Bar"); + dave.setAddress(address); + repository.save(dave); + + Persons result = repository.streamPersonsByAddress(address); + assertThat(result).hasSize(1).contains(dave); + } + @Test // GH-5089 void streamPersonByAddressCorrectlyWhenPaged() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java index 7718c1241..cf8e265ca 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java @@ -17,6 +17,7 @@ package org.springframework.data.mongodb.repository; import java.util.Collection; import java.util.Date; +import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -214,6 +215,8 @@ public interface PersonRepository extends MongoRepository, Query Streamable streamByAddress(Address address); + Persons streamPersonsByAddress(Address address); + Streamable streamByAddress(Address address, Pageable pageable); List findByAddressZipCode(String zipCode); @@ -502,4 +505,17 @@ public interface PersonRepository extends MongoRepository, Query List findBySpiritAnimal(User user); + class Persons implements Streamable { + + private final Streamable streamable; + + public Persons(Streamable streamable) { + this.streamable = streamable; + } + + @Override + public Iterator iterator() { + return streamable.iterator(); + } + } }