Skip to content

Commit

Permalink
Merge branch 'Deep-MI:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
taha-abdullah authored Sep 10, 2024
2 parents fc250d4 + 77231fb commit 1ed459e
Show file tree
Hide file tree
Showing 105 changed files with 1,354 additions and 1,409 deletions.
2 changes: 2 additions & 0 deletions .codespellignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
assertIn
mapp
padd
struc
TE
warmup
16 changes: 8 additions & 8 deletions .github/workflows/code-style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }}
cancel-in-progress: true
on:
# pull_request:
# push:
# branches: [dev]
pull_request:
push:
branches: [dev]
workflow_dispatch:

jobs:
Expand All @@ -32,9 +32,9 @@ jobs:
with:
check_filenames: true
check_hidden: true
skip: './.git,./build,./.mypy_cache,./.pytest_cache'
skip: './build,./doc/images,./Tutorial,./.git,./.mypy_cache,./.pytest_cache'
ignore_words_file: ./.codespellignore
- name: Run pydocstyle
run: pydocstyle .
- name: Run bibclean
run: bibclean-check doc/references.bib
# - name: Run pydocstyle
# run: pydocstyle .
# - name: Run bibclean
# run: bibclean-check doc/references.bib
69 changes: 46 additions & 23 deletions CerebNet/apply_warp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import argparse
#!/bin/python

# Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn
#
Expand All @@ -15,48 +15,64 @@
# limitations under the License.

# IMPORTS
import numpy as np
import argparse
from pathlib import Path
from typing import cast

import nibabel as nib
import numpy as np
from numpy import typing as npt

from os.path import join
from CerebNet.datasets import utils


def save_nii_image(img_data, save_path, header, affine):
def save_nii_image(
img_data: npt.ArrayLike,
save_path: Path | str,
header: nib.Nifti1Header | nib.Nifti2Header,
affine: npt.NDArray[float],
):
"""
Save an image data array as a NIfTI file.
Parameters
----------
img_data : ndarray
The image data to be saved.
save_path : str
save_path : Path, str
The path (including file name) where the image will be saved.
header : nibabel.Nifti1Header
header : nibabel.Nifti1Header, nibabel.Nifti2Header
The header information for the NIfTI file.
affine : ndarray
The affine matrix for the NIfTI file.
"""

if not isinstance(header, nib.Nifti1Header):
header = nib.Nifti1Header.from_header(header)
img_out = nib.Nifti1Image(img_data, header=header, affine=affine)
print(f"Saving {save_path}")
nib.save(img_out, save_path)


def main(img_path, lbl_path, warp_path, result_path, patch_size):
def main(
img_path: Path | str,
lbl_path: Path | str,
warp_path: Path | str,
result_path: Path | str,
patch_size,
):

