From a925653a0a64db94711d9b281736b62dc8524507 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 20 Sep 2024 14:39:30 -0400 Subject: [PATCH] `IndexedFile` -> `ArrowRow` (#445) * indexedfile -> arrowvfile * ArrowVFile -> ArrowRow --- docs/references/file.md | 4 ++-- src/datachain/__init__.py | 4 ++-- src/datachain/lib/arrow.py | 11 ++++++++--- src/datachain/lib/dc.py | 4 ++-- src/datachain/lib/file.py | 28 +++++++++++++++++++++++----- tests/unit/lib/test_arrow.py | 17 ++++++++--------- tests/unit/test_module_exports.py | 4 ++-- 7 files changed, 47 insertions(+), 25 deletions(-) diff --git a/docs/references/file.md b/docs/references/file.md index ef3366ba5..b293014e4 100644 --- a/docs/references/file.md +++ b/docs/references/file.md @@ -7,6 +7,8 @@ automatically when creating a `DataChain` from files, like in classes include various metadata fields about the underlying file as well as methods to read from the files and otherwise work with the file contents. +::: datachain.lib.file.ArrowRow + ::: datachain.lib.file.ExportPlacement ::: datachain.lib.file.File @@ -15,8 +17,6 @@ read from the files and otherwise work with the file contents. ::: datachain.lib.file.ImageFile -::: datachain.lib.file.IndexedFile - ::: datachain.lib.file.TarVFile ::: datachain.lib.file.TextFile diff --git a/src/datachain/__init__.py b/src/datachain/__init__.py index 8de05340c..f0d90a92d 100644 --- a/src/datachain/__init__.py +++ b/src/datachain/__init__.py @@ -1,10 +1,10 @@ from datachain.lib.data_model import DataModel, DataType, is_chain_type from datachain.lib.dc import C, Column, DataChain, Sys from datachain.lib.file import ( + ArrowRow, File, FileError, ImageFile, - IndexedFile, TarVFile, TextFile, ) @@ -16,6 +16,7 @@ __all__ = [ "AbstractUDF", "Aggregator", + "ArrowRow", "C", "Column", "DataChain", @@ -26,7 +27,6 @@ "FileError", "Generator", "ImageFile", - "IndexedFile", "Mapper", "ModelStore", "Session", diff --git a/src/datachain/lib/arrow.py b/src/datachain/lib/arrow.py index f6db7dadb..c9f2b3119 100644 --- a/src/datachain/lib/arrow.py +++ b/src/datachain/lib/arrow.py @@ -4,11 +4,11 @@ from typing import TYPE_CHECKING, Optional import pyarrow as pa -from pyarrow.dataset import dataset +from pyarrow.dataset import CsvFileFormat, dataset from tqdm import tqdm from datachain.lib.data_model import dict_to_data_model -from datachain.lib.file import File, IndexedFile +from datachain.lib.file import ArrowRow, File from datachain.lib.model_store import ModelStore from datachain.lib.udf import Generator @@ -84,7 +84,12 @@ def process(self, file: File): vals_dict[field] = val vals = [self.output_schema(**vals_dict)] if self.source: - yield [IndexedFile(file=file, index=index), *vals] + kwargs: dict = self.kwargs + # Can't serialize CsvFileFormat; may lose formatting options. + if isinstance(kwargs.get("format"), CsvFileFormat): + kwargs["format"] = "csv" + arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs) + yield [arrow_file, *vals] else: yield vals index += 1 diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index f73bb86d0..84cab6f29 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -26,8 +26,8 @@ from datachain.lib.convert.values_to_tuples import values_to_tuples from datachain.lib.data_model import DataModel, DataType, dict_to_data_model from datachain.lib.dataset_info import DatasetInfo +from datachain.lib.file import ArrowRow, File, get_file_type from datachain.lib.file import ExportPlacement as FileExportPlacement -from datachain.lib.file import File, IndexedFile, get_file_type from datachain.lib.listing import ( is_listing_dataset, is_listing_expired, @@ -1614,7 +1614,7 @@ def parse_tabular( for name, info in output.model_fields.items() } if source: - output = {"source": IndexedFile} | output # type: ignore[assignment,operator] + output = {"source": ArrowRow} | output # type: ignore[assignment,operator] return self.gen( ArrowGenerator(schema, model, source, nrows, **kwargs), output=output ) diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 368044384..41cd6369f 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -17,6 +17,7 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback from PIL import Image +from pyarrow.dataset import dataset from pydantic import Field, field_validator if TYPE_CHECKING: @@ -439,14 +440,31 @@ def save(self, destination: str): self.read().save(destination) -class IndexedFile(DataModel): - """Metadata indexed from tabular files. - - Includes `file` and `index` signals. - """ +class ArrowRow(DataModel): + """`DataModel` for reading row from Arrow-supported file.""" file: File index: int + kwargs: dict + + @contextmanager + def open(self): + """Stream row contents from indexed file.""" + if self.file._caching_enabled: + self.file.ensure_cached() + path = self.file.get_local_path() + ds = dataset(path, **self.kwargs) + + else: + path = self.file.get_path() + ds = dataset(path, filesystem=self.file.get_fs(), **self.kwargs) + + return ds.take([self.index]).to_reader() + + def read(self): + """Returns row contents as dict.""" + with self.open() as record_batch: + return record_batch.to_pylist()[0] def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]: diff --git a/tests/unit/lib/test_arrow.py b/tests/unit/lib/test_arrow.py index daaf23f22..4d1414b91 100644 --- a/tests/unit/lib/test_arrow.py +++ b/tests/unit/lib/test_arrow.py @@ -13,7 +13,7 @@ schema_to_output, ) from datachain.lib.data_model import dict_to_data_model -from datachain.lib.file import File, IndexedFile +from datachain.lib.file import ArrowRow, File from datachain.lib.hf import HFClassLabel @@ -33,10 +33,11 @@ def test_arrow_generator(tmp_path, catalog, cache): objs = list(func.process(stream)) assert len(objs) == len(ids) - for index, (o, id, text) in enumerate(zip(objs, ids, texts)): - assert isinstance(o[0], IndexedFile) - assert isinstance(o[0].file, File) - assert o[0].index == index + for o, id, text in zip(objs, ids, texts): + assert isinstance(o[0], ArrowRow) + file_vals = o[0].read() + assert file_vals["id"] == id + assert file_vals["text"] == text assert o[1] == id assert o[2] == text @@ -78,10 +79,8 @@ def test_arrow_generator_output_schema(tmp_path, catalog): objs = list(func.process(stream)) assert len(objs) == len(ids) - for index, (o, id, text, dict) in enumerate(zip(objs, ids, texts, dicts)): - assert isinstance(o[0], IndexedFile) - assert isinstance(o[0].file, File) - assert o[0].index == index + for o, id, text, dict in zip(objs, ids, texts, dicts): + assert isinstance(o[0], ArrowRow) assert o[1].id == id assert o[1].text == text assert o[1].dict.a == dict["a"] diff --git a/tests/unit/test_module_exports.py b/tests/unit/test_module_exports.py index 4f3e4567f..9939658d3 100644 --- a/tests/unit/test_module_exports.py +++ b/tests/unit/test_module_exports.py @@ -13,6 +13,7 @@ def test_module_exports(): from datachain import ( AbstractUDF, Aggregator, + ArrowRow, C, Column, DataChain, @@ -22,7 +23,6 @@ def test_module_exports(): FileError, Generator, ImageFile, - IndexedFile, Mapper, Session, TarVFile, @@ -67,6 +67,7 @@ def monkey_import_importerror( from datachain import ( AbstractUDF, Aggregator, + ArrowRow, C, Column, DataChain, @@ -76,7 +77,6 @@ def monkey_import_importerror( FileError, Generator, ImageFile, - IndexedFile, Mapper, Session, TarVFile,