Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: EXPOSED-82 Inaccurate UShort column type mapping #1799

Merged
merged 4 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -283,32 +283,33 @@ class ShortColumnType : ColumnType() {

/**
* Numeric column for storing unsigned 2-byte integers.
*
* **Note:** If the database being used is not MySQL or MariaDB, this column will represent the database's 4-byte
* integer type with a check constraint that ensures storage of only values between 0 and [UShort.MAX_VALUE] inclusive.
*/
class UShortColumnType : ColumnType() {
override fun sqlType(): String = currentDialect.dataTypeProvider.ushortType()
override fun valueFromDB(value: Any): UShort {
return when (value) {
is UShort -> value
is Short -> value.takeIf { it >= 0 }?.toUShort()
is Number -> value.toInt().takeIf { it >= 0 && it <= UShort.MAX_VALUE.toInt() }?.toUShort()
is Short -> value.toUShort()
is Number -> value.toInt().toUShort()
is String -> value.toUShort()
else -> error("Unexpected value of type Short: $value of ${value::class.qualifiedName}")
} ?: error("Negative value but type is UShort: $value")
}
}

override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
val v = when {
value is UShort && currentDialect is MysqlDialect -> value.toInt()
value is UShort -> value.toShort()
val v = when (value) {
is UShort -> value.toInt()
else -> value
}
super.setParameter(stmt, index, v)
}

override fun notNullValueToDB(value: Any): Any {
val v = when {
value is UShort && currentDialect is MysqlDialect -> value.toInt()
value is UShort -> value.toShort()
val v = when (value) {
is UShort -> value.toInt()
else -> value
}
return super.notNullValueToDB(v)
Expand Down
45 changes: 33 additions & 12 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {

private val checkConstraints = mutableListOf<Pair<String, Op<Boolean>>>()

private val generatedCheckPrefix = "chk_unsigned_"

/**
* Returns the table name in proper case.
* Should be called within transaction or default [tableName] will be returned.
Expand Down Expand Up @@ -483,8 +485,14 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
/** Creates a numeric column, with the specified [name], for storing 2-byte integers. */
fun short(name: String): Column<Short> = registerColumn(name, ShortColumnType())

/** Creates a numeric column, with the specified [name], for storing 2-byte unsigned integers. */
fun ushort(name: String): Column<UShort> = registerColumn(name, UShortColumnType())
/** Creates a numeric column, with the specified [name], for storing 2-byte unsigned integers.
*
* **Note:** If the database being used is not MySQL or MariaDB, this column will use the database's 4-byte
* integer type with a check constraint that ensures storage of only values between 0 and [UShort.MAX_VALUE] inclusive.
*/
fun ushort(name: String): Column<UShort> = registerColumn<UShort>(name, UShortColumnType()).apply {
check("$generatedCheckPrefix$name") { it.between(0u, UShort.MAX_VALUE) }
}

/** Creates a numeric column, with the specified [name], for storing 4-byte integers. */
fun integer(name: String): Column<Int> = registerColumn(name, IntegerColumnType())
Expand Down Expand Up @@ -1043,29 +1051,33 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {

/**
* Creates a check constraint in this column.
* @param name The name to identify the constraint, optional. Must be **unique** (case-insensitive) to this table, otherwise, the constraint will
* not be created. All names are [trimmed][String.trim], blank names are ignored and the database engine decides the default name.
* @param name The name to identify the constraint, optional. Must be **unique** (case-insensitive) to this table,
* otherwise, the constraint will not be created. All names are [trimmed][String.trim], blank names are ignored and
* the database engine decides the default name.
* @param op The expression against which the newly inserted values will be compared.
*/
fun <T> Column<T>.check(name: String = "", op: SqlExpressionBuilder.(Column<T>) -> Op<Boolean>): Column<T> = apply {
if (name.isEmpty() || table.checkConstraints.none { it.first.equals(name, true) }) {
table.checkConstraints.add(name to SqlExpressionBuilder.op(this))
} else {
exposedLogger.warn("A CHECK constraint with name '$name' was ignored because there is already one with that name")
exposedLogger
.warn("A CHECK constraint with name '$name' was ignored because there is already one with that name")
}
}

/**
* Creates a check constraint in this table.
* @param name The name to identify the constraint, optional. Must be **unique** (case-insensitive) to this table, otherwise, the constraint will
* not be created. All names are [trimmed][String.trim], blank names are ignored and the database engine decides the default name.
* @param name The name to identify the constraint, optional. Must be **unique** (case-insensitive) to this table,
* otherwise, the constraint will not be created. All names are [trimmed][String.trim], blank names are ignored and
* the database engine decides the default name.
* @param op The expression against which the newly inserted values will be compared.
*/
fun check(name: String = "", op: SqlExpressionBuilder.() -> Op<Boolean>) {
if (name.isEmpty() || checkConstraints.none { it.first.equals(name, true) }) {
checkConstraints.add(name to SqlExpressionBuilder.op())
} else {
exposedLogger.warn("A CHECK constraint with name '$name' was ignored because there is already one with that name")
exposedLogger
.warn("A CHECK constraint with name '$name' was ignored because there is already one with that name")
}
}

Expand Down Expand Up @@ -1104,7 +1116,8 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
return primaryKey?.let { primaryKey ->
val tr = TransactionManager.current()
val constraint = tr.db.identifierManager.cutIfNecessaryAndQuote(primaryKey.name)
return primaryKey.columns.joinToString(prefix = "CONSTRAINT $constraint PRIMARY KEY (", postfix = ")", transform = tr::identity)
return primaryKey.columns
.joinToString(prefix = "CONSTRAINT $constraint PRIMARY KEY (", postfix = ")", transform = tr::identity)
}
}

Expand Down Expand Up @@ -1140,10 +1153,17 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
}

if (checkConstraints.isNotEmpty()) {
checkConstraints.mapIndexed { index, (name, op) ->
val filteredChecks = if (currentDialect is MysqlDialect) {
checkConstraints
.filterNot { (name, _) -> name.startsWith(generatedCheckPrefix) }
.ifEmpty { null }
} else {
checkConstraints
}
filteredChecks?.mapIndexed { index, (name, op) ->
val resolvedName = name.ifBlank { "check_${tableName}_$index" }
CheckConstraint.from(this@Table, resolvedName, op).checkPart
}.joinTo(this, prefix = ", ")
}?.joinTo(this, prefix = ", ")
}

append(")")
Expand All @@ -1159,7 +1179,8 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
return createSequence + createTable + createConstraint
}

override fun modifyStatement(): List<String> = throw UnsupportedOperationException("Use modify on columns and indices")
override fun modifyStatement(): List<String> =
throw UnsupportedOperationException("Use modify on columns and indices")

override fun dropStatement(): List<String> {
val dropTable = buildString {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ abstract class DataTypeProvider {
/** Numeric type for storing 2-byte integers. */
open fun shortType(): String = "SMALLINT"

/** Numeric type for storing 2-byte unsigned integers. */
open fun ushortType(): String = "SMALLINT"
/** Numeric type for storing 2-byte unsigned integers.
*
* **Note:** If the database being used is not MySQL or MariaDB, this will represent the 4-byte integer type.
*/
open fun ushortType(): String = "INT"

/** Numeric type for storing 4-byte integers. */
open fun integerType(): String = "INT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.jetbrains.exposed.sql.transactions.TransactionManager
internal object OracleDataTypeProvider : DataTypeProvider() {
override fun byteType(): String = "SMALLINT"
override fun ubyteType(): String = "SMALLINT"
override fun ushortType(): String = "NUMBER(6)"
override fun integerType(): String = "NUMBER(12)"
override fun integerAutoincType(): String = "NUMBER(12)"
override fun uintegerType(): String = "NUMBER(13)"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
package org.jetbrains.exposed.sql.tests.shared.types

import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.Table
import org.jetbrains.exposed.sql.insert
import org.jetbrains.exposed.sql.selectAll
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.TestDB
import org.jetbrains.exposed.sql.tests.currentDialectTest
import org.jetbrains.exposed.sql.tests.shared.assertEqualCollections
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.jetbrains.exposed.sql.tests.shared.assertFailAndRollback
import org.jetbrains.exposed.sql.tests.shared.assertTrue
import org.jetbrains.exposed.sql.vendors.MysqlDialect
import org.junit.Test

class UnsignedColumnTypeTests : DatabaseTestsBase() {
object UByteTable : Table("uByteTable") {
val unsignedByte = ubyte("uByte")
object UByteTable : Table("ubyte_table") {
val unsignedByte = ubyte("ubyte")
}

object UShortTable : Table("uShortTable") {
val unsignedShort = ushort("uShort")
object UShortTable : Table("ushort_table") {
val unsignedShort = ushort("ushort")
}

object UIntTable : Table("uIntTable") {
val unsignedInt = uinteger("uInt")
object UIntTable : Table("uint_table") {
val unsignedInt = uinteger("uint")
}

object ULongTable : Table("uLongTable") {
val unsignedLong = ulong("uLong")
object ULongTable : Table("ulong_table") {
val unsignedLong = ulong("ulong")
}

@Test
Expand Down Expand Up @@ -52,6 +54,70 @@ class UnsignedColumnTypeTests : DatabaseTestsBase() {
}
}

@Test
fun testUShortWithCheckConstraint() {
withTables(UShortTable) {
val ddlEnding = if (currentDialectTest is MysqlDialect) {
"(ushort SMALLINT UNSIGNED NOT NULL)"
} else {
"CHECK (ushort BETWEEN 0 and ${UShort.MAX_VALUE}))"
}
assertTrue(UShortTable.ddl.single().endsWith(ddlEnding, ignoreCase = true))

val number = 49151.toUShort()
assertTrue(number in Short.MAX_VALUE.toUShort()..UShort.MAX_VALUE)

UShortTable.insert { it[unsignedShort] = number }

val result = UShortTable.selectAll()
assertEquals(number, result.single()[UShortTable.unsignedShort])

// test that column itself blocks same out-of-range value that compiler blocks
assertFailAndRollback("Check constraint violation (or out-of-range error in MySQL/MariaDB)") {
val tableName = UShortTable.nameInDatabaseCase()
val columnName = UShortTable.unsignedShort.nameInDatabaseCase()
val outOfRangeValue = UShort.MAX_VALUE + 1u
exec("""INSERT INTO $tableName ($columnName) VALUES ($outOfRangeValue)""")
}
}
}

@Test
fun testPreviousUShortColumnTypeWorksWithNewIntType() {
withDb(excludeSettings = listOf(TestDB.MYSQL, TestDB.MARIADB)) { testDb ->
try {
val tableName = UShortTable.nameInDatabaseCase()
val columnName = UShortTable.unsignedShort.nameInDatabaseCase()
// create table using previous column type SMALLINT
exec("""CREATE TABLE ${addIfNotExistsIfSupported()}$tableName ($columnName SMALLINT NOT NULL)""")

val number1 = Short.MAX_VALUE.toUShort()
UShortTable.insert { it[unsignedShort] = number1 }

val result1 = UShortTable.select { UShortTable.unsignedShort eq number1 }.count()
assertEquals(1, result1)

// SMALLINT maps to INTEGER in SQLite and NUMBER(38) in Oracle, so they will not throw OoR error
if (testDb != TestDB.SQLITE && testDb != TestDB.ORACLE) {
val number2 = (Short.MAX_VALUE + 1).toUShort()
assertFailAndRollback("Out-of-range (OoR) error") {
UShortTable.insert { it[unsignedShort] = number2 }
assertEquals(0, UShortTable.select { UShortTable.unsignedShort less 0u }.count())
}

// modify column to now have INT type
exec(UShortTable.unsignedShort.modifyStatement().first())
UShortTable.insert { it[unsignedShort] = number2 }

val result2 = UShortTable.selectAll().map { it[UShortTable.unsignedShort] }
assertEqualCollections(listOf(number1, number2), result2)
}
} finally {
SchemaUtils.drop(UShortTable)
}
}
}

@Test
fun testUIntColumnType() {
withTables(UIntTable) {
Expand Down
Loading