"""
Load, warp, crop, and save both an image and its corresponding label based on a given warp field.
Parameters
----------
img_path : str
img_path : Path, str
Path to the T1-weighted MRI image to be warped.
lbl_path : str
lbl_path : Path, str
Path to the label image corresponding to the T1 image, to be warped similarly.
warp_path : str
warp_path : Path, str
Path to the warp field file used to warp the images.
result_path : str
result_path : Path, str
Directory path where the warped and cropped images will be saved.
patch_size : tuple of int
The dimensions (height, width, depth) cropped images after warping.
Expand All @@ -65,9 +81,16 @@ def main(img_path, lbl_path, warp_path, result_path, patch_size):
img, img_file = utils.load_reorient_rescale_image(img_path)

lbl_file = nib.load(lbl_path)
label = np.asarray(lbl_file.get_fdata(), dtype=np.int16)

warp_field = np.asarray(nib.load(warp_path).get_fdata())
# if not isinstance(lbl_file, nib.analyze.SpatialImage):
if not isinstance(lbl_file, nib.Nifti1Image | nib.Nifti2Image):
raise ValueError(f"{lbl_file} is not a valid file format!")
lbl_header = cast(nib.Nifti1Header | nib.Nifti2Header, lbl_file.header)
label = np.asarray(lbl_file.dataobj, dtype=np.int16)

warp_file = nib.load(warp_path)
if not isinstance(warp_file, nib.analyze.SpatialImage):
raise ValueError(f"{warp_file} is not a valid file format!")
warp_field = np.asarray(warp_file.dataobj, dtype=float)
img = utils.map_size(img, base_shape=warp_field.shape[:3])
label = utils.map_size(label, base_shape=warp_field.shape[:3])
warped_img = utils.apply_warp_field(warp_field, img, interpol_order=3)
Expand All @@ -80,28 +103,28 @@ def main(img_path, lbl_path, warp_path, result_path, patch_size):

img_file.header['dim'][1:4] = patch_size
img_file.set_data_dtype(img.dtype)
lbl_file.header['dim'][1:4] = patch_size
lbl_header['dim'][1:4] = patch_size
save_nii_image(img,
join(result_path, "T1_warped_cropped.nii.gz"),
Path(result_path) / "T1_warped_cropped.nii.gz",
header=img_file.header,
affine=img_file.affine)
save_nii_image(label,
join(result_path, "label_warped_cropped.nii.gz"),
header=lbl_file.header,
Path(result_path) / "label_warped_cropped.nii.gz",
header=lbl_header,
affine=lbl_file.affine)


def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--img_path",
help="path to T1 image",
type=str)
type=Path)
parser.add_argument("--lbl_path",
help="path to label image",
type=str)
type=Path)
parser.add_argument("--result_path",
help="folder to store the results",
type=str)
type=Path)

parser.add_argument("--warp_filename",
help="Warp field file",
Expand All @@ -113,7 +136,7 @@ def make_parser() -> argparse.ArgumentParser:
if __name__ == '__main__':
parser = make_parser()
args = parser.parse_args()
warp_path = str(join(args.result_path, args.warp_filename))
warp_path = Path(args.result_path) / args.warp_filename
main(
args.img_path,
args.lbl_path,
Expand Down
1 change: 1 addition & 0 deletions CerebNet/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# IMPORTS
from CerebNet.config.cerebnet import get_cfg_cerebnet
from CerebNet.config.dataset import get_cfg_dataset

__all__ = [
"cerebnet",
"dataset",
Expand Down
2 changes: 1 addition & 1 deletion CerebNet/config/cerebnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
# Data Augmentation options
# ---------------------------------------------------------------------------- #

# Augmentation for traning
# Augmentation for training
_C.AUGMENTATION = CN()

# list of augmentations to use for training
Expand Down
18 changes: 8 additions & 10 deletions CerebNet/data_loader/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,19 @@
# IMPORTS
import numbers
import random
from typing import Optional

import numpy as np
import torch
from numpy import random as npr
from scipy.ndimage import gaussian_filter, affine_transform
from scipy.ndimage import affine_transform, gaussian_filter
from scipy.stats import median_abs_deviation
from torchvision import transforms

from CerebNet.data_loader.data_utils import FLIPPED_LABELS



# Transformations for training
class ToTensor(object):
class ToTensor:
"""
Convert ndarrays in sample to Tensors.
"""
Expand Down Expand Up @@ -66,7 +64,7 @@ def _apply_img(self, img):
return super()._apply_img(img.transpose((2, 0, 1)))


class RandomAffine(object):
class RandomAffine:
"""
Apply a random affine transformation to
images, label and weight
Expand Down Expand Up @@ -99,15 +97,15 @@ def _get_random_affine(self):
degrees = (-self.degree, self.degree)
else:
assert (
isinstance(self.degree, (tuple, list)) and len(self.degree) == 2
isinstance(self.degree, tuple | list) and len(self.degree) == 2
), "degrees should be a list or tuple and it must be of length 2."
if isinstance(self.translate, numbers.Number):
if not (0.0 <= self.translate <= 1.0):
raise ValueError("translation values should be between 0 and 1")
translate = (self.translate, self.translate)
else:
assert (
isinstance(self.translate, (tuple, list)) and len(self.translate) == 2
isinstance(self.translate, tuple | list) and len(self.translate) == 2
), "translate should be a list or tuple and it must be of length 2."
for t in self.translate:
if not (0.0 <= t <= 1.0):
Expand Down Expand Up @@ -159,7 +157,7 @@ def __call__(self, sample):
return sample


class RandomFlip(object):
class RandomFlip:
"""
Random horizontal flipping.
"""
Expand Down Expand Up @@ -196,7 +194,7 @@ class RandomBiasField:
def __init__(
self,
cfg,
seed: Optional[int] = None,
seed: int | None = None,
):
"""
Initialize the RandomBiasField object with configuration and optional seed.
Expand Down Expand Up @@ -287,7 +285,7 @@ def __call__(self, sample):
return sample


class RandomLabelsToImage(object):
class RandomLabelsToImage:
"""
Generate image from segmentation
using the dataset intensity priors.
Expand Down
2 changes: 1 addition & 1 deletion CerebNet/data_loader/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def uncrop_volume(vol, uncrop_shape, roi):


