From 5f3943cc9bfa3aaf4a9eb04f905e78440ab74cd2 Mon Sep 17 00:00:00 2001 From: valhassan Date: Thu, 24 Oct 2024 08:14:05 -0400 Subject: [PATCH] Refactor input_channels assignment in SegmentationDOFA.export_model() --- geo_deep_learning/tasks_with_models/segmentation_dofa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/geo_deep_learning/tasks_with_models/segmentation_dofa.py b/geo_deep_learning/tasks_with_models/segmentation_dofa.py index bb5969d0..bb4c632f 100644 --- a/geo_deep_learning/tasks_with_models/segmentation_dofa.py +++ b/geo_deep_learning/tasks_with_models/segmentation_dofa.py @@ -80,7 +80,7 @@ def on_train_end(self): self.export_model(best_model_path, best_model_export_path, self.trainer.datamodule) def export_model(self, checkpoint_path: str, export_path: str, datamodule: LightningDataModule): - input_channels = self.hparams["init_args"]["in_channels"] + input_channels = self.hparams["in_channels"] map_location = "cuda" if self.device.type == "cpu": map_location = "cpu"