Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrates to LightningCLI #265

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ Training is performed by the [`yoyodyne-train`](yoyodyne/train.py) script. One
must specify the following required arguments:

- `--model_dir`: path for model metadata and checkpoints
- `--experiment`: name of experiment (pick something unique)
- `--train`: path to TSV file containing training data
- `--val`: path to TSV file containing validation data

Expand Down Expand Up @@ -108,7 +107,6 @@ One must specify the following required arguments:

- `--arch`: architecture, matching the one used for training
- `--model_dir`: path for model metadata
- `--experiment`: name of experiment
- `--checkpoint`: path to checkpoint
- `--predict`: path to file containing data to be predicted
- `--output`: path for predictions
Expand Down Expand Up @@ -162,12 +160,15 @@ provide any symbols of the form `<...>`, `[...]`, or `{...}`.
Checkpointing is handled by
[Lightning](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html).
The path for model information, including checkpoints, is specified by a
combination of `--model_dir` and `--experiment`, such that we build the path
`model_dir/experiment/version_n`, where each run of an experiment with the same
`model_dir` and `experiment` is namespaced with a new version number. A version
stores everything needed to reload the model, including the hyperparameters
(`model_dir/experiment_name/version_n/hparams.yaml`) and the checkpoints
directory (`model_dir/experiment_name/version_n/checkpoints`).
combination of `--model_dir` such that we build the path `model_dir/version_n`,
where each run of an experiment with the same `model_dir` is namespaced with a
new version number. A version stores everything needed to reload the model,
including:

- the index (`model_dir/index.pkl`),
- the hyperparameters (`model_dir/lightning_logs/version_n/hparams.yaml`),
- the metrics (`model_dir/lightning_logs/version_n/metrics.csv`), and
- the checkpoints (`model_dir/lightning_logs/version_n/checkpoints`).