def get_binary_map(lbl_map, class_names):
bin_map = np.logical_or.reduce(list(map(lambda l: lbl_map == l, class_names)))
bin_map = np.logical_or.reduce(list(map(lambda lb: lbl_map == lb, class_names)))
return bin_map


Expand Down
35 changes: 15 additions & 20 deletions CerebNet/data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,30 @@
# limitations under the License.

# IMPORTS
from typing import Tuple, Literal, TypeVar, Dict
from numbers import Number
from typing import Literal, TypeVar

import h5py
import nibabel as nib
import torch
import numpy as np
import torch
from numpy import typing as npt
import h5py
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose

from FastSurferCNN.utils import logging, Plane
from CerebNet.data_loader import data_utils as utils
from CerebNet.data_loader.augmentation import ToTensor
from CerebNet.datasets.load_data import SubjectLoader
from CerebNet.datasets.utils import bounding_volume_offset, crop_transform
from FastSurferCNN.data_loader.data_utils import (
get_thick_slices,
transform_axial,
transform_sagittal,
)

from CerebNet.data_loader import data_utils as utils
from CerebNet.data_loader.augmentation import ToTensor
from CerebNet.datasets.load_data import SubjectLoader
from CerebNet.datasets.utils import crop_transform, bounding_volume_offset
from FastSurferCNN.utils import Plane, logging

ROIKeys = Literal["source_shape", "offsets", "target_shape"]
LocalizerROI = Dict[ROIKeys, Tuple[int, ...]]
LocalizerROI = dict[ROIKeys, tuple[int, ...]]

NT = TypeVar("NT", bound=Number)

Expand Down Expand Up @@ -118,11 +117,7 @@ def __init__(self, dataset_path, cfg, transforms, load_aux_data):
del self.dataset["subject"]

logger.info(
"Successfully loaded {} slices in {} plane from {}".format(
self.count,
cfg.DATA.PLANE,
dataset_path,
)
f"Successfully loaded {self.count} slices in {cfg.DATA.PLANE} plane from {dataset_path}"
)

logger.info(
Expand Down Expand Up @@ -242,7 +237,7 @@ def __init__(
self,
img_org: nib.analyze.SpatialImage,
brain_seg: nib.analyze.SpatialImage,
patch_size: Tuple[int, ...],
patch_size: tuple[int, ...],
slice_thickness: int,
primary_slice: str,
):
Expand Down Expand Up @@ -298,10 +293,10 @@ def __init__(
"coronal": img,
"sagittal": transform_sagittal(img),
}
for plane, data in data.items():
for plane, data_i in data.items():
# data is transformed to 'plane'-direction in axis 2
thick_slices = get_thick_slices(
data, self.slice_thickness
data_i, self.slice_thickness
) # [H, W, n_slices, C]
# it seems x and y are flipped with respect to expectations here
self.images_per_plane[plane] = np.transpose(
Expand All @@ -315,7 +310,7 @@ def locate_mask_bbox(self, mask: npt.NDArray[bool]):
bbox of min0, min1, ..., max0, max1, ...
"""
# filter disconnected components
from skimage.measure import regionprops, label
from skimage.measure import label, regionprops

label_image = label(mask, connectivity=3)
regions = regionprops(label_image)
Expand All @@ -341,7 +336,7 @@ def plane(self) -> Plane:
"""Returns the active plane"""
return self._plane

def __getitem__(self, index: int) -> Tuple[Plane, np.ndarray]:
def __getitem__(self, index: int) -> tuple[Plane, np.ndarray]:
"""Get the plane and data belonging to indices given."""

if not (0 <= index < self.images_per_plane[self.plane].shape[0]):
Expand Down
5 changes: 2 additions & 3 deletions CerebNet/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# limitations under the License.

# IMPORTS
from torchvision import transforms
from torch.utils.data import DataLoader

from FastSurferCNN.utils import logging
from torchvision import transforms

from CerebNet.data_loader import dataset as dset
from CerebNet.data_loader.augmentation import ToTensor, get_transform
from FastSurferCNN.utils import logging

logger = logging.get_logger(__name__)

Expand Down
Loading

0 comments on commit 1ed459e

Please sign in to comment.