Skip to content

Commit

Permalink
fix iterable dataset indices order (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored May 22, 2024
1 parent 2a3b018 commit 3b61899
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 123 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ runs/
.pytest_cache/
docs/build/
docs/source/examples/
docs/source/tutorials/
eggs/
*.egg*
lightning_logs/
Expand Down
14 changes: 13 additions & 1 deletion cellarium/ml/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: BSD-3-Clause


from typing import Literal

import lightning.pytorch as pl
import torch
from anndata import AnnData
Expand Down Expand Up @@ -32,6 +34,7 @@ class CellariumAnnDataDataModule(pl.LightningDataModule):
... "var_names_g": AnnDataField(attr="var_names"),
... },
... batch_size=5000,
... iteration_strategy="cache_efficient",
... shuffle=True,
... seed=0,
... drop_last=True,
Expand All @@ -51,6 +54,10 @@ class CellariumAnnDataDataModule(pl.LightningDataModule):
:class:`cellarium.ml.utilities.data.AnnDataField`.
batch_size:
How many samples per batch to load.
iteration_strategy:
Strategy to use for iterating through the dataset. Options are ``same_order`` and ``cache_efficient``.
``same_order`` will iterate through the dataset in the same order independent of the number of replicas
and workers. ``cache_efficient`` will try to minimize the amount of anndata files fetched by each worker.
shuffle:
If ``True``, the data is reshuffled at every epoch.
seed:
Expand Down Expand Up @@ -82,6 +89,7 @@ def __init__(
# IterableDistributedAnnDataCollectionDataset args
batch_keys: dict[str, AnnDataField] | None = None,
batch_size: int = 1,
iteration_strategy: Literal["same_order", "cache_efficient"] = "cache_efficient",
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False,
Expand All @@ -100,6 +108,7 @@ def __init__(
# IterableDistributedAnnDataCollectionDataset args
self.batch_keys = batch_keys or {}
self.batch_size = batch_size
self.iteration_strategy = iteration_strategy
self.shuffle = shuffle
self.seed = seed
self.drop_last = drop_last
Expand All @@ -114,14 +123,15 @@ def setup(self, stage: str | None = None) -> None:
setup is called from every process across all the nodes. Setting state here is recommended.
.. note::
:attr:`val_dataset` is not shuffled.
:attr:`val_dataset` is not shuffled and uses the ``same_order`` iteration strategy.
"""
if stage == "fit":
self.train_dataset = IterableDistributedAnnDataCollectionDataset(
dadc=self.dadc,
batch_keys=self.batch_keys,
batch_size=self.batch_size,
iteration_strategy=self.iteration_strategy,
shuffle=self.shuffle,
seed=self.seed,
drop_last=self.drop_last,
Expand All @@ -133,6 +143,7 @@ def setup(self, stage: str | None = None) -> None:
dadc=self.dadc,
batch_keys=self.batch_keys,
batch_size=self.batch_size,
iteration_strategy="same_order",
shuffle=False,
seed=self.seed,
drop_last=False,
Expand All @@ -146,6 +157,7 @@ def setup(self, stage: str | None = None) -> None:
dadc=self.dadc,
batch_keys=self.batch_keys,
batch_size=self.batch_size,
iteration_strategy=self.iteration_strategy,
shuffle=self.shuffle,
seed=self.seed,
drop_last=self.drop_last,
Expand Down
Loading

0 comments on commit 3b61899

Please sign in to comment.