Browse Source

Add Stream support to JdbcAggregateOperations

See #1714
Original pull request #1963

Signed-off-by: Sergey Korotaev <sergey.evgen.kor2501@gmail.com>
pull/1955/head
Sergey Korotaev 12 months ago committed by Jens Schauder
parent
commit
ea296429df
No known key found for this signature in database
GPG Key ID: 74F6C554AE971567
  1. 44
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java
  2. 34
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java
  3. 22
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java
  4. 48
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java
  5. 34
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java
  6. 22
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java
  7. 44
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ReadingDataAccessStrategy.java
  8. 22
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java
  9. 38
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java
  10. 46
      spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java
  11. 123
      spring-data-jdbc/src/test/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategyUnitTests.java

44
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java

@ -17,6 +17,7 @@ package org.springframework.data.jdbc.core;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException; import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException;
import org.springframework.data.domain.Example; import org.springframework.data.domain.Example;
@ -35,6 +36,7 @@ import org.springframework.lang.Nullable;
* @author Chirag Tailor * @author Chirag Tailor
* @author Diego Krupitza * @author Diego Krupitza
* @author Myeonghyeon Lee * @author Myeonghyeon Lee
* @author Sergey Korotaev
*/ */
public interface JdbcAggregateOperations { public interface JdbcAggregateOperations {
@ -165,6 +167,17 @@ public interface JdbcAggregateOperations {
*/ */
<T> List<T> findAllById(Iterable<?> ids, Class<T> domainType); <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType);
/**
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
*
* @param ids the Ids of the entities to load. Must not be {@code null}.
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> type of entities to load.
* @return the loaded entities. Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
/** /**
* Load all aggregates of a given type. * Load all aggregates of a given type.
* *
@ -174,6 +187,15 @@ public interface JdbcAggregateOperations {
*/ */
<T> List<T> findAll(Class<T> domainType); <T> List<T> findAll(Class<T> domainType);
/**
* Load all aggregates of a given type to a {@link Stream}.
*
* @param domainType the type of the aggregate roots. Must not be {@code null}.
* @param <T> the type of the aggregate roots. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAll(Class<T> domainType);
/** /**
* Load all aggregates of a given type, sorted. * Load all aggregates of a given type, sorted.
* *
@ -185,6 +207,17 @@ public interface JdbcAggregateOperations {
*/ */
<T> List<T> findAll(Class<T> domainType, Sort sort); <T> List<T> findAll(Class<T> domainType, Sort sort);
/**
* Loads all entities of the given type to a {@link Stream}, sorted.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @param sort the sorting information. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
* @since 2.0
*/
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
/** /**
* Load a page of (potentially sorted) aggregates of a given type. * Load a page of (potentially sorted) aggregates of a given type.
* *
@ -218,6 +251,17 @@ public interface JdbcAggregateOperations {
*/ */
<T> List<T> findAll(Query query, Class<T> domainType); <T> List<T> findAll(Query query, Class<T> domainType);
/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
*
* @param query must not be {@literal null}.
* @param domainType the type of entities. Must not be {@code null}.
* @return a non-null list with all the matching results.
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
* @since 3.0
*/
<T> Stream<T> streamAll(Query query, Class<T> domainType);
/** /**
* Returns a {@link Page} of entities matching the given {@link Query}. In case no match could be found, an empty * Returns a {@link Page} of entities matching the given {@link Query}. In case no match could be found, an empty
* {@link Page} is returned. * {@link Page} is returned.

34
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java

@ -25,6 +25,7 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport; import java.util.stream.StreamSupport;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
@ -68,6 +69,7 @@ import org.springframework.util.ClassUtils;
* @author Myeonghyeon Lee * @author Myeonghyeon Lee
* @author Chirag Tailor * @author Chirag Tailor
* @author Diego Krupitza * @author Diego Krupitza
* @author Sergey Korotaev
*/ */
public class JdbcAggregateTemplate implements JdbcAggregateOperations { public class JdbcAggregateTemplate implements JdbcAggregateOperations {
@ -283,6 +285,16 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
return triggerAfterConvert(all); return triggerAfterConvert(all);
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
Assert.notNull(domainType, "Domain type must not be null");
Stream<T> allStreamable = accessStrategy.streamAll(domainType, sort);
return allStreamable.map(this::triggerAfterConvert);
}
@Override @Override
public <T> Page<T> findAll(Class<T> domainType, Pageable pageable) { public <T> Page<T> findAll(Class<T> domainType, Pageable pageable) {
@ -307,6 +319,11 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
return triggerAfterConvert(all); return triggerAfterConvert(all);
} }
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
return accessStrategy.streamAll(query, domainType).map(this::triggerAfterConvert);
}
@Override @Override
public <T> Page<T> findAll(Query query, Class<T> domainType, Pageable pageable) { public <T> Page<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
@ -325,6 +342,12 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
return triggerAfterConvert(all); return triggerAfterConvert(all);
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
Iterable<T> items = triggerAfterConvert(accessStrategy.findAll(domainType));
return StreamSupport.stream(items.spliterator(), false).map(this::triggerAfterConvert);
}
@Override @Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) { public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
@ -335,6 +358,17 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
return triggerAfterConvert(allById); return triggerAfterConvert(allById);
} }
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
Assert.notNull(ids, "Ids must not be null");
Assert.notNull(domainType, "Domain type must not be null");
Stream<T> allByIdStreamable = accessStrategy.streamAllByIds(ids, domainType);
return allByIdStreamable.map(this::triggerAfterConvert);
}
@Override @Override
public <S> void delete(S aggregateRoot) { public <S> void delete(S aggregateRoot) {

22
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java

@ -22,6 +22,7 @@ import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Stream;
import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort;
@ -42,6 +43,7 @@ import org.springframework.data.relational.core.sql.LockMode;
* @author Myeonghyeon Lee * @author Myeonghyeon Lee
* @author Chirag Tailor * @author Chirag Tailor
* @author Diego Krupitza * @author Diego Krupitza
* @author Sergey Korotaev
* @since 1.1 * @since 1.1
*/ */
public class CascadingDataAccessStrategy implements DataAccessStrategy { public class CascadingDataAccessStrategy implements DataAccessStrategy {
@ -132,11 +134,21 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy {
return collect(das -> das.findAll(domainType)); return collect(das -> das.findAll(domainType));
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
return collect(das -> das.streamAll(domainType));
}
@Override @Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) { public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return collect(das -> das.findAllById(ids, domainType)); return collect(das -> das.findAllById(ids, domainType));
} }
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
return collect(das -> das.streamAllByIds(ids, domainType));
}
@Override @Override
public Iterable<Object> findAllByPath(Identifier identifier, public Iterable<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) { PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@ -153,6 +165,11 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy {
return collect(das -> das.findAll(domainType, sort)); return collect(das -> das.findAll(domainType, sort));
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
return collect(das -> das.streamAll(domainType, sort));
}
@Override @Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) { public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
return collect(das -> das.findAll(domainType, pageable)); return collect(das -> das.findAll(domainType, pageable));
@ -168,6 +185,11 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy {
return collect(das -> das.findAll(query, domainType)); return collect(das -> das.findAll(query, domainType));
} }
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
return collect(das -> das.streamAll(query, domainType));
}
@Override @Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) { public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
return collect(das -> das.findAll(query, domainType, pageable)); return collect(das -> das.findAll(query, domainType, pageable));

48
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java

@ -18,6 +18,7 @@ package org.springframework.data.jdbc.core.convert;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.dao.OptimisticLockingFailureException;
import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Pageable;
@ -41,6 +42,7 @@ import org.springframework.lang.Nullable;
* @author Myeonghyeon Lee * @author Myeonghyeon Lee
* @author Chirag Tailor * @author Chirag Tailor
* @author Diego Krupitza * @author Diego Krupitza
* @author Sergey Korotaev
*/ */
public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationResolver { public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationResolver {
@ -252,6 +254,16 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override @Override
<T> Iterable<T> findAll(Class<T> domainType); <T> Iterable<T> findAll(Class<T> domainType);
/**
* Loads all entities of the given type to a {@link Stream}.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @return Guaranteed to be not {@code null}.
*/
@Override
<T> Stream<T> streamAll(Class<T> domainType);
/** /**
* Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids * Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids
* passed in matches the number of entities returned. * passed in matches the number of entities returned.
@ -264,6 +276,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override @Override
<T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType); <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType);
/**
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
*
* @param ids the Ids of the entities to load. Must not be {@code null}.
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> type of entities to load.
* @return the loaded entities. Guaranteed to be not {@code null}.
*/
@Override
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
@Override @Override
Iterable<Object> findAllByPath(Identifier identifier, Iterable<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path); PersistentPropertyPath<? extends RelationalPersistentProperty> path);
@ -280,6 +304,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override @Override
<T> Iterable<T> findAll(Class<T> domainType, Sort sort); <T> Iterable<T> findAll(Class<T> domainType, Sort sort);
/**
* Loads all entities of the given type to a {@link Stream}, sorted.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @param sort the sorting information. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
* @since 2.0
*/
@Override
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
/** /**
* Loads all entities of the given type, paged and sorted. * Loads all entities of the given type, paged and sorted.
* *
@ -316,6 +352,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override @Override
<T> Iterable<T> findAll(Query query, Class<T> domainType); <T> Iterable<T> findAll(Query query, Class<T> domainType);
/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
*
* @param query must not be {@literal null}.
* @param domainType the type of entities. Must not be {@code null}.
* @return a non-null list with all the matching results.
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
* @since 3.0
*/
@Override
<T> Stream<T> streamAll(Query query, Class<T> domainType);
/** /**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable} * Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable}
* to the result. * to the result.

34
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

@ -22,6 +22,7 @@ import java.sql.SQLException;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.dao.OptimisticLockingFailureException;
@ -60,6 +61,7 @@ import org.springframework.util.Assert;
* @author Radim Tlusty * @author Radim Tlusty
* @author Chirag Tailor * @author Chirag Tailor
* @author Diego Krupitza * @author Diego Krupitza
* @author Sergey Korotaev
* @since 1.1 * @since 1.1
*/ */
public class DefaultDataAccessStrategy implements DataAccessStrategy { public class DefaultDataAccessStrategy implements DataAccessStrategy {
@ -276,6 +278,11 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType)); return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType));
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(), getEntityRowMapper(domainType));
}
@Override @Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) { public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
@ -288,6 +295,19 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
return operations.query(findAllInListSql, parameterSource, getEntityRowMapper(domainType)); return operations.query(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
} }
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
if (!ids.iterator().hasNext()) {
return Stream.empty();
}
SqlParameterSource parameterSource = sqlParametersFactory.forQueryByIds(ids, domainType);
String findAllInListSql = sql(domainType).getFindAllInList();
return operations.queryForStream(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
}
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public List<Object> findAllByPath(Identifier identifier, public List<Object> findAllByPath(Identifier identifier,
@ -342,6 +362,11 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType)); return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType));
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(), getEntityRowMapper(domainType));
}
@Override @Override
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) { public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType)); return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType));
@ -369,6 +394,15 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
return operations.query(sqlQuery, parameterSource, getEntityRowMapper(domainType)); return operations.query(sqlQuery, parameterSource, getEntityRowMapper(domainType));
} }
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource);
return operations.queryForStream(sqlQuery, parameterSource, getEntityRowMapper(domainType));
}
@Override @Override
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) { public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {

22
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java

@ -17,6 +17,7 @@ package org.springframework.data.jdbc.core.convert;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort;
@ -37,6 +38,7 @@ import org.springframework.util.Assert;
* @author Myeonghyeon Lee * @author Myeonghyeon Lee
* @author Chirag Tailor * @author Chirag Tailor
* @author Diego Krupitza * @author Diego Krupitza
* @author Sergey Korotaev
* @since 1.1 * @since 1.1
*/ */
public class DelegatingDataAccessStrategy implements DataAccessStrategy { public class DelegatingDataAccessStrategy implements DataAccessStrategy {
@ -135,11 +137,21 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy {
return delegate.findAll(domainType); return delegate.findAll(domainType);
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
return delegate.streamAll(domainType);
}
@Override @Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) { public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return delegate.findAllById(ids, domainType); return delegate.findAllById(ids, domainType);
} }
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
return delegate.streamAllByIds(ids, domainType);
}
@Override @Override
public Iterable<Object> findAllByPath(Identifier identifier, public Iterable<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) { PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@ -156,6 +168,11 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy {
return delegate.findAll(domainType, sort); return delegate.findAll(domainType, sort);
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
return delegate.streamAll(domainType, sort);
}
@Override @Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) { public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
return delegate.findAll(domainType, pageable); return delegate.findAll(domainType, pageable);
@ -171,6 +188,11 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy {
return delegate.findAll(query, domainType); return delegate.findAll(query, domainType);
} }
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
return delegate.streamAll(query, domainType);
}
@Override @Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) { public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
return delegate.findAll(query, domainType, pageable); return delegate.findAll(query, domainType, pageable);

44
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/ReadingDataAccessStrategy.java

@ -17,6 +17,7 @@
package org.springframework.data.jdbc.core.convert; package org.springframework.data.jdbc.core.convert;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort;
@ -27,6 +28,7 @@ import org.springframework.lang.Nullable;
* The finding methods of a {@link DataAccessStrategy}. * The finding methods of a {@link DataAccessStrategy}.
* *
* @author Jens Schauder * @author Jens Schauder
* @author Sergey Korotaev
* @since 3.2 * @since 3.2
*/ */
interface ReadingDataAccessStrategy { interface ReadingDataAccessStrategy {
@ -51,6 +53,15 @@ interface ReadingDataAccessStrategy {
*/ */
<T> Iterable<T> findAll(Class<T> domainType); <T> Iterable<T> findAll(Class<T> domainType);
/**
* Loads all entities of the given type to a {@link Stream}.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @return Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAll(Class<T> domainType);
/** /**
* Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids * Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids
* passed in matches the number of entities returned. * passed in matches the number of entities returned.
@ -62,6 +73,17 @@ interface ReadingDataAccessStrategy {
*/ */
<T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType); <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType);
/**
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
*
* @param ids the Ids of the entities to load. Must not be {@code null}.
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> type of entities to load.
* @return the loaded entities. Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);
/** /**
* Loads all entities of the given type, sorted. * Loads all entities of the given type, sorted.
* *
@ -73,6 +95,17 @@ interface ReadingDataAccessStrategy {
*/ */
<T> Iterable<T> findAll(Class<T> domainType, Sort sort); <T> Iterable<T> findAll(Class<T> domainType, Sort sort);
/**
* Loads all entities of the given type to a {@link Stream}, sorted.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @param sort the sorting information. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
* @since 2.0
*/
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);
/** /**
* Loads all entities of the given type, paged and sorted. * Loads all entities of the given type, paged and sorted.
* *
@ -106,6 +139,17 @@ interface ReadingDataAccessStrategy {
*/ */
<T> Iterable<T> findAll(Query query, Class<T> domainType); <T> Iterable<T> findAll(Query query, Class<T> domainType);
/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
*
* @param query must not be {@literal null}.
* @param domainType the type of entities. Must not be {@code null}.
* @return a non-null list with all the matching results.
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
* @since 3.0
*/
<T> Stream<T> streamAll(Query query, Class<T> domainType);
/** /**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable} * Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable}
* to the result. * to the result.

22
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java

@ -18,6 +18,7 @@ package org.springframework.data.jdbc.core.convert;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Stream;
import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort;
@ -32,6 +33,7 @@ import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
* *
* @author Jens Schauder * @author Jens Schauder
* @author Mark Paluch * @author Mark Paluch
* @author Sergey Korotaev
* @since 3.2 * @since 3.2
*/ */
class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy { class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy {
@ -56,16 +58,31 @@ class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy {
return aggregateReader.findAll(getPersistentEntity(domainType)); return aggregateReader.findAll(getPersistentEntity(domainType));
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
throw new UnsupportedOperationException();
}
@Override @Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) { public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return aggregateReader.findAllById(ids, getPersistentEntity(domainType)); return aggregateReader.findAllById(ids, getPersistentEntity(domainType));
} }
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
throw new UnsupportedOperationException();
}
@Override @Override
public <T> List<T> findAll(Class<T> domainType, Sort sort) { public <T> List<T> findAll(Class<T> domainType, Sort sort) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
throw new UnsupportedOperationException();
}
@Override @Override
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) { public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
@ -81,6 +98,11 @@ class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy {
return aggregateReader.findAll(query, getPersistentEntity(domainType)); return aggregateReader.findAll(query, getPersistentEntity(domainType));
} }
@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
throw new UnsupportedOperationException();
}
@Override @Override
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) { public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

38
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java

@ -22,7 +22,10 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.ibatis.cursor.Cursor;
import org.apache.ibatis.session.SqlSession; import org.apache.ibatis.session.SqlSession;
import org.mybatis.spring.SqlSessionTemplate; import org.mybatis.spring.SqlSessionTemplate;
import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.dao.EmptyResultDataAccessException;
@ -59,6 +62,7 @@ import org.springframework.util.Assert;
* @author Chirag Tailor * @author Chirag Tailor
* @author Christopher Klein * @author Christopher Klein
* @author Mikhail Polivakha * @author Mikhail Polivakha
* @author Sergey Korotaev
*/ */
public class MyBatisDataAccessStrategy implements DataAccessStrategy { public class MyBatisDataAccessStrategy implements DataAccessStrategy {
@ -263,12 +267,28 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
return sqlSession().selectList(statement, parameter); return sqlSession().selectList(statement, parameter);
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
String statement = namespace(domainType) + ".streamAll";
MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap());
Cursor<T> cursor = sqlSession().selectCursor(statement, parameter);
return StreamSupport.stream(cursor.spliterator(), false);
}
@Override @Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) { public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return sqlSession().selectList(namespace(domainType) + ".findAllById", return sqlSession().selectList(namespace(domainType) + ".findAllById",
new MyBatisContext(ids, null, domainType, Collections.emptyMap())); new MyBatisContext(ids, null, domainType, Collections.emptyMap()));
} }
@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
String statement = namespace(domainType) + ".streamAllByIds";
MyBatisContext parameter = new MyBatisContext(ids, null, domainType, Collections.emptyMap());
Cursor<T> cursor = sqlSession().selectCursor(statement, parameter);
return StreamSupport.stream(cursor.spliterator(), false);
}
@Override @Override
public List<Object> findAllByPath(Identifier identifier, public List<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) { PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
@ -296,6 +316,19 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
new MyBatisContext(null, null, domainType, additionalContext)); new MyBatisContext(null, null, domainType, additionalContext));
} }
@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
Map<String, Object> additionalContext = new HashMap<>();
additionalContext.put("sort", sort);
String statement = namespace(domainType) + ".streamAllSorted";
MyBatisContext parameter = new MyBatisContext(null, null, domainType, additionalContext);
Cursor<T> cursor = sqlSession().selectCursor(statement, parameter);
return StreamSupport.stream(cursor.spliterator(), false);
}
@Override @Override
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) { public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
@ -315,6 +348,11 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
throw new UnsupportedOperationException("Not implemented"); throw new UnsupportedOperationException("Not implemented");
} }
@Override
public <T> Stream<T> streamAll(Query query, Class<T> probeType) {
throw new UnsupportedOperationException("Not implemented");
}
@Override @Override
public <T> List<T> findAll(Query query, Class<T> probeType, Pageable pageable) { public <T> List<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
throw new UnsupportedOperationException("Not implemented"); throw new UnsupportedOperationException("Not implemented");

46
spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java

@ -27,8 +27,8 @@ import java.util.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.assertj.core.api.SoftAssertions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisher;
@ -81,6 +81,7 @@ import org.springframework.test.context.ContextConfiguration;
* @author Mikhail Polivakha * @author Mikhail Polivakha
* @author Chirag Tailor * @author Chirag Tailor
* @author Vincent Galloy * @author Vincent Galloy
* @author Sergey Korotaev
*/ */
@IntegrationTest @IntegrationTest
abstract class AbstractJdbcAggregateTemplateIntegrationTests { abstract class AbstractJdbcAggregateTemplateIntegrationTests {
@ -309,6 +310,18 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests {
.containsExactly(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content)); .containsExactly(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
} }
@Test // GH-1714
void saveAndLoadManeEntitiesWithReferenceEntityLikeStream() {
template.save(legoSet);
Stream<LegoSet> streamable = template.streamAll(LegoSet.class);
assertThat(streamable)
.extracting("id", "manual.id", "manual.content") //
.containsExactly(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
}
@Test // DATAJDBC-101 @Test // DATAJDBC-101
void saveAndLoadManyEntitiesWithReferencedEntitySorted() { void saveAndLoadManyEntitiesWithReferencedEntitySorted() {
@ -323,6 +336,20 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests {
.containsExactly("Frozen", "Lava", "Star"); .containsExactly("Frozen", "Lava", "Star");
} }
@Test // GH-1714
void saveAndLoadManyEntitiesWithReferencedEntitySortedLikeStream() {
template.save(createLegoSet("Lava"));
template.save(createLegoSet("Star"));
template.save(createLegoSet("Frozen"));
Stream<LegoSet> reloadedLegoSets = template.streamAll(LegoSet.class, Sort.by("name"));
assertThat(reloadedLegoSets) //
.extracting("name") //
.containsExactly("Frozen", "Lava", "Star");
}
@Test // DATAJDBC-101 @Test // DATAJDBC-101
void saveAndLoadManyEntitiesWithReferencedEntitySortedAndPaged() { void saveAndLoadManyEntitiesWithReferencedEntitySortedAndPaged() {
@ -360,6 +387,12 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests {
.isInstanceOf(InvalidPersistentPropertyPath.class); .isInstanceOf(InvalidPersistentPropertyPath.class);
} }
@Test // GH-1714
void findByNonPropertySortLikeStreamFails() {
assertThatThrownBy(() -> template.streamAll(LegoSet.class, Sort.by("somethingNotExistant")))
.isInstanceOf(InvalidPersistentPropertyPath.class);
}
@Test // DATAJDBC-112 @Test // DATAJDBC-112
void saveAndLoadManyEntitiesByIdWithReferencedEntity() { void saveAndLoadManyEntitiesByIdWithReferencedEntity() {
@ -371,6 +404,17 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests {
.contains(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content)); .contains(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
} }
@Test // GH-1714
void saveAndLoadManyEntitiesByIdWithReferencedEntityLikeStream() {
template.save(legoSet);
Stream<LegoSet> reloadedLegoSets = template.streamAllByIds(singletonList(legoSet.id), LegoSet.class);
assertThat(reloadedLegoSets).hasSize(1).extracting("id", "manual.id", "manual.content")
.contains(tuple(legoSet.id, legoSet.manual.id, legoSet.manual.content));
}
@Test // DATAJDBC-112 @Test // DATAJDBC-112
void saveAndLoadAnEntityWithReferencedNullEntity() { void saveAndLoadAnEntityWithReferencedNullEntity() {

123
spring-data-jdbc/src/test/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategyUnitTests.java

@ -22,7 +22,12 @@ import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
import static org.springframework.data.relational.core.sql.SqlIdentifier.*; import static org.springframework.data.relational.core.sql.SqlIdentifier.*;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;
import org.apache.ibatis.cursor.Cursor;
import org.apache.ibatis.session.SqlSession; import org.apache.ibatis.session.SqlSession;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
@ -43,6 +48,7 @@ import org.springframework.data.relational.core.mapping.RelationalPersistentProp
* @author Mark Paluch * @author Mark Paluch
* @author Tyler Van Gorder * @author Tyler Van Gorder
* @author Chirag Tailor * @author Chirag Tailor
* @author Sergey Korotaev
*/ */
public class MyBatisDataAccessStrategyUnitTests { public class MyBatisDataAccessStrategyUnitTests {
@ -241,6 +247,36 @@ public class MyBatisDataAccessStrategyUnitTests {
); );
} }
@Test
public void streamAll() {
String value = "some answer";
Cursor<String> cursor = getCursor(value);
when(session.selectCursor(anyString(), any())).then(answer -> cursor);
Stream<String> streamable = accessStrategy.streamAll(String.class);
verify(session).selectCursor(eq("java.lang.StringMapper.streamAll"), captor.capture());
assertThat(streamable).isNotNull().containsExactly(value);
assertThat(captor.getValue()) //
.isNotNull() //
.extracting( //
MyBatisContext::getInstance, //
MyBatisContext::getId, //
MyBatisContext::getDomainType, //
c -> c.get("key") //
).containsExactly( //
null, //
null, //
String.class, //
null //
);
}
@Test // DATAJDBC-123 @Test // DATAJDBC-123
public void findAllById() { public void findAllById() {
@ -263,6 +299,33 @@ public class MyBatisDataAccessStrategyUnitTests {
); );
} }
@Test
public void streamAllByIds() {
String value = "some answer 2";
Cursor<String> cursor = getCursor(value);
when(session.selectCursor(anyString(), any())).then(answer -> cursor);
accessStrategy.streamAllByIds(asList("id1", "id2"), String.class);
verify(session).selectCursor(eq("java.lang.StringMapper.streamAllByIds"), captor.capture());
assertThat(captor.getValue()) //
.isNotNull() //
.extracting( //
MyBatisContext::getInstance, //
MyBatisContext::getId, //
MyBatisContext::getDomainType, //
c -> c.get("key") //
).containsExactly( //
null, //
asList("id1", "id2"), //
String.class, //
null //
);
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test // DATAJDBC-384 @Test // DATAJDBC-384
public void findAllByPath() { public void findAllByPath() {
@ -367,6 +430,33 @@ public class MyBatisDataAccessStrategyUnitTests {
); );
} }
@Test
public void streamAllSorted() {
String value = "some answer 3";
Cursor<String> cursor = getCursor(value);
when(session.selectCursor(anyString(), any())).then(answer -> cursor);
accessStrategy.streamAll(String.class, Sort.by("length"));
verify(session).selectCursor(eq("java.lang.StringMapper.streamAllSorted"), captor.capture());
assertThat(captor.getValue()) //
.isNotNull() //
.extracting( //
MyBatisContext::getInstance, //
MyBatisContext::getId, //
MyBatisContext::getDomainType, //
c -> c.get("sort") //
).containsExactly( //
null, //
null, //
String.class, //
Sort.by("length") //
);
}
@Test // DATAJDBC-101 @Test // DATAJDBC-101
public void findAllPaged() { public void findAllPaged() {
@ -399,5 +489,36 @@ public class MyBatisDataAccessStrategyUnitTests {
ChildTwo two; ChildTwo two;
} }
private static class ChildTwo {} private static class ChildTwo {
}
private Cursor<String> getCursor(String value) {
return new Cursor<>() {
@Override
public boolean isOpen() {
return false;
}
@Override
public boolean isConsumed() {
return false;
}
@Override
public int getCurrentIndex() {
return 0;
}
@Override
public void close() {
}
@NotNull
@Override
public Iterator<String> iterator() {
return List.of(value).iterator();
}
};
}
} }

Loading…
Cancel
Save