Skip to content

Commit

Permalink
Move 'join' SQL implementation to warehouse (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour authored Sep 16, 2024
1 parent 78ee1ba commit 404021f
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 22 deletions.
19 changes: 19 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
from sqlalchemy.dialects.sqlite import Insert
from sqlalchemy.engine.base import Engine
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql._typing import _FromClauseArgument, _OnClauseArgument
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.selectable import Join
from sqlalchemy.types import TypeEngine

from datachain.lib.file import File
Expand Down Expand Up @@ -788,6 +790,23 @@ def copy_table(
if progress_cb:
progress_cb(len(batch_ids))

def join(
self,
left: "_FromClauseArgument",
right: "_FromClauseArgument",
onclause: "_OnClauseArgument",
inner: bool = True,
) -> "Join":
"""
Join two tables together.
"""
return sqlalchemy.join(
left,
right,
onclause,
isouter=not inner,
)

def create_pre_udf_table(self, query: "Select") -> "Table":
"""
Create a temporary table from a query for use in a UDF.
Expand Down
22 changes: 19 additions & 3 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
from datachain.utils import sql_escape_like

if TYPE_CHECKING:
from sqlalchemy.sql._typing import _ColumnsClauseArgument
from sqlalchemy.sql.selectable import Select
from sqlalchemy.sql._typing import (
_ColumnsClauseArgument,
_FromClauseArgument,
_OnClauseArgument,
)
from sqlalchemy.sql.selectable import Join, Select
from sqlalchemy.types import TypeEngine

from datachain.data_storage import AbstractIDGenerator, schema
Expand Down Expand Up @@ -894,6 +898,18 @@ def copy_table(
Copy the results of a query into a table.
"""

@abstractmethod
def join(
self,
left: "_FromClauseArgument",
right: "_FromClauseArgument",
onclause: "_OnClauseArgument",
inner: bool = True,
) -> "Join":
"""
Join two tables together.
"""

@abstractmethod
def create_pre_udf_table(self, query: "Select") -> "Table":
"""
Expand Down Expand Up @@ -922,7 +938,7 @@ def cleanup_tables(self, names: Iterable[str]) -> None:
are cleaned up as soon as they are no longer needed.
"""
with tqdm(desc="Cleanup", unit=" tables") as pbar:
for name in names:
for name in set(names):
self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
pbar.update(1)

Expand Down
57 changes: 38 additions & 19 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from sqlalchemy.sql.expression import label
from sqlalchemy.sql.schema import TableClause
from sqlalchemy.sql.selectable import Select
from tqdm import tqdm

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
Expand Down Expand Up @@ -899,12 +898,36 @@ def q(*columns):

@frozen
class SQLJoin(Step):
catalog: "Catalog"
query1: "DatasetQuery"
query2: "DatasetQuery"
predicates: Union[JoinPredicateType, tuple[JoinPredicateType, ...]]
inner: bool
rname: str

def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
query = dq.apply_steps().select()
temp_tables.extend(dq.temp_table_names)

if not any(isinstance(step, (SQLJoin, SQLUnion)) for step in dq.steps):
return query.subquery(dq.table.name)

warehouse = self.catalog.warehouse

columns = [
c if isinstance(c, Column) else Column(c.name, c.type)
for c in query.subquery().columns
]
temp_table = warehouse.create_dataset_rows_table(
warehouse.temp_table_name(),
columns=columns,
)
temp_tables.append(temp_table.name)

warehouse.copy_table(temp_table, query)

return temp_table.select().subquery(dq.table.name)

def validate_expression(self, exp: "ClauseElement", q1, q2):
"""
Checking if columns used in expression actually exist in left / right
Expand Down Expand Up @@ -937,10 +960,8 @@ def validate_expression(self, exp: "ClauseElement", q1, q2):
def apply(
self, query_generator: QueryGenerator, temp_tables: list[str]
) -> StepResult:
q1 = self.query1.apply_steps().select().subquery(self.query1.table.name)
temp_tables.extend(self.query1.temp_table_names)
q2 = self.query2.apply_steps().select().subquery(self.query2.table.name)
temp_tables.extend(self.query2.temp_table_names)
q1 = self.get_query(self.query1, temp_tables)
q2 = self.get_query(self.query2, temp_tables)

q1_columns = list(q1.c)
q1_column_names = {c.name for c in q1_columns}
Expand All @@ -951,7 +972,12 @@ def apply(
continue

if c.name in q1_column_names:
c = c.label(self.rname.format(name=c.name))
new_name = self.rname.format(name=c.name)
new_name_idx = 0
while new_name in q1_column_names:
new_name_idx += 1
new_name = self.rname.format(name=f"{c.name}_{new_name_idx}")
c = c.label(new_name)
q2_columns.append(c)

res_columns = q1_columns + q2_columns
Expand Down Expand Up @@ -979,16 +1005,14 @@ def apply(
self.validate_expression(join_expression, q1, q2)

def q(*columns):
join_query = sqlalchemy.join(
join_query = self.catalog.warehouse.join(
q1,
q2,
join_expression,
isouter=not self.inner,
inner=self.inner,
)

res = sqlalchemy.select(*columns).select_from(join_query)
subquery = res.subquery()
return sqlalchemy.select(*subquery.c).select_from(subquery)
return sqlalchemy.select(*columns).select_from(join_query)
# return sqlalchemy.select(*subquery.c).select_from(subquery)

return step_result(
q,
Expand Down Expand Up @@ -1511,7 +1535,7 @@ def join(
if isinstance(predicates, (str, ColumnClause, ColumnElement))
else tuple(predicates)
)
new_query.steps = [SQLJoin(left, right, predicates, inner, rname)]
new_query.steps = [SQLJoin(self.catalog, left, right, predicates, inner, rname)]
return new_query

@detach
Expand Down Expand Up @@ -1687,12 +1711,7 @@ def save(

dr = self.catalog.warehouse.dataset_rows(dataset)

with tqdm(desc="Saving", unit=" rows") as pbar:
self.catalog.warehouse.copy_table(
dr.get_table(),
query.select(),
progress_cb=pbar.update,
)
self.catalog.warehouse.copy_table(dr.get_table(), query.select())

self.catalog.metastore.update_dataset_status(
dataset, DatasetStatus.COMPLETE, version=version
Expand Down
117 changes: 117 additions & 0 deletions tests/func/test_dataset_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,123 @@ def test_union(cloud_test_catalog):
assert count == 6


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
@pytest.mark.parametrize("inner", [True, False])
def test_union_join(cloud_test_catalog, inner):
catalog = cloud_test_catalog.catalog
sources = [str(cloud_test_catalog.src_uri)]
catalog.index(sources)

src = cloud_test_catalog.src_uri
catalog.create_dataset_from_sources("dogs", [f"{src}/dogs/*"], recursive=True)
catalog.create_dataset_from_sources("cats", [f"{src}/cats/*"], recursive=True)

dogs = DatasetQuery(name="dogs", version=1, catalog=catalog)
cats = DatasetQuery(name="cats", version=1, catalog=catalog)

signal_default_value = Int.default_value(catalog.warehouse.db.dialect)

@udf((), {"sig1": Int})
def signals1():
return (1,)

@udf((), {"sig2": Int})
def signals2():
return (2,)

dogs1 = dogs.add_signals(signals1)
dogs2 = dogs.add_signals(signals2)
cats1 = cats.add_signals(signals1)

joined = (dogs1 | cats1).join(dogs2, C.path, inner=inner)
signals = list(joined.select("path", "sig1", "sig2").order_by("path"))

if inner:
assert signals == [
("dogs/dog1", 1, 2),
("dogs/dog2", 1, 2),
("dogs/dog3", 1, 2),
("dogs/others/dog4", 1, 2),
]
else:
assert signals == [
("cats/cat1", 1, signal_default_value),
("cats/cat2", 1, signal_default_value),
("dogs/dog1", 1, 2),
("dogs/dog2", 1, 2),
("dogs/dog3", 1, 2),
("dogs/others/dog4", 1, 2),
]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
@pytest.mark.parametrize("inner1", [True, False])
@pytest.mark.parametrize("inner2", [True, False])
@pytest.mark.parametrize("inner3", [True, False])
def test_multiple_join(cloud_test_catalog, inner1, inner2, inner3):
catalog = cloud_test_catalog.catalog
sources = [str(cloud_test_catalog.src_uri)]
catalog.index(sources)

src = cloud_test_catalog.src_uri
catalog.create_dataset_from_sources("dogs", [f"{src}/dogs/*"], recursive=True)
catalog.create_dataset_from_sources("cats", [f"{src}/cats/*"], recursive=True)

dogs = DatasetQuery(name="dogs", version=1, catalog=catalog)
cats = DatasetQuery(name="cats", version=1, catalog=catalog)

signal_default_value = Int.default_value(catalog.warehouse.db.dialect)

@udf((), {"sig1": Int})
def signals1():
return (1,)

@udf((), {"sig2": Int})
def signals2():
return (2,)

dogs_and_cats = dogs | cats
dogs1 = dogs.add_signals(signals1)
cats1 = cats.add_signals(signals2)
dogs2 = dogs_and_cats.join(dogs1, C.path, inner=inner1)
cats2 = dogs_and_cats.join(cats1, C.path, inner=inner2)
joined = dogs2.join(cats2, C.path, inner=inner3)

joined_signals = list(joined.select("path", "sig1", "sig2").order_by("path"))

if inner1 and inner2 and inner3:
assert joined_signals == []
elif inner1:
assert joined_signals == [
("dogs/dog1", 1, signal_default_value),
("dogs/dog2", 1, signal_default_value),
("dogs/dog3", 1, signal_default_value),
("dogs/others/dog4", 1, signal_default_value),
]
elif inner2 and inner3:
assert joined_signals == [
("cats/cat1", signal_default_value, 2),
("cats/cat2", signal_default_value, 2),
]
else:
assert joined_signals == [
("cats/cat1", signal_default_value, 2),
("cats/cat2", signal_default_value, 2),
("dogs/dog1", 1, signal_default_value),
("dogs/dog2", 1, signal_default_value),
("dogs/dog3", 1, signal_default_value),
("dogs/others/dog4", 1, signal_default_value),
]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
Expand Down

0 comments on commit 404021f

Please sign in to comment.