diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/CaseExpression.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/CaseExpression.java index 81c423220..78fcfb46f 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/CaseExpression.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/CaseExpression.java @@ -1,9 +1,11 @@ package org.springframework.data.relational.core.sql; +import org.springframework.lang.Nullable; + import java.util.ArrayList; import java.util.List; -import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.*; /** * Case with one or more conditions expression. @@ -22,73 +24,64 @@ import static java.util.stream.Collectors.joining; * @since 3.4 */ public class CaseExpression extends AbstractSegment implements Expression { - private final List whenList; - private final Expression elseExpression; - - private CaseExpression(List whenList, Expression elseExpression) { - - super(children(whenList, elseExpression)); - this.whenList = whenList; - this.elseExpression = elseExpression; - } - - /** - * Create CASE {@link Expression} with initial {@link When} condition. - * @param condition initial {@link When} condition - * @return the {@link CaseExpression} - */ - public static CaseExpression create(When condition) { - return new CaseExpression(List.of(condition), null); - } - - /** - * Add additional {@link When} condition - * @param condition the {@link When} condition - * @return the {@link CaseExpression} - */ - public CaseExpression when(When condition) { - List conditions = new ArrayList<>(this.whenList); - conditions.add(condition); - return new CaseExpression(conditions, elseExpression); - } - - /** - * Add ELSE clause - * @param elseExpression the {@link Expression} else value - * @return the {@link CaseExpression} - */ - public CaseExpression elseExpression(Literal elseExpression) { - return new CaseExpression(whenList, elseExpression); - } - - /** - * @return the {@link When} conditions - */ - public List getWhenList() { - return whenList; - } - - /** - * @return the ELSE {@link Literal} value - */ - public Expression getElseExpression() { - return elseExpression; - } - - @Override - public String toString() { - return "CASE " + whenList.stream().map(When::toString).collect(joining(" ")) + (elseExpression != null ? " ELSE " + elseExpression : "") + " END"; - } - - private static Segment[] children(List whenList, Expression elseExpression) { - - List segments = new ArrayList<>(); - segments.addAll(whenList); - - if (elseExpression != null) { - segments.add(elseExpression); - } - - return segments.toArray(new Segment[segments.size()]); - } + + private final List whenList; + @Nullable + private final Expression elseExpression; + + private static Segment[] children(List whenList, @Nullable Expression elseExpression) { + + List segments = new ArrayList<>(whenList); + + if (elseExpression != null) { + segments.add(elseExpression); + } + + return segments.toArray(new Segment[0]); + } + + private CaseExpression(List whenList, @Nullable Expression elseExpression) { + + super(children(whenList, elseExpression)); + + this.whenList = whenList; + this.elseExpression = elseExpression; + } + + /** + * Create CASE {@link Expression} with initial {@link When} condition. + * + * @param condition initial {@link When} condition + * @return the {@link CaseExpression} + */ + public static CaseExpression create(When condition) { + return new CaseExpression(List.of(condition), null); + } + + /** + * Add additional {@link When} condition + * + * @param condition the {@link When} condition + * @return the {@link CaseExpression} + */ + public CaseExpression when(When condition) { + List conditions = new ArrayList<>(this.whenList); + conditions.add(condition); + return new CaseExpression(conditions, elseExpression); + } + + /** + * Add ELSE clause + * + * @param elseExpression the {@link Expression} else value + * @return the {@link CaseExpression} + */ + public CaseExpression elseExpression(Expression elseExpression) { + return new CaseExpression(whenList, elseExpression); + } + + @Override + public String toString() { + return "CASE " + whenList.stream().map(When::toString).collect(joining(" ")) + (elseExpression != null ? " ELSE " + elseExpression : "") + " END"; + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/When.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/When.java index e90338a92..43ea34316 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/When.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/When.java @@ -10,6 +10,7 @@ package org.springframework.data.relational.core.sql; * @since 3.4 */ public class When extends AbstractSegment { + private final Condition condition; private final Expression value; diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/ExpressionVisitor.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/ExpressionVisitor.java index c3bdcebcd..65843cd34 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/ExpressionVisitor.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/ExpressionVisitor.java @@ -24,7 +24,7 @@ import org.springframework.util.Assert; * * @author Mark Paluch * @author Jens Schauder - * @since 1.1 + * @author Sven Rienstra * @see Column * @see SubselectExpression */ @@ -48,7 +48,7 @@ class ExpressionVisitor extends TypedSubtreeVisitor implements PartR /** * Creates an {@code ExpressionVisitor}. * - * @param context must not be {@literal null}. + * @param context must not be {@literal null}. * @param aliasHandling controls if columns should be rendered as their alias or using their table names. * @since 2.3 */ @@ -109,6 +109,7 @@ class ExpressionVisitor extends TypedSubtreeVisitor implements PartR partRenderer = visitor; return Delegation.delegateTo(visitor); } else if (segment instanceof CaseExpression) { + CaseExpressionVisitor visitor = new CaseExpressionVisitor(context); partRenderer = visitor; return Delegation.delegateTo(visitor); @@ -132,7 +133,7 @@ class ExpressionVisitor extends TypedSubtreeVisitor implements PartR if (segment instanceof InlineQuery) { - NoopVisitor partRenderer = new NoopVisitor(InlineQuery.class); + NoopVisitor partRenderer = new NoopVisitor<>(InlineQuery.class); return Delegation.delegateTo(partRenderer); } return super.enterNested(segment); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitor.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitor.java index b536b8998..0ac551dc9 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitor.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitor.java @@ -16,8 +16,8 @@ package org.springframework.data.relational.core.sql.render; -import org.springframework.data.relational.core.sql.Column; import org.springframework.data.relational.core.sql.CaseExpression; +import org.springframework.data.relational.core.sql.Column; import org.springframework.data.relational.core.sql.Expressions; import org.springframework.data.relational.core.sql.OrderByField; import org.springframework.data.relational.core.sql.SimpleFunction; @@ -31,6 +31,7 @@ import org.springframework.lang.Nullable; * @author Jens Schauder * @author Chirag Tailor * @author Koen Punt + * @author Sven Rienstra * @since 1.1 */ class OrderByClauseVisitor extends TypedSubtreeVisitor implements PartRenderer { @@ -39,7 +40,8 @@ class OrderByClauseVisitor extends TypedSubtreeVisitor implements private final StringBuilder builder = new StringBuilder(); - @Nullable private PartRenderer delegate; + @Nullable + private PartRenderer delegate; private boolean first = true; @@ -69,7 +71,7 @@ class OrderByClauseVisitor extends TypedSubtreeVisitor implements String nullPrecedence = context.getSelectRenderContext().evaluateOrderByNullHandling(segment.getNullHandling()); if (!nullPrecedence.isEmpty()) { - + builder.append(" ") // .append(nullPrecedence); } @@ -82,12 +84,12 @@ class OrderByClauseVisitor extends TypedSubtreeVisitor implements if (segment instanceof SimpleFunction) { delegate = new SimpleFunctionVisitor(context); - return Delegation.delegateTo((SimpleFunctionVisitor)delegate); + return Delegation.delegateTo((SimpleFunctionVisitor) delegate); } if (segment instanceof Expressions.SimpleExpression || segment instanceof CaseExpression) { delegate = new ExpressionVisitor(context); - return Delegation.delegateTo((ExpressionVisitor)delegate); + return Delegation.delegateTo((ExpressionVisitor) delegate); } return super.enterNested(segment); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/WhenVisitor.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/WhenVisitor.java index 70e8c1da5..ed872d805 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/WhenVisitor.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/WhenVisitor.java @@ -10,6 +10,7 @@ import org.springframework.data.relational.core.sql.When; * @since 3.4 */ public class WhenVisitor extends TypedSingleConditionRenderSupport implements PartRenderer { + private final StringBuilder part = new StringBuilder(); private boolean conditionRendered; diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitorUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitorUnitTests.java index baadc8b01..af8981fd5 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitorUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/OrderByClauseVisitorUnitTests.java @@ -29,6 +29,7 @@ import java.util.List; * @author Mark Paluch * @author Jens Schauder * @author Koen Punt + * @author Sven Rienstra */ class OrderByClauseVisitorUnitTests { @@ -125,15 +126,16 @@ class OrderByClauseVisitorUnitTests { @Test void shouldRenderOrderByCase() { + Table employee = SQL.table("employee").as("emp"); Column column = employee.column("name"); - CaseExpression caseExpression = CaseExpression.create(When.when(column.isNull(), SQL.literalOf(1))).elseExpression(SQL.literalOf(2)); + CaseExpression caseExpression = CaseExpression.create(When.when(column.isNull(), SQL.literalOf(1))).elseExpression(SQL.literalOf(column)); Select select = Select.builder().select(column).from(employee).orderBy(OrderByField.from(caseExpression).asc()).build(); OrderByClauseVisitor visitor = new OrderByClauseVisitor(new SimpleRenderContext(NamingStrategies.asIs())); select.visit(visitor); - assertThat(visitor.getRenderedPart().toString()).isEqualTo("CASE WHEN emp.name IS NULL THEN 1 ELSE 2 END ASC"); + assertThat(visitor.getRenderedPart().toString()).isEqualTo("CASE WHEN emp.name IS NULL THEN 1 ELSE emp.name END ASC"); } } diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/SelectRendererUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/SelectRendererUnitTests.java index 1251aefdd..4bc4d0c79 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/SelectRendererUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sql/render/SelectRendererUnitTests.java @@ -31,6 +31,7 @@ import java.util.List; * * @author Mark Paluch * @author Jens Schauder + * @author Sven Rienstra */ class SelectRendererUnitTests { @@ -690,11 +691,12 @@ class SelectRendererUnitTests { @Test void rendersCaseExpression() { + Table table = SQL.table("table"); Column column = table.column("name"); CaseExpression caseExpression = CaseExpression.create(When.when(column.isNull(), SQL.literalOf(1))) // - .when(When.when(column.isNotNull(), SQL.literalOf(2))) // + .when(When.when(column.isNotNull(), column)) // .elseExpression(SQL.literalOf(3)); Select select = StatementBuilder.select(caseExpression) // @@ -702,7 +704,7 @@ class SelectRendererUnitTests { .build(); String rendered = SqlRenderer.toString(select); - assertThat(rendered).isEqualTo("SELECT CASE WHEN table.name IS NULL THEN 1 WHEN table.name IS NOT NULL THEN 2 ELSE 3 END FROM table"); + assertThat(rendered).isEqualTo("SELECT CASE WHEN table.name IS NULL THEN 1 WHEN table.name IS NOT NULL THEN table.name ELSE 3 END FROM table"); } /**