Skip to content

Commit

Permalink
Parquet source: When pseudo_shuffle=True, limit the number of shard…
Browse files Browse the repository at this point in the history
…s we read from (databricks#827)
  • Loading branch information
dsmilkov authored Nov 7, 2023
1 parent 6189fc5 commit fb43614
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 69 deletions.
3 changes: 2 additions & 1 deletion .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"ZixuanChen.vitest-explorer",
"ryanluker.vscode-coverage-gutters",
"bradlc.vscode-tailwindcss",
"svelte.svelte-vscode"
"svelte.svelte-vscode",
"ms-python.mypy-type-checker"
],
// List of extensions recommended by VS Code that should not be recommended for users of this workspace.
"unwantedRecommendations": []
Expand Down
1 change: 0 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
"eslint.workingDirectories": ["auto"],
"eslint.validate": ["typescript", "svelte"],
"python.envFile": "${workspaceFolder}/.venv",
"python.linting.mypyEnabled": true,
"python.defaultInterpreterPath": "${workspaceFolder}/.venv/bin/python",
"git.enableSmartCommit": true,
"git.confirmSync": false,
Expand Down
12 changes: 7 additions & 5 deletions docs/datasets/dataset_load.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,19 @@ use a glob pattern to load multiple files.
The `ParquetSource` takes a few optional arguments related to sampling:

- `sample_size`, the number of rows to sample.
- `approximate_shuffle`, defaulting to `False`. When `False`, we take an entire pass over the
dataset with reservoir sampling. When `True`, we read a fraction of rows from the start of each
shard, to avoid shard skew, without doing a full pass over the entire dataset. This is useful when
your dataset is very large and consists of a large number of shards.
- `pseudo_shuffle`, defaulting to `False`. When `False`, we take an entire pass over the dataset
with reservoir sampling. When `True`, we read a fraction of rows from the start of each shard, to
avoid shard skew, without doing a full pass over the entire dataset. This is useful when your
dataset is very large and consists of a large number of shards.
- `pseudo_shuffle_num_shards`, the maximum number of shards to read from when `pseudo_shuffle` is
`True`. Defaults to `10`.
- `seed`, the random seed to use for sampling.

