Skip to content

Commit

Permalink
[FEAT]: add to_arrow_iter (#2681)
Browse files Browse the repository at this point in the history
closes #2679
  • Loading branch information
universalmind303 authored Aug 19, 2024
1 parent ae13e22 commit 774a5d6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
21 changes: 21 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,27 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
row = {key: value[i] for (key, value) in pydict.items()}
yield row

@DataframePublicAPI
def to_arrow_iter(self, results_buffer_size: Optional[int] = 1) -> Iterator["pyarrow.Table"]:
"""
Return an iterator of pyarrow tables for this dataframe.
"""
if results_buffer_size is not None and not results_buffer_size > 0:
raise ValueError(f"Provided `results_buffer_size` value must be > 0, received: {results_buffer_size}")
if self._result is not None:
# If the dataframe has already finished executing,
# use the precomputed results.
yield self.to_arrow()

else:
# Execute the dataframe in a streaming fashion.
context = get_context()
partitions_iter = context.runner().run_iter_tables(self._builder, results_buffer_size)

# Iterate through partitions.
for partition in partitions_iter:
yield partition.to_arrow()

@DataframePublicAPI
def iter_partitions(
self, results_buffer_size: Optional[int] = 1
Expand Down
6 changes: 6 additions & 0 deletions tests/table/test_from_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,9 @@ def __iter__(self):
table = daft.from_arrow(my_iter)
tbl = table.to_pydict()
assert tbl == {"text": ["foo1", "bar2", "foo2", "bar2", "foo3", "bar3"]}


def test_to_arrow_iterator() -> None:
df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
it = df.to_arrow_iter()
assert isinstance(next(it), pa.Table)

0 comments on commit 774a5d6

Please sign in to comment.