Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewleighCERN committed Jan 14, 2025
0 parents commit 2e7412a
Show file tree
Hide file tree
Showing 13 changed files with 942 additions and 0 deletions.
31 changes: 31 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Directories
__pycache__
data/
logs/
outputs/
plots/
.vscode/
container/
wandb/
configs/sweeps/
slurmscripts/
user**/**
**user**
*.snakemake
*.pytest
.config/
.nv/

# Anything with cache in the name, pytest, ruff, or wandb...
*cache*

# File types
launch.json
.mypy*
*workspace
*.pyc
*.ipynb
*.out
*.sh
*.log
*.idea
13 changes: 13 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.9.1
hooks:
- id: ruff
args: [--fix, --show-fixes, --exit-non-zero-on-fix]
- id: ruff-format

- repo: https://github.com/adrienverge/yamllint.git
rev: v1.35.1
hooks:
- id: yamllint
args: [-d, "{extends: relaxed, rules: {line-length: disable}}"]
18 changes: 18 additions & 0 deletions configs/hydra/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Enable color logging
defaults:
- override hydra_logging: colorlog
- override job_logging: colorlog

# To allow hydra to change the current working directory when running the file
job:
chdir: True

# Interpolated variables output directory, generated dynamically on each run
run:
dir: ${paths.full_path}

# This makes hydra overwrite the logging instead of appending
job_logging:
handlers:
file:
mode: w
105 changes: 105 additions & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# @package _global_

# Order indicates overwriting
defaults:
- _self_
- hydra: default.yaml
- experiment: null

seed: 42 # For reproducibility
project_name: gpt # Determines output directory path and wandb project
network_name: test # Used for both saving and wandb
output_dir: /srv/beegfs/scratch/groups/rodem/nlp/
ckpt_path: null # Checkpoint path to resume training
weight_ckpt_path: null # Checkpoint path to load weights (but not optimizers etc)

# Extra tweaks available with the new pytorch version
precision: medium # Should use medium if on ampere gpus
compile: null # Can set to default for faster compiles
tags: null # Extra tags passed to the logger

# COMPLETELY replaces the all config info with what is contained in ${paths.full_path}
# This is ideal for resuming a job, log to the same directory
# Will also resume the loggers and set the ckpt_path to the latest
full_resume: False
ckpt_flag: last.ckpt # Name of the checkpoint file, can use wildcards

# Datamodule settings
datamodule:
_target_: src.datamodules.text.TextModule
train_path: ${root_dir}/data/wikitext-103/train.npy
val_path: ${root_dir}/data/wikitext-103/valid.npy
test_path: ${root_dir}/data/wikitext-103/test.npy
max_seq_len: 512
train_epoch_size: 10000
val_epoch_size: 1000
batch_size: 2
num_workers: 2
pin_memory: True

# Model settings
model:
_target_: src.models.gpt.GPT
vocab_size: 50257
dim: 128
num_layers: 4
max_seq_len: ${datamodule.max_seq_len}
final_norm: True
layer_config:
num_heads: 8
drop: 0.1
qk_norm: rms
out_norm: none
causal: True
optimizer:
_target_: torch.optim.AdamW
_partial_: True
lr: 1e-4
scheduler:
_target_: src.schedulers.one_cycle
_partial_: True
max_steps: 10_001

# Trainer settings
trainer:
_target_: lightning.Trainer
max_epochs: 10
enable_progress_bar: True
gradient_clip_val: 1
precision: 16-mixed
check_val_every_n_epoch: 1
accelerator: auto
devices: 1
num_nodes: 1
default_root_dir: ${paths.full_path}

# Logger settings
logger:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
offline: False
id: null
log_model: False
tags: ${tags}
project: ${project_name}
name: ${network_name}
save_dir: ${paths.full_path}
resume: ${full_resume}

# Callbacks
checkpoint_per_epoch:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${paths.full_path}/checkpoints
filename: last
enable_version_counter: False
auto_insert_metric_name: False
model_summary:
_target_: lightning.pytorch.callbacks.RichModelSummary
max_depth: 2
lr_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: step

# Interpolated paths
root_dir: ${oc.env:PROJECT_ROOT}
full_path: ${output_dir}/${project_name}/${network_name}/

50 changes: 50 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
[project]
name = "nlp_tests"
version = "0.1"
description = "A simple library for testing homegrown NLP models"
license = {text = "MIT"}
requires-python = ">=3.10,<3.12"
dynamic = ["dependencies"]
authors = [
{name = "Matthew Leigh", email = "[email protected]"}
]

[project.urls]
"Homepage" = "https://github.com/mattcleigh/nlp_tests"
"Issue Tracker" = "https://github.com/mattcleigh/nlp_tests/issues"

[tool.setuptools]
packages = ["jetssl"]

[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}

[tool.ruff]
line-length = 88
preview = true
target-version = "py311"
lint.select = ["ALL", "D212", "D417"]
lint.ignore = [
"ANN002", "PTH123", "S602", "PLC0415", "ANN101", "ANN201", "PLR0911",
"E402", "TRY003", "D401", "PLR0913", "PLR2004", "ANN001", "S102", "C901",
"D101", "EXE002", "ANN204", "D205", "ISC001", "D105", "ARG002", "SLF001",
"DOC501", "DTZ005", "FBT", "N802", "G004", "ANN401", "D102", "N812",
"PLR6301", "RUF017", "PLR0914", "TD002", "ERA", "D104", "DTZ007", "CPY001",
"BLE001", "FIX", "PLR0917", "T201", "PLR1702", "PLR0912", "S404", "ANN003",
"D100", "S105", "EM", "D103", "INP", "N803", "N806", "PLW1514", "PD011",
"B905", "ANN202", "COM", "PLR0915", "ARG001", "S311", "RUF015", "TD003",
"DOC201", "PD901", "F811", "PLR6104", "TRY002", "S101", "DOC402", "D212",
]

[tool.ruff.lint.pydocstyle]
convention = "numpy"

[tool.ruff.lint.flake8-pytest-style]
fixture-parentheses = false
mark-parentheses = false

[tool.pytest.ini_options]
log_cli = true
log_cli_level = "CRITICAL"
filterwarnings = ["ignore::DeprecationWarning"]
pythonpath = ["."]
30 changes: 30 additions & 0 deletions scripts/setup_wikitext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Download, tokenize, and save the WikiText-103 dataset."""

import numpy as np
import rootutils
import tiktoken
from datasets import load_dataset

root = rootutils.setup_root(search_from=__file__)


def main():
ds = load_dataset("Salesforce/wikitext", "wikitext-103-v1")
enc = tiktoken.get_encoding("gpt2")

def prepare(x):
return {"ids": enc.encode_ordinary(x["text"])}

ds = ds.map(prepare, remove_columns=["text"], num_proc=4)

# Loop over the splits, save as a single numpy array (only 1GB for train)
for split, data in ds.items():
print(f"Saving {split} split")
file_name = root / "data" / "wikitext-103" / f"{split}.npy"
file_name.parent.mkdir(parents=True, exist_ok=True)
arr = np.hstack(data.with_format("numpy")["ids"])
np.save(file_name, arr.astype(np.uint16))


if __name__ == "__main__":
main()
80 changes: 80 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Basic training script."""

import logging

import hydra
import lightning.pytorch as pl
import rootutils
import torch as T
from omegaconf import DictConfig

root = rootutils.setup_root(search_from=__file__, pythonpath=True)


from src.hydra_utils import (
log_hyperparameters,
print_config,
reload_original_config,
save_config,
)

log = logging.getLogger(__name__)


@hydra.main(
version_base=None, config_path=str(root / "configs"), config_name="train.yaml"
)
def main(cfg: DictConfig) -> None:
"""Main training script."""
log.info("Setting up full job config")

if cfg.full_resume:
log.info("Attempting to resume previous job")
old_cfg = reload_original_config(ckpt_flag=cfg.ckpt_flag)
if old_cfg is not None:
cfg = old_cfg
print_config(cfg)

log.info(f"Setting seed to: {cfg.seed}")
pl.seed_everything(cfg.seed, workers=True)

log.info(f"Setting matrix precision to: {cfg.precision}")
T.set_float32_matmul_precision(cfg.precision)

log.info("Instantiating the data module")
datamodule = hydra.utils.instantiate(cfg.datamodule)

log.info("Instantiating the model")
if cfg.weight_ckpt_path:
log.info(f"Loading model weights from checkpoint: {cfg.ckpt_path}")
model_class = hydra.utils.get_class(cfg.model._target_)
model = model_class.load_from_checkpoint(cfg.ckpt_path, map_location="cpu")
else:
model = hydra.utils.instantiate(cfg.model)

if cfg.compile:
log.info(f"Compiling the model using torch 2.0: {cfg.compile}")
model = T.compile(model, mode=cfg.compile)

log.info("Instantiating all callbacks")
callbacks = hydra.utils.instantiate(cfg.callbacks)

log.info("Instantiating the logger")
logger = hydra.utils.instantiate(cfg.logger)

log.info("Instantiating the trainer")
trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)

log.info("Logging all hyperparameters")
log_hyperparameters(cfg, model, trainer)
log.info(model)

log.info("Saving config so job can be resumed")
save_config(cfg)

log.info("Starting training!")
trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)


if __name__ == "__main__":
main()
Loading

0 comments on commit 2e7412a

Please sign in to comment.