By default, each run initializes a new model from scratch, unless the
`--train_from` argument is specified. To continue training from a specific
Expand Down Expand Up @@ -287,7 +288,9 @@ A non-exhaustive list includes:
- Seeding:
- `--seed`
- [Weights & Biases](https://wandb.ai/site):
- `--log_wandb` (default: `False`): enables Weights & Biases tracking
- `--log_wandb` (default: `False`): enables Weights & Biases tracking; the
"project" name can be specified using the environmental variable
`$WANDB_PROJECT`.

Additional training options are discussed below.

Expand Down
9 changes: 3 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ exclude = ["examples*"]

[project]
name = "yoyodyne"
version = "0.2.14"
version = "0.3.0"
description = "Small-vocabulary neural sequence-to-sequence models"
readme = "README.md"
requires-python = ">= 3.9"
Expand All @@ -29,9 +29,7 @@ keywords = [
]
dependencies = [
"maxwell >= 0.2.5",
# TODO: allow >= 2.0.0 once we we migrate to lightning >= 2.0.0".
"numpy >= 1.26.0, < 2.0.0",
"lightning >= 1.7.0, < 2.0.0",
"lightning >= 2.4.0",
"torch >= 2.5.1",
"wandb >= 0.18.5",
]
Expand All @@ -49,8 +47,7 @@ classifiers = [
]

[project.scripts]
yoyodyne-predict = "yoyodyne.predict:main"
yoyodyne-train = "yoyodyne.train:main"
yoyodyne = "yoyodyne.cli:main"

[project.urls]
homepage = "https://github.com/CUNY-CL/yoyodyne"
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
black>=24.10.0
build>=1.2.1
flake8>=7.1.0
lightning>=1.7.0,<2.0.0
lightning>=2.4.0
maxwell>=0.2.5
numpy>=1.26.0,<2.0.0
numpy>=2.0.0
pandas>=2.2.2
pytest>=8.3.2
scipy>=1.13.1
Expand All @@ -12,4 +12,4 @@ torch>=2.5.1
tqdm>=4.66.6
twine>=5.1.1
wandb>=0.18.5
wheel>=0.40.0
wheel>=0.45.0
Empty file added yoyodyne/cli.py
Empty file.
2 changes: 2 additions & 0 deletions yoyodyne/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .datamodules import DataModule # noqa: F401
from .datasets import Dataset # noqa: F401
from .indexes import Index # noqa: F401
from .mappers import Mapper # noqa: F401
from .tsv import TsvParser # noqa: F401


def add_argparse_args(parser: argparse.ArgumentParser) -> None:
Expand Down
1 change: 1 addition & 0 deletions yoyodyne/data/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import dataclasses

from typing import List

import torch
Expand Down
112 changes: 69 additions & 43 deletions yoyodyne/data/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,60 @@
from torch.utils import data

from .. import defaults, util
from . import collators, datasets, indexes, tsv
from . import collators, datasets, indexes, mappers, tsv


class DataModule(lightning.LightningDataModule):
"""Parses, indexes, collates and loads data.
"""Data module.

This class is initialized by the LightningCLI interface. It manages all
data loading steps.

Args:
model_dir: Path for checkpoints, indexes, and logs.
predict

The batch size tuner is permitted to mutate the `batch_size` argument.
"""

predict: Optional[str]
test: Optional[str]
train: Optional[str]
val: Optional[str]
parser: tsv.TsvParser
index: indexes.Index
separate_features: bool
batch_size: int
index: indexes.Index
collator: collators.Collator

def __init__(
self,
# Paths.
*,
train: Optional[str] = None,
val: Optional[str] = None,
predict: Optional[str] = None,
test: Optional[str] = None,
index_path: Optional[str] = None,
# TSV parsing arguments.
model_dir: str,
predict=None,
train=None,
test=None,
val=None,
# TSV parsing options.
source_col: int = defaults.SOURCE_COL,
features_col: int = defaults.FEATURES_COL,
target_col: int = defaults.TARGET_COL,
# String parsing arguments.
source_sep: str = defaults.SOURCE_SEP,
features_sep: str = defaults.FEATURES_SEP,
target_sep: str = defaults.TARGET_SEP,
# Modeling options.
separate_features: bool = False,
tie_embeddings: bool = defaults.TIE_EMBEDDINGS,
# Collator options.
# Other.
batch_size: int = defaults.BATCH_SIZE,
separate_features: bool = False,
max_source_length: int = defaults.MAX_SOURCE_LENGTH,
max_target_length: int = defaults.MAX_TARGET_LENGTH,
# Indexing.
index: Optional[indexes.Index] = None,
):
super().__init__()
self.train = train
self.val = val
self.predict = predict
self.test = test
self.parser = tsv.TsvParser(
source_col=source_col,
features_col=features_col,
Expand All @@ -56,14 +69,15 @@ def __init__(
target_sep=target_sep,
tie_embeddings=tie_embeddings,
)
self.tie_embeddings = tie_embeddings
self.train = train
self.val = val
self.predict = predict
self.test = test
self.batch_size = batch_size
self.separate_features = separate_features
self.index = index if index is not None else self._make_index()
self.batch_size = batch_size
# If the training data is specified, it is used to create (or recreate)
# the index; if not specified it is read from the model directory.
self.index = (
self._make_index(model_dir, tie_embeddings)
if self.train
else indexes.Index.read(model_dir)
)
self.collator = collators.Collator(
has_features=self.has_features,
has_target=self.has_target,
Expand All @@ -72,11 +86,16 @@ def __init__(
max_target_length=max_target_length,
)

def _make_index(self) -> indexes.Index:
# Computes index.
def _make_index(
self, model_dir: str, tie_embeddings: bool
) -> indexes.Index:
source_vocabulary: Set[str] = set()
features_vocabulary: Set[str] = set()
target_vocabulary: Set[str] = set()
features_vocabulary: Optional[Set[str]] = (
set() if self.has_features else None
)
target_vocabulary: Optional[Set[str]] = (
set() if self.has_target else None
)
if self.has_features:
if self.has_target:
for source, features, target in self.parser.samples(
Expand All @@ -96,16 +115,18 @@ def _make_index(self) -> indexes.Index:
else:
for source in self.parser.samples(self.train):
source_vocabulary.update(source)
return indexes.Index(
source_vocabulary=sorted(source_vocabulary),
index = indexes.Index(
source_vocabulary=source_vocabulary,
features_vocabulary=(
sorted(features_vocabulary) if features_vocabulary else None
),
target_vocabulary=(
sorted(target_vocabulary) if target_vocabulary else None
features_vocabulary if features_vocabulary else None
),
tie_embeddings=self.tie_embeddings,
target_vocabulary=target_vocabulary if target_vocabulary else None,
tie_embeddings=tie_embeddings,
)
index.write(model_dir)
return index

# Logging.

@staticmethod
def pprint(vocabulary: Iterable) -> str:
Expand All @@ -128,9 +149,7 @@ def log_vocabularies(self) -> None:
f"{self.pprint(self.index.target_vocabulary)}"
)

def write_index(self, model_dir: str, experiment: str) -> None:
"""Writes the index."""
self.index.write(model_dir, experiment)
# Properties.

@property
def has_features(self) -> bool:
Expand All @@ -149,13 +168,6 @@ def source_vocab_size(self) -> int:
self.index.source_vocab_size + self.index.features_vocab_size
)

def _dataset(self, path: str) -> datasets.Dataset:
return datasets.Dataset(
list(self.parser.samples(path)),
self.index,
self.parser,
)

# Required API.

def train_dataloader(self) -> data.DataLoader:
Expand All @@ -166,6 +178,7 @@ def train_dataloader(self) -> data.DataLoader:
batch_size=self.batch_size,
shuffle=True,
num_workers=1,
persistent_workers=True,
)

def val_dataloader(self) -> data.DataLoader:
Expand All @@ -174,7 +187,9 @@ def val_dataloader(self) -> data.DataLoader:
self._dataset(self.val),
collate_fn=self.collator,
batch_size=self.batch_size,
shuffle=False,
num_workers=1,
persistent_workers=True,
)

def predict_dataloader(self) -> data.DataLoader:
Expand All @@ -183,7 +198,9 @@ def predict_dataloader(self) -> data.DataLoader:
self._dataset(self.predict),
collate_fn=self.collator,
batch_size=self.batch_size,
shuffle=False,
num_workers=1,
persistent_workers=True,
)

def test_dataloader(self) -> data.DataLoader:
Expand All @@ -192,5 +209,14 @@ def test_dataloader(self) -> data.DataLoader:
self._dataset(self.test),
collate_fn=self.collator,
batch_size=self.batch_size,
shuffle=False,
num_workers=1,
persistent_workers=True,
)

def _dataset(self, path: str) -> datasets.Dataset:
return datasets.Dataset(
list(self.parser.samples(path)),
mappers.Mapper(self.index),
self.parser,
)
Loading