diff --git a/finetuna/finetuner_utils/trainer.py b/finetuna/finetuner_utils/trainer.py index 727c597..3fe5652 100644 --- a/finetuna/finetuner_utils/trainer.py +++ b/finetuna/finetuner_utils/trainer.py @@ -72,8 +72,8 @@ def __init__(self, config_yml=None, checkpoint_path=None, cutoff=6, max_neighbor del config["task"]["relax_dataset"] # Calculate the edge indices on the fly - config["model"]["otf_graph"] = True - + self.otf_graph = True + config["model"]["otf_graph"] = self.otf_graph # Save config so obj can be transported over network (pkl) config = update_config(config) self.config = copy.deepcopy(config) @@ -169,23 +169,85 @@ def save( checkpoint["config"]["normalizer"] = self.normalizer torch.save(checkpoint, checkpoint_path) return checkpoint_path + + def _compute_loss(self, out, batch): + batch_size = batch.natoms.numel() + fixed = batch.fixed + mask = fixed == 0 + + loss = [] + for loss_fn in self.loss_fns: + target_name, loss_info = loss_fn + + target = batch[target_name] + pred = out[target_name] + natoms = batch.natoms + natoms = torch.repeat_interleave(natoms, natoms) + + if ( + self.output_targets[target_name]["level"] == "atom" + and self.output_targets[target_name]["train_on_free_atoms"] + ): + target = target[mask] + pred = pred[mask] + natoms = natoms[mask] + + num_atoms_in_batch = natoms.numel() + if self.normalizers.get(target_name, False): + target = self.normalizers[target_name].norm(target) + + ### reshape accordingly: num_atoms_in_batch, -1 or num_systems_in_batch, -1 + if self.output_targets[target_name]["level"] == "atom": + target = target.view(num_atoms_in_batch, -1) + else: + target = target.view(batch_size, -1) + + mult = loss_info["coefficient"] + loss.append( + mult + * loss_info["fn"]( + pred, + target, + natoms=natoms, + batch_size=batch_size, + ) + ) - def train(self, disable_eval_tqdm=False): - eval_every = self.config["optim"].get("eval_every", None) - if eval_every is None: - eval_every = len(self.train_loader) - checkpoint_every = self.config["optim"].get("checkpoint_every", eval_every) - primary_metric = self.config["task"].get( + # Sanity check to make sure the compute graph is correct. + for lc in loss: + assert hasattr(lc, "grad_fn") + + loss = sum(loss) + return loss + + def train(self, disable_eval_tqdm: bool = False) -> None: + # ensure_fitted(self._unwrapped_model, warn=True) + + eval_every = self.config["optim"].get( + "eval_every", len(self.train_loader) + ) + checkpoint_every = self.config["optim"].get( + "checkpoint_every", eval_every + ) + primary_metric = self.evaluation_metrics.get( "primary_metric", self.evaluator.task_primary_metric[self.name] ) - self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 + if ( + not hasattr(self, "primary_metric") + or self.primary_metric != primary_metric + ): + self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 + else: + primary_metric = self.primary_metric self.metrics = {} # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. start_epoch = self.step // len(self.train_loader) - for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): + for epoch_int in range( + start_epoch, self.config["optim"]["max_epochs"] + ): self.train_sampler.set_epoch(epoch_int) skip_steps = self.step % len(self.train_loader) train_loader_iter = iter(self.train_loader) @@ -198,32 +260,10 @@ def train(self, disable_eval_tqdm=False): # Get a batch. batch = next(train_loader_iter) - if self.config["optim"]["optimizer"] == "LBFGS": - - def closure(): - self.optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=self.scaler is not None): - out = self._forward(batch) - loss = self._compute_loss(out, batch) - loss.backward() - return loss - - self.optimizer.step(closure) - - self.optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=self.scaler is not None): - out = self._forward(batch) - loss = self._compute_loss(out, batch) - - else: - # Forward, loss, backward. - with torch.cuda.amp.autocast(enabled=self.scaler is not None): - out = self._forward(batch) - loss = self._compute_loss(out, batch) - loss = self.scaler.scale(loss) if self.scaler else loss - self._backward(loss) - - scale = self.scaler.get_scale() if self.scaler else 1.0 + # Forward, loss, backward. + with torch.cuda.amp.autocast(enabled=self.scaler is not None): + out = self._forward(batch) + loss = self._compute_loss(out, batch) # Compute metrics. self.metrics = self._compute_metrics( @@ -233,9 +273,12 @@ def closure(): self.metrics, ) self.metrics = self.evaluator.update( - "loss", loss.item() / scale, self.metrics + "loss", loss.item(), self.metrics ) + loss = self.scaler.scale(loss) if self.scaler else loss + self._backward(loss) + # Log metrics. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( @@ -248,9 +291,10 @@ def closure(): if ( self.step % self.config["cmd"]["print_every"] == 0 and distutils.is_master() - and not self.is_hpo ): - log_str = ["{}: {:.2e}".format(k, v) for k, v in log_dict.items()] + log_str = [ + "{}: {:.2e}".format(k, v) for k, v in log_dict.items() + ] logging.info(", ".join(log_str)) self.metrics = {} @@ -261,16 +305,16 @@ def closure(): split="train", ) - if checkpoint_every != -1 and self.step % checkpoint_every == 0: - self.save(checkpoint_file="checkpoint.pt", training_state=True) + if ( + checkpoint_every != -1 + and self.step % checkpoint_every == 0 + ): + self.save( + checkpoint_file="checkpoint.pt", training_state=True + ) # Evaluate on val set every `eval_every` iterations. if self.step % eval_every == 0: - if self.test_loader is not None: - test_metrics = self.validate( - split="test", - disable_tqdm=disable_eval_tqdm, - ) if self.val_loader is not None: val_metrics = self.validate( split="val", @@ -281,13 +325,6 @@ def closure(): val_metrics, disable_eval_tqdm=disable_eval_tqdm, ) - if self.is_hpo: - self.hpo_update( - self.epoch, - self.step, - self.metrics, - val_metrics, - ) if self.config["task"].get("eval_relaxations", False): if "relax_dataset" not in self.config["task"]: @@ -297,35 +334,6 @@ def closure(): else: self.run_relaxations() - if self.config["optim"].get("print_loss_and_lr", False): - if self.step % eval_every == 0 or not self.config["optim"].get( - "print_only_on_eval", True - ): - if self.val_loader is not None: - print( - "epoch: " - + "{:.1f}".format(self.epoch) - + ", \tstep: " - + str(self.step) - + ", \tloss: " - + str(loss.detach().item()) - + ", \tlr: " - + str(self.scheduler.get_lr()) - + ", \tval: " - + str(val_metrics["loss"]["metric"]) - ) - else: - print( - "epoch: " - + "{:.1f}".format(self.epoch) - + ", \tstep: " - + str(self.step) - + ", \tloss: " - + str(loss.detach().item()) - + ", \tlr: " - + str(self.scheduler.get_lr()) - ) - if self.scheduler.scheduler_type == "ReduceLROnPlateau": if ( self.step % eval_every == 0 @@ -341,24 +349,205 @@ def closure(): else: self.scheduler.step() - break_below_lr = ( - self.config["optim"].get("break_below_lr", None) is not None - ) and (self.scheduler.get_lr() < self.config["optim"]["break_below_lr"]) - if break_below_lr: - break - if break_below_lr: - break - torch.cuda.empty_cache() if checkpoint_every == -1: self.save(checkpoint_file="checkpoint.pt", training_state=True) self.train_dataset.close_db() - if "val_dataset" in self.config: + if self.config.get("val_dataset", False): self.val_dataset.close_db() - if "test_dataset" in self.config: + if self.config.get("test_dataset", False): self.test_dataset.close_db() + # def train(self, disable_eval_tqdm=False): + # eval_every = self.config["optim"].get("eval_every", None) + # if eval_every is None: + # eval_every = len(self.train_loader) + # checkpoint_every = self.config["optim"].get("checkpoint_every", eval_every) + # primary_metric = self.config["task"].get( + # "primary_metric", self.evaluator.task_primary_metric[self.name] + # ) + # self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 + # self.metrics = {} + + # # Calculate start_epoch from step instead of loading the epoch number + # # to prevent inconsistencies due to different batch size in checkpoint. + # start_epoch = self.step // len(self.train_loader) + + # for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): + # self.train_sampler.set_epoch(epoch_int) + # skip_steps = self.step % len(self.train_loader) + # train_loader_iter = iter(self.train_loader) + + # for i in range(skip_steps, len(self.train_loader)): + # self.epoch = epoch_int + (i + 1) / len(self.train_loader) + # self.step = epoch_int * len(self.train_loader) + i + 1 + # self.model.train() + + # # Get a batch. + # batch = next(train_loader_iter) + + # if self.config["optim"]["optimizer"] == "LBFGS": + + # def closure(): + # self.optimizer.zero_grad() + # with torch.cuda.amp.autocast(enabled=self.scaler is not None): + # out = self._forward(batch) + # loss = self._compute_loss(out, batch) + # loss.backward() + # return loss + + # self.optimizer.step(closure) + + # self.optimizer.zero_grad() + # with torch.cuda.amp.autocast(enabled=self.scaler is not None): + # out = self._forward(batch) + # loss = self._compute_loss(out, batch) + + # else: + # # Forward, loss, backward. + # with torch.cuda.amp.autocast(enabled=self.scaler is not None): + # out = self._forward(batch) + # loss = self._compute_loss(out, batch) + # loss = self.scaler.scale(loss) if self.scaler else loss + # self._backward(loss) + + # scale = self.scaler.get_scale() if self.scaler else 1.0 + + # # Compute metrics. + # self.metrics = self._compute_metrics( + # out, + # batch, + # self.evaluator, + # self.metrics, + # ) + # self.metrics = self.evaluator.update( + # "loss", loss.item() / scale, self.metrics + # ) + + # # Log metrics. + # log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} + # log_dict.update( + # { + # "lr": self.scheduler.get_lr(), + # "epoch": self.epoch, + # "step": self.step, + # } + # ) + # if ( + # self.step % self.config["cmd"]["print_every"] == 0 + # and distutils.is_master() + # and not self.is_hpo + # ): + # log_str = ["{}: {:.2e}".format(k, v) for k, v in log_dict.items()] + # logging.info(", ".join(log_str)) + # self.metrics = {} + + # if self.logger is not None: + # self.logger.log( + # log_dict, + # step=self.step, + # split="train", + # ) + + # if checkpoint_every != -1 and self.step % checkpoint_every == 0: + # self.save(checkpoint_file="checkpoint.pt", training_state=True) + + # # Evaluate on val set every `eval_every` iterations. + # if self.step % eval_every == 0: + # if self.test_loader is not None: + # test_metrics = self.validate( + # split="test", + # disable_tqdm=disable_eval_tqdm, + # ) + # if self.val_loader is not None: + # val_metrics = self.validate( + # split="val", + # disable_tqdm=disable_eval_tqdm, + # ) + # self.update_best( + # primary_metric, + # val_metrics, + # disable_eval_tqdm=disable_eval_tqdm, + # ) + # if self.is_hpo: + # self.hpo_update( + # self.epoch, + # self.step, + # self.metrics, + # val_metrics, + # ) + + # if self.config["task"].get("eval_relaxations", False): + # if "relax_dataset" not in self.config["task"]: + # logging.warning( + # "Cannot evaluate relaxations, relax_dataset not specified" + # ) + # else: + # self.run_relaxations() + + # if self.config["optim"].get("print_loss_and_lr", False): + # if self.step % eval_every == 0 or not self.config["optim"].get( + # "print_only_on_eval", True + # ): + # if self.val_loader is not None: + # print( + # "epoch: " + # + "{:.1f}".format(self.epoch) + # + ", \tstep: " + # + str(self.step) + # + ", \tloss: " + # + str(loss.detach().item()) + # + ", \tlr: " + # + str(self.scheduler.get_lr()) + # + ", \tval: " + # + str(val_metrics["loss"]["metric"]) + # ) + # else: + # print( + # "epoch: " + # + "{:.1f}".format(self.epoch) + # + ", \tstep: " + # + str(self.step) + # + ", \tloss: " + # + str(loss.detach().item()) + # + ", \tlr: " + # + str(self.scheduler.get_lr()) + # ) + + # if self.scheduler.scheduler_type == "ReduceLROnPlateau": + # if ( + # self.step % eval_every == 0 + # and self.config["optim"].get("scheduler_loss", None) == "train" + # ): + # self.scheduler.step( + # metrics=loss.detach().item(), + # ) + # elif self.step % eval_every == 0 and self.val_loader is not None: + # self.scheduler.step( + # metrics=val_metrics[primary_metric]["metric"], + # ) + # else: + # self.scheduler.step() + + # break_below_lr = ( + # self.config["optim"].get("break_below_lr", None) is not None + # ) and (self.scheduler.get_lr() < self.config["optim"]["break_below_lr"]) + # if break_below_lr: + # break + # if break_below_lr: + # break + + # torch.cuda.empty_cache() + + # if checkpoint_every == -1: + # self.save(checkpoint_file="checkpoint.pt", training_state=True) + + # self.train_dataset.close_db() + # if "val_dataset" in self.config: + # self.val_dataset.close_db() + # if "test_dataset" in self.config: + # self.test_dataset.close_db() # def load_loss(self): # self.loss_fn = {}