diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java index 407089473..32813be26 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java @@ -48,30 +48,24 @@ import org.springframework.lang.Nullable; * intermediate {@link RowDocumentResultSetExtractor RowDocument} and mapped via * {@link org.springframework.data.relational.core.conversion.RelationalConverter#read(Class, RowDocument)}. * - * @param the type of aggregate produced by this reader. * @author Jens Schauder * @author Mark Paluch * @since 3.2 */ -class AggregateReader implements PathToColumnMapping { +class AggregateReader implements PathToColumnMapping { - private final RelationalPersistentEntity aggregate; - private final Table table; + private final AliasFactory aliasFactory; private final SqlGenerator sqlGenerator; private final JdbcConverter converter; private final NamedParameterJdbcOperations jdbcTemplate; - private final AliasFactory aliasFactory; private final RowDocumentResultSetExtractor extractor; - AggregateReader(Dialect dialect, JdbcConverter converter, AliasFactory aliasFactory, - NamedParameterJdbcOperations jdbcTemplate, RelationalPersistentEntity aggregate) { + AggregateReader(Dialect dialect, JdbcConverter converter, NamedParameterJdbcOperations jdbcTemplate) { + this.aliasFactory = new AliasFactory(); this.converter = converter; - this.aggregate = aggregate; this.jdbcTemplate = jdbcTemplate; - this.table = Table.create(aggregate.getQualifiedTableName()); - this.sqlGenerator = new SingleQuerySqlGenerator(converter.getMappingContext(), aliasFactory, dialect, aggregate); - this.aliasFactory = aliasFactory; + this.sqlGenerator = new SingleQuerySqlGenerator(converter.getMappingContext(), aliasFactory, dialect); this.extractor = new RowDocumentResultSetExtractor(converter.getMappingContext(), this); } @@ -92,55 +86,96 @@ class AggregateReader implements PathToColumnMapping { return aliasFactory.getKeyAlias(path); } + /** + * Select a single aggregate by its identifier. + * + * @param id the identifier, must not be {@literal null}. + * @param entity the persistent entity type must not be {@literal null}. + * @return the found aggregate root, or {@literal null} if not found. + * @param aggregator type. + */ @Nullable - public T findById(Object id) { + public T findById(Object id, RelationalPersistentEntity entity) { - Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).is(id)).limit(1); + Query query = Query.query(Criteria.where(entity.getRequiredIdProperty().getName()).is(id)).limit(1); - return findOne(query); + return findOne(query, entity); } + /** + * Select a single aggregate by a {@link Query}. + * + * @param query the query to run, must not be {@literal null}. + * @param entity the persistent entity type must not be {@literal null}. + * @return the found aggregate root, or {@literal null} if not found. + * @param aggregator type. + */ @Nullable - public T findOne(Query query) { - return doFind(query, this::extractZeroOrOne); + public T findOne(Query query, RelationalPersistentEntity entity) { + return doFind(query, entity, rs -> extractZeroOrOne(rs, entity)); } - public List findAllById(Iterable ids) { + /** + * Select aggregates by their identifiers. + * + * @param ids the identifiers, must not be {@literal null}. + * @param entity the persistent entity type must not be {@literal null}. + * @return the found aggregate roots. The resulting list can be empty or may not contain objects that correspond to + * the identifiers when the objects are not found in the database. + * @param aggregator type. + */ + public List findAllById(Iterable ids, RelationalPersistentEntity entity) { Collection identifiers = ids instanceof Collection idl ? idl : Streamable.of(ids).toList(); - Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).in(identifiers)); + Query query = Query.query(Criteria.where(entity.getRequiredIdProperty().getName()).in(identifiers)); - return findAll(query); + return findAll(query, entity); } + /** + * Select all aggregates by type. + * + * @param entity the persistent entity type must not be {@literal null}. + * @return the found aggregate roots. + * @param aggregator type. + */ @SuppressWarnings("ConstantConditions") - public List findAll() { - return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll); + public List findAll(RelationalPersistentEntity entity) { + return jdbcTemplate.query(sqlGenerator.findAll(entity), + (ResultSetExtractor>) rs -> extractAll(rs, entity)); } - public List findAll(Query query) { - return doFind(query, this::extractAll); + /** + * Select all aggregates by query. + * + * @param query the query to run, must not be {@literal null}. + * @param entity the persistent entity type must not be {@literal null}. + * @return the found aggregate roots. + * @param aggregator type. + */ + public List findAll(Query query, RelationalPersistentEntity entity) { + return doFind(query, entity, rs -> extractAll(rs, entity)); } @SuppressWarnings("ConstantConditions") - private R doFind(Query query, ResultSetExtractor extractor) { + private R doFind(Query query, RelationalPersistentEntity entity, ResultSetExtractor extractor) { MapSqlParameterSource parameterSource = new MapSqlParameterSource(); - Condition condition = createCondition(query, parameterSource); - String sql = sqlGenerator.findAll(condition); + Condition condition = createCondition(query, parameterSource, entity); + String sql = sqlGenerator.findAll(entity, condition); return jdbcTemplate.query(sql, parameterSource, extractor); } @Nullable - private Condition createCondition(Query query, MapSqlParameterSource parameterSource) { + private Condition createCondition(Query query, MapSqlParameterSource parameterSource, + RelationalPersistentEntity entity) { QueryMapper queryMapper = new QueryMapper(converter); Optional criteria = query.getCriteria(); - return criteria - .map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate)) - .orElse(null); + return criteria.map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, + Table.create(entity.getQualifiedTableName()), entity)).orElse(null); } /** @@ -152,12 +187,13 @@ class AggregateReader implements PathToColumnMapping { * @return a {@code List} of aggregates, fully converted. * @throws SQLException on underlying JDBC errors. */ - private List extractAll(ResultSet rs) throws SQLException { + private List extractAll(ResultSet rs, RelationalPersistentEntity entity) throws SQLException { - Iterator iterate = extractor.iterate(aggregate, rs); + Iterator iterate = extractor.iterate(entity, rs); List resultList = new ArrayList<>(); + while (iterate.hasNext()) { - resultList.add(converter.read(aggregate.getType(), iterate.next())); + resultList.add(converter.read(entity.getType(), iterate.next())); } return resultList; @@ -175,17 +211,19 @@ class AggregateReader implements PathToColumnMapping { * @throws IncorrectResultSizeDataAccessException when the conversion yields more than one instance. */ @Nullable - private T extractZeroOrOne(ResultSet rs) throws SQLException { + private T extractZeroOrOne(ResultSet rs, RelationalPersistentEntity entity) throws SQLException { + + Iterator iterate = extractor.iterate(entity, rs); - Iterator iterate = extractor.iterate(aggregate, rs); if (iterate.hasNext()) { RowDocument object = iterate.next(); if (iterate.hasNext()) { throw new IncorrectResultSizeDataAccessException(1); } - return converter.read(aggregate.getType(), object); + return converter.read(entity.getType(), object); } + return null; } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java index d5fc206e0..6b7c95142 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java @@ -25,9 +25,7 @@ import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.query.Query; -import org.springframework.data.relational.core.sqlgeneration.AliasFactory; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; -import org.springframework.util.ConcurrentLruCache; /** * A {@link ReadingDataAccessStrategy} that uses an {@link AggregateReader} to load entities with a single query. @@ -39,31 +37,28 @@ import org.springframework.util.ConcurrentLruCache; class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy { private final RelationalMappingContext mappingContext; - private final AliasFactory aliasFactory; - private final ConcurrentLruCache, AggregateReader> readerCache; + private final AggregateReader aggregateReader; public SingleQueryDataAccessStrategy(Dialect dialect, JdbcConverter converter, NamedParameterJdbcOperations jdbcTemplate) { this.mappingContext = converter.getMappingContext(); - this.aliasFactory = new AliasFactory(); - this.readerCache = new ConcurrentLruCache<>(256, - entity -> new AggregateReader<>(dialect, converter, aliasFactory, jdbcTemplate, entity)); + this.aggregateReader = new AggregateReader(dialect, converter, jdbcTemplate); } @Override public T findById(Object id, Class domainType) { - return getReader(domainType).findById(id); + return aggregateReader.findById(id, getPersistentEntity(domainType)); } @Override public List findAll(Class domainType) { - return getReader(domainType).findAll(); + return aggregateReader.findAll(getPersistentEntity(domainType)); } @Override public List findAllById(Iterable ids, Class domainType) { - return getReader(domainType).findAllById(ids); + return aggregateReader.findAllById(ids, getPersistentEntity(domainType)); } @Override @@ -78,12 +73,12 @@ class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy { @Override public Optional findOne(Query query, Class domainType) { - return Optional.ofNullable(getReader(domainType).findOne(query)); + return Optional.ofNullable(aggregateReader.findOne(query, getPersistentEntity(domainType))); } @Override public List findAll(Query query, Class domainType) { - return getReader(domainType).findAll(query); + return aggregateReader.findAll(query, getPersistentEntity(domainType)); } @Override @@ -92,11 +87,7 @@ class SingleQueryDataAccessStrategy implements ReadingDataAccessStrategy { } @SuppressWarnings("unchecked") - private AggregateReader getReader(Class domainType) { - - RelationalPersistentEntity persistentEntity = (RelationalPersistentEntity) mappingContext - .getRequiredPersistentEntity(domainType); - - return (AggregateReader) readerCache.get(persistentEntity); + private RelationalPersistentEntity getPersistentEntity(Class domainType) { + return (RelationalPersistentEntity) mappingContext.getRequiredPersistentEntity(domainType); } } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java index 20bc95686..e337383fa 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java @@ -18,12 +18,9 @@ package org.springframework.data.jdbc.core; import static java.util.Arrays.*; import static java.util.Collections.*; import static org.assertj.core.api.Assertions.*; -import static org.assertj.core.api.SoftAssertions.*; import static org.springframework.data.jdbc.testing.TestConfiguration.*; import static org.springframework.data.jdbc.testing.TestDatabaseFeatures.Feature.*; -import java.sql.ResultSet; -import java.sql.SQLException; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Collections; @@ -38,8 +35,6 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.IntStream; -import org.assertj.core.api.SoftAssertions; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationEventPublisher; @@ -47,7 +42,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; import org.springframework.dao.IncorrectResultSizeDataAccessException; -import org.springframework.dao.DataAccessException; import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.annotation.Id; @@ -74,7 +68,6 @@ import org.springframework.data.relational.core.mapping.Table; import org.springframework.data.relational.core.query.Criteria; import org.springframework.data.relational.core.query.CriteriaDefinition; import org.springframework.data.relational.core.query.Query; -import org.springframework.jdbc.core.ResultSetExtractor; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.ContextConfiguration; @@ -191,11 +184,11 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { private static LegoSet createLegoSet(String name) { LegoSet entity = new LegoSet(); - entity.name = (name); + entity.name = name; Manual manual = new Manual(); - manual.content = ("Accelerates to 99% of light speed; Destroys almost everything. See https://what-if.xkcd.com/1/"); - entity.manual = (manual); + manual.content = "Accelerates to 99% of light speed; Destroys almost everything. See https://what-if.xkcd.com/1/"; + entity.manual = manual; return entity; } @@ -304,13 +297,10 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { assertThat(reloadedLegoSet.manual).isNotNull(); - assertSoftly(softly -> { - softly.assertThat(reloadedLegoSet.manual.id) // - .isEqualTo(legoSet.manual.id) // - .isNotNull(); - softly.assertThat(reloadedLegoSet.manual.content).isEqualTo(legoSet.manual.content); - }); - + assertThat(reloadedLegoSet.manual.id) // + .isEqualTo(legoSet.manual.id) // + .isNotNull(); + assertThat(reloadedLegoSet.manual.content).isEqualTo(legoSet.manual.content); } @Test // DATAJDBC-112 @@ -378,7 +368,6 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { assertThatThrownBy(() -> template.findAll(LegoSet.class, Sort.by("somethingNotExistant"))) .isInstanceOf(InvalidPersistentPropertyPath.class); - } @Test // DATAJDBC-112 @@ -397,7 +386,7 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { @EnabledOnFeature(SUPPORTS_QUOTED_IDS) void saveAndLoadAnEntityWithReferencedNullEntity() { - legoSet.manual = (null); + legoSet.manual = null; template.save(legoSet); @@ -414,11 +403,8 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { template.delete(legoSet); - assertSoftly(softly -> { - - softly.assertThat(template.findAll(LegoSet.class)).isEmpty(); - softly.assertThat(template.findAll(Manual.class)).isEmpty(); - }); + assertThat(template.findAll(LegoSet.class)).isEmpty(); + assertThat(template.findAll(Manual.class)).isEmpty(); } @Test // DATAJDBC-112 @@ -429,12 +415,8 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { template.deleteAll(LegoSet.class); - assertSoftly(softly -> { - - softly.assertThat(template.findAll(LegoSet.class)).isEmpty(); - softly.assertThat(template.findAll(Manual.class)).isEmpty(); - }); - + assertThat(template.findAll(LegoSet.class)).isEmpty(); + assertThat(template.findAll(Manual.class)).isEmpty(); } @Test // GH-537 @@ -447,11 +429,8 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { template.deleteAll(List.of(legoSet1, legoSet2)); - assertSoftly(softly -> { - - softly.assertThat(template.findAll(LegoSet.class)).extracting(l -> l.name).containsExactly("Some other Name"); - softly.assertThat(template.findAll(Manual.class)).hasSize(1); - }); + assertThat(template.findAll(LegoSet.class)).extracting(l -> l.name).containsExactly("Some other Name"); + assertThat(template.findAll(Manual.class)).hasSize(1); } @Test // GH-537 @@ -464,11 +443,8 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { template.deleteAllById(List.of(legoSet1.id, legoSet2.id), LegoSet.class); - assertSoftly(softly -> { - - softly.assertThat(template.findAll(LegoSet.class)).extracting(l -> l.name).containsExactly("Some other Name"); - softly.assertThat(template.findAll(Manual.class)).hasSize(1); - }); + assertThat(template.findAll(LegoSet.class)).extracting(l -> l.name).containsExactly("Some other Name"); + assertThat(template.findAll(Manual.class)).hasSize(1); } @Test @@ -525,9 +501,9 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { template.save(legoSet); Manual manual = new Manual(); - manual.id = (23L); - manual.content = ("Some content"); - legoSet.manual = (manual); + manual.id = 23L; + manual.content = "Some content"; + legoSet.manual = manual; template.save(legoSet); @@ -542,18 +518,14 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { template.save(legoSet); - legoSet.manual = (null); + legoSet.manual = null; template.save(legoSet); LegoSet reloadedLegoSet = template.findById(legoSet.id, LegoSet.class); - SoftAssertions softly = new SoftAssertions(); - - softly.assertThat(reloadedLegoSet.manual).isNull(); - softly.assertThat(template.findAll(Manual.class)).describedAs("Manuals failed to delete").isEmpty(); - - softly.assertAll(); + assertThat(reloadedLegoSet.manual).isNull(); + assertThat(template.findAll(Manual.class)).describedAs("Manuals failed to delete").isEmpty(); } @Test @@ -561,7 +533,7 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { void updateFailedRootDoesNotExist() { LegoSet entity = new LegoSet(); - entity.id = (100L); // does not exist in the database + entity.id = 100L; // does not exist in the database assertThatExceptionOfType(DbActionExecutionException.class) // .isThrownBy(() -> template.save(entity)) // @@ -575,18 +547,15 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { template.save(legoSet); Manual manual = new Manual(); - manual.content = ("other content"); - legoSet.manual = (manual); + manual.content = "other content"; + legoSet.manual = manual; template.save(legoSet); LegoSet reloadedLegoSet = template.findById(legoSet.id, LegoSet.class); - assertSoftly(softly -> { - - softly.assertThat(reloadedLegoSet.manual.content).isEqualTo("other content"); - softly.assertThat(template.findAll(Manual.class)).describedAs("There should be only one manual").hasSize(1); - }); + assertThat(reloadedLegoSet.manual.content).isEqualTo("other content"); + assertThat(template.findAll(Manual.class)).describedAs("There should be only one manual").hasSize(1); } @Test // DATAJDBC-112 @@ -595,7 +564,7 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { template.save(legoSet); - legoSet.manual.content = ("new content"); + legoSet.manual.content = "new content"; template.save(legoSet); @@ -678,14 +647,11 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { LegoSet reloadedLegoSet = template.findById(legoSet.id, LegoSet.class); - assertSoftly(softly -> { - - softly.assertThat(reloadedLegoSet.alternativeInstructions).isNotNull(); - softly.assertThat(reloadedLegoSet.alternativeInstructions.id).isNotNull(); - softly.assertThat(reloadedLegoSet.alternativeInstructions.id).isNotEqualTo(reloadedLegoSet.manual.id); - softly.assertThat(reloadedLegoSet.alternativeInstructions.content) - .isEqualTo(reloadedLegoSet.alternativeInstructions.content); - }); + assertThat(reloadedLegoSet.alternativeInstructions).isNotNull(); + assertThat(reloadedLegoSet.alternativeInstructions.id).isNotNull(); + assertThat(reloadedLegoSet.alternativeInstructions.id).isNotEqualTo(reloadedLegoSet.manual.id); + assertThat(reloadedLegoSet.alternativeInstructions.content) + .isEqualTo(reloadedLegoSet.alternativeInstructions.content); } @Test // DATAJDBC-276 @@ -927,14 +893,11 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { NoIdListChain4 saved = template.save(createNoIdTree()); template.deleteById(saved.four, NoIdListChain4.class); - assertSoftly(softly -> { - - softly.assertThat(count("NO_ID_LIST_CHAIN4")).describedAs("Chain4 elements got deleted").isEqualTo(0); - softly.assertThat(count("NO_ID_LIST_CHAIN3")).describedAs("Chain3 elements got deleted").isEqualTo(0); - softly.assertThat(count("NO_ID_LIST_CHAIN2")).describedAs("Chain2 elements got deleted").isEqualTo(0); - softly.assertThat(count("NO_ID_LIST_CHAIN1")).describedAs("Chain1 elements got deleted").isEqualTo(0); - softly.assertThat(count("NO_ID_LIST_CHAIN0")).describedAs("Chain0 elements got deleted").isEqualTo(0); - }); + assertThat(count("NO_ID_LIST_CHAIN4")).describedAs("Chain4 elements got deleted").isEqualTo(0); + assertThat(count("NO_ID_LIST_CHAIN3")).describedAs("Chain3 elements got deleted").isEqualTo(0); + assertThat(count("NO_ID_LIST_CHAIN2")).describedAs("Chain2 elements got deleted").isEqualTo(0); + assertThat(count("NO_ID_LIST_CHAIN1")).describedAs("Chain1 elements got deleted").isEqualTo(0); + assertThat(count("NO_ID_LIST_CHAIN0")).describedAs("Chain0 elements got deleted").isEqualTo(0); } @Test @@ -956,14 +919,11 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { NoIdMapChain4 saved = template.save(createNoIdMapTree()); template.deleteById(saved.four, NoIdMapChain4.class); - assertSoftly(softly -> { - - softly.assertThat(count("NO_ID_MAP_CHAIN4")).describedAs("Chain4 elements got deleted").isEqualTo(0); - softly.assertThat(count("NO_ID_MAP_CHAIN3")).describedAs("Chain3 elements got deleted").isEqualTo(0); - softly.assertThat(count("NO_ID_MAP_CHAIN2")).describedAs("Chain2 elements got deleted").isEqualTo(0); - softly.assertThat(count("NO_ID_MAP_CHAIN1")).describedAs("Chain1 elements got deleted").isEqualTo(0); - softly.assertThat(count("NO_ID_MAP_CHAIN0")).describedAs("Chain0 elements got deleted").isEqualTo(0); - }); + assertThat(count("NO_ID_MAP_CHAIN4")).describedAs("Chain4 elements got deleted").isEqualTo(0); + assertThat(count("NO_ID_MAP_CHAIN3")).describedAs("Chain3 elements got deleted").isEqualTo(0); + assertThat(count("NO_ID_MAP_CHAIN2")).describedAs("Chain2 elements got deleted").isEqualTo(0); + assertThat(count("NO_ID_MAP_CHAIN1")).describedAs("Chain1 elements got deleted").isEqualTo(0); + assertThat(count("NO_ID_MAP_CHAIN0")).describedAs("Chain0 elements got deleted").isEqualTo(0); } @Test // DATAJDBC-431 @@ -978,7 +938,7 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { assertThat( jdbcTemplate.queryForObject("SELECT read_only FROM with_read_only", Collections.emptyMap(), String.class)) - .isEqualTo("from-db"); + .isEqualTo("from-db"); } @Test @@ -1230,21 +1190,17 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { MultipleCollections reloaded = template.findById(aggregate.id, MultipleCollections.class); - assertSoftly(softly -> { - - softly.assertThat(reloaded.name).isEqualTo(aggregate.name); + assertThat(reloaded.name).isEqualTo(aggregate.name); - softly.assertThat(reloaded.listElements).containsExactly(aggregate.listElements.get(0), - aggregate.listElements.get(1), aggregate.listElements.get(2)); + assertThat(reloaded.listElements).containsExactly(aggregate.listElements.get(0), aggregate.listElements.get(1), + aggregate.listElements.get(2)); - softly.assertThat(reloaded.setElements) - .containsExactlyInAnyOrder(aggregate.setElements.toArray(new SetElement[0])); + assertThat(reloaded.setElements).containsExactlyInAnyOrder(aggregate.setElements.toArray(new SetElement[0])); - softly.assertThat(reloaded.mapElements.get("alpha")).isEqualTo(new MapElement("one")); - softly.assertThat(reloaded.mapElements.get("beta")).isEqualTo(new MapElement("two")); - softly.assertThat(reloaded.mapElements.get("gamma")).isEqualTo(new MapElement("three")); - softly.assertThat(reloaded.mapElements.get("delta")).isEqualTo(new MapElement("four")); - }); + assertThat(reloaded.mapElements.get("alpha")).isEqualTo(new MapElement("one")); + assertThat(reloaded.mapElements.get("beta")).isEqualTo(new MapElement("two")); + assertThat(reloaded.mapElements.get("gamma")).isEqualTo(new MapElement("three")); + assertThat(reloaded.mapElements.get("delta")).isEqualTo(new MapElement("four")); } @Test // GH-1448 @@ -1266,21 +1222,17 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { MultipleCollections reloaded = template.findById(aggregate.id, MultipleCollections.class); - assertSoftly(softly -> { - - softly.assertThat(reloaded.name).isEqualTo(aggregate.name); + assertThat(reloaded.name).isEqualTo(aggregate.name); - softly.assertThat(reloaded.listElements).containsExactly(aggregate.listElements.get(0), - aggregate.listElements.get(1), aggregate.listElements.get(2)); + assertThat(reloaded.listElements).containsExactly(aggregate.listElements.get(0), aggregate.listElements.get(1), + aggregate.listElements.get(2)); - softly.assertThat(reloaded.setElements) - .containsExactlyInAnyOrder(aggregate.setElements.toArray(new SetElement[0])); + assertThat(reloaded.setElements).containsExactlyInAnyOrder(aggregate.setElements.toArray(new SetElement[0])); - softly.assertThat(reloaded.mapElements.get("alpha")).isEqualTo(new MapElement("one")); - softly.assertThat(reloaded.mapElements.get("beta")).isEqualTo(new MapElement("two")); - softly.assertThat(reloaded.mapElements.get("gamma")).isEqualTo(new MapElement("three")); - softly.assertThat(reloaded.mapElements.get("delta")).isEqualTo(new MapElement("four")); - }); + assertThat(reloaded.mapElements.get("alpha")).isEqualTo(new MapElement("one")); + assertThat(reloaded.mapElements.get("beta")).isEqualTo(new MapElement("two")); + assertThat(reloaded.mapElements.get("gamma")).isEqualTo(new MapElement("three")); + assertThat(reloaded.mapElements.get("delta")).isEqualTo(new MapElement("four")); } @Test // GH-1448 @@ -1301,20 +1253,16 @@ abstract class AbstractJdbcAggregateTemplateIntegrationTests { MultipleCollections reloaded = template.findById(aggregate.id, MultipleCollections.class); - assertSoftly(softly -> { - - softly.assertThat(reloaded.name).isEqualTo(aggregate.name); + assertThat(reloaded.name).isEqualTo(aggregate.name); - softly.assertThat(reloaded.listElements).containsExactly(); + assertThat(reloaded.listElements).containsExactly(); - softly.assertThat(reloaded.setElements) - .containsExactlyInAnyOrder(aggregate.setElements.toArray(new SetElement[0])); + assertThat(reloaded.setElements).containsExactlyInAnyOrder(aggregate.setElements.toArray(new SetElement[0])); - softly.assertThat(reloaded.mapElements.get("alpha")).isEqualTo(new MapElement("one")); - softly.assertThat(reloaded.mapElements.get("beta")).isEqualTo(new MapElement("two")); - softly.assertThat(reloaded.mapElements.get("gamma")).isEqualTo(new MapElement("three")); - softly.assertThat(reloaded.mapElements.get("delta")).isEqualTo(new MapElement("four")); - }); + assertThat(reloaded.mapElements.get("alpha")).isEqualTo(new MapElement("one")); + assertThat(reloaded.mapElements.get("beta")).isEqualTo(new MapElement("two")); + assertThat(reloaded.mapElements.get("gamma")).isEqualTo(new MapElement("three")); + assertThat(reloaded.mapElements.get("delta")).isEqualTo(new MapElement("four")); } private void saveAndUpdateAggregateWithVersion(VersionedAggregate aggregate, diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQueryUnitTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQueryUnitTests.java index 500c3f133..30dc8d586 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQueryUnitTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQueryUnitTests.java @@ -29,7 +29,6 @@ import java.util.Set; import java.util.stream.Stream; import org.assertj.core.api.Assertions; -import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; @@ -441,7 +440,6 @@ class StringBasedJdbcQueryUnitTests { this.values = List.of(values); } - @NotNull @Override public Iterator iterator() { return values.iterator(); @@ -460,7 +458,7 @@ class StringBasedJdbcQueryUnitTests { } private static class DummyEntity { - private Long id; + private final Long id; public DummyEntity(Long id) { this.id = id; @@ -488,6 +486,7 @@ class StringBasedJdbcQueryUnitTests { } } + @Override public Object getRootObject() { return new ExtensionRoot(); } diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java index aef0f8a3d..e0749185d 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java @@ -22,8 +22,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import org.jetbrains.annotations.NotNull; - import org.springframework.data.r2dbc.convert.R2dbcConverter; import org.springframework.data.r2dbc.core.R2dbcEntityOperations; import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; @@ -157,11 +155,8 @@ public class StringBasedR2dbcQuery extends AbstractR2dbcQuery { @Override public String toString() { - StringBuffer sb = new StringBuffer(); - sb.append(getClass().getSimpleName()); - sb.append(" [").append(expressionQuery.getQuery()); - sb.append(']'); - return sb.toString(); + String sb = getClass().getSimpleName() + " [" + expressionQuery.getQuery() + ']'; + return sb; } private class ExpandedQuery implements PreparedOperation { @@ -234,7 +229,6 @@ public class StringBasedR2dbcQuery extends AbstractR2dbcQuery { byName.put(identifier, toParameter(value)); } - @NotNull private Parameter toParameter(Object value) { return value instanceof Parameter ? (Parameter) value : Parameter.from(value); } diff --git a/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorBenchmark.java b/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorBenchmark.java index f8a1399b5..cb49fd72f 100644 --- a/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorBenchmark.java +++ b/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorBenchmark.java @@ -38,8 +38,7 @@ public class SingleQuerySqlGeneratorBenchmark extends BenchmarkSettings { @Benchmark public String findAll(StateHolder state) { - return new SingleQuerySqlGenerator(state.context, state.aliasFactory, PostgresDialect.INSTANCE, - state.persistentEntity).findAll(null); + return new SingleQuerySqlGenerator(state.context, state.aliasFactory, PostgresDialect.INSTANCE).findAll(state.persistentEntity, null); } @State(Scope.Benchmark) diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/EmbeddedRelationalPersistentEntity.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/EmbeddedRelationalPersistentEntity.java index 3bca90a2e..3c5b3f2c8 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/EmbeddedRelationalPersistentEntity.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/EmbeddedRelationalPersistentEntity.java @@ -18,7 +18,6 @@ package org.springframework.data.relational.core.mapping; import java.lang.annotation.Annotation; import java.util.Iterator; -import org.jetbrains.annotations.NotNull; import org.springframework.data.mapping.*; import org.springframework.data.mapping.model.PersistentPropertyAccessorFactory; import org.springframework.data.relational.core.sql.SqlIdentifier; @@ -67,7 +66,6 @@ class EmbeddedRelationalPersistentEntity implements RelationalPersistentEntit @Override public void verify() throws MappingException {} - @NotNull @Override public Iterator iterator() { diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java index 949dbc6aa..5e8c94024 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java @@ -20,8 +20,6 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; import org.springframework.data.mapping.PersistentProperty; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.mapping.PersistentPropertyPaths; @@ -33,6 +31,7 @@ import org.springframework.data.relational.core.mapping.RelationalPersistentEnti import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.sql.*; import org.springframework.data.relational.core.sql.render.SqlRenderer; +import org.springframework.lang.Nullable; /** * A {@link SqlGenerator} that creates SQL statements for loading complete aggregates with a single statement. @@ -45,23 +44,20 @@ public class SingleQuerySqlGenerator implements SqlGenerator { private final RelationalMappingContext context; private final Dialect dialect; private final AliasFactory aliases; - private final RelationalPersistentEntity aggregate; - public SingleQuerySqlGenerator(RelationalMappingContext context, AliasFactory aliasFactory, Dialect dialect, - RelationalPersistentEntity aggregate) { + public SingleQuerySqlGenerator(RelationalMappingContext context, AliasFactory aliasFactory, Dialect dialect) { this.context = context; this.aliases = aliasFactory; this.dialect = dialect; - this.aggregate = aggregate; } @Override - public String findAll(@Nullable Condition condition) { - return createSelect(condition); + public String findAll(RelationalPersistentEntity aggregate, @Nullable Condition condition) { + return createSelect(aggregate, condition); } - String createSelect(@Nullable Condition condition) { + String createSelect(RelationalPersistentEntity aggregate, @Nullable Condition condition) { AggregatePath rootPath = context.getAggregatePath(aggregate); QueryMeta queryMeta = createInlineQuery(rootPath, condition); @@ -70,6 +66,7 @@ public class SingleQuerySqlGenerator implements SqlGenerator { List rownumbers = new ArrayList<>(); rownumbers.add(queryMeta.rowNumber); + PersistentPropertyPaths entityPaths = context .findPersistentPropertyPaths(aggregate.getType(), PersistentProperty::isEntity); List inlineQueries = createInlineQueries(entityPaths); @@ -83,44 +80,46 @@ public class SingleQuerySqlGenerator implements SqlGenerator { columns.add(totalRownumber); InlineQuery inlineQuery = createMainSelect(columns, rootPath, rootQuery, inlineQueries); + Expression rootId = just(aliases.getColumnAlias(rootPath.append(aggregate.getRequiredIdProperty()))); + + List selectList = getSelectList(queryMeta, inlineQueries, rootId); + Select fullQuery = StatementBuilder.select(selectList).from(inlineQuery).orderBy(rootId, just("rn")).build(false); - Expression rootIdExpression = just(aliases.getColumnAlias(rootPath.append(aggregate.getRequiredIdProperty()))); + return SqlRenderer.create(new RenderContextFactory(dialect).createRenderContext()).render(fullQuery); + } + + private static List getSelectList(QueryMeta queryMeta, List inlineQueries, Expression rootId) { + + List expressions = new ArrayList<>(inlineQueries.size() + queryMeta.simpleColumns.size() + 8); - List finalColumns = new ArrayList<>(); queryMeta.simpleColumns - .forEach(e -> finalColumns.add(filteredColumnExpression(queryMeta.rowNumber.toString(), e.toString()))); + .forEach(e -> expressions.add(filteredColumnExpression(queryMeta.rowNumber.toString(), e.toString()))); for (QueryMeta meta : inlineQueries) { + meta.simpleColumns - .forEach(e -> finalColumns.add(filteredColumnExpression(meta.rowNumber.toString(), e.toString()))); + .forEach(e -> expressions.add(filteredColumnExpression(meta.rowNumber.toString(), e.toString()))); + if (meta.id != null) { - finalColumns.add(meta.id); + expressions.add(meta.id); } if (meta.key != null) { - finalColumns.add(meta.key); + expressions.add(meta.key); } } - finalColumns.add(rootIdExpression); - - Select fullQuery = StatementBuilder.select(finalColumns).from(inlineQuery).orderBy(rootIdExpression, just("rn")) - .build(false); - - return SqlRenderer.create(new RenderContextFactory(dialect).createRenderContext()).render(fullQuery); + expressions.add(rootId); + return expressions; } - @NotNull private InlineQuery createMainSelect(List columns, AggregatePath rootPath, InlineQuery rootQuery, List inlineQueries) { SelectBuilder.SelectJoin select = StatementBuilder.select(columns).from(rootQuery); - select = applyJoins(rootPath, inlineQueries, select); - SelectBuilder.BuildSelect buildSelect = applyWhereCondition(rootPath, inlineQueries, select); - Select mainSelect = buildSelect.build(false); - - return InlineQuery.create(mainSelect, "main"); + SelectBuilder.BuildSelect buildSelect = applyWhereCondition(inlineQueries, select); + return InlineQuery.create(buildSelect.build(false), "main"); } /** @@ -159,16 +158,8 @@ public class SingleQuerySqlGenerator implements SqlGenerator { RelationalPersistentEntity entity = basePath.getRequiredLeafEntity(); Table table = Table.create(entity.getQualifiedTableName()); - List paths = new ArrayList<>(); - - entity.doWithProperties((RelationalPersistentProperty p) -> { - if (!p.isEntity()) { - paths.add(basePath.append(p)); - } - }); - + List paths = getAggregatePaths(basePath, entity); List columns = new ArrayList<>(); - List columnAliases = new ArrayList<>(); String rowNumberAlias = aliases.getRowNumberAlias(basePath); Expression rownumber = basePath.isRoot() ? new AliasedExpression(SQL.literalOf(1), rowNumberAlias) @@ -180,8 +171,10 @@ public class SingleQuerySqlGenerator implements SqlGenerator { : AnalyticFunction.create("count", Expressions.just("*")) .partitionBy(table.column(basePath.getTableInfo().reverseColumnInfo().name())).as(rowCountAlias); columns.add(count); + String backReferenceAlias = null; String keyAlias = null; + if (!basePath.isRoot()) { backReferenceAlias = aliases.getBackReferenceAlias(basePath); @@ -192,31 +185,57 @@ public class SingleQuerySqlGenerator implements SqlGenerator { ? table.column(basePath.getTableInfo().qualifierColumnInfo().name()).as(keyAlias) : createRowNumberExpression(basePath, table, keyAlias); columns.add(keyExpression); - } - String id = null; + String id = getIdentifierProperty(paths); + List columnAliases = getColumnAliases(table, paths, columns); + SelectBuilder.SelectWhere select = StatementBuilder.select(columns).from(table); + SelectBuilder.BuildSelect buildSelect = condition != null ? select.where(condition) : select; + + InlineQuery inlineQuery = InlineQuery.create(buildSelect.build(false), aliases.getTableAlias(basePath)); + return QueryMeta.of(basePath, inlineQuery, columnAliases, just(id), just(backReferenceAlias), just(keyAlias), + just(rowNumberAlias), just(rowCountAlias)); + } + + private List getColumnAliases(Table table, List paths, List columns) { + + List columnAliases = new ArrayList<>(); for (AggregatePath path : paths) { String alias = aliases.getColumnAlias(path); - if (path.getRequiredLeafProperty().isIdProperty()) { - id = alias; - } else { + if (!path.getRequiredLeafProperty().isIdProperty()) { columnAliases.add(just(alias)); } columns.add(table.column(path.getColumnInfo().name()).as(alias)); } + return columnAliases; + } - SelectBuilder.SelectWhere select = StatementBuilder.select(columns).from(table); + private static List getAggregatePaths(AggregatePath basePath, RelationalPersistentEntity entity) { - SelectBuilder.BuildSelect buildSelect = condition != null ? select.where(condition) : select; + List paths = new ArrayList<>(); - InlineQuery inlineQuery = InlineQuery.create(buildSelect.build(false), aliases.getTableAlias(basePath)); - return QueryMeta.of(basePath, inlineQuery, columnAliases, just(id), just(backReferenceAlias), just(keyAlias), - just(rowNumberAlias), just(rowCountAlias)); + for (RelationalPersistentProperty property : entity) { + if (!property.isEntity()) { + paths.add(basePath.append(property)); + } + } + + return paths; + } + + @Nullable + private String getIdentifierProperty(List paths) { + + for (AggregatePath path : paths) { + if (path.getRequiredLeafProperty().isIdProperty()) { + return aliases.getColumnAlias(path); + } + } + + return null; } - @NotNull private static AnalyticFunction createRowNumberExpression(AggregatePath basePath, Table table, String rowNumberAlias) { return AnalyticFunction.create("row_number") // @@ -261,17 +280,20 @@ public class SingleQuerySqlGenerator implements SqlGenerator { * null (when there is no child elements at all) or the values for rownumber 1 are used for that child * * - * @param rootPath path to the root entity that gets selected. * @param inlineQueries all in the inline queries for all the children, as returned by * {@link #createInlineQueries(PersistentPropertyPaths)} * @param select the select to which the where clause gets added. * @return the modified select. */ - private SelectBuilder.SelectOrdered applyWhereCondition(AggregatePath rootPath, List inlineQueries, + private SelectBuilder.SelectOrdered applyWhereCondition(List inlineQueries, SelectBuilder.SelectJoin select) { SelectBuilder.SelectWhere selectWhere = (SelectBuilder.SelectWhere) select; + if (inlineQueries.isEmpty()) { + return selectWhere; + } + Condition joins = null; for (int left = 0; left < inlineQueries.size(); left++) { @@ -288,8 +310,6 @@ public class SingleQuerySqlGenerator implements SqlGenerator { Expression rightRowNumber = just(aliases.getRowNumberAlias(rightPath)); Expression rightRowCount = just(aliases.getRowCountAlias(rightPath)); - System.out.println("joining: " + leftPath + " and " + rightPath); - Condition mutualJoin = Conditions.isEqual(leftRowNumber, rightRowNumber).or(Conditions.isNull(leftRowNumber)) .or(Conditions.isNull(rightRowNumber)) .or(Conditions.nest(Conditions.isGreater(leftRowNumber, rightRowCount) @@ -307,18 +327,7 @@ public class SingleQuerySqlGenerator implements SqlGenerator { } } - // for (QueryMeta queryMeta : inlineQueries) { - // - // AggregatePath path = queryMeta.basePath; - // Expression childRowNumber = just(aliases.getRowNumberAlias(path)); - // Condition pseudoJoinCondition = Conditions.isNull(childRowNumber) - // .or(Conditions.isEqual(childRowNumber, Expressions.just(aliases.getRowNumberAlias(rootPath)))) - // .or(Conditions.isGreater(childRowNumber, Expressions.just(aliases.getRowCountAlias(rootPath)))); - // - selectWhere = (SelectBuilder.SelectWhere) selectWhere.where(joins); - // } - - return selectWhere == null ? (SelectBuilder.SelectOrdered) select : selectWhere; + return selectWhere.where(joins); } @Override @@ -430,6 +439,5 @@ public class SingleQuerySqlGenerator implements SqlGenerator { return new QueryMeta(basePath, inlineQuery, simpleColumns, selectableExpressions, id, backReference, key, rowNumber, rowCount); } - } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java index fe783882a..38f2827d7 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java @@ -15,6 +15,7 @@ */ package org.springframework.data.relational.core.sqlgeneration; +import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.sql.Condition; import org.springframework.lang.Nullable; @@ -26,11 +27,11 @@ import org.springframework.lang.Nullable; */ public interface SqlGenerator { - default String findAll() { - return findAll(null); + default String findAll(RelationalPersistentEntity aggregate) { + return findAll(aggregate, null); } - String findAll(@Nullable Condition condition); + String findAll(RelationalPersistentEntity aggregate, @Nullable Condition condition); AliasFactory getAliasFactory(); } diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java index 2f8a86598..6185d30a6 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java @@ -52,7 +52,7 @@ class SingleQuerySqlGeneratorUnitTests { @Test // GH-1446 void createSelectForFindAll() { - String sql = sqlGenerator.findAll(); + String sql = sqlGenerator.findAll(persistentEntity); SqlAssert fullSelect = assertThatParsed(sql); fullSelect.extractOrderBy().isEqualTo(alias("id") + ", rn"); @@ -79,7 +79,7 @@ class SingleQuerySqlGeneratorUnitTests { void createSelectForFindById() { Table table = Table.create(persistentEntity.getQualifiedTableName()); - String sql = sqlGenerator.findAll(table.column("id").isEqualTo(Conditions.just(":id"))); + String sql = sqlGenerator.findAll(persistentEntity, table.column("id").isEqualTo(Conditions.just(":id"))); SqlAssert baseSelect = assertThatParsed(sql).hasInlineView(); @@ -104,7 +104,7 @@ class SingleQuerySqlGeneratorUnitTests { void createSelectForFindAllById() { Table table = Table.create(persistentEntity.getQualifiedTableName()); - String sql = sqlGenerator.findAll(table.column("id").in(Conditions.just(":ids"))); + String sql = sqlGenerator.findAll(persistentEntity, table.column("id").in(Conditions.just(":ids"))); SqlAssert baseSelect = assertThatParsed(sql).hasInlineView(); @@ -138,7 +138,7 @@ class SingleQuerySqlGeneratorUnitTests { void createSelectForFindById() { Table table = Table.create(persistentEntity.getQualifiedTableName()); - String sql = sqlGenerator.findAll(table.column("id").isEqualTo(Conditions.just(":id"))); + String sql = sqlGenerator.findAll(persistentEntity, table.column("id").isEqualTo(Conditions.just(":id"))); String rootRowNumber = rnAlias(); String rootCount = rcAlias(); @@ -161,8 +161,7 @@ class SingleQuerySqlGeneratorUnitTests { func("coalesce", col(trivialsRowNumber), lit(1))), // col(backref), // col(keyAlias) // - ) - .extractWhereClause() // + ).extractWhereClause() // .isEqualTo(""); baseSelect.hasInlineViewSelectingFrom("\"single_reference_aggregate\"") // .hasExactlyColumns( // @@ -216,7 +215,7 @@ class SingleQuerySqlGeneratorUnitTests { this.aggregateRootType = aggregateRootType; this.persistentEntity = context.getRequiredPersistentEntity(aggregateRootType); - this.sqlGenerator = new SingleQuerySqlGenerator(context, new AliasFactory(), dialect, persistentEntity); + this.sqlGenerator = new SingleQuerySqlGenerator(context, new AliasFactory(), dialect); this.aliases = sqlGenerator.getAliasFactory(); }