Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/NRCan/geo-deep-learning
Browse files Browse the repository at this point in the history
…into geo-inference
  • Loading branch information
valhassan committed May 16, 2024
2 parents e7d610e + 41c7710 commit ce92822
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 0 deletions.
1 change: 1 addition & 0 deletions config/training/default_training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ training:
max_used_perc:
state_dict_path:
state_dict_strict_load: True
script_model: False
compute_sampler_weights: False

# precision: 16
Expand Down
14 changes: 14 additions & 0 deletions tests/utils/test_script_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
import pytest
from utils.script_model import ScriptModel

def test_script_model():
model = torch.nn.Linear(3, 1)
script_model = ScriptModel(model,
input_shape=(1, 3),)

input_tensor = torch.rand((1, 3))
output = script_model.forward(input_tensor)

assert output.shape == (1, 1)
assert isinstance(output, torch.Tensor)
15 changes: 15 additions & 0 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tiling_segmentation import Tiler
from utils import augmentation as aug
from dataset import create_dataset
from utils.script_model import ScriptModel
from utils.logger import InformationLogger, tsv_line, get_logger, set_tracker
from utils.loss import verify_weights, define_loss
from utils.metrics import create_metrics_dict, calculate_batch_metrics
Expand Down Expand Up @@ -553,6 +554,7 @@ def train(cfg: DictConfig) -> None:
train_state_dict_path = get_key_def('state_dict_path', cfg['training'], default=None, expected_type=str)
state_dict_strict = get_key_def('state_dict_strict_load', cfg['training'], default=True, expected_type=bool)
dropout_prob = get_key_def('factor', cfg['scheduler']['params'], default=None, expected_type=float)
scriptmodel = get_key_def('script_model', cfg['training'], default=False, expected_type=bool)
# if error
if train_state_dict_path and not Path(train_state_dict_path).is_file():
raise logging.critical(
Expand Down Expand Up @@ -792,6 +794,19 @@ def train(cfg: DictConfig) -> None:

cur_elapsed = time.time() - since
# logging.info(f'\nCurrent elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s')

# Script model
if scriptmodel:
model_to_script = ScriptModel(model,
device=device,
input_shape=(1, num_bands, patches_size, patches_size),
mean=mean,
std=std,
min=scale[0],
max=scale[1])

scripted_model = torch.jit.script(model_to_script)
scripted_model.save(output_path.joinpath('scripted_model.pt'))

# load checkpoint model and evaluate it on test dataset.
if int(cfg['general']['max_epochs']) > 0: # if num_epochs is set to 0, model is loaded to evaluate on test set
Expand Down
32 changes: 32 additions & 0 deletions utils/script_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

class ScriptModel(torch.nn.Module):
def __init__(self,
model,
device = torch.device("cpu"),
input_shape = (1, 3, 512, 512),
mean = [0.405,0.432,0.397],
std = [0.164,0.173,0.153],
min = 0,
max = 255,
scaled_min = 0.0,
scaled_max = 1.0):
super().__init__()
self.device = device
self.mean = torch.tensor(mean).resize_(len(mean), 1)
self.std = torch.tensor(std).resize_(len(std), 1)
self.min = min
self.max = max
self.min_val = scaled_min
self.max_val = scaled_max

input_tensor = torch.rand(input_shape).to(self.device)
self.model_scripted = torch.jit.trace(model.eval(), input_tensor)

def forward(self, input):
shape = input.shape
B, C = shape[0], shape[1]
input = (self.max_val - self.min_val) * (input - self.min) / (self.max -self.min) + self.min_val
input = (input.view(B, C, -1) - self.mean) / self.std
input = input.view(shape)
return self.model_scripted(input.to(self.device))

0 comments on commit ce92822

Please sign in to comment.