From 4f741e6402b45da6a8cdf9ffa694c8baf810cf0c Mon Sep 17 00:00:00 2001 From: Ronan Lamy Date: Tue, 10 Sep 2024 17:03:31 +0100 Subject: [PATCH] Remove Entry class and use File instead --- src/datachain/asyn.py | 13 +++----- src/datachain/client/azure.py | 14 +------- src/datachain/client/fsspec.py | 14 ++++---- src/datachain/client/gcs.py | 15 ++------- src/datachain/client/hf.py | 10 ------ src/datachain/client/local.py | 15 ++------- src/datachain/client/s3.py | 28 ++++++---------- src/datachain/data_storage/sqlite.py | 5 +++ src/datachain/data_storage/warehouse.py | 17 +++------- src/datachain/lib/listing.py | 3 +- src/datachain/listing.py | 14 ++++---- src/datachain/node.py | 43 ------------------------- tests/data.py | 22 ++++++------- tests/func/test_catalog.py | 2 +- tests/func/test_pull.py | 3 +- tests/unit/test_listing.py | 5 +-- 16 files changed, 59 insertions(+), 164 deletions(-) diff --git a/src/datachain/asyn.py b/src/datachain/asyn.py index e4b8d8255..4e42f4a8c 100644 --- a/src/datachain/asyn.py +++ b/src/datachain/asyn.py @@ -1,14 +1,8 @@ import asyncio -from collections.abc import Awaitable, Coroutine, Iterable +from collections.abc import AsyncIterable, Awaitable, Coroutine, Iterable, Iterator from concurrent.futures import ThreadPoolExecutor from heapq import heappop, heappush -from typing import ( - Any, - Callable, - Generic, - Optional, - TypeVar, -) +from typing import Any, Callable, Generic, Optional, TypeVar from fsspec.asyn import get_loop @@ -16,6 +10,7 @@ InputT = TypeVar("InputT", contravariant=True) # noqa: PLC0105 ResultT = TypeVar("ResultT", covariant=True) # noqa: PLC0105 +T = TypeVar("T") class AsyncMapper(Generic[InputT, ResultT]): @@ -226,7 +221,7 @@ async def _break_iteration(self) -> None: self._push_result(self._next_yield, None) -def iter_over_async(ait, loop): +def iter_over_async(ait: AsyncIterable[T], loop) -> Iterator[T]: """Wrap an asynchronous iterator into a synchronous one""" ait = ait.__aiter__() diff --git a/src/datachain/client/azure.py b/src/datachain/client/azure.py index 7eb28047f..4421945c6 100644 --- a/src/datachain/client/azure.py +++ b/src/datachain/client/azure.py @@ -4,7 +4,6 @@ from tqdm import tqdm from datachain.lib.file import File -from datachain.node import Entry from .fsspec import DELIMITER, Client, ResultQueue @@ -14,17 +13,6 @@ class AzureClient(Client): PREFIX = "az://" protocol = "az" - def convert_info(self, v: dict[str, Any], path: str) -> Entry: - version_id = v.get("version_id") - return Entry.from_file( - path=path, - etag=v.get("etag", "").strip('"'), - version=version_id or "", - is_latest=version_id is None or bool(v.get("is_current_version")), - last_modified=v["last_modified"], - size=v.get("size", ""), - ) - def info_to_file(self, v: dict[str, Any], path: str) -> File: version_id = v.get("version_id") return File( @@ -57,7 +45,7 @@ async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> Non continue info = (await self.fs._details([b]))[0] entries.append( - self.convert_info(info, self.rel_path(info["name"])) + self.info_to_file(info, self.rel_path(info["name"])) ) if entries: await result_queue.put(entries) diff --git a/src/datachain/client/fsspec.py b/src/datachain/client/fsspec.py index f49480cd9..cbbf8f521 100644 --- a/src/datachain/client/fsspec.py +++ b/src/datachain/client/fsspec.py @@ -29,7 +29,7 @@ from datachain.cache import DataChainCache, UniqueId from datachain.client.fileslice import FileSlice, FileWrapper from datachain.error import ClientError as DataChainClientError -from datachain.node import Entry +from datachain.lib.file import File from datachain.nodes_fetcher import NodesFetcher from datachain.nodes_thread_pool import NodeChunk from datachain.storage import StorageURI @@ -45,7 +45,7 @@ DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$") -ResultQueue = asyncio.Queue[Optional[Sequence[Entry]]] +ResultQueue = asyncio.Queue[Optional[Sequence[File]]] def _is_win_local_path(uri: str) -> bool: @@ -188,7 +188,7 @@ def url(self, path: str, expires: int = 3600, **kwargs) -> str: async def get_current_etag(self, uid: UniqueId) -> str: info = await self.fs._info(self.get_full_path(uid.path)) - return self.convert_info(info, "").etag + return self.info_to_file(info, "").etag async def get_size(self, path: str) -> int: return await self.fs._size(path) @@ -198,7 +198,7 @@ async def get_file(self, lpath, rpath, callback): async def scandir( self, start_prefix: str, method: str = "default" - ) -> AsyncIterator[Sequence[Entry]]: + ) -> AsyncIterator[Sequence[File]]: try: impl = getattr(self, f"_fetch_{method}") except AttributeError: @@ -264,7 +264,7 @@ async def _fetch_default( ) -> None: await self._fetch_nested(start_prefix, result_queue) - async def _fetch_dir(self, prefix, pbar, result_queue) -> set[str]: + async def _fetch_dir(self, prefix, pbar, result_queue: ResultQueue) -> set[str]: path = f"{self.name}/{prefix}" infos = await self.ls_dir(path) files = [] @@ -277,7 +277,7 @@ async def _fetch_dir(self, prefix, pbar, result_queue) -> set[str]: if info["type"] == "directory": subdirs.add(subprefix) else: - files.append(self.convert_info(info, subprefix)) + files.append(self.info_to_file(info, subprefix)) if files: await result_queue.put(files) found_count = len(subdirs) + len(files) @@ -303,7 +303,7 @@ def get_full_path(self, rel_path: str) -> str: return f"{self.PREFIX}{self.name}/{rel_path}" @abstractmethod - def convert_info(self, v: dict[str, Any], parent: str) -> Entry: ... + def info_to_file(self, v: dict[str, Any], parent: str) -> File: ... def fetch_nodes( self, diff --git a/src/datachain/client/gcs.py b/src/datachain/client/gcs.py index 95de0e5ac..33516c013 100644 --- a/src/datachain/client/gcs.py +++ b/src/datachain/client/gcs.py @@ -10,7 +10,6 @@ from tqdm import tqdm from datachain.lib.file import File -from datachain.node import Entry from .fsspec import DELIMITER, Client, ResultQueue @@ -108,19 +107,9 @@ async def _get_pages(self, path: str, page_queue: PageQueue) -> None: finally: await page_queue.put(None) - def _entry_from_dict(self, d: dict[str, Any]) -> Entry: + def _entry_from_dict(self, d: dict[str, Any]) -> File: info = self.fs._process_object(self.name, d) - return self.convert_info(info, self.rel_path(info["name"])) - - def convert_info(self, v: dict[str, Any], path: str) -> Entry: - return Entry.from_file( - path=path, - etag=v.get("etag", ""), - version=v.get("generation", ""), - is_latest=not v.get("timeDeleted"), - last_modified=self.parse_timestamp(v["updated"]), - size=v.get("size", ""), - ) + return self.info_to_file(info, self.rel_path(info["name"])) def info_to_file(self, v: dict[str, Any], path: str) -> File: return File( diff --git a/src/datachain/client/hf.py b/src/datachain/client/hf.py index 01c84d9f0..68b02cfa6 100644 --- a/src/datachain/client/hf.py +++ b/src/datachain/client/hf.py @@ -5,7 +5,6 @@ from huggingface_hub import HfFileSystem from datachain.lib.file import File -from datachain.node import Entry from .fsspec import Client @@ -22,15 +21,6 @@ def create_fs(cls, **kwargs) -> HfFileSystem: return cast(HfFileSystem, super().create_fs(**kwargs)) - def convert_info(self, v: dict[str, Any], path: str) -> Entry: - return Entry.from_file( - path=path, - size=v["size"], - version=v["last_commit"].oid, - etag=v.get("blob_id", ""), - last_modified=v["last_commit"].date, - ) - def info_to_file(self, v: dict[str, Any], path: str) -> File: return File( path=path, diff --git a/src/datachain/client/local.py b/src/datachain/client/local.py index c59743ab9..6bb9df872 100644 --- a/src/datachain/client/local.py +++ b/src/datachain/client/local.py @@ -7,8 +7,8 @@ from fsspec.implementations.local import LocalFileSystem +from datachain.cache import UniqueId from datachain.lib.file import File -from datachain.node import Entry from datachain.storage import StorageURI from .fsspec import Client @@ -114,9 +114,9 @@ def from_source( use_symlinks=use_symlinks, ) - async def get_current_etag(self, uid) -> str: + async def get_current_etag(self, uid: UniqueId) -> str: info = self.fs.info(self.get_full_path(uid.path)) - return self.convert_info(info, "").etag + return self.info_to_file(info, "").etag async def get_size(self, path: str) -> int: return self.fs.size(path) @@ -136,15 +136,6 @@ def get_full_path(self, rel_path): full_path += "/" return full_path - def convert_info(self, v: dict[str, Any], path: str) -> Entry: - return Entry.from_file( - path=path, - etag=v["mtime"].hex(), - is_latest=True, - last_modified=datetime.fromtimestamp(v["mtime"], timezone.utc), - size=v.get("size", ""), - ) - def info_to_file(self, v: dict[str, Any], path: str) -> File: return File( source=self.uri, diff --git a/src/datachain/client/s3.py b/src/datachain/client/s3.py index e859c5431..37de24442 100644 --- a/src/datachain/client/s3.py +++ b/src/datachain/client/s3.py @@ -1,12 +1,11 @@ import asyncio -from typing import Any, cast +from typing import Any, Optional, cast from botocore.exceptions import NoCredentialsError from s3fs import S3FileSystem from tqdm import tqdm from datachain.lib.file import File -from datachain.node import Entry from .fsspec import DELIMITER, Client, ResultQueue @@ -111,8 +110,9 @@ async def _fetch_default( ) -> None: await self._fetch_flat(start_prefix, result_queue) - def _entry_from_boto(self, v, bucket, versions=False): - return Entry.from_file( + def _entry_from_boto(self, v, bucket, versions=False) -> File: + return File( + source=self.uri, path=v["Key"], etag=v.get("ETag", "").strip('"'), version=ClientS3.clean_s3_version(v.get("VersionId", "")), @@ -125,8 +125,8 @@ async def _fetch_dir( self, prefix, pbar, - result_queue, - ): + result_queue: ResultQueue, + ) -> set[str]: if prefix: prefix = prefix.lstrip(DELIMITER) + DELIMITER files = [] @@ -141,7 +141,7 @@ async def _fetch_dir( if info["type"] == "directory": subdirs.add(subprefix) else: - files.append(self.convert_info(info, subprefix)) + files.append(self.info_to_file(info, subprefix)) pbar.update() found = True if not found: @@ -152,18 +152,8 @@ async def _fetch_dir( return subdirs @staticmethod - def clean_s3_version(ver): - return ver if ver != "null" else "" - - def convert_info(self, v: dict[str, Any], path: str) -> Entry: - return Entry.from_file( - path=path, - etag=v.get("ETag", "").strip('"'), - version=ClientS3.clean_s3_version(v.get("VersionId", "")), - is_latest=v.get("IsLatest", True), - last_modified=v.get("LastModified", ""), - size=v["size"], - ) + def clean_s3_version(ver: Optional[str]) -> str: + return ver if (ver is not None and ver != "null") else "" def info_to_file(self, v: dict[str, Any], path: str) -> File: return File( diff --git a/src/datachain/data_storage/sqlite.py b/src/datachain/data_storage/sqlite.py index ae41b9dc5..99a7ac2f1 100644 --- a/src/datachain/data_storage/sqlite.py +++ b/src/datachain/data_storage/sqlite.py @@ -43,6 +43,8 @@ from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.types import TypeEngine + from datachain.lib.file import File + logger = logging.getLogger("datachain") @@ -708,6 +710,9 @@ def merge_dataset_rows( self.db.execute(insert_query) + def prepare_entries(self, entries: "Iterable[File]") -> Iterable[dict[str, Any]]: + return (e.model_dump() for e in entries) + def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None: rows = list(rows) if not rows: diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index c210a621e..8b17f975b 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -20,7 +20,7 @@ from datachain.data_storage.schema import convert_rows_custom_column_types from datachain.data_storage.serializer import Serializable from datachain.dataset import DatasetRecord -from datachain.node import DirType, DirTypeGroup, Entry, Node, NodeWithPath, get_path +from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path from datachain.sql.functions import path as pathfunc from datachain.sql.types import Int, SQLType from datachain.storage import StorageURI @@ -34,6 +34,7 @@ from datachain.data_storage import AbstractIDGenerator, schema from datachain.data_storage.db_engine import DatabaseEngine from datachain.data_storage.schema import DataTable + from datachain.lib.file import File try: import numpy as np @@ -410,17 +411,9 @@ def dataset_stats( ((nrows, *rest),) = self.db.execute(query) return nrows, rest[0] if rest else 0 - def prepare_entries( - self, uri: str, entries: Iterable[Entry] - ) -> list[dict[str, Any]]: - """ - Prepares bucket listing entry (row) for inserting into database - """ - - def _prepare_entry(entry: Entry): - return attrs.asdict(entry) | {"source": uri} - - return [_prepare_entry(e) for e in entries] + @abstractmethod + def prepare_entries(self, entries: "Iterable[File]") -> Iterable[dict[str, Any]]: + """Convert File entries so they can be passed on to `insert_rows()`""" @abstractmethod def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None: diff --git a/src/datachain/lib/listing.py b/src/datachain/lib/listing.py index 8c1c611b2..d2357048f 100644 --- a/src/datachain/lib/listing.py +++ b/src/datachain/lib/listing.py @@ -30,8 +30,7 @@ def list_func() -> Iterator[File]: config = client_config or {} client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type] for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()): - for entry in entries: - yield entry.to_file(client.uri) + yield from entries return list_func diff --git a/src/datachain/listing.py b/src/datachain/listing.py index 4f1b3b907..ef5f83f01 100644 --- a/src/datachain/listing.py +++ b/src/datachain/listing.py @@ -9,7 +9,8 @@ from sqlalchemy.sql import func from tqdm import tqdm -from datachain.node import DirType, Entry, Node, NodeWithPath +from datachain.lib.file import File +from datachain.node import DirType, Node, NodeWithPath from datachain.sql.functions import path as pathfunc from datachain.utils import suffix_to_number @@ -80,16 +81,13 @@ async def _fetch(self, start_prefix: str, method: str) -> None: finally: fetch_listing.insert_entries_done() - def insert_entry(self, entry: Entry) -> None: - self.warehouse.insert_rows( - self.dataset_rows.get_table(), - self.warehouse.prepare_entries(self.client.uri, [entry]), - ) + def insert_entry(self, entry: File) -> None: + self.insert_entries([entry]) - def insert_entries(self, entries: Iterable[Entry]) -> None: + def insert_entries(self, entries: Iterable[File]) -> None: self.warehouse.insert_rows( self.dataset_rows.get_table(), - self.warehouse.prepare_entries(self.client.uri, entries), + self.warehouse.prepare_entries(entries), ) def insert_entries_done(self) -> None: diff --git a/src/datachain/node.py b/src/datachain/node.py index a9fa61596..6ef281357 100644 --- a/src/datachain/node.py +++ b/src/datachain/node.py @@ -4,7 +4,6 @@ import attrs from datachain.cache import UniqueId -from datachain.lib.file import File from datachain.storage import StorageURI from datachain.utils import TIME_ZERO, time_to_str @@ -139,48 +138,6 @@ def parent(self): return split[0] -@attrs.define -class Entry: - path: str = "" - etag: str = "" - version: str = "" - is_latest: bool = True - last_modified: Optional[datetime] = None - size: int = 0 - location: Optional[str] = None - - @classmethod - def from_file(cls, path: str, **kwargs) -> "Entry": - return cls(path=path, **kwargs) - - @property - def full_path(self) -> str: - return self.path - - @property - def name(self): - return self.path.rsplit("/", 1)[-1] - - @property - def parent(self): - split = self.path.rsplit("/", 1) - if len(split) <= 1: - return "" - return split[0] - - def to_file(self, source: str) -> File: - return File( - source=source, - path=self.path, - size=self.size, - version=self.version, - etag=self.etag, - is_latest=self.is_latest, - last_modified=self.last_modified, - location=self.location, - ) - - def get_path(parent: str, name: str): return f"{parent}/{name}" if parent else name diff --git a/tests/data.py b/tests/data.py index 30d58e13f..7faf758e9 100644 --- a/tests/data.py +++ b/tests/data.py @@ -1,12 +1,12 @@ from datetime import datetime, timezone -from datachain.node import Entry +from datachain.lib.file import File utc = timezone.utc TIME_ZERO = datetime.fromtimestamp(0, tz=utc) ENTRIES = [ - Entry.from_file( + File( path="description", etag="60a7605e934638ab9113e0f9cf852239", version="7e589b7d-382c-49a5-931f-2b999c930c5e", @@ -14,7 +14,7 @@ last_modified=datetime(2023, 2, 27, 18, 28, 54, tzinfo=utc), size=13, ), - Entry.from_file( + File( path="cats/cat1", etag="4a4be40c96ac6314e91d93f38043a634", version="309eb4a4-bba9-47c1-afcd-d7c51110af6f", @@ -22,7 +22,7 @@ last_modified=datetime(2023, 2, 27, 18, 28, 54, tzinfo=utc), size=4, ), - Entry.from_file( + File( path="cats/cat2", etag="0268c692ff940a830e1e7296aa48c176", version="f9d168d3-6d1b-47ef-8f6a-81fce48de141", @@ -30,7 +30,7 @@ last_modified=datetime(2023, 2, 27, 18, 28, 54, tzinfo=utc), size=4, ), - Entry.from_file( + File( path="dogs/dog1", etag="8fdb60801e9d39a5286aa01dd1f4f4f3", version="b9c31cf7-d011-466a-bf16-cf9da0cb422a", @@ -38,7 +38,7 @@ last_modified=datetime(2023, 2, 27, 18, 28, 54, tzinfo=utc), size=4, ), - Entry.from_file( + File( path="dogs/dog2", etag="2d50c921b22aa164a56c68d71eeb4100", version="3a8bb6d9-38db-47a8-8bcb-8972ea95aa20", @@ -46,7 +46,7 @@ last_modified=datetime(2023, 2, 27, 18, 28, 54, tzinfo=utc), size=3, ), - Entry.from_file( + File( path="dogs/dog3", etag="33c6c2397a1b079e903c474df792d0e2", version="ee49e963-36a8-492a-b03a-e801b93afb40", @@ -54,7 +54,7 @@ last_modified=datetime(2023, 2, 27, 18, 28, 54, tzinfo=utc), size=4, ), - Entry.from_file( + File( path="dogs/others/dog4", etag="a5e1a5d93ff242b745f5cf87aeb726d5", version="c5969421-6900-4060-bc39-d54f4a49b9fc", @@ -69,7 +69,7 @@ # dogs/others # dogs/ INVALID_ENTRIES = [ - Entry.from_file( + File( path="dogs/others/", etag="68b329da9893e34099c7d8ad5cb9c940", version="85969421-6900-4060-bc39-d54f4a49b9ab", @@ -77,7 +77,7 @@ last_modified=datetime(2023, 2, 27, 18, 28, 54, tzinfo=utc), size=4, ), - Entry.from_file( + File( path="dogs/others", etag="68b329da9893e34099c7d8ad5cb9c940", version="85969421-6900-4060-bc39-d54f4a49b9ab", @@ -85,7 +85,7 @@ last_modified=datetime(2023, 2, 27, 18, 28, 54, tzinfo=utc), size=4, ), - Entry.from_file( + File( path="dogs/", etag="68b329da9893e34099c7d8ad5cb9c940", version="85969421-6900-4060-bc39-d54f4a49b9ab", diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 33c444bfb..944bbbce9 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -84,7 +84,7 @@ def fake_index(catalog): def test_find(catalog, fake_index): src_uri = fake_index dirs = ["cats/", "dogs/", "dogs/others/"] - expected_paths = dirs + [entry.full_path for entry in ENTRIES] + expected_paths = dirs + [entry.path for entry in ENTRIES] assert set(catalog.find([src_uri])) == { f"{src_uri}/{path}" for path in expected_paths } diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index 7c78bd3f9..6058ced59 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -2,7 +2,6 @@ import json from datetime import datetime -import attrs import lz4.frame import pandas as pd import pytest @@ -16,7 +15,7 @@ @pytest.fixture def dog_entries(): - return [attrs.asdict(e) for e in ENTRIES if e.name.startswith("dog")] + return [e.model_dump() for e in ENTRIES if e.name.startswith("dog")] @pytest.fixture diff --git a/tests/unit/test_listing.py b/tests/unit/test_listing.py index 241796cb3..c86146c5e 100644 --- a/tests/unit/test_listing.py +++ b/tests/unit/test_listing.py @@ -5,6 +5,7 @@ from datachain.catalog import Catalog from datachain.catalog.catalog import DataSource +from datachain.lib.file import File from datachain.lib.listing import ( LISTING_TTL, is_listing_dataset, @@ -13,7 +14,7 @@ listing_uri_from_name, parse_listing_uri, ) -from datachain.node import DirType, Entry, get_path +from datachain.node import DirType from tests.utils import skip_if_not_sqlite TREE = { @@ -33,7 +34,7 @@ def _tree_to_entries(tree: dict, path=""): yield from _tree_to_entries(v, dir_path) else: for fname in v: - yield Entry.from_file(get_path(path, fname)) + yield File(path=posixpath.join(path, fname)) @pytest.fixture