Browse Source

Add Stream support to JdbcAggregateOperations

See #1714
Original pull request #1963

Signed-off-by: Sergey Korotaev <sergey.evgen.kor2501@gmail.com>
pull/1955/head
Sergey Korotaev 12 months ago committed by Jens Schauder
parent
commit
ea296429df
No known key found for this signature in database
GPG Key ID: 74F6C554AE971567
  1. 44
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java
  2. 34
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java
  3. 22
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java
  4. 48
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java
  5. 34
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java
  6. 22
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java
  7. 44
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ReadingDataAccessStrategy.java
  8. 22
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java
  9. 38
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java
  10. 46
      spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java
  11. 123
      spring-data-jdbc/src/test/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategyUnitTests.java

44
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java

@ -17,6 +17,7 @@ package org.springframework.data.jdbc.core; @@ -17,6 +17,7 @@ package org.springframework.data.jdbc.core;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException;
import org.springframework.data.domain.Example;
@ -35,6 +36,7 @@ import org.springframework.lang.Nullable; @@ -35,6 +36,7 @@ import org.springframework.lang.Nullable;
* @author Chirag Tailor
* @author Diego Krupitza
* @author Myeonghyeon Lee
* @author Sergey Korotaev
*/
public interface JdbcAggregateOperations {
@ -165,6 +167,17 @@ public interface JdbcAggregateOperations { @@ -165,6 +167,17 @@ public interface JdbcAggregateOperations {
*/
<T> List<T> findAllById(Iterable<?> ids, Class<T> domainType);
/**
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
*
* @param ids the Ids of the entities to load. Must not be {@code null}.
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> type of entities to load.
* @return the loaded entities. Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
/**
* Load all aggregates of a given type.
*
@ -174,6 +187,15 @@ public interface JdbcAggregateOperations { @@ -174,6 +187,15 @@ public interface JdbcAggregateOperations {
*/
<T> List<T> findAll(Class<T> domainType);
/**
* Load all aggregates of a given type to a {@link Stream}.
*
* @param domainType the type of the aggregate roots. Must not be {@code null}.
* @param <T> the type of the aggregate roots. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAll(Class<T> domainType);
/**
* Load all aggregates of a given type, sorted.
*
@ -185,6 +207,17 @@ public interface JdbcAggregateOperations { @@ -185,6 +207,17 @@ public interface JdbcAggregateOperations {
*/
<T> List<T> findAll(Class<T> domainType, Sort sort);
/**
* Loads all entities of the given type to a {@link Stream}, sorted.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @param sort the sorting information. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
* @since 2.0
*/
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
/**
* Load a page of (potentially sorted) aggregates of a given type.
*
@ -218,6 +251,17 @@ public interface JdbcAggregateOperations { @@ -218,6 +251,17 @@ public interface JdbcAggregateOperations {
*/
<T> List<T> findAll(Query query, Class<T> domainType);
/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
*
* @param query must not be {@literal null}.
* @param domainType the type of entities. Must not be {@code null}.
* @return a non-null list with all the matching results.
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
* @since 3.0
*/
<T> Stream<T> streamAll(Query query, Class<T> domainType);
/**
* Returns a {@link Page} of entities matching the given {@link Query}. In case no match could be found, an empty
* {@link Page} is returned.

34
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java

@ -25,6 +25,7 @@ import java.util.Map; @@ -25,6 +25,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.springframework.context.ApplicationContext;
@ -68,6 +69,7 @@ import org.springframework.util.ClassUtils; @@ -68,6 +69,7 @@ import org.springframework.util.ClassUtils;
* @author Myeonghyeon Lee
* @author Chirag Tailor
* @author Diego Krupitza
* @author Sergey Korotaev
*/
public class JdbcAggregateTemplate implements JdbcAggregateOperations {
@ -283,6 +285,16 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations { @@ -283,6 +285,16 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
return triggerAfterConvert(all);
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
Assert.notNull(domainType, "Domain type must not be null");
Stream<T> allStreamable = accessStrategy.streamAll(domainType, sort);
return allStreamable.map(this::triggerAfterConvert);
}
@Override
public <T> Page<T> findAll(Class<T> domainType, Pageable pageable) {
@ -307,6 +319,11 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations { @@ -307,6 +319,11 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
return triggerAfterConvert(all);
}
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
return accessStrategy.streamAll(query, domainType).map(this::triggerAfterConvert);
}
@Override
public <T> Page<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
@ -325,6 +342,12 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations { @@ -325,6 +342,12 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
return triggerAfterConvert(all);
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
Iterable<T> items = triggerAfterConvert(accessStrategy.findAll(domainType));
return StreamSupport.stream(items.spliterator(), false).map(this::triggerAfterConvert);
}
@Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
@ -335,6 +358,17 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations { @@ -335,6 +358,17 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
return triggerAfterConvert(allById);
}
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
Assert.notNull(ids, "Ids must not be null");
Assert.notNull(domainType, "Domain type must not be null");
Stream<T> allByIdStreamable = accessStrategy.streamAllByIds(ids, domainType);
return allByIdStreamable.map(this::triggerAfterConvert);
}
@Override
public <S> void delete(S aggregateRoot) {

22
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java

@ -22,6 +22,7 @@ import java.util.List; @@ -22,6 +22,7 @@ import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
@ -42,6 +43,7 @@ import org.springframework.data.relational.core.sql.LockMode; @@ -42,6 +43,7 @@ import org.springframework.data.relational.core.sql.LockMode;
* @author Myeonghyeon Lee
* @author Chirag Tailor
* @author Diego Krupitza
* @author Sergey Korotaev
* @since 1.1
*/
public class CascadingDataAccessStrategy implements DataAccessStrategy {
@ -132,11 +134,21 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy { @@ -132,11 +134,21 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy {
return collect(das -> das.findAll(domainType));
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
return collect(das -> das.streamAll(domainType));
}
@Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return collect(das -> das.findAllById(ids, domainType));
}
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
return collect(das -> das.streamAllByIds(ids, domainType));
}
@Override
public Iterable<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@ -153,6 +165,11 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy { @@ -153,6 +165,11 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy {
return collect(das -> das.findAll(domainType, sort));
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
return collect(das -> das.streamAll(domainType, sort));
}
@Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
return collect(das -> das.findAll(domainType, pageable));
@ -168,6 +185,11 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy { @@ -168,6 +185,11 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy {
return collect(das -> das.findAll(query, domainType));
}
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
return collect(das -> das.streamAll(query, domainType));
}
@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
return collect(das -> das.findAll(query, domainType, pageable));

48
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java

@ -18,6 +18,7 @@ package org.springframework.data.jdbc.core.convert; @@ -18,6 +18,7 @@ package org.springframework.data.jdbc.core.convert;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.dao.OptimisticLockingFailureException;
import org.springframework.data.domain.Pageable;
@ -41,6 +42,7 @@ import org.springframework.lang.Nullable; @@ -41,6 +42,7 @@ import org.springframework.lang.Nullable;
* @author Myeonghyeon Lee
* @author Chirag Tailor
* @author Diego Krupitza
* @author Sergey Korotaev
*/
public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationResolver {
@ -252,6 +254,16 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR @@ -252,6 +254,16 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override
<T> Iterable<T> findAll(Class<T> domainType);
/**
* Loads all entities of the given type to a {@link Stream}.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @return Guaranteed to be not {@code null}.
*/
@Override
<T> Stream<T> streamAll(Class<T> domainType);
/**
* Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids
* passed in matches the number of entities returned.
@ -264,6 +276,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR @@ -264,6 +276,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override
<T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType);
/**
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
*
* @param ids the Ids of the entities to load. Must not be {@code null}.
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> type of entities to load.
* @return the loaded entities. Guaranteed to be not {@code null}.
*/
@Override
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
@Override
Iterable<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path);
@ -280,6 +304,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR @@ -280,6 +304,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override
<T> Iterable<T> findAll(Class<T> domainType, Sort sort);
/**
* Loads all entities of the given type to a {@link Stream}, sorted.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @param sort the sorting information. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
* @since 2.0
*/
@Override
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
/**
* Loads all entities of the given type, paged and sorted.
*
@ -316,6 +352,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR @@ -316,6 +352,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override
<T> Iterable<T> findAll(Query query, Class<T> domainType);
/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
*
* @param query must not be {@literal null}.
* @param domainType the type of entities. Must not be {@code null}.
* @return a non-null list with all the matching results.
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
* @since 3.0
*/
@Override
<T> Stream<T> streamAll(Query query, Class<T> domainType);
/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable}
* to the result.

34
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

@ -22,6 +22,7 @@ import java.sql.SQLException; @@ -22,6 +22,7 @@ import java.sql.SQLException;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.dao.OptimisticLockingFailureException;
@ -60,6 +61,7 @@ import org.springframework.util.Assert; @@ -60,6 +61,7 @@ import org.springframework.util.Assert;
* @author Radim Tlusty
* @author Chirag Tailor
* @author Diego Krupitza
* @author Sergey Korotaev
* @since 1.1
*/
public class DefaultDataAccessStrategy implements DataAccessStrategy {
@ -276,6 +278,11 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy { @@ -276,6 +278,11 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType));
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(), getEntityRowMapper(domainType));
}
@Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
@ -288,6 +295,19 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy { @@ -288,6 +295,19 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
return operations.query(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
}
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
if (!ids.iterator().hasNext()) {
return Stream.empty();
}
SqlParameterSource parameterSource = sqlParametersFactory.forQueryByIds(ids, domainType);
String findAllInListSql = sql(domainType).getFindAllInList();
return operations.queryForStream(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
}
@Override
@SuppressWarnings("unchecked")
public List<Object> findAllByPath(Identifier identifier,
@ -342,6 +362,11 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy { @@ -342,6 +362,11 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType));
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(), getEntityRowMapper(domainType));
}
@Override
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType));
@ -369,6 +394,15 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy { @@ -369,6 +394,15 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
return operations.query(sqlQuery, parameterSource, getEntityRowMapper(domainType));
}
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource);
return operations.queryForStream(sqlQuery, parameterSource, getEntityRowMapper(domainType));
}
@Override
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {

22
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java

@ -17,6 +17,7 @@ package org.springframework.data.jdbc.core.convert; @@ -17,6 +17,7 @@ package org.springframework.data.jdbc.core.convert;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
@ -37,6 +38,7 @@ import org.springframework.util.Assert; @@ -37,6 +38,7 @@ import org.springframework.util.Assert;
* @author Myeonghyeon Lee
* @author Chirag Tailor
* @author Diego Krupitza
* @author Sergey Korotaev
* @since 1.1
*/
public class DelegatingDataAccessStrategy implements DataAccessStrategy {
@ -135,11 +137,21 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy { @@ -135,11 +137,21 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy {
return delegate.findAll(domainType);
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
return delegate.streamAll(domainType);
}
@Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return delegate.findAllById(ids, domainType);
}
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
return delegate.streamAllByIds(ids, domainType);
}
@Override
public Iterable<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@ -156,6 +168,11 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy { @@ -156,6 +168,11 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy {
return delegate.findAll(domainType, sort);
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
return delegate.streamAll(domainType, sort);
}
@Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
return delegate.findAll(domainType, pageable);
@ -171,6 +188,11 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy { @@ -171,6 +188,11 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy {
return delegate.findAll(query, domainType);
}
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
return delegate.streamAll(query, domainType);
}
@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
return delegate.findAll(query, domainType, pageable);

44
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ReadingDataAccessStrategy.java

@ -17,6 +17,7 @@ @@ -17,6 +17,7 @@
package org.springframework.data.jdbc.core.convert;
import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
@ -27,6 +28,7 @@ import org.springframework.lang.Nullable; @@ -27,6 +28,7 @@ import org.springframework.lang.Nullable;
* The finding methods of a {@link DataAccessStrategy}.
*
* @author Jens Schauder
* @author Sergey Korotaev
* @since 3.2
*/
interface ReadingDataAccessStrategy {
@ -51,6 +53,15 @@ interface ReadingDataAccessStrategy { @@ -51,6 +53,15 @@ interface ReadingDataAccessStrategy {
*/
<T> Iterable<T> findAll(Class<T> domainType);
/**
* Loads all entities of the given type to a {@link Stream}.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @return Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAll(Class<T> domainType);
/**
* Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids
* passed in matches the number of entities returned.
@ -62,6 +73,17 @@ interface ReadingDataAccessStrategy { @@ -62,6 +73,17 @@ interface ReadingDataAccessStrategy {
*/
<T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType);
/**
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
*
* @param ids the Ids of the entities to load. Must not be {@code null}.
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> type of entities to load.
* @return the loaded entities. Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
/**
* Loads all entities of the given type, sorted.
*
@ -73,6 +95,17 @@ interface ReadingDataAccessStrategy { @@ -73,6 +95,17 @@ interface ReadingDataAccessStrategy {
*/
<T> Iterable<T> findAll(Class<T> domainType, Sort sort);
/**
* Loads all entities of the given type to a {@link Stream}, sorted.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @param sort the sorting information. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
* @since 2.0
*/
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
/**
* Loads all entities of the given type, paged and sorted.
*
@ -106,6 +139,17 @@ interface ReadingDataAccessStrategy { @@ -106,6 +139,17 @@ interface ReadingDataAccessStrategy {
*/
<T> Iterable<T> findAll(Query query, Class<T> domainType);
/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
*
* @param query must not be {@literal null}.
* @param domainType the type of entities. Must not be {@code null}.
* @return a non-null list with all the matching results.
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
* @since 3.0
*/
<T> Stream<T> streamAll(Query query, Class<T> domainType);
/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable}
* to the result.

22
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java

@ -18,6 +18,7 @@ package org.springframework.data.jdbc.core.convert; @@ -18,6 +18,7 @@ package org.springframework.data.jdbc.core.convert;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
@ -32,6 +33,7 @@ import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; @@ -32,6 +33,7 @@ import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
*
* @author Jens Schauder
* @author Mark Paluch
* @author Sergey Korotaev
* @since 3.2
*/
class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy {
@ -56,16 +58,31 @@ class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy { @@ -56,16 +58,31 @@ class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy {
return aggregateReader.findAll(getPersistentEntity(domainType));
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
throw new UnsupportedOperationException();
}
@Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return aggregateReader.findAllById(ids, getPersistentEntity(domainType));
}
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
throw new UnsupportedOperationException();
}
@Override
public <T> List<T> findAll(Class<T> domainType, Sort sort) {
throw new UnsupportedOperationException();
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
throw new UnsupportedOperationException();
}
@Override
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
throw new UnsupportedOperationException();
@ -81,6 +98,11 @@ class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy { @@ -81,6 +98,11 @@ class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy {
return aggregateReader.findAll(query, getPersistentEntity(domainType));
}
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
throw new UnsupportedOperationException();
}
@Override
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
throw new UnsupportedOperationException();

38
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java

@ -22,7 +22,10 @@ import java.util.HashMap; @@ -22,7 +22,10 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.ibatis.cursor.Cursor;
import org.apache.ibatis.session.SqlSession;
import org.mybatis.spring.SqlSessionTemplate;
import org.springframework.dao.EmptyResultDataAccessException;
@ -59,6 +62,7 @@ import org.springframework.util.Assert; @@ -59,6 +62,7 @@ import org.springframework.util.Assert;
* @author Chirag Tailor
* @author Christopher Klein
* @author Mikhail Polivakha
* @author Sergey Korotaev
*/
public class MyBatisDataAccessStrategy implements DataAccessStrategy {
@ -263,12 +267,28 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy { @@ -263,12 +267,28 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
return sqlSession().selectList(statement, parameter);
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
String statement = namespace(domainType) + ".streamAll";
MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap());
Cursor<T> cursor = sqlSession().selectCursor(statement, parameter);
return StreamSupport.stream(cursor.spliterator(), false);
}
@Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return sqlSession().selectList(namespace(domainType) + ".findAllById",
new MyBatisContext(ids, null, domainType, Collections.emptyMap()));
}
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
String statement = namespace(domainType) + ".streamAllByIds";
MyBatisContext parameter = new MyBatisContext(ids, null, domainType, Collections.emptyMap());
Cursor<T> cursor = sqlSession().selectCursor(statement, parameter);
return StreamSupport.stream(cursor.spliterator(), false);
}
@Override
public List<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@ -296,6 +316,19 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy { @@ -296,6 +316,19 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
new MyBatisContext(null, null, domainType, additionalContext));
}
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
Map<String, Object> additionalContext = new HashMap<>();
additionalContext.put("sort", sort);
String statement = namespace(domainType) + ".streamAllSorted";
MyBatisContext parameter = new MyBatisContext(null, null, domainType, additionalContext);
Cursor<T> cursor = sqlSession().selectCursor(statement, parameter);
return StreamSupport.stream(cursor.spliterator(), false);
}
@Override
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
@ -315,6 +348,11 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy { @@ -315,6 +348,11 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public <T> Stream<T> streamAll(Query query, Class<T> probeType) {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public <T> List<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
throw new UnsupportedOperationException("Not implemented");

46
spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java

@ -27,8 +27,8 @@ import java.util.*; @@ -27,8 +27,8 @@ import java.util.*;
import java.util.ArrayList;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.assertj.core.api.SoftAssertions;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEventPublisher;
@ -81,6 +81,7 @@ import org.springframework.test.context.ContextConfiguration; @@ -81,6 +81,7 @@ import org.springframework.test.context.ContextConfiguration;
* @author Mikhail Polivakha
* @author Chirag Tailor
* @author Vincent Galloy
* @author Sergey Korotaev
*/
@IntegrationTest
abstract class AbstractJdbcAggregateTemplateIntegrationTests {
@ -309,6 +310,18 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { @@ -309,6 +310,18 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests {
.containsExactly(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
}
@Test // GH-1714
void saveAndLoadManeEntitiesWithReferenceEntityLikeStream() {
template.save(legoSet);
Stream<LegoSet> streamable = template.streamAll(LegoSet.class);
assertThat(streamable)
.extracting("id", "manual.id", "manual.content") //
.containsExactly(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
}
@Test // DATAJDBC-101
void saveAndLoadManyEntitiesWithReferencedEntitySorted() {
@ -323,6 +336,20 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { @@ -323,6 +336,20 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests {
.containsExactly("Frozen", "Lava", "Star");
}
@Test // GH-1714
void saveAndLoadManyEntitiesWithReferencedEntitySortedLikeStream() {
template.save(createLegoSet("Lava"));
template.save(createLegoSet("Star"));
template.save(createLegoSet("Frozen"));
Stream<LegoSet> reloadedLegoSets = template.streamAll(LegoSet.class, Sort.by("name"));
assertThat(reloadedLegoSets) //
.extracting("name") //
.containsExactly("Frozen", "Lava", "Star");
}
@Test // DATAJDBC-101
void saveAndLoadManyEntitiesWithReferencedEntitySortedAndPaged() {
@ -360,6 +387,12 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { @@ -360,6 +387,12 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests {
.isInstanceOf(InvalidPersistentPropertyPath.class);
}
@Test // GH-1714
void findByNonPropertySortLikeStreamFails() {
assertThatThrownBy(() -> template.streamAll(LegoSet.class, Sort.by("somethingNotExistant")))
.isInstanceOf(InvalidPersistentPropertyPath.class);
}
@Test // DATAJDBC-112
void saveAndLoadManyEntitiesByIdWithReferencedEntity() {
@ -371,6 +404,17 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { @@ -371,6 +404,17 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests {
.contains(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
}
@Test // GH-1714
void saveAndLoadManyEntitiesByIdWithReferencedEntityLikeStream() {
template.save(legoSet);
Stream<LegoSet> reloadedLegoSets = template.streamAllByIds(singletonList(legoSet.id), LegoSet.class);
assertThat(reloadedLegoSets).hasSize(1).extracting("id", "manual.id", "manual.content")
.contains(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
}
@Test // DATAJDBC-112
void saveAndLoadAnEntityWithReferencedNullEntity() {

123
spring-data-jdbc/src/test/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategyUnitTests.java

@ -22,7 +22,12 @@ import static org.mockito.ArgumentMatchers.*; @@ -22,7 +22,12 @@ import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
import static org.springframework.data.relational.core.sql.SqlIdentifier.*;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;
import org.apache.ibatis.cursor.Cursor;
import org.apache.ibatis.session.SqlSession;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
@ -43,6 +48,7 @@ import org.springframework.data.relational.core.mapping.RelationalPersistentProp @@ -43,6 +48,7 @@ import org.springframework.data.relational.core.mapping.RelationalPersistentProp
* @author Mark Paluch
* @author Tyler Van Gorder
* @author Chirag Tailor
* @author Sergey Korotaev
*/
public class MyBatisDataAccessStrategyUnitTests {
@ -241,6 +247,36 @@ public class MyBatisDataAccessStrategyUnitTests { @@ -241,6 +247,36 @@ public class MyBatisDataAccessStrategyUnitTests {
);
}
@Test
public void streamAll() {
String value = "some answer";
Cursor<String> cursor = getCursor(value);
when(session.selectCursor(anyString(), any())).then(answer -> cursor);
Stream<String> streamable = accessStrategy.streamAll(String.class);
verify(session).selectCursor(eq("java.lang.StringMapper.streamAll"), captor.capture());
assertThat(streamable).isNotNull().containsExactly(value);
assertThat(captor.getValue()) //
.isNotNull() //
.extracting( //
MyBatisContext::getInstance, //
MyBatisContext::getId, //
MyBatisContext::getDomainType, //
c -> c.get("key") //
).containsExactly( //
null, //
null, //
String.class, //
null //
);
}
@Test // DATAJDBC-123
public void findAllById() {
@ -263,6 +299,33 @@ public class MyBatisDataAccessStrategyUnitTests { @@ -263,6 +299,33 @@ public class MyBatisDataAccessStrategyUnitTests {
);
}
@Test
public void streamAllByIds() {
String value = "some answer 2";
Cursor<String> cursor = getCursor(value);
when(session.selectCursor(anyString(), any())).then(answer -> cursor);
accessStrategy.streamAllByIds(asList("id1", "id2"), String.class);
verify(session).selectCursor(eq("java.lang.StringMapper.streamAllByIds"), captor.capture());
assertThat(captor.getValue()) //
.isNotNull() //
.extracting( //
MyBatisContext::getInstance, //
MyBatisContext::getId, //
MyBatisContext::getDomainType, //
c -> c.get("key") //
).containsExactly( //
null, //
asList("id1", "id2"), //
String.class, //
null //
);
}
@SuppressWarnings("unchecked")
@Test // DATAJDBC-384
public void findAllByPath() {
@ -367,6 +430,33 @@ public class MyBatisDataAccessStrategyUnitTests { @@ -367,6 +430,33 @@ public class MyBatisDataAccessStrategyUnitTests {
);
}
@Test
public void streamAllSorted() {
String value = "some answer 3";
Cursor<String> cursor = getCursor(value);
when(session.selectCursor(anyString(), any())).then(answer -> cursor);
accessStrategy.streamAll(String.class, Sort.by("length"));
verify(session).selectCursor(eq("java.lang.StringMapper.streamAllSorted"), captor.capture());
assertThat(captor.getValue()) //
.isNotNull() //
.extracting( //
MyBatisContext::getInstance, //
MyBatisContext::getId, //
MyBatisContext::getDomainType, //
c -> c.get("sort") //
).containsExactly( //
null, //
null, //
String.class, //
Sort.by("length") //
);
}
@Test // DATAJDBC-101
public void findAllPaged() {
@ -399,5 +489,36 @@ public class MyBatisDataAccessStrategyUnitTests { @@ -399,5 +489,36 @@ public class MyBatisDataAccessStrategyUnitTests {
ChildTwo two;
}
private static class ChildTwo {}
private static class ChildTwo {
}
private Cursor<String> getCursor(String value) {
return new Cursor<>() {
@Override
public boolean isOpen() {
return false;
}
@Override
public boolean isConsumed() {
return false;
}
@Override
public int getCurrentIndex() {
return 0;
}
@Override
public void close() {
}
@NotNull
@Override
public Iterator<String> iterator() {
return List.of(value).iterator();
}
};
}
}

Loading…
Cancel
Save