Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Project main rework #99

Merged
merged 8 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
9 changes: 4 additions & 5 deletions docs/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
datamodule_config,
experiment_dictconfig,
)
from project.experiment import (
from project.experiment import instantiate_datamodule, instantiate_trainer
from project.main import (
instantiate_algorithm,
instantiate_datamodule,
instantiate_trainer,
setup_logging,
)
from project.utils.hydra_utils import resolve_dictconfig
Expand Down Expand Up @@ -121,8 +120,8 @@ def test_notebook_commands_dont_cause_errors(experiment_dictconfig: DictConfig):
# _experiment = _setup_experiment(config)
setup_logging(log_level=config.log_level)
lightning.seed_everything(config.seed, workers=True)
_trainer = instantiate_trainer(config)
_trainer = instantiate_trainer(config.trainer)
datamodule = instantiate_datamodule(config.datamodule)
_algorithm = instantiate_algorithm(config.algorithm, datamodule=datamodule)
_algorithm = instantiate_algorithm(config, datamodule=datamodule)

# Note: Here we don't actually do anything with the objects.
13 changes: 0 additions & 13 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +0,0 @@
from .image_classifier import ImageClassifier
from .jax_image_classifier import JaxImageClassifier
from .jax_ppo import JaxRLExample
from .no_op import NoOp
from .text_classifier import TextClassifier

__all__ = [
"ImageClassifier",
"JaxImageClassifier",
"NoOp",
"TextClassifier",
"JaxRLExample",
]
33 changes: 33 additions & 0 deletions project/algorithms/jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from __future__ import annotations

import contextlib
import dataclasses
import functools
import operator
from collections.abc import Callable, Sequence
from logging import getLogger as get_logger
from pathlib import Path
Expand Down Expand Up @@ -36,6 +38,8 @@
from typing_extensions import TypeVar
from xtils.jitpp import Static

from project import experiment
from project.configs.config import Config
from project.trainers.jax_trainer import JaxCallback, JaxModule, JaxTrainer
from project.utils.typing_utils.jax_typing_utils import field, jit

Expand Down Expand Up @@ -826,3 +830,32 @@ def on_train_epoch_start(self, trainer: JaxTrainer, module: JaxRLExample, ts: PP
gif_path = Path(log_dir) / f"epoch_{ts.data_collection_state.global_step:05}.gif"
module.visualize(ts=ts, gif_path=gif_path)
jax.debug.print("Saved gif to {gif_path}", gif_path=gif_path)


@experiment.evaluate.register
def evaluate_ppo_example(
algorithm: JaxRLExample,
/,
*,
trainer: JaxTrainer,
train_results: tuple[PPOState, EvalMetrics],
config: Config,
datamodule: None = None,
):
"""Override for the `evaluate` function used by `main.py`, in the case of this algorithm."""
# todo: there isn't yet a `validate` method on the jax trainer.
assert isinstance(algorithm, JaxModule)
assert isinstance(trainer, JaxTrainer)
assert train_results is not None
metrics = train_results[1]

last_epoch_metrics = jax.tree.map(operator.itemgetter(-1), metrics)
assert isinstance(last_epoch_metrics, EvalMetrics)
# Average across eval seeds (we're doing evaluation in multiple environments in parallel with
# vmap).
last_epoch_average_cumulative_reward = last_epoch_metrics.cumulative_reward.mean().item()
return (
"-avg_cumulative_reward",
-last_epoch_average_cumulative_reward, # need to return an "error" to minimize for HPO.
dataclasses.asdict(last_epoch_metrics),
)
9 changes: 6 additions & 3 deletions project/algorithms/testsuites/lightning_module_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

from project.configs.config import Config
from project.conftest import DEFAULT_SEED
from project.experiment import instantiate_algorithm, instantiate_trainer, setup_logging
from project.experiment import instantiate_trainer
from project.main import instantiate_algorithm, setup_logging
from project.trainers.jax_trainer import JaxTrainer
from project.utils.hydra_utils import resolve_dictconfig
from project.utils.typing_utils import PyTree, is_sequence_of
Expand All @@ -47,6 +48,8 @@ class LightningModuleTests(Generic[AlgorithmType], ABC):
- Dataset splits: check some basic stats about the train/val/test inputs, are they somewhat similar?
- Define the input as a space, check that the dataset samples are in that space and not too
many samples are statistically OOD?
- Test to monitor distributed traffic out of this process?
- Dummy two-process tests (on CPU) to check before scaling up experiments?
"""

# algorithm_config: ParametrizedFixture[str]
Expand All @@ -67,7 +70,7 @@ def trainer(
) -> lightning.Trainer | JaxTrainer:
setup_logging(log_level=experiment_config.log_level)
lightning.seed_everything(experiment_config.seed, workers=True)
return instantiate_trainer(experiment_config)
return instantiate_trainer(experiment_config.trainer)

@pytest.fixture(scope="class")
def algorithm(
Expand All @@ -79,7 +82,7 @@ def algorithm(
):
"""Fixture that creates the "algorithm" (a
[LightningModule][lightning.pytorch.core.module.LightningModule])."""
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
algorithm = instantiate_algorithm(experiment_config, datamodule=datamodule)
if isinstance(trainer, lightning.Trainer) and isinstance(
algorithm, lightning.LightningModule
):
Expand Down
1 change: 1 addition & 0 deletions project/algorithms/text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
init_seed: int = 42,
):
super().__init__()
self.datamodule = datamodule
self.network_config = network
self.num_labels = datamodule.num_classes
self.task_name = datamodule.task_name
Expand Down
2 changes: 1 addition & 1 deletion project/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Config:
It is suggested for this class to accept a `datamodule` and `network` as arguments. The
instantiated datamodule and network will be passed to the algorithm's constructor.

For more info, see the [instantiate_algorithm][project.experiment.instantiate_algorithm] function.
For more info, see the [instantiate_algorithm][project.main.instantiate_algorithm] function.
"""

datamodule: Any | None = None
Expand Down
3 changes: 2 additions & 1 deletion project/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- base_config
- _self_
- algorithm: ???
- algorithm: null
- optional datamodule: null
- trainer: default.yaml
- hydra: default.yaml
Expand All @@ -12,4 +12,5 @@ defaults:
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null
# This is a good default name to use when you aren't doing a sweep. Otherwise it causes an error.
# name: "${hydra:runtime.choices.algorithm}-${hydra:runtime.choices.network}-${hydra:runtime.choices.datamodule}"
4 changes: 0 additions & 4 deletions project/configs/datamodule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

logger = get_logger(__name__)


# TODO: Make it possible to extend a structured base via yaml files as well as adding new fields
# (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the
# config).
datamodule_store = store(group="datamodule")


Expand Down
2 changes: 1 addition & 1 deletion project/configs/datamodule/cifar10.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- vision
- _self_
_target_: project.datamodules.CIFAR10DataModule
_target_: project.datamodules.image_classification.cifar10.CIFAR10DataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
batch_size: 128
train_transforms:
Expand Down
2 changes: 1 addition & 1 deletion project/configs/datamodule/fashion_mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- mnist
- _self_
_target_: project.datamodules.FashionMNISTDataModule
_target_: project.datamodules.image_classification.fashion_mnist.FashionMNISTDataModule
2 changes: 1 addition & 1 deletion project/configs/datamodule/glue_cola.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: project.datamodules.text.TextClassificationDataModule
_target_: project.datamodules.text.text_classification.TextClassificationDataModule
data_dir: ${oc.env:SCRATCH,.}/data
hf_dataset_path: glue
task_name: cola
Expand Down
2 changes: 1 addition & 1 deletion project/configs/datamodule/imagenet.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- vision
- _self_
_target_: project.datamodules.ImageNetDataModule
_target_: project.datamodules.image_classification.imagenet.ImageNetDataModule
# todo: add good configuration options here.
2 changes: 1 addition & 1 deletion project/configs/datamodule/inaturalist.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- vision
- _self_
_target_: project.datamodules.INaturalistDataModule
_target_: project.datamodules.image_classification.inaturalist.INaturalistDataModule
version: "2021_train"
target_type: "full"
2 changes: 1 addition & 1 deletion project/configs/datamodule/mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- vision
- _self_
_target_: project.datamodules.MNISTDataModule
_target_: project.datamodules.image_classification.mnist.MNISTDataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
normalize: True
batch_size: 128
Expand Down
2 changes: 1 addition & 1 deletion project/configs/datamodule/vision.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# todo: This config should not show up as an option on the command-line.
_target_: project.datamodules.VisionDataModule
_target_: project.datamodules.vision.VisionDataModule
data_dir: ${constant:DATA_DIR}
num_workers: ${constant:NUM_WORKERS}
val_split: 0.1 # NOTE: reduced from default of 0.2
Expand Down
2 changes: 1 addition & 1 deletion project/configs/experiment/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ defaults:
# The parameters below will be merged with parameters from default configurations set above.
# This allows you to overwrite only specified parameters

# The name of the e
# The name of the experiment (for logging)
name: example

seed: ${oc.env:SLURM_PROCID,42}
Expand Down
12 changes: 6 additions & 6 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,12 @@

from project.configs.config import Config
from project.datamodules.vision import VisionDataModule, num_cpus_on_node
from project.experiment import (
from project.experiment import instantiate_datamodule, instantiate_trainer
from project.main import (
PROJECT_NAME,
instantiate_algorithm,
instantiate_datamodule,
instantiate_trainer,
setup_logging,
)
from project.main import PROJECT_NAME
from project.trainers.jax_trainer import JaxTrainer
from project.utils.env_vars import REPO_ROOTDIR
from project.utils.hydra_utils import resolve_dictconfig
Expand Down Expand Up @@ -332,7 +331,7 @@ def algorithm(
):
"""Fixture that creates the "algorithm" (a
[LightningModule][lightning.pytorch.core.module.LightningModule])."""
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
algorithm = instantiate_algorithm(experiment_config, datamodule=datamodule)
if isinstance(trainer, lightning.Trainer) and isinstance(algorithm, lightning.LightningModule):
with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
Expand All @@ -347,8 +346,9 @@ def trainer(
experiment_config: Config,
) -> pl.Trainer | JaxTrainer:
setup_logging(log_level=experiment_config.log_level)
# put here to copy what's done in main.py
lightning.seed_everything(experiment_config.seed, workers=True)
return instantiate_trainer(experiment_config)
return instantiate_trainer(experiment_config.trainer)


@pytest.fixture(scope="session")
Expand Down
Loading
Loading