diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java index 8e4f59a3e..8dcd1b785 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java @@ -32,6 +32,8 @@ import java.util.stream.Stream; import org.jspecify.annotations.Nullable; import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Sort; import org.springframework.data.javapoet.LordOfTheStrings; @@ -812,11 +814,22 @@ class JdbcCodeBlocks { return builder.build(); } - if (methodReturn.toClass().equals(Streamable.class)) { + if (isStreamable(methodReturn)) { builder.addStatement( "return ($1T) $1T.of(($2T) convertMany($3L, %s))".formatted(dynamicProjection ? "$4L" : "$4T.class"), Streamable.class, Iterable.class, result, queryResultTypeRef); + } else if (isStreamableWrapper(methodReturn) && canConvert(Streamable.class, methodReturn)) { + + builder.addStatement( + "$1T $2L = ($1T) convertMany($3L, %s)".formatted(dynamicProjection ? "$4L" : "$4T.class"), Iterable.class, + context.localVariable("converted"), result, queryResultTypeRef); + + builder.addStatement( + "return ($1T) $2T.getSharedInstance().convert($3T.of($4L), $5T.valueOf($3T.class), $5T.valueOf($1T.class))", + methodReturn.toClass(), DefaultConversionService.class, Streamable.class, + context.localVariable("converted"), TypeDescriptor.class); } else { + builder.addStatement("return ($T) convertMany($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"), methodReturn.getTypeName(), result, queryResultTypeRef); } @@ -843,6 +856,18 @@ class JdbcCodeBlocks { return builder.build(); } + private boolean canConvert(Class from, MethodReturn methodReturn) { + return DefaultConversionService.getSharedInstance().canConvert(from, methodReturn.toClass()); + } + + private static boolean isStreamable(MethodReturn methodReturn) { + return methodReturn.toClass().equals(Streamable.class); + } + + private static boolean isStreamableWrapper(MethodReturn methodReturn) { + return !isStreamable(methodReturn) && Streamable.class.isAssignableFrom(methodReturn.toClass()); + } + public static boolean returnsModifying(Class returnType) { return ClassUtils.resolvePrimitiveIfNecessary(returnType) == Integer.class diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java index 7dd18feb8..1efa4337b 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryIntegrationTests.java @@ -30,6 +30,7 @@ import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -603,6 +604,16 @@ public class JdbcRepositoryIntegrationTests { assertThat(slice.hasNext()).isTrue(); } + @Test // GH-2175 + public void streamableWrapperByNameShouldReturnCorrectResult() { + + repository.saveAll(Arrays.asList(new DummyEntity("a1"), new DummyEntity("a2"), new DummyEntity("a3"))); + + DummyEntities entities = repository.findEntitiesByNameContains("a", Limit.of(5)); + + assertThat(entities).hasSize(3); + } + @Test // GH-935 public void queryByOffsetDateTime() { @@ -1585,6 +1596,8 @@ public class JdbcRepositoryIntegrationTests { Slice findSliceByNameContains(String name, Pageable pageable); + DummyEntities findEntitiesByNameContains(String name, Limit limit); + @Query("SELECT * FROM DUMMY_ENTITY WHERE OFFSET_DATE_TIME > :threshhold") List findByOffsetDateTime(@Param("threshhold") OffsetDateTime threshhold); @@ -1620,6 +1633,7 @@ public class JdbcRepositoryIntegrationTests { @Query("SELECT * FROM DUMMY_ENTITY WHERE BYTES = :bytes") List findByBytes(byte[] bytes); + } public interface RootRepository extends ListCrudRepository { @@ -2002,6 +2016,20 @@ public class JdbcRepositoryIntegrationTests { } } + public static class DummyEntities implements Streamable { + + private final Streamable delegate; + + public DummyEntities(Streamable delegate) { + this.delegate = delegate; + } + + @Override + public Iterator iterator() { + return delegate.iterator(); + } + } + public static class DummyEntity { @Id Long idProp;