Skip to content

Commit

Permalink
* feat: Segmentation report now supports multiclass (#95)
Browse files Browse the repository at this point in the history
* feat: Segmentation report now supports multiclass

* docs: Update CHANGELOG

* build: Update version

* build: Update version

* fix: Setting right vmin vmax in matplotlib plots
  • Loading branch information
AlessandroPolidori authored Jan 23, 2024
1 parent 81d6e4f commit 0c069ac
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 11 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
# Changelog
All notable changes to this project will be documented in this file.

### [1.5.3]

#### Fixed

- Fix multiclass segmentation analysis report.

### [1.5.2]

#### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "quadra"
version = "1.5.2"
version = "1.5.3"
description = "Deep Learning experiment orchestration library"
authors = [
"Federico Belotti <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion quadra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.5.2"
__version__ = "1.5.3"


def get_version():
Expand Down
38 changes: 29 additions & 9 deletions quadra/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import yaml
from segmentation_models_pytorch.losses import DiceLoss
from segmentation_models_pytorch.losses.constants import BINARY_MODE, MULTICLASS_MODE
from skimage.measure import label, regionprops
from skimage.measure import label, regionprops # pylint: disable=no-name-in-module

from quadra.utils.logger import get_logger
from quadra.utils.visualization import UnNormalize, create_grid_figure
Expand Down Expand Up @@ -72,7 +72,7 @@ def score_dice_smp(y_pred: torch.Tensor, y_true: torch.Tensor, mode: str = "bina
Args:
y_pred: 1xCxHxW one channel for each class
y_true: 1x1xHxW true mask with value in [0, ..., num_classes]
y_true: 1x1xHxW true mask with value in [0, ..., n_classes]
mode: "binary" or "multiclass"
Returns:
Expand All @@ -93,6 +93,8 @@ def calculate_mask_based_metrics(
threshold: float = 0.5,
show_orj_predictions: bool = False,
metric: Callable = score_dice,
multilabel: bool = False,
n_classes: Optional[int] = None,
) -> Tuple[
Dict[str, float],
Dict[str, List[np.ndarray]],
Expand All @@ -108,6 +110,8 @@ def calculate_mask_based_metrics(
threshold: Threshold to apply. Defaults to 0.5.
show_orj_predictions: Flag to show original predictions. Defaults to False.
metric: Metric to use comparison. Defaults to `score_dice`.
multilabel: True if segmentation is multiclass.
n_classes: Number of classes. If multilabel is False, this should be None.
Returns:
dict: Dictionary with metrics.
Expand All @@ -118,7 +122,21 @@ def calculate_mask_based_metrics(
thresh_preds = th_thresh_preds.squeeze(0).numpy()
dice_scores = metric(th_thresh_preds, th_masks, reduction=None).numpy()
result = {}
tp, fp, fn, tn = smp.metrics.get_stats(th_thresh_preds.long(), th_masks.long(), mode="binary")
if multilabel:
if n_classes is None:
raise ValueError("n_classes arg shouldn't be None when multilabel is True")
preds_multilabel = (
torch.nn.functional.one_hot(th_preds.to(torch.int64), num_classes=n_classes).squeeze(1).permute(0, 3, 1, 2)
)
masks_multilabel = (
torch.nn.functional.one_hot(th_masks.to(torch.int64), num_classes=n_classes).squeeze(1).permute(0, 3, 1, 2)
).to(preds_multilabel.device)
# get_stats multiclass, not considering background channel
tp, fp, fn, tn = smp.metrics.get_stats(
preds_multilabel[:, 1:, :, :].long(), masks_multilabel[:, 1:, :, :].long(), mode="multilabel"
)
else:
tp, fp, fn, tn = smp.metrics.get_stats(th_thresh_preds.long(), th_masks.long(), mode="binary")
per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
result["F1_image"] = round(float(smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise").item()), 4)
Expand Down Expand Up @@ -239,18 +257,18 @@ def create_mask_report(
th_masks = output["mask"]
th_preds = output["mask_pred"]
th_labels = output["label"]

n_classes = th_preds.shape[1]
# TODO: Apply sigmoid is a wrong name now
if apply_sigmoid:
if th_preds.shape[1] == 1:
if n_classes == 1:
th_preds = torch.nn.Sigmoid()(th_preds)
th_thresh_preds = (th_preds > threshold).float()
else:
th_preds = torch.nn.Softmax(dim=1)(th_preds)
th_thresh_preds = torch.argmax(th_preds, dim=1).float().unsqueeze(1)
th_preds = th_preds[:, 1].unsqueeze(1)
# Compute labels from the given masks since by default they are all 0
th_labels = th_masks.max(dim=2)[0].max(dim=2)[0].squeeze(dim=1)
show_orj_predictions = False

mean = np.asarray(mean)
std = np.asarray(std)
Expand All @@ -268,7 +286,7 @@ def create_mask_report(
binary_labels = labels == 0

row_names = ["Input", "Mask", "Pred", f"Pred>{threshold}"]
bounds = [(0, 255), (0.0, 1.0), (0.0, 1.0), (0.0, 1.0)]
bounds = [(0, 255), (0.0, float(n_classes - 1)), (0.0, 1.0), (0.0, float(n_classes - 1))]
if not show_orj_predictions:
row_names.pop(2)
bounds.pop(2)
Expand Down Expand Up @@ -302,7 +320,7 @@ def create_mask_report(
for k, v in indexes.items():
file_path = os.path.join(report_path, f"{stage}_{name}_{k}_results.png")
images_to_show = [images[v], masks[v], preds[v], thresh_preds[v]]
if not show_orj_predictions:
if not show_orj_predictions or n_classes > 1:
images_to_show.pop(2)
create_grid_figure(
images_to_show,
Expand All @@ -319,10 +337,12 @@ def create_mask_report(
result, fg, fb, area_graph = calculate_mask_based_metrics(
images=images,
th_masks=th_masks,
th_preds=th_preds,
th_preds=th_thresh_preds,
threshold=threshold,
show_orj_predictions=show_orj_predictions,
metric=metric,
multilabel=bool(n_classes > 1),
n_classes=n_classes,
)

if len(fg["image"]) > 0:
Expand Down

0 comments on commit 0c069ac

Please sign in to comment.