Skip to content

Commit

Permalink
fix: EXPOSED-82 Inaccurate UShort column type mapping (#1799)
Browse files Browse the repository at this point in the history
* fix: EXPOSED-82 Inaccurate UShort column type mapping

Currently, when attempting to insert a UShort value outside of the range [0, 32767],
Exposed truncates the value by calling value.toShort() before sending it to the
DB, causing overflow. The value is stored successfully as a negative number because
all databases (except MySQL and MariaDB) don't support unsigned types natively,
which means Exposed is actually mapping to 2-byte `SMALLINT`, which accepts
the range [-32768, 32767].

Change the default mapping to the next higher-up integer data type INT (technically
a 4-byte storage type) and remove the truncation conversions so that an accurate
value is sent/received to/from the database. To ensure that the intended behavior
cannot be overriden using exec() directly, a check constraint is auto-applied to
the column when registered if the database is not MySQL/MariaDB.

Oracle INT is an alias for NUMBER(38), so this has been overriden to a reduced
type NUMBER(6).
  • Loading branch information
bog-walk authored Jul 28, 2023
1 parent 799737c commit 9de706c
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -283,32 +283,33 @@ class ShortColumnType : ColumnType() {

/**
* Numeric column for storing unsigned 2-byte integers.
*
* **Note:** If the database being used is not MySQL or MariaDB, this column will represent the database's 4-byte
* integer type with a check constraint that ensures storage of only values between 0 and [UShort.MAX_VALUE] inclusive.
*/
class UShortColumnType : ColumnType() {
override fun sqlType(): String = currentDialect.dataTypeProvider.ushortType()
override fun valueFromDB(value: Any): UShort {
return when (value) {
is UShort -> value
is Short -> value.takeIf { it >= 0 }?.toUShort()
is Number -> value.toInt().takeIf { it >= 0 && it <= UShort.MAX_VALUE.toInt() }?.toUShort()
is Short -> value.toUShort()
is Number -> value.toInt().toUShort()
is String -> value.toUShort()
else -> error("Unexpected value of type Short: $value of ${value::class.qualifiedName}")
} ?: error("Negative value but type is UShort: $value")
}
}

override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
val v = when {
value is UShort && currentDialect is MysqlDialect -> value.toInt()
value is UShort -> value.toShort()
val v = when (value) {
is UShort -> value.toInt()
else -> value
}
super.setParameter(stmt, index, v)
}

override fun notNullValueToDB(value: Any): Any {
val v = when {
value is UShort && currentDialect is MysqlDialect -> value.toInt()
value is UShort -> value.toShort()
val v = when (value) {
is UShort -> value.toInt()
else -> value
}
return super.notNullValueToDB(v)
Expand Down
45 changes: 33 additions & 12 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {

private val checkConstraints = mutableListOf<Pair<String, Op<Boolean>>>()

private val generatedCheckPrefix = "chk_unsigned_"

/**
* Returns the table name in proper case.
* Should be called within transaction or default [tableName] will be returned.
Expand Down Expand Up @@ -483,8 +485,14 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
/** Creates a numeric column, with the specified [name], for storing 2-byte integers. */
fun short(name: String): Column<Short> = registerColumn(name, ShortColumnType())

/** Creates a numeric column, with the specified [name], for storing 2-byte unsigned integers. */
fun ushort(name: String): Column<UShort> = registerColumn(name, UShortColumnType())
/** Creates a numeric column, with the specified [name], for storing 2-byte unsigned integers.
*
* **Note:** If the database being used is not MySQL or MariaDB, this column will use the database's 4-byte
* integer type with a check constraint that ensures storage of only values between 0 and [UShort.MAX_VALUE] inclusive.
*/
fun ushort(name: String): Column<UShort> = registerColumn<UShort>(name, UShortColumnType()).apply {
check("$generatedCheckPrefix$name") { it.between(0u, UShort.MAX_VALUE) }
}

/** Creates a numeric column, with the specified [name], for storing 4-byte integers. */
fun integer(name: String): Column<Int> = registerColumn(name, IntegerColumnType())
Expand Down Expand Up @@ -1043,29 +1051,33 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {

/**
* Creates a check constraint in this column.
* @param name The name to identify the constraint, optional. Must be **unique** (case-insensitive) to this table, otherwise, the constraint will
* not be created. All names are [trimmed][String.trim], blank names are ignored and the database engine decides the default name.
* @param name The name to identify the constraint, optional. Must be **unique** (case-insensitive) to this table,
* otherwise, the constraint will not be created. All names are [trimmed][String.trim], blank names are ignored and
* the database engine decides the default name.
* @param op The expression against which the newly inserted values will be compared.
*/
fun <T> Column<T>.check(name: String = "", op: SqlExpressionBuilder.(Column<T>) -> Op<Boolean>): Column<T> = apply {
if (name.isEmpty() || table.checkConstraints.none { it.first.equals(name, true) }) {
table.checkConstraints.add(name to SqlExpressionBuilder.op(this))
} else {
exposedLogger.warn("A CHECK constraint with name '$name' was ignored because there is already one with that name")
exposedLogger
.warn("A CHECK constraint with name '$name' was ignored because there is already one with that name")
}
}

/**
* Creates a check constraint in this table.
* @param name The name to identify the constraint, optional. Must be **unique** (case-insensitive) to this table, otherwise, the constraint will
* not be created. All names are [trimmed][String.trim], blank names are ignored and the database engine decides the default name.
* @param name The name to identify the constraint, optional. Must be **unique** (case-insensitive) to this table,
* otherwise, the constraint will not be created. All names are [trimmed][String.trim], blank names are ignored and
* the database engine decides the default name.
* @param op The expression against which the newly inserted values will be compared.
*/
fun check(name: String = "", op: SqlExpressionBuilder.() -> Op<Boolean>) {
if (name.isEmpty() || checkConstraints.none { it.first.equals(name, true) }) {
checkConstraints.add(name to SqlExpressionBuilder.op())
} else {
exposedLogger.warn("A CHECK constraint with name '$name' was ignored because there is already one with that name")
exposedLogger
.warn("A CHECK constraint with name '$name' was ignored because there is already one with that name")
}
}

Expand Down Expand Up @@ -1104,7 +1116,8 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
return primaryKey?.let { primaryKey ->
val tr = TransactionManager.current()
val constraint = tr.db.identifierManager.cutIfNecessaryAndQuote(primaryKey.name)
return primaryKey.columns.joinToString(prefix = "CONSTRAINT $constraint PRIMARY KEY (", postfix = ")", transform = tr::identity)
return primaryKey.columns
.joinToString(prefix = "CONSTRAINT $constraint PRIMARY KEY (", postfix = ")", transform = tr::identity)
}
}

Expand Down Expand Up @@ -1140,10 +1153,17 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
}

if (checkConstraints.isNotEmpty()) {
checkConstraints.mapIndexed { index, (name, op) ->
val filteredChecks = if (currentDialect is MysqlDialect) {
checkConstraints
.filterNot { (name, _) -> name.startsWith(generatedCheckPrefix) }
.ifEmpty { null }
} else {
checkConstraints
}
filteredChecks?.mapIndexed { index, (name, op) ->
val resolvedName = name.ifBlank { "check_${tableName}_$index" }
CheckConstraint.from(this@Table, resolvedName, op).checkPart
}.joinTo(this, prefix = ", ")
}?.joinTo(this, prefix = ", ")
}

append(")")
Expand All @@ -1159,7 +1179,8 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
return createSequence + createTable + createConstraint
}

override fun modifyStatement(): List<String> = throw UnsupportedOperationException("Use modify on columns and indices")
override fun modifyStatement(): List<String> =
throw UnsupportedOperationException("Use modify on columns and indices")

override fun dropStatement(): List<String> {
val dropTable = buildString {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ abstract class DataTypeProvider {
/** Numeric type for storing 2-byte integers. */
open fun shortType(): String = "SMALLINT"

/** Numeric type for storing 2-byte unsigned integers. */
open fun ushortType(): String = "SMALLINT"
/** Numeric type for storing 2-byte unsigned integers.
*
* **Note:** If the database being used is not MySQL or MariaDB, this will represent the 4-byte integer type.
*/
open fun ushortType(): String = "INT"

/** Numeric type for storing 4-byte integers. */
open fun integerType(): String = "INT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.jetbrains.exposed.sql.transactions.TransactionManager
internal object OracleDataTypeProvider : DataTypeProvider() {
override fun byteType(): String = "SMALLINT"
override fun ubyteType(): String = "SMALLINT"
override fun ushortType(): String = "NUMBER(6)"
override fun integerType(): String = "NUMBER(12)"
override fun integerAutoincType(): String = "NUMBER(12)"
override fun uintegerType(): String = "NUMBER(13)"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
package org.jetbrains.exposed.sql.tests.shared.types

import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.Table
import org.jetbrains.exposed.sql.insert
import org.jetbrains.exposed.sql.selectAll
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.TestDB
import org.jetbrains.exposed.sql.tests.currentDialectTest
import org.jetbrains.exposed.sql.tests.shared.assertEqualCollections
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.jetbrains.exposed.sql.tests.shared.assertFailAndRollback
import org.jetbrains.exposed.sql.tests.shared.assertTrue
import org.jetbrains.exposed.sql.vendors.MysqlDialect
import org.junit.Test

class UnsignedColumnTypeTests : DatabaseTestsBase() {
object UByteTable : Table("uByteTable") {
val unsignedByte = ubyte("uByte")
object UByteTable : Table("ubyte_table") {
val unsignedByte = ubyte("ubyte")
}

object UShortTable : Table("uShortTable") {
val unsignedShort = ushort("uShort")
object UShortTable : Table("ushort_table") {
val unsignedShort = ushort("ushort")
}

object UIntTable : Table("uIntTable") {
val unsignedInt = uinteger("uInt")
object UIntTable : Table("uint_table") {
val unsignedInt = uinteger("uint")
}

object ULongTable : Table("uLongTable") {
val unsignedLong = ulong("uLong")
object ULongTable : Table("ulong_table") {
val unsignedLong = ulong("ulong")
}

@Test
Expand Down Expand Up @@ -52,6 +54,70 @@ class UnsignedColumnTypeTests : DatabaseTestsBase() {
}
}

@Test
fun testUShortWithCheckConstraint() {
withTables(UShortTable) {
val ddlEnding = if (currentDialectTest is MysqlDialect) {
"(ushort SMALLINT UNSIGNED NOT NULL)"
} else {
"CHECK (ushort BETWEEN 0 and ${UShort.MAX_VALUE}))"
}
assertTrue(UShortTable.ddl.single().endsWith(ddlEnding, ignoreCase = true))

val number = 49151.toUShort()
assertTrue(number in Short.MAX_VALUE.toUShort()..UShort.MAX_VALUE)

UShortTable.insert { it[unsignedShort] = number }

val result = UShortTable.selectAll()
assertEquals(number, result.single()[UShortTable.unsignedShort])

// test that column itself blocks same out-of-range value that compiler blocks
assertFailAndRollback("Check constraint violation (or out-of-range error in MySQL/MariaDB)") {
val tableName = UShortTable.nameInDatabaseCase()
val columnName = UShortTable.unsignedShort.nameInDatabaseCase()
val outOfRangeValue = UShort.MAX_VALUE + 1u
exec("""INSERT INTO $tableName ($columnName) VALUES ($outOfRangeValue)""")
}
}
}

@Test
fun testPreviousUShortColumnTypeWorksWithNewIntType() {
withDb(excludeSettings = listOf(TestDB.MYSQL, TestDB.MARIADB)) { testDb ->
try {
val tableName = UShortTable.nameInDatabaseCase()
val columnName = UShortTable.unsignedShort.nameInDatabaseCase()
// create table using previous column type SMALLINT
exec("""CREATE TABLE ${addIfNotExistsIfSupported()}$tableName ($columnName SMALLINT NOT NULL)""")

val number1 = Short.MAX_VALUE.toUShort()
UShortTable.insert { it[unsignedShort] = number1 }

val result1 = UShortTable.select { UShortTable.unsignedShort eq number1 }.count()
assertEquals(1, result1)

// SMALLINT maps to INTEGER in SQLite and NUMBER(38) in Oracle, so they will not throw OoR error
if (testDb != TestDB.SQLITE && testDb != TestDB.ORACLE) {
val number2 = (Short.MAX_VALUE + 1).toUShort()
assertFailAndRollback("Out-of-range (OoR) error") {
UShortTable.insert { it[unsignedShort] = number2 }
assertEquals(0, UShortTable.select { UShortTable.unsignedShort less 0u }.count())
}

// modify column to now have INT type
exec(UShortTable.unsignedShort.modifyStatement().first())
UShortTable.insert { it[unsignedShort] = number2 }

val result2 = UShortTable.selectAll().map { it[UShortTable.unsignedShort] }
assertEqualCollections(listOf(number1, number2), result2)
}
} finally {
SchemaUtils.drop(UShortTable)
}
}
}

@Test
fun testUIntColumnType() {
withTables(UIntTable) {
Expand Down

0 comments on commit 9de706c

Please sign in to comment.