diff --git a/src/main/java/org/springframework/data/repository/query/ResultProcessor.java b/src/main/java/org/springframework/data/repository/query/ResultProcessor.java index 492e070ad..c37452168 100644 --- a/src/main/java/org/springframework/data/repository/query/ResultProcessor.java +++ b/src/main/java/org/springframework/data/repository/query/ResultProcessor.java @@ -23,11 +23,16 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Stream; import org.springframework.core.CollectionFactory; +import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.converter.Converter; +import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.data.domain.Slice; import org.springframework.data.projection.ProjectionFactory; +import org.springframework.data.util.ReflectionUtils; import org.springframework.util.Assert; /** @@ -127,7 +132,8 @@ public class ResultProcessor { Assert.notNull(preparingConverter, "Preparing converter must not be null!"); - ChainingConverter converter = ChainingConverter.of(type.getReturnedType(), preparingConverter).and(this.converter); + final ChainingConverter converter = ChainingConverter.of(type.getReturnedType(), preparingConverter) + .and(this.converter); if (source instanceof Slice && method.isPageQuery() || method.isSliceQuery()) { return (T) ((Slice) source).map(converter); @@ -145,6 +151,17 @@ public class ResultProcessor { return (T) target; } + if (ReflectionUtils.isJava8StreamType(source.getClass()) && method.isStreamQuery()) { + + return (T) ((Stream) source).map(new Function() { + + @Override + public T apply(Object t) { + return (T) (type.isInstance(t) ? t : converter.convert(t)); + } + }); + } + return (T) converter.convert(source); } @@ -211,6 +228,7 @@ public class ResultProcessor { private final @NonNull ReturnedType type; private final @NonNull ProjectionFactory factory; + private final ConversionService conversionService = new DefaultConversionService(); /* * (non-Javadoc) @@ -218,7 +236,14 @@ public class ResultProcessor { */ @Override public Object convert(Object source) { - return factory.createProjection(type.getReturnedType(), getProjectionTarget(source)); + + Class targetType = type.getReturnedType(); + + if (targetType.isInterface()) { + return factory.createProjection(targetType, getProjectionTarget(source)); + } + + return conversionService.convert(source, targetType); } private Object getProjectionTarget(Object source) { diff --git a/src/test/java/org/springframework/data/repository/query/ResultProcessorUnitTests.java b/src/test/java/org/springframework/data/repository/query/ResultProcessorUnitTests.java index 905289593..d832e9a5d 100644 --- a/src/test/java/org/springframework/data/repository/query/ResultProcessorUnitTests.java +++ b/src/test/java/org/springframework/data/repository/query/ResultProcessorUnitTests.java @@ -25,6 +25,8 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.junit.Test; import org.springframework.beans.factory.annotation.Value; @@ -183,11 +185,11 @@ public class ResultProcessorUnitTests { @Override public Object convert(Object source) { - return new SampleDTO(); + return new SampleDto(); } }); - assertThat(result, is(instanceOf(SampleDTO.class))); + assertThat(result, is(instanceOf(SampleDto.class))); } /** @@ -215,6 +217,34 @@ public class ResultProcessorUnitTests { assertThat(content.get(0), is(instanceOf(SampleProjection.class))); } + /** + * @see DATACMNS-859 + */ + @Test + public void supportsStreamAsReturnWrapper() throws Exception { + + Stream samples = Arrays.asList(new Sample("Dave", "Matthews")).stream(); + + Object result = getProcessor("findStreamProjection").processResult(samples); + + assertThat(result, is(instanceOf(Stream.class))); + List content = ((Stream) result).collect(Collectors.toList()); + + assertThat(content, is(not(empty()))); + assertThat(content.get(0), is(instanceOf(SampleProjection.class))); + } + + /** + * @see DATACMNS-860 + */ + @Test + public void supportsWrappingDto() throws Exception { + + Object result = getProcessor("findOneWrappingDto").processResult(new Sample("Dave", "Matthews")); + + assertThat(result, is(instanceOf(WrappingDto.class))); + } + private static ResultProcessor getProcessor(String methodName, Class... parameters) throws Exception { return getQueryMethod(methodName, parameters).getResultProcessor(); } @@ -230,13 +260,15 @@ public class ResultProcessorUnitTests { List findAll(); - List findAllDtos(); + List findAllDtos(); List findAllProjection(); Sample findOne(); - SampleDTO findOneDto(); + SampleDto findOneDto(); + + WrappingDto findOneWrappingDto(); SampleProjection findOneProjection(); @@ -247,6 +279,8 @@ public class ResultProcessorUnitTests { Slice findSliceProjection(Pageable pageable); T findOneDynamic(Class type); + + Stream findStreamProjection(); } static class Sample { @@ -258,10 +292,15 @@ public class ResultProcessorUnitTests { } } - static class SampleDTO {} + static class SampleDto {} - interface SampleProjection { + @lombok.Value + // Needs to be public until https://jira.spring.io/browse/SPR-14304 is resolved + public static class WrappingDto { + Sample sample; + } + interface SampleProjection { String getLastname(); }