Skip to content

Commit

Permalink
fix: EXPOSED-107 Incorrect mapping for UByte data type
Browse files Browse the repository at this point in the history
Currently, when attempting to insert a UByte value outside of the range [0, 127],
Exposed truncates the value by calling value.toByte() before sending it to the
DB, causing overflow. The value is stored successfully as a negative number because
all databases (except MySQL and MariaDB) don't support unsigned types natively,
which means Exposed is actually mapping to 1-byte TINYINT, which accepts
the range [-128, 127].

Note: SQL Server is an exception as that DB represents TINYINT as holding the unsigned range [0, 255]; so SQL Server requires no changes.

Change the default mapping to the next higher-up integer data type SMALLINT (technically
a 2-byte storage type) and remove the truncation conversions so that an accurate
value is sent/received to/from the database. To ensure that the intended behavior
cannot be overriden using exec() directly, a check constraint is auto-applied to
the column when registered if the database is not MySQL/MariaDB/SQL Server.

Oracle TINYINT is an alias for NUMBER(38), so this has been overriden to a reduced
type NUMBER(4).
  • Loading branch information
bog-walk committed Jul 28, 2023
1 parent 74fba00 commit ba007bc
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,33 +235,35 @@ class ByteColumnType : ColumnType() {

/**
* Numeric column for storing unsigned 1-byte integers.
*
* **Note:** If the database being used is not MySQL, MariaDB, or SQL Server, this column will represent the
* database's 2-byte integer type with a check constraint that ensures storage of only values
* between 0 and [UByte.MAX_VALUE] inclusive.
*/
class UByteColumnType : ColumnType() {
override fun sqlType(): String = currentDialect.dataTypeProvider.ubyteType()

override fun valueFromDB(value: Any): UByte {
return when (value) {
is UByte -> value
is Byte -> value.takeIf { it >= 0 }?.toUByte()
is Number -> value.toShort().takeIf { it >= 0 && it <= UByte.MAX_VALUE.toShort() }?.toUByte()
is Byte -> value.toUByte()
is Number -> value.toShort().toUByte()
is String -> value.toUByte()
else -> error("Unexpected value of type Byte: $value of ${value::class.qualifiedName}")
} ?: error("Negative value but type is UByte: $value")
}
}

override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
val v = when {
value is UByte && currentDialect is MysqlDialect -> value.toShort()
value is UByte -> value.toByte()
val v = when (value) {
is UByte -> value.toShort()
else -> value
}
super.setParameter(stmt, index, v)
}

override fun notNullValueToDB(value: Any): Any {
val v = when {
value is UByte && currentDialect is MysqlDialect -> value.toShort()
value is UByte -> value.toByte()
val v = when (value) {
is UByte -> value.toShort()
else -> value
}
return super.notNullValueToDB(v)
Expand Down
27 changes: 18 additions & 9 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,15 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
/** Creates a numeric column, with the specified [name], for storing 1-byte integers. */
fun byte(name: String): Column<Byte> = registerColumn(name, ByteColumnType())

/** Creates a numeric column, with the specified [name], for storing 1-byte unsigned integers. */
fun ubyte(name: String): Column<UByte> = registerColumn(name, UByteColumnType())
/** Creates a numeric column, with the specified [name], for storing 1-byte unsigned integers.
*
* **Note:** If the database being used is not MySQL, MariaDB, or SQL Server, this column will use the
* database's 2-byte integer type with a check constraint that ensures storage of only values
* between 0 and [UByte.MAX_VALUE] inclusive.
*/
fun ubyte(name: String): Column<UByte> = registerColumn<UByte>(name, UByteColumnType()).apply {
check("${generatedCheckPrefix}byte_$name") { it.between(0u, UByte.MAX_VALUE) }
}

/** Creates a numeric column, with the specified [name], for storing 2-byte integers. */
fun short(name: String): Column<Short> = registerColumn(name, ShortColumnType())
Expand Down Expand Up @@ -1153,13 +1160,15 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
}

if (checkConstraints.isNotEmpty()) {
val filteredChecks = if (currentDialect is MysqlDialect) {
checkConstraints
.filterNot { (name, _) -> name.startsWith(generatedCheckPrefix) }
.ifEmpty { null }
} else {
checkConstraints
}
val filteredChecks = when (currentDialect) {
is MysqlDialect -> checkConstraints.filterNot { (name, _) ->
name.startsWith(generatedCheckPrefix)
}
is SQLServerDialect -> checkConstraints.filterNot { (name, _) ->
name.startsWith("${generatedCheckPrefix}byte_")
}
else -> checkConstraints
}.ifEmpty { null }
filteredChecks?.mapIndexed { index, (name, op) ->
val resolvedName = name.ifBlank { "check_${tableName}_$index" }
CheckConstraint.from(this@Table, resolvedName, op).checkPart
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ abstract class DataTypeProvider {
/** Numeric type for storing 1-byte integers. */
open fun byteType(): String = "TINYINT"

/** Numeric type for storing 1-byte unsigned integers. */
open fun ubyteType(): String = "TINYINT"
/** Numeric type for storing 1-byte unsigned integers.
*
* **Note:** If the database being used is not MySQL, MariaDB, or SQL Server, this will represent the 2-byte
* integer type.
*/
open fun ubyteType(): String = "SMALLINT"

/** Numeric type for storing 2-byte integers. */
open fun shortType(): String = "SMALLINT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import org.jetbrains.exposed.sql.transactions.TransactionManager

internal object OracleDataTypeProvider : DataTypeProvider() {
override fun byteType(): String = "SMALLINT"
override fun ubyteType(): String = "SMALLINT"
override fun ubyteType(): String = "NUMBER(4)"
override fun ushortType(): String = "NUMBER(6)"
override fun integerType(): String = "NUMBER(12)"
override fun integerAutoincType(): String = "NUMBER(12)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ internal object PostgreSQLDataTypeProvider : DataTypeProvider() {
override fun blobType(): String = "bytea"
override fun uuidToDB(value: UUID): Any = value
override fun dateTimeType(): String = "TIMESTAMP"
override fun ubyteType(): String = "SMALLINT"
override fun jsonBType(): String = "JSONB"
override fun hexToDb(hexString: String): String = """E'\\x$hexString'"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.util.*

internal object SQLServerDataTypeProvider : DataTypeProvider() {
override fun ubyteType(): String {
return if ((currentDialect as? H2Dialect)?.h2Mode == H2Dialect.H2CompatibilityMode.SQLServer) {
"SMALLINT"
} else {
"TINYINT"
}
}
override fun integerAutoincType(): String = "INT IDENTITY(1,1)"
override fun longAutoincType(): String = "BIGINT IDENTITY(1,1)"
override fun binaryType(): String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.jetbrains.exposed.sql.tests.shared.assertFailAndRollback
import org.jetbrains.exposed.sql.tests.shared.assertTrue
import org.jetbrains.exposed.sql.vendors.MysqlDialect
import org.jetbrains.exposed.sql.vendors.SQLServerDialect
import org.junit.Test

class UnsignedColumnTypeTests : DatabaseTestsBase() {
Expand Down Expand Up @@ -41,6 +42,70 @@ class UnsignedColumnTypeTests : DatabaseTestsBase() {
}
}

@Test
fun testUByteWithCheckConstraint() {
withTables(UByteTable) {
val ddlEnding = when (currentDialectTest) {
is MysqlDialect -> "(ubyte TINYINT UNSIGNED NOT NULL)"
is SQLServerDialect -> "(ubyte TINYINT NOT NULL)"
else -> "CHECK (ubyte BETWEEN 0 and ${UByte.MAX_VALUE}))"
}
assertTrue(UByteTable.ddl.single().endsWith(ddlEnding, ignoreCase = true))

val number = 191.toUByte()
assertTrue(number in Byte.MAX_VALUE.toUByte()..UByte.MAX_VALUE)

UByteTable.insert { it[unsignedByte] = number }

val result = UByteTable.selectAll()
assertEquals(number, result.single()[UByteTable.unsignedByte])

// test that column itself blocks same out-of-range value that compiler blocks
assertFailAndRollback("Check constraint violation (or out-of-range error in MySQL/MariaDB/SQL Server)") {
val tableName = UByteTable.nameInDatabaseCase()
val columnName = UByteTable.unsignedByte.nameInDatabaseCase()
val outOfRangeValue = UByte.MAX_VALUE + 1u
exec("""INSERT INTO $tableName ($columnName) VALUES ($outOfRangeValue)""")
}
}
}

@Test
fun testPreviousUByteColumnTypeWorksWithNewSmallIntType() {
// MySQL and MariaDB type hasn't changed, and PostgreSQL and Oracle never supported TINYINT
withDb(TestDB.allH2TestDB - TestDB.H2_PSQL + TestDB.SQLITE) { testDb ->
try {
val tableName = UByteTable.nameInDatabaseCase()
val columnName = UByteTable.unsignedByte.nameInDatabaseCase()
exec("""CREATE TABLE ${addIfNotExistsIfSupported()}$tableName ($columnName TINYINT NOT NULL)""")

val number1 = Byte.MAX_VALUE.toUByte()
UByteTable.insert { it[unsignedByte] = number1 }

val result1 = UByteTable.select { UByteTable.unsignedByte eq number1 }.count()
assertEquals(1, result1)

// TINYINT maps to INTEGER in SQLite, so it will not throw OoR error
if (testDb != TestDB.SQLITE) {
val number2 = (Byte.MAX_VALUE + 1).toUByte()
assertFailAndRollback("Out-of-range (OoR) error") {
UByteTable.insert { it[unsignedByte] = number2 }
assertEquals(0, UByteTable.select { UByteTable.unsignedByte less 0u }.count())
}

// modify column to now have SMALLINT type
exec(UByteTable.unsignedByte.modifyStatement().first())
UByteTable.insert { it[unsignedByte] = number2 }

val result2 = UByteTable.selectAll().map { it[UByteTable.unsignedByte] }
assertEqualCollections(listOf(number1, number2), result2)
}
} finally {
SchemaUtils.drop(UByteTable)
}
}
}

@Test
fun testUShortColumnType() {
withTables(UShortTable) {
Expand Down

0 comments on commit ba007bc

Please sign in to comment.