Skip to content

Commit

Permalink
HypVINN/run_prediction.py
Browse files Browse the repository at this point in the history
- revert changes to the docstring of main
- move replacement of strings into a decorator function
- fix indentation errors in doc

FastSurferCNN/checkpoint.py
- import Scheduler from torch instead of string-declaring (which does not work with pipe for Union)

FastSurferCNN/parser_defaults.py
- revert <TYPE> | None to Optional[<TYPE>]
- add # noqa: UP0007 to ignore this ruff rule
- add documentation for this
- remove fields import
- add Optional import
- revert parser == None removal
- reformat for number of characters per line
  • Loading branch information
dkuegler committed Sep 4, 2024
1 parent d83d172 commit 751c40d
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 81 deletions.
9 changes: 7 additions & 2 deletions FastSurferCNN/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import MutableSequence
from functools import lru_cache
from pathlib import Path
from typing import Literal, TypedDict, cast, overload
from typing import Literal, TypedDict, cast, overload, TYPE_CHECKING

import requests
import torch
Expand All @@ -27,7 +27,12 @@
from FastSurferCNN.utils import Plane, logging
from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT

Scheduler = "torch.optim.lr_scheduler"
if TYPE_CHECKING:
from torch.optim import lr_scheduler as Scheduler
else:
class Scheduler:
...

LOGGER = logging.getLogger(__name__)

# Defaults
Expand Down
134 changes: 61 additions & 73 deletions FastSurferCNN/utils/parser_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
import argparse
import types
from collections.abc import Iterable, Mapping
from dataclasses import Field, dataclass, fields
from dataclasses import Field, dataclass
from pathlib import Path
from typing import Literal, Protocol, TypeVar, get_args, get_origin
from typing import Literal, Optional, Protocol, TypeVar, get_args, get_origin

