From 4b10b28a233cca82a848187cf9d6cdd772c6eee4 Mon Sep 17 00:00:00 2001 From: Amir Mann Date: Fri, 3 Jan 2025 21:59:09 +0200 Subject: [PATCH] Finetuning option added and legacy bf replaced with os.path --- train/training_loop.py | 37 +++++++++++++++++++++---------------- utils/parser_util.py | 2 ++ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/train/training_loop.py b/train/training_loop.py index b3cc24ba..21dc8df2 100644 --- a/train/training_loop.py +++ b/train/training_loop.py @@ -5,7 +5,6 @@ from types import SimpleNamespace import numpy as np -import blobfile as bf import torch from torch.optim import AdamW @@ -18,6 +17,7 @@ from data_loaders.humanml.networks.evaluator_wrapper import EvaluatorMDMWrapper from eval import eval_humanml, eval_humanact12_uestc from data_loaders.get_data import get_dataset_loader +from utils.misc import load_model_wo_clip # For ImageNet experiments, this was a good default value. @@ -41,6 +41,7 @@ def __init__(self, args, train_platform, model, diffusion, data): self.log_interval = args.log_interval self.save_interval = args.save_interval self.resume_checkpoint = args.resume_checkpoint + self.fine_tunning = args.fine_tunning self.use_fp16 = False # deprecating this option self.fp16_scale_growth = 1e-3 # deprecating this option self.weight_decay = args.weight_decay @@ -104,20 +105,26 @@ def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint if resume_checkpoint: - self.resume_step = parse_resume_step_from_filename(resume_checkpoint) - logger.log(f"loading model from checkpoint: {resume_checkpoint}...") - self.model.load_state_dict( - dist_util.load_state_dict( - resume_checkpoint, map_location=dist_util.dev() + if not self.fine_tunning: + self.resume_step = parse_resume_step_from_filename(resume_checkpoint) + logger.log(f"loading model from checkpoint: {resume_checkpoint}...") + self.model.load_state_dict( + dist_util.load_state_dict( + resume_checkpoint, map_location=dist_util.dev() + ) ) - ) + else: + logger.log(f"loading model (for fine tunning!) from checkpoint: {resume_checkpoint}...") + state_dict = torch.load(resume_checkpoint, map_location='cpu') + load_model_wo_clip(self.model, state_dict) + self.model.to(dist_util.dev()) def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint - opt_checkpoint = bf.join( - bf.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt" + opt_checkpoint = os.path.join( + os.path.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt" ) - if bf.exists(opt_checkpoint): + if os.path.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") state_dict = dist_util.load_state_dict( opt_checkpoint, map_location=dist_util.dev() @@ -127,7 +134,7 @@ def _load_optimizer_state(self): def run_loop(self): for epoch in range(self.num_epochs): - print(f'Starting epoch {epoch}') + print(f'Starting epoch {epoch} / {self.num_epochs}') for motion, cond in tqdm(self.data): if not (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): break @@ -274,15 +281,13 @@ def save_checkpoint(params): logger.log(f"saving model...") filename = self.ckpt_file_name() - with bf.BlobFile(bf.join(self.save_dir, filename), "wb") as f: + with open(os.path.join(self.save_dir, filename), "wb") as f: torch.save(state_dict, f) save_checkpoint(self.mp_trainer.master_params) - with bf.BlobFile( - bf.join(self.save_dir, f"opt{(self.step+self.resume_step):09d}.pt"), - "wb", - ) as f: + opt_cp_path = os.path.join(self.save_dir, f"opt{(self.step+self.resume_step):09d}.pt") + with open(opt_cp_path, "wb",) as f: torch.save(self.opt.state_dict(), f) diff --git a/utils/parser_util.py b/utils/parser_util.py index 97bba12f..c198ebb9 100644 --- a/utils/parser_util.py +++ b/utils/parser_util.py @@ -136,6 +136,8 @@ def add_training_options(parser): help="Limit for the maximal number of frames. In HumanML3D and KIT this field is ignored.") group.add_argument("--resume_checkpoint", default="", type=str, help="If not empty, will start from the specified checkpoint (path to model###.pt file).") + group.add_argument("--fine_tunning", action='store_true', + help="If True, will not load clip and the optimizer state from the checkpoint, allowing to finetune a model checkpoint.") def add_sampling_options(parser):