From 4cf539158e7b037293c8f8f55a57fa3be42514ac Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Tue, 21 Jan 2025 16:33:00 -0800 Subject: [PATCH] Got resnet working - doing some more investigation to find bugs. --- src/kbmod_ml/data_sets/kbmod_stamps.py | 2 +- src/kbmod_ml/default_config.toml | 4 ++-- train_model.ipynb | 14 +++++++++++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/kbmod_ml/data_sets/kbmod_stamps.py b/src/kbmod_ml/data_sets/kbmod_stamps.py index b58640e..65aac08 100644 --- a/src/kbmod_ml/data_sets/kbmod_stamps.py +++ b/src/kbmod_ml/data_sets/kbmod_stamps.py @@ -28,7 +28,7 @@ def __init__(self, config, split: str): cols = [] - for c in ["mean", "median"]: + for c in ["mean"]: cols.append(coadd_type_to_column[c]) self.active_columns = np.array(cols) diff --git a/src/kbmod_ml/default_config.toml b/src/kbmod_ml/default_config.toml index 56a0326..996e140 100644 --- a/src/kbmod_ml/default_config.toml +++ b/src/kbmod_ml/default_config.toml @@ -1,9 +1,9 @@ [data_set] name = "kbmod_ml.data_sets.kbmod_stamps.KbmodStamps" -train_split = 0.8 +train_size = 0.8 -validation_split = 0.2 +validate_size = 0.2 [kbmod_ml] # The file name of the true positive samples diff --git a/train_model.ipynb b/train_model.ipynb index 2561ac4..01142f6 100644 --- a/train_model.ipynb +++ b/train_model.ipynb @@ -100,9 +100,21 @@ "metadata": {}, "outputs": [], "source": [ + "from fibad import Fibad\n", + "\n", + "fibad_instance = Fibad(config_file=\"./user_config.toml\")\n", + "\n", "# Change the model to the resnet50 model and attempt to train\n", "fibad_instance.config[\"model\"][\"name\"] = \"kbmod_ml.models.resnet50.RESNET50\"\n", - "fibad_instance.config[\"data_loader\"][\"batch_size\"] = 1\n", + "fibad_instance.config[\"train\"][\"epochs\"] = 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "fibad_instance.train()" ] },