From 89108a19aa30aed36534e6fd22fa82a650980316 Mon Sep 17 00:00:00 2001 From: LoryWang Date: Thu, 22 Feb 2024 19:21:19 +0000 Subject: [PATCH] update trainer --- finetuna/finetuner_utils/trainer.py | 138 ++++++++++++----------- finetuna/ml_potentials/finetuner_calc.py | 5 +- 2 files changed, 77 insertions(+), 66 deletions(-) diff --git a/finetuna/finetuner_utils/trainer.py b/finetuna/finetuner_utils/trainer.py index 75c4541b..727c597c 100644 --- a/finetuna/finetuner_utils/trainer.py +++ b/finetuna/finetuner_utils/trainer.py @@ -1,7 +1,8 @@ -from ocpmodels.trainers.forces_trainer import ForcesTrainer +from ocpmodels.trainers.ocp_trainer import OCPTrainer from ocpmodels.datasets.lmdb_dataset import data_list_collater -from ocpmodels.common.utils import setup_imports, setup_logging +from ocpmodels.common.utils import setup_imports, setup_logging, update_config from ocpmodels.common import distutils +from ocpmodels.common.registry import registry import logging import yaml from ocpmodels.preprocessing import AtomsToGraphs @@ -17,72 +18,81 @@ ) -class Trainer(ForcesTrainer): - def __init__(self, config_yml=None, checkpoint=None, cutoff=6, max_neighbors=50): +class Trainer(OCPTrainer): + def __init__(self, config_yml=None, checkpoint_path=None, cutoff=6, max_neighbors=50): setup_imports() setup_logging() # Either the config path or the checkpoint path needs to be provided - assert config_yml or checkpoint is not None + assert config_yml or checkpoint_path is not None + checkpoint = None if config_yml is not None: if isinstance(config_yml, str): - config = yaml.safe_load(open(config_yml, "r")) - - if "includes" in config: - for include in config["includes"]: - # Change the path based on absolute path of config_yml - path = os.path.join(config_yml.split("configs")[0], include) - include_config = yaml.safe_load(open(path, "r")) - config.update(include_config) + config, duplicates_warning, duplicates_error = load_config( + config_yml + ) + if len(duplicates_warning) > 0: + logging.warning( + f"Overwritten config parameters from included configs " + f"(non-included parameters take precedence): {duplicates_warning}" + ) + if len(duplicates_error) > 0: + raise ValueError( + f"Conflicting (duplicate) parameters in simultaneously " + f"included configs: {duplicates_error}" + ) else: config = config_yml - else: - # Loads the config from the checkpoint directly - config = torch.load(checkpoint, map_location=torch.device("cpu"))["config"] - - # Load the trainer based on the dataset used - if config["task"]["dataset"] == "trajectory_lmdb": - config["trainer"] = "forces" - else: - config["trainer"] = "energy" + # Only keeps the train data that might have normalizer values + # if isinstance(config["dataset"], list): + # config["dataset"] = config["dataset"][0] + # elif isinstance(config["dataset"], dict): + # config["dataset"] = config["dataset"].get("train", None) + else: + # Loads the config from the checkpoint directly (always on CPU). + checkpoint = torch.load( + checkpoint_path, map_location=torch.device("cpu") + ) + config = checkpoint["config"] + + # if trainer is not None: + # config["trainer"] = trainer + # else: + config["trainer"] = config.get("trainer", "ocp") + + if "model_attributes" in config: config["model_attributes"]["name"] = config.pop("model") config["model"] = config["model_attributes"] - # Calculate the edge indices on the fly - self.otf_graph = True - config["model"]["otf_graph"] = self.otf_graph + # for checkpoints with relaxation datasets defined, remove to avoid + # unnecesarily trying to load that dataset + if "relax_dataset" in config["task"]: + del config["task"]["relax_dataset"] - # delete scale file entry in config before loading (remove me if this causes problems in the future) - config.get("model", {}).pop("scale_file", None) + # Calculate the edge indices on the fly + config["model"]["otf_graph"] = True # Save config so obj can be transported over network (pkl) + config = update_config(config) self.config = copy.deepcopy(config) - self.config["checkpoint"] = checkpoint - - if "normalizer" not in config: - if config["dataset"] is not None: - del config["dataset"]["src"] - config["normalizer"] = config["dataset"] - - identifier = "" - if hasattr(config.get("logger", {}), "get"): - identifier = config.get("logger", {}).get("identifier", "") - + self.config["checkpoint"] = checkpoint_path + del config["dataset"]["src"] super().__init__( task=config["task"], model=config["model"], - dataset=None, + dataset=[config["dataset"]], + outputs=config["outputs"], + loss_fns=config["loss_fns"], + eval_metrics=config["eval_metrics"], optimizer=config["optim"], - identifier=identifier, - normalizer=config["normalizer"], + identifier="", slurm=config.get("slurm", {}), local_rank=config.get("local_rank", 0), - logger=config.get("logger", None), - print_every=config.get("print_every", 1), is_debug=config.get("is_debug", True), cpu=config.get("cpu", True), + amp=config.get("amp", False), ) # if loading a model with added blocks for training from the checkpoint, set strict loading to False @@ -90,9 +100,9 @@ def __init__(self, config_yml=None, checkpoint=None, cutoff=6, max_neighbors=50) self.model.load_state_dict.__func__.__defaults__ = (False,) # load checkpoint - if checkpoint is not None: + if checkpoint_path is not None: try: - self.load_checkpoint(checkpoint) + self.load_checkpoint(checkpoint_path) except NotImplementedError: logging.warning("Unable to load checkpoint!") @@ -350,22 +360,22 @@ def closure(): if "test_dataset" in self.config: self.test_dataset.close_db() - def load_loss(self): - self.loss_fn = {} - self.loss_fn["energy"] = self.config["optim"].get("loss_energy", "mae") - self.loss_fn["force"] = self.config["optim"].get("loss_force", "mae") - for loss, loss_name in self.loss_fn.items(): - if loss_name in ["l1", "mae"]: - self.loss_fn[loss] = nn.L1Loss() - elif loss_name == "mse": - self.loss_fn[loss] = nn.MSELoss() - elif loss_name == "l2mae": - self.loss_fn[loss] = L2MAELoss() - elif loss_name == "rell2mae": - self.loss_fn[loss] = RelativeL2MAELoss() - elif loss_name == "atomwisel2": - self.loss_fn[loss] = AtomwiseL2LossNoBatch() - else: - raise NotImplementedError(f"Unknown loss function name: {loss_name}") - if distutils.initialized(): - self.loss_fn[loss] = DDPLoss(self.loss_fn[loss]) + # def load_loss(self): + # self.loss_fn = {} + # self.loss_fn["energy"] = self.config["optim"].get("loss_energy", "mae") + # self.loss_fn["force"] = self.config["optim"].get("loss_force", "mae") + # for loss, loss_name in self.loss_fn.items(): + # if loss_name in ["l1", "mae"]: + # self.loss_fn[loss] = nn.L1Loss() + # elif loss_name == "mse": + # self.loss_fn[loss] = nn.MSELoss() + # elif loss_name == "l2mae": + # self.loss_fn[loss] = L2MAELoss() + # elif loss_name == "rell2mae": + # self.loss_fn[loss] = RelativeL2MAELoss() + # elif loss_name == "atomwisel2": + # self.loss_fn[loss] = AtomwiseL2LossNoBatch() + # else: + # raise NotImplementedError(f"Unknown loss function name: {loss_name}") + # if distutils.initialized(): + # self.loss_fn[loss] = DDPLoss(self.loss_fn[loss]) diff --git a/finetuna/ml_potentials/finetuner_calc.py b/finetuna/ml_potentials/finetuner_calc.py index 6599dd3a..8d760bcb 100644 --- a/finetuna/ml_potentials/finetuner_calc.py +++ b/finetuna/ml_potentials/finetuner_calc.py @@ -107,12 +107,13 @@ def load_trainer(self): """ # make a copy of the config dict so the trainer doesn't edit the original config_dict = copy.deepcopy(self.mlp_params) - + print(config_dict["dataset"]) + print("---------------------------") # initialize trainer sys.stdout = open(os.devnull, "w") self.trainer = Trainer( config_yml=config_dict, - checkpoint=self.checkpoint_path, + checkpoint_path=self.checkpoint_path, cutoff=self.cutoff, max_neighbors=self.max_neighbors, )