-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2e7412a
Showing
13 changed files
with
942 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Directories | ||
__pycache__ | ||
data/ | ||
logs/ | ||
outputs/ | ||
plots/ | ||
.vscode/ | ||
container/ | ||
wandb/ | ||
configs/sweeps/ | ||
slurmscripts/ | ||
user**/** | ||
**user** | ||
*.snakemake | ||
*.pytest | ||
.config/ | ||
.nv/ | ||
|
||
# Anything with cache in the name, pytest, ruff, or wandb... | ||
*cache* | ||
|
||
# File types | ||
launch.json | ||
.mypy* | ||
*workspace | ||
*.pyc | ||
*.ipynb | ||
*.out | ||
*.sh | ||
*.log | ||
*.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
repos: | ||
- repo: https://github.com/charliermarsh/ruff-pre-commit | ||
rev: v0.9.1 | ||
hooks: | ||
- id: ruff | ||
args: [--fix, --show-fixes, --exit-non-zero-on-fix] | ||
- id: ruff-format | ||
|
||
- repo: https://github.com/adrienverge/yamllint.git | ||
rev: v1.35.1 | ||
hooks: | ||
- id: yamllint | ||
args: [-d, "{extends: relaxed, rules: {line-length: disable}}"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Enable color logging | ||
defaults: | ||
- override hydra_logging: colorlog | ||
- override job_logging: colorlog | ||
|
||
# To allow hydra to change the current working directory when running the file | ||
job: | ||
chdir: True | ||
|
||
# Interpolated variables output directory, generated dynamically on each run | ||
run: | ||
dir: ${paths.full_path} | ||
|
||
# This makes hydra overwrite the logging instead of appending | ||
job_logging: | ||
handlers: | ||
file: | ||
mode: w |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# @package _global_ | ||
|
||
# Order indicates overwriting | ||
defaults: | ||
- _self_ | ||
- hydra: default.yaml | ||
- experiment: null | ||
|
||
seed: 42 # For reproducibility | ||
project_name: gpt # Determines output directory path and wandb project | ||
network_name: test # Used for both saving and wandb | ||
output_dir: /srv/beegfs/scratch/groups/rodem/nlp/ | ||
ckpt_path: null # Checkpoint path to resume training | ||
weight_ckpt_path: null # Checkpoint path to load weights (but not optimizers etc) | ||
|
||
# Extra tweaks available with the new pytorch version | ||
precision: medium # Should use medium if on ampere gpus | ||
compile: null # Can set to default for faster compiles | ||
tags: null # Extra tags passed to the logger | ||
|
||
# COMPLETELY replaces the all config info with what is contained in ${paths.full_path} | ||
# This is ideal for resuming a job, log to the same directory | ||
# Will also resume the loggers and set the ckpt_path to the latest | ||
full_resume: False | ||
ckpt_flag: last.ckpt # Name of the checkpoint file, can use wildcards | ||
|
||
# Datamodule settings | ||
datamodule: | ||
_target_: src.datamodules.text.TextModule | ||
train_path: ${root_dir}/data/wikitext-103/train.npy | ||
val_path: ${root_dir}/data/wikitext-103/valid.npy | ||
test_path: ${root_dir}/data/wikitext-103/test.npy | ||
max_seq_len: 512 | ||
train_epoch_size: 10000 | ||
val_epoch_size: 1000 | ||
batch_size: 2 | ||
num_workers: 2 | ||
pin_memory: True | ||
|
||
# Model settings | ||
model: | ||
_target_: src.models.gpt.GPT | ||
vocab_size: 50257 | ||
dim: 128 | ||
num_layers: 4 | ||
max_seq_len: ${datamodule.max_seq_len} | ||
final_norm: True | ||
layer_config: | ||
num_heads: 8 | ||
drop: 0.1 | ||
qk_norm: rms | ||
out_norm: none | ||
causal: True | ||
optimizer: | ||
_target_: torch.optim.AdamW | ||
_partial_: True | ||
lr: 1e-4 | ||
scheduler: | ||
_target_: src.schedulers.one_cycle | ||
_partial_: True | ||
max_steps: 10_001 | ||
|
||
# Trainer settings | ||
trainer: | ||
_target_: lightning.Trainer | ||
max_epochs: 10 | ||
enable_progress_bar: True | ||
gradient_clip_val: 1 | ||
precision: 16-mixed | ||
check_val_every_n_epoch: 1 | ||
accelerator: auto | ||
devices: 1 | ||
num_nodes: 1 | ||
default_root_dir: ${paths.full_path} | ||
|
||
# Logger settings | ||
logger: | ||
_target_: lightning.pytorch.loggers.wandb.WandbLogger | ||
offline: False | ||
id: null | ||
log_model: False | ||
tags: ${tags} | ||
project: ${project_name} | ||
name: ${network_name} | ||
save_dir: ${paths.full_path} | ||
resume: ${full_resume} | ||
|
||
# Callbacks | ||
checkpoint_per_epoch: | ||
_target_: lightning.pytorch.callbacks.ModelCheckpoint | ||
dirpath: ${paths.full_path}/checkpoints | ||
filename: last | ||
enable_version_counter: False | ||
auto_insert_metric_name: False | ||
model_summary: | ||
_target_: lightning.pytorch.callbacks.RichModelSummary | ||
max_depth: 2 | ||
lr_monitor: | ||
_target_: lightning.pytorch.callbacks.LearningRateMonitor | ||
logging_interval: step | ||
|
||
# Interpolated paths | ||
root_dir: ${oc.env:PROJECT_ROOT} | ||
full_path: ${output_dir}/${project_name}/${network_name}/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
[project] | ||
name = "nlp_tests" | ||
version = "0.1" | ||
description = "A simple library for testing homegrown NLP models" | ||
license = {text = "MIT"} | ||
requires-python = ">=3.10,<3.12" | ||
dynamic = ["dependencies"] | ||
authors = [ | ||
{name = "Matthew Leigh", email = "[email protected]"} | ||
] | ||
|
||
[project.urls] | ||
"Homepage" = "https://github.com/mattcleigh/nlp_tests" | ||
"Issue Tracker" = "https://github.com/mattcleigh/nlp_tests/issues" | ||
|
||
[tool.setuptools] | ||
packages = ["jetssl"] | ||
|
||
[tool.setuptools.dynamic] | ||
dependencies = {file = ["requirements.txt"]} | ||
|
||
[tool.ruff] | ||
line-length = 88 | ||
preview = true | ||
target-version = "py311" | ||
lint.select = ["ALL", "D212", "D417"] | ||
lint.ignore = [ | ||
"ANN002", "PTH123", "S602", "PLC0415", "ANN101", "ANN201", "PLR0911", | ||
"E402", "TRY003", "D401", "PLR0913", "PLR2004", "ANN001", "S102", "C901", | ||
"D101", "EXE002", "ANN204", "D205", "ISC001", "D105", "ARG002", "SLF001", | ||
"DOC501", "DTZ005", "FBT", "N802", "G004", "ANN401", "D102", "N812", | ||
"PLR6301", "RUF017", "PLR0914", "TD002", "ERA", "D104", "DTZ007", "CPY001", | ||
"BLE001", "FIX", "PLR0917", "T201", "PLR1702", "PLR0912", "S404", "ANN003", | ||
"D100", "S105", "EM", "D103", "INP", "N803", "N806", "PLW1514", "PD011", | ||
"B905", "ANN202", "COM", "PLR0915", "ARG001", "S311", "RUF015", "TD003", | ||
"DOC201", "PD901", "F811", "PLR6104", "TRY002", "S101", "DOC402", "D212", | ||
] | ||
|
||
[tool.ruff.lint.pydocstyle] | ||
convention = "numpy" | ||
|
||
[tool.ruff.lint.flake8-pytest-style] | ||
fixture-parentheses = false | ||
mark-parentheses = false | ||
|
||
[tool.pytest.ini_options] | ||
log_cli = true | ||
log_cli_level = "CRITICAL" | ||
filterwarnings = ["ignore::DeprecationWarning"] | ||
pythonpath = ["."] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
"""Download, tokenize, and save the WikiText-103 dataset.""" | ||
|
||
import numpy as np | ||
import rootutils | ||
import tiktoken | ||
from datasets import load_dataset | ||
|
||
root = rootutils.setup_root(search_from=__file__) | ||
|
||
|
||
def main(): | ||
ds = load_dataset("Salesforce/wikitext", "wikitext-103-v1") | ||
enc = tiktoken.get_encoding("gpt2") | ||
|
||
def prepare(x): | ||
return {"ids": enc.encode_ordinary(x["text"])} | ||
|
||
ds = ds.map(prepare, remove_columns=["text"], num_proc=4) | ||
|
||
# Loop over the splits, save as a single numpy array (only 1GB for train) | ||
for split, data in ds.items(): | ||
print(f"Saving {split} split") | ||
file_name = root / "data" / "wikitext-103" / f"{split}.npy" | ||
file_name.parent.mkdir(parents=True, exist_ok=True) | ||
arr = np.hstack(data.with_format("numpy")["ids"]) | ||
np.save(file_name, arr.astype(np.uint16)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
"""Basic training script.""" | ||
|
||
import logging | ||
|
||
import hydra | ||
import lightning.pytorch as pl | ||
import rootutils | ||
import torch as T | ||
from omegaconf import DictConfig | ||
|
||
root = rootutils.setup_root(search_from=__file__, pythonpath=True) | ||
|
||
|
||
from src.hydra_utils import ( | ||
log_hyperparameters, | ||
print_config, | ||
reload_original_config, | ||
save_config, | ||
) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@hydra.main( | ||
version_base=None, config_path=str(root / "configs"), config_name="train.yaml" | ||
) | ||
def main(cfg: DictConfig) -> None: | ||
"""Main training script.""" | ||
log.info("Setting up full job config") | ||
|
||
if cfg.full_resume: | ||
log.info("Attempting to resume previous job") | ||
old_cfg = reload_original_config(ckpt_flag=cfg.ckpt_flag) | ||
if old_cfg is not None: | ||
cfg = old_cfg | ||
print_config(cfg) | ||
|
||
log.info(f"Setting seed to: {cfg.seed}") | ||
pl.seed_everything(cfg.seed, workers=True) | ||
|
||
log.info(f"Setting matrix precision to: {cfg.precision}") | ||
T.set_float32_matmul_precision(cfg.precision) | ||
|
||
log.info("Instantiating the data module") | ||
datamodule = hydra.utils.instantiate(cfg.datamodule) | ||
|
||
log.info("Instantiating the model") | ||
if cfg.weight_ckpt_path: | ||
log.info(f"Loading model weights from checkpoint: {cfg.ckpt_path}") | ||
model_class = hydra.utils.get_class(cfg.model._target_) | ||
model = model_class.load_from_checkpoint(cfg.ckpt_path, map_location="cpu") | ||
else: | ||
model = hydra.utils.instantiate(cfg.model) | ||
|
||
if cfg.compile: | ||
log.info(f"Compiling the model using torch 2.0: {cfg.compile}") | ||
model = T.compile(model, mode=cfg.compile) | ||
|
||
log.info("Instantiating all callbacks") | ||
callbacks = hydra.utils.instantiate(cfg.callbacks) | ||
|
||
log.info("Instantiating the logger") | ||
logger = hydra.utils.instantiate(cfg.logger) | ||
|
||
log.info("Instantiating the trainer") | ||
trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) | ||
|
||
log.info("Logging all hyperparameters") | ||
log_hyperparameters(cfg, model, trainer) | ||
log.info(model) | ||
|
||
log.info("Saving config so job can be resumed") | ||
save_config(cfg) | ||
|
||
log.info("Starting training!") | ||
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.