from FastSurferCNN.utils import PLANES, Plane
from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one
Expand Down Expand Up @@ -71,14 +71,12 @@ def __arg(
"""
Create stub function, which sets default settings for argparse arguments.
The positional and keyword arguments function as if they were directly passed to
parser.add_arguments().
The positional and keyword arguments function as if they were directly passed to parser.add_arguments().
The result will be a stub function, which has as first argument a parser (or other
object with an add_argument method) to which the argument is added. The stub
function also accepts positional and keyword arguments, which overwrite the default
arguments. Additionally, these specific values can be callables, which will be
called upon the default values (to alter the default value).
The result will be a stub function, which has as first argument a parser (or other object with an add_argument
method) to which the argument is added. The stub function also accepts positional and keyword arguments, which
overwrite the default arguments. Additionally, these specific values can be callables, which will be called upon the
default values (to alter the default value).
This function is private for this module.
"""
Expand Down Expand Up @@ -126,7 +124,7 @@ def _stub(parser: CanAddArguments | type[dict], *flags, **kwargs):
_flags = flags if len(flags) != 0 else default_flags
if hasattr(parser, "add_argument"):
return parser.add_argument(*_flags, **kwargs)
elif isinstance(parser, dict):
elif parser is dict or isinstance(parser, dict):
return {"flag": _flags[0], "flags": _flags, **kwargs}
else:
raise ValueError(
Expand All @@ -145,6 +143,13 @@ def _stub(parser: CanAddArguments | type[dict], *flags, **kwargs):
class SubjectDirectoryConfig:
"""
This class describes the 'minimal' parameters used by SubjectList.
Notes
-----
Important:
Data Types of fields should stay `Optional[<TYPE>]` and not be replaced by `<TYPE> | None`, so the Parser can use
the type in argparse as the value for `type` of `parser.add_argument()` (`Optional` is a callable, while `Union` is
not).
"""
orig_name: str = field(
help="Name of T1 full head MRI. Absolute path if single image else common "
Expand All @@ -154,63 +159,57 @@ class SubjectDirectoryConfig:
)
pred_name: str = field(
default="mri/aparc.DKTatlas+aseg.deep.mgz",
help="Name of intermediate DL-based segmentation file (similar to aparc+aseg). "
"When using FastSurfer, this segmentation is already conformed, since "
"inference is always based on a conformed image. Absolute path if single "
"image else common image name. Default: mri/aparc.DKTatlas+aseg.deep.mgz",
help="Name of intermediate DL-based segmentation file (similar to aparc+aseg). When using FastSurfer, this "
"segmentation is already conformed, since inference is always based on a conformed image. Absolute path "
"if single image else common image name. Default: mri/aparc.DKTatlas+aseg.deep.mgz",
)
conf_name: str = field(
default="mri/orig.mgz",
help="Name under which the conformed input image will be saved, in the same "
"directory as the segmentation (the input image is always conformed "
"first, if it is not already conformed). The original input image is "
"saved in the output directory as $id/mri/orig/001.mgz. Default: "
"mri/orig.mgz.",
help="Name under which the conformed input image will be saved, in the same directory as the segmentation (the "
"input image is always conformed first, if it is not already conformed). The original input image is "
"saved in the output directory as $id/mri/orig/001.mgz. Default: mri/orig.mgz.",
flags=("--conformed_name",),
)
in_dir: Path | None = field(

in_dir: Optional[Path] = field( # noqa: UP007
flags=("--in_dir",),
default=None,
help="Directory in which input volume(s) are located. Optional, if full path "
"is defined for --t1.",
help="Directory in which input volume(s) are located. Optional, if full path is defined for --t1.",
)
csv_file: Path | None = field(
csv_file: Optional[Path] = field( # noqa: UP007
flags=("--csv_file",),
default=None,
help="Csv-file with subjects to analyze (alternative to --tag)",
)
sid: str | None = field(
sid: Optional[str] = field( # noqa: UP007
flags=("--sid",),
default=None,
help="Optional: directly set the subject id to use. Can be used for single "
"subject input. For multi-subject processing, use remove suffix if sid is "
"not second to last element of input file passed to --t1",
help="Optional: directly set the subject id to use. Can be used for single subject input. For multi-subject "
"processing, use remove suffix if sid is not second to last element of input file passed to --t1",
)
search_tag: str = field(
flags=("--tag",),
default="*",
help="Search tag to process only certain subjects. If a single image should be "
"analyzed, set the tag with its id. Default: processes all.",
help="Search tag to process only certain subjects. If a single image should be analyzed, set the tag with its "
"id. Default: processes all.",
)
brainmask_name: str = field(
default="mri/mask.mgz",
help="Name under which the brainmask image will be saved, in the same "
"directory as the segmentation. The brainmask is created from the "
"aparc_aseg segmentation (dilate 5, erode 4, largest component). Default: "
help="Name under which the brainmask image will be saved, in the same directory as the segmentation. The "
"brainmask is created from the aparc_aseg segmentation (dilate 5, erode 4, largest component). Default: "
"`mri/mask.mgz`.",
flags=("--brainmask_name",),
)
remove_suffix: str = field(
flags=("--remove_suffix",),
default="",
help="Optional: remove suffix from path definition of input file to yield "
"correct subject name (e.g. /ses-x/anat/ for BIDS or /mri/ for FreeSurfer "
"input). Default: do not remove anything.",
help="Optional: remove suffix from path definition of input file to yield correct subject name (e.g. "
"/ses-x/anat/ for BIDS or /mri/ for FreeSurfer input). Default: do not remove anything.",
)
out_dir: Path | None = field(
out_dir: Optional[Path] = field( # noqa: UP007
default=None,
help="Directory in which evaluation results should be written. Will be created "
"if it does not exist. Optional if full path is defined for --pred_name.",
help="Directory in which evaluation results should be written. Will be created if it does not exist. Optional "
"if full path is defined for --pred_name.",
)


Expand All @@ -230,45 +229,41 @@ class SubjectDirectoryConfig:
type=str,
dest="norm_name",
default="mri/norm.mgz",
help="Name under which the bias field corrected image is stored. Default: "
"mri/norm.mgz.",
help="Name under which the bias field corrected image is stored. Default: mri/norm.mgz.",
),
"brainmask_name": __arg("--brainmask_name", dc=SubjectDirectoryConfig),
"aseg_name": __arg(
"--aseg_name",
type=str,
dest="aseg_name",
default="mri/aseg.auto_noCCseg.mgz",
help="Name under which the reduced aseg segmentation will be saved, in the "
"same directory as the aparc-aseg segmentation (labels of full aparc "
"segmentation are reduced to aseg). Default: mri/aseg.auto_noCCseg.mgz.",
help="Name under which the reduced aseg segmentation will be saved, in the same directory as the aparc-aseg "
"segmentation (labels of full aparc segmentation are reduced to aseg). Default: "
"mri/aseg.auto_noCCseg.mgz.",
),
"seg_log": __arg(
"--seg_log",
type=str,
dest="log_name",
default="",
help="Absolute path to file in which run logs will be saved. If not set, logs "
"will not be saved.",
help="Absolute path to file in which run logs will be saved. If not set, logs will not be saved.",
),
"device": __arg(
"--device",
default="auto",
help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or "
"specify a certain gpu (e.g. cuda:1), default: auto",
help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or specify a certain gpu (e.g. cuda:1), "
"Default: auto",
),
"viewagg_device": __arg(
"--viewagg_device",
dest="viewagg_device",
type=str,
default="auto",
help="Define the device, where the view aggregation should be run. By default, "
"the program checks if you have enough memory to run the view aggregation "
"on the gpu (cuda). The total memory is considered for this decision. If "
"this fails, or you actively overwrote the check with setting "
"> --viewagg_device cpu <, view agg is run on the cpu. Equivalently, if "
"you define > --viewagg_device cuda <, view agg will be run on the gpu "
"(no memory check will be done).",
help="Define the device, where the view aggregation should be run. By default, the program checks if you have "
"enough memory to run the view aggregation on the gpu (cuda). The total memory is considered for this "
"decision. If this fails, or you actively overwrote the check with setting > --viewagg_device cpu <, view "
"agg is run on the cpu. Equivalently, if you define > --viewagg_device cuda <, view agg will be run on "
"the gpu (no memory check will be done).",
),
"in_dir": __arg("--in_dir", dc=SubjectDirectoryConfig, fieldname="in_dir"),
"tag": __arg(
Expand All @@ -290,29 +285,26 @@ class SubjectDirectoryConfig:
type=str,
dest="qc_log",
default="",
help="Absolute path to file in which a list of subjects that failed QC check "
"(when processing multiple subjects) will be saved. If not set, the file "
"will not be saved.",
help="Absolute path to file in which a list of subjects that failed QC check (when processing multiple "
"subjects) will be saved. If not set, the file will not be saved.",
),
"vox_size": __arg(
"--vox_size",
type=__vox_size,
default="min",
dest="vox_size",
help="Choose the primary voxelsize to process, must be either a number between "
"0 and 1 (below 0.7 is experimental) or 'min' (default). A number forces "
"processing at that specific voxel size, 'min' determines the voxel size "
"from the image itself (conforming to the minimum voxel size, or 1 if the "
help="Choose the primary voxelsize to process, must be either a number between 0 and 1 (below 0.7 is "
"experimental) or 'min' (default). A number forces processing at that specific voxel size, 'min' "
"determines the voxel size from the image itself (conforming to the minimum voxel size, or 1 if the "
"minimum voxel size is above 0.95mm). ",
),
"conform_to_1mm_threshold": __arg(
"--conform_to_1mm_threshold",
type=__conform_to_one,
default=0.95,
dest="conform_to_1mm_threshold",
help="The voxelsize threshold, above which images will be conformed to 1mm "
"isotropic, if the --vox_size argument is also 'min' (the --vox_size "
"default setting). Contrary to conform.py, the default behavior of "
help="The voxelsize threshold, above which images will be conformed to 1mm isotropic, if the --vox_size "
"argument is also 'min' (the --vox_size default setting). Contrary to conform.py, the default behavior of "
"%(prog)s is to resample all images above 0.95mm to 1mm.",
),
"lut": __arg(
Expand All @@ -332,16 +324,14 @@ class SubjectDirectoryConfig:
dest="threads",
default=get_num_threads(),
type=int,
help=f"Number of threads to use (defaults to number of hardware threads: "
f"{get_num_threads()})",
help=f"Number of threads to use (defaults to number of hardware threads: {get_num_threads()})",
),
"async_io": __arg(
"--async_io",
dest="async_io",
action="store_true",
help="Allow asynchronous file operations (default: off). Note, this may impact "
"the order of messages in the log, but speed up the segmentation "
"specifically for slow file systems.",
help="Allow asynchronous file operations (default: off). Note, this may impact the order of messages in the "
"log, but speed up the segmentation specifically for slow file systems.",
),
}

Expand Down Expand Up @@ -403,11 +393,9 @@ def add_plane_flags(
The parser to add flags to.
configtype : Literal["checkpoint", "config"]
The type of files (for help text and prefix from "checkpoint" and "config".
"checkpoint" will lead to flags like "--ckpt_{plane}", "config" to
"--cfg_{plane}".
"checkpoint" will lead to flags like "--ckpt_{plane}", "config" to "--cfg_{plane}".
files : Mapping[Plane, Path | str]
A dictionary of plane to filename. Relative files are assumed to be relative to
the FastSurfer root directory.
A dictionary of plane to filename. Relative files are assumed to be relative to the FastSurfer root directory.
defaults_path : Path, str
A path to the file to load defaults from.
Expand Down
23 changes: 17 additions & 6 deletions HypVINN/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ def option_parse() -> argparse.ArgumentParser:
return parser


def _update_docstring(**kwargs):
"""
Make custom replacements in the docstring.
"""

def stub(f):
f.__doc__ = f.__doc__.format(**kwargs)
return f
return stub


@_update_docstring(HYPVINN_SEG_NAME=HYPVINN_SEG_NAME, HYPVINN_MASK_NAME=HYPVINN_MASK_NAME)
def main(
out_dir: Path,
t2: Path | None,
Expand Down Expand Up @@ -207,10 +219,10 @@ def main(
The path to the coronal configuration file.
cfg_sag : Path
The path to the sagittal configuration file.
hypo_segfile : str, default is in HYPVINN_SEG_NAME as specified in config.
The name of the hypothalamus segmentation file. Default is in HYPVINN_SEG_NAME.
hypo_maskfile : str, default is in HYPVINN_MASK_NAME
The name of the hypothalamus mask file. Default is in HYPVINN_MASK_NAME.
hypo_segfile : str, default="{HYPVINN_SEG_NAME}"
The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}.
hypo_maskfile : str, default="{HYPVINN_MASK_NAME}"
The name of the hypothalamus mask file. Default is {HYPVINN_MASK_NAME}.
allow_root : bool, default=False
Whether to allow running as root user. Default is False.
qc_snapshots : bool, optional
Expand Down Expand Up @@ -466,8 +478,7 @@ def load_volumes(
-------
tuple
A tuple containing the following elements:
- modalities: A dictionary with keys 't1' and/or 't2' and values
being the corresponding loaded and rescaled images.
- modalities: A dictionary of `ndarrays` of rescaled images for keys 't1' and/or 't2'.
- affine: The affine transformation of the loaded image(s).
- header: The header of the loaded image(s).
- zoom: The zoom level of the loaded image(s).
Expand Down

0 comments on commit 751c40d

Please sign in to comment.