```python
source = ll.ParquetSource(
filepaths=['s3://lilac-public-data/test-*.parquet'],
sample_size=100,
approximate_shuffle=True)
pseudo_shuffle=True)
config = ll.DatasetConfig(namespace='local', name='parquet-test', source=source)
dataset = ll.create_dataset(config)
```
Expand Down
22 changes: 14 additions & 8 deletions lilac/sources/parquet_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@ class ParquetSource(Source):
sample_size: Optional[int] = Field(
title='Sample size', description='Number of rows to sample from the dataset', default=None
)
approximate_shuffle: bool = Field(
pseudo_shuffle: bool = Field(
default=False,
description='If true, the reader will read a fraction of rows from each shard, '
'avoiding a pass over the entire dataset.',
)
pseudo_shuffle_num_shards: int = Field(
default=10, description='Number of shards to sample from when using pseudo shuffle.'
)

_source_schema: Optional[SourceSchema] = None
_readers: list[pa.RecordBatchReader] = []
Expand All @@ -60,23 +63,26 @@ def validate_sample_size(cls, sample_size: int) -> int:
raise ValueError('sample_size must be greater than 0.')
return sample_size

@field_validator('approximate_shuffle')
@field_validator('pseudo_shuffle')
@classmethod
def validate_approximate_shuffle(cls, approximate_shuffle: bool, info: ValidationInfo) -> bool:
def validate_pseudo_shuffle(cls, pseudo_shuffle: bool, info: ValidationInfo) -> bool:
"""Validate shuffle before sampling."""
if approximate_shuffle and not info.data['sample_size']:
raise ValueError('`approximate_shuffle` requires `sample_size` to be set.')
return approximate_shuffle
if pseudo_shuffle and not info.data['sample_size']:
raise ValueError('`pseudo_shuffle` requires `sample_size` to be set.')
return pseudo_shuffle

def _setup_sampling(self, duckdb_paths: list[str]) -> Schema:
assert self._con, 'setup() must be called first.'
if self.approximate_shuffle:
assert self.sample_size, 'approximate_shuffle requires sample_size to be set.'
if self.pseudo_shuffle:
assert self.sample_size, 'pseudo_shuffle requires sample_size to be set.'
# Find each individual file.
glob_rows: list[tuple[str]] = self._con.execute(
f'SELECT * FROM GLOB({duckdb_paths})'
).fetchall()
duckdb_files: list[str] = list(set([row[0] for row in glob_rows]))
# Sub-sample shards so we don't open too many files.
num_shards = min(self.pseudo_shuffle_num_shards, len(duckdb_files))
duckdb_files = random.sample(duckdb_files, num_shards)
batch_size = max(1, min(self.sample_size // len(duckdb_files), ROWS_PER_BATCH_READ))
for duckdb_file in duckdb_files:
# Since we are not fetching the entire results immediately, we need a seperate cursor
Expand Down
16 changes: 8 additions & 8 deletions lilac/sources/parquet_source_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_single_shard_with_sampling(tmp_path: pathlib.Path) -> None:
assert len(items) == min(sample_size, len(source_items))


def test_single_shard_approximate_shuffle(tmp_path: pathlib.Path) -> None:
def test_single_shard_pseudo_shuffle(tmp_path: pathlib.Path) -> None:
source_items = [{'name': 'a', 'age': 1}, {'name': 'b', 'age': 2}, {'name': 'c', 'age': 3}]
table = pa.Table.from_pylist(source_items)

Expand All @@ -57,7 +57,7 @@ def test_single_shard_approximate_shuffle(tmp_path: pathlib.Path) -> None:

# Test sampling with different sample sizes, including sample size > num_items.
for sample_size in range(1, 5):
source = ParquetSource(filepaths=[out_file], sample_size=sample_size, approximate_shuffle=True)
source = ParquetSource(filepaths=[out_file], sample_size=sample_size, pseudo_shuffle=True)
source.setup()
items = list(source.process())
assert len(items) == min(sample_size, len(source_items))
Expand Down Expand Up @@ -103,30 +103,30 @@ def test_multi_shard_approx_shuffle(tmp_path: pathlib.Path) -> None:
for sample_size in range(1, 5):
source = ParquetSource(
filepaths=[str(tmp_path / 'test-*.parquet')],
approximate_shuffle=True,
pseudo_shuffle=True,
sample_size=sample_size,
)
source.setup()
items = list(source.process())
assert len(items) == min(sample_size, len(source_items))


def test_uniform_shards_approximate_shuffle(tmp_path: pathlib.Path) -> None:
def test_uniform_shards_pseudo_shuffle(tmp_path: pathlib.Path) -> None:
source_items = [{'index': i} for i in range(100)]
for i, chunk in enumerate(chunks(source_items, 10)):
table = pa.Table.from_pylist(chunk)
out_file = tmp_path / f'test-{i}.parquet'
pq.write_table(table, out_file)

source = ParquetSource(
filepaths=[str(tmp_path / 'test-*.parquet')], approximate_shuffle=True, sample_size=20
filepaths=[str(tmp_path / 'test-*.parquet')], pseudo_shuffle=True, sample_size=20
)
source.setup()
items = list(source.process())
assert len(items) == 20


def test_nonuniform_shards_approximate_shuffle(tmp_path: pathlib.Path) -> None:
def test_nonuniform_shards_pseudo_shuffle(tmp_path: pathlib.Path) -> None:
source_items = [{'index': i} for i in range(100)]
shard_sizes = [49, 1, 40, 10]
for i, shard_size in enumerate(shard_sizes):
Expand All @@ -137,7 +137,7 @@ def test_nonuniform_shards_approximate_shuffle(tmp_path: pathlib.Path) -> None:
pq.write_table(table, out_file)

source = ParquetSource(
filepaths=[str(tmp_path / 'test-*.parquet')], approximate_shuffle=True, sample_size=20
filepaths=[str(tmp_path / 'test-*.parquet')], pseudo_shuffle=True, sample_size=20
)
source.setup()
items = list(source.process())
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_approx_shuffle_with_seed(tmp_path: pathlib.Path) -> None:
pq.write_table(table, out_file)

source = ParquetSource(
filepaths=[str(tmp_path / 'test-*.parquet')], approximate_shuffle=True, sample_size=20, seed=42
filepaths=[str(tmp_path / 'test-*.parquet')], pseudo_shuffle=True, sample_size=20, seed=42
)
source.setup()
items = list(source.process())
Expand Down
Loading

0 comments on commit fb43614

Please sign in to comment.