From 9dbca5b8a54cfc3e9cf3e8ea7fd1ae7ad3d7687d Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Thu, 19 Mar 2026 09:08:35 +0100 Subject: [PATCH] Move over to UpsertRenderer. --- .../data/jdbc/core/JdbcAggregateTemplate.java | 2 +- .../convert/CascadingDataAccessStrategy.java | 8 +- .../jdbc/core/convert/DataAccessStrategy.java | 8 +- .../convert/DefaultDataAccessStrategy.java | 9 +- .../convert/DelegatingDataAccessStrategy.java | 4 +- .../mybatis/MyBatisDataAccessStrategy.java | 2 +- .../sql/render/ConflictColumnCollector.java | 16 +- .../sql/render/MySqlUpsertRenderContext.java | 39 +-- .../sql/render/OracleUpsertRenderContext.java | 59 +---- .../render/PostgresUpsertRenderContext.java | 49 +--- .../core/sql/render/SqlRenderer.java | 1 + .../render/SqlServerUpsertRenderContext.java | 14 +- .../StandardSqlUpsertRenderContext.java | 63 +---- .../core/sql/render/UpsertRenderContext.java | 61 +---- .../sql/render/UpsertStatementRenderer.java | 227 +++++++++++++++++ .../sql/render/UpsertStatementRenderers.java | 230 ++++++++++++++++++ .../sql/render/UpsertStatementVisitor.java | 47 +--- .../data/relational/DependencyTests.java | 1 + ...andardSqlUpsertRenderContextUnitTests.java | 20 +- .../render/UpsertRenderContextUnitTests.java | 66 ++--- 20 files changed, 556 insertions(+), 370 deletions(-) create mode 100644 spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementRenderer.java create mode 100644 spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementRenderers.java diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java index b4ee10bd3..984fc5a6a 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java @@ -274,7 +274,7 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations, Applicati Assert.notNull(instance, "Aggregate instance must not be null"); Class entityType = (Class) ClassUtils.getUserClass(instance); - accessStrategy.upsert(instance, entityType, Identifier.empty()); + accessStrategy.upsert(instance, entityType); return instance; } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java index 565579802..e274c9671 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java @@ -15,7 +15,7 @@ */ package org.springframework.data.jdbc.core.convert; -import static java.lang.Boolean.*; +import static java.lang.Boolean.TRUE; import java.util.ArrayList; import java.util.List; @@ -25,7 +25,6 @@ import java.util.function.Function; import java.util.stream.Stream; import org.jspecify.annotations.Nullable; - import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.mapping.PersistentPropertyPath; @@ -49,6 +48,7 @@ import org.springframework.util.Assert; * @author Chirag Tailor * @author Diego Krupitza * @author Sergey Korotaev + * @author Christoph Strobl * @since 1.1 */ public class CascadingDataAccessStrategy implements DataAccessStrategy { @@ -88,8 +88,8 @@ public class CascadingDataAccessStrategy implements DataAccessStrategy { } @Override - public int upsert(T instance, Class domainType, Identifier identifier) { - return collect(das -> das.upsert(instance, domainType, identifier)); + public int upsert(T instance, Class domainType) { + return collect(das -> das.upsert(instance, domainType)); } @Override diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java index 4820e961c..8c1217666 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java @@ -45,6 +45,7 @@ import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; * @author Chirag Tailor * @author Diego Krupitza * @author Sergey Korotaev + * @author Christoph Strobl */ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationResolver { @@ -119,18 +120,17 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR boolean updateWithVersion(T instance, Class domainType, Number previousVersion); /** - * Upserts the data of a single entity (insert if row for id does not exist, update if it exists). Requires a - * provided id. Only supported when the dialect supports single-statement upsert. + * Upserts the data of a single entity (insert if row for id does not exist, update if it exists). Requires the + * instance to hold an id. Only supported when the dialect supports single-statement upsert. * * @param instance the instance to upsert. Must not be {@code null}. Must have an id set. * @param domainType the type of the instance. Must not be {@code null}. - * @param identifier information about data that needs to be considered (e.g. back-references). May be empty for root. * @param the type of the instance. * @return the number of rows affected by the upsert. * @throws UnsupportedOperationException if the dialect does not support upsert. * @since 4.x */ - int upsert(T instance, Class domainType, Identifier identifier); + int upsert(T instance, Class domainType); /** * Deletes a single row identified by the id, from the table identified by the domainType. Does not handle cascading diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java index be99cbc64..dd171b54b 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java @@ -15,7 +15,7 @@ */ package org.springframework.data.jdbc.core.convert; -import static org.springframework.data.jdbc.core.convert.SqlGenerator.*; +import static org.springframework.data.jdbc.core.convert.SqlGenerator.VERSION_SQL_PARAMETER; import java.sql.ResultSet; import java.sql.SQLException; @@ -27,7 +27,6 @@ import java.util.stream.Stream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; - import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; @@ -184,14 +183,14 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy { } @Override - public int upsert(T instance, Class domainType, Identifier identifier) { + public int upsert(T instance, Class domainType) { - SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forInsert(instance, domainType, identifier, + SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forInsert(instance, domainType, Identifier.empty(), IdValueSource.PROVIDED); String statement = sql(domainType).getUpsert(parameterSource.getIdentifiers()); - if(logger.isTraceEnabled()) { + if (logger.isTraceEnabled()) { logger.trace("Upsert: [%s]".formatted(statement)); } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java index a271e60dc..9abf9f7d4 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java @@ -81,8 +81,8 @@ public class DelegatingDataAccessStrategy implements DataAccessStrategy { } @Override - public int upsert(T instance, Class domainType, Identifier identifier) { - return delegate.upsert(instance, domainType, identifier); + public int upsert(T instance, Class domainType) { + return delegate.upsert(instance, domainType); } @Override diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java index 6d542615b..921b0f1f8 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java @@ -185,7 +185,7 @@ public class MyBatisDataAccessStrategy implements DataAccessStrategy { } @Override - public int upsert(T instance, Class domainType, Identifier identifier) { + public int upsert(T instance, Class domainType) { throw new UnsupportedOperationException("Upsert is not supported by MyBatisDataAccessStrategy"); } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/ConflictColumnCollector.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/ConflictColumnCollector.java index a6f7de610..306b1b4e8 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/ConflictColumnCollector.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/ConflictColumnCollector.java @@ -22,27 +22,25 @@ import org.springframework.data.relational.core.sql.Column; import org.springframework.data.relational.core.sql.Comparison; import org.springframework.data.relational.core.sql.Condition; import org.springframework.data.relational.core.sql.MultipleCondition; -import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.data.relational.core.sql.Visitable; import org.springframework.data.relational.core.sql.Visitor; /** - * Collects conflict columns from a {@link Condition} by traversing equality comparisons. - * For {@link Comparison} with {@code =} and a {@link Column} on the left, the column name is collected. - * For {@link MultipleCondition} (e.g. AND), recurses into child conditions. + * Collects conflict columns from a {@link Condition} by traversing equality comparisons. For {@link Comparison} with + * {@code =} and a {@link Column} on the left, the column name is collected. For {@link MultipleCondition} (e.g. AND), + * recurses into child conditions. * * @since 4.x */ final class ConflictColumnCollector implements Visitor { - private final List conflictColumns = new ArrayList<>(); + private final List conflictColumns = new ArrayList<>(); @Override public void enter(Visitable segment) { - if (segment instanceof Comparison comparison && "=".equals(comparison.getComparator()) - && comparison.getLeft() instanceof Column column) { - conflictColumns.add(column.getName()); + if (segment instanceof Comparison comparison && comparison.getLeft() instanceof Column column) { + conflictColumns.add(column); } if (segment instanceof MultipleCondition multiple) { @@ -52,7 +50,7 @@ final class ConflictColumnCollector implements Visitor { } } - List getConflictColumns() { + List getConflictColumns() { return conflictColumns; } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/MySqlUpsertRenderContext.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/MySqlUpsertRenderContext.java index ab2fa1b47..3b2fd1de7 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/MySqlUpsertRenderContext.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/MySqlUpsertRenderContext.java @@ -15,14 +15,6 @@ */ package org.springframework.data.relational.core.sql.render; -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; - -import org.springframework.data.relational.core.sql.SqlIdentifier; -import org.springframework.data.relational.core.sql.Table; -import org.springframework.util.Assert; - /** * MySQL / MariaDB upsert using {@code INSERT ... ON DUPLICATE KEY UPDATE}. * @@ -34,34 +26,7 @@ public enum MySqlUpsertRenderContext implements UpsertRenderContext { INSTANCE; @Override - public String renderUpsert(Table table, Columns columns, Function bindMarkerFn) { - - Assert.notEmpty(columns.insertColumns(), "Insert columns must not be empty"); - Assert.notEmpty(columns.filterColumns(), "Filter columns must not be empty"); - - String tableName = columns.tableName(table); - String columnNames = String.join(", ", columns.insertColumnNames()); - String bindMarkers = String.join(", ", columns.insertColumnBindMarkers(bindMarkerFn)); - String setValues = setValuesSnippet(columns); - - return "INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE %s".formatted( // - tableName, // - columnNames, // - bindMarkers, // - setValues); - } - - private static String setValuesSnippet(Columns columns) { - - List updateColumns = columns.updateColumns(); - - if (updateColumns.isEmpty()) { - updateColumns = columns.filterColumns(); - } - - return updateColumns.stream().map(col -> { - String colName = col.toSql(columns.identifierProcessing()); - return "%s = VALUES(%s)".formatted(colName, colName); - }).collect(Collectors.joining(", ")); + public UpsertStatementRenderer renderer() { + return UpsertStatementRenderer.mySql(); } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/OracleUpsertRenderContext.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/OracleUpsertRenderContext.java index 3e218e041..79d3d1807 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/OracleUpsertRenderContext.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/OracleUpsertRenderContext.java @@ -15,13 +15,6 @@ */ package org.springframework.data.relational.core.sql.render; -import java.util.List; -import java.util.function.Function; - -import org.springframework.data.relational.core.sql.SqlIdentifier; -import org.springframework.data.relational.core.sql.Table; -import org.springframework.util.Assert; - /** * Oracle MERGE upsert. Uses {@code SELECT ... FROM DUAL} for source values. * @@ -32,55 +25,7 @@ public enum OracleUpsertRenderContext implements UpsertRenderContext { INSTANCE; @Override - public String renderUpsert(Table table, Columns columns, Function bindMarkerFn) { - - Assert.notEmpty(columns.insertColumns(), "Insert columns must not be empty"); - Assert.notEmpty(columns.filterColumns(), "Filter columns must not be empty"); - - String targetTableAlias = columns.identifierProcessing().quote(StandardSqlUpsertRenderContext.targetTableAlias); - String sourceTableAlias = columns.identifierProcessing().quote(StandardSqlUpsertRenderContext.sourceTableAlias); - - String tableName = columns.tableName(table); - String insertColumnNames = String.join(", ", columns.insertColumnNames()); - String sourceSelectList = String.join(", ", - columns.insertColumns().stream().map(col -> bindMarkerFn.apply(col) + " AS " + columns.column(col)).toList()); - - String onCondition = String.join(" AND ", columns.filterColumns().stream().map(col -> { - String colName = columns.column(col); - return "%s.%s = %s.%s".formatted(targetTableAlias, colName, sourceTableAlias, colName); - }).toList()); - - String insertValuesSql = String.join(", ", - columns.insertColumns().stream().map(col -> columns.column(sourceTableAlias, col)).toList()); - - String insertClause = "WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s)".formatted(insertColumnNames, - insertValuesSql); - - List updateColumns = columns.updateColumns(); - if (updateColumns.isEmpty()) { - // ORA-38104: columns referenced in ON cannot be updated; omit WHEN MATCHED so existing rows are left - // unchanged (same as a no-op update of key-only columns). - return "MERGE INTO %s %s USING (SELECT %s FROM DUAL) %s ON (%s) %s".formatted( // - tableName, // - targetTableAlias, // - sourceSelectList, // - sourceTableAlias, // - onCondition, // - insertClause); - } - - String updateSetClause = String.join(", ", updateColumns.stream().map(col -> { - String colName = columns.column(col); - return "%s.%s = %s.%s".formatted(targetTableAlias, colName, sourceTableAlias, colName); - }).toList()); - - return "MERGE INTO %s %s USING (SELECT %s FROM DUAL) %s ON (%s) WHEN MATCHED THEN UPDATE SET %s %s".formatted( // - tableName, // - targetTableAlias, // - sourceSelectList, // - sourceTableAlias, // - onCondition, // - updateSetClause, // - insertClause); + public UpsertStatementRenderer renderer() { + return UpsertStatementRenderer.oracle(); } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/PostgresUpsertRenderContext.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/PostgresUpsertRenderContext.java index 33cd00644..9349a975b 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/PostgresUpsertRenderContext.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/PostgresUpsertRenderContext.java @@ -15,14 +15,6 @@ */ package org.springframework.data.relational.core.sql.render; -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; - -import org.springframework.data.relational.core.sql.SqlIdentifier; -import org.springframework.data.relational.core.sql.Table; -import org.springframework.util.Assert; - /** * PostgreSQL upsert using {@code INSERT ... ON CONFLICT ... DO UPDATE SET}. * @@ -33,44 +25,7 @@ public enum PostgresUpsertRenderContext implements UpsertRenderContext { INSTANCE; @Override - public String renderUpsert(Table table, Columns columns, Function bindMarkerFn) { - - Assert.notEmpty(columns.insertColumns(), "Insert columns must not be empty"); - Assert.notEmpty(columns.filterColumns(), "Filter columns must not be empty"); - - String tableName = columns.tableName(table); - String insertColumnNames = String.join(", ", columns.insertColumnNames()); - String bindMarkers = String.join(", ", columns.insertColumnBindMarkers(bindMarkerFn)); - String filterColumnNames = String.join(", ", columns.filterColumnNames()); - - if(columns.updateColumns().isEmpty()) { - return "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO NOTHING".formatted(// - tableName, // - insertColumnNames, // - bindMarkers, // - filterColumnNames); - } - - String setValues = setValuesSnippet(columns); - return "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s".formatted(// - tableName, // - insertColumnNames, // - bindMarkers, // - filterColumnNames, // - setValues); - } - - private static String setValuesSnippet(Columns columns) { - - List updateColumns = columns.updateColumns(); - - if (updateColumns.isEmpty()) { - updateColumns = columns.filterColumns(); - } - - return updateColumns.stream().map(col -> { - String colName = col.toSql(columns.identifierProcessing()); - return "%s = EXCLUDED.%s".formatted(colName, colName); - }).collect(Collectors.joining(", ")); + public UpsertStatementRenderer renderer() { + return UpsertStatementRenderer.postgres(); } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/SqlRenderer.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/SqlRenderer.java index 097dd61bb..d143bbd57 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/SqlRenderer.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/SqlRenderer.java @@ -27,6 +27,7 @@ import org.springframework.util.Assert; * * @author Mark Paluch * @author Jens Schauder + * @author Christoph Strobl * @since 1.1 * @see RenderContext */ diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/SqlServerUpsertRenderContext.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/SqlServerUpsertRenderContext.java index 906b673d5..a035d47b5 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/SqlServerUpsertRenderContext.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/SqlServerUpsertRenderContext.java @@ -15,13 +15,8 @@ */ package org.springframework.data.relational.core.sql.render; -import java.util.function.Function; - -import org.springframework.data.relational.core.sql.SqlIdentifier; -import org.springframework.data.relational.core.sql.Table; - /** - * SQL Server MERGE upsert. Delegates to {@link StandardSqlUpsertRenderContext} and appends a required semicolon. + * SQL Server MERGE upsert. Delegates to {@link UpsertStatementRenderers.StandardSql} and appends a required semicolon. * * @since 4.x */ @@ -29,11 +24,8 @@ public enum SqlServerUpsertRenderContext implements UpsertRenderContext { INSTANCE; - private static final String STATEMENT_TERMINATOR = ";"; - @Override - public String renderUpsert(Table table, Columns merge, Function bindMarkerFn) { - return StandardSqlUpsertRenderContext.INSTANCE.renderUpsert(table, merge, - bindMarkerFn) + STATEMENT_TERMINATOR; + public UpsertStatementRenderer renderer() { + return UpsertStatementRenderer.sqlServer(); } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/StandardSqlUpsertRenderContext.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/StandardSqlUpsertRenderContext.java index c54d8e007..09d4567d8 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/StandardSqlUpsertRenderContext.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/StandardSqlUpsertRenderContext.java @@ -15,13 +15,6 @@ */ package org.springframework.data.relational.core.sql.render; -import java.util.List; -import java.util.function.Function; - -import org.springframework.data.relational.core.sql.SqlIdentifier; -import org.springframework.data.relational.core.sql.Table; -import org.springframework.util.Assert; - /** * Standard SQL {@code MERGE} upsert for dialects that support it (like H2, HSQLDB, SQL Server, DB2). *

@@ -33,60 +26,8 @@ public enum StandardSqlUpsertRenderContext implements UpsertRenderContext { INSTANCE; - static final String targetTableAlias = "_t"; - static final String sourceTableAlias = "_s"; - @Override - public String renderUpsert(Table table, Columns columns, Function bindMarkerFn) { - - Assert.notEmpty(columns.insertColumns(), "Insert columns must not be empty"); - Assert.notEmpty(columns.filterColumns(), "Filter columns must not be empty"); - - String targetTableAlias = columns.identifierProcessing().quote(StandardSqlUpsertRenderContext.targetTableAlias); - String sourceTableAlias = columns.identifierProcessing().quote(StandardSqlUpsertRenderContext.sourceTableAlias); - - String tableName = columns.tableName(table); - String insertColumnNames = String.join(", ", columns.insertColumnNames()); - String bindMarkers = String.join(", ", columns.insertColumnBindMarkers(bindMarkerFn)); - - String onCondition = String.join(" AND ", columns.filterColumns().stream().map(col -> { - String colName = columns.column(col); - return "%s.%s = %s.%s".formatted(targetTableAlias, colName, sourceTableAlias, colName); - }).toList()); - - String insertValuesSql = String.join(", ", - columns.insertColumns().stream().map(col -> columns.column(sourceTableAlias, col)).toList()); - - String insertClause = "WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s)".formatted(insertColumnNames, - insertValuesSql); - - List updateColumns = columns.updateColumns(); - if (updateColumns.isEmpty()) { - // Matched rows are left unchanged. Updating only key columns is invalid on SQL Server (identity) and Oracle - // (ORA-38104). - return "MERGE INTO %s %s USING (VALUES (%s)) AS %s (%s) ON %s %s".formatted( // - tableName, // - targetTableAlias, // - bindMarkers, // - sourceTableAlias, // - insertColumnNames, // - onCondition, // - insertClause); - } - - String updateSetClause = String.join(", ", updateColumns.stream().map(col -> { - String colName = columns.column(col); - return "%s.%s = %s.%s".formatted(targetTableAlias, colName, sourceTableAlias, colName); - }).toList()); - - return "MERGE INTO %s %s USING (VALUES (%s)) AS %s (%s) ON %s WHEN MATCHED THEN UPDATE SET %s %s".formatted( // - tableName, // - targetTableAlias, // - bindMarkers, // - sourceTableAlias, // - insertColumnNames, // - onCondition, // - updateSetClause, // - insertClause); + public UpsertStatementRenderer renderer() { + return new UpsertStatementRenderers.StandardSql(); } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertRenderContext.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertRenderContext.java index 172b783f6..5613861aa 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertRenderContext.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertRenderContext.java @@ -15,70 +15,13 @@ */ package org.springframework.data.relational.core.sql.render; -import java.util.List; -import java.util.function.Function; - -import org.springframework.data.relational.core.sql.IdentifierProcessing; -import org.springframework.data.relational.core.sql.SqlIdentifier; -import org.springframework.data.relational.core.sql.Table; - /** - * Encapsulates dialect-specific rendering of a single-statement upsert (insert or update by id). Implementations - * produce vendor-specific SQL such as {@code INSERT ... ON CONFLICT ... DO UPDATE}, - * {@code INSERT ... ON DUPLICATE KEY UPDATE}, or standard {@code MERGE}. + * {@link UpsertStatementRenderers}. * * @since 4.x */ public interface UpsertRenderContext { - /** - * Render a full upsert statement. - * - * @param table the target table. - * @param columns the merge operation. - * @param bindMarkerFn function from column name to bind marker placeholder (e.g. {@code "id" -> ":id"}). - * @return the full upsert SQL statement. - */ - String renderUpsert(Table table, Columns columns, Function bindMarkerFn); - - /** - * @param insertColumns column names for INSERT (order preserved for VALUES clause). - * @param filterColumns columns that define the query for existing records (e.g. primary key). - * @param identifierProcessing identifier processing for rendering table and column names to SQL. - */ - record Columns(List insertColumns, List filterColumns, - IdentifierProcessing identifierProcessing) { - - String tableName(Table table) { - return table.getName().toSql(identifierProcessing); - } - - List insertColumnNames() { - return insertColumns.stream().map(this::column).toList(); - } - - List filterColumnNames(String tableAlias) { - return filterColumns.stream().map(col -> tableAlias + "." + column(col)).toList(); - } - - List filterColumnNames() { - return filterColumns.stream().map(this::column).toList(); - } - - List insertColumnBindMarkers(Function bindMarkerFn) { - return insertColumns.stream().map(bindMarkerFn).toList(); - } - - List updateColumns() { - return insertColumns.stream().filter(col -> !filterColumns.contains(col)).toList(); - } - - String column(String tableAlias, SqlIdentifier column) { - return tableAlias + "." + column(column); - } + UpsertStatementRenderer renderer(); - String column(SqlIdentifier column) { - return column.toSql(identifierProcessing); - } - } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementRenderer.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementRenderer.java new file mode 100644 index 000000000..f4e104423 --- /dev/null +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementRenderer.java @@ -0,0 +1,227 @@ +/* + * Copyright 2026-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.relational.core.sql.render; + +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collector; + +import org.springframework.data.relational.core.sql.Aliased; +import org.springframework.data.relational.core.sql.Column; +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.data.relational.core.sql.Table; +import org.springframework.data.relational.core.sql.render.UpsertStatementRenderers.MySql; +import org.springframework.data.relational.core.sql.render.UpsertStatementRenderers.Oracle; +import org.springframework.data.relational.core.sql.render.UpsertStatementRenderers.Postgres; +import org.springframework.data.relational.core.sql.render.UpsertStatementRenderers.SqlServer; +import org.springframework.data.relational.core.sql.render.UpsertStatementRenderers.StandardSql; + +/** + * Dialect-specific upsert SQL as a single statement string (e.g. {@code MERGE}, {@code INSERT ... ON CONFLICT}, + * {@code INSERT ... ON DUPLICATE KEY UPDATE}). Callers resolve {@link Column}s and {@link Table}; implementations only + * assemble syntax and use {@link UpsertRenderingContext} so names and bind markers match the enclosing + * {@link RenderContext}. Concrete renderers are defined in {@link UpsertStatementRenderers}. + * + * @author Christoph Strobl + * @since 4.x + */ +public interface UpsertStatementRenderer { + + static UpsertStatementRenderer standardSql() { + return StandardSql.INSTANCE; + } + + static UpsertStatementRenderer mySql() { + return MySql.INSTANCE; + } + + static UpsertStatementRenderer oracle() { + return Oracle.INSTANCE; + } + + static UpsertStatementRenderer postgres() { + return Postgres.INSTANCE; + } + + static UpsertStatementRenderer sqlServer() { + return SqlServer.INSTANCE; + } + + /** + * Render the full upsert statement for {@code table}. + * + * @param table target table + * @param columns {@link Columns#insertColumns()} values to insert; {@link Columns#conflictColumns()} keys that + * identify an existing row for the dialect's conflict/merge semantics + * @param ctx rendering hooks (quoting, bind markers) tied to the current {@link RenderContext} + * @return executable upsert SQL text (parameter placeholders as produced by {@code ctx}) + */ + String render(Table table, Columns columns, UpsertRenderingContext ctx); + + /** + * Building blocks for {@link UpsertStatementRenderer}. + */ + interface UpsertRenderingContext { + + /** + * Backs upsert rendering with {@code renderContext} (quoting, bind marker style). + * + * @param renderContext active SQL render context + * @return context passed to {@link UpsertStatementRenderer#render} + */ + static UpsertRenderingContext of(RenderContext renderContext) { + return () -> renderContext; + } + + /** @return render context */ + RenderContext renderContext(); + + /** @return rendered table reference */ + default CharSequence tableName(Table table) { + return NameRenderer.render(renderContext(), table); + } + + /** @return rendered column reference without a table qualifier */ + default CharSequence columnName(Column column) { + return columnName(SqlIdentifier.EMPTY, column); + } + + /** @return {@code column} rendered with {@link Aliased#getAlias()} as qualifier */ + default CharSequence columnName(Aliased table, Column column) { + return columnName(table.getAlias(), column); + } + + /** + * @param tableAlias table or empty; if empty, column only, else {@code alias.column} + * @return rendered column reference + */ + default CharSequence columnName(SqlIdentifier tableAlias, Column column) { + if (tableAlias.equals(SqlIdentifier.EMPTY)) { + return NameRenderer.render(renderContext(), column); + } + return "%s.%s".formatted(NameRenderer.render(renderContext(), tableAlias), + NameRenderer.render(renderContext(), column)); + } + + /** @return each column name rendered (unqualified) and collected (e.g. comma-separated) */ + default CharSequence columnNames(List columns, + Collector collector) { + return columnNames(SqlIdentifier.EMPTY, columns, collector); + } + + /** @return like {@link #columnNames(List, Collector)} but with {@code tableAlias} on each column */ + default CharSequence columnNames(SqlIdentifier tableAlias, List columns, + Collector collector) { + return columns.stream().map(column -> columnName(tableAlias, column)).collect(collector); + } + + /** @return {@code :reference} bind marker from {@link Column#getName()} */ + default CharSequence bindMarker(Column column) { + return bindMarker(column, (columnName, bindMarker) -> bindMarker); + } + + /** + * @param bindMarkerFn receives rendered column name and default {@code :reference} marker; returns fragment to + * embed + * @return result of {@code bindMarkerFn} + */ + default CharSequence bindMarker(Column column, BiFunction bindMarkerFn) { + return bindMarkerFn.apply(columnName(column), ":%s".formatted(column.getName().getReference())); + } + + /** @return bind marker per column, collected */ + default CharSequence bindMarkers(List columns, + Collector collector) { + return columns.stream().map(column -> bindMarker(column, (columnName, bindMarker) -> bindMarker)) + .collect(collector); + } + + /** @return bind markers using {@code bindMarkerFn} per column, collected */ + default CharSequence bindMarkers(List columns, + BiFunction bindMarkerFn, + Collector collector) { + return columns.stream().map(column -> bindMarker(column, bindMarkerFn)).collect(collector); + } + + /** @return {@code targetColumn = sourceColumn} for the given aliases */ + default CharSequence assignment(SqlIdentifier targetTableAlias, Column column, SqlIdentifier sourceTableAlias) { + return assignment(targetTableAlias, column, sourceTableAlias, Function.identity()); + } + + /** + * @param sourceValueFn transforms the rendered source column reference (e.g. wrap in a function call) + * @return {@code targetColumn =} {@code sourceValueFn(sourceColumn)} + */ + default CharSequence assignment(SqlIdentifier targetTableAlias, Column column, SqlIdentifier sourceTableAlias, + Function sourceValueFn) { + + CharSequence targetColumn = columnName(targetTableAlias, column); + CharSequence sourceColumn = columnName(sourceTableAlias, column); + return "%s = %s".formatted(targetColumn, sourceValueFn.apply(sourceColumn)); + } + + /** @return one assignment per column, collected */ + default CharSequence assignments(SqlIdentifier targetTableAlias, List columns, + SqlIdentifier sourceTableAlias, Collector collector) { + return assignments(targetTableAlias, columns, sourceTableAlias, Function.identity(), collector); + } + + /** @return assignments with {@code sourceValueFn} applied to each source side, collected */ + default CharSequence assignments(SqlIdentifier targetTableAlias, List columns, + SqlIdentifier sourceTableAlias, Function sourceValueFn, + Collector collector) { + return columns.stream().map(column -> assignment(targetTableAlias, column, sourceTableAlias, sourceValueFn)) + .collect(collector); + } + } + + final class Columns { + + private final List insertColumns; + private final List conflictColumns; + private final List updateColumns; + + public Columns(List insertColumns, List conflictColumns) { + + this.insertColumns = insertColumns; + this.conflictColumns = conflictColumns; + this.updateColumns = insertColumns.stream() + .filter(col -> conflictColumns.stream().noneMatch(it -> it.getName().equals(col.getName()))).toList(); + } + + /** + * Columns to assign on update. + */ + List updateColumns() { + return updateColumns; + } + + /** + * Columns insert. + */ + public List insertColumns() { + return insertColumns; + } + + /** + * Columns defining the conflict condition. + */ + public List conflictColumns() { + return conflictColumns; + } + } +} diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementRenderers.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementRenderers.java new file mode 100644 index 000000000..bd92401e9 --- /dev/null +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementRenderers.java @@ -0,0 +1,230 @@ +/* + * Copyright 2026-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.relational.core.sql.render; + +import java.util.List; +import java.util.stream.Collectors; + +import org.springframework.data.relational.core.sql.Column; +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.data.relational.core.sql.Table; +import org.springframework.util.Assert; + +/** + * Concrete {@link UpsertStatementRenderer} implementations. + * + * @author Christoph Strobl + * @since 4.x + */ +final class UpsertStatementRenderers { + + /** Target table alias in {@code MERGE} statements. */ + static final SqlIdentifier MERGE_TARGET_TABLE_ALIAS = SqlIdentifier.quoted("_t"); + + /** Source (values) alias in {@code MERGE} statements. */ + static final SqlIdentifier MERGE_SOURCE_TABLE_ALIAS = SqlIdentifier.quoted("_s"); + + private UpsertStatementRenderers() {} + + /** + * Standard SQL {@code MERGE} using a table value constructor {@code (VALUES (?, ?)) AS s (col1, col2)} (H2, HSQLDB, + * DB2, etc.). + */ + static class StandardSql implements UpsertStatementRenderer { + + static final StandardSql INSTANCE = new StandardSql(); + + @Override + public String render(Table table, Columns columns, UpsertRenderingContext ctx) { + + Assert.notEmpty(columns.insertColumns(), "Insert columns must not be empty"); + Assert.notEmpty(columns.conflictColumns(), "Conflict columns must not be empty"); + + CharSequence tableName = ctx.tableName(table); + CharSequence insertColumnNames = ctx.columnNames(columns.insertColumns(), Collectors.joining(", ")); + CharSequence bindMarkers = ctx.bindMarkers(columns.insertColumns(), Collectors.joining(", ")); + CharSequence onCondition = ctx.assignments(MERGE_TARGET_TABLE_ALIAS, columns.conflictColumns(), + MERGE_SOURCE_TABLE_ALIAS, Collectors.joining(" AND ")); + CharSequence insertValuesSql = ctx.columnNames(MERGE_SOURCE_TABLE_ALIAS, columns.insertColumns(), + Collectors.joining(", ")); + + String insertClause = "WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s)".formatted(insertColumnNames, + insertValuesSql); + + List updateColumns = columns.updateColumns(); + if (updateColumns.isEmpty()) { + return "MERGE INTO %s %s USING (VALUES (%s)) AS %s (%s) ON %s %s".formatted( // + tableName, // + MERGE_TARGET_TABLE_ALIAS, // + bindMarkers, // + MERGE_SOURCE_TABLE_ALIAS, // + insertColumnNames, // + onCondition, // + insertClause); + } + + CharSequence updateSetClause = ctx.assignments(MERGE_TARGET_TABLE_ALIAS, columns.updateColumns(), + MERGE_SOURCE_TABLE_ALIAS, Collectors.joining(", ")); + + return "MERGE INTO %s %s USING (VALUES (%s)) AS %s (%s) ON %s WHEN MATCHED THEN UPDATE SET %s %s".formatted( // + tableName, // + MERGE_TARGET_TABLE_ALIAS, // + bindMarkers, // + MERGE_SOURCE_TABLE_ALIAS, // + insertColumnNames, // + onCondition, // + updateSetClause, // + insertClause); + } + } + + /** PostgreSQL {@code INSERT ... ON CONFLICT ... DO UPDATE SET} / {@code DO NOTHING}. */ + static class Postgres implements UpsertStatementRenderer { + + static final Postgres INSTANCE = new Postgres(); + + @Override + public String render(Table table, Columns columns, UpsertRenderingContext ctx) { + + Assert.notEmpty(columns.insertColumns(), "Insert columns must not be empty"); + Assert.notEmpty(columns.conflictColumns(), "Conflict columns must not be empty"); + + CharSequence tableName = ctx.tableName(table); + CharSequence insertColumnNames = ctx.columnNames(columns.insertColumns(), Collectors.joining(", ")); + CharSequence conflictColumnNames = ctx.columnNames(columns.conflictColumns(), Collectors.joining(", ")); + CharSequence bindMarkers = ctx.bindMarkers(columns.insertColumns(), Collectors.joining(", ")); + + List updateColumns = columns.updateColumns(); + + if (updateColumns.isEmpty()) { + updateColumns = columns.conflictColumns(); + } + + CharSequence setValues = ctx.assignments(SqlIdentifier.EMPTY, updateColumns, SqlIdentifier.EMPTY, + "EXCLUDED.%s"::formatted, Collectors.joining(", ")); + + if (columns.updateColumns().isEmpty()) { + return "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO NOTHING".formatted(// + tableName, // + insertColumnNames, // + bindMarkers, // + conflictColumnNames); + } + + return "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s".formatted(// + tableName, // + insertColumnNames, // + bindMarkers, // + conflictColumnNames, // + setValues); + + } + } + + /** MySQL / MariaDB {@code INSERT ... ON DUPLICATE KEY UPDATE}. */ + static class MySql implements UpsertStatementRenderer { + + static final MySql INSTANCE = new MySql(); + + @Override + public String render(Table table, Columns columns, UpsertRenderingContext ctx) { + + Assert.notEmpty(columns.insertColumns(), "Insert columns must not be empty"); + Assert.notEmpty(columns.conflictColumns(), "Conflict columns must not be empty"); + + CharSequence tableName = ctx.tableName(table); + CharSequence columnNames = ctx.columnNames(columns.insertColumns(), Collectors.joining(", ")); + CharSequence bindMarkers = ctx.bindMarkers(columns.insertColumns(), Collectors.joining(", ")); + + List updateColumns = columns.updateColumns(); + + if (updateColumns.isEmpty()) { + updateColumns = columns.conflictColumns(); + } + + CharSequence setValues = ctx.assignments(SqlIdentifier.EMPTY, updateColumns, SqlIdentifier.EMPTY, + "VALUES(%s)"::formatted, Collectors.joining(", ")); + + return "INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE %s".formatted( // + tableName, // + columnNames, // + bindMarkers, // + setValues); + } + } + + /** Oracle {@code MERGE} with {@code SELECT ... FROM DUAL} as source. */ + static class Oracle implements UpsertStatementRenderer { + + static final Oracle INSTANCE = new Oracle(); + + @Override + public String render(Table table, Columns columns, UpsertRenderingContext ctx) { + + Assert.notEmpty(columns.insertColumns(), "Insert columns must not be empty"); + Assert.notEmpty(columns.conflictColumns(), "Conflict columns must not be empty"); + + CharSequence tableName = ctx.tableName(table); + CharSequence insertColumnNames = ctx.columnNames(columns.insertColumns(), Collectors.joining(", ")); + CharSequence sourceSelectList = ctx.bindMarkers(columns.insertColumns(), + (columnName, bindMarker) -> "%s AS %s".formatted(bindMarker, columnName), Collectors.joining(", ")); + CharSequence onCondition = ctx.assignments(MERGE_TARGET_TABLE_ALIAS, columns.conflictColumns(), + MERGE_SOURCE_TABLE_ALIAS, Collectors.joining(" AND ")); + CharSequence insertValuesSql = ctx.columnNames(MERGE_SOURCE_TABLE_ALIAS, columns.insertColumns(), + Collectors.joining(", ")); + + String insertClause = "WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s)".formatted(insertColumnNames, + insertValuesSql); + + List updateColumns = columns.updateColumns(); + if (updateColumns.isEmpty()) { + return "MERGE INTO %s %s USING (SELECT %s FROM DUAL) %s ON (%s) %s".formatted( // + tableName, // + MERGE_TARGET_TABLE_ALIAS, // + sourceSelectList, // + MERGE_SOURCE_TABLE_ALIAS, // + onCondition, // + insertClause); + } + + CharSequence updateSetClause = ctx.assignments(MERGE_TARGET_TABLE_ALIAS, columns.updateColumns(), + MERGE_SOURCE_TABLE_ALIAS, Collectors.joining(", ")); + + return "MERGE INTO %s %s USING (SELECT %s FROM DUAL) %s ON (%s) WHEN MATCHED THEN UPDATE SET %s %s".formatted( // + tableName, // + MERGE_TARGET_TABLE_ALIAS, // + sourceSelectList, // + MERGE_SOURCE_TABLE_ALIAS, // + onCondition, // + updateSetClause, // + insertClause); + } + } + + /** + * SQL Server {@code MERGE}: same body as {@link StandardSql} with a trailing semicolon (batch separator). + */ + static class SqlServer extends StandardSql { + + private static final String STATEMENT_TERMINATOR = ";"; + static final SqlServer INSTANCE = new SqlServer(); + + @Override + public String render(Table table, Columns columns, UpsertRenderingContext ctx) { + return super.render(table, columns, ctx) + STATEMENT_TERMINATOR; + } + } +} diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementVisitor.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementVisitor.java index 29ce551af..4538f4571 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementVisitor.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/UpsertStatementVisitor.java @@ -17,30 +17,30 @@ package org.springframework.data.relational.core.sql.render; import java.util.ArrayList; import java.util.List; -import java.util.function.Function; import org.jspecify.annotations.Nullable; import org.springframework.data.relational.core.sql.AssignValue; +import org.springframework.data.relational.core.sql.Column; import org.springframework.data.relational.core.sql.Condition; -import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.data.relational.core.sql.Table; import org.springframework.data.relational.core.sql.Visitable; -import org.springframework.data.relational.core.sql.render.UpsertRenderContext.Columns; +import org.springframework.data.relational.core.sql.render.UpsertStatementRenderer.UpsertRenderingContext; import org.springframework.util.Assert; /** - * {@link PartRenderer} for {@link org.springframework.data.relational.core.sql.Upsert} statements. - * Traverses the Upsert AST (table, where/conflict condition, assignments), collects insert and conflict columns, - * and delegates dialect-specific rendering to {@link UpsertRenderContext}. + * {@link PartRenderer} for {@link org.springframework.data.relational.core.sql.Upsert} statements. Traverses the Upsert + * AST (table, where/conflict condition, assignments), collects insert and conflict columns, and delegates + * dialect-specific rendering via {@link UpsertRenderContext#renderer()}. * + * @author Christoph Strobl * @since 4.x */ public class UpsertStatementVisitor extends DelegatingVisitor implements PartRenderer { private final StringBuilder builder = new StringBuilder(); private final RenderContext context; - private final List insertColumns = new ArrayList<>(); - private final List conflictColumns = new ArrayList<>(); + private final List insertColumns = new ArrayList<>(); + private final List conflictColumns = new ArrayList<>(5); private @Nullable Table table; @@ -66,7 +66,7 @@ public class UpsertStatementVisitor extends DelegatingVisitor implements PartRen } if (segment instanceof AssignValue assignValue) { - this.insertColumns.add(assignValue.getColumn().getName()); + this.insertColumns.add(assignValue.getColumn()); return Delegation.retain(); } @@ -78,21 +78,12 @@ public class UpsertStatementVisitor extends DelegatingVisitor implements PartRen if (segment instanceof org.springframework.data.relational.core.sql.Upsert) { + Assert.state(table != null, "Upsert requires a table"); UpsertRenderContext upsertContext = context.getUpsertRenderContext(); - if (upsertContext == null) { - throw new UnsupportedOperationException( - "Upsert is not supported by the current render context; no UpsertRenderContext available."); - } - if (table == null) { - throw new IllegalStateException("Upsert statement has no table."); - } - Function bindMarkerFn = cn -> ":" - + sanitizeBindMarkerName(cn.getReference()); - - - String sql = upsertContext.renderUpsert(table, new Columns(new ArrayList<>(insertColumns), - new ArrayList<>(conflictColumns), context.getIdentifierProcessing()), bindMarkerFn); + UpsertRenderingContext renderingContext = UpsertRenderingContext.of(context); + String sql = upsertContext.renderer().render(table, + new UpsertStatementRenderer.Columns(insertColumns, conflictColumns), renderingContext); builder.append(sql); return Delegation.leave(); @@ -105,16 +96,4 @@ public class UpsertStatementVisitor extends DelegatingVisitor implements PartRen public CharSequence getRenderedPart() { return builder; } - - private static String sanitizeBindMarkerName(String rawName) { - - StringBuilder sb = new StringBuilder(rawName.length()); - for (int i = 0; i < rawName.length(); i++) { - char c = rawName.charAt(i); - if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { - sb.append(c); - } - } - return sb.length() > 0 ? sb.toString() : rawName; - } } diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/DependencyTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/DependencyTests.java index f608c8837..08ff167d8 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/DependencyTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/DependencyTests.java @@ -41,6 +41,7 @@ import com.tngtech.archunit.library.dependencies.SlicesRuleDefinition; * * @author Jens Schauder * @author Mark Paluch + * @author Christoph Strobl */ public class DependencyTests { diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/StandardSqlUpsertRenderContextUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/StandardSqlUpsertRenderContextUnitTests.java index a23884f1e..2254050ef 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/StandardSqlUpsertRenderContextUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/StandardSqlUpsertRenderContextUnitTests.java @@ -18,13 +18,14 @@ package org.springframework.data.relational.core.sql.render; import static org.assertj.core.api.Assertions.assertThat; import java.util.List; -import java.util.function.Function; import org.junit.jupiter.api.Test; -import org.springframework.data.relational.core.sql.IdentifierProcessing; +import org.springframework.data.relational.core.dialect.AnsiDialect; +import org.springframework.data.relational.core.dialect.RenderContextFactory; +import org.springframework.data.relational.core.sql.Column; import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.data.relational.core.sql.Table; -import org.springframework.data.relational.core.sql.render.UpsertRenderContext.Columns; +import org.springframework.data.relational.core.sql.render.UpsertStatementRenderer.UpsertRenderingContext; /** * Unit tests for {@link StandardSqlUpsertRenderContext}. @@ -32,11 +33,6 @@ import org.springframework.data.relational.core.sql.render.UpsertRenderContext.C class StandardSqlUpsertRenderContextUnitTests { private static final Table TABLE = Table.create(SqlIdentifier.unquoted("my_table")); - private static final List INSERT_COLUMNS = List.of(SqlIdentifier.unquoted("id"), - SqlIdentifier.unquoted("name")); - private static final List CONFLICT_COLUMNS = List.of(SqlIdentifier.unquoted("id")); - private static final Function BIND_MARKER = id -> ":" + id.getReference(); - private static final IdentifierProcessing IDENTIFIER_PROCESSING = IdentifierProcessing.ANSI; @Test // GH-493 void mergeUpsertWithMultipleConflictColumnsBuildsFilterClauseWithAllColumns() { @@ -44,9 +40,13 @@ class StandardSqlUpsertRenderContextUnitTests { List insertColumns = List.of(SqlIdentifier.unquoted("tenant_id"), SqlIdentifier.unquoted("id"), SqlIdentifier.unquoted("name")); List conflictColumns = List.of(SqlIdentifier.unquoted("tenant_id"), SqlIdentifier.unquoted("id")); - Columns columns = new Columns(insertColumns, conflictColumns, IDENTIFIER_PROCESSING); - String sql = StandardSqlUpsertRenderContext.INSTANCE.renderUpsert(TABLE, columns, BIND_MARKER); + UpsertRenderingContext ctx = UpsertRenderingContext.of(new RenderContextFactory(AnsiDialect.INSTANCE).createRenderContext()); + List insertCols = insertColumns.stream().map(id -> Column.create(id, TABLE)).toList(); + List conflictCols = conflictColumns.stream().map(id -> Column.create(id, TABLE)).toList(); + + String sql = StandardSqlUpsertRenderContext.INSTANCE.renderer().render(TABLE, + new UpsertStatementRenderer.Columns(insertCols, conflictCols), ctx); assertThat(sql).contains("ON \"_t\".tenant_id = \"_s\".tenant_id AND \"_t\".id = \"_s\".id"); assertThat(sql).contains("WHEN MATCHED THEN UPDATE SET \"_t\".name = \"_s\".name"); diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/UpsertRenderContextUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/UpsertRenderContextUnitTests.java index b8012539b..3a0627523 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/UpsertRenderContextUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/UpsertRenderContextUnitTests.java @@ -18,16 +18,22 @@ package org.springframework.data.relational.core.sql.render; import static org.assertj.core.api.Assertions.assertThat; import java.util.List; -import java.util.function.Function; import org.junit.jupiter.api.Test; -import org.springframework.data.relational.core.sql.IdentifierProcessing; +import org.springframework.data.relational.core.dialect.AnsiDialect; +import org.springframework.data.relational.core.dialect.Dialect; +import org.springframework.data.relational.core.dialect.MySqlDialect; +import org.springframework.data.relational.core.dialect.OracleDialect; +import org.springframework.data.relational.core.dialect.PostgresDialect; +import org.springframework.data.relational.core.dialect.RenderContextFactory; +import org.springframework.data.relational.core.dialect.SqlServerDialect; +import org.springframework.data.relational.core.sql.Column; import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.data.relational.core.sql.Table; -import org.springframework.data.relational.core.sql.render.UpsertRenderContext.Columns; +import org.springframework.data.relational.core.sql.render.UpsertStatementRenderer.UpsertRenderingContext; /** - * Unit tests for {@link UpsertRenderContext} implementations. + * Unit tests for {@link UpsertRenderContext} implementations via {@link UpsertStatementRenderer#render}. */ class UpsertRenderContextUnitTests { @@ -35,14 +41,21 @@ class UpsertRenderContextUnitTests { private static final List INSERT_COLUMNS = List.of(SqlIdentifier.unquoted("id"), SqlIdentifier.unquoted("name")); private static final List CONFLICT_COLUMNS = List.of(SqlIdentifier.unquoted("id")); - private static final Function BIND_MARKER = id -> ":" + id.getReference(); - private static final IdentifierProcessing IDENTIFIER_PROCESSING = IdentifierProcessing.ANSI; + + private static String render(UpsertRenderContext upsertContext, Dialect dialect, Table table, + List insertColumns, List conflictColumns) { + + UpsertRenderingContext ctx = UpsertRenderingContext.of(new RenderContextFactory(dialect).createRenderContext()); + List insertCols = insertColumns.stream().map(id -> Column.create(id, table)).toList(); + List conflictCols = conflictColumns.stream().map(id -> Column.create(id, table)).toList(); + return upsertContext.renderer().render(table, new UpsertStatementRenderer.Columns(insertCols, conflictCols), ctx); + } @Test // GH-493 void standardUpsertRendersMergeInto() { - String sql = StandardSqlUpsertRenderContext.INSTANCE.renderUpsert(TABLE, - new Columns(INSERT_COLUMNS, CONFLICT_COLUMNS, IDENTIFIER_PROCESSING), BIND_MARKER); + String sql = render(StandardSqlUpsertRenderContext.INSTANCE, AnsiDialect.INSTANCE, TABLE, INSERT_COLUMNS, + CONFLICT_COLUMNS); assertThat(sql).isEqualTo( "MERGE INTO my_table \"_t\" USING (VALUES (:id, :name)) AS \"_s\" (id, name) ON \"_t\".id = \"_s\".id WHEN MATCHED THEN UPDATE SET \"_t\".name = \"_s\".name WHEN NOT MATCHED THEN INSERT (id, name) VALUES (\"_s\".id, \"_s\".name)"); @@ -54,9 +67,9 @@ class UpsertRenderContextUnitTests { List insertColumns = List.of(SqlIdentifier.unquoted("tenant_id"), SqlIdentifier.unquoted("id"), SqlIdentifier.unquoted("name")); List conflictColumns = List.of(SqlIdentifier.unquoted("tenant_id"), SqlIdentifier.unquoted("id")); - Columns columns = new Columns(insertColumns, conflictColumns, IDENTIFIER_PROCESSING); - String sql = StandardSqlUpsertRenderContext.INSTANCE.renderUpsert(TABLE, columns, BIND_MARKER); + String sql = render(StandardSqlUpsertRenderContext.INSTANCE, AnsiDialect.INSTANCE, TABLE, insertColumns, + conflictColumns); assertThat(sql).isEqualToIgnoringWhitespace( "MERGE INTO my_table \"_t\" USING (VALUES (:tenant_id, :id, :name)) AS \"_s\" (tenant_id, id, name) " @@ -68,8 +81,8 @@ class UpsertRenderContextUnitTests { @Test // GH-493 void postgresUpsertRendersInsertOnConflictDoUpdate() { - String sql = PostgresUpsertRenderContext.INSTANCE.renderUpsert(TABLE, - new Columns(INSERT_COLUMNS, CONFLICT_COLUMNS, IDENTIFIER_PROCESSING), BIND_MARKER); + String sql = render(PostgresUpsertRenderContext.INSTANCE, PostgresDialect.INSTANCE, TABLE, INSERT_COLUMNS, + CONFLICT_COLUMNS); assertThat(sql).isEqualToIgnoringWhitespace( "INSERT INTO my_table (id, name) VALUES (:id, :name) ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name"); @@ -78,29 +91,28 @@ class UpsertRenderContextUnitTests { @Test // GH-493 void postgresUpsertRendersInsertOnConflictDoNothing() { - String sql = PostgresUpsertRenderContext.INSTANCE.renderUpsert(TABLE, - new Columns(INSERT_COLUMNS, INSERT_COLUMNS, IDENTIFIER_PROCESSING), BIND_MARKER); + String sql = render(PostgresUpsertRenderContext.INSTANCE, PostgresDialect.INSTANCE, TABLE, INSERT_COLUMNS, + INSERT_COLUMNS); assertThat(sql).isEqualToIgnoringWhitespace( - "INSERT INTO my_table (id, name) VALUES (:id, :name) ON CONFLICT (id, name) DO NOTHING"); + "INSERT INTO my_table (id, name) VALUES (:id, :name) ON CONFLICT (id, name) DO NOTHING"); } @Test // GH-493 void mySqlUpsertRendersOnDuplicateKeyUpdate() { - String sql = MySqlUpsertRenderContext.INSTANCE.renderUpsert(TABLE, - new Columns(INSERT_COLUMNS, CONFLICT_COLUMNS, IDENTIFIER_PROCESSING), BIND_MARKER); + String sql = render(MySqlUpsertRenderContext.INSTANCE, MySqlDialect.INSTANCE, TABLE, INSERT_COLUMNS, + CONFLICT_COLUMNS); assertThat(sql).isEqualToIgnoringWhitespace( "INSERT INTO my_table (id, name) VALUES (:id, :name) ON DUPLICATE KEY UPDATE name = VALUES(name)"); } @Test // GH-493 - // TODO: should we have all values in the update or just a single one in this case. void mySqlUpsertRendersCorrectlyWhenUpdateCoversEntireKey() { - String sql = MySqlUpsertRenderContext.INSTANCE.renderUpsert(TABLE, - new Columns(INSERT_COLUMNS, INSERT_COLUMNS, IDENTIFIER_PROCESSING), BIND_MARKER); + String sql = render(MySqlUpsertRenderContext.INSTANCE, MySqlDialect.INSTANCE, TABLE, INSERT_COLUMNS, + INSERT_COLUMNS); assertThat(sql).isEqualToIgnoringWhitespace( "INSERT INTO my_table (id, name) VALUES (:id, :name) ON DUPLICATE KEY UPDATE id = VALUES(id), name = VALUES(name)"); @@ -109,8 +121,8 @@ class UpsertRenderContextUnitTests { @Test // GH-493 void oracleMergeUpsertRendersOnConditionInParentheses() { - String sql = OracleUpsertRenderContext.INSTANCE.renderUpsert(TABLE, - new Columns(INSERT_COLUMNS, CONFLICT_COLUMNS, IDENTIFIER_PROCESSING), BIND_MARKER); + String sql = render(OracleUpsertRenderContext.INSTANCE, OracleDialect.INSTANCE, TABLE, INSERT_COLUMNS, + CONFLICT_COLUMNS); assertThat(sql).isEqualToIgnoringWhitespace( "MERGE INTO my_table \"_t\" USING (SELECT :id AS id, :name AS name FROM DUAL) \"_s\" ON (\"_t\".id = \"_s\".id) WHEN MATCHED THEN UPDATE SET \"_t\".name = \"_s\".name WHEN NOT MATCHED THEN INSERT (id, name) VALUES (\"_s\".id, \"_s\".name)"); @@ -120,8 +132,7 @@ class UpsertRenderContextUnitTests { void standardMergeIdOnlyOmitsWhenMatchedUpdate() { List idOnly = List.of(SqlIdentifier.unquoted("id")); - String sql = StandardSqlUpsertRenderContext.INSTANCE.renderUpsert(TABLE, new Columns(idOnly, idOnly, - IDENTIFIER_PROCESSING), BIND_MARKER); + String sql = render(StandardSqlUpsertRenderContext.INSTANCE, AnsiDialect.INSTANCE, TABLE, idOnly, idOnly); assertThat(sql).isEqualTo( "MERGE INTO my_table \"_t\" USING (VALUES (:id)) AS \"_s\" (id) ON \"_t\".id = \"_s\".id WHEN NOT MATCHED THEN INSERT (id) VALUES (\"_s\".id)"); @@ -131,8 +142,7 @@ class UpsertRenderContextUnitTests { void oracleMergeIdOnlyOmitsWhenMatchedUpdate() { List idOnly = List.of(SqlIdentifier.unquoted("id")); - String sql = OracleUpsertRenderContext.INSTANCE.renderUpsert(TABLE, new Columns(idOnly, idOnly, - IDENTIFIER_PROCESSING), BIND_MARKER); + String sql = render(OracleUpsertRenderContext.INSTANCE, OracleDialect.INSTANCE, TABLE, idOnly, idOnly); assertThat(sql).isEqualToIgnoringWhitespace( "MERGE INTO my_table \"_t\" USING (SELECT :id AS id FROM DUAL) \"_s\" ON (\"_t\".id = \"_s\".id) WHEN NOT MATCHED THEN INSERT (id) VALUES (\"_s\".id)"); @@ -141,8 +151,8 @@ class UpsertRenderContextUnitTests { @Test // GH-493 void sqlServerUpsertRendersMergeWithSemicolon() { - String sql = SqlServerUpsertRenderContext.INSTANCE.renderUpsert(TABLE, - new Columns(INSERT_COLUMNS, CONFLICT_COLUMNS, IDENTIFIER_PROCESSING), BIND_MARKER); + String sql = render(SqlServerUpsertRenderContext.INSTANCE, SqlServerDialect.INSTANCE, TABLE, INSERT_COLUMNS, + CONFLICT_COLUMNS); assertThat(sql).contains("MERGE INTO"); assertThat(sql).contains("my_table");