Skip to content

Commit

Permalink
EXPOSED-65 Deprecate select() and replace with where()
Browse files Browse the repository at this point in the history
Step 1 of DSL design changes

Introduce SelectBuilder interface to allow where to accept both FieldSet and
Query as a receiver.

Adjust all tests to use where().
  • Loading branch information
bog-walk committed Nov 17, 2023
1 parent c6fe30e commit 07abbfd
Show file tree
Hide file tree
Showing 61 changed files with 1,255 additions and 556 deletions.
30 changes: 24 additions & 6 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,31 @@ import org.jetbrains.exposed.sql.vendors.SQLServerDialect
import org.jetbrains.exposed.sql.vendors.currentDialect
import kotlin.sequences.Sequence

/**
* @sample org.jetbrains.exposed.sql.tests.shared.DMLTests.testSelect01
*/
inline fun FieldSet.select(where: SqlExpressionBuilder.() -> Op<Boolean>): Query = select(SqlExpressionBuilder.where())

@Deprecated(
message = "As the first step in DSL design changes, this will be removed in future releases, then slice() will be renamed to select()",
replaceWith = ReplaceWith("selectAll().where { where.invoke() }", "import org.jetbrains.exposed.sql.selectAll", "import org.jetbrains.exposed.sql.where"),
level = DeprecationLevel.WARNING
)
inline fun FieldSet.select(where: SqlExpressionBuilder.() -> Op<Boolean>): Query = selectAll().where(SqlExpressionBuilder.where())

@Deprecated(
message = "As the first step in DSL design changes, this will be removed in future releases, then slice() will be renamed to select()",
replaceWith = ReplaceWith("selectAll().where(where)", "import org.jetbrains.exposed.sql.selectAll", "import org.jetbrains.exposed.sql.where"),
level = DeprecationLevel.WARNING
)
fun FieldSet.select(where: Op<Boolean>): Query = Query(this, where)

sealed interface SelectBuilder

inline fun SelectBuilder.where(predicate: SqlExpressionBuilder.() -> Op<Boolean>): Query = where(SqlExpressionBuilder.predicate())

fun SelectBuilder.where(predicate: Op<Boolean>): Query = when (this) {
is FieldSet -> Query(this, predicate)
is Query -> this.where?.let {
error("WHERE clause is specified twice. Old value = '$it', new value = '$predicate'. Use either adjustWhere() or andWhere() instead.")
} ?: copy().adjustWhere { predicate }
}

