Skip to content

Commit

Permalink
fix: EXPOSED-54 CaseWhen.Else returns narrow Expression<R> (#1800)
Browse files Browse the repository at this point in the history
* fix: EXPOSED-54 CaseWhen.Else returns narrow Expression<R>

The class CaseWhen has a function Else() with a return type Expression<R>, which
creates an instance of the CaseWhenElse class. According to commit history, the
CaseWhenElse class uses to also extend Expression<R>, but this was changed after a
few years to extend ExpressionWithColumnType<R> instead. Inspite of this change,
the return type of Else() did not change.

This means that the same case written as a more verbose CaseWhenElse instance
or written as a CaseWhen instance invoking Else() cannot be used interchangeably,
for example as the argument in Coalesce(), which takes ExpressionWithColumnType<*>.

Update Else() return type to be as broad as the instance it returns.

Remove extra whitespace in Case query builder.

Add unit test to ensure consistency.

Fix Detekt issues.
  • Loading branch information
bog-walk authored Jul 27, 2023
1 parent 55ed36b commit e55f3ef
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 23 deletions.
2 changes: 1 addition & 1 deletion exposed-core/api/exposed-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public final class org/jetbrains/exposed/sql/Case {

public final class org/jetbrains/exposed/sql/CaseWhen {
public fun <init> (Lorg/jetbrains/exposed/sql/Expression;)V
public final fun Else (Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/Expression;
public final fun Else (Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/ExpressionWithColumnType;
public final fun When (Lorg/jetbrains/exposed/sql/Expression;Lorg/jetbrains/exposed/sql/Expression;)Lorg/jetbrains/exposed/sql/CaseWhen;
public final fun getCases ()Ljava/util/List;
public final fun getValue ()Lorg/jetbrains/exposed/sql/Expression;
Expand Down
52 changes: 35 additions & 17 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ class Random(
class CharLength<T : String?>(
val expr: Expression<T>
) : Function<Int?>(IntegerColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.charLength(expr, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
currentDialect.functionProvider.charLength(expr, queryBuilder)
}
}

/**
Expand Down Expand Up @@ -105,7 +107,9 @@ class Concat(
/** Returns the expressions being concatenated. */
vararg val expr: Expression<*>
) : Function<String>(TextColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.concat(separator, queryBuilder, expr = expr)
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
currentDialect.functionProvider.concat(separator, queryBuilder, expr = expr)
}
}

/**
Expand All @@ -121,7 +125,9 @@ class GroupConcat<T : String?>(
/** Returns the order in which the elements of each group are sorted. */
vararg val orderBy: Pair<Expression<*>, SortOrder>
) : Function<T>(TextColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.groupConcat(this, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
currentDialect.functionProvider.groupConcat(this, queryBuilder)
}
}

/**
Expand All @@ -133,7 +139,9 @@ class Substring<T : String?>(
/** Returns the length of the substring. */
val length: Expression<Int>
) : Function<T>(TextColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.substring(expr, start, length, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
currentDialect.functionProvider.substring(expr, start, length, queryBuilder)
}
}

/**
Expand Down Expand Up @@ -346,7 +354,9 @@ sealed class NextVal<T>(
columnType: IColumnType
) : Function<T>(columnType) {

override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.nextVal(seq, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
currentDialect.functionProvider.nextVal(seq, queryBuilder)
}

class IntNextVal(seq: Sequence) : NextVal<Int>(seq, IntegerColumnType())
class LongNextVal(seq: Sequence) : NextVal<Long>(seq, LongColumnType())
Expand All @@ -368,27 +378,33 @@ class CaseWhen<T>(val value: Expression<*>?) {
return this as CaseWhen<R>
}

fun <R : T> Else(e: Expression<R>): Expression<R> = CaseWhenElse(this, e)
fun <R : T> Else(e: Expression<R>): ExpressionWithColumnType<R> = CaseWhenElse(this, e)
}

class CaseWhenElse<T, R : T>(val caseWhen: CaseWhen<T>, val elseResult: Expression<R>) : ExpressionWithColumnType<R>(), ComplexExpression {
class CaseWhenElse<T, R : T>(
val caseWhen: CaseWhen<T>,
val elseResult: Expression<R>
) : ExpressionWithColumnType<R>(), ComplexExpression {

override val columnType: IColumnType =
(elseResult as? ExpressionWithColumnType<R>)?.columnType
?: caseWhen.cases.map { it.second }.filterIsInstance<ExpressionWithColumnType<*>>().firstOrNull()?.columnType
?: BooleanColumnType.INSTANCE

override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder {
append("CASE ")
if (caseWhen.value != null) {
+caseWhen.value
}
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
queryBuilder {
append("CASE ")
if (caseWhen.value != null) {
+caseWhen.value
+" "
}

for ((first, second) in caseWhen.cases) {
append(" WHEN ", first, " THEN ", second)
}
for ((first, second) in caseWhen.cases) {
append("WHEN ", first, " THEN ", second)
}

append(" ELSE ", elseResult, " END")
append(" ELSE ", elseResult, " END")
}
}
}

Expand Down Expand Up @@ -419,5 +435,7 @@ class Cast<T>(
val expr: Expression<*>,
columnType: IColumnType
) : Function<T>(columnType) {
override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = currentDialect.functionProvider.cast(expr, columnType, queryBuilder)
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
currentDialect.functionProvider.cast(expr, columnType, queryBuilder)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class ConditionsTests : DatabaseTestsBase() {
@Test
fun testTRUEandFALSEOps() {
withCitiesAndUsers { cities, _, _ ->
val allSities = cities.selectAll().toCityNameList()
val allCities = cities.selectAll().toCityNameList()
assertEquals(0L, cities.select { Op.FALSE }.count())
assertEquals(allSities.size.toLong(), cities.select { Op.TRUE }.count())
assertEquals(allCities.size.toLong(), cities.select { Op.TRUE }.count())
}
}

Expand Down Expand Up @@ -159,9 +159,9 @@ class ConditionsTests : DatabaseTestsBase() {
@Test
fun nullOpInCaseTest() {
withCitiesAndUsers { cities, _, _ ->
val caseCondition = Case().
When(Op.build { cities.id eq 1 }, Op.nullOp<String>()).
Else(cities.name)
val caseCondition = Case()
.When(Op.build { cities.id eq 1 }, Op.nullOp<String>())
.Else(cities.name)
var nullBranchWasExecuted = false
cities.slice(cities.id, cities.name, caseCondition).selectAll().forEach {
val result = it[caseCondition]
Expand All @@ -175,4 +175,39 @@ class ConditionsTests : DatabaseTestsBase() {
assertEquals(true, nullBranchWasExecuted)
}
}

@Test
fun testCaseWhenElseAsArgument() {
withCitiesAndUsers { cities, _, _ ->
val original = "ORIGINAL"
val copy = "COPY"
val condition = Op.build { cities.id eq 1 }

val caseCondition1 = Case()
.When(condition, stringLiteral(original))
.Else(Op.nullOp())
// Case().When().Else() invokes CaseWhenElse() so the 2 formats should be interchangeable as arguments
val caseCondition2 = CaseWhenElse(
Case().When(condition, stringLiteral(original)),
Op.nullOp()
)
val function1 = Coalesce(caseCondition1, stringLiteral(copy))
val function2 = Coalesce(caseCondition2, stringLiteral(copy))

// confirm both formats produce identical SQL
val query1 = cities.slice(cities.id, function1).selectAll().prepareSQL(this, prepared = false)
val query2 = cities.slice(cities.id, function2).selectAll().prepareSQL(this, prepared = false)
assertEquals(query1, query2)

val results1 = cities.slice(cities.id, function1).selectAll().toList()
cities.slice(cities.id, function2).selectAll().forEachIndexed { i, row ->
val currentId = row[cities.id]
val functionResult = row[function2]

assertEquals(if (currentId == 1) original else copy, functionResult)
assertEquals(currentId, results1[i][cities.id])
assertEquals(functionResult, results1[i][function1])
}
}
}
}

0 comments on commit e55f3ef

Please sign in to comment.