Skip to content

Commit

Permalink
update trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
LoryWangxx committed Feb 22, 2024
1 parent fa96565 commit 89108a1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 66 deletions.
138 changes: 74 additions & 64 deletions finetuna/finetuner_utils/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ocpmodels.trainers.forces_trainer import ForcesTrainer
from ocpmodels.trainers.ocp_trainer import OCPTrainer
from ocpmodels.datasets.lmdb_dataset import data_list_collater
from ocpmodels.common.utils import setup_imports, setup_logging
from ocpmodels.common.utils import setup_imports, setup_logging, update_config
from ocpmodels.common import distutils
from ocpmodels.common.registry import registry
import logging
import yaml
from ocpmodels.preprocessing import AtomsToGraphs
Expand All @@ -17,82 +18,91 @@
)


class Trainer(ForcesTrainer):
def __init__(self, config_yml=None, checkpoint=None, cutoff=6, max_neighbors=50):
class Trainer(OCPTrainer):
def __init__(self, config_yml=None, checkpoint_path=None, cutoff=6, max_neighbors=50):
setup_imports()
setup_logging()

# Either the config path or the checkpoint path needs to be provided
assert config_yml or checkpoint is not None
assert config_yml or checkpoint_path is not None

checkpoint = None
if config_yml is not None:
if isinstance(config_yml, str):
config = yaml.safe_load(open(config_yml, "r"))

if "includes" in config:
for include in config["includes"]:
# Change the path based on absolute path of config_yml
path = os.path.join(config_yml.split("configs")[0], include)
include_config = yaml.safe_load(open(path, "r"))
config.update(include_config)
config, duplicates_warning, duplicates_error = load_config(
config_yml
)
if len(duplicates_warning) > 0:
logging.warning(
f"Overwritten config parameters from included configs "
f"(non-included parameters take precedence): {duplicates_warning}"
)
if len(duplicates_error) > 0:
raise ValueError(
f"Conflicting (duplicate) parameters in simultaneously "
f"included configs: {duplicates_error}"
)
else:
config = config_yml
else:
# Loads the config from the checkpoint directly
config = torch.load(checkpoint, map_location=torch.device("cpu"))["config"]

# Load the trainer based on the dataset used
if config["task"]["dataset"] == "trajectory_lmdb":
config["trainer"] = "forces"
else:
config["trainer"] = "energy"

# Only keeps the train data that might have normalizer values
# if isinstance(config["dataset"], list):
# config["dataset"] = config["dataset"][0]
# elif isinstance(config["dataset"], dict):
# config["dataset"] = config["dataset"].get("train", None)
else:
# Loads the config from the checkpoint directly (always on CPU).
checkpoint = torch.load(
checkpoint_path, map_location=torch.device("cpu")
)
config = checkpoint["config"]

# if trainer is not None:
# config["trainer"] = trainer
# else:
config["trainer"] = config.get("trainer", "ocp")

if "model_attributes" in config:
config["model_attributes"]["name"] = config.pop("model")
config["model"] = config["model_attributes"]

# Calculate the edge indices on the fly
self.otf_graph = True
config["model"]["otf_graph"] = self.otf_graph
# for checkpoints with relaxation datasets defined, remove to avoid
# unnecesarily trying to load that dataset
if "relax_dataset" in config["task"]:
del config["task"]["relax_dataset"]

# delete scale file entry in config before loading (remove me if this causes problems in the future)
config.get("model", {}).pop("scale_file", None)
# Calculate the edge indices on the fly
config["model"]["otf_graph"] = True

# Save config so obj can be transported over network (pkl)
config = update_config(config)
self.config = copy.deepcopy(config)
self.config["checkpoint"] = checkpoint

if "normalizer" not in config:
if config["dataset"] is not None:
del config["dataset"]["src"]
config["normalizer"] = config["dataset"]

identifier = ""
if hasattr(config.get("logger", {}), "get"):
identifier = config.get("logger", {}).get("identifier", "")

