diff --git a/src/datachain/data_storage/sqlite.py b/src/datachain/data_storage/sqlite.py index 4b17eddb5..a280552d0 100644 --- a/src/datachain/data_storage/sqlite.py +++ b/src/datachain/data_storage/sqlite.py @@ -42,6 +42,7 @@ from sqlalchemy.dialects.sqlite import Insert from sqlalchemy.schema import SchemaItem from sqlalchemy.sql.elements import ColumnClause, ColumnElement, TextClause + from sqlalchemy.sql.selectable import Select from sqlalchemy.types import TypeEngine @@ -705,3 +706,23 @@ def export_dataset_table( client_config=None, ) -> list[str]: raise NotImplementedError("Exporting dataset table not implemented for SQLite") + + def create_pre_udf_table(self, query: "Select") -> "Table": + """ + Create a temporary table from a query for use in a UDF. + """ + columns = [ + sqlalchemy.Column(c.name, c.type) + for c in query.selected_columns + if c.name != "sys__id" + ] + table = self.create_udf_table(columns) + + select_q = query.with_only_columns( + *[c for c in query.selected_columns if c.name != "sys__id"] + ) + self.db.execute( + table.insert().from_select(list(select_q.selected_columns), select_q) + ) + + return table diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 652344a9c..b8ccb43e9 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -2,6 +2,8 @@ import json import logging import posixpath +import random +import string from abc import ABC, abstractmethod from collections.abc import Generator, Iterable, Iterator, Sequence from typing import TYPE_CHECKING, Any, Optional, Union @@ -24,6 +26,7 @@ if TYPE_CHECKING: from sqlalchemy.sql._typing import _ColumnsClauseArgument from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.selectable import Select from sqlalchemy.types import TypeEngine from datachain.data_storage import AbstractIDGenerator, schema @@ -252,6 +255,12 @@ def dataset_table_name(self, dataset_name: str, version: int) -> str: prefix = self.DATASET_SOURCE_TABLE_PREFIX return f"{prefix}{dataset_name}_{version}" + def temp_table_name(self) -> str: + return self.TMP_TABLE_NAME_PREFIX + _random_string(6) + + def udf_table_name(self) -> str: + return self.UDF_TABLE_NAME_PREFIX + _random_string(6) + # # Datasets # @@ -869,8 +878,8 @@ def update_node(self, node_id: int, values: dict[str, Any]) -> None: def create_udf_table( self, - name: str, columns: Sequence["sa.Column"] = (), + name: Optional[str] = None, ) -> "sa.Table": """ Create a temporary table for storing custom signals generated by a UDF. @@ -878,7 +887,7 @@ def create_udf_table( and UDFs are run in other processes when run in parallel. """ tbl = sa.Table( - name, + name or self.udf_table_name(), sa.MetaData(), sa.Column("sys__id", Int, primary_key=True), *columns, @@ -886,6 +895,12 @@ def create_udf_table( self.db.create_table(tbl, if_not_exists=True) return tbl + @abstractmethod + def create_pre_udf_table(self, query: "Select") -> "Table": + """ + Create a temporary table from a query for use in a UDF. + """ + def is_temp_table_name(self, name: str) -> bool: """Returns if the given table name refers to a temporary or no longer needed table.""" @@ -937,3 +952,10 @@ def changed_query( & (tq.c.is_latest == true()) ) ) + + +def _random_string(length: int) -> str: + return "".join( + random.choice(string.ascii_letters + string.digits) # noqa: S311 + for i in range(length) + ) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 0c99ea637..9a0504408 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -262,9 +262,7 @@ def apply(self, query_generator, temp_tables: list[str]): temp_tables.extend(self.dq.temp_table_names) # creating temp table that will hold subtract results - temp_table_name = self.catalog.warehouse.TMP_TABLE_NAME_PREFIX + _random_string( - 6 - ) + temp_table_name = self.catalog.warehouse.temp_table_name() temp_tables.append(temp_table_name) columns = [ @@ -448,9 +446,6 @@ def create_result_query( to select """ - def udf_table_name(self) -> str: - return self.catalog.warehouse.UDF_TABLE_NAME_PREFIX + _random_string(6) - def populate_udf_table(self, udf_table: "Table", query: Select) -> None: use_partitioning = self.partition_by is not None batching = self.udf.properties.get_batching(use_partitioning) @@ -574,9 +569,7 @@ def create_partitions_table(self, query: Select) -> "Table": list_partition_by = [self.partition_by] # create table with partitions - tbl = self.catalog.warehouse.create_udf_table( - self.udf_table_name(), partition_columns() - ) + tbl = self.catalog.warehouse.create_udf_table(partition_columns()) # fill table with partitions cols = [ @@ -638,37 +631,12 @@ def create_udf_table(self, query: Select) -> "Table": for (col_name, col_type) in self.udf.output.items() ] - return self.catalog.warehouse.create_udf_table( - self.udf_table_name(), udf_output_columns - ) - - def create_pre_udf_table(self, query: Select) -> "Table": - columns = [ - sqlalchemy.Column(c.name, c.type) - for c in query.selected_columns - if c.name != "sys__id" - ] - table = self.catalog.warehouse.create_udf_table(self.udf_table_name(), columns) - select_q = query.with_only_columns( - *[c for c in query.selected_columns if c.name != "sys__id"] - ) - - # if there is order by clause we need row_number to preserve order - # if there is no order by clause we still need row_number to generate - # unique ids as uniqueness is important for this table - select_q = select_q.add_columns( - f.row_number().over(order_by=select_q._order_by_clauses).label("sys__id") - ) - - self.catalog.warehouse.db.execute( - table.insert().from_select(list(select_q.selected_columns), select_q) - ) - return table + return self.catalog.warehouse.create_udf_table(udf_output_columns) def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]: if os.getenv("DATACHAIN_DISABLE_QUERY_CACHE", "") not in ("", "0"): return query, [] - table = self.create_pre_udf_table(query) + table = self.catalog.warehouse.create_pre_udf_table(query) q: Select = sqlalchemy.select(*table.c) if query._order_by_clauses: # we are adding ordering only if it's explicitly added by user in @@ -732,7 +700,7 @@ class RowGenerator(UDFStep): def create_udf_table(self, query: Select) -> "Table": warehouse = self.catalog.warehouse - table_name = self.udf_table_name() + table_name = self.catalog.warehouse.udf_table_name() columns: tuple[Column, ...] = tuple( Column(name, typ) for name, typ in self.udf.output.items() ) @@ -1802,10 +1770,3 @@ def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery: _send_result(dataset_query) return dataset_query - - -def _random_string(length: int) -> str: - return "".join( - random.choice(string.ascii_letters + string.digits) # noqa: S311 - for i in range(length) - ) diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 4db8660ca..1ebe85556 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -1113,7 +1113,7 @@ def test_garbage_collect(cloud_test_catalog, from_cli, capsys): assert catalog.get_temp_table_names() == [] temp_tables = ["tmp_vc12F", "udf_jh653", "ds_shadow_12345", "old_ds_shadow"] for t in temp_tables: - catalog.warehouse.create_udf_table(t) + catalog.warehouse.create_udf_table(name=t) assert set(catalog.get_temp_table_names()) == set(temp_tables) if from_cli: garbage_collect(catalog)