/**
* @sample org.jetbrains.exposed.sql.tests.shared.DMLTests.testSelectDistinct
*/
Expand Down Expand Up @@ -384,7 +402,7 @@ private fun FieldSet.selectBatched(
var lastOffset = 0L
while (true) {
val query =
select { whereOp and (autoIncColumn greater lastOffset) }
selectAll().where { whereOp and (autoIncColumn greater lastOffset) }
.limit(batchSize)
.orderBy(autoIncColumn, SortOrder.ASC)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ enum class SortOrder(val code: String) {
DESC_NULLS_LAST(code = "DESC NULLS LAST")
}

open class Query(override var set: FieldSet, where: Op<Boolean>?) : AbstractQuery<Query>(set.source.targetTables()) {
open class Query(
override var set: FieldSet,
where: Op<Boolean>?
) : AbstractQuery<Query>(set.source.targetTables()), SelectBuilder {
var distinct: Boolean = false
protected set

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ typealias JoinCondition = Pair<Expression<*>, Expression<*>>
/**
* Represents a set of expressions, contained in the given column set.
*/
interface FieldSet {
interface FieldSet : SelectBuilder {
/** Return the column set that contains this field set. */
val source: ColumnSet

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ open class Entity<ID : Comparable<ID>>(val id: EntityID<ID>) {
val readValues: ResultRow
get() = _readValues ?: run {
val table = klass.table
_readValues = klass.searchQuery(Op.build { table.id eq id }).firstOrNull() ?: table.select { table.id eq id }.first()
_readValues = klass.searchQuery(Op.build { table.id eq id }).firstOrNull() ?: table.selectAll().where { table.id eq id }.first()
_readValues!!
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ abstract class EntityClass<ID : Comparable<ID>, out T : Entity<ID>>(
}
column to value
}

exp is Column && exp.table == table -> null
else -> exp to value
}
Expand Down Expand Up @@ -210,7 +211,7 @@ abstract class EntityClass<ID : Comparable<ID>, out T : Entity<ID>>(
open val dependsOnColumns: List<Column<out Any?>> get() = dependsOnTables.columns

open fun searchQuery(op: Op<Boolean>): Query =
dependsOnTables.slice(dependsOnColumns).select { op }.setForUpdateStatus()
dependsOnTables.slice(dependsOnColumns).where { op }.setForUpdateStatus()

/**
* Count the amount of entities that conform to the [op] statement.
Expand Down Expand Up @@ -532,7 +533,7 @@ abstract class EntityClass<ID : Comparable<ID>, out T : Entity<ID>>(
else -> (dependsOnColumns + linkTable.columns + sourceRefColumn).distinct()
}

val query = entityTables.slice(columns).select { sourceRefColumn inList idsToLoad }
val query = entityTables.slice(columns).where { sourceRefColumn inList idsToLoad }
val targetEntities = mutableMapOf<EntityID<ID>, T>()
val entitiesWithRefs = when (forUpdate) {
true -> query.forUpdate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class InnerTableLink<SID : Comparable<SID>, Source : Entity<SID>, ID : Comparabl

val (columns, entityTables) = columnsAndTables

val query = { target.wrapRows(entityTables.slice(columns).select { sourceColumn eq o.id }) }
val query = { target.wrapRows(entityTables.slice(columns).where { sourceColumn eq o.id }) }
return transaction.entityCache.getOrPutReferrers(o.id, sourceColumn, query).also {
o.storeReferenceInCache(sourceColumn, it)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,23 @@ class DefaultsTest : DatabaseTestsBase() {
val varcharType = currentDialectTest.dataTypeProvider.varcharType(100)
val q = db.identifierManager.quoteString
val baseExpression = "CREATE TABLE " + addIfNotExistsIfSupported() +
"${"t".inProperCase()} (" +
"${"id".inProperCase()} ${currentDialectTest.dataTypeProvider.integerAutoincType()} PRIMARY KEY, " +
"${"s".inProperCase()} $varcharType${testTable.s.constraintNamePart()} DEFAULT 'test' NOT NULL, " +
"${"sn".inProperCase()} $varcharType${testTable.sn.constraintNamePart()} DEFAULT 'testNullable' NULL, " +
"${"l".inProperCase()} ${currentDialectTest.dataTypeProvider.longType()}${testTable.l.constraintNamePart()} DEFAULT 42 NOT NULL, " +
"$q${"c".inProperCase()}$q CHAR${testTable.c.constraintNamePart()} DEFAULT 'X' NOT NULL, " +
"${"t1".inProperCase()} $dtType${testTable.t1.constraintNamePart()} ${currentDT.itOrNull()}, " +
"${"t2".inProperCase()} $dtType${testTable.t2.constraintNamePart()} ${nowExpression.itOrNull()}, " +
"${"t3".inProperCase()} $dtType${testTable.t3.constraintNamePart()} ${dtLiteral.itOrNull()}, " +
"${"t4".inProperCase()} DATE${testTable.t4.constraintNamePart()} ${dLiteral.itOrNull()}, " +
"${"t5".inProperCase()} $dtType${testTable.t5.constraintNamePart()} ${tsLiteral.itOrNull()}, " +
"${"t6".inProperCase()} $dtType${testTable.t6.constraintNamePart()} ${tsLiteral.itOrNull()}, " +
"${"t7".inProperCase()} $longType${testTable.t7.constraintNamePart()} ${durLiteral.itOrNull()}, " +
"${"t8".inProperCase()} $longType${testTable.t8.constraintNamePart()} ${durLiteral.itOrNull()}, " +
"${"t9".inProperCase()} $timeType${testTable.t9.constraintNamePart()} ${tLiteral.itOrNull()}, " +
"${"t10".inProperCase()} $timeType${testTable.t10.constraintNamePart()} ${tLiteral.itOrNull()}" +
")"
"${"t".inProperCase()} (" +
"${"id".inProperCase()} ${currentDialectTest.dataTypeProvider.integerAutoincType()} PRIMARY KEY, " +
"${"s".inProperCase()} $varcharType${testTable.s.constraintNamePart()} DEFAULT 'test' NOT NULL, " +
"${"sn".inProperCase()} $varcharType${testTable.sn.constraintNamePart()} DEFAULT 'testNullable' NULL, " +
"${"l".inProperCase()} ${currentDialectTest.dataTypeProvider.longType()}${testTable.l.constraintNamePart()} DEFAULT 42 NOT NULL, " +
"$q${"c".inProperCase()}$q CHAR${testTable.c.constraintNamePart()} DEFAULT 'X' NOT NULL, " +
"${"t1".inProperCase()} $dtType${testTable.t1.constraintNamePart()} ${currentDT.itOrNull()}, " +
"${"t2".inProperCase()} $dtType${testTable.t2.constraintNamePart()} ${nowExpression.itOrNull()}, " +
"${"t3".inProperCase()} $dtType${testTable.t3.constraintNamePart()} ${dtLiteral.itOrNull()}, " +
"${"t4".inProperCase()} DATE${testTable.t4.constraintNamePart()} ${dLiteral.itOrNull()}, " +
"${"t5".inProperCase()} $dtType${testTable.t5.constraintNamePart()} ${tsLiteral.itOrNull()}, " +
"${"t6".inProperCase()} $dtType${testTable.t6.constraintNamePart()} ${tsLiteral.itOrNull()}, " +
"${"t7".inProperCase()} $longType${testTable.t7.constraintNamePart()} ${durLiteral.itOrNull()}, " +
"${"t8".inProperCase()} $longType${testTable.t8.constraintNamePart()} ${durLiteral.itOrNull()}, " +
"${"t9".inProperCase()} $timeType${testTable.t9.constraintNamePart()} ${tLiteral.itOrNull()}, " +
"${"t10".inProperCase()} $timeType${testTable.t10.constraintNamePart()} ${tLiteral.itOrNull()}" +
")"

val expected = if (currentDialectTest is OracleDialect || currentDialectTest.h2Mode == H2Dialect.H2CompatibilityMode.Oracle) {
arrayListOf("CREATE SEQUENCE t_id_seq START WITH 1 MINVALUE 1 MAXVALUE 9223372036854775807", baseExpression)
Expand All @@ -274,7 +274,7 @@ class DefaultsTest : DatabaseTestsBase() {

val id1 = testTable.insertAndGetId { }

val row1 = testTable.select { testTable.id eq id1 }.single()
val row1 = testTable.selectAll().where { testTable.id eq id1 }.single()
assertEquals("test", row1[testTable.s])
assertEquals("testNullable", row1[testTable.sn])
assertEquals(42, row1[testTable.l])
Expand Down Expand Up @@ -309,7 +309,7 @@ class DefaultsTest : DatabaseTestsBase() {
val id = foo.insertAndGetId {
it[foo.name] = "bar"
}
val result = foo.select { foo.id eq id }.single()
val result = foo.selectAll().where { foo.id eq id }.single()

assertEquals(today, result[foo.defaultDateTime].toLocalDate())
assertEquals(today, result[foo.defaultDate])
Expand All @@ -332,7 +332,7 @@ class DefaultsTest : DatabaseTestsBase() {
it[foo.defaultDateTime] = nonDefaultDate
}

val result = foo.select { foo.id eq id }.single()
val result = foo.selectAll().where { foo.id eq id }.single()

assertEquals("bar", result[foo.name])
assertEqualDateTime(nonDefaultDate, result[foo.defaultDateTime])
Expand All @@ -341,7 +341,7 @@ class DefaultsTest : DatabaseTestsBase() {
it[foo.name] = "baz"
}

val result2 = foo.select { foo.id eq id }.single()
val result2 = foo.selectAll().where { foo.id eq id }.single()
assertEquals("baz", result2[foo.name])
assertEqualDateTime(nonDefaultDate, result2[foo.defaultDateTime])
}
Expand All @@ -358,7 +358,7 @@ class DefaultsTest : DatabaseTestsBase() {
foo.insert { it[dt] = LocalDateTime.of(2019, 1, 1, 1, 1) }
foo.insert { it[dt] = dt2020 }
foo.insert { it[dt] = LocalDateTime.of(2021, 1, 1, 1, 1) }
val count = foo.select { foo.dt.between(dt2020.minusWeeks(1), dt2020.plusWeeks(1)) }.count()
val count = foo.selectAll().where { foo.dt.between(dt2020.minusWeeks(1), dt2020.plusWeeks(1)) }.count()
assertEquals(1, count)
}
}
Expand Down Expand Up @@ -422,11 +422,11 @@ class DefaultsTest : DatabaseTestsBase() {
val timestampWithTimeZoneType = currentDialectTest.dataTypeProvider.timestampWithTimeZoneType()

val baseExpression = "CREATE TABLE " + addIfNotExistsIfSupported() +
"${"t".inProperCase()} (" +
"${"id".inProperCase()} ${currentDialectTest.dataTypeProvider.integerAutoincType()} PRIMARY KEY, " +
"${"t1".inProperCase()} $timestampWithTimeZoneType${testTable.t1.constraintNamePart()} ${timestampWithTimeZoneLiteral.itOrNull()}, " +
"${"t2".inProperCase()} $timestampWithTimeZoneType${testTable.t2.constraintNamePart()} ${timestampWithTimeZoneLiteral.itOrNull()}" +
")"
"${"t".inProperCase()} (" +
"${"id".inProperCase()} ${currentDialectTest.dataTypeProvider.integerAutoincType()} PRIMARY KEY, " +
"${"t1".inProperCase()} $timestampWithTimeZoneType${testTable.t1.constraintNamePart()} ${timestampWithTimeZoneLiteral.itOrNull()}, " +
"${"t2".inProperCase()} $timestampWithTimeZoneType${testTable.t2.constraintNamePart()} ${timestampWithTimeZoneLiteral.itOrNull()}" +
")"

val expected = if (currentDialectTest is OracleDialect ||
currentDialectTest.h2Mode == H2Dialect.H2CompatibilityMode.Oracle
Expand All @@ -443,7 +443,7 @@ class DefaultsTest : DatabaseTestsBase() {

val id1 = testTable.insertAndGetId { }

val row1 = testTable.select { testTable.id eq id1 }.single()
val row1 = testTable.selectAll().where { testTable.id eq id1 }.single()
assertEqualDateTime(nowWithTimeZone, row1[testTable.t1])
assertEqualDateTime(nowWithTimeZone, row1[testTable.t2])
}
Expand Down
Loading

0 comments on commit 07abbfd

Please sign in to comment.