From f962977a5bc973ae6a52e129e987b566035091e3 Mon Sep 17 00:00:00 2001 From: valhassan Date: Tue, 5 Mar 2024 15:33:05 -0500 Subject: [PATCH 1/8] Fix import statement for ruamel.yaml.comments --- inference_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference_segmentation.py b/inference_segmentation.py index 9e75fd71..0ff0007d 100644 --- a/inference_segmentation.py +++ b/inference_segmentation.py @@ -16,7 +16,7 @@ from scipy.special import softmax from collections import OrderedDict from fiona.crs import to_string -from ruamel_yaml.comments import CommentedSeq +from ruamel.yaml.comments import CommentedSeq from tqdm import tqdm from rasterio import features from rasterio.windows import Window From 62bdf79052c12018cfaf7df8f84c7d0792efe407 Mon Sep 17 00:00:00 2001 From: valhassan Date: Tue, 7 May 2024 14:12:19 -0400 Subject: [PATCH 2/8] Add ScriptModel class to utils --- utils/script_model.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 utils/script_model.py diff --git a/utils/script_model.py b/utils/script_model.py new file mode 100644 index 00000000..14139ab4 --- /dev/null +++ b/utils/script_model.py @@ -0,0 +1,32 @@ +import torch + +class ScriptModel(torch.nn.Module): + def __init__(self, + model, + device = torch.device("cpu"), + input_shape = (1, 3, 512, 512), + mean = [0.405,0.432,0.397], + std = [0.164,0.173,0.153], + min = 0, + max = 255, + scaled_min = 0.0, + scaled_max = 1.0): + super().__init__() + self.device = device + self.mean = torch.tensor(mean).resize_(len(mean), 1) + self.std = torch.tensor(std).resize_(len(std), 1) + self.min = min + self.max = max + self.min_val = scaled_min + self.max_val = scaled_max + + input_tensor = torch.rand(input_shape).to(self.device) + self.model_scripted = torch.jit.trace(model.eval(), input_tensor) + + def forward(self, input): + shape = input.shape + B, C = shape[0], shape[1] + input = (self.max_val - self.min_val) * (input - self.min) / (self.max -self.min) + self.min_val + input = (input.view(B, C, -1) - self.mean) / self.std + input = input.view(shape) + return self.model_scripted(input.to(self.device)) \ No newline at end of file From 7ddcd41d54cbff73620a8fb308bb899fa63d1b66 Mon Sep 17 00:00:00 2001 From: valhassan Date: Tue, 7 May 2024 14:12:39 -0400 Subject: [PATCH 3/8] Add script model functionality --- train_segmentation.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/train_segmentation.py b/train_segmentation.py index 664abc82..cf2a7ca0 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -20,6 +20,7 @@ from tiling_segmentation import Tiler from utils import augmentation as aug from dataset import create_dataset +from utils.script_model import ScriptModel from utils.logger import InformationLogger, tsv_line, get_logger, set_tracker from utils.loss import verify_weights, define_loss from utils.metrics import create_metrics_dict, calculate_batch_metrics @@ -553,6 +554,7 @@ def train(cfg: DictConfig) -> None: train_state_dict_path = get_key_def('state_dict_path', cfg['training'], default=None, expected_type=str) state_dict_strict = get_key_def('state_dict_strict_load', cfg['training'], default=True, expected_type=bool) dropout_prob = get_key_def('factor', cfg['scheduler']['params'], default=None, expected_type=float) + scriptmodel = get_key_def('script_model', cfg['training'], default=False, expected_type=bool) # if error if train_state_dict_path and not Path(train_state_dict_path).is_file(): raise logging.critical( @@ -792,6 +794,19 @@ def train(cfg: DictConfig) -> None: cur_elapsed = time.time() - since # logging.info(f'\nCurrent elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s') + + # Script model + if scriptmodel: + model_to_script = ScriptModel(model, + device=device, + input_shape=(1, num_bands, patches_size, patches_size), + mean=mean, + std=std, + min=scale[0], + max=scale[1]) + + scripted_model = torch.jit.script(model_to_script) + scripted_model.save(output_path.joinpath('scripted_model.pt')) # load checkpoint model and evaluate it on test dataset. if int(cfg['general']['max_epochs']) > 0: # if num_epochs is set to 0, model is loaded to evaluate on test set From 3703bb80e3788664672c6dff28bd2c1ee96fa4fe Mon Sep 17 00:00:00 2001 From: valhassan Date: Tue, 7 May 2024 14:13:05 -0400 Subject: [PATCH 4/8] Update gdl_config_template.yaml and default_training.yaml --- config/gdl_config_template.yaml | 2 +- config/training/default_training.yaml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/config/gdl_config_template.yaml b/config/gdl_config_template.yaml index 29d8089a..dcea42be 100644 --- a/config/gdl_config_template.yaml +++ b/config/gdl_config_template.yaml @@ -34,7 +34,7 @@ general: raw_data_dir: data raw_data_csv: tests/tiling/tiling_segmentation_binary_ci.csv tiling_data_dir: ${general.raw_data_dir}/patches # where the patches will be saved - save_weights_dir: saved_model/${general.project_name} + save_weights_dir: print_config: True # save the config in the log folder mode: {verify, tiling, train, inference, evaluate} diff --git a/config/training/default_training.yaml b/config/training/default_training.yaml index 8ae94077..308a019a 100644 --- a/config/training/default_training.yaml +++ b/config/training/default_training.yaml @@ -13,6 +13,7 @@ training: max_used_perc: state_dict_path: state_dict_strict_load: True + script_model: False compute_sampler_weights: False # precision: 16 From bac46a9e03fa01c1d0dfbaf24079e3d02f86cee9 Mon Sep 17 00:00:00 2001 From: valhassan Date: Wed, 8 May 2024 12:46:49 -0400 Subject: [PATCH 5/8] Add mkl dependency to environment.yml --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 157cf02c..f1945ddd 100644 --- a/environment.yml +++ b/environment.yml @@ -13,6 +13,7 @@ dependencies: - pytest>=7.1 - python>=3.10 - pytorch==1.12 + - mkl==2024.0 - rich>=11.1 - ruamel_yaml>=0.15 - scikit-image>=0.18 From e3454c711f5aaf5df3a2e01e88ce0ffbb539fff5 Mon Sep 17 00:00:00 2001 From: valhassan Date: Wed, 8 May 2024 13:13:11 -0400 Subject: [PATCH 6/8] Fix import statement for ruamel_yaml.comments --- inference_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference_segmentation.py b/inference_segmentation.py index 0ff0007d..9e75fd71 100644 --- a/inference_segmentation.py +++ b/inference_segmentation.py @@ -16,7 +16,7 @@ from scipy.special import softmax from collections import OrderedDict from fiona.crs import to_string -from ruamel.yaml.comments import CommentedSeq +from ruamel_yaml.comments import CommentedSeq from tqdm import tqdm from rasterio import features from rasterio.windows import Window From c650f897012d43796802d28a9487226d8903eab1 Mon Sep 17 00:00:00 2001 From: valhassan Date: Thu, 9 May 2024 15:31:35 -0400 Subject: [PATCH 7/8] Update save_weights_dir in gdl_config_template.yaml --- config/gdl_config_template.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/gdl_config_template.yaml b/config/gdl_config_template.yaml index dcea42be..29d8089a 100644 --- a/config/gdl_config_template.yaml +++ b/config/gdl_config_template.yaml @@ -34,7 +34,7 @@ general: raw_data_dir: data raw_data_csv: tests/tiling/tiling_segmentation_binary_ci.csv tiling_data_dir: ${general.raw_data_dir}/patches # where the patches will be saved - save_weights_dir: + save_weights_dir: saved_model/${general.project_name} print_config: True # save the config in the log folder mode: {verify, tiling, train, inference, evaluate} From e67cf747a7579b850060af4892a41f58f2d757bf Mon Sep 17 00:00:00 2001 From: valhassan Date: Mon, 13 May 2024 14:59:35 -0400 Subject: [PATCH 8/8] Add unit test for ScriptModel --- tests/utils/test_script_model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 tests/utils/test_script_model.py diff --git a/tests/utils/test_script_model.py b/tests/utils/test_script_model.py new file mode 100644 index 00000000..3973e531 --- /dev/null +++ b/tests/utils/test_script_model.py @@ -0,0 +1,14 @@ +import torch +import pytest +from utils.script_model import ScriptModel + +def test_script_model(): + model = torch.nn.Linear(3, 1) + script_model = ScriptModel(model, + input_shape=(1, 3),) + + input_tensor = torch.rand((1, 3)) + output = script_model.forward(input_tensor) + + assert output.shape == (1, 1) + assert isinstance(output, torch.Tensor)