Skip to content

Commit

Permalink
fix: EXPOSED-97 Unsigned column types truncate MySQL values (#1796)
Browse files Browse the repository at this point in the history
MySQL supports unsigned numeric types using the UNSIGNED prefix, which Exposed
uses for unsigned column types. Logic in these column classes, however, truncates
the values being sent to and received from the database. This logic has been
refactored to allow values in the full allowed range to actually be inserted/stored
in MySQL.

All unsigned column type tests are removed from the heavy DDLTests.kt suite and
placed in their own test file inside the 'types' package, along with a new unit test.

Fix Detekt issues in altered files.
  • Loading branch information
bog-walk authored Jul 20, 2023
1 parent e1bbf73 commit 212a616
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 85 deletions.
104 changes: 79 additions & 25 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class AutoIncColumnType(

/** Returns the name of the sequence used to generate new values for this auto-increment column. */
val autoincSeq: String?
get() = _autoincSeq.takeIf { currentDialect.supportsCreateSequence } ?: fallbackSeqName.takeIf { currentDialect.needsSequenceToAutoInc }
get() = _autoincSeq.takeIf { currentDialect.supportsCreateSequence }
?: fallbackSeqName.takeIf { currentDialect.needsSequenceToAutoInc }

val nextValExpression: NextVal<*>? get() = nextValValue.takeIf { autoincSeq != null }

Expand Down Expand Up @@ -156,11 +157,13 @@ class AutoIncColumnType(
}

/** Returns `true` if this is an auto-increment column, `false` otherwise. */
val IColumnType.isAutoInc: Boolean get() = this is AutoIncColumnType || (this is EntityIDColumnType<*> && idColumn.columnType.isAutoInc)
val IColumnType.isAutoInc: Boolean
get() = this is AutoIncColumnType || (this is EntityIDColumnType<*> && idColumn.columnType.isAutoInc)

/** Returns the name of the auto-increment sequence of this column. */
val Column<*>.autoIncColumnType: AutoIncColumnType?
get() = (columnType as? AutoIncColumnType) ?: (columnType as? EntityIDColumnType<*>)?.idColumn?.columnType as? AutoIncColumnType
get() = (columnType as? AutoIncColumnType)
?: (columnType as? EntityIDColumnType<*>)?.idColumn?.columnType as? AutoIncColumnType

internal fun IColumnType.rawSqlType(): IColumnType = when {
this is AutoIncColumnType -> delegate
Expand Down Expand Up @@ -240,19 +243,27 @@ class UByteColumnType : ColumnType() {
return when (value) {
is UByte -> value
is Byte -> value.takeIf { it >= 0 }?.toUByte()
is Number -> value.toByte().takeIf { it >= 0 }?.toUByte()
is Number -> value.toShort().takeIf { it >= 0 && it <= UByte.MAX_VALUE.toShort() }?.toUByte()
is String -> value.toUByte()
else -> error("Unexpected value of type Byte: $value of ${value::class.qualifiedName}")
} ?: error("negative value but type is UByte: $value")
} ?: error("Negative value but type is UByte: $value")
}

override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
val v = if (value is UByte) value.toByte() else value
val v = when {
value is UByte && currentDialect is MysqlDialect -> value.toShort()
value is UByte -> value.toByte()
else -> value
}
super.setParameter(stmt, index, v)
}

override fun notNullValueToDB(value: Any): Any {
val v = if (value is UByte) value.toByte() else value
val v = when {
value is UByte && currentDialect is MysqlDialect -> value.toShort()
value is UByte -> value.toByte()
else -> value
}
return super.notNullValueToDB(v)
}
}
Expand All @@ -279,19 +290,27 @@ class UShortColumnType : ColumnType() {
return when (value) {
is UShort -> value
is Short -> value.takeIf { it >= 0 }?.toUShort()
is Number -> value.toShort().takeIf { it >= 0 }?.toUShort()
is Number -> value.toInt().takeIf { it >= 0 && it <= UShort.MAX_VALUE.toInt() }?.toUShort()
is String -> value.toUShort()
else -> error("Unexpected value of type Short: $value of ${value::class.qualifiedName}")
} ?: error("negative value but type is UShort: $value")
} ?: error("Negative value but type is UShort: $value")
}

override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
val v = if (value is UShort) value.toShort() else value
val v = when {
value is UShort && currentDialect is MysqlDialect -> value.toInt()
value is UShort -> value.toShort()
else -> value
}
super.setParameter(stmt, index, v)
}

