Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Storage options del vectors #33

Merged
merged 5 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions deltatorch/deltadataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Optional, Callable, List, Tuple, Dict, Any

import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
import torch.distributed
from PIL import Image
Expand Down Expand Up @@ -40,6 +39,7 @@ def __init__(
shuffle: bool = False,
batch_size: int = 32,
drop_last: bool = False,
storage_options: Optional[Dict[str, str]] = None,
) -> None:
super().__init__()
self.path = path
Expand All @@ -55,6 +55,7 @@ def __init__(
self.drop_last = drop_last
self.path = path
self.batch_size = batch_size
self.storage_options = storage_options
self.init_boundaries(path)

@abstractmethod
Expand Down Expand Up @@ -127,7 +128,9 @@ def count(self):
return self.count_with_partition_filters(_delta_table)
else:
_add_actions = _delta_table.get_add_actions().to_pandas()
return _add_actions["num_records"].sum()
num_records = _add_actions["num_records"].sum()
del _delta_table
return num_records

def count_with_partition_filters(self, _delta_table):
_cnt = 0
Expand All @@ -142,7 +145,17 @@ def count_with_partition_filters(self, _delta_table):
return _cnt

def create_delta_table(self):
return DeltaTable(self.path, version=self.version)
delta_table = DeltaTable(
self.path, version=self.version, storage_options=self.storage_options
)
conf = delta_table.metadata().configuration
if conf:
deletion_vectors = conf.get("delta.enableDeletionVectors", None)
if deletion_vectors == "true":
raise Exception(
"Tables with enabled Deletion Vectors are not supported."
)
return delta_table

def __iter__(self):
return self.process_data()
Expand Down
9 changes: 5 additions & 4 deletions deltatorch/id_based_deltadataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import logging
import random
from typing import List, Optional, Tuple, Any
from typing import List, Optional, Tuple, Any, Dict

import pyarrow.compute as pc
from pyarrow.dataset import Expression
from deltalake import DeltaTable
from torch.utils.data import get_worker_info

from deltatorch import DeltaIterableDataset
Expand All @@ -28,6 +26,7 @@ def __init__(
shuffle: bool = False,
batch_size: int = 32,
drop_last: bool = False,
storage_options: Optional[Dict[str, str]] = None,
):
super().__init__(
path,
Expand All @@ -41,6 +40,7 @@ def __init__(
shuffle,
batch_size,
drop_last,
storage_options,
)
self.id_field = id_field

Expand Down Expand Up @@ -68,7 +68,7 @@ def process_data(self):
pc.field(self.id_field) < pc.scalar(iter_end)
)

delta_table = DeltaTable(self.path, version=self.version)
delta_table = self.create_delta_table()
scanner = delta_table.to_pyarrow_dataset().scanner(
columns=self.arrow_fields, filter=_filter
)
Expand All @@ -83,3 +83,4 @@ def process_data(self):
item, self.field_specs
)
yield item
del delta_table
6 changes: 4 additions & 2 deletions deltatorch/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Optional, Tuple, Any
from typing import List, Optional, Tuple, Any, Dict

from pyarrow.dataset import Expression
from torch.utils.data import DataLoader

from .deltadataset import FieldSpec
Expand All @@ -20,6 +19,7 @@ def create_pytorch_dataloader(
num_workers: int = 2,
shuffle: bool = False,
drop_last: bool = False,
storage_options: Optional[Dict[str, str]] = None,
**pytorch_dataloader_kwargs
):
"""Create a PyTorch DataLoader.
Expand Down Expand Up @@ -50,6 +50,7 @@ def create_pytorch_dataloader(
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
:param storage_options: a dictionary of the options to use for the storage backend
:param pytorch_dataloader_kwargs: arguments for `torch.utils.data.DataLoader`,
exclude these arguments: ``batch_size``, ``num_workers``, ``shuffle``,
``drop_last``.
Expand All @@ -69,6 +70,7 @@ def create_pytorch_dataloader(
shuffle,
batch_size,
drop_last,
storage_options,
)

return DataLoader(
Expand Down
Loading
Loading