Browse Source

Polishing.

Replace code duplications with doWithBatch(…) method. Return most concrete type in DefaultDataAccessStrategy and MyBatisDataAccessStrategy.

See #1623
Original pull request: #1897
pull/1905/head
Mark Paluch 1 year ago
parent
commit
7cf81aed35
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 60
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java
  2. 14
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java
  3. 14
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java

60
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java

@ -16,6 +16,7 @@
package org.springframework.data.jdbc.core; package org.springframework.data.jdbc.core;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
@ -23,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.StreamSupport; import java.util.stream.StreamSupport;
@ -56,6 +58,7 @@ import org.springframework.data.util.Streamable;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
/** /**
* {@link JdbcAggregateOperations} implementation, storing aggregates in and obtaining them from a JDBC data store. * {@link JdbcAggregateOperations} implementation, storing aggregates in and obtaining them from a JDBC data store.
@ -173,19 +176,8 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
@Override @Override
public <T> List<T> saveAll(Iterable<T> instances) { public <T> List<T> saveAll(Iterable<T> instances) {
return doWithBatch(instances, entity -> changeCreatorSelectorForSave(entity).apply(entity), this::verifyIdProperty,
Assert.notNull(instances, "Aggregate instances must not be null"); this::performSaveAll);
if (!instances.iterator().hasNext()) {
return Collections.emptyList();
}
List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
for (T instance : instances) {
verifyIdProperty(instance);
entityAndChangeCreators.add(new EntityAndChangeCreator<>(instance, changeCreatorSelectorForSave(instance)));
}
return performSaveAll(entityAndChangeCreators);
} }
/** /**
@ -206,21 +198,7 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
@Override @Override
public <T> List<T> insertAll(Iterable<T> instances) { public <T> List<T> insertAll(Iterable<T> instances) {
return doWithBatch(instances, entity -> createInsertChange(prepareVersionForInsert(entity)), this::performSaveAll);
Assert.notNull(instances, "Aggregate instances must not be null");
if (!instances.iterator().hasNext()) {
return Collections.emptyList();
}
List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
for (T instance : instances) {
Function<T, RootAggregateChange<T>> changeCreator = entity -> createInsertChange(prepareVersionForInsert(entity));
EntityAndChangeCreator<T> entityChange = new EntityAndChangeCreator<>(instance, changeCreator);
entityAndChangeCreators.add(entityChange);
}
return performSaveAll(entityAndChangeCreators);
} }
/** /**
@ -241,21 +219,35 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
@Override @Override
public <T> List<T> updateAll(Iterable<T> instances) { public <T> List<T> updateAll(Iterable<T> instances) {
return doWithBatch(instances, entity -> createUpdateChange(prepareVersionForUpdate(entity)), this::performSaveAll);
}
private <T> List<T> doWithBatch(Iterable<T> iterable, Function<T, RootAggregateChange<T>> changeCreator,
Function<List<EntityAndChangeCreator<T>>, List<T>> performFunction) {
return doWithBatch(iterable, changeCreator, entity -> {}, performFunction);
}
Assert.notNull(instances, "Aggregate instances must not be null"); private <T> List<T> doWithBatch(Iterable<T> iterable, Function<T, RootAggregateChange<T>> changeCreator,
Consumer<T> beforeEntityChange, Function<List<EntityAndChangeCreator<T>>, List<T>> performFunction) {
if (!instances.iterator().hasNext()) { Assert.notNull(iterable, "Aggregate instances must not be null");
if (ObjectUtils.isEmpty(iterable)) {
return Collections.emptyList(); return Collections.emptyList();
} }
List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>(); List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>(
for (T instance : instances) { iterable instanceof Collection<?> c ? c.size() : 16);
for (T instance : iterable) {
beforeEntityChange.accept(instance);
Function<T, RootAggregateChange<T>> changeCreator = entity -> createUpdateChange(prepareVersionForUpdate(entity));
EntityAndChangeCreator<T> entityChange = new EntityAndChangeCreator<>(instance, changeCreator); EntityAndChangeCreator<T> entityChange = new EntityAndChangeCreator<>(instance, changeCreator);
entityAndChangeCreators.add(entityChange); entityAndChangeCreators.add(entityChange);
} }
return performSaveAll(entityAndChangeCreators);
return performFunction.apply(entityAndChangeCreators);
} }
@Override @Override

14
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

@ -272,12 +272,12 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
} }
@Override @Override
public <T> Iterable<T> findAll(Class<T> domainType) { public <T> List<T> findAll(Class<T> domainType) {
return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType)); return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType));
} }
@Override @Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) { public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
if (!ids.iterator().hasNext()) { if (!ids.iterator().hasNext()) {
return Collections.emptyList(); return Collections.emptyList();
@ -290,7 +290,7 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public Iterable<Object> findAllByPath(Identifier identifier, public List<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> propertyPath) { PersistentPropertyPath<? extends RelationalPersistentProperty> propertyPath) {
Assert.notNull(identifier, "identifier must not be null"); Assert.notNull(identifier, "identifier must not be null");
@ -338,12 +338,12 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
} }
@Override @Override
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) { public <T> List<T> findAll(Class<T> domainType, Sort sort) {
return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType)); return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType));
} }
@Override @Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) { public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType)); return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType));
} }
@ -361,7 +361,7 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
} }
@Override @Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType) { public <T> List<T> findAll(Query query, Class<T> domainType) {
MapSqlParameterSource parameterSource = new MapSqlParameterSource(); MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource); String sqlQuery = sql(domainType).selectByQuery(query, parameterSource);
@ -370,7 +370,7 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
} }
@Override @Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) { public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
MapSqlParameterSource parameterSource = new MapSqlParameterSource(); MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource, pageable); String sqlQuery = sql(domainType).selectByQuery(query, parameterSource, pageable);

14
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java

@ -256,7 +256,7 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
} }
@Override @Override
public <T> Iterable<T> findAll(Class<T> domainType) { public <T> List<T> findAll(Class<T> domainType) {
String statement = namespace(domainType) + ".findAll"; String statement = namespace(domainType) + ".findAll";
MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap()); MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap());
@ -264,13 +264,13 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
} }
@Override @Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) { public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return sqlSession().selectList(namespace(domainType) + ".findAllById", return sqlSession().selectList(namespace(domainType) + ".findAllById",
new MyBatisContext(ids, null, domainType, Collections.emptyMap())); new MyBatisContext(ids, null, domainType, Collections.emptyMap()));
} }
@Override @Override
public Iterable<Object> findAllByPath(Identifier identifier, public List<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) { PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
String statementName = namespace(getOwnerTyp(path)) + ".findAllByPath-" + path.toDotPath(); String statementName = namespace(getOwnerTyp(path)) + ".findAllByPath-" + path.toDotPath();
@ -288,7 +288,7 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
} }
@Override @Override
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) { public <T> List<T> findAll(Class<T> domainType, Sort sort) {
Map<String, Object> additionalContext = new HashMap<>(); Map<String, Object> additionalContext = new HashMap<>();
additionalContext.put("sort", sort); additionalContext.put("sort", sort);
@ -297,7 +297,7 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
} }
@Override @Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) { public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
Map<String, Object> additionalContext = new HashMap<>(); Map<String, Object> additionalContext = new HashMap<>();
additionalContext.put("pageable", pageable); additionalContext.put("pageable", pageable);
@ -311,12 +311,12 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy {
} }
@Override @Override
public <T> Iterable<T> findAll(Query query, Class<T> probeType) { public <T> List<T> findAll(Query query, Class<T> probeType) {
throw new UnsupportedOperationException("Not implemented"); throw new UnsupportedOperationException("Not implemented");
} }
@Override @Override
public <T> Iterable<T> findAll(Query query, Class<T> probeType, Pageable pageable) { public <T> List<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
throw new UnsupportedOperationException("Not implemented"); throw new UnsupportedOperationException("Not implemented");
} }

Loading…
Cancel
Save