override fun notNullValueToDB(value: Any): Any {
val v = if (value is UShort) value.toShort() else value
val v = when {
value is UShort && currentDialect is MysqlDialect -> value.toInt()
value is UShort -> value.toShort()
else -> value
}
return super.notNullValueToDB(v)
}
}
Expand All @@ -318,19 +337,27 @@ class UIntegerColumnType : ColumnType() {
return when (value) {
is UInt -> value
is Int -> value.takeIf { it >= 0 }?.toUInt()
is Number -> value.toLong().takeIf { it >= 0 && it < UInt.MAX_VALUE.toLong() }?.toUInt()
is Number -> value.toLong().takeIf { it >= 0 && it <= UInt.MAX_VALUE.toLong() }?.toUInt()
is String -> value.toUInt()
else -> error("Unexpected value of type Int: $value of ${value::class.qualifiedName}")
} ?: error("negative value but type is UInt: $value")
} ?: error("Negative value but type is UInt: $value")
}

override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
val v = if (value is UInt) value.toInt() else value
val v = when {
value is UInt && currentDialect is MysqlDialect -> value.toLong()
value is UInt -> value.toInt()
else -> value
}
super.setParameter(stmt, index, v)
}

override fun notNullValueToDB(value: Any): Any {
val v = if (value is UInt) value.toInt() else value
val v = when {
value is UInt && currentDialect is MysqlDialect -> value.toLong()
value is UInt -> value.toInt()
else -> value
}
return super.notNullValueToDB(v)
}
}
Expand All @@ -357,19 +384,35 @@ class ULongColumnType : ColumnType() {
return when (value) {
is ULong -> value
is Long -> value.takeIf { it >= 0 }?.toULong()
is Number -> value.toLong().takeIf { it >= 0 }?.toULong()
is Number -> {
if (currentDialect is MysqlDialect) {
value.toString().toBigInteger().takeIf {
it >= "0".toBigInteger() && it <= ULong.MAX_VALUE.toString().toBigInteger()
}?.toString()?.toULong()
} else {
value.toLong().takeIf { it >= 0 }?.toULong()
}
}
is String -> value.toULong()
else -> error("Unexpected value of type Long: $value of ${value::class.qualifiedName}")
} ?: error("negative value but type is ULong: $value")
} ?: error("Negative value but type is ULong: $value")
}

override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
val v = if (value is ULong) value.toLong() else value
val v = when {
value is ULong && currentDialect is MysqlDialect -> value.toString()
value is ULong -> value.toLong()
else -> value
}
super.setParameter(stmt, index, v)
}

override fun notNullValueToDB(value: Any): Any {
val v = if (value is ULong) value.toLong() else value
val v = when {
value is ULong && currentDialect is MysqlDialect -> value.toString()
value is ULong -> value.toLong()
else -> value
}
return super.notNullValueToDB(v)
}
}
Expand Down Expand Up @@ -642,11 +685,17 @@ open class TextColumnType(collate: String? = null, val eagerLoading: Boolean = f
}
}

