Skip to content

Commit

Permalink
Pull dataset from studio if not available locally
Browse files Browse the repository at this point in the history
If the following case are met, this will pull dataset from Studio.
- User should be logged in to Studio.
- The dataset or version doesn't exist in local
- User has not pass studio=False to from_dataset.

In such case, this will pull the dataset from studio before continuing
further.

The test is added to check for such behavior.

Closes #874
  • Loading branch information
amritghimire committed Feb 6, 2025
1 parent d0c5f94 commit 67e76b8
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def from_dataset(
version: Optional[int] = None,
session: Optional[Session] = None,
settings: Optional[dict] = None,
studio: bool = True,
) -> "Self":
"""Get data from a saved Dataset. It returns the chain itself.
Expand All @@ -498,6 +499,7 @@ def from_dataset(
version=version,
session=session,
indexing_column_types=File._datachain_column_types,
studio=studio,
)
telemetry.send_event_once("class", "datachain_init", name=name, version=version)
if settings:
Expand Down
48 changes: 43 additions & 5 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,18 @@

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog.catalog import clone_catalog_with_cache
from datachain.config import Config
from datachain.data_storage.schema import (
PARTITION_COLUMN_ID,
partition_col_names,
partition_columns,
)
from datachain.dataset import DatasetStatus, RowDict
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict
from datachain.error import (
DatasetNotFoundError,
DatasetVersionNotFoundError,
QueryScriptCancelError,
)
from datachain.func.base import Function
from datachain.lib.udf import UDFAdapter, _get_cache
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
Expand Down Expand Up @@ -1081,6 +1086,7 @@ def __init__(
session: Optional[Session] = None,
indexing_column_types: Optional[dict[str, Any]] = None,
in_memory: bool = False,
studio: bool = True,
) -> None:
self.session = Session.get(session, catalog=catalog, in_memory=in_memory)
self.catalog = catalog or self.session.catalog
Expand All @@ -1097,9 +1103,26 @@ def __init__(
self.column_types: Optional[dict[str, Any]] = None

self.name = name
ds = self.catalog.get_dataset(name)
self.version = version or ds.latest_version
self.feature_schema = ds.get_version(self.version).feature_schema
try:
ds = self.catalog.get_dataset(name)
self.version = version or ds.latest_version
self.feature_schema = ds.get_version(self.version).feature_schema

except (DatasetNotFoundError, DatasetVersionNotFoundError):
if not studio:
raise

Check warning on line 1113 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L1113

Added line #L1113 was not covered by tests

token = os.environ.get("DVC_STUDIO_TOKEN") or Config().read().get(
"studio", {}
).get("token")
if not token:
raise

# Pull only if studio token is set and studio flag is True.
ds = self.pull_dataset(name, version)
self.version = version or ds.latest_version
self.feature_schema = ds.get_version(self.version).feature_schema

self.column_types = copy(ds.schema)
if "sys__id" in self.column_types:
self.column_types.pop("sys__id")
Expand All @@ -1112,6 +1135,21 @@ def __iter__(self):
def __or__(self, other):
return self.union(other)

def pull_dataset(self, name: str, version: Optional[int] = None) -> "DatasetRecord":
print("Dataset not found in local catalog, trying to get from studio")

remote_ds_uri = f"{DATASET_PREFIX}{name}"
if version:
remote_ds_uri += f"@v{version}"

self.catalog.pull_dataset(
remote_ds_uri=remote_ds_uri,
local_ds_name=name,
local_ds_version=version,
)

return self.catalog.get_dataset(name)

@staticmethod
def get_table() -> "TableClause":
table_name = "".join(
Expand Down
40 changes: 40 additions & 0 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from datachain.config import Config, ConfigLevel
from datachain.dataset import DatasetStatus
from datachain.error import DataChainError, DatasetNotFoundError
from datachain.lib.dc import DataChain
from datachain.query.session import Session
from datachain.utils import STUDIO_URL, JSONSerialize
from tests.data import ENTRIES
from tests.utils import assert_row_names, skip_if_not_sqlite, tree_from_path
Expand Down Expand Up @@ -267,6 +269,44 @@ def test_pull_dataset_success(
}


@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
@skip_if_not_sqlite
def test_datachain_from_dataset_pull(
mocker,
cloud_test_catalog,
remote_dataset_info,
dataset_export,
dataset_export_status,
dataset_export_data_chunk,
):
# Check if the datachain pull from studio if datachain is not available.
mocker.patch(
"datachain.catalog.catalog.DatasetRowsFetcher.should_check_for_status",
return_value=True,
)

catalog = cloud_test_catalog.catalog

# Makes sure dataset is not available locally at first
with pytest.raises(DatasetNotFoundError):
catalog.get_dataset("dogs")

with Session("testSession", catalog=catalog):
ds = DataChain.from_dataset(
name="dogs",
version=1,
studio=True,
)

assert ds.dataset.name == "dogs"
assert ds.dataset.latest_version == 1
assert ds.dataset.status == DatasetStatus.COMPLETE

# Check that dataset is available locally after pulling
dataset = catalog.get_dataset("dogs")
assert dataset.name == "dogs"


@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
@skip_if_not_sqlite
def test_pull_dataset_wrong_dataset_uri_format(
Expand Down

0 comments on commit 67e76b8

Please sign in to comment.