From d839bdf42ec6c1de25a3467a72d96510d22c8925 Mon Sep 17 00:00:00 2001 From: Rudolph Pienaar Date: Fri, 19 Apr 2024 18:09:58 -0400 Subject: [PATCH] Almost final verification code --- spleenseg/core/neuralnet.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/spleenseg/core/neuralnet.py b/spleenseg/core/neuralnet.py index 8be7147..d412116 100644 --- a/spleenseg/core/neuralnet.py +++ b/spleenseg/core/neuralnet.py @@ -42,8 +42,8 @@ # from monai.data.dataset import Dataset from monai.data.utils import decollate_batch from monai.data.meta_tensor import MetaTensor +from monai.handlers.utils import from_engine -# from monai.config.deviceconfig import print_config # from monai.apps.utils import download_and_extract import torch @@ -52,7 +52,7 @@ # import shutil # import glob # import pudb -from typing import Any +from typing import Any, Sequence import numpy as np from spleenseg.transforms import transforms @@ -296,7 +296,7 @@ def train( ) def inference_metricsProcess(self) -> float: - metric: float = self.network.dice_metric.aggregate().item() + metric: float = self.network.dice_metric.aggregate().item() # type: ignore self.trainingLog.metric_per_epoch.append(metric) self.network.dice_metric.reset() if metric > self.trainingLog.best_metric: @@ -408,6 +408,12 @@ def plot_bestModel( ) return 0.0 + def bestModel_runOverValidationSpace(self): + self.network.model.load_state_dict( + torch.load(str(self.trainingParams.modelPth)) + ) + self.slidingWindowInference_do(self.validationSpace, self.plot_bestModel) + def diceMetric_onValidationSpacing( self, sample: dict[str, MetaTensor | torch.Tensor], @@ -415,16 +421,25 @@ def diceMetric_onValidationSpacing( index: int, result: torch.Tensor, ) -> float: + metric: float = -1.0 sample["pred"] = result sample = [ self.f_outputPost(i) for i in decollate_batch(sample) # type: ignore[arg-type] ] - self.network.model.load_state_dict( - torch.load(str(self.trainingParams.modelPth)) + predictions: torch.Tensor + labels: torch.Tensor + predictions, labels = from_engine(["pred", "label"])(sample) + Dm: torch.Tensor = self.network.dice_metric( + y_pred=predictions, # type: ignore + y=labels, # type: ignore ) - self.slidingWindowInference_do(self.validationSpace, self.plot_bestModel) - return 0.0 + print(f"Best prediction dice metric: {Dm}") + if space.loader.batch_size: + if index == len(space.cache) // space.loader.batch_size: + metric = self.network.dice_metric.aggregate().item() + print(f"metric on original image spacing: {metric}") + return metric def bestModel_evaluateImageSpacings(self, validationTransforms: Compose): self.network.model.load_state_dict( @@ -437,3 +452,6 @@ def bestModel_evaluateImageSpacings(self, validationTransforms: Compose): transforms.f_labelAsDiscreted(), ] ) + self.slidingWindowInference_do( + self.validationSpace, self.diceMetric_onValidationSpacing + )