open class MediumTextColumnType(collate: String? = null, eagerLoading: Boolean = false) : TextColumnType(collate, eagerLoading) {
open class MediumTextColumnType(
collate: String? = null,
eagerLoading: Boolean = false
) : TextColumnType(collate, eagerLoading) {
override fun preciseType(): String = currentDialect.dataTypeProvider.mediumTextType()
}

open class LargeTextColumnType(collate: String? = null, eagerLoading: Boolean = false) : TextColumnType(collate, eagerLoading) {
open class LargeTextColumnType(
collate: String? = null,
eagerLoading: Boolean = false
) : TextColumnType(collate, eagerLoading) {
override fun preciseType(): String = currentDialect.dataTypeProvider.largeTextType()
}

Expand Down Expand Up @@ -784,7 +833,8 @@ class UUIDColumnType : ColumnType() {
}

companion object {
private val uuidRegexp = Regex("[0-9A-F]{8}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{12}", RegexOption.IGNORE_CASE)
private val uuidRegexp =
Regex("[0-9A-F]{8}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{12}", RegexOption.IGNORE_CASE)
}
}

Expand All @@ -802,10 +852,12 @@ class BooleanColumnType : ColumnType() {
else -> value.toString().toBoolean()
}

override fun nonNullValueToString(value: Any): String = currentDialect.dataTypeProvider.booleanToStatementString(value as Boolean)
override fun nonNullValueToString(value: Any): String =
currentDialect.dataTypeProvider.booleanToStatementString(value as Boolean)

override fun notNullValueToDB(value: Any): Any = when {
value is Boolean && (currentDialect is OracleDialect || currentDialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle) ->
value is Boolean &&
(currentDialect is OracleDialect || currentDialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle) ->
nonNullValueToString(value)
else -> value
}
Expand Down Expand Up @@ -871,7 +923,9 @@ class EnumerationNameColumnType<T : Enum<T>>(

@Suppress("UNCHECKED_CAST")
override fun valueFromDB(value: Any): T = when (value) {
is String -> enumConstants[value] ?: error("$value can't be associated with any from enum ${klass.qualifiedName}")
is String -> {
enumConstants[value] ?: error("$value can't be associated with any from enum ${klass.qualifiedName}")
}
is Enum<*> -> value as T
else -> error("$value of ${value::class.qualifiedName} is not valid for enum ${klass.qualifiedName}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,66 +725,6 @@ class DDLTests : DatabaseTestsBase() {
}
}

@Test fun testUByteColumnType() {
val UbyteTable = object : Table("ubyteTable") {
val ubyte = ubyte("ubyte")
}

withTables(UbyteTable) {
UbyteTable.insert {
it[ubyte] = 123u
}
val result = UbyteTable.selectAll().toList()
assertEquals(1, result.size)
assertEquals(123u, result.single()[UbyteTable.ubyte])
}
}

@Test fun testUshortColumnType() {
val UshortTable = object : Table("ushortTable") {
val ushort = ushort("ushort")
}

withTables(UshortTable) {
UshortTable.insert {
it[ushort] = 123u
}
val result = UshortTable.selectAll().toList()
assertEquals(1, result.size)
assertEquals(123u, result.single()[UshortTable.ushort])
}
}

@Test fun testUintColumnType() {
val UintTable = object : Table("uintTable") {
val uint = uinteger("uint")
}

withTables(UintTable) {
UintTable.insert {
it[uint] = 123u
}
val result = UintTable.selectAll().toList()
assertEquals(1, result.size)
assertEquals(123u, result.single()[UintTable.uint])
}
}

@Test fun testUlongColumnType() {
val UlongTable = object : Table("ulongTable") {
val ulong = ulong("ulong")
}

withTables(UlongTable) {
UlongTable.insert {
it[ulong] = 123uL
}
val result = UlongTable.selectAll().toList()
assertEquals(1, result.size)
assertEquals(123uL, result.single()[UlongTable.ulong])
}
}

@Test fun tableWithDifferentTextTypes() {
val TestTable = object : Table("different_text_column_types") {
val id = integer("id").autoIncrement()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package org.jetbrains.exposed.sql.tests.shared.types

import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.Table
import org.jetbrains.exposed.sql.insert
import org.jetbrains.exposed.sql.selectAll
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.TestDB
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.junit.Test

class UnsignedColumnTypeTests : DatabaseTestsBase() {
object UByteTable : Table("uByteTable") {
val unsignedByte = ubyte("uByte")
}

object UShortTable : Table("uShortTable") {
val unsignedShort = ushort("uShort")
}

object UIntTable : Table("uIntTable") {
val unsignedInt = uinteger("uInt")
}

object ULongTable : Table("uLongTable") {
val unsignedLong = ulong("uLong")
}

@Test
fun testUByteColumnType() {
withTables(UByteTable) {
UByteTable.insert {
it[unsignedByte] = 123u
}

val result = UByteTable.selectAll().toList()
assertEquals(1, result.size)
assertEquals(123u, result.single()[UByteTable.unsignedByte])
}
}

@Test
fun testUShortColumnType() {
withTables(UShortTable) {
UShortTable.insert {
it[unsignedShort] = 123u
}

val result = UShortTable.selectAll().toList()
assertEquals(1, result.size)
assertEquals(123u, result.single()[UShortTable.unsignedShort])
}
}

@Test
fun testUIntColumnType() {
withTables(UIntTable) {
UIntTable.insert {
it[unsignedInt] = 123u
}

val result = UIntTable.selectAll().toList()
assertEquals(1, result.size)
assertEquals(123u, result.single()[UIntTable.unsignedInt])
}
}

@Test
fun testULongColumnType() {
withTables(ULongTable) {
ULongTable.insert {
it[unsignedLong] = 123uL
}

val result = ULongTable.selectAll().toList()
assertEquals(1, result.size)
assertEquals(123uL, result.single()[ULongTable.unsignedLong])
}
}

@Test
fun testMaxUnsignedTypesInMySql() {
withDb(listOf(TestDB.MYSQL, TestDB.MARIADB)) {
SchemaUtils.create(UByteTable, UShortTable, UIntTable, ULongTable)

UByteTable.insert { it[unsignedByte] = UByte.MAX_VALUE }
assertEquals(UByte.MAX_VALUE, UByteTable.selectAll().single()[UByteTable.unsignedByte])

UShortTable.insert { it[unsignedShort] = UShort.MAX_VALUE }
assertEquals(UShort.MAX_VALUE, UShortTable.selectAll().single()[UShortTable.unsignedShort])

UIntTable.insert { it[unsignedInt] = UInt.MAX_VALUE }
assertEquals(UInt.MAX_VALUE, UIntTable.selectAll().single()[UIntTable.unsignedInt])

ULongTable.insert { it[unsignedLong] = ULong.MAX_VALUE }
assertEquals(ULong.MAX_VALUE, ULongTable.selectAll().single()[ULongTable.unsignedLong])

SchemaUtils.drop(UByteTable, UShortTable, UIntTable, ULongTable)
}
}
}

0 comments on commit 212a616

Please sign in to comment.