Skip to content

Commit

Permalink
Update streaming kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylion007 committed Jan 4, 2024
1 parent 4cbdec1 commit 6ecb798
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/benchmarks/bert/src/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class StreamingTextDataset(StreamingDataset):
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
`False``.
samples_per_epoch (int, optional): Provide this field iff you are weighting sub-datasets
epoch_size (int, optional): Provide this field iff you are weighting sub-datasets
proportionally. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
Expand All @@ -96,7 +96,7 @@ def __init__(self,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
samples_per_epoch: Optional[int] = None,
epoch_size: Optional[int] = None,
predownload: int = 100_000,
partition_algo: str = 'orig',
num_canonical_nodes: Optional[int] = None,
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(self,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
samples_per_epoch=samples_per_epoch,
epoch_size=epoch_size,
predownload=predownload,
partition_algo=partition_algo,
num_canonical_nodes=num_canonical_nodes,
Expand Down Expand Up @@ -275,7 +275,7 @@ def build_text_dataloader(
download_timeout=cfg.dataset.get('download_timeout', 60),
validate_hash=cfg.dataset.get('validate_hash', None),
keep_zip=cfg.dataset.get('keep_zip', False),
samples_per_epoch=cfg.dataset.get('samples_per_epoch', None),
epoch_size=cfg.dataset.get('epoch_size', None),
predownload=cfg.dataset.get('predownload', 100_000),
partition_algo=cfg.dataset.get('partition_algo', 'orig'),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', 128),
Expand Down

0 comments on commit 6ecb798

Please sign in to comment.