Skip to content

Commit

Permalink
Clean up the code to avoid repeats
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 17, 2024
1 parent 3a01e26 commit d863fbd
Show file tree
Hide file tree
Showing 35 changed files with 2,073 additions and 1,303 deletions.
3 changes: 2 additions & 1 deletion examples/_test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
scikit-learn
jupyter
jupyter
evaluate
19 changes: 10 additions & 9 deletions examples/cifar/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import math
from typing import Dict, List, Optional, Tuple
from typing import List, Optional

import datasets
import numpy as np
Expand Down Expand Up @@ -33,15 +33,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def construct_resnet9() -> nn.Module:
# ResNet-9 architecture from: https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb.
def conv_bn(
channels_in: int,
channels_out: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
groups=1,
groups: int = 1,
) -> nn.Module:
assert groups == 1
return torch.nn.Sequential(
torch.nn.Conv2d(
channels_in,
Expand Down Expand Up @@ -74,9 +74,9 @@ def conv_bn(

def get_cifar10_dataset(
split: str,
do_corrupt: bool,
indices: List[int] = None,
data_dir: str = "data/",
corrupt_percentage: Optional[float] = None,
dataset_dir: str = "data/",
) -> datasets.Dataset:
assert split in ["train", "eval_train", "valid"]

Expand All @@ -99,16 +99,17 @@ def get_cifar10_dataset(
)

dataset = torchvision.datasets.CIFAR10(
root=data_dir,
root=dataset_dir,
download=True,
train=split in ["train", "eval_train", "eval_train_with_aug"],
train=split in ["train", "eval_train"],
transform=transform_config,
)

if do_corrupt:
if corrupt_percentage is not None:
if split == "valid":
raise NotImplementedError("Performing corruption on the validation dataset is not supported.")
num_corrupt = math.ceil(len(dataset) * 0.1)
assert 0.0 < corrupt_percentage <= 1.0
num_corrupt = math.ceil(len(dataset) * corrupt_percentage)
original_targets = np.array(copy.deepcopy(dataset.targets[:num_corrupt]))
new_targets = torch.randint(
0,
Expand Down
190 changes: 190 additions & 0 deletions examples/cifar/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import argparse
import logging
import os
from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F
from accelerate.utils import set_seed
from torch import nn
from torch.optim import lr_scheduler
from torch.utils import data
from tqdm import tqdm

from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def parse_args():
parser = argparse.ArgumentParser(description="Train ResNet-9 model on CIFAR-10 dataset.")

parser.add_argument(
"--corrupt_percentage",
type=float,
default=None,
help="Percentage of the training dataset to corrupt.",
)
parser.add_argument(
"--dataset_dir",
type=str,
default="./data",
help="A folder to download or load CIFAR-10 dataset.",
)

parser.add_argument(
"--train_batch_size",
type=int,
default=512,
help="Batch size for the training dataloader.",
)
parser.add_argument(
"--eval_batch_size",
type=int,
default=1024,
help="Batch size for the evaluation dataloader.",
)

parser.add_argument(
"--learning_rate",
type=float,
default=0.4,
help="Initial learning rate to train the model.",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.001,
help="Weight decay to train the model.",
)
parser.add_argument(
"--num_train_epochs",
type=int,
default=25,
help="Total number of epochs to train the model.",
)

parser.add_argument(
"--seed",
type=int,
default=1004,
help="A seed for reproducible training pipeline.",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="./checkpoints",
help="A path to store the final checkpoint.",
)

args = parser.parse_args()

if args.checkpoint_dir is not None:
os.makedirs(args.checkpoint_dir, exist_ok=True)

return args


def train(
dataset: data.Dataset,
batch_size: int,
num_train_epochs: int,
learning_rate: float,
weight_decay: float,
disable_tqdm: bool = False,
) -> nn.Module:
train_dataloader = data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)

model = construct_resnet9().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

iters_per_epoch = len(train_dataloader)
lr_peak_epoch = num_train_epochs // 4
lr_schedule = np.interp(
np.arange((num_train_epochs + 1) * iters_per_epoch),
[0, lr_peak_epoch * iters_per_epoch, num_train_epochs * iters_per_epoch],
[0, 1, 0],
)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__)

model.train()
for epoch in range(num_train_epochs):
total_loss = 0.0
with tqdm(train_dataloader, unit="batch", disable=disable_tqdm) as tepoch:
for batch in tepoch:
tepoch.set_description(f"Epoch {epoch}")
model.zero_grad()
inputs, labels = batch
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.detach().float()
tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader))
return model


def evaluate(model: nn.Module, dataset: data.Dataset, batch_size: int) -> Tuple[float, float]:
dataloader = data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
)

model.eval()
total_loss, total_correct = 0.0, 0
for batch in dataloader:
with torch.no_grad():
inputs, labels = batch
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
loss = F.cross_entropy(outputs, labels, reduction="sum")
total_loss += loss.detach().float()
total_correct += outputs.detach().argmax(1).eq(labels).sum()

return total_loss.item() / len(dataloader.dataset), total_correct.item() / len(dataloader.dataset)


def main():
args = parse_args()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

if args.seed is not None:
set_seed(args.seed)

train_dataset = get_cifar10_dataset(split="train", corrupt_percentage=args.corrupt_percentage, dataset_dir=args.dataset_dir)
model = train(
dataset=train_dataset,
batch_size=args.train_batch_size,
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
)

eval_train_dataset = get_cifar10_dataset(split="eval_train", dataset_dir=args.dataset_dir)
train_loss, train_acc = evaluate(model=model, dataset=eval_train_dataset, batch_size=args.eval_batch_size)
logger.info(f"Train loss: {train_loss}, Train Accuracy: {train_acc}")

eval_dataset = get_cifar10_dataset(split="valid", dataset_dir=args.dataset_dir)
eval_loss, eval_acc = evaluate(model=model, dataset=eval_dataset, batch_size=args.eval_batch_size)
logger.info(f"Evaluation loss: {eval_loss}, Evaluation Accuracy: {eval_acc}")

if args.checkpoint_dir is not None:
model_name = "model"
if args.corrupt_percentage is not None:
model_name += "_corrupt_" + str(args.corrupt_percentage)
torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, f"{model_name}.pth"))


if __name__ == "__main__":
main()
7 changes: 3 additions & 4 deletions examples/glue/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
}


def construct_bert(data_name) -> nn.Module:
def construct_bert(data_name: str = "sst2") -> nn.Module:
config = AutoConfig.from_pretrained(
"bert-base-cased",
num_labels=2,
finetuning_task=data_name,
trust_remote_code=True,
)

return AutoModelForSequenceClassification.from_pretrained(
"bert-base-cased",
from_tf=False,
Expand All @@ -42,14 +41,14 @@ def get_glue_dataset(
data_name: str,
split: str,
indices: List[int] = None,
data_path: str = "data/",
dataset_dir: str = "data/",
) -> Dataset:
assert split in ["train", "eval_train", "valid"]

raw_datasets = load_dataset(
path="glue",
name=data_name,
data_dir=data_path,
# data_dir=dataset_dir,
)
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
Expand Down
Loading

0 comments on commit d863fbd

Please sign in to comment.