From e55f3ef09e321446f70a45401b9a00067e87538f Mon Sep 17 00:00:00 2001 From: bog-walk <82039410+bog-walk@users.noreply.github.com> Date: Thu, 27 Jul 2023 04:30:15 -0400 Subject: [PATCH] fix: EXPOSED-54 CaseWhen.Else returns narrow Expression (#1800) * fix: EXPOSED-54 CaseWhen.Else returns narrow Expression The class CaseWhen has a function Else() with a return type Expression, which creates an instance of the CaseWhenElse class. According to commit history, the CaseWhenElse class uses to also extend Expression, but this was changed after a few years to extend ExpressionWithColumnType instead. Inspite of this change, the return type of Else() did not change. This means that the same case written as a more verbose CaseWhenElse instance or written as a CaseWhen instance invoking Else() cannot be used interchangeably, for example as the argument in Coalesce(), which takes ExpressionWithColumnType<*>. Update Else() return type to be as broad as the instance it returns. Remove extra whitespace in Case query builder. Add unit test to ensure consistency. Fix Detekt issues. --- exposed-core/api/exposed-core.api | 2 +- .../org/jetbrains/exposed/sql/Function.kt | 52 +++++++++++++------ .../sql/tests/shared/dml/ConditionsTests.kt | 45 ++++++++++++++-- 3 files changed, 76 insertions(+), 23 deletions(-) diff --git a/exposed-core/api/exposed-core.api b/exposed-core/api/exposed-core.api index da691569b1..5cc1b8fc7d 100644 --- a/exposed-core/api/exposed-core.api +++ b/exposed-core/api/exposed-core.api @@ -259,7 +259,7 @@ public final class org/jetbrains/exposed/sql/Case { public final class org/jetbrains/exposed/sql/CaseWhen { public fun (Lorg/jetbrains/exposed/sql/Expression;)V - public final fun Else (Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/Expression; + public final fun Else (Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/ExpressionWithColumnType; public final fun When (Lorg/jetbrains/exposed/sql/Expression;Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/CaseWhen; public final fun getCases ()Ljava/util/List; public final fun getValue ()Lorg/jetbrains/exposed/sql/Expression; diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt index d195852e26..d013a7942f 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt @@ -73,7 +73,9 @@ class Random( class CharLength( val expr: Expression ) : Function(IntegerColumnType()) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.charLength(expr, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder) { + currentDialect.functionProvider.charLength(expr, queryBuilder) + } } /** @@ -105,7 +107,9 @@ class Concat( /** Returns the expressions being concatenated. */ vararg val expr: Expression<*> ) : Function(TextColumnType()) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.concat(separator, queryBuilder, expr = expr) + override fun toQueryBuilder(queryBuilder: QueryBuilder) { + currentDialect.functionProvider.concat(separator, queryBuilder, expr = expr) + } } /** @@ -121,7 +125,9 @@ class GroupConcat( /** Returns the order in which the elements of each group are sorted. */ vararg val orderBy: Pair, SortOrder> ) : Function(TextColumnType()) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.groupConcat(this, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder) { + currentDialect.functionProvider.groupConcat(this, queryBuilder) + } } /** @@ -133,7 +139,9 @@ class Substring( /** Returns the length of the substring. */ val length: Expression ) : Function(TextColumnType()) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.substring(expr, start, length, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder) { + currentDialect.functionProvider.substring(expr, start, length, queryBuilder) + } } /** @@ -346,7 +354,9 @@ sealed class NextVal( columnType: IColumnType ) : Function(columnType) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.nextVal(seq, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder) { + currentDialect.functionProvider.nextVal(seq, queryBuilder) + } class IntNextVal(seq: Sequence) : NextVal(seq, IntegerColumnType()) class LongNextVal(seq: Sequence) : NextVal(seq, LongColumnType()) @@ -368,27 +378,33 @@ class CaseWhen(val value: Expression<*>?) { return this as CaseWhen } - fun Else(e: Expression): Expression = CaseWhenElse(this, e) + fun Else(e: Expression): ExpressionWithColumnType = CaseWhenElse(this, e) } -class CaseWhenElse(val caseWhen: CaseWhen, val elseResult: Expression) : ExpressionWithColumnType(), ComplexExpression { +class CaseWhenElse( + val caseWhen: CaseWhen, + val elseResult: Expression +) : ExpressionWithColumnType(), ComplexExpression { override val columnType: IColumnType = (elseResult as? ExpressionWithColumnType)?.columnType ?: caseWhen.cases.map { it.second }.filterIsInstance>().firstOrNull()?.columnType ?: BooleanColumnType.INSTANCE - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { - append("CASE ") - if (caseWhen.value != null) { - +caseWhen.value - } + override fun toQueryBuilder(queryBuilder: QueryBuilder) { + queryBuilder { + append("CASE ") + if (caseWhen.value != null) { + +caseWhen.value + +" " + } - for ((first, second) in caseWhen.cases) { - append(" WHEN ", first, " THEN ", second) - } + for ((first, second) in caseWhen.cases) { + append("WHEN ", first, " THEN ", second) + } - append(" ELSE ", elseResult, " END") + append(" ELSE ", elseResult, " END") + } } } @@ -419,5 +435,7 @@ class Cast( val expr: Expression<*>, columnType: IColumnType ) : Function(columnType) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.cast(expr, columnType, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder) { + currentDialect.functionProvider.cast(expr, columnType, queryBuilder) + } } diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ConditionsTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ConditionsTests.kt index 32def78f3b..5e5e90f815 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ConditionsTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ConditionsTests.kt @@ -14,9 +14,9 @@ class ConditionsTests : DatabaseTestsBase() { @Test fun testTRUEandFALSEOps() { withCitiesAndUsers { cities, _, _ -> - val allSities = cities.selectAll().toCityNameList() + val allCities = cities.selectAll().toCityNameList() assertEquals(0L, cities.select { Op.FALSE }.count()) - assertEquals(allSities.size.toLong(), cities.select { Op.TRUE }.count()) + assertEquals(allCities.size.toLong(), cities.select { Op.TRUE }.count()) } } @@ -159,9 +159,9 @@ class ConditionsTests : DatabaseTestsBase() { @Test fun nullOpInCaseTest() { withCitiesAndUsers { cities, _, _ -> - val caseCondition = Case(). - When(Op.build { cities.id eq 1 }, Op.nullOp()). - Else(cities.name) + val caseCondition = Case() + .When(Op.build { cities.id eq 1 }, Op.nullOp()) + .Else(cities.name) var nullBranchWasExecuted = false cities.slice(cities.id, cities.name, caseCondition).selectAll().forEach { val result = it[caseCondition] @@ -175,4 +175,39 @@ class ConditionsTests : DatabaseTestsBase() { assertEquals(true, nullBranchWasExecuted) } } + + @Test + fun testCaseWhenElseAsArgument() { + withCitiesAndUsers { cities, _, _ -> + val original = "ORIGINAL" + val copy = "COPY" + val condition = Op.build { cities.id eq 1 } + + val caseCondition1 = Case() + .When(condition, stringLiteral(original)) + .Else(Op.nullOp()) + // Case().When().Else() invokes CaseWhenElse() so the 2 formats should be interchangeable as arguments + val caseCondition2 = CaseWhenElse( + Case().When(condition, stringLiteral(original)), + Op.nullOp() + ) + val function1 = Coalesce(caseCondition1, stringLiteral(copy)) + val function2 = Coalesce(caseCondition2, stringLiteral(copy)) + + // confirm both formats produce identical SQL + val query1 = cities.slice(cities.id, function1).selectAll().prepareSQL(this, prepared = false) + val query2 = cities.slice(cities.id, function2).selectAll().prepareSQL(this, prepared = false) + assertEquals(query1, query2) + + val results1 = cities.slice(cities.id, function1).selectAll().toList() + cities.slice(cities.id, function2).selectAll().forEachIndexed { i, row -> + val currentId = row[cities.id] + val functionResult = row[function2] + + assertEquals(if (currentId == 1) original else copy, functionResult) + assertEquals(currentId, results1[i][cities.id]) + assertEquals(functionResult, results1[i][function1]) + } + } + } }