From 85d77f5a19b086990578e03c9ec1310718b22b62 Mon Sep 17 00:00:00 2001 From: Jocelyne Date: Tue, 5 Mar 2024 01:21:40 +0100 Subject: [PATCH] trial --- .../org/jetbrains/exposed/dao/id/EntityID.kt | 17 +- .../kotlin/org/jetbrains/exposed/sql/Alias.kt | 4 +- .../org/jetbrains/exposed/sql/Column.kt | 17 +- .../org/jetbrains/exposed/sql/ColumnType.kt | 354 ++++++++---------- .../org/jetbrains/exposed/sql/Expression.kt | 14 +- .../org/jetbrains/exposed/sql/Function.kt | 72 ++-- .../kotlin/org/jetbrains/exposed/sql/Op.kt | 30 +- .../exposed/sql/SQLExpressionBuilder.kt | 38 +- .../org/jetbrains/exposed/sql/SchemaUtils.kt | 8 +- .../kotlin/org/jetbrains/exposed/sql/Table.kt | 30 +- .../org/jetbrains/exposed/sql/Transaction.kt | 6 +- .../jetbrains/exposed/sql/WindowFunction.kt | 2 +- .../sql/functions/array/ArrayFunctions.kt | 8 +- .../jetbrains/exposed/sql/ops/AllAnyOps.kt | 2 +- .../sql/statements/BatchUpdateStatement.kt | 2 +- .../exposed/sql/statements/DeleteStatement.kt | 2 +- .../sql/statements/InsertSelectStatement.kt | 2 +- .../exposed/sql/statements/InsertStatement.kt | 2 +- .../exposed/sql/statements/Statement.kt | 8 +- .../exposed/sql/statements/UpdateBuilder.kt | 17 +- .../exposed/sql/statements/UpdateStatement.kt | 10 +- .../exposed/sql/statements/UpsertStatement.kt | 6 +- .../statements/api/PreparedStatementApi.kt | 6 +- .../exposed/sql/vendors/FunctionProvider.kt | 8 +- .../exposed/sql/vendors/MysqlDialect.kt | 8 +- .../exposed/sql/vendors/OracleDialect.kt | 4 +- .../exposed/sql/vendors/PostgreSQL.kt | 8 +- .../exposed/sql/vendors/SQLServerDialect.kt | 2 +- .../exposed/sql/vendors/SQLiteDialect.kt | 4 +- .../crypt/EncryptedBinaryColumnType.kt | 19 +- .../crypt/EncryptedVarCharColumnType.kt | 15 +- .../sql/javatime/JavaDateColumnType.kt | 216 ++++------- .../org/jetbrains/exposed/DefaultsTest.kt | 2 +- .../sqlserver/SQLServerDefaultsTest.kt | 2 +- .../jdbc/JdbcPreparedStatementImpl.kt | 2 +- .../exposed/sql/jodatime/DateColumnType.kt | 114 +++--- .../jetbrains/exposed/JodaTimeDefaultsTest.kt | 2 +- .../exposed/sql/json/JsonColumnType.kt | 8 +- .../exposed/sql/json/JsonFunctions.kt | 4 +- .../kotlin/datetime/KotlinDateColumnType.kt | 208 ++++------ .../sqlserver/SQLServerDefaultsTest.kt | 2 +- .../exposed/sql/money/CurrencyColumnType.kt | 55 ++- 42 files changed, 581 insertions(+), 759 deletions(-) diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/dao/id/EntityID.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/dao/id/EntityID.kt index 4bea278b62..3cbceee5e6 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/dao/id/EntityID.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/dao/id/EntityID.kt @@ -18,16 +18,17 @@ open class EntityID> protected constructor(val table: IdTable< var _value: Any? = id /** The identity value of type [T] wrapped by this [EntityID] instance. */ - val value: T get() { - if (_value == null) { - invokeOnNoValue() - check(_value != null) { "Entity must be inserted" } + val value: T + get() { + if (_value == null) { + invokeOnNoValue() + check(_value != null) { "Entity must be inserted" } + } + + @Suppress("UNCHECKED_CAST") + return _value!! as T } - @Suppress("UNCHECKED_CAST") - return _value!! as T - } - /** Performs steps when the internal [_value] is accessed without first being initialized. */ protected open fun invokeOnNoValue() {} diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Alias.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Alias.kt index bc3f166f53..a62c61f271 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Alias.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Alias.kt @@ -8,7 +8,7 @@ class Alias(val delegate: T, val alias: String) : Table() { /** The table name along with its [alias]. */ val tableNameWithAlias: String = "${delegate.tableName} $alias" - private fun Column.clone() = Column(this@Alias, name, columnType) + private fun Column.clone() = Column(this@Alias, name, columnType) /** * Returns the original column from the [delegate] table, or `null` if the [column] is not associated @@ -111,7 +111,7 @@ class QueryAlias(val query: AbstractQuery<*>, val alias: String) : ColumnSet() { override infix fun crossJoin(otherTable: ColumnSet): Join = Join(this, otherTable, JoinType.CROSS) - private fun Column.clone() = Column(table.alias(alias), name, columnType) + private fun Column.clone() = Column(table.alias(alias), name, columnType) } /** diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Column.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Column.kt index f5311e1366..198691cc9a 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Column.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Column.kt @@ -21,7 +21,7 @@ class Column( /** Name of the column. */ val name: String, /** Data type of the column. */ - override val columnType: IColumnType + override val columnType: IColumnType ) : ExpressionWithColumnType(), DdlAware, Comparable> { /** The foreign key constraint on this column, or `null` if the column is not referencing. */ var foreignKey: ForeignKeyConstraint? = null @@ -63,12 +63,13 @@ class Column( private val isLastColumnInPK: Boolean get() = this == table.primaryKey?.columns?.last() - internal val isPrimaryConstraintWillBeDefined: Boolean get() = when { - currentDialect is SQLiteDialect && columnType.isAutoInc -> false - table.isCustomPKNameDefined() -> isLastColumnInPK - isOneColumnPK() -> false - else -> isLastColumnInPK - } + internal val isPrimaryConstraintWillBeDefined: Boolean + get() = when { + currentDialect is SQLiteDialect && columnType.isAutoInc -> false + table.isCustomPKNameDefined() -> isLastColumnInPK + isOneColumnPK() -> false + else -> isLastColumnInPK + } override fun createStatement(): List { val alterTablePrefix = "ALTER TABLE ${TransactionManager.current().identity(table)} ADD" @@ -154,7 +155,7 @@ class Column( /** * Returns a copy of this column, but with the given column type. */ - fun withColumnType(columnType: IColumnType) = Column( + fun withColumnType(columnType: IColumnType) = Column( table = this.table, name = this.name, columnType = columnType diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt index cbc11fc788..75d9e750d7 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt @@ -3,7 +3,6 @@ package org.jetbrains.exposed.sql import org.jetbrains.exposed.dao.id.EntityID import org.jetbrains.exposed.dao.id.EntityIDFunctionProvider import org.jetbrains.exposed.dao.id.IdTable -import org.jetbrains.exposed.sql.statements.DefaultValueMarker import org.jetbrains.exposed.sql.statements.api.ExposedBlob import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi import org.jetbrains.exposed.sql.vendors.* @@ -23,7 +22,7 @@ import kotlin.reflect.full.isSubclassOf /** * Interface common to all column types. */ -interface IColumnType { +interface IColumnType { /** Returns `true` if the column type is nullable, `false` otherwise. */ var nullable: Boolean @@ -34,31 +33,29 @@ interface IColumnType { * Converts the specified [value] (from the database) to an object of the appropriated type, for this column type. * Default implementation returns the same instance. */ - fun valueFromDB(value: Any): Any = value + fun valueFromDB(value: Any): T? /** Returns an object compatible with the database, from the specified [value], for this column type. */ - fun valueToDB(value: Any?): Any? = value?.let(::notNullValueToDB) + fun valueToDB(value: T?): Any? = value?.let(::notNullValueToDB) /** Returns an object compatible with the database, from the specified **non-null** [value], for this column type. */ - fun notNullValueToDB(value: Any): Any = value + fun notNullValueToDB(value: T & Any): Any = value /** * Returns the SQL representation of the specified [value], for this column type. * If the value is `null` and the column is not nullable, an exception will be thrown. * Used when generating an SQL statement and when logging that statement. */ - fun valueToString(value: Any?): String = when (value) { + fun valueToString(value: T?): String = when (value) { null -> { check(nullable) { "NULL in non-nullable column" } "NULL" } - DefaultValueMarker -> "DEFAULT" - is Iterable<*> -> value.joinToString(",", transform = ::valueToString) else -> nonNullValueToString(value) } /** Returns the SQL representation of the specified **non-null** [value], for this column type. */ - fun nonNullValueToString(value: Any): String = notNullValueToDB(value).toString() + fun nonNullValueToString(value: T & Any): String = notNullValueToDB(value).toString() /** * Returns the String representation of the specified [value] when [value] is set as the default for @@ -66,7 +63,7 @@ interface IColumnType { * If the value is `null` and the column is not nullable, an exception will be thrown. * Used for metadata default value comparison. */ - fun valueAsDefaultString(value: Any?): String = when (value) { + fun valueAsDefaultString(value: T?): String = when (value) { null -> { check(nullable) { "NULL in non-nullable column" } "NULL" @@ -78,7 +75,7 @@ interface IColumnType { * Returns the String representation of the specified **non-null** [value] when [value] is set as the default for * the column. */ - fun nonNullValueAsDefaultString(value: Any): String = nonNullValueToString(value) + fun nonNullValueAsDefaultString(value: T & Any): String = nonNullValueToString(value) /** Returns the object at the specified [index] in the [rs]. */ fun readObject(rs: ResultSet, index: Int): Any? = rs.getObject(index) @@ -97,22 +94,21 @@ interface IColumnType { * [value] can be of any type (including [Expression]) * */ @Throws(IllegalArgumentException::class) - fun validateValueBeforeUpdate(value: Any?) {} + fun validateValueBeforeUpdate(value: T?) {} } /** * Standard column type. */ -abstract class ColumnType(override var nullable: Boolean = false) : IColumnType { +abstract class ColumnType(override var nullable: Boolean = false) : IColumnType { override fun toString(): String = sqlType() override fun equals(other: Any?): Boolean { if (this === other) return true if (javaClass != other?.javaClass) return false - other as ColumnType + other as ColumnType<*> - if (nullable != other.nullable) return false - return true + return nullable == other.nullable } override fun hashCode(): Int = 31 * javaClass.hashCode() + nullable.hashCode() @@ -121,12 +117,12 @@ abstract class ColumnType(override var nullable: Boolean = false) : IColumnType /** * Auto-increment column type. */ -class AutoIncColumnType( +class AutoIncColumnType( /** Returns the base column type of this auto-increment column. */ - val delegate: ColumnType, + val delegate: ColumnType, private val _autoincSeq: String?, private val fallbackSeqName: String -) : IColumnType by delegate { +) : IColumnType by delegate { private val nextValValue = run { val sequence = Sequence(_autoincSeq ?: fallbackSeqName) @@ -142,7 +138,7 @@ class AutoIncColumnType( val nextValExpression: NextVal<*>? get() = nextValValue.takeIf { autoincSeq != null } - private fun resolveAutoIncType(columnType: IColumnType): String = when { + private fun resolveAutoIncType(columnType: IColumnType<*>): String = when { columnType is EntityIDColumnType<*> -> resolveAutoIncType(columnType.idColumn.columnType) columnType is IntegerColumnType && autoincSeq != null -> currentDialect.dataTypeProvider.integerType() columnType is IntegerColumnType -> currentDialect.dataTypeProvider.integerAutoincType() @@ -166,7 +162,7 @@ class AutoIncColumnType( other == null -> false this === other -> true this::class != other::class -> false - other !is AutoIncColumnType -> false + other !is AutoIncColumnType<*> -> false delegate != other.delegate -> false _autoincSeq != other._autoincSeq -> false fallbackSeqName != other.fallbackSeqName -> false @@ -183,15 +179,15 @@ class AutoIncColumnType( } /** Returns `true` if this is an auto-increment column, `false` otherwise. */ -val IColumnType.isAutoInc: Boolean +val IColumnType<*>.isAutoInc: Boolean get() = this is AutoIncColumnType || (this is EntityIDColumnType<*> && idColumn.columnType.isAutoInc) /** Returns the name of the auto-increment sequence of this column. */ -val Column<*>.autoIncColumnType: AutoIncColumnType? +val Column<*>.autoIncColumnType: AutoIncColumnType<*>? get() = (columnType as? AutoIncColumnType) ?: (columnType as? EntityIDColumnType<*>)?.idColumn?.columnType as? AutoIncColumnType -internal fun IColumnType.rawSqlType(): IColumnType = when { +internal fun IColumnType<*>.rawSqlType(): IColumnType<*> = when { this is AutoIncColumnType -> delegate this is EntityIDColumnType<*> && idColumn.columnType is AutoIncColumnType -> idColumn.columnType.delegate else -> this @@ -203,7 +199,7 @@ internal fun IColumnType.rawSqlType(): IColumnType = when { class EntityIDColumnType>( /** The underlying wrapped column storing the identity values. */ val idColumn: Column -) : ColumnType() { +) : ColumnType>() { init { require(idColumn.table is IdTable<*>) { "EntityId supported only for IdTables" } @@ -211,19 +207,9 @@ class EntityIDColumnType>( override fun sqlType(): String = idColumn.columnType.sqlType() - override fun notNullValueToDB(value: Any): Any = idColumn.columnType.notNullValueToDB( - when (value) { - is EntityID<*> -> value.value - else -> value - } - ) + override fun notNullValueToDB(value: EntityID): Any = idColumn.columnType.notNullValueToDB(value.value) - override fun nonNullValueToString(value: Any): String = idColumn.columnType.nonNullValueToString( - when (value) { - is EntityID<*> -> value.value - else -> value - } - ) + override fun nonNullValueToString(value: EntityID): String = idColumn.columnType.nonNullValueToString(value.value) @Suppress("UNCHECKED_CAST") override fun valueFromDB(value: Any): EntityID = EntityIDFunctionProvider.createEntityID( @@ -241,7 +227,7 @@ class EntityIDColumnType>( return when (other) { is EntityIDColumnType<*> -> idColumn == other.idColumn - is IColumnType -> idColumn.columnType == other + is IColumnType<*> -> idColumn.columnType == other else -> false } } @@ -254,7 +240,7 @@ class EntityIDColumnType>( /** * Numeric column for storing 1-byte integers. */ -class ByteColumnType : ColumnType() { +class ByteColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.byteType() override fun valueFromDB(value: Any): Byte = when (value) { @@ -272,7 +258,7 @@ class ByteColumnType : ColumnType() { * database's 2-byte integer type with a check constraint that ensures storage of only values * between 0 and [UByte.MAX_VALUE] inclusive. */ -class UByteColumnType : ColumnType() { +class UByteColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.ubyteType() override fun valueFromDB(value: Any): UByte { @@ -293,19 +279,13 @@ class UByteColumnType : ColumnType() { super.setParameter(stmt, index, v) } - override fun notNullValueToDB(value: Any): Any { - val v = when (value) { - is UByte -> value.toShort() - else -> value - } - return super.notNullValueToDB(v) - } + override fun notNullValueToDB(value: UByte): Any = value.toShort() } /** * Numeric column for storing 2-byte integers. */ -class ShortColumnType : ColumnType() { +class ShortColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.shortType() override fun valueFromDB(value: Any): Short = when (value) { is Short -> value @@ -321,7 +301,7 @@ class ShortColumnType : ColumnType() { * **Note:** If the database being used is not MySQL or MariaDB, this column will represent the database's 4-byte * integer type with a check constraint that ensures storage of only values between 0 and [UShort.MAX_VALUE] inclusive. */ -class UShortColumnType : ColumnType() { +class UShortColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.ushortType() override fun valueFromDB(value: Any): UShort { return when (value) { @@ -341,19 +321,13 @@ class UShortColumnType : ColumnType() { super.setParameter(stmt, index, v) } - override fun notNullValueToDB(value: Any): Any { - val v = when (value) { - is UShort -> value.toInt() - else -> value - } - return super.notNullValueToDB(v) - } + override fun notNullValueToDB(value: UShort): Any = value.toInt() } /** * Numeric column for storing 4-byte integers. */ -class IntegerColumnType : ColumnType() { +class IntegerColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.integerType() override fun valueFromDB(value: Any): Int = when (value) { is Int -> value @@ -370,7 +344,7 @@ class IntegerColumnType : ColumnType() { * 8-byte integer type with a check constraint that ensures storage of only values * between 0 and [UInt.MAX_VALUE] inclusive. */ -class UIntegerColumnType : ColumnType() { +class UIntegerColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.uintegerType() override fun valueFromDB(value: Any): UInt { return when (value) { @@ -390,19 +364,13 @@ class UIntegerColumnType : ColumnType() { super.setParameter(stmt, index, v) } - override fun notNullValueToDB(value: Any): Any { - val v = when (value) { - is UInt -> value.toLong() - else -> value - } - return super.notNullValueToDB(v) - } + override fun notNullValueToDB(value: UInt): Any = value.toLong() } /** * Numeric column for storing 8-byte integers. */ -class LongColumnType : ColumnType() { +class LongColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.longType() override fun valueFromDB(value: Any): Long = when (value) { is Long -> value @@ -415,7 +383,7 @@ class LongColumnType : ColumnType() { /** * Numeric column for storing unsigned 8-byte integers. */ -class ULongColumnType : ColumnType() { +class ULongColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.ulongType() override fun valueFromDB(value: Any): ULong { return when (value) { @@ -444,20 +412,16 @@ class ULongColumnType : ColumnType() { super.setParameter(stmt, index, v) } - override fun notNullValueToDB(value: Any): Any { - val v = when { - value is ULong && currentDialect is MysqlDialect -> value.toString() - value is ULong -> value.toLong() - else -> value - } - return super.notNullValueToDB(v) + override fun notNullValueToDB(value: ULong) = when { + currentDialect is MysqlDialect -> value.toString() + else -> value.toLong() } } /** * Numeric column for storing 4-byte (single precision) floating-point numbers. */ -class FloatColumnType : ColumnType() { +class FloatColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.floatType() override fun valueFromDB(value: Any): Float = when (value) { is Float -> value @@ -470,7 +434,7 @@ class FloatColumnType : ColumnType() { /** * Numeric column for storing 8-byte (double precision) floating-point numbers. */ -class DoubleColumnType : ColumnType() { +class DoubleColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.doubleType() override fun valueFromDB(value: Any): Double = when (value) { is Double -> value @@ -488,7 +452,7 @@ class DecimalColumnType( val precision: Int, /** Count of decimal digits in the fractional part. */ val scale: Int -) : ColumnType() { +) : ColumnType() { override fun sqlType(): String = "DECIMAL($precision, $scale)" override fun readObject(rs: ResultSet, index: Int): Any? { @@ -547,7 +511,7 @@ class DecimalColumnType( /** * Character column for storing single characters. */ -class CharacterColumnType : ColumnType() { +class CharacterColumnType : ColumnType() { override fun sqlType(): String = "CHAR" override fun valueFromDB(value: Any): Char = when (value) { is Char -> value @@ -556,8 +520,9 @@ class CharacterColumnType : ColumnType() { else -> error("Unexpected value of type Char: $value of ${value::class.qualifiedName}") } - override fun notNullValueToDB(value: Any): Any = value.toString() - override fun nonNullValueToString(value: Any): String = "'$value'" + override fun notNullValueToDB(value: Char): Any = value.toString() + + override fun nonNullValueToString(value: Char): String = "'$value'" } /** @@ -566,7 +531,7 @@ class CharacterColumnType : ColumnType() { abstract class StringColumnType( /** Returns the collate type used in by this column. */ val collate: String? = null -) : ColumnType() { +) : ColumnType() { /** Returns the specified [value] with special characters escaped. */ protected fun escape(value: String): String = value.map { charactersToEscape[it] ?: it }.joinToString("") @@ -576,15 +541,15 @@ abstract class StringColumnType( else -> escape(value) } - override fun valueFromDB(value: Any): Any = when (value) { + override fun valueFromDB(value: Any): String = when (value) { is Clob -> value.characterStream.readText() is ByteArray -> String(value) - else -> value + else -> value.toString() } - override fun nonNullValueToString(value: Any): String = buildString { + override fun nonNullValueToString(value: String): String = buildString { append('\'') - append(escape(value.toString())) + append(escape(value)) append('\'') } @@ -595,9 +560,7 @@ abstract class StringColumnType( other as StringColumnType - if (collate != other.collate) return false - - return true + return collate == other.collate } override fun hashCode(): Int { @@ -609,7 +572,6 @@ abstract class StringColumnType( companion object { private val charactersToEscape = mapOf( '\'' to "\'\'", -// '\"' to "\"\"", // no need to escape double quote as we put string in single quotes '\r' to "\\r", '\n' to "\\n" ) @@ -631,7 +593,7 @@ open class CharColumnType( } } - override fun validateValueBeforeUpdate(value: Any?) { + override fun validateValueBeforeUpdate(value: String?) { if (value is String) { val valueLength = value.codePointCount(0, value.length) require(valueLength <= colLength) { @@ -677,7 +639,7 @@ open class VarCharColumnType( } } - override fun validateValueBeforeUpdate(value: Any?) { + override fun validateValueBeforeUpdate(value: String?) { if (value is String) { val valueLength = value.codePointCount(0, value.length) require(valueLength <= colLength) { @@ -693,9 +655,7 @@ open class VarCharColumnType( other as VarCharColumnType - if (colLength != other.colLength) return false - - return true + return colLength == other.colLength } override fun hashCode(): Int { @@ -752,21 +712,18 @@ open class LargeTextColumnType( /** * Binary column for storing binary strings of variable and _unlimited_ length. */ -open class BasicBinaryColumnType : ColumnType() { +open class BasicBinaryColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.binaryType() override fun readObject(rs: ResultSet, index: Int): Any? = rs.getBytes(index) - override fun valueFromDB(value: Any): Any = when (value) { + override fun valueFromDB(value: Any): ByteArray = when (value) { is Blob -> value.binaryStream.use { it.readBytes() } is InputStream -> value.use { it.readBytes() } - else -> value + else -> error("Unexpected value $value of type ${value::class.qualifiedName}") } - override fun nonNullValueToString(value: Any): String = when (value) { - is ByteArray -> value.toString(Charsets.UTF_8) - else -> value.toString() - } + override fun nonNullValueToString(value: ByteArray): String = value.toString(Charsets.UTF_8) } /** @@ -778,7 +735,7 @@ open class BinaryColumnType( ) : BasicBinaryColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.binaryType(length) - override fun validateValueBeforeUpdate(value: Any?) { + override fun validateValueBeforeUpdate(value: ByteArray?) { if (value is ByteArray) { val valueLength = value.size require(valueLength <= length) { @@ -794,9 +751,7 @@ open class BinaryColumnType( other as BinaryColumnType - if (length != other.length) return false - - return true + return length == other.length } override fun hashCode(): Int { @@ -812,12 +767,13 @@ open class BinaryColumnType( class BlobColumnType( /** Returns whether an OID column should be used instead of BYTEA. This value only applies to PostgreSQL databases. */ val useObjectIdentifier: Boolean = false -) : ColumnType() { +) : ColumnType() { override fun sqlType(): String = when { useObjectIdentifier && currentDialect is PostgreSQLDialect -> "oid" useObjectIdentifier -> error("Storing BLOBs using OID columns is only supported by PostgreSQL") else -> currentDialect.dataTypeProvider.blobType() } + override fun valueFromDB(value: Any): ExposedBlob = when (value) { is ExposedBlob -> value is InputStream -> ExposedBlob(value) @@ -825,21 +781,7 @@ class BlobColumnType( else -> error("Unexpected value of type Blob: $value of ${value::class.qualifiedName}") } - override fun notNullValueToDB(value: Any): Any { - return if (value is Blob) { - value.binaryStream - } else { - value - } - } - - override fun nonNullValueToString(value: Any): String { - if (value !is ExposedBlob) { - error("Unexpected value of type Blob: $value of ${value::class.qualifiedName}") - } - - return currentDialect.dataTypeProvider.hexToDb(value.hexString()) - } + override fun nonNullValueToString(value: ExposedBlob): String = currentDialect.dataTypeProvider.hexToDb(value.hexString()) override fun readObject(rs: ResultSet, index: Int) = when { currentDialect is SQLServerDialect -> rs.getBytes(index)?.let(::ExposedBlob) @@ -859,7 +801,7 @@ class BlobColumnType( /** * Binary column for storing [UUID]. */ -class UUIDColumnType : ColumnType() { +class UUIDColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.uuidType() override fun valueFromDB(value: Any): UUID = when { @@ -870,16 +812,9 @@ class UUIDColumnType : ColumnType() { else -> error("Unexpected value of type UUID: $value of ${value::class.qualifiedName}") } - override fun notNullValueToDB(value: Any): Any = currentDialect.dataTypeProvider.uuidToDB(valueToUUID(value)) + override fun notNullValueToDB(value: UUID): Any = currentDialect.dataTypeProvider.uuidToDB(value) - override fun nonNullValueToString(value: Any): String = "'${valueToUUID(value)}'" - - private fun valueToUUID(value: Any): UUID = when (value) { - is UUID -> value - is String -> UUID.fromString(value) - is ByteArray -> ByteBuffer.wrap(value).let { UUID(it.long, it.long) } - else -> error("Unexpected value of type UUID: ${value.javaClass.canonicalName}") - } + override fun nonNullValueToString(value: UUID): String = "'$value'" override fun readObject(rs: ResultSet, index: Int): Any? = when (currentDialect) { is MariaDBDialect -> rs.getBytes(index) @@ -897,7 +832,7 @@ class UUIDColumnType : ColumnType() { /** * Boolean column for storing boolean values. */ -class BooleanColumnType : ColumnType() { +class BooleanColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.booleanType() override fun valueFromDB(value: Any): Boolean = when (value) { @@ -906,12 +841,11 @@ class BooleanColumnType : ColumnType() { else -> value.toString().toBoolean() } - override fun nonNullValueToString(value: Any): String = - currentDialect.dataTypeProvider.booleanToStatementString(value as Boolean) + override fun nonNullValueToString(value: Boolean): String = + currentDialect.dataTypeProvider.booleanToStatementString(value) - override fun notNullValueToDB(value: Any): Any = when { - value is Boolean && - (currentDialect is OracleDialect || currentDialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle) -> + override fun notNullValueToDB(value: Boolean): Any = when { + (currentDialect is OracleDialect || currentDialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle) -> nonNullValueToString(value) else -> value } @@ -929,7 +863,7 @@ class BooleanColumnType : ColumnType() { class EnumerationColumnType>( /** Returns the enum class used in this column type. */ val klass: KClass -) : ColumnType() { +) : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.integerType() private val enumConstants by lazy { klass.java.enumConstants!! } @@ -940,11 +874,7 @@ class EnumerationColumnType>( else -> error("$value of ${value::class.qualifiedName} is not valid for enum ${klass.simpleName}") } - override fun notNullValueToDB(value: Any): Int = when (value) { - is Int -> value - is Enum<*> -> value.ordinal - else -> error("$value of ${value::class.qualifiedName} is not valid for enum ${klass.simpleName}") - } + override fun notNullValueToDB(value: T): Int = value.ordinal override fun equals(other: Any?): Boolean { if (this === other) return true @@ -971,10 +901,12 @@ class EnumerationColumnType>( class EnumerationNameColumnType>( /** Returns the enum class used in this column type. */ val klass: KClass, - colLength: Int -) : VarCharColumnType(colLength) { + val colLength: Int +) : ColumnType() { private val enumConstants by lazy { klass.java.enumConstants!!.associateBy { it.name } } + override fun sqlType(): String = currentDialect.dataTypeProvider.varcharType(colLength) + @Suppress("UNCHECKED_CAST") override fun valueFromDB(value: Any): T = when (value) { is String -> { @@ -984,10 +916,21 @@ class EnumerationNameColumnType>( else -> error("$value of ${value::class.qualifiedName} is not valid for enum ${klass.qualifiedName}") } - override fun notNullValueToDB(value: Any): Any = when (value) { - is String -> super.notNullValueToDB(value) - is Enum<*> -> super.notNullValueToDB(value.name) - else -> error("$value of ${value::class.qualifiedName} is not valid for enum ${klass.qualifiedName}") + override fun notNullValueToDB(value: T): Any = value.name + + override fun nonNullValueToString(value: T): String = buildString { + append('\'') + append(escape(value.name)) + append('\'') + } + + override fun validateValueBeforeUpdate(value: T?) { + if (value != null) { + val valueLength = value.name.codePointCount(0, value.name.length) + require(valueLength <= colLength) { + "Value can't be stored to database column because exceeds length ($valueLength > $colLength)" + } + } } override fun equals(other: Any?): Boolean { @@ -1007,6 +950,16 @@ class EnumerationNameColumnType>( result = 31 * result + klass.hashCode() return result } + + private fun escape(value: String): String = value.map { charactersToEscape[it] ?: it }.joinToString("") + + companion object { + private val charactersToEscape = mapOf( + '\'' to "\'\'", + '\r' to "\\r", + '\n' to "\\n" + ) + } } /** @@ -1021,16 +974,29 @@ class CustomEnumerationColumnType>( val fromDb: (Any) -> T, /** Returns the function that converts an enumeration instance [T] to a value that will be stored to a database. */ val toDb: (T) -> Any -) : StringColumnType() { +) : ColumnType() { override fun sqlType(): String = sql ?: error("Column $name should exist in database") @Suppress("UNCHECKED_CAST") override fun valueFromDB(value: Any): T = if (value::class.isSubclassOf(Enum::class)) value as T else fromDb(value) - @Suppress("UNCHECKED_CAST") - override fun notNullValueToDB(value: Any): Any = toDb(value as T) + override fun notNullValueToDB(value: T): Any = toDb(value) + + override fun nonNullValueToString(value: T): String = buildString { + append('\'') + append(escape(value.toString())) + append('\'') + } + + private fun escape(value: String): String = value.map { charactersToEscape[it] ?: it }.joinToString("") - override fun nonNullValueToString(value: Any): String = super.nonNullValueToString(notNullValueToDB(value)) + companion object { + private val charactersToEscape = mapOf( + '\'' to "\'\'", + '\r' to "\\r", + '\n' to "\\n" + ) + } } // Array columns @@ -1038,12 +1004,12 @@ class CustomEnumerationColumnType>( /** * Array column for storing a collection of elements. */ -class ArrayColumnType( +class ArrayColumnType( /** Returns the base column type of this array column's individual elements. */ - val delegate: ColumnType, + val delegate: ColumnType, /** Returns the maximum amount of allowed elements in this array column. */ val maximumCardinality: Int? = null -) : ColumnType() { +) : ColumnType>() { override fun sqlType(): String = buildString { append(delegate.sqlType()) when { @@ -1056,37 +1022,23 @@ class ArrayColumnType( val delegateType: String get() = delegate.sqlType().substringBefore('(') - override fun valueFromDB(value: Any): Any = when { + override fun valueFromDB(value: Any): List = when { value is java.sql.Array -> (value.array as Array<*>).map { e -> e?.let { delegate.valueFromDB(it) } } - else -> value + else -> error("Unexpected value $value of type ${value::class.qualifiedName}") } - override fun notNullValueToDB(value: Any): Any = when { - value is List<*> -> value.map { e -> e?.let { delegate.notNullValueToDB(it) } }.toTypedArray() - else -> value - } + override fun notNullValueToDB(value: List): Any = value.map { e -> e?.let { delegate.notNullValueToDB(it) } }.toTypedArray() - override fun valueToString(value: Any?): String = when (value) { - is List<*> -> nonNullValueToString(value) - is Array<*> -> nonNullValueToString(value.toList()) - else -> super.valueToString(value) - } + override fun valueToString(value: List?): String = if (value != null) nonNullValueToString(value) else super.valueToString(value) - override fun nonNullValueToString(value: Any): String = when { - value is List<*> -> { - val prefix = if (currentDialect is H2Dialect) "ARRAY [" else "ARRAY[" - value.joinToString(",", prefix, "]") { delegate.valueToString(it) } - } - else -> super.nonNullValueToString(value) + override fun nonNullValueToString(value: List): String { + val prefix = if (currentDialect is H2Dialect) "ARRAY [" else "ARRAY[" + return value.joinToString(",", prefix, "]") { delegate.valueToString(it) } } - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is List<*> -> { - val prefix = if (currentDialect is H2Dialect) "ARRAY [" else "ARRAY[" - value.joinToString(",", prefix, "]") { delegate.valueAsDefaultString(it) } - } - is Array<*> -> nonNullValueAsDefaultString(value.toList()) - else -> super.nonNullValueAsDefaultString(value) + override fun nonNullValueAsDefaultString(value: List): String { + val prefix = if (currentDialect is H2Dialect) "ARRAY [" else "ARRAY[" + return value.joinToString(",", prefix, "]") { delegate.valueAsDefaultString(it) } } override fun readObject(rs: ResultSet, index: Int): Any? = rs.getArray(index) @@ -1126,25 +1078,29 @@ interface JsonColumnMarker { @InternalApi fun resolveColumnType( klass: KClass, - defaultType: ColumnType? = null -): ColumnType = when (klass) { - Boolean::class -> BooleanColumnType() - Byte::class -> ByteColumnType() - UByte::class -> UByteColumnType() - Short::class -> ShortColumnType() - UShort::class -> UShortColumnType() - Int::class -> IntegerColumnType() - UInt::class -> UIntegerColumnType() - Long::class -> LongColumnType() - ULong::class -> ULongColumnType() - Float::class -> FloatColumnType() - Double::class -> DoubleColumnType() - String::class -> TextColumnType() - Char::class -> CharacterColumnType() - ByteArray::class -> BasicBinaryColumnType() - BigDecimal::class -> DecimalColumnType.INSTANCE - UUID::class -> UUIDColumnType() - else -> defaultType ?: error( + defaultType: ColumnType<*>? = null +): ColumnType { + val type = when (klass) { + Boolean::class -> BooleanColumnType() + Byte::class -> ByteColumnType() + UByte::class -> UByteColumnType() + Short::class -> ShortColumnType() + UShort::class -> UShortColumnType() + Int::class -> IntegerColumnType() + UInt::class -> UIntegerColumnType() + Long::class -> LongColumnType() + ULong::class -> ULongColumnType() + Float::class -> FloatColumnType() + Double::class -> DoubleColumnType() + String::class -> TextColumnType() + Char::class -> CharacterColumnType() + ByteArray::class -> BasicBinaryColumnType() + BigDecimal::class -> DecimalColumnType.INSTANCE + UUID::class -> UUIDColumnType() + else -> defaultType + } as? ColumnType + + return type ?: error( "A column type could not be associated with ${klass.qualifiedName}. Provide an explicit column type argument." ) } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Expression.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Expression.kt index 5ca7084d68..a9ea4d4353 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Expression.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Expression.kt @@ -11,10 +11,10 @@ class QueryBuilder( val prepared: Boolean ) { private val internalBuilder = StringBuilder() - private val _args = mutableListOf>() + private val _args = mutableListOf, Any?>>() /** Returns the list of arguments used in this query. */ - val args: List> get() = _args + val args: List, Any?>> get() = _args operator fun invoke(body: QueryBuilder.() -> Unit): Unit = body() @@ -76,10 +76,10 @@ class QueryBuilder( } /** Adds the specified [argument] as a value of the specified [sqlType]. */ - fun registerArgument(sqlType: IColumnType, argument: T): Unit = registerArguments(sqlType, listOf(argument)) + fun registerArgument(sqlType: IColumnType<*>, argument: T): Unit = registerArguments(sqlType, listOf(argument)) /** Adds the specified sequence of [arguments] as values of the specified [sqlType]. */ - fun registerArguments(sqlType: IColumnType, arguments: Iterable) { + fun registerArguments(sqlType: IColumnType<*>, arguments: Iterable) { if (arguments is Collection && arguments.size <= 1) { // avoid potentially expensive valueToString call unless we need to sort values arguments.forEach { @@ -87,13 +87,13 @@ class QueryBuilder( _args.add(sqlType to it) append("?") } else { - append(sqlType.valueToString(it)) + append((sqlType as IColumnType).valueToString(it)) } } } else { fun toString(value: T) = when { prepared && value is String -> value - else -> sqlType.valueToString(value) + else -> (sqlType as IColumnType).valueToString(value) } arguments.map { it to toString(it) } @@ -166,5 +166,5 @@ abstract class Expression { */ abstract class ExpressionWithColumnType : Expression() { /** Returns the column type of this expression. Used for operations with literals. */ - abstract val columnType: IColumnType + abstract val columnType: IColumnType } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt index 97faac13f0..573c6da0dc 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Function.kt @@ -9,7 +9,7 @@ import java.math.BigDecimal /** * Represents an SQL function. */ -abstract class Function(override val columnType: IColumnType) : ExpressionWithColumnType() +abstract class Function(override val columnType: IColumnType) : ExpressionWithColumnType() /** * Represents a custom SQL function. @@ -17,7 +17,7 @@ abstract class Function(override val columnType: IColumnType) : ExpressionWit open class CustomFunction( /** Returns the name of the function. */ val functionName: String, - columnType: IColumnType, + columnType: IColumnType, /** Returns the list of arguments of this function. */ vararg val expr: Expression<*> ) : Function(columnType) { @@ -34,7 +34,7 @@ open class CustomFunction( open class CustomOperator( /** Returns the name of the operator. */ val operatorName: String, - columnType: IColumnType, + columnType: IColumnType, /** Returns the left-hand side operand. */ val expr1: Expression<*>, /** Returns the right-hand side operand. */ @@ -84,7 +84,7 @@ class CharLength( class LowerCase( /** Returns the expression to convert. */ val expr: Expression -) : Function(TextColumnType()) { +) : Function(TextColumnType()) { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("LOWER(", expr, ")") } } @@ -94,7 +94,7 @@ class LowerCase( class UpperCase( /** Returns the expression to convert. */ val expr: Expression -) : Function(TextColumnType()) { +) : Function(TextColumnType()) { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("UPPER(", expr, ")") } } @@ -124,7 +124,7 @@ class GroupConcat( val distinct: Boolean, /** Returns the order in which the elements of each group are sorted. */ vararg val orderBy: Pair, SortOrder> -) : Function(TextColumnType()) { +) : Function(TextColumnType()) { override fun toQueryBuilder(queryBuilder: QueryBuilder) { currentDialect.functionProvider.groupConcat(this, queryBuilder) } @@ -138,7 +138,7 @@ class Substring( private val start: Expression, /** Returns the length of the substring. */ val length: Expression -) : Function(TextColumnType()) { +) : Function(TextColumnType()) { override fun toQueryBuilder(queryBuilder: QueryBuilder) { currentDialect.functionProvider.substring(expr, start, length, queryBuilder) } @@ -150,7 +150,7 @@ class Substring( class Trim( /** Returns the expression being trimmed. */ val expr: Expression -) : Function(TextColumnType()) { +) : Function(TextColumnType()) { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("TRIM(", expr, ")") } } @@ -167,14 +167,14 @@ class Locate(val expr: Expression, val substring: String) : Func /** * Represents an SQL function that returns the minimum value of [expr] across all non-null input values, or `null` if there are no non-null values. */ -class Min, in S : T?>( +class Min, S : T?>( /** Returns the expression from which the minimum value is obtained. */ - val expr: Expression, - columnType: IColumnType -) : Function(columnType), WindowFunction { + val expr: Expression, + columnType: IColumnType +) : Function(columnType), WindowFunction { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("MIN(", expr, ")") } - override fun over(): WindowFunctionDefinition { + override fun over(): WindowFunctionDefinition { return WindowFunctionDefinition(columnType, this) } } @@ -182,14 +182,14 @@ class Min, in S : T?>( /** * Represents an SQL function that returns the maximum value of [expr] across all non-null input values, or `null` if there are no non-null values. */ -class Max, in S : T?>( +class Max, S : T?>( /** Returns the expression from which the maximum value is obtained. */ - val expr: Expression, - columnType: IColumnType -) : Function(columnType), WindowFunction { + val expr: Expression, + columnType: IColumnType +) : Function(columnType), WindowFunction { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("MAX(", expr, ")") } - override fun over(): WindowFunctionDefinition { + override fun over(): WindowFunctionDefinition { return WindowFunctionDefinition(columnType, this) } } @@ -215,7 +215,7 @@ class Avg, in S : T?>( class Sum( /** Returns the expression from which the sum is calculated. */ val expr: Expression, - columnType: IColumnType + columnType: IColumnType ) : Function(columnType), WindowFunction { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { append("SUM(", expr, ")") } @@ -351,7 +351,7 @@ class VarSamp( sealed class NextVal( /** Returns the sequence from which the next value is obtained. */ val seq: Sequence, - columnType: IColumnType + columnType: IColumnType ) : Function(columnType) { override fun toQueryBuilder(queryBuilder: QueryBuilder) { @@ -393,30 +393,30 @@ class CaseWhen( /** Adds a conditional expression with a [result] if the expression evaluates to `true`. */ @Suppress("UNCHECKED_CAST") - fun When(cond: Expression, result: Expression): CaseWhen { + fun When(cond: Expression, result: Expression): CaseWhen { cases.add(cond to result) - return this as CaseWhen + return this } /** Adds an expression that will be used as the function result if all [cases] evaluate to `false`. */ - fun Else(e: Expression): ExpressionWithColumnType = CaseWhenElse(this, e) + fun Else(e: Expression): ExpressionWithColumnType = CaseWhenElse(this, e) } /** * Represents an SQL function that steps through conditions, and either returns a value when the first condition is met * or returns [elseResult] if all conditions are `false`. */ -class CaseWhenElse( +class CaseWhenElse( /** The conditions to check and their results if met. */ val caseWhen: CaseWhen, /** The result if none of the conditions checked are found to be `true`. */ - val elseResult: Expression -) : ExpressionWithColumnType(), ComplexExpression { + val elseResult: Expression +) : ExpressionWithColumnType(), ComplexExpression { - override val columnType: IColumnType = - (elseResult as? ExpressionWithColumnType)?.columnType - ?: caseWhen.cases.map { it.second }.filterIsInstance>().firstOrNull()?.columnType - ?: BooleanColumnType.INSTANCE + override val columnType: IColumnType = + (elseResult as? ExpressionWithColumnType)?.columnType + ?: caseWhen.cases.map { it.second }.filterIsInstance>().firstOrNull()?.columnType + ?: error("No column type has been found") override fun toQueryBuilder(queryBuilder: QueryBuilder) { queryBuilder { @@ -438,11 +438,11 @@ class CaseWhenElse( /** * Represents an SQL function that returns the first of its arguments that is not null. */ -class Coalesce( - private val expr: ExpressionWithColumnType, - private val alternate: Expression, - private vararg val others: Expression -) : Function(expr.columnType) { +class Coalesce( + private val expr: ExpressionWithColumnType, + private val alternate: Expression, + private vararg val others: Expression +) : Function(expr.columnType) { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { (listOf(expr, alternate) + others).appendTo( prefix = "COALESCE(", @@ -460,7 +460,7 @@ class Coalesce( class Cast( /** Returns the expression being casted. */ val expr: Expression<*>, - columnType: IColumnType + columnType: IColumnType ) : Function(columnType) { override fun toQueryBuilder(queryBuilder: QueryBuilder) { currentDialect.functionProvider.cast(expr, columnType, queryBuilder) diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Op.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Op.kt index cca68f99db..d9c34f530e 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Op.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Op.kt @@ -309,7 +309,7 @@ class PlusOp( /** The right-hand side operand. */ expr2: Expression, /** The column type of this expression. */ - columnType: IColumnType + columnType: IColumnType ) : CustomOperator("+", columnType, expr1, expr2) /** @@ -321,7 +321,7 @@ class MinusOp( /** The right-hand side operand. */ expr2: Expression, /** The column type of this expression. */ - columnType: IColumnType + columnType: IColumnType ) : CustomOperator("-", columnType, expr1, expr2) /** @@ -333,7 +333,7 @@ class TimesOp( /** The right-hand side operand. */ expr2: Expression, /** The column type of this expression. */ - columnType: IColumnType + columnType: IColumnType ) : CustomOperator("*", columnType, expr1, expr2) /** @@ -345,7 +345,7 @@ class DivideOp( /** The right-hand side operand. */ private val divisor: Expression, /** The column type of this expression. */ - columnType: IColumnType + columnType: IColumnType ) : CustomOperator("/", columnType, dividend, divisor) { companion object { fun DivideOp.withScale(scale: Int): DivideOp { @@ -356,7 +356,7 @@ class DivideOp( decimalLiteral(it.setScale(1)) // it is needed to treat dividend as decimal instead of integer in SQL } ?: dividend - return DivideOp(newExpression as Expression, divisor, decimalColumnType) + return DivideOp(newExpression as Expression, divisor, decimalColumnType as IColumnType) } } } @@ -369,7 +369,7 @@ class ModOp( val expr1: Expression, /** Returns the right-hand side operand. */ val expr2: Expression, - override val columnType: IColumnType + override val columnType: IColumnType ) : ExpressionWithColumnType() { override fun toQueryBuilder(queryBuilder: QueryBuilder) { @@ -431,7 +431,7 @@ class AndBitOp( /** The right-hand side operand. */ val expr2: Expression, /** The column type of this expression. */ - override val columnType: IColumnType + override val columnType: IColumnType ) : ExpressionWithColumnType() { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { when (val dialect = currentDialectIfAvailable) { @@ -462,7 +462,7 @@ class OrBitOp( /** The right-hand side operand. */ val expr2: Expression, /** The column type of this expression. */ - override val columnType: IColumnType + override val columnType: IColumnType ) : ExpressionWithColumnType() { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { when (val dialect = currentDialectIfAvailable) { @@ -494,7 +494,7 @@ class XorBitOp( /** The right-hand side operand. */ val expr2: Expression, /** The column type of this expression. */ - override val columnType: IColumnType + override val columnType: IColumnType ) : ExpressionWithColumnType() { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { when (val dialect = currentDialectIfAvailable) { @@ -638,7 +638,7 @@ class NotEqSubQueryOp(expr: Expression, query: AbstractQuery<*>) : SubQuer * Represents the specified [value] as an SQL literal, using the specified [columnType] to convert the value. */ class LiteralOp( - override val columnType: IColumnType, + override val columnType: IColumnType, /** Returns the value being used as a literal. */ val value: T ) : ExpressionWithColumnType() { @@ -692,7 +692,7 @@ fun decimalLiteral(value: BigDecimal): LiteralOp = LiteralOp(Decimal * * @throws IllegalStateException If no column type mapping is found and a [delegateType] is not provided. */ -inline fun arrayLiteral(value: List, delegateType: ColumnType? = null): LiteralOp> { +inline fun arrayLiteral(value: List, delegateType: ColumnType? = null): LiteralOp> { @OptIn(InternalApi::class) return LiteralOp(ArrayColumnType(delegateType ?: resolveColumnType(T::class)), value) } @@ -706,13 +706,13 @@ class QueryParameter( /** Returns the value being used as a query parameter. */ val value: T, /** Returns the column type of this expression. */ - val sqlType: IColumnType + val sqlType: IColumnType ) : Expression() { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { registerArgument(sqlType, value) } } /** Returns the specified [value] as a query parameter with the same type as [column]. */ -fun > idParam(value: EntityID, column: Column>): Expression> = +fun > idParam(value: EntityID, column: Column): Expression> = QueryParameter(value, EntityIDColumnType(column)) /** Returns the specified [value] as a boolean query parameter. */ @@ -771,7 +771,7 @@ fun blobParam(value: ExposedBlob, useObjectIdentifier: Boolean = false): Express * * @throws IllegalStateException If no column type mapping is found and a [delegateType] is not provided. */ -inline fun arrayParam(value: List, delegateType: ColumnType? = null): Expression> { +inline fun arrayParam(value: List, delegateType: ColumnType? = null): Expression> { @OptIn(InternalApi::class) return QueryParameter(value, ArrayColumnType(delegateType ?: resolveColumnType(T::class))) } @@ -785,7 +785,7 @@ inline fun arrayParam(value: List, delegateType: ColumnType class NoOpConversion( /** Returns the expression whose type is being changed. */ val expr: Expression, - override val columnType: IColumnType + override val columnType: IColumnType ) : ExpressionWithColumnType() { override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { +expr } } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/SQLExpressionBuilder.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/SQLExpressionBuilder.kt index 2455a399cf..1c283d7548 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/SQLExpressionBuilder.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/SQLExpressionBuilder.kt @@ -64,16 +64,16 @@ fun Expression.locate(substring: String): Locate = Locate(th // General-Purpose Aggregate Functions /** Returns the minimum value of this expression across all non-null input values, or `null` if there are no non-null values. */ -fun , S : T?> ExpressionWithColumnType.min(): Min = Min(this, this.columnType) +fun , S : T?> ExpressionWithColumnType.min(): Min = Min(this, this.columnType) /** Returns the maximum value of this expression across all non-null input values, or `null` if there are no non-null values. */ -fun , S : T?> ExpressionWithColumnType.max(): Max = Max(this, this.columnType) +fun , S : T?> ExpressionWithColumnType.max(): Max = Max(this, this.columnType) /** Returns the average (arithmetic mean) value of this expression across all non-null input values, or `null` if there are no non-null values. */ fun , S : T?> ExpressionWithColumnType.avg(scale: Int = 2): Avg = Avg(this, scale) /** Returns the sum of this expression across all non-null input values, or `null` if there are no non-null values. */ -fun ExpressionWithColumnType.sum(): Sum = Sum(this, this.columnType) +fun ExpressionWithColumnType.sum(): Sum = Sum(this, this.columnType) /** Returns the number of input rows for which the value of this expression is not null. */ fun ExpressionWithColumnType<*>.count(): Count = Count(this) @@ -124,7 +124,7 @@ fun anyFrom(subQuery: AbstractQuery<*>): Op = AllAnyFromSubQueryOp(true, * * @throws IllegalStateException If no column type mapping is found and a [delegateType] is not provided. */ -inline fun anyFrom(array: Array, delegateType: ColumnType? = null): Op { +inline fun anyFrom(array: Array, delegateType: ColumnType? = null): Op { // emptyArray() without type info generates ARRAY[] @OptIn(InternalApi::class) val columnType = delegateType ?: resolveColumnType(T::class, if (array.isEmpty()) TextColumnType() else null) @@ -139,7 +139,7 @@ inline fun anyFrom(array: Array, delegateType: ColumnType? * * @throws IllegalStateException If no column type mapping is found and a [delegateType] is not provided. */ -inline fun anyFrom(array: List, delegateType: ColumnType? = null): Op { +inline fun anyFrom(array: List, delegateType: ColumnType? = null): Op { // emptyList() without type info generates ARRAY[] @OptIn(InternalApi::class) val columnType = delegateType ?: resolveColumnType(T::class, if (array.isEmpty()) TextColumnType() else null) @@ -163,7 +163,7 @@ fun allFrom(subQuery: AbstractQuery<*>): Op = AllAnyFromSubQueryOp(false, * * @throws IllegalStateException If no column type mapping is found and a [delegateType] is not provided. */ -inline fun allFrom(array: Array, delegateType: ColumnType? = null): Op { +inline fun allFrom(array: Array, delegateType: ColumnType? = null): Op { // emptyArray() without type info generates ARRAY[] @OptIn(InternalApi::class) val columnType = delegateType ?: resolveColumnType(T::class, if (array.isEmpty()) TextColumnType() else null) @@ -178,7 +178,7 @@ inline fun allFrom(array: Array, delegateType: ColumnType? * * @throws IllegalStateException If no column type mapping is found and a [delegateType] is not provided. */ -inline fun allFrom(array: List, delegateType: ColumnType? = null): Op { +inline fun allFrom(array: List, delegateType: ColumnType? = null): Op { // emptyList() without type info generates ARRAY[] @OptIn(InternalApi::class) val columnType = delegateType ?: resolveColumnType(T::class, if (array.isEmpty()) TextColumnType() else null) @@ -197,7 +197,7 @@ fun ?> allFrom(expression: Expression): Op = AllAnyFromExpr * @sample org.jetbrains.exposed.sql.tests.shared.types.ArrayColumnTypeTests.testSelectUsingArrayGet */ infix operator fun ?> ExpressionWithColumnType.get(index: Int): ArrayGet = - ArrayGet(this, index, (this.columnType as ArrayColumnType).delegate) + ArrayGet(this, index, (this.columnType as ArrayColumnType).delegate) /** * Returns a subarray of elements stored from between [lower] and [upper] bounds (inclusive), @@ -207,7 +207,7 @@ infix operator fun ?> ExpressionWithColumnType.get(index: Int) * @sample org.jetbrains.exposed.sql.tests.shared.types.ArrayColumnTypeTests.testSelectUsingArraySlice */ fun ?> ExpressionWithColumnType.slice(lower: Int? = null, upper: Int? = null): ArraySlice = - ArraySlice(this, lower, upper, this.columnType as ArrayColumnType) + ArraySlice(this, lower, upper, this.columnType) // Sequence Manipulation Functions @@ -220,14 +220,14 @@ fun Sequence.nextLongVal(): NextVal = NextVal.LongNextVal(this) // Value Expressions /** Specifies a conversion from one data type to another. */ -fun Expression<*>.castTo(columnType: IColumnType): ExpressionWithColumnType = Cast(this, columnType) +fun Expression<*>.castTo(columnType: IColumnType): ExpressionWithColumnType = Cast(this, columnType) // Misc. /** * Calls a custom SQL function with the specified [functionName] and passes this expression as its only argument. */ -fun ExpressionWithColumnType.function(functionName: String): CustomFunction = CustomFunction(functionName, columnType, this) +fun ExpressionWithColumnType.function(functionName: String): CustomFunction = CustomFunction(functionName, columnType, this) /** * Calls a custom SQL function with the specified [functionName], that returns a string, and passing [params] as its arguments. @@ -780,11 +780,11 @@ interface ISqlExpressionBuilder { // Conditional Expressions /** Returns the first of its arguments that is not null. */ - fun , R : T> coalesce( - expr: ExpressionWithColumnType, - alternate: A, - vararg others: A - ): Coalesce = Coalesce(expr, alternate, others = others) + fun coalesce( + expr: ExpressionWithColumnType, + alternate: Expression, + vararg others: Expression + ): Coalesce = Coalesce(expr, alternate, others = others) /** * Compares [value] against any chained conditional expressions. @@ -900,8 +900,8 @@ interface ISqlExpressionBuilder { is ULong -> ulongParam(value) is Float -> floatParam(value) is Double -> doubleParam(value) - is String -> QueryParameter(value, columnType) // String value should inherit from column - else -> QueryParameter(value, columnType) + is String -> QueryParameter(value, columnType as IColumnType) // String value should inherit from column + else -> QueryParameter(value, columnType as IColumnType) } as QueryParameter /** Returns the specified [value] as a literal of type [T]. */ @@ -920,7 +920,7 @@ interface ISqlExpressionBuilder { is Double -> doubleLiteral(value) is String -> stringLiteral(value) is ByteArray -> stringLiteral(value.toString(Charsets.UTF_8)) - else -> LiteralOp(columnType, value) + else -> LiteralOp(columnType as IColumnType, value) } as LiteralOp fun ExpressionWithColumnType.intToDecimal(): NoOpConversion = diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt index befcbc8bbc..f635286d24 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/SchemaUtils.kt @@ -204,11 +204,15 @@ object SchemaUtils { else -> processed.trim('\'') } } - column.columnType is ArrayColumnType && dialect is PostgreSQLDialect -> { + column.columnType is ArrayColumnType<*> && dialect is PostgreSQLDialect -> { (value as List<*>) .takeIf { it.isNotEmpty() } ?.run { - val delegate = column.withColumnType(column.columnType.delegate) + val delegate = Column( + table = column.table, + name = column.name, + columnType = column.columnType + ) val processed = map { if (delegate.columnType is StringColumnType) { "'$it'::text" diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt index ae77b71a6e..69adcc2dfb 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt @@ -515,7 +515,7 @@ open class Table(name: String = "") : ColumnSet(), DdlAware { // Column registration /** Adds a column of the specified [type] and with the specified [name] to the table. */ - fun registerColumn(name: String, type: IColumnType): Column = Column( + fun registerColumn(name: String, type: IColumnType): Column = Column( this, name, type @@ -826,7 +826,7 @@ open class Table(name: String = "") : ColumnSet(), DdlAware { * @param maximumCardinality The maximum amount of allowed elements. **Note** Providing an array size limit * when using the PostgreSQL dialect is allowed, but this value will be ignored by the database. */ - fun array(name: String, columnType: ColumnType, maximumCardinality: Int? = null): Column> = + fun array(name: String, columnType: ColumnType, maximumCardinality: Int? = null): Column> = registerColumn(name, ArrayColumnType(columnType.apply { nullable = true }, maximumCardinality)) /** @@ -844,9 +844,9 @@ open class Table(name: String = "") : ColumnSet(), DdlAware { * when using the PostgreSQL dialect is allowed, but this value will be ignored by the database. * @throws IllegalStateException If no column type mapping is found. */ - inline fun array(name: String, maximumCardinality: Int? = null): Column> { + inline fun array(name: String, maximumCardinality: Int? = null): Column> { @OptIn(InternalApi::class) - return array(name, resolveColumnType(T::class), maximumCardinality) + return array(name, resolveColumnType(E::class), maximumCardinality) } // Auto-generated values @@ -908,7 +908,7 @@ open class Table(name: String = "") : ColumnSet(), DdlAware { * without getting an error. * The value for the column can be set by creating a TRIGGER or with a DEFAULT clause, for example. */ - fun Column.databaseGenerated(): Column = apply { + fun Column.databaseGenerated(): Column = apply { isDatabaseGenerated = true } @@ -1011,7 +1011,7 @@ open class Table(name: String = "") : ColumnSet(), DdlAware { onUpdate: ReferenceOption? = null, fkName: String? = null ): Column { - val column = Column( + val column = Column( this, name, refColumn.columnType.cloneAsBaseType() @@ -1334,16 +1334,26 @@ open class Table(name: String = "") : ColumnSet(), DdlAware { } } - private fun IColumnType.cloneAsBaseType(): IColumnType = ((this as? AutoIncColumnType)?.delegate ?: this).clone() + private fun IColumnType.cloneAsBaseType(): IColumnType = ((this as? AutoIncColumnType)?.delegate ?: this).clone() private fun Column.cloneWithAutoInc(idSeqName: String?): Column = when (columnType) { is AutoIncColumnType -> this is ColumnType -> { val q = if (tableName.contains('.')) "\"" else "" val fallbackSeqName = "$q${tableName.replace("\"", "")}_${name}_seq$q" - this.withColumnType( - AutoIncColumnType(columnType, idSeqName, fallbackSeqName) - ) + Column( + table = this.table, + name = this.name, + columnType = AutoIncColumnType(columnType, idSeqName, fallbackSeqName) + ).also { + it.foreignKey = this.foreignKey + it.defaultValueFun = this.defaultValueFun + it.dbDefaultValue = this.dbDefaultValue + it.isDatabaseGenerated = this.isDatabaseGenerated + } +// this.withColumnType( +// AutoIncColumnType(columnType, idSeqName, fallbackSeqName) +// ) } else -> error("Unsupported column type for auto-increment $columnType") diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Transaction.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Transaction.kt index 730593125b..7ffc1e1dd4 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Transaction.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Transaction.kt @@ -156,7 +156,7 @@ open class Transaction( */ fun exec( @Language("sql") stmt: String, - args: Iterable> = emptyList(), + args: Iterable, Any?>> = emptyList(), explicitStatementType: StatementType? = null ) = exec(stmt, args, explicitStatementType) { } @@ -178,7 +178,7 @@ open class Transaction( */ fun exec( @Language("sql") stmt: String, - args: Iterable> = emptyList(), + args: Iterable, Any?>> = emptyList(), explicitStatementType: StatementType? = null, transform: (ResultSet) -> T? ): T? { @@ -206,7 +206,7 @@ open class Transaction( override fun prepareSQL(transaction: Transaction, prepared: Boolean): String = stmt - override fun arguments(): Iterable>> = listOf( + override fun arguments(): Iterable, Any?>>> = listOf( args.map { (columnType, value) -> columnType.apply { nullable = true } to value } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/WindowFunction.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/WindowFunction.kt index b2a3d6e27c..1329e414fc 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/WindowFunction.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/WindowFunction.kt @@ -15,7 +15,7 @@ interface WindowFunction { /** Represents an SQL window function with window definition. */ @Suppress("TooManyFunctions") class WindowFunctionDefinition( - override val columnType: IColumnType, + override val columnType: IColumnType, /** Returns the function that definition is used for. */ private val function: WindowFunction ) : ExpressionWithColumnType() { diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/functions/array/ArrayFunctions.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/functions/array/ArrayFunctions.kt index 45d86f8ecc..6eca7728fe 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/functions/array/ArrayFunctions.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/functions/array/ArrayFunctions.kt @@ -14,12 +14,12 @@ import org.jetbrains.exposed.sql.vendors.h2Mode * Represents an SQL function that returns the array element stored at the one-based [index] position, * or `null` if the stored array itself is null. */ -class ArrayGet?>( +class ArrayGet?>( /** The array expression that is accessed. */ val expression: Expression, /** The one-based index position at which the stored array is accessed. */ val index: Int, - columnType: IColumnType + columnType: IColumnType ) : Function(columnType) { override fun toQueryBuilder(queryBuilder: QueryBuilder) { queryBuilder { @@ -32,14 +32,14 @@ class ArrayGet?>( * Represents an SQL function that returns a subarray of elements stored from between [lower] and [upper] bounds (inclusive), * or `null` if the stored array itself is null. */ -class ArraySlice?>( +class ArraySlice?>( /** The array expression from which the subarray is returned. */ val expression: Expression, /** The lower bounds (inclusive) of a subarray. If left `null`, the database will use the stored array's lower limit. */ val lower: Int?, /** The upper bounds (inclusive) of a subarray. If left `null`, the database will use the stored array's upper limit. */ val upper: Int?, - columnType: IColumnType + columnType: IColumnType ) : Function(columnType) { override fun toQueryBuilder(queryBuilder: QueryBuilder) { val functionProvider = when (currentDialect.h2Mode) { diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ops/AllAnyOps.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ops/AllAnyOps.kt index f04fbc7c11..0c170c9823 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ops/AllAnyOps.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ops/AllAnyOps.kt @@ -45,7 +45,7 @@ class AllAnyFromSubQueryOp( class AllAnyFromArrayOp( isAny: Boolean, array: List, - private val delegateType: ColumnType + private val delegateType: ColumnType ) : AllAnyFromBaseOp>(isAny, array) { override fun QueryBuilder.registerSubSearchArgument(subSearch: List) { registerArgument(ArrayColumnType(delegateType), subSearch) diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/BatchUpdateStatement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/BatchUpdateStatement.kt index 5d6328120b..1b15c43bd7 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/BatchUpdateStatement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/BatchUpdateStatement.kt @@ -50,7 +50,7 @@ open class BatchUpdateStatement(val table: IdTable<*>) : UpdateStatement(table, override fun PreparedStatementApi.executeInternal(transaction: Transaction): Int = if (data.size == 1) executeUpdate() else executeBatch().sum() - override fun arguments(): Iterable>> = data.map { (id, row) -> + override fun arguments(): Iterable, Any?>>> = data.map { (id, row) -> firstDataSet.map { it.first.columnType to row[it.first] } + (table.id.columnType to id) } } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/DeleteStatement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/DeleteStatement.kt index 4c5854df8e..94a5d2813c 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/DeleteStatement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/DeleteStatement.kt @@ -32,7 +32,7 @@ open class DeleteStatement( override fun prepareSQL(transaction: Transaction, prepared: Boolean): String = transaction.db.dialect.functionProvider.delete(isIgnore, table, where?.let { QueryBuilder(prepared).append(it).toString() }, limit, transaction) - override fun arguments(): Iterable>> = QueryBuilder(true).run { + override fun arguments(): Iterable, Any?>>> = QueryBuilder(true).run { where?.toQueryBuilder(this) listOf(args) } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/InsertSelectStatement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/InsertSelectStatement.kt index fff8bd0d5e..d05a1c197e 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/InsertSelectStatement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/InsertSelectStatement.kt @@ -29,7 +29,7 @@ open class InsertSelectStatement( override fun PreparedStatementApi.executeInternal(transaction: Transaction): Int? = executeUpdate() - override fun arguments(): Iterable>> = selectQuery.arguments() + override fun arguments(): Iterable, Any?>>> = selectQuery.arguments() override fun prepareSQL(transaction: Transaction, prepared: Boolean): String = transaction.db.dialect.functionProvider.insert(isIgnore, targets.single(), columns, selectQuery.prepareSQL(transaction, prepared), transaction) diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/InsertStatement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/InsertStatement.kt index e9421004ac..aeb5a05081 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/InsertStatement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/InsertStatement.kt @@ -208,7 +208,7 @@ open class InsertStatement( listOf(result).apply { field = this } } - override fun arguments(): List>> { + override fun arguments(): List, Any?>>> { return arguments?.map { args -> val builder = QueryBuilder(true) args.filter { (_, value) -> diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/Statement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/Statement.kt index 70fb63b73f..1f5193e812 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/Statement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/Statement.kt @@ -36,7 +36,7 @@ abstract class Statement(val type: StatementType, val targets: List>> + abstract fun arguments(): Iterable, Any?>>> /** * Uses a [transaction] connection and an [sql] string representation to return a precompiled SQL statement, @@ -77,7 +77,7 @@ abstract class Statement(val type: StatementType, val targets: List
+ contexts.forEachIndexed { _, context -> statement.fillParameters(context.args) // REVIEW if (contexts.size > 1 || isAlwaysBatch) statement.addBatch() @@ -103,7 +103,7 @@ abstract class Statement(val type: StatementType, val targets: List
, val args: Iterable>) { +class StatementContext(val statement: Statement<*>, val args: Iterable, Any?>>) { /** Returns the string representation of the SQL statement associated with this [StatementContext]. */ fun sql(transaction: Transaction) = statement.prepareSQL(transaction) } @@ -133,7 +133,7 @@ fun StatementContext.expandArgs(transaction: Transaction): String { append(sql.substring(lastPos, i)) lastPos = i + 1 val (col, value) = iterator.next() - append(col.valueToString(value)) + append((col as IColumnType).valueToString(value)) } char == '\'' || char == '\"' -> { when { diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateBuilder.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateBuilder.kt index caa284f73f..4633ce74c5 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateBuilder.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateBuilder.kt @@ -3,14 +3,7 @@ package org.jetbrains.exposed.sql.statements import org.jetbrains.exposed.dao.id.EntityID -import org.jetbrains.exposed.sql.Column -import org.jetbrains.exposed.sql.CompositeColumn -import org.jetbrains.exposed.sql.Expression -import org.jetbrains.exposed.sql.Op -import org.jetbrains.exposed.sql.Query -import org.jetbrains.exposed.sql.SqlExpressionBuilder -import org.jetbrains.exposed.sql.Table -import org.jetbrains.exposed.sql.wrapAsExpression +import org.jetbrains.exposed.sql.* import kotlin.internal.LowPriorityInOverloadResolution /** @@ -41,7 +34,7 @@ abstract class UpdateBuilder(type: StatementType, targets: List
) : @JvmName("setWithEntityIdValue") operator fun , ID : EntityID> set(column: Column, value: S) { - column.columnType.validateValueBeforeUpdate(value) +// column.columnType.validateValueBeforeUpdate(value) values[column] = value } @@ -50,7 +43,7 @@ abstract class UpdateBuilder(type: StatementType, targets: List
) : require(column.columnType.nullable || value != null) { "Trying to set null to not nullable column $column" } - column.columnType.validateValueBeforeUpdate(value) +// column.columnType.validateValueBeforeUpdate(value) values[column] = value } @@ -60,7 +53,7 @@ abstract class UpdateBuilder(type: StatementType, targets: List
) : "Trying to set null to not nullable column $column" } checkThatExpressionWasNotSetInPreviousBatch(column) - column.columnType.validateValueBeforeUpdate(value) +// column.columnType.validateValueBeforeUpdate(value) values[column] = value } @@ -80,7 +73,7 @@ abstract class UpdateBuilder(type: StatementType, targets: List
) : **/ open fun update(column: Column, value: Expression) { checkThatExpressionWasNotSetInPreviousBatch(column) - column.columnType.validateValueBeforeUpdate(value) +// column.columnType.validateValueBeforeUpdate(value) values[column] = value } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt index 8c91ebe2f8..02846d1226 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt @@ -46,7 +46,7 @@ open class UpdateStatement(val targetsSet: ColumnSet, val limit: Int?, val where } } - override fun arguments(): Iterable>> = QueryBuilder(true).run { + override fun arguments(): Iterable, Any?>>> = QueryBuilder(true).run { val dialect = currentDialect when { targetsSet is Join && dialect is OracleDialect -> { @@ -72,9 +72,13 @@ open class UpdateStatement(val targetsSet: ColumnSet, val limit: Int?, val where if (args.isNotEmpty()) listOf(args) else emptyList() } - private fun QueryBuilder.registerWhereArg() { where?.toQueryBuilder(this) } + private fun QueryBuilder.registerWhereArg() { + where?.toQueryBuilder(this) + } - private fun QueryBuilder.registerUpdateArgs() { values.forEach { registerArgument(it.key, it.value) } } + private fun QueryBuilder.registerUpdateArgs() { + values.forEach { registerArgument(it.key, it.value) } + } private fun QueryBuilder.registerAdditionalArgs(join: Join) { join.joinParts.forEach { diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpsertStatement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpsertStatement.kt index 92a28a068e..b8133d5c4b 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpsertStatement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpsertStatement.kt @@ -1,7 +1,9 @@ package org.jetbrains.exposed.sql.statements import org.jetbrains.exposed.sql.* -import org.jetbrains.exposed.sql.vendors.* +import org.jetbrains.exposed.sql.vendors.H2Dialect +import org.jetbrains.exposed.sql.vendors.H2FunctionProvider +import org.jetbrains.exposed.sql.vendors.MysqlFunctionProvider /** * Represents the SQL statement that either inserts a new row into a table, or updates the existing row if insertion would violate a unique constraint. @@ -34,7 +36,7 @@ open class UpsertStatement( return functionProvider.upsert(table, arguments!!.first(), onUpdate, onUpdateExclude, where, transaction, keys = keys) } - override fun arguments(): List>> { + override fun arguments(): List, Any?>>> { return arguments?.map { args -> val builder = QueryBuilder(true) args.filter { (_, value) -> diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/api/PreparedStatementApi.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/api/PreparedStatementApi.kt index 96d0da1921..70fcdd804b 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/api/PreparedStatementApi.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/api/PreparedStatementApi.kt @@ -19,9 +19,9 @@ interface PreparedStatementApi { * Sets the value for each column or expression in [args] into the appropriate statement parameter and * returns the number of parameters filled. */ - fun fillParameters(args: Iterable>): Int { + fun fillParameters(args: Iterable, Any?>>): Int { args.forEachIndexed { index, (c, v) -> - c.setParameter(this, index + 1, c.valueToDB(v)) + c.setParameter(this, index + 1, (c as IColumnType).valueToDB(v)) } return args.count() + 1 @@ -59,7 +59,7 @@ interface PreparedStatementApi { operator fun set(index: Int, value: Any) /** Sets the statement parameter at the [index] position to SQL NULL, if allowed wih the specified [columnType]. */ - fun setNull(index: Int, columnType: IColumnType) + fun setNull(index: Int, columnType: IColumnType<*>) /** * Sets the statement parameter at the [index] position to the provided [inputStream], diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/FunctionProvider.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/FunctionProvider.kt index c300394f9a..d7062396c7 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/FunctionProvider.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/FunctionProvider.kt @@ -257,7 +257,7 @@ abstract class FunctionProvider { */ open fun cast( expr: Expression<*>, - type: IColumnType, + type: IColumnType<*>, builder: QueryBuilder ): Unit = builder { append("CAST(", expr, " AS ", type.sqlType(), ")") @@ -343,7 +343,7 @@ abstract class FunctionProvider { expression: Expression, vararg path: String, toScalar: Boolean, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { throw UnsupportedByDialectException( @@ -365,7 +365,7 @@ abstract class FunctionProvider { target: Expression<*>, candidate: Expression<*>, path: String?, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { throw UnsupportedByDialectException( @@ -387,7 +387,7 @@ abstract class FunctionProvider { expression: Expression<*>, vararg path: String, optional: String?, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { throw UnsupportedByDialectException( diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/MysqlDialect.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/MysqlDialect.kt index b018811f26..ef7243735c 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/MysqlDialect.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/MysqlDialect.kt @@ -129,7 +129,7 @@ internal open class MysqlFunctionProvider : FunctionProvider() { expression: Expression, vararg path: String, toScalar: Boolean, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) = queryBuilder { if (toScalar) append("JSON_UNQUOTE(") @@ -142,7 +142,7 @@ internal open class MysqlFunctionProvider : FunctionProvider() { target: Expression<*>, candidate: Expression<*>, path: String?, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) = queryBuilder { append("JSON_CONTAINS(", target, ", ", candidate) @@ -156,7 +156,7 @@ internal open class MysqlFunctionProvider : FunctionProvider() { expression: Expression<*>, vararg path: String, optional: String?, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { val oneOrAll = optional?.lowercase() @@ -186,7 +186,7 @@ internal open class MysqlFunctionProvider : FunctionProvider() { override fun sqlType(): String = "CHAR" } - override fun cast(expr: Expression<*>, type: IColumnType, builder: QueryBuilder) = when (type) { + override fun cast(expr: Expression<*>, type: IColumnType<*>, builder: QueryBuilder) = when (type) { is StringColumnType -> super.cast(expr, CharColumnType, builder) else -> super.cast(expr, type, builder) } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/OracleDialect.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/OracleDialect.kt index 5f56a4e33e..23ee2e0aaa 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/OracleDialect.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/OracleDialect.kt @@ -173,7 +173,7 @@ internal object OracleFunctionProvider : FunctionProvider() { expression: Expression, vararg path: String, toScalar: Boolean, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { if (path.size > 1) { @@ -191,7 +191,7 @@ internal object OracleFunctionProvider : FunctionProvider() { expression: Expression<*>, vararg path: String, optional: String?, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { if (path.size > 1) { diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/PostgreSQL.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/PostgreSQL.kt index bee5325120..09cd4e2212 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/PostgreSQL.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/PostgreSQL.kt @@ -29,7 +29,7 @@ internal object PostgreSQLDataTypeProvider : DataTypeProvider() { e is LiteralOp<*> && e.columnType is BlobColumnType && e.columnType.useObjectIdentifier && (currentDialect as? H2Dialect) == null -> { "lo_from_bytea(0, ${super.processForDefaultValue(e)} :: bytea)" } - e is LiteralOp<*> && e.columnType is ArrayColumnType -> { + e is LiteralOp<*> && e.columnType is ArrayColumnType<*> -> { val processed = super.processForDefaultValue(e) processed .takeUnless { it == "ARRAY[]" } @@ -145,7 +145,7 @@ internal object PostgreSQLFunctionProvider : FunctionProvider() { expression: Expression, vararg path: String, toScalar: Boolean, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) = queryBuilder { append("${jsonType.sqlType()}_EXTRACT_PATH") @@ -159,7 +159,7 @@ internal object PostgreSQLFunctionProvider : FunctionProvider() { target: Expression<*>, candidate: Expression<*>, path: String?, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { path?.let { @@ -178,7 +178,7 @@ internal object PostgreSQLFunctionProvider : FunctionProvider() { expression: Expression<*>, vararg path: String, optional: String?, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { if (path.size > 1) { diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/SQLServerDialect.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/SQLServerDialect.kt index 6cc09d52bc..482fed6474 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/SQLServerDialect.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/SQLServerDialect.kt @@ -152,7 +152,7 @@ internal object SQLServerFunctionProvider : FunctionProvider() { expression: Expression, vararg path: String, toScalar: Boolean, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { if (path.size > 1) { diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/SQLiteDialect.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/SQLiteDialect.kt index e244b7d883..e22e30b908 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/SQLiteDialect.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/SQLiteDialect.kt @@ -133,7 +133,7 @@ internal object SQLiteFunctionProvider : FunctionProvider() { expression: Expression, vararg path: String, toScalar: Boolean, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) = queryBuilder { append("JSON_EXTRACT(", expression, ", ") @@ -145,7 +145,7 @@ internal object SQLiteFunctionProvider : FunctionProvider() { expression: Expression<*>, vararg path: String, optional: String?, - jsonType: IColumnType, + jsonType: IColumnType<*>, queryBuilder: QueryBuilder ) { val transaction = TransactionManager.current() diff --git a/exposed-crypt/src/main/kotlin/org/jetbrains/exposed/crypt/EncryptedBinaryColumnType.kt b/exposed-crypt/src/main/kotlin/org/jetbrains/exposed/crypt/EncryptedBinaryColumnType.kt index e592ed1ee7..f60a2a1251 100644 --- a/exposed-crypt/src/main/kotlin/org/jetbrains/exposed/crypt/EncryptedBinaryColumnType.kt +++ b/exposed-crypt/src/main/kotlin/org/jetbrains/exposed/crypt/EncryptedBinaryColumnType.kt @@ -11,29 +11,18 @@ class EncryptedBinaryColumnType( private val encryptor: Encryptor, length: Int ) : BinaryColumnType(length) { - override fun nonNullValueToString(value: Any): String { + override fun nonNullValueToString(value: ByteArray): String { return super.nonNullValueToString(notNullValueToDB(value)) } - override fun notNullValueToDB(value: Any): Any { - if (value !is ByteArray) { - error("Unexpected value of type Byte: $value of ${value::class.qualifiedName}") - } - - return encryptor.encrypt(String(value)).toByteArray() - } + override fun notNullValueToDB(value: ByteArray): ByteArray = encryptor.encrypt(String(value)).toByteArray() - override fun valueFromDB(value: Any): Any { + override fun valueFromDB(value: Any): ByteArray { val encryptedByte = super.valueFromDB(value) - - if (encryptedByte !is ByteArray) { - error("Unexpected value of type Byte: $value of ${value::class.qualifiedName}") - } - return encryptor.decrypt(String(encryptedByte)).toByteArray() } - override fun validateValueBeforeUpdate(value: Any?) { + override fun validateValueBeforeUpdate(value: ByteArray?) { if (value != null) { super.validateValueBeforeUpdate(notNullValueToDB(value)) } diff --git a/exposed-crypt/src/main/kotlin/org/jetbrains/exposed/crypt/EncryptedVarCharColumnType.kt b/exposed-crypt/src/main/kotlin/org/jetbrains/exposed/crypt/EncryptedVarCharColumnType.kt index 6dbc6dbaff..37e5c7ce78 100644 --- a/exposed-crypt/src/main/kotlin/org/jetbrains/exposed/crypt/EncryptedVarCharColumnType.kt +++ b/exposed-crypt/src/main/kotlin/org/jetbrains/exposed/crypt/EncryptedVarCharColumnType.kt @@ -12,25 +12,20 @@ class EncryptedVarCharColumnType( private val encryptor: Encryptor, colLength: Int, ) : VarCharColumnType(colLength) { - override fun nonNullValueToString(value: Any): String { + override fun nonNullValueToString(value: String): String { return super.nonNullValueToString(notNullValueToDB(value)) } - override fun notNullValueToDB(value: Any): Any { - return encryptor.encrypt(value.toString()) + override fun notNullValueToDB(value: String): String { + return encryptor.encrypt(value) } - override fun valueFromDB(value: Any): Any { + override fun valueFromDB(value: Any): String { val encryptedStr = super.valueFromDB(value) - - if (encryptedStr !is String) { - error("Unexpected value of type String: $value of ${value::class.qualifiedName}") - } - return encryptor.decrypt(encryptedStr) } - override fun validateValueBeforeUpdate(value: Any?) { + override fun validateValueBeforeUpdate(value: String?) { if (value != null) { super.validateValueBeforeUpdate(notNullValueToDB(value)) } diff --git a/exposed-java-time/src/main/kotlin/org/jetbrains/exposed/sql/javatime/JavaDateColumnType.kt b/exposed-java-time/src/main/kotlin/org/jetbrains/exposed/sql/javatime/JavaDateColumnType.kt index 32922b7264..7649c653ca 100644 --- a/exposed-java-time/src/main/kotlin/org/jetbrains/exposed/sql/javatime/JavaDateColumnType.kt +++ b/exposed-java-time/src/main/kotlin/org/jetbrains/exposed/sql/javatime/JavaDateColumnType.kt @@ -120,43 +120,32 @@ private val LocalDate.millis get() = atStartOfDay(ZoneId.systemDefault()).toEpoc * @sample org.jetbrains.exposed.sql.javatime.date */ @Suppress("MagicNumber") -class JavaLocalDateColumnType : ColumnType(), IDateColumnType { +class JavaLocalDateColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = false override fun sqlType(): String = currentDialect.dataTypeProvider.dateType() - override fun nonNullValueToString(value: Any): String { - val instant = when (value) { - is String -> return value - is LocalDate -> Instant.from(value.atStartOfDay(ZoneId.systemDefault())) - is java.sql.Date -> Instant.ofEpochMilli(value.time) - is java.sql.Timestamp -> Instant.ofEpochSecond(value.time / 1000, value.nanos.toLong()) - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } - + override fun nonNullValueToString(value: LocalDate): String { + val instant = Instant.from(value.atStartOfDay(ZoneId.systemDefault())) return "'${DEFAULT_DATE_STRING_FORMATTER.format(instant)}'" } - override fun valueFromDB(value: Any): Any = when (value) { + override fun valueFromDB(value: Any): LocalDate? = when (value) { is LocalDate -> value is java.sql.Date -> longToLocalDate(value.time) is java.sql.Timestamp -> longToLocalDate(value.time) is Int -> longToLocalDate(value.toLong()) is Long -> longToLocalDate(value) - is String -> when (currentDialect) { - is SQLiteDialect -> LocalDate.parse(value) - else -> value - } + is String -> LocalDate.parse(value) else -> LocalDate.parse(value.toString()) } - override fun notNullValueToDB(value: Any) = when { - value is LocalDate && currentDialect is SQLiteDialect -> DEFAULT_DATE_STRING_FORMATTER.format(value) - value is LocalDate -> java.sql.Date(value.millis) - else -> value + override fun notNullValueToDB(value: LocalDate): Any = when { + currentDialect is SQLiteDialect -> DEFAULT_DATE_STRING_FORMATTER.format(value) + else -> java.sql.Date(value.millis) } - override fun nonNullValueAsDefaultString(value: Any): String = when (currentDialect) { + override fun nonNullValueAsDefaultString(value: LocalDate): String = when (currentDialect) { is PostgreSQLDialect -> "${nonNullValueToString(value)}::date" else -> super.nonNullValueAsDefaultString(value) } @@ -174,18 +163,12 @@ class JavaLocalDateColumnType : ColumnType(), IDateColumnType { * @sample org.jetbrains.exposed.sql.javatime.datetime */ @Suppress("MagicNumber") -class JavaLocalDateTimeColumnType : ColumnType(), IDateColumnType { +class JavaLocalDateTimeColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = true override fun sqlType(): String = currentDialect.dataTypeProvider.dateTimeType() - override fun nonNullValueToString(value: Any): String { - val instant = when (value) { - is String -> return value - is LocalDateTime -> Instant.from(value.atZone(ZoneId.systemDefault())) - is java.sql.Date -> Instant.ofEpochMilli(value.time) - is java.sql.Timestamp -> Instant.ofEpochSecond(value.time / 1000, value.nanos.toLong()) - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } + override fun nonNullValueToString(value: LocalDateTime): String { + val instant = Instant.from(value.atZone(ZoneId.systemDefault())) val dialect = currentDialect return when { @@ -199,7 +182,7 @@ class JavaLocalDateTimeColumnType : ColumnType(), IDateColumnType { } } - override fun valueFromDB(value: Any): Any = when (value) { + override fun valueFromDB(value: Any): LocalDateTime? = when (value) { is LocalDateTime -> value is java.sql.Date -> longToLocalDateTime(value.time) is java.sql.Timestamp -> longToLocalDateTime(value.time / 1000, value.nanos.toLong()) @@ -210,14 +193,13 @@ class JavaLocalDateTimeColumnType : ColumnType(), IDateColumnType { else -> valueFromDB(value.toString()) } - override fun notNullValueToDB(value: Any): Any = when { - value is LocalDateTime && currentDialect is SQLiteDialect -> + override fun notNullValueToDB(value: LocalDateTime): Any = when { + currentDialect is SQLiteDialect -> SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value.atZone(ZoneId.systemDefault())) - value is LocalDateTime -> { + else -> { val instant = value.atZone(ZoneId.systemDefault()).toInstant() java.sql.Timestamp(instant.toEpochMilli()).apply { nanos = instant.nano } } - else -> value } override fun readObject(rs: ResultSet, index: Int): Any? { @@ -228,18 +210,15 @@ class JavaLocalDateTimeColumnType : ColumnType(), IDateColumnType { } } - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is LocalDateTime -> { - val dialect = currentDialect - when { - dialect is PostgreSQLDialect -> - "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value).trimEnd('0').trimEnd('.')}'::timestamp without time zone" - (dialect as? H2Dialect)?.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> - "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value).trimEnd('0').trimEnd('.')}'" - else -> super.nonNullValueAsDefaultString(value) - } + override fun nonNullValueAsDefaultString(value: LocalDateTime): String { + val dialect = currentDialect + return when { + dialect is PostgreSQLDialect -> + "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value).trimEnd('0').trimEnd('.')}'::timestamp without time zone" + (dialect as? H2Dialect)?.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> + "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value).trimEnd('0').trimEnd('.')}'" + else -> super.nonNullValueAsDefaultString(value) } - else -> super.nonNullValueAsDefaultString(value) } private fun longToLocalDateTime(millis: Long) = LocalDateTime.ofInstant(Instant.ofEpochMilli(millis), ZoneId.systemDefault()) @@ -256,27 +235,19 @@ class JavaLocalDateTimeColumnType : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.javatime.time */ -class JavaLocalTimeColumnType : ColumnType(), IDateColumnType { +class JavaLocalTimeColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = true override fun sqlType(): String = currentDialect.dataTypeProvider.timeType() - override fun nonNullValueToString(value: Any): String { - val instant = when (value) { - is String -> return value - is LocalTime -> value - is java.sql.Time -> Instant.ofEpochMilli(value.time).atZone(ZoneId.systemDefault()) - is java.sql.Timestamp -> Instant.ofEpochMilli(value.time).atZone(ZoneId.systemDefault()) - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } - + override fun nonNullValueToString(value: LocalTime): String { val dialect = currentDialect val formatter = if (dialect is OracleDialect || dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle) { ORACLE_TIME_STRING_FORMATTER } else { DEFAULT_TIME_STRING_FORMATTER } - return "'${formatter.format(instant)}'" + return "'${formatter.format(value)}'" } override fun valueFromDB(value: Any): LocalTime = when (value) { @@ -297,19 +268,11 @@ class JavaLocalTimeColumnType : ColumnType(), IDateColumnType { else -> valueFromDB(value.toString()) } - override fun notNullValueToDB(value: Any): Any = when (value) { - is LocalTime -> java.sql.Time.valueOf(value) - else -> value - } + override fun notNullValueToDB(value: LocalTime): Any = java.sql.Time.valueOf(value) - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is LocalTime -> { - when (currentDialect) { - is PostgreSQLDialect -> "${nonNullValueToString(value)}::time without time zone" - is MysqlDialect -> "'${MYSQL_TIME_AS_DEFAULT_STRING_FORMATTER.format(value)}'" - else -> super.nonNullValueAsDefaultString(value) - } - } + override fun nonNullValueAsDefaultString(value: LocalTime): String = when (currentDialect) { + is PostgreSQLDialect -> "${nonNullValueToString(value)}::time without time zone" + is MysqlDialect -> "'${MYSQL_TIME_AS_DEFAULT_STRING_FORMATTER.format(value)}'" else -> super.nonNullValueAsDefaultString(value) } @@ -325,28 +288,20 @@ class JavaLocalTimeColumnType : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.javatime.timestamp */ -class JavaInstantColumnType : ColumnType(), IDateColumnType { +class JavaInstantColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = true override fun sqlType(): String = currentDialect.dataTypeProvider.dateTimeType() - override fun nonNullValueToString(value: Any): String { - val instant = when (value) { - is String -> return value - is Instant -> value - is java.sql.Timestamp -> value.toInstant() - is LocalDateTime -> value.atZone(ZoneId.systemDefault()).toInstant() - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } - + override fun nonNullValueToString(value: Instant): String { val dialect = currentDialect return when { dialect is OracleDialect || dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> - "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(instant)}'" + "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value)}'" dialect is MysqlDialect -> { val formatter = if (dialect.isFractionDateTimeSupported()) MYSQL_FRACTION_DATE_TIME_STRING_FORMATTER else MYSQL_DATE_TIME_STRING_FORMATTER - "'${formatter.format(instant)}'" + "'${formatter.format(value)}'" } - else -> "'${DEFAULT_DATE_TIME_STRING_FORMATTER.format(instant)}'" + else -> "'${DEFAULT_DATE_TIME_STRING_FORMATTER.format(value)}'" } } @@ -360,26 +315,21 @@ class JavaInstantColumnType : ColumnType(), IDateColumnType { return rs.getTimestamp(index) } - override fun notNullValueToDB(value: Any): Any = when { - value is Instant && currentDialect is SQLiteDialect -> + override fun notNullValueToDB(value: Instant): Any = when (currentDialect) { + is SQLiteDialect -> SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value) - value is Instant -> - java.sql.Timestamp.from(value) - else -> value + else -> java.sql.Timestamp.from(value) } - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is Instant -> { - val dialect = currentDialect - when { - dialect is PostgreSQLDialect -> - "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value).trimEnd('0').trimEnd('.')}'::timestamp without time zone" - dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> - "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value).trimEnd('0').trimEnd('.')}'" - else -> super.nonNullValueAsDefaultString(value) - } + override fun nonNullValueAsDefaultString(value: Instant): String { + val dialect = currentDialect + return when { + dialect is PostgreSQLDialect -> + "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value).trimEnd('0').trimEnd('.')}'::timestamp without time zone" + dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> + "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value).trimEnd('0').trimEnd('.')}'" + else -> super.nonNullValueAsDefaultString(value) } - else -> super.nonNullValueAsDefaultString(value) } companion object { @@ -392,21 +342,16 @@ class JavaInstantColumnType : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.javatime.timestampWithTimeZone */ -class JavaOffsetDateTimeColumnType : ColumnType(), IDateColumnType { +class JavaOffsetDateTimeColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = true override fun sqlType(): String = currentDialect.dataTypeProvider.timestampWithTimeZoneType() - override fun nonNullValueToString(value: Any): String = when (value) { - is OffsetDateTime -> { - when (currentDialect) { - is SQLiteDialect -> "'${value.format(SQLITE_OFFSET_DATE_TIME_FORMATTER)}'" - is MysqlDialect -> "'${value.format(MYSQL_OFFSET_DATE_TIME_FORMATTER)}'" - is OracleDialect -> "'${value.format(ORACLE_OFFSET_DATE_TIME_FORMATTER)}'" - else -> "'${value.format(DEFAULT_OFFSET_DATE_TIME_FORMATTER)}'" - } - } - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") + override fun nonNullValueToString(value: OffsetDateTime): String = when (currentDialect) { + is SQLiteDialect -> "'${value.format(SQLITE_OFFSET_DATE_TIME_FORMATTER)}'" + is MysqlDialect -> "'${value.format(MYSQL_OFFSET_DATE_TIME_FORMATTER)}'" + is OracleDialect -> "'${value.format(ORACLE_OFFSET_DATE_TIME_FORMATTER)}'" + else -> "'${value.format(DEFAULT_OFFSET_DATE_TIME_FORMATTER)}'" } override fun valueFromDB(value: Any): OffsetDateTime = when (value) { @@ -426,30 +371,22 @@ class JavaOffsetDateTimeColumnType : ColumnType(), IDateColumnType { else -> rs.getObject(index, OffsetDateTime::class.java) } - override fun notNullValueToDB(value: Any): Any = when (value) { - is OffsetDateTime -> { - when (currentDialect) { - is SQLiteDialect -> value.format(SQLITE_OFFSET_DATE_TIME_FORMATTER) - is MysqlDialect -> value.format(MYSQL_OFFSET_DATE_TIME_FORMATTER) - else -> value - } - } - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") + override fun notNullValueToDB(value: OffsetDateTime): Any = when (currentDialect) { + is SQLiteDialect -> value.format(SQLITE_OFFSET_DATE_TIME_FORMATTER) + is MysqlDialect -> value.format(MYSQL_OFFSET_DATE_TIME_FORMATTER) + else -> value } - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is OffsetDateTime -> { - val dialect = currentDialect - when { - dialect is PostgreSQLDialect -> // +00 appended because PostgreSQL stores it in UTC time zone - "'${value.format(POSTGRESQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER)}+00'::timestamp with time zone" - dialect is H2Dialect && dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> - "'${value.format(POSTGRESQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER)}'" - dialect is MysqlDialect -> "'${value.format(MYSQL_FRACTION_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER)}'" - else -> super.nonNullValueAsDefaultString(value) - } + override fun nonNullValueAsDefaultString(value: OffsetDateTime): String { + val dialect = currentDialect + return when { + dialect is PostgreSQLDialect -> // +00 appended because PostgreSQL stores it in UTC time zone + "'${value.format(POSTGRESQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER)}+00'::timestamp with time zone" + dialect is H2Dialect && dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> + "'${value.format(POSTGRESQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER)}'" + dialect is MysqlDialect -> "'${value.format(MYSQL_FRACTION_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER)}'" + else -> super.nonNullValueAsDefaultString(value) } - else -> super.nonNullValueAsDefaultString(value) } companion object { @@ -462,20 +399,10 @@ class JavaOffsetDateTimeColumnType : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.javatime.duration */ -class JavaDurationColumnType : ColumnType() { +class JavaDurationColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.longType() - override fun nonNullValueToString(value: Any): String { - val duration = when (value) { - is String -> return value - is Duration -> value - is Long -> Duration.ofNanos(value) - is Number -> Duration.ofNanos(value.toLong()) - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } - - return "'${duration.toNanos()}'" - } + override fun nonNullValueToString(value: Duration): String = "'${value.toNanos()}'" override fun valueFromDB(value: Any): Duration = when (value) { is Long -> Duration.ofNanos(value) @@ -489,12 +416,7 @@ class JavaDurationColumnType : ColumnType() { return rs.getLong(index).takeIf { rs.getObject(index) != null } } - override fun notNullValueToDB(value: Any): Any { - if (value is Duration) { - return value.toNanos() - } - return value - } + override fun notNullValueToDB(value: Duration): Any = value.toNanos() companion object { internal val INSTANCE = JavaDurationColumnType() diff --git a/exposed-java-time/src/test/kotlin/org/jetbrains/exposed/DefaultsTest.kt b/exposed-java-time/src/test/kotlin/org/jetbrains/exposed/DefaultsTest.kt index f0f57ea6e7..9493fb994c 100644 --- a/exposed-java-time/src/test/kotlin/org/jetbrains/exposed/DefaultsTest.kt +++ b/exposed-java-time/src/test/kotlin/org/jetbrains/exposed/DefaultsTest.kt @@ -295,7 +295,7 @@ class DefaultsTest : DatabaseTestsBase() { fun abs(value: Int) = object : ExpressionWithColumnType() { override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder { append("ABS($value)") } - override val columnType: IColumnType = IntegerColumnType() + override val columnType: IColumnType = IntegerColumnType() } val foo = object : IntIdTable("foo") { diff --git a/exposed-java-time/src/test/kotlin/org/jetbrains/exposed/sqlserver/SQLServerDefaultsTest.kt b/exposed-java-time/src/test/kotlin/org/jetbrains/exposed/sqlserver/SQLServerDefaultsTest.kt index c27b256c56..4280906c77 100644 --- a/exposed-java-time/src/test/kotlin/org/jetbrains/exposed/sqlserver/SQLServerDefaultsTest.kt +++ b/exposed-java-time/src/test/kotlin/org/jetbrains/exposed/sqlserver/SQLServerDefaultsTest.kt @@ -16,7 +16,7 @@ class SQLServerDefaultsTest : DatabaseTestsBase() { fun testDefaultExpressionsForTemporalTable() { fun databaseGeneratedTimestamp() = object : ExpressionWithColumnType() { override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder { +"DEFAULT" } - override val columnType: IColumnType = JavaLocalDateTimeColumnType() + override val columnType: IColumnType = JavaLocalDateTimeColumnType() } val temporalTable = object : UUIDTable("TemporalTable") { diff --git a/exposed-jdbc/src/main/kotlin/org/jetbrains/exposed/sql/statements/jdbc/JdbcPreparedStatementImpl.kt b/exposed-jdbc/src/main/kotlin/org/jetbrains/exposed/sql/statements/jdbc/JdbcPreparedStatementImpl.kt index f094ad5591..cf9e340aba 100644 --- a/exposed-jdbc/src/main/kotlin/org/jetbrains/exposed/sql/statements/jdbc/JdbcPreparedStatementImpl.kt +++ b/exposed-jdbc/src/main/kotlin/org/jetbrains/exposed/sql/statements/jdbc/JdbcPreparedStatementImpl.kt @@ -71,7 +71,7 @@ class JdbcPreparedStatementImpl( statement.setObject(index, value) } - override fun setNull(index: Int, columnType: IColumnType) { + override fun setNull(index: Int, columnType: IColumnType<*>) { if (columnType is BinaryColumnType || (columnType is BlobColumnType && !columnType.useObjectIdentifier)) { statement.setNull(index, Types.LONGVARBINARY) } else { diff --git a/exposed-jodatime/src/main/kotlin/org/jetbrains/exposed/sql/jodatime/DateColumnType.kt b/exposed-jodatime/src/main/kotlin/org/jetbrains/exposed/sql/jodatime/DateColumnType.kt index ff7f61ba12..b85c1fb84d 100644 --- a/exposed-jodatime/src/main/kotlin/org/jetbrains/exposed/sql/jodatime/DateColumnType.kt +++ b/exposed-jodatime/src/main/kotlin/org/jetbrains/exposed/sql/jodatime/DateColumnType.kt @@ -55,7 +55,7 @@ private fun dateTimeWithFractionFormat(fraction: Int): DateTimeFormatter { * * @sample org.jetbrains.exposed.sql.jodatime.datetime */ -class DateColumnType(val time: Boolean) : ColumnType(), IDateColumnType { +class DateColumnType(val time: Boolean) : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = time override fun sqlType(): String = if (time) { currentDialect.dataTypeProvider.dateTimeType() @@ -63,28 +63,19 @@ class DateColumnType(val time: Boolean) : ColumnType(), IDateColumnType { currentDialect.dataTypeProvider.dateType() } - override fun nonNullValueToString(value: Any): String { - if (value is String) return value - - val dateTime = when (value) { - is DateTime -> value - is java.sql.Date -> DateTime(value.time) - is java.sql.Timestamp -> DateTime(value.time) - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } - + override fun nonNullValueToString(value: DateTime): String { return if (time) { when { (currentDialect as? MysqlDialect)?.isFractionDateTimeSupported() == false -> - "'${MYSQL_DATE_TIME_STRING_FORMATTER.print(dateTime.toDateTime(DateTimeZone.getDefault()))}'" - else -> "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(dateTime.toDateTime(DateTimeZone.getDefault()))}'" + "'${MYSQL_DATE_TIME_STRING_FORMATTER.print(value.toDateTime(DateTimeZone.getDefault()))}'" + else -> "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(value.toDateTime(DateTimeZone.getDefault()))}'" } } else { - "'${DEFAULT_DATE_STRING_FORMATTER.print(dateTime)}'" + "'${DEFAULT_DATE_STRING_FORMATTER.print(value)}'" } } - override fun valueFromDB(value: Any): Any { + override fun valueFromDB(value: Any): DateTime? { val dateTime = when (value) { is DateTime -> value is java.sql.Date -> DateTime(value.time) @@ -118,37 +109,33 @@ class DateColumnType(val time: Boolean) : ColumnType(), IDateColumnType { } } - override fun notNullValueToDB(value: Any): Any { + override fun notNullValueToDB(value: DateTime): Any { val dialect = currentDialect return when { - value is DateTime && time && dialect is SQLiteDialect -> SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.print(value) - value is DateTime && time -> java.sql.Timestamp(value.millis) - value is DateTime && dialect is SQLiteDialect -> DEFAULT_DATE_STRING_FORMATTER.print(value) - value is DateTime -> java.sql.Date(value.millis) - else -> value + time && dialect is SQLiteDialect -> SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.print(value) + time -> java.sql.Timestamp(value.millis) + dialect is SQLiteDialect -> DEFAULT_DATE_STRING_FORMATTER.print(value) + else -> java.sql.Date(value.millis) } } - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is DateTime -> { - val dialect = currentDialect - when { - dialect is PostgreSQLDialect -> { - if (time) { - "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(value).trimEnd('0').trimEnd('.')}'::timestamp without time zone" - } else { - "'${DEFAULT_DATE_STRING_FORMATTER.print(value)}'::date" - } - } - time && (dialect as? H2Dialect)?.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> - "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(value).trimEnd('0').trimEnd('.')}'" - time && dialect is MysqlDialect && dialect.isFractionDateTimeSupported() -> { - "'${MYSQL_FRACTION_DATE_TIME_AS_DEFAULT_FORMATTER.print(value)}'" + override fun nonNullValueAsDefaultString(value: DateTime): String { + val dialect = currentDialect + return when { + dialect is PostgreSQLDialect -> { + if (time) { + "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(value).trimEnd('0').trimEnd('.')}'::timestamp without time zone" + } else { + "'${DEFAULT_DATE_STRING_FORMATTER.print(value)}'::date" } - else -> super.nonNullValueAsDefaultString(value) } + time && (dialect as? H2Dialect)?.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> + "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(value).trimEnd('0').trimEnd('.')}'" + time && dialect is MysqlDialect && dialect.isFractionDateTimeSupported() -> { + "'${MYSQL_FRACTION_DATE_TIME_AS_DEFAULT_FORMATTER.print(value)}'" + } + else -> super.nonNullValueAsDefaultString(value) } - else -> super.nonNullValueAsDefaultString(value) } override fun equals(other: Any?): Boolean { @@ -173,21 +160,16 @@ class DateColumnType(val time: Boolean) : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.jodatime.timestampWithTimeZone */ -class DateTimeWithTimeZoneColumnType : ColumnType(), IDateColumnType { +class DateTimeWithTimeZoneColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = true override fun sqlType(): String = currentDialect.dataTypeProvider.timestampWithTimeZoneType() - override fun nonNullValueToString(value: Any): String = when (value) { - is DateTime -> { - when (currentDialect) { - is SQLiteDialect -> "'${SQLITE_AND_MYSQL_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value)}'" - is MysqlDialect -> "'${SQLITE_AND_MYSQL_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value)}'" - is OracleDialect -> "'${ORACLE_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value)}'" - else -> "'${DEFAULT_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value)}'" - } - } - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") + override fun nonNullValueToString(value: DateTime): String = when (currentDialect) { + is SQLiteDialect -> "'${SQLITE_AND_MYSQL_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value)}'" + is MysqlDialect -> "'${SQLITE_AND_MYSQL_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value)}'" + is OracleDialect -> "'${ORACLE_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value)}'" + else -> "'${DEFAULT_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value)}'" } override fun valueFromDB(value: Any): DateTime = when (value) { @@ -209,30 +191,22 @@ class DateTimeWithTimeZoneColumnType : ColumnType(), IDateColumnType { else -> rs.getObject(index, java.time.OffsetDateTime::class.java) } - override fun notNullValueToDB(value: Any): Any = when (value) { - is DateTime -> { - when (currentDialect) { - is SQLiteDialect -> SQLITE_AND_MYSQL_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value) - is MysqlDialect -> SQLITE_AND_MYSQL_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value) - else -> java.sql.Timestamp(value.millis) - } - } - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") + override fun notNullValueToDB(value: DateTime): Any = when (currentDialect) { + is SQLiteDialect -> SQLITE_AND_MYSQL_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value) + is MysqlDialect -> SQLITE_AND_MYSQL_DATE_TIME_WITH_TIME_ZONE_FORMATTER.print(value) + else -> java.sql.Timestamp(value.millis) } - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is DateTime -> { - val dialect = currentDialect - when { - dialect is PostgreSQLDialect -> - "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(value).trimEnd('0')}+00'::timestamp with time zone" - (dialect as? H2Dialect)?.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> - "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(value).trimEnd('0')}'" - dialect is MysqlDialect -> "'${MYSQL_FRACTION_DATE_TIME_AS_DEFAULT_FORMATTER.print(value)}'" - else -> super.nonNullValueAsDefaultString(value) - } + override fun nonNullValueAsDefaultString(value: DateTime): String { + val dialect = currentDialect + return when { + dialect is PostgreSQLDialect -> + "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(value).trimEnd('0')}+00'::timestamp with time zone" + (dialect as? H2Dialect)?.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> + "'${DEFAULT_DATE_TIME_STRING_FORMATTER.print(value).trimEnd('0')}'" + dialect is MysqlDialect -> "'${MYSQL_FRACTION_DATE_TIME_AS_DEFAULT_FORMATTER.print(value)}'" + else -> super.nonNullValueAsDefaultString(value) } - else -> super.nonNullValueAsDefaultString(value) } } diff --git a/exposed-jodatime/src/test/kotlin/org/jetbrains/exposed/JodaTimeDefaultsTest.kt b/exposed-jodatime/src/test/kotlin/org/jetbrains/exposed/JodaTimeDefaultsTest.kt index 5d54bdaa3d..d448b4471e 100644 --- a/exposed-jodatime/src/test/kotlin/org/jetbrains/exposed/JodaTimeDefaultsTest.kt +++ b/exposed-jodatime/src/test/kotlin/org/jetbrains/exposed/JodaTimeDefaultsTest.kt @@ -225,7 +225,7 @@ class JodaTimeDefaultsTest : JodaTimeBaseTest() { fun abs(value: Int) = object : ExpressionWithColumnType() { override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder { append("ABS($value)") } - override val columnType: IColumnType = IntegerColumnType() + override val columnType: IColumnType = IntegerColumnType() } val foo = object : IntIdTable("foo") { diff --git a/exposed-json/src/main/kotlin/org/jetbrains/exposed/sql/json/JsonColumnType.kt b/exposed-json/src/main/kotlin/org/jetbrains/exposed/sql/json/JsonColumnType.kt index ed3c3c1299..028b987be9 100644 --- a/exposed-json/src/main/kotlin/org/jetbrains/exposed/sql/json/JsonColumnType.kt +++ b/exposed-json/src/main/kotlin/org/jetbrains/exposed/sql/json/JsonColumnType.kt @@ -23,12 +23,12 @@ open class JsonColumnType( val serialize: (T) -> String, /** Decode a JSON String to an object of type [T]. */ val deserialize: (String) -> T -) : ColumnType(), JsonColumnMarker { +) : ColumnType(), JsonColumnMarker { override val usesBinaryFormat: Boolean = false override fun sqlType(): String = currentDialect.dataTypeProvider.jsonType() - override fun valueFromDB(value: Any): Any { + override fun valueFromDB(value: Any): T { return when { currentDialect is PostgreSQLDialect && value is PGobject -> deserialize(value.value!!) value is String -> deserialize(value) @@ -38,9 +38,9 @@ open class JsonColumnType( } @Suppress("UNCHECKED_CAST") - override fun notNullValueToDB(value: Any) = serialize(value as T) + override fun notNullValueToDB(value: T): Any = serialize(value as T) - override fun valueToString(value: Any?): String = when (value) { + override fun valueToString(value: T): String = when (value) { is Iterable<*> -> nonNullValueToString(value) else -> super.valueToString(value) } diff --git a/exposed-json/src/main/kotlin/org/jetbrains/exposed/sql/json/JsonFunctions.kt b/exposed-json/src/main/kotlin/org/jetbrains/exposed/sql/json/JsonFunctions.kt index 143ea819ed..5b98ec5e90 100644 --- a/exposed-json/src/main/kotlin/org/jetbrains/exposed/sql/json/JsonFunctions.kt +++ b/exposed-json/src/main/kotlin/org/jetbrains/exposed/sql/json/JsonFunctions.kt @@ -22,8 +22,8 @@ class Extract( /** Whether the extracted result should be a scalar or text value; if `false`, result will be a JSON object. */ val toScalar: Boolean, /** The column type of [expression] to check, if casting to JSONB is required. */ - val jsonType: IColumnType, - columnType: IColumnType + val jsonType: IColumnType<*>, + columnType: IColumnType ) : Function(columnType) { override fun toQueryBuilder(queryBuilder: QueryBuilder) = currentDialect.functionProvider.jsonExtract(expression, path = path, toScalar, jsonType, queryBuilder) diff --git a/exposed-kotlin-datetime/src/main/kotlin/org/jetbrains/exposed/sql/kotlin/datetime/KotlinDateColumnType.kt b/exposed-kotlin-datetime/src/main/kotlin/org/jetbrains/exposed/sql/kotlin/datetime/KotlinDateColumnType.kt index 894aa6222a..122050e70e 100644 --- a/exposed-kotlin-datetime/src/main/kotlin/org/jetbrains/exposed/sql/kotlin/datetime/KotlinDateColumnType.kt +++ b/exposed-kotlin-datetime/src/main/kotlin/org/jetbrains/exposed/sql/kotlin/datetime/KotlinDateColumnType.kt @@ -125,43 +125,32 @@ private val LocalDate.millis get() = this.atStartOfDayIn(TimeZone.currentSystemD * * @sample org.jetbrains.exposed.sql.kotlin.datetime.date */ -class KotlinLocalDateColumnType : ColumnType(), IDateColumnType { +class KotlinLocalDateColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = false override fun sqlType(): String = currentDialect.dataTypeProvider.dateType() - override fun nonNullValueToString(value: Any): String { - val instant = when (value) { - is String -> return value - is LocalDate -> Instant.fromEpochMilliseconds(value.atStartOfDayIn(DEFAULT_TIME_ZONE).toEpochMilliseconds()) - is java.sql.Date -> Instant.fromEpochMilliseconds(value.time) - is java.sql.Timestamp -> Instant.fromEpochSeconds(value.time / MILLIS_IN_SECOND, value.nanos.toLong()) - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } - + override fun nonNullValueToString(value: LocalDate): String { + val instant = Instant.fromEpochMilliseconds(value.atStartOfDayIn(DEFAULT_TIME_ZONE).toEpochMilliseconds()) return "'${DEFAULT_DATE_STRING_FORMATTER.format(instant.toJavaInstant())}'" } - override fun valueFromDB(value: Any): Any = when (value) { + override fun valueFromDB(value: Any): LocalDate = when (value) { is LocalDate -> value is java.sql.Date -> longToLocalDate(value.time) is java.sql.Timestamp -> longToLocalDate(value.time) is Int -> longToLocalDate(value.toLong()) is Long -> longToLocalDate(value) - is String -> when (currentDialect) { - is SQLiteDialect -> LocalDate.parse(value) - else -> value - } + is String -> LocalDate.parse(value) else -> LocalDate.parse(value.toString()) } - override fun notNullValueToDB(value: Any) = when { - value is LocalDate && currentDialect is SQLiteDialect -> DEFAULT_DATE_STRING_FORMATTER.format(value.toJavaLocalDate()) - value is LocalDate -> java.sql.Date(value.millis) - else -> value + override fun notNullValueToDB(value: LocalDate) = when { + currentDialect is SQLiteDialect -> DEFAULT_DATE_STRING_FORMATTER.format(value.toJavaLocalDate()) + else -> java.sql.Date(value.millis) } - override fun nonNullValueAsDefaultString(value: Any): String = when (currentDialect) { + override fun nonNullValueAsDefaultString(value: LocalDate): String = when (currentDialect) { is PostgreSQLDialect -> "${nonNullValueToString(value)}::date" else -> super.nonNullValueAsDefaultString(value) } @@ -178,19 +167,13 @@ class KotlinLocalDateColumnType : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.kotlin.datetime.datetime */ -class KotlinLocalDateTimeColumnType : ColumnType(), IDateColumnType { +class KotlinLocalDateTimeColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = true override fun sqlType(): String = currentDialect.dataTypeProvider.dateTimeType() - override fun nonNullValueToString(value: Any): String { - val instant = when (value) { - is String -> return value - is LocalDateTime -> value.toInstant(DEFAULT_TIME_ZONE) - is java.sql.Date -> Instant.fromEpochMilliseconds(value.time) - is java.sql.Timestamp -> Instant.fromEpochSeconds(value.time / MILLIS_IN_SECOND, value.nanos.toLong()) - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } + override fun nonNullValueToString(value: LocalDateTime): String { + val instant = value.toInstant(DEFAULT_TIME_ZONE) val dialect = currentDialect return when { @@ -205,7 +188,7 @@ class KotlinLocalDateTimeColumnType : ColumnType(), IDateColumnType { } } - override fun valueFromDB(value: Any): Any = when (value) { + override fun valueFromDB(value: Any): LocalDateTime = when (value) { is LocalDateTime -> value is java.sql.Date -> longToLocalDateTime(value.time) is java.sql.Timestamp -> longToLocalDateTime(value.time / MILLIS_IN_SECOND, value.nanos.toLong()) @@ -217,14 +200,13 @@ class KotlinLocalDateTimeColumnType : ColumnType(), IDateColumnType { else -> valueFromDB(value.toString()) } - override fun notNullValueToDB(value: Any): Any = when { - value is LocalDateTime && currentDialect is SQLiteDialect -> + override fun notNullValueToDB(value: LocalDateTime): Any = when { + currentDialect is SQLiteDialect -> SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value.toJavaLocalDateTime().atZone(ZoneId.systemDefault())) - value is LocalDateTime -> { + else -> { val instant = value.toJavaLocalDateTime().atZone(ZoneId.systemDefault()).toInstant() java.sql.Timestamp(instant.toEpochMilli()).apply { nanos = instant.nano } } - else -> value } override fun readObject(rs: ResultSet, index: Int): Any? { @@ -235,19 +217,16 @@ class KotlinLocalDateTimeColumnType : ColumnType(), IDateColumnType { } } - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is LocalDateTime -> { - val instant = value.toInstant(DEFAULT_TIME_ZONE).toJavaInstant() - val dialect = currentDialect - when { - dialect is PostgreSQLDialect -> - "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(instant).trimEnd('0').trimEnd('.')}'::timestamp without time zone" - dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> - "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(instant).trimEnd('0').trimEnd('.')}'" - else -> super.nonNullValueAsDefaultString(value) - } + override fun nonNullValueAsDefaultString(value: LocalDateTime): String { + val instant = value.toInstant(DEFAULT_TIME_ZONE).toJavaInstant() + val dialect = currentDialect + return when { + dialect is PostgreSQLDialect -> + "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(instant).trimEnd('0').trimEnd('.')}'::timestamp without time zone" + dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> + "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(instant).trimEnd('0').trimEnd('.')}'" + else -> super.nonNullValueAsDefaultString(value) } - else -> super.nonNullValueAsDefaultString(value) } private fun longToLocalDateTime(millis: Long) = Instant.fromEpochMilliseconds(millis).toLocalDateTime(DEFAULT_TIME_ZONE) @@ -263,12 +242,12 @@ class KotlinLocalDateTimeColumnType : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.kotlin.datetime.time */ -class KotlinLocalTimeColumnType : ColumnType(), IDateColumnType { +class KotlinLocalTimeColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = true override fun sqlType(): String = currentDialect.dataTypeProvider.timeType() - override fun nonNullValueToString(value: Any): String { + override fun nonNullValueToString(value: LocalTime): String { val dialect = currentDialect val formatter = if (dialect is OracleDialect || dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle) { ORACLE_TIME_STRING_FORMATTER @@ -276,13 +255,7 @@ class KotlinLocalTimeColumnType : ColumnType(), IDateColumnType { DEFAULT_TIME_STRING_FORMATTER } - val instant = when (value) { - is String -> return value - is LocalTime -> value.toJavaLocalTime() - is java.sql.Time -> Instant.fromEpochMilliseconds(value.time).toJavaInstant() - is java.sql.Timestamp -> Instant.fromEpochSeconds(value.time / MILLIS_IN_SECOND, value.nanos.toLong()).toJavaInstant() - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } + val instant = value.toJavaLocalTime() return "'${formatter.format(instant)}'" } @@ -304,19 +277,11 @@ class KotlinLocalTimeColumnType : ColumnType(), IDateColumnType { else -> valueFromDB(value.toString()) } - override fun notNullValueToDB(value: Any): Any = when (value) { - is LocalTime -> java.sql.Time.valueOf(value.toJavaLocalTime()) - else -> value - } + override fun notNullValueToDB(value: LocalTime): Any = java.sql.Time.valueOf(value.toJavaLocalTime()) - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is LocalTime -> { - when (currentDialect) { - is PostgreSQLDialect -> "${nonNullValueToString(value)}::time without time zone" - is MysqlDialect -> "'${MYSQL_TIME_AS_DEFAULT_STRING_FORMATTER.format(value.toJavaLocalTime())}'" - else -> super.nonNullValueAsDefaultString(value) - } - } + override fun nonNullValueAsDefaultString(value: LocalTime): String = when (currentDialect) { + is PostgreSQLDialect -> "${nonNullValueToString(value)}::time without time zone" + is MysqlDialect -> "'${MYSQL_TIME_AS_DEFAULT_STRING_FORMATTER.format(value.toJavaLocalTime())}'" else -> super.nonNullValueAsDefaultString(value) } @@ -332,18 +297,13 @@ class KotlinLocalTimeColumnType : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.kotlin.datetime.timestamp */ -class KotlinInstantColumnType : ColumnType(), IDateColumnType { +class KotlinInstantColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = true override fun sqlType(): String = currentDialect.dataTypeProvider.dateTimeType() - override fun nonNullValueToString(value: Any): String { - val instant = when (value) { - is String -> return value - is Instant -> value.toJavaInstant() - is java.sql.Timestamp -> value.toInstant() - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } + override fun nonNullValueToString(value: Instant): String { + val instant = value.toJavaInstant() val dialect = currentDialect return when { @@ -367,26 +327,21 @@ class KotlinInstantColumnType : ColumnType(), IDateColumnType { return rs.getTimestamp(index) } - override fun notNullValueToDB(value: Any): Any = when { - value is Instant && currentDialect is SQLiteDialect -> + override fun notNullValueToDB(value: Instant): Any = when { + currentDialect is SQLiteDialect -> SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value.toJavaInstant()) - value is Instant -> - java.sql.Timestamp.from(value.toJavaInstant()) - else -> value + else -> java.sql.Timestamp.from(value.toJavaInstant()) } - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is Instant -> { - val dialect = currentDialect - when { - dialect is PostgreSQLDialect -> - "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value.toJavaInstant()).trimEnd('0').trimEnd('.')}'::timestamp without time zone" - dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> - "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value.toJavaInstant()).trimEnd('0').trimEnd('.')}'" - else -> super.nonNullValueAsDefaultString(value) - } + override fun nonNullValueAsDefaultString(value: Instant): String { + val dialect = currentDialect + return when { + dialect is PostgreSQLDialect -> + "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value.toJavaInstant()).trimEnd('0').trimEnd('.')}'::timestamp without time zone" + dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> + "'${SQLITE_AND_ORACLE_DATE_TIME_STRING_FORMATTER.format(value.toJavaInstant()).trimEnd('0').trimEnd('.')}'" + else -> super.nonNullValueAsDefaultString(value) } - else -> super.nonNullValueAsDefaultString(value) } companion object { @@ -399,21 +354,16 @@ class KotlinInstantColumnType : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.kotlin.datetime.timestampWithTimeZone */ -class KotlinOffsetDateTimeColumnType : ColumnType(), IDateColumnType { +class KotlinOffsetDateTimeColumnType : ColumnType(), IDateColumnType { override val hasTimePart: Boolean = true override fun sqlType(): String = currentDialect.dataTypeProvider.timestampWithTimeZoneType() - override fun nonNullValueToString(value: Any): String = when (value) { - is OffsetDateTime -> { - when (currentDialect) { - is SQLiteDialect -> "'${value.format(SQLITE_OFFSET_DATE_TIME_FORMATTER)}'" - is MysqlDialect -> "'${value.format(MYSQL_OFFSET_DATE_TIME_FORMATTER)}'" - is OracleDialect -> "'${value.format(ORACLE_OFFSET_DATE_TIME_FORMATTER)}'" - else -> "'${value.format(DEFAULT_OFFSET_DATE_TIME_FORMATTER)}'" - } - } - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") + override fun nonNullValueToString(value: OffsetDateTime): String = when (currentDialect) { + is SQLiteDialect -> "'${value.format(SQLITE_OFFSET_DATE_TIME_FORMATTER)}'" + is MysqlDialect -> "'${value.format(MYSQL_OFFSET_DATE_TIME_FORMATTER)}'" + is OracleDialect -> "'${value.format(ORACLE_OFFSET_DATE_TIME_FORMATTER)}'" + else -> "'${value.format(DEFAULT_OFFSET_DATE_TIME_FORMATTER)}'" } override fun valueFromDB(value: Any): OffsetDateTime = when (value) { @@ -433,30 +383,22 @@ class KotlinOffsetDateTimeColumnType : ColumnType(), IDateColumnType { else -> rs.getObject(index, OffsetDateTime::class.java) } - override fun notNullValueToDB(value: Any): Any = when (value) { - is OffsetDateTime -> { - when (currentDialect) { - is SQLiteDialect -> value.format(SQLITE_OFFSET_DATE_TIME_FORMATTER) - is MysqlDialect -> value.format(MYSQL_OFFSET_DATE_TIME_FORMATTER) - else -> value - } - } - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") + override fun notNullValueToDB(value: OffsetDateTime): Any = when (currentDialect) { + is SQLiteDialect -> value.format(SQLITE_OFFSET_DATE_TIME_FORMATTER) + is MysqlDialect -> value.format(MYSQL_OFFSET_DATE_TIME_FORMATTER) + else -> value } - override fun nonNullValueAsDefaultString(value: Any): String = when (value) { - is OffsetDateTime -> { - val dialect = currentDialect - when { - dialect is PostgreSQLDialect -> // +00 appended because PostgreSQL stores it in UTC time zone - "'${value.format(POSTGRESQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER).trimEnd('0').trimEnd('.')}+00'::timestamp with time zone" - dialect is H2Dialect && dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> - "'${value.format(POSTGRESQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER).trimEnd('0').trimEnd('.')}'" - dialect is MysqlDialect -> "'${value.format(MYSQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER)}'" - else -> super.nonNullValueAsDefaultString(value) - } + override fun nonNullValueAsDefaultString(value: OffsetDateTime): String { + val dialect = currentDialect + return when { + dialect is PostgreSQLDialect -> // +00 appended because PostgreSQL stores it in UTC time zone + "'${value.format(POSTGRESQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER).trimEnd('0').trimEnd('.')}+00'::timestamp with time zone" + dialect is H2Dialect && dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> + "'${value.format(POSTGRESQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER).trimEnd('0').trimEnd('.')}'" + dialect is MysqlDialect -> "'${value.format(MYSQL_OFFSET_DATE_TIME_AS_DEFAULT_FORMATTER)}'" + else -> super.nonNullValueAsDefaultString(value) } - else -> super.nonNullValueAsDefaultString(value) } companion object { @@ -469,18 +411,11 @@ class KotlinOffsetDateTimeColumnType : ColumnType(), IDateColumnType { * * @sample org.jetbrains.exposed.sql.kotlin.datetime.duration */ -class KotlinDurationColumnType : ColumnType() { +class KotlinDurationColumnType : ColumnType() { override fun sqlType(): String = currentDialect.dataTypeProvider.longType() - override fun nonNullValueToString(value: Any): String { - val duration = when (value) { - is String -> return value - is Duration -> value.inWholeNanoseconds - is Long -> value - is Number -> value.toLong() - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") - } - + override fun nonNullValueToString(value: Duration): String { + val duration = value.inWholeNanoseconds return "'$duration'" } @@ -497,12 +432,7 @@ class KotlinDurationColumnType : ColumnType() { return rs.getLong(index).takeIf { rs.getObject(index) != null } } - override fun notNullValueToDB(value: Any): Any { - if (value is Duration) { - return value.inWholeNanoseconds - } - return value - } + override fun notNullValueToDB(value: Duration): Any = value.inWholeNanoseconds companion object { internal val INSTANCE = KotlinDurationColumnType() diff --git a/exposed-kotlin-datetime/src/test/kotlin/org/jetbrains/exposed/sql/kotlin/datetime/sqlserver/SQLServerDefaultsTest.kt b/exposed-kotlin-datetime/src/test/kotlin/org/jetbrains/exposed/sql/kotlin/datetime/sqlserver/SQLServerDefaultsTest.kt index 03b58f2bcc..291d2a5d13 100644 --- a/exposed-kotlin-datetime/src/test/kotlin/org/jetbrains/exposed/sql/kotlin/datetime/sqlserver/SQLServerDefaultsTest.kt +++ b/exposed-kotlin-datetime/src/test/kotlin/org/jetbrains/exposed/sql/kotlin/datetime/sqlserver/SQLServerDefaultsTest.kt @@ -17,7 +17,7 @@ class SQLServerDefaultsTest : DatabaseTestsBase() { fun testDefaultExpressionsForTemporalTable() { fun databaseGeneratedTimestamp() = object : ExpressionWithColumnType() { override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder { +"DEFAULT" } - override val columnType: IColumnType = KotlinLocalDateTimeColumnType() + override val columnType: IColumnType = KotlinLocalDateTimeColumnType() } val temporalTable = object : UUIDTable("TemporalTable") { diff --git a/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CurrencyColumnType.kt b/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CurrencyColumnType.kt index 04ac179a0f..dbdd23e5c7 100644 --- a/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CurrencyColumnType.kt +++ b/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CurrencyColumnType.kt @@ -1,8 +1,10 @@ package org.jetbrains.exposed.sql.money import org.jetbrains.exposed.sql.Column +import org.jetbrains.exposed.sql.ColumnType import org.jetbrains.exposed.sql.Table import org.jetbrains.exposed.sql.VarCharColumnType +import org.jetbrains.exposed.sql.vendors.currentDialect import javax.money.CurrencyUnit import javax.money.Monetary @@ -12,23 +14,62 @@ import javax.money.Monetary * @author Vladislav Kisel */ @Suppress("MagicNumber") -class CurrencyColumnType : VarCharColumnType(3) { +class CurrencyColumnType : ColumnType() { - override fun notNullValueToDB(value: Any): Any { - return when (value) { - is String -> value - is CurrencyUnit -> value.currencyCode - else -> error("Unexpected value: $value of ${value::class.qualifiedName}") + override fun sqlType(): String = currentDialect.dataTypeProvider.varcharType(colLength) + + override fun validateValueBeforeUpdate(value: CurrencyUnit?) { + if (value is CurrencyUnit) { + val valueLength = value.currencyCode.codePointCount(0, value.currencyCode.length) + require(valueLength <= colLength) { + "Value can't be stored to database column because exceeds length ($valueLength > $colLength)" + } } } - override fun valueFromDB(value: Any): Any { + override fun valueFromDB(value: Any): CurrencyUnit { return when (value) { is CurrencyUnit -> value is String -> Monetary.getCurrency(value) else -> valueFromDB(value.toString()) } } + + override fun notNullValueToDB(value: CurrencyUnit): Any = value.currencyCode + + override fun nonNullValueToString(value: CurrencyUnit): String = buildString { + append('\'') + append(escape(value.currencyCode)) + append('\'') + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + if (!super.equals(other)) return false + + other as VarCharColumnType + + return colLength == other.colLength + } + + override fun hashCode(): Int { + var result = super.hashCode() + result = 31 * result + colLength + return result + } + + private fun escape(value: String): String = value.map { charactersToEscape[it] ?: it }.joinToString("") + + companion object { + private const val colLength = 3 + + private val charactersToEscape = mapOf( + '\'' to "\'\'", + '\r' to "\\r", + '\n' to "\\n" + ) + } } /** Creates a character column, with the specified [name], for storing currency (as javax.money.CurrencyUnit). */