diff --git a/train_segmentation.py b/train_segmentation.py index cf2a7ca0..3376e68c 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -797,13 +797,15 @@ def train(cfg: DictConfig) -> None: # Script model if scriptmodel: + logging.info(f'\nScripting model...') model_to_script = ScriptModel(model, device=device, + num_classes=num_classes, input_shape=(1, num_bands, patches_size, patches_size), mean=mean, std=std, - min=scale[0], - max=scale[1]) + scaled_min=scale[0], + scaled_max=scale[1]) scripted_model = torch.jit.script(model_to_script) scripted_model.save(output_path.joinpath('scripted_model.pt'))