self.config["checkpoint"] = checkpoint_path
del config["dataset"]["src"]
super().__init__(
task=config["task"],
model=config["model"],
dataset=None,
dataset=[config["dataset"]],
outputs=config["outputs"],
loss_fns=config["loss_fns"],
eval_metrics=config["eval_metrics"],
optimizer=config["optim"],
identifier=identifier,
normalizer=config["normalizer"],
identifier="",
slurm=config.get("slurm", {}),
local_rank=config.get("local_rank", 0),
logger=config.get("logger", None),
print_every=config.get("print_every", 1),
is_debug=config.get("is_debug", True),
cpu=config.get("cpu", True),
amp=config.get("amp", False),
)

# if loading a model with added blocks for training from the checkpoint, set strict loading to False
if self.config["model"] in ["adapter_gemnet_t", "adapter_gemnet_oc"]:
self.model.load_state_dict.__func__.__defaults__ = (False,)

# load checkpoint
if checkpoint is not None:
if checkpoint_path is not None:
try:
self.load_checkpoint(checkpoint)
self.load_checkpoint(checkpoint_path)
except NotImplementedError:
logging.warning("Unable to load checkpoint!")

Expand Down Expand Up @@ -350,22 +360,22 @@ def closure():
if "test_dataset" in self.config:
self.test_dataset.close_db()

def load_loss(self):
self.loss_fn = {}
self.loss_fn["energy"] = self.config["optim"].get("loss_energy", "mae")
self.loss_fn["force"] = self.config["optim"].get("loss_force", "mae")
for loss, loss_name in self.loss_fn.items():
if loss_name in ["l1", "mae"]:
self.loss_fn[loss] = nn.L1Loss()
elif loss_name == "mse":
self.loss_fn[loss] = nn.MSELoss()
elif loss_name == "l2mae":
self.loss_fn[loss] = L2MAELoss()
elif loss_name == "rell2mae":
self.loss_fn[loss] = RelativeL2MAELoss()
elif loss_name == "atomwisel2":
self.loss_fn[loss] = AtomwiseL2LossNoBatch()
else:
raise NotImplementedError(f"Unknown loss function name: {loss_name}")
if distutils.initialized():
self.loss_fn[loss] = DDPLoss(self.loss_fn[loss])
# def load_loss(self):
# self.loss_fn = {}
# self.loss_fn["energy"] = self.config["optim"].get("loss_energy", "mae")
# self.loss_fn["force"] = self.config["optim"].get("loss_force", "mae")
# for loss, loss_name in self.loss_fn.items():
# if loss_name in ["l1", "mae"]:
# self.loss_fn[loss] = nn.L1Loss()
# elif loss_name == "mse":
# self.loss_fn[loss] = nn.MSELoss()
# elif loss_name == "l2mae":
# self.loss_fn[loss] = L2MAELoss()
# elif loss_name == "rell2mae":
# self.loss_fn[loss] = RelativeL2MAELoss()
# elif loss_name == "atomwisel2":
# self.loss_fn[loss] = AtomwiseL2LossNoBatch()
# else:
# raise NotImplementedError(f"Unknown loss function name: {loss_name}")
# if distutils.initialized():
# self.loss_fn[loss] = DDPLoss(self.loss_fn[loss])
5 changes: 3 additions & 2 deletions finetuna/ml_potentials/finetuner_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,13 @@ def load_trainer(self):
"""
# make a copy of the config dict so the trainer doesn't edit the original
config_dict = copy.deepcopy(self.mlp_params)

print(config_dict["dataset"])
print("---------------------------")
# initialize trainer
sys.stdout = open(os.devnull, "w")
self.trainer = Trainer(
config_yml=config_dict,
checkpoint=self.checkpoint_path,
checkpoint_path=self.checkpoint_path,
cutoff=self.cutoff,
max_neighbors=self.max_neighbors,
)
Expand Down

0 comments on commit 89108a1

Please sign in to comment.