diff --git a/exposed-core/api/exposed-core.api b/exposed-core/api/exposed-core.api index da691569b1..5cc1b8fc7d 100644 --- a/exposed-core/api/exposed-core.api +++ b/exposed-core/api/exposed-core.api @@ -259,7 +259,7 @@ public final class org/jetbrains/exposed/sql/Case { public final class org/jetbrains/exposed/sql/CaseWhen { public fun (Lorg/jetbrains/exposed/sql/Expression;)V - public final fun Else (Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/Expression; + public final fun Else (Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/ExpressionWithColumnType; public final fun When (Lorg/jetbrains/exposed/sql/Expression;Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/CaseWhen; public final fun getCases ()Ljava/util/List; public final fun getValue ()Lorg/jetbrains/exposed/sql/Expression; diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt index d195852e26..20ca398d92 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt @@ -73,7 +73,8 @@ class Random( class CharLength( val expr: Expression ) : Function(IntegerColumnType()) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.charLength(expr, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = + currentDialect.functionProvider.charLength(expr, queryBuilder) } /** @@ -105,7 +106,8 @@ class Concat( /** Returns the expressions being concatenated. */ vararg val expr: Expression<*> ) : Function(TextColumnType()) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.concat(separator, queryBuilder, expr = expr) + override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = + currentDialect.functionProvider.concat(separator, queryBuilder, expr = expr) } /** @@ -121,7 +123,8 @@ class GroupConcat( /** Returns the order in which the elements of each group are sorted. */ vararg val orderBy: Pair, SortOrder> ) : Function(TextColumnType()) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.groupConcat(this, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = + currentDialect.functionProvider.groupConcat(this, queryBuilder) } /** @@ -133,7 +136,8 @@ class Substring( /** Returns the length of the substring. */ val length: Expression ) : Function(TextColumnType()) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.substring(expr, start, length, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = + currentDialect.functionProvider.substring(expr, start, length, queryBuilder) } /** @@ -346,7 +350,8 @@ sealed class NextVal( columnType: IColumnType ) : Function(columnType) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.nextVal(seq, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = + currentDialect.functionProvider.nextVal(seq, queryBuilder) class IntNextVal(seq: Sequence) : NextVal(seq, IntegerColumnType()) class LongNextVal(seq: Sequence) : NextVal(seq, LongColumnType()) @@ -368,10 +373,13 @@ class CaseWhen(val value: Expression<*>?) { return this as CaseWhen } - fun Else(e: Expression): Expression = CaseWhenElse(this, e) + fun Else(e: Expression): ExpressionWithColumnType = CaseWhenElse(this, e) } -class CaseWhenElse(val caseWhen: CaseWhen, val elseResult: Expression) : ExpressionWithColumnType(), ComplexExpression { +class CaseWhenElse( + val caseWhen: CaseWhen, + val elseResult: Expression +) : ExpressionWithColumnType(), ComplexExpression { override val columnType: IColumnType = (elseResult as? ExpressionWithColumnType)?.columnType @@ -382,10 +390,11 @@ class CaseWhenElse(val caseWhen: CaseWhen, val elseResult: Expressi append("CASE ") if (caseWhen.value != null) { +caseWhen.value + +" " } for ((first, second) in caseWhen.cases) { - append(" WHEN ", first, " THEN ", second) + append("WHEN ", first, " THEN ", second) } append(" ELSE ", elseResult, " END") @@ -419,5 +428,6 @@ class Cast( val expr: Expression<*>, columnType: IColumnType ) : Function(columnType) { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.cast(expr, columnType, queryBuilder) + override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = + currentDialect.functionProvider.cast(expr, columnType, queryBuilder) } diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ConditionsTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ConditionsTests.kt index f68c38c046..8d017f9c2f 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ConditionsTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/ConditionsTests.kt @@ -13,9 +13,9 @@ class ConditionsTests : DatabaseTestsBase() { @Test fun testTRUEandFALSEOps() { withCitiesAndUsers { cities, _, _ -> - val allSities = cities.selectAll().toCityNameList() + val allCities = cities.selectAll().toCityNameList() assertEquals(0L, cities.select { Op.FALSE }.count()) - assertEquals(allSities.size.toLong(), cities.select { Op.TRUE }.count()) + assertEquals(allCities.size.toLong(), cities.select { Op.TRUE }.count()) } } @@ -158,9 +158,9 @@ class ConditionsTests : DatabaseTestsBase() { @Test fun nullOpInCaseTest() { withCitiesAndUsers { cities, _, _ -> - val caseCondition = Case(). - When(Op.build { cities.id eq 1 }, Op.nullOp()). - Else(cities.name) + val caseCondition = Case() + .When(Op.build { cities.id eq 1 }, Op.nullOp()) + .Else(cities.name) var nullBranchWasExecuted = false cities.slice(cities.id, cities.name, caseCondition).selectAll().forEach { val result = it[caseCondition] @@ -174,4 +174,40 @@ class ConditionsTests : DatabaseTestsBase() { assertEquals(true, nullBranchWasExecuted) } } + + @Test + fun testCaseWhenElseAsArgument() { + withCitiesAndUsers { cities, _, _ -> + addLogger(StdOutSqlLogger) + val original = "ORIGINAL" + val copy = "COPY" + val condition = Op.build { cities.id eq 1 } + + val caseCondition1 = Case() + .When(condition, stringLiteral(original)) + .Else(Op.nullOp()) + // Case().When().Else() invokes CaseWhenElse() so the 2 formats should be interchangeable as arguments + val caseCondition2 = CaseWhenElse( + Case().When(condition, stringLiteral(original)), + Op.nullOp() + ) + val function1 = Coalesce(caseCondition1, stringLiteral(copy)) + val function2 = Coalesce(caseCondition2, stringLiteral(copy)) + + // confirm both formats produce identical SQL + val query1 = cities.slice(cities.id, function1).selectAll().prepareSQL(this, prepared = false) + val query2 = cities.slice(cities.id, function2).selectAll().prepareSQL(this, prepared = false) + assertEquals(query1, query2) + + val results1 = cities.slice(cities.id, function1).selectAll().toList() + cities.slice(cities.id, function2).selectAll().forEachIndexed { i, row -> + val currentId = row[cities.id] + val functionResult = row[function2] + + assertEquals(if (currentId == 1) original else copy, functionResult) + assertEquals(currentId, results1[i][cities.id]) + assertEquals(functionResult, results1[i][function1]) + } + } + } }