diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt index 1e4172071..bd623046f 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt @@ -15,7 +15,7 @@ import kotlin.reflect.KType * * NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types. */ -public class H2(public val dialect: DbType = MySql) : DbType("h2") { +public open class H2(public val dialect: DbType = MySql) : DbType("h2") { init { require(dialect::class != H2::class) { "H2 database could not be specified with H2 dialect!" } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index 75a101e8e..7cee658f9 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -113,6 +113,7 @@ public data class DbConnectionConfig(val url: String, val user: String = "", val * @param [tableName] the name of the table to read data from. * @param [limit] the maximum number of rows to retrieve from the table. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the DataFrame containing the data from the SQL table. */ public fun DataFrame.Companion.readSqlTable( @@ -120,9 +121,10 @@ public fun DataFrame.Companion.readSqlTable( tableName: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlTable(connection, tableName, limit, inferNullability) + return readSqlTable(connection, tableName, limit, inferNullability, dbType) } } @@ -133,6 +135,7 @@ public fun DataFrame.Companion.readSqlTable( * @param [tableName] the name of the table to read data from. * @param [limit] the maximum number of rows to retrieve from the table. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the DataFrame containing the data from the SQL table. * * @see DriverManager.getConnection @@ -142,12 +145,13 @@ public fun DataFrame.Companion.readSqlTable( tableName: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { val url = connection.metaData.url - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) val selectAllQuery = if (limit > 0) { - dbType.sqlQueryLimit("SELECT * FROM $tableName", limit) + determinedDbType.sqlQueryLimit("SELECT * FROM $tableName", limit) } else { "SELECT * FROM $tableName" } @@ -157,7 +161,7 @@ public fun DataFrame.Companion.readSqlTable( st.executeQuery(selectAllQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability) + return fetchAndConvertDataFromResultSet(tableColumns, rs, determinedDbType, limit, inferNullability) } } } @@ -172,6 +176,7 @@ public fun DataFrame.Companion.readSqlTable( * @param [sqlQuery] the SQL query to execute. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the DataFrame containing the result of the SQL query. */ public fun DataFrame.Companion.readSqlQuery( @@ -179,9 +184,10 @@ public fun DataFrame.Companion.readSqlQuery( sqlQuery: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlQuery(connection, sqlQuery, limit, inferNullability) + return readSqlQuery(connection, sqlQuery, limit, inferNullability, dbType) } } @@ -195,6 +201,7 @@ public fun DataFrame.Companion.readSqlQuery( * @param [sqlQuery] the SQL query to execute. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the DataFrame containing the result of the SQL query. * * @see DriverManager.getConnection @@ -204,22 +211,23 @@ public fun DataFrame.Companion.readSqlQuery( sqlQuery: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { require(isValid(sqlQuery)) { "SQL query should start from SELECT and contain one query for reading data without any manipulation. " + "Also it should not contain any separators like `;`." } - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) - val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery + val internalSqlQuery = if (limit > 0) determinedDbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery logger.debug { "Executing SQL query: $internalSqlQuery" } connection.createStatement().use { st -> st.executeQuery(internalSqlQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability) + return fetchAndConvertDataFromResultSet(tableColumns, rs, determinedDbType, limit, inferNullability) } } } @@ -233,12 +241,14 @@ public fun DataFrame.Companion.readSqlQuery( * It should not contain `;` symbol. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the DataFrame containing the result of the SQL query. */ public fun DbConnectionConfig.readDataFrame( sqlQueryOrTableName: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame = when { isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery( @@ -246,6 +256,7 @@ public fun DbConnectionConfig.readDataFrame( sqlQueryOrTableName, limit, inferNullability, + dbType, ) isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable( @@ -253,6 +264,7 @@ public fun DbConnectionConfig.readDataFrame( sqlQueryOrTableName, limit, inferNullability, + dbType, ) else -> throw IllegalArgumentException( @@ -280,12 +292,14 @@ private fun isSqlTableName(sqlQueryOrTableName: String): Boolean { * It should not contain `;` symbol. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the DataFrame containing the result of the SQL query. */ public fun Connection.readDataFrame( sqlQueryOrTableName: String, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame = when { isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery( @@ -293,6 +307,7 @@ public fun Connection.readDataFrame( sqlQueryOrTableName, limit, inferNullability, + dbType, ) isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable( @@ -300,6 +315,7 @@ public fun Connection.readDataFrame( sqlQueryOrTableName, limit, inferNullability, + dbType, ) else -> throw IllegalArgumentException( @@ -386,6 +402,7 @@ public fun ResultSet.readDataFrame( * that the [ResultSet] belongs to. * @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet]. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data. * * [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html @@ -395,10 +412,11 @@ public fun DataFrame.Companion.readResultSet( connection: Connection, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): AnyFrame { - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) - return readResultSet(resultSet, dbType, limit, inferNullability) + return readResultSet(resultSet, determinedDbType, limit, inferNullability) } /** @@ -416,6 +434,7 @@ public fun DataFrame.Companion.readResultSet( * that the [ResultSet] belongs to. * @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet]. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data. * * [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html @@ -424,7 +443,8 @@ public fun ResultSet.readDataFrame( connection: Connection, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, -): AnyFrame = DataFrame.Companion.readResultSet(this, connection, limit, inferNullability) + dbType: DbType? = null, +): AnyFrame = DataFrame.Companion.readResultSet(this, connection, limit, inferNullability, dbType) /** * Reads all non-system tables from a database and returns them @@ -434,6 +454,7 @@ public fun ResultSet.readDataFrame( * @param [limit] the maximum number of rows to read from each table. * @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database. */ public fun DataFrame.Companion.readAllSqlTables( @@ -441,9 +462,10 @@ public fun DataFrame.Companion.readAllSqlTables( catalogue: String? = null, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): Map { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readAllSqlTables(connection, catalogue, limit, inferNullability) + return readAllSqlTables(connection, catalogue, limit, inferNullability, dbType) } } @@ -455,6 +477,7 @@ public fun DataFrame.Companion.readAllSqlTables( * @param [limit] the maximum number of rows to read from each table. * @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs. * @param [inferNullability] indicates how the column nullability should be inferred. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database. * * @see DriverManager.getConnection @@ -464,9 +487,10 @@ public fun DataFrame.Companion.readAllSqlTables( catalogue: String? = null, limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, + dbType: DbType? = null, ): Map { val metaData = connection.metaData - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) // exclude a system and other tables without data, but it looks like it is supported badly for many databases val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE")) @@ -474,8 +498,8 @@ public fun DataFrame.Companion.readAllSqlTables( val dataFrames = mutableMapOf() while (tables.next()) { - val table = dbType.buildTableMetadata(tables) - if (!dbType.isSystemTable(table)) { + val table = determinedDbType.buildTableMetadata(tables) + if (!determinedDbType.isSystemTable(table)) { // we filter here a second time because of specific logic with SQLite and possible issues with future databases val tableName = when { catalogue != null && table.schemaName != null -> "$catalogue.${table.schemaName}.${table.name}" @@ -488,7 +512,7 @@ public fun DataFrame.Companion.readAllSqlTables( // could be Dialect/Database specific logger.debug { "Reading table: $tableName" } - val dataFrame = readSqlTable(connection, tableName, limit, inferNullability) + val dataFrame = readSqlTable(connection, tableName, limit, inferNullability, dbType) dataFrames += tableName to dataFrame logger.debug { "Finished reading table: $tableName" } } @@ -502,11 +526,12 @@ public fun DataFrame.Companion.readAllSqlTables( * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. * @param [tableName] the name of the SQL table for which to retrieve the schema. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the [DataFrameSchema] object representing the schema of the SQL table */ -public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DbConnectionConfig, tableName: String): DataFrameSchema { +public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DbConnectionConfig, tableName: String, dbType: DbType? = null,): DataFrameSchema { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return getSchemaForSqlTable(connection, tableName) + return getSchemaForSqlTable(connection, tableName, dbType) } } @@ -515,20 +540,21 @@ public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DbConnectionConfig * * @param [connection] the database connection. * @param [tableName] the name of the SQL table for which to retrieve the schema. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the schema of the SQL table as a [DataFrameSchema] object. * * @see DriverManager.getConnection */ -public fun DataFrame.Companion.getSchemaForSqlTable(connection: Connection, tableName: String): DataFrameSchema { - val dbType = extractDBTypeFromConnection(connection) +public fun DataFrame.Companion.getSchemaForSqlTable(connection: Connection, tableName: String, dbType: DbType? = null,): DataFrameSchema { + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) val sqlQuery = "SELECT * FROM $tableName" - val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1) + val selectFirstRowQuery = determinedDbType.sqlQueryLimit(sqlQuery, limit = 1) connection.createStatement().use { st -> st.executeQuery(selectFirstRowQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return buildSchemaByTableColumns(tableColumns, dbType) + return buildSchemaByTableColumns(tableColumns, determinedDbType) } } } @@ -538,11 +564,12 @@ public fun DataFrame.Companion.getSchemaForSqlTable(connection: Connection, tabl * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. * @param [sqlQuery] the SQL query to execute and retrieve the schema from. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the schema of the SQL query as a [DataFrameSchema] object. */ -public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DbConnectionConfig, sqlQuery: String): DataFrameSchema { +public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DbConnectionConfig, sqlQuery: String, dbType: DbType? = null,): DataFrameSchema { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return getSchemaForSqlQuery(connection, sqlQuery) + return getSchemaForSqlQuery(connection, sqlQuery, dbType) } } @@ -551,17 +578,18 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DbConnectionConfig * * @param [connection] the database connection. * @param [sqlQuery] the SQL query to execute and retrieve the schema from. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the schema of the SQL query as a [DataFrameSchema] object. * * @see DriverManager.getConnection */ -public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String): DataFrameSchema { - val dbType = extractDBTypeFromConnection(connection) +public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String, dbType: DbType? = null,): DataFrameSchema { + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) connection.createStatement().use { st -> st.executeQuery(sqlQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return buildSchemaByTableColumns(tableColumns, dbType) + return buildSchemaByTableColumns(tableColumns, determinedDbType) } } } @@ -570,13 +598,14 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQ * Retrieves the schema of an SQL query result or the SQL table using the provided database configuration. * * @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the schema of the SQL query as a [DataFrameSchema] object. */ -public fun DbConnectionConfig.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema = +public fun DbConnectionConfig.getDataFrameSchema(sqlQueryOrTableName: String, dbType: DbType? = null,): DataFrameSchema = when { - isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName) + isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName, dbType) - isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName) + isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName, dbType) else -> throw IllegalArgumentException( "$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!", @@ -587,13 +616,14 @@ public fun DbConnectionConfig.getDataFrameSchema(sqlQueryOrTableName: String): D * Retrieves the schema of an SQL query result or the SQL table using the provided database configuration. * * @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return the schema of the SQL query as a [DataFrameSchema] object. */ -public fun Connection.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema = +public fun Connection.getDataFrameSchema(sqlQueryOrTableName: String, dbType: DbType? = null,): DataFrameSchema = when { - isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName) + isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName, dbType) - isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName) + isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName, dbType) else -> throw IllegalArgumentException( "$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!", @@ -606,7 +636,7 @@ public fun Connection.getDataFrameSchema(sqlQueryOrTableName: String): DataFrame * NOTE: This function will not close connection and result set and not retrieve data from the result set. * * @param [resultSet] the [ResultSet] obtained from executing a database query. - * @param [dbType] the type of database that the [ResultSet] belongs to. + * @param [dbType] the type of database that the [ResultSet] belongs to, could be a custom object, provided by user. * @return the schema of the [ResultSet] as a [DataFrameSchema] object. */ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbType: DbType): DataFrameSchema { @@ -619,48 +649,21 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp * * NOTE: This function will not close connection and result set and not retrieve data from the result set. * - * @param [dbType] the type of database that the [ResultSet] belongs to. + * @param [dbType] the type of database that the [ResultSet] belongs to, could be a custom object, provided by user. * @return the schema of the [ResultSet] as a [DataFrameSchema] object. */ public fun ResultSet.getDataFrameSchema(dbType: DbType): DataFrameSchema = DataFrame.getSchemaForResultSet(this, dbType) -/** - * Retrieves the schema from [ResultSet]. - * - * NOTE: [connection] is required to extract the database type. - * This function will not close connection and result set and not retrieve data from the result set. - * - * @param [resultSet] the [ResultSet] obtained from executing a database query. - * @param [connection] the connection to the database (it's required to extract the database type). - * @return the schema of the [ResultSet] as a [DataFrameSchema] object. - */ -public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema { - val dbType = extractDBTypeFromConnection(connection) - - val tableColumns = getTableColumnsMetadata(resultSet) - return buildSchemaByTableColumns(tableColumns, dbType) -} - -/** - * Retrieves the schema from [ResultSet]. - * - * NOTE: This function will not close connection and result set and not retrieve data from the result set. - * - * @param [connection] the connection to the database (it's required to extract the database type). - * @return the schema of the [ResultSet] as a [DataFrameSchema] object. - */ -public fun ResultSet.getDataFrameSchema(connection: Connection): DataFrameSchema = - DataFrame.getSchemaForResultSet(this, connection) - /** * Retrieves the schemas of all non-system tables in the database using the provided database configuration. * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table. */ -public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DbConnectionConfig): Map { +public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DbConnectionConfig, dbType: DbType? = null,): Map { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return getSchemaForAllSqlTables(connection) + return getSchemaForAllSqlTables(connection, dbType) } } @@ -668,11 +671,12 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DbConnectionCo * Retrieves the schemas of all non-system tables in the database using the provided database connection. * * @param [connection] the database connection. + * @param [dbType] the type of database, could be a custom object, provided by user, optional, default is `null`. * @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table. */ -public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map { +public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection, dbType: DbType? = null,): Map { val metaData = connection.metaData - val dbType = extractDBTypeFromConnection(connection) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) val tableTypes = arrayOf("TABLE") // exclude a system and other tables without data @@ -681,11 +685,11 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): val dataFrameSchemas = mutableMapOf() while (tables.next()) { - val jdbcTable = dbType.buildTableMetadata(tables) - if (!dbType.isSystemTable(jdbcTable)) { + val jdbcTable = determinedDbType.buildTableMetadata(tables) + if (!determinedDbType.isSystemTable(jdbcTable)) { // we filter her a second time because of specific logic with SQLite and possible issues with future databases val tableName = jdbcTable.name - val dataFrameSchema = getSchemaForSqlTable(connection, tableName) + val dataFrameSchema = getSchemaForSqlTable(connection, tableName, determinedDbType) dataFrameSchemas += tableName to dataFrameSchema } } @@ -826,6 +830,7 @@ private fun fetchAndConvertDataFromResultSet( } val dataFrame = data.mapIndexed { index, values -> + // TODO: add override handlers from dbType to intercept the final parcing before column creation val correctedValues = if (kotlinTypesForSqlColumns[index]!!.classifier == Array::class) { handleArrayValues(values) } else { diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt index c83d59158..5b60ce827 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt @@ -166,6 +166,8 @@ class JdbcTest { val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) dataSchema.columns.size shouldBe 2 dataSchema.columns["characterCol"]!!.type shouldBe typeOf() + + connection.createStatement().execute("DROP TABLE EmptyTestTable") } @Test @@ -299,6 +301,8 @@ class JdbcTest { schema.columns["realCol"]!!.type shouldBe typeOf() schema.columns["doublePrecisionCol"]!!.type shouldBe typeOf() schema.columns["decFloatCol"]!!.type shouldBe typeOf() + + connection.createStatement().execute("DROP TABLE $tableName") } @Test @@ -441,7 +445,7 @@ class JdbcTest { rs.beforeFirst() - val dataSchema1 = DataFrame.getSchemaForResultSet(rs, connection) + val dataSchema1 = DataFrame.getSchemaForResultSet(rs, H2(MySql)) dataSchema1.columns.size shouldBe 3 dataSchema1.columns["name"]!!.type shouldBe typeOf() } @@ -493,7 +497,7 @@ class JdbcTest { rs.beforeFirst() - val dataSchema1 = rs.getDataFrameSchema(connection) + val dataSchema1 = rs.getDataFrameSchema(H2(MySql)) dataSchema1.columns.size shouldBe 3 dataSchema1.columns["name"]!!.type shouldBe typeOf() } @@ -613,6 +617,7 @@ class JdbcTest { """ DataFrame.readSqlQuery(connection, selectFromWeirdTableSQL).rowsCount() shouldBe 0 + connection.createStatement().execute("DROP TABLE \"ALTER\"") } @Test @@ -967,4 +972,127 @@ class JdbcTest { } exception.message shouldBe "H2 database could not be specified with H2 dialect!" } + + // helper object created for API testing purposes + object CustomDB: H2(MySql) + + @Test + fun `read from table from custom database`() { + val tableName = "Customer" + val df = DataFrame.readSqlTable(connection, tableName, dbType = CustomDB).cast() + + df.rowsCount() shouldBe 4 + df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 + df[0][1] shouldBe "John" + + val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName, dbType = CustomDB) + dataSchema.columns.size shouldBe 3 + dataSchema.columns["name"]!!.type shouldBe typeOf() + + val dbConfig = DbConnectionConfig(url = URL) + val df2 = DataFrame.readSqlTable(dbConfig, tableName, dbType = CustomDB).cast() + + df2.rowsCount() shouldBe 4 + df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 + df2[0][1] shouldBe "John" + + val dataSchema1 = DataFrame.getSchemaForSqlTable(dbConfig, tableName, dbType = CustomDB) + dataSchema1.columns.size shouldBe 3 + dataSchema1.columns["name"]!!.type shouldBe typeOf() + } + + @Test + fun `read from query from custom database`() { + @Language("SQL") + val sqlQuery = + """ + SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount + FROM Sale s + INNER JOIN Customer c ON s.customerId = c.id + WHERE c.age > 35 + GROUP BY s.customerId, c.name + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery, dbType = CustomDB).cast() + + df.rowsCount() shouldBe 2 + df.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 + df[0][0] shouldBe "John" + + val dataSchema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery, dbType = CustomDB) + dataSchema.columns.size shouldBe 2 + dataSchema.columns["name"]!!.type shouldBe typeOf() + + val dbConfig = DbConnectionConfig(url = URL) + val df2 = DataFrame.readSqlQuery(dbConfig, sqlQuery, dbType = CustomDB).cast() + + df2.rowsCount() shouldBe 2 + df2.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 + df2[0][0] shouldBe "John" + + val dataSchema1 = DataFrame.getSchemaForSqlQuery(dbConfig, sqlQuery, dbType = CustomDB) + dataSchema1.columns.size shouldBe 2 + dataSchema1.columns["name"]!!.type shouldBe typeOf() + } + + @Test + fun `read from all tables from custom database`() { + val dataFrameMap = DataFrame.readAllSqlTables(connection, dbType = CustomDB) + dataFrameMap.containsKey("Customer") shouldBe true + dataFrameMap.containsKey("Sale") shouldBe true + + val dataframes = dataFrameMap.values.toList() + + val customerDf = dataframes[0].cast() + + customerDf.rowsCount() shouldBe 4 + customerDf.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 + customerDf[0][1] shouldBe "John" + + val saleDf = dataframes[1].cast() + + saleDf.rowsCount() shouldBe 4 + saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + + val dataFrameSchemaMap = DataFrame.getSchemaForAllSqlTables(connection, dbType = CustomDB) + dataFrameSchemaMap.containsKey("Customer") shouldBe true + dataFrameSchemaMap.containsKey("Sale") shouldBe true + + val dataSchemas = dataFrameSchemaMap.values.toList() + + val customerDataSchema = dataSchemas[0] + customerDataSchema.columns.size shouldBe 3 + customerDataSchema.columns["name"]!!.type shouldBe typeOf() + + val saleDataSchema = dataSchemas[1] + saleDataSchema.columns.size shouldBe 3 + // TODO: fix nullability + saleDataSchema.columns["amount"]!!.type shouldBe typeOf() + + val dbConfig = DbConnectionConfig(url = URL) + val dataframes2 = DataFrame.readAllSqlTables(dbConfig, dbType = CustomDB).values.toList() + + val customerDf2 = dataframes2[0].cast() + + customerDf2.rowsCount() shouldBe 4 + customerDf2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 + customerDf2[0][1] shouldBe "John" + + val saleDf2 = dataframes2[1].cast() + + saleDf2.rowsCount() shouldBe 4 + saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + + val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig, dbType = CustomDB).values.toList() + + val customerDataSchema1 = dataSchemas1[0] + customerDataSchema1.columns.size shouldBe 3 + customerDataSchema1.columns["name"]!!.type shouldBe typeOf() + + val saleDataSchema1 = dataSchemas1[1] + saleDataSchema1.columns.size shouldBe 3 + saleDataSchema1.columns["amount"]!!.type shouldBe typeOf() + } }