Skip to content

Commit

Permalink
Move 'create_pre_udf_table' function to warehouse module (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour authored Jul 31, 2024
1 parent 8f431dd commit b67cb70
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 47 deletions.
21 changes: 21 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sqlalchemy.dialects.sqlite import Insert
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql.elements import ColumnClause, ColumnElement, TextClause
from sqlalchemy.sql.selectable import Select
from sqlalchemy.types import TypeEngine


Expand Down Expand Up @@ -705,3 +706,23 @@ def export_dataset_table(
client_config=None,
) -> list[str]:
raise NotImplementedError("Exporting dataset table not implemented for SQLite")

def create_pre_udf_table(self, query: "Select") -> "Table":
"""
Create a temporary table from a query for use in a UDF.
"""
columns = [
sqlalchemy.Column(c.name, c.type)
for c in query.selected_columns
if c.name != "sys__id"
]
table = self.create_udf_table(columns)

select_q = query.with_only_columns(
*[c for c in query.selected_columns if c.name != "sys__id"]
)
self.db.execute(
table.insert().from_select(list(select_q.selected_columns), select_q)
)

return table
26 changes: 24 additions & 2 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
import logging
import posixpath
import random
import string
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Iterator, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
Expand All @@ -24,6 +26,7 @@
if TYPE_CHECKING:
from sqlalchemy.sql._typing import _ColumnsClauseArgument
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.selectable import Select
from sqlalchemy.types import TypeEngine

from datachain.data_storage import AbstractIDGenerator, schema
Expand Down Expand Up @@ -252,6 +255,12 @@ def dataset_table_name(self, dataset_name: str, version: int) -> str:
prefix = self.DATASET_SOURCE_TABLE_PREFIX
return f"{prefix}{dataset_name}_{version}"

def temp_table_name(self) -> str:
return self.TMP_TABLE_NAME_PREFIX + _random_string(6)

def udf_table_name(self) -> str:
return self.UDF_TABLE_NAME_PREFIX + _random_string(6)

#
# Datasets
#
Expand Down Expand Up @@ -869,23 +878,29 @@ def update_node(self, node_id: int, values: dict[str, Any]) -> None:

def create_udf_table(
self,
name: str,
columns: Sequence["sa.Column"] = (),
name: Optional[str] = None,
) -> "sa.Table":
"""
Create a temporary table for storing custom signals generated by a UDF.
SQLite TEMPORARY tables cannot be directly used as they are process-specific,
and UDFs are run in other processes when run in parallel.
"""
tbl = sa.Table(
name,
name or self.udf_table_name(),
sa.MetaData(),
sa.Column("sys__id", Int, primary_key=True),
*columns,
)
self.db.create_table(tbl, if_not_exists=True)
return tbl

@abstractmethod
def create_pre_udf_table(self, query: "Select") -> "Table":
"""
Create a temporary table from a query for use in a UDF.
"""

def is_temp_table_name(self, name: str) -> bool:
"""Returns if the given table name refers to a temporary
or no longer needed table."""
Expand Down Expand Up @@ -937,3 +952,10 @@ def changed_query(
& (tq.c.is_latest == true())
)
)


def _random_string(length: int) -> str:
return "".join(
random.choice(string.ascii_letters + string.digits) # noqa: S311
for i in range(length)
)
49 changes: 5 additions & 44 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,7 @@ def apply(self, query_generator, temp_tables: list[str]):
temp_tables.extend(self.dq.temp_table_names)

# creating temp table that will hold subtract results
temp_table_name = self.catalog.warehouse.TMP_TABLE_NAME_PREFIX + _random_string(
6
)
temp_table_name = self.catalog.warehouse.temp_table_name()
temp_tables.append(temp_table_name)

columns = [
Expand Down Expand Up @@ -448,9 +446,6 @@ def create_result_query(
to select
"""

def udf_table_name(self) -> str:
return self.catalog.warehouse.UDF_TABLE_NAME_PREFIX + _random_string(6)

def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
use_partitioning = self.partition_by is not None
batching = self.udf.properties.get_batching(use_partitioning)
Expand Down Expand Up @@ -574,9 +569,7 @@ def create_partitions_table(self, query: Select) -> "Table":
list_partition_by = [self.partition_by]

# create table with partitions
tbl = self.catalog.warehouse.create_udf_table(
self.udf_table_name(), partition_columns()
)
tbl = self.catalog.warehouse.create_udf_table(partition_columns())

# fill table with partitions
cols = [
Expand Down Expand Up @@ -638,37 +631,12 @@ def create_udf_table(self, query: Select) -> "Table":
for (col_name, col_type) in self.udf.output.items()
]

return self.catalog.warehouse.create_udf_table(
self.udf_table_name(), udf_output_columns
)

def create_pre_udf_table(self, query: Select) -> "Table":
columns = [
sqlalchemy.Column(c.name, c.type)
for c in query.selected_columns
if c.name != "sys__id"
]
table = self.catalog.warehouse.create_udf_table(self.udf_table_name(), columns)
select_q = query.with_only_columns(
*[c for c in query.selected_columns if c.name != "sys__id"]
)

# if there is order by clause we need row_number to preserve order
# if there is no order by clause we still need row_number to generate
# unique ids as uniqueness is important for this table
select_q = select_q.add_columns(
f.row_number().over(order_by=select_q._order_by_clauses).label("sys__id")
)

self.catalog.warehouse.db.execute(
table.insert().from_select(list(select_q.selected_columns), select_q)
)
return table
return self.catalog.warehouse.create_udf_table(udf_output_columns)

def process_input_query(self, query: Select) -> tuple[Select, list["Table"]]:
if os.getenv("DATACHAIN_DISABLE_QUERY_CACHE", "") not in ("", "0"):
return query, []
table = self.create_pre_udf_table(query)
table = self.catalog.warehouse.create_pre_udf_table(query)
q: Select = sqlalchemy.select(*table.c)
if query._order_by_clauses:
# we are adding ordering only if it's explicitly added by user in
Expand Down Expand Up @@ -732,7 +700,7 @@ class RowGenerator(UDFStep):
def create_udf_table(self, query: Select) -> "Table":
warehouse = self.catalog.warehouse

table_name = self.udf_table_name()
table_name = self.catalog.warehouse.udf_table_name()
columns: tuple[Column, ...] = tuple(
Column(name, typ) for name, typ in self.udf.output.items()
)
Expand Down Expand Up @@ -1802,10 +1770,3 @@ def query_wrapper(dataset_query: DatasetQuery) -> DatasetQuery:

_send_result(dataset_query)
return dataset_query


def _random_string(length: int) -> str:
return "".join(
random.choice(string.ascii_letters + string.digits) # noqa: S311
for i in range(length)
)
2 changes: 1 addition & 1 deletion tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def test_garbage_collect(cloud_test_catalog, from_cli, capsys):
assert catalog.get_temp_table_names() == []
temp_tables = ["tmp_vc12F", "udf_jh653", "ds_shadow_12345", "old_ds_shadow"]
for t in temp_tables:
catalog.warehouse.create_udf_table(t)
catalog.warehouse.create_udf_table(name=t)
assert set(catalog.get_temp_table_names()) == set(temp_tables)
if from_cli:
garbage_collect(catalog)
Expand Down

0 comments on commit b67cb70

Please sign in to comment.