From ba0deab0b778e7f3b81a5c2c46846f07b1e9d7a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Tue, 3 Sep 2024 13:02:33 +0200 Subject: [PATCH 01/31] Fix the run time calculation (rounded to integer hours in log file) --- recon_surf/recon-surf.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/recon_surf/recon-surf.sh b/recon_surf/recon-surf.sh index d3eba901..19feb9dc 100755 --- a/recon_surf/recon-surf.sh +++ b/recon_surf/recon-surf.sh @@ -1077,8 +1077,7 @@ fi # Collect info EndTime=$(date) tSecEnd=$(date '+%s') -tRunHours=$(($((tSecEnd - tSecStart))/3600)) -tRunHours=$(printf %6.3f "$tRunHours") +tRunHours=$(printf %6.3f "$(bc <<< "($tSecEnd - $tSecStart) / 3600")") { echo "" From 9022bf8f6dd3050c1f5bcf087c1477b6cacd6522 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 14:51:17 +0200 Subject: [PATCH 02/31] bump version to 2.3 (and lapy) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6bd077a5..c6d16506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = 'setuptools.build_meta' [project] name = 'fastsurfer' -version = '2.3.0-dev' +version = '2.3.0' description = 'A fast and accurate deep-learning based neuroimaging pipeline' readme = 'README.md' license = {file = 'LICENSE'} @@ -33,7 +33,7 @@ classifiers = [ ] dependencies = [ 'h5py>=3.7', - 'lapy>=1.0.1', + 'lapy>=1.1.0', 'matplotlib>=3.7.1', 'nibabel>=5.1.0', 'numpy>=1.25,<2', From 5d659e1bdceefd10f27fae55771194494bcafc39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Tue, 3 Sep 2024 17:55:02 +0200 Subject: [PATCH 03/31] Fix the run time calculation (rounded to integer hours in log file) Fixes error introduced in 2f350e9965dc19fa9bcc1b53ee0934ed7c54ea53 --- FastSurferCNN/run_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/FastSurferCNN/run_model.py b/FastSurferCNN/run_model.py index e4894726..4d7dcc89 100644 --- a/FastSurferCNN/run_model.py +++ b/FastSurferCNN/run_model.py @@ -58,12 +58,10 @@ def make_parser() -> argparse.ArgumentParser: return parser - def main(args): """ First sets variables and then runs the trainer model. """ - args = setup_options() cfg = get_config(args) if args.aug is not None: From 5d55dd3b1635a3f9787c6dbd41075fb95aec9bc1 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 12:22:35 +0200 Subject: [PATCH 04/31] fix ruff warnings and formating --- .../Tutorial_FastSurferCNN_QuickSeg.ipynb | 136 +++++++++--------- recon_surf/N4_bias_correct.py | 39 ++--- recon_surf/align_points.py | 20 +-- recon_surf/align_seg.py | 55 +++---- recon_surf/create_annotation.py | 89 ++++++------ recon_surf/fs_balabels.py | 34 ++--- recon_surf/image_io.py | 29 ++-- recon_surf/lta.py | 38 ++--- recon_surf/map_surf_label.py | 60 ++++---- recon_surf/paint_cc_into_pred.py | 9 +- recon_surf/rewrite_mc_surface.py | 7 +- recon_surf/rewrite_oriented_surface.py | 3 +- recon_surf/rotate_sphere.py | 40 +++--- recon_surf/sample_parc.py | 17 ++- recon_surf/smooth_aparc.py | 22 +-- recon_surf/spherically_project.py | 29 ++-- recon_surf/spherically_project_wrapper.py | 10 +- .../utils/extract_recon_surf_time_info.py | 4 +- 18 files changed, 319 insertions(+), 322 deletions(-) diff --git a/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb b/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb index db463305..e12a8fb8 100644 --- a/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb +++ b/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb @@ -71,115 +71,115 @@ " Cloning https://github.com/Deep-MI/LaPy.git to /tmp/pip-install-pha61_pf/lapy_2347d8ba3da148d7acf6917844aea139\n", " Running command git clone --filter=blob:none --quiet https://github.com/Deep-MI/LaPy.git /tmp/pip-install-pha61_pf/lapy_2347d8ba3da148d7acf6917844aea139\n", " Resolved https://github.com/Deep-MI/LaPy.git to commit f75628053399480e4ef741e9f0d96656833168dd\n", - " Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", - " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", - " Installing backend dependencies ... \u001B[?25l\u001B[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", "Collecting absl-py==1.2.0 (from -r /content/fastsurfer//requirements.txt (line 9))\n", " Downloading absl_py-1.2.0-py3-none-any.whl (123 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m123.4/123.4 kB\u001B[0m \u001B[31m12.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting cachetools==5.2.0 (from -r /content/fastsurfer//requirements.txt (line 11))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m123.4/123.4 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting cachetools==5.2.0 (from -r /content/fastsurfer//requirements.txt (line 11))\n", " Downloading cachetools-5.2.0-py3-none-any.whl (9.3 kB)\n", "Collecting certifi==2022.6.15 (from -r /content/fastsurfer//requirements.txt (line 13))\n", " Downloading certifi-2022.6.15-py3-none-any.whl (160 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m160.2/160.2 kB\u001B[0m \u001B[31m15.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting charset-normalizer==2.1.0 (from -r /content/fastsurfer//requirements.txt (line 15))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m160.2/160.2 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting charset-normalizer==2.1.0 (from -r /content/fastsurfer//requirements.txt (line 15))\n", " Downloading charset_normalizer-2.1.0-py3-none-any.whl (39 kB)\n", "Requirement already satisfied: click==8.1.3 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 17)) (8.1.3)\n", "Requirement already satisfied: cycler==0.11.0 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 19)) (0.11.0)\n", "Requirement already satisfied: deprecated==1.2.13 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 21)) (1.2.13)\n", "Collecting fonttools==4.34.4 (from -r /content/fastsurfer//requirements.txt (line 23))\n", " Downloading fonttools-4.34.4-py3-none-any.whl (944 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m944.1/944.1 kB\u001B[0m \u001B[31m54.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting google-auth==2.9.1 (from -r /content/fastsurfer//requirements.txt (line 25))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m944.1/944.1 kB\u001b[0m \u001b[31m54.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting google-auth==2.9.1 (from -r /content/fastsurfer//requirements.txt (line 25))\n", " Downloading google_auth-2.9.1-py2.py3-none-any.whl (167 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m167.8/167.8 kB\u001B[0m \u001B[31m19.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting google-auth-oauthlib==0.4.6 (from -r /content/fastsurfer//requirements.txt (line 29))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m167.8/167.8 kB\u001b[0m \u001b[31m19.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting google-auth-oauthlib==0.4.6 (from -r /content/fastsurfer//requirements.txt (line 29))\n", " Downloading google_auth_oauthlib-0.4.6-py2.py3-none-any.whl (18 kB)\n", "Collecting grpcio==1.47.0 (from -r /content/fastsurfer//requirements.txt (line 31))\n", " Downloading grpcio-1.47.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m4.5/4.5 MB\u001B[0m \u001B[31m103.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting h5py==3.7.0 (from -r /content/fastsurfer//requirements.txt (line 33))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m103.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting h5py==3.7.0 (from -r /content/fastsurfer//requirements.txt (line 33))\n", " Downloading h5py-3.7.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.5 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m4.5/4.5 MB\u001B[0m \u001B[31m90.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting humanize==4.2.3 (from -r /content/fastsurfer//requirements.txt (line 35))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m90.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting humanize==4.2.3 (from -r /content/fastsurfer//requirements.txt (line 35))\n", " Downloading humanize-4.2.3-py3-none-any.whl (102 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m102.6/102.6 kB\u001B[0m \u001B[31m9.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting idna==3.3 (from -r /content/fastsurfer//requirements.txt (line 37))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.6/102.6 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting idna==3.3 (from -r /content/fastsurfer//requirements.txt (line 37))\n", " Downloading idna-3.3-py3-none-any.whl (61 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m61.2/61.2 kB\u001B[0m \u001B[31m8.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting imageio==2.19.5 (from -r /content/fastsurfer//requirements.txt (line 39))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.2/61.2 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting imageio==2.19.5 (from -r /content/fastsurfer//requirements.txt (line 39))\n", " Downloading imageio-2.19.5-py3-none-any.whl (3.4 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m3.4/3.4 MB\u001B[0m \u001B[31m113.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting importlib-metadata==4.12.0 (from -r /content/fastsurfer//requirements.txt (line 41))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m113.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting importlib-metadata==4.12.0 (from -r /content/fastsurfer//requirements.txt (line 41))\n", " Downloading importlib_metadata-4.12.0-py3-none-any.whl (21 kB)\n", "Requirement already satisfied: joblib==1.2.0 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 43)) (1.2.0)\n", "Requirement already satisfied: kiwisolver==1.4.4 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 45)) (1.4.4)\n", "Collecting markdown==3.4.1 (from -r /content/fastsurfer//requirements.txt (line 49))\n", " Downloading Markdown-3.4.1-py3-none-any.whl (93 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m93.3/93.3 kB\u001B[0m \u001B[31m5.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting matplotlib==3.5.1 (from -r /content/fastsurfer//requirements.txt (line 51))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m93.3/93.3 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting matplotlib==3.5.1 (from -r /content/fastsurfer//requirements.txt (line 51))\n", " Downloading matplotlib-3.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.9 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m11.9/11.9 MB\u001B[0m \u001B[31m81.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting networkx==2.8.5 (from -r /content/fastsurfer//requirements.txt (line 53))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.9/11.9 MB\u001b[0m \u001b[31m81.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting networkx==2.8.5 (from -r /content/fastsurfer//requirements.txt (line 53))\n", " Downloading networkx-2.8.5-py3-none-any.whl (2.0 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m2.0/2.0 MB\u001B[0m \u001B[31m77.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting nibabel==3.2.2 (from -r /content/fastsurfer//requirements.txt (line 55))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m77.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nibabel==3.2.2 (from -r /content/fastsurfer//requirements.txt (line 55))\n", " Downloading nibabel-3.2.2-py3-none-any.whl (3.3 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m3.3/3.3 MB\u001B[0m \u001B[31m76.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting numpy==1.23.5 (from -r /content/fastsurfer//requirements.txt (line 59))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.3/3.3 MB\u001b[0m \u001b[31m76.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting numpy==1.23.5 (from -r /content/fastsurfer//requirements.txt (line 59))\n", " Downloading numpy-1.23.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m17.1/17.1 MB\u001B[0m \u001B[31m60.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting oauthlib==3.2.0 (from -r /content/fastsurfer//requirements.txt (line 76))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m17.1/17.1 MB\u001b[0m \u001b[31m60.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting oauthlib==3.2.0 (from -r /content/fastsurfer//requirements.txt (line 76))\n", " Downloading oauthlib-3.2.0-py3-none-any.whl (151 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m151.5/151.5 kB\u001B[0m \u001B[31m18.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting packaging==21.3 (from -r /content/fastsurfer//requirements.txt (line 78))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m151.5/151.5 kB\u001b[0m \u001b[31m18.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting packaging==21.3 (from -r /content/fastsurfer//requirements.txt (line 78))\n", " Downloading packaging-21.3-py3-none-any.whl (40 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m40.8/40.8 kB\u001B[0m \u001B[31m5.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting pandas==1.4.3 (from -r /content/fastsurfer//requirements.txt (line 83))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pandas==1.4.3 (from -r /content/fastsurfer//requirements.txt (line 83))\n", " Downloading pandas-1.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m11.6/11.6 MB\u001B[0m \u001B[31m75.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting pillow==9.2.0 (from -r /content/fastsurfer//requirements.txt (line 85))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.6/11.6 MB\u001b[0m \u001b[31m75.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pillow==9.2.0 (from -r /content/fastsurfer//requirements.txt (line 85))\n", " Downloading Pillow-9.2.0-cp310-cp310-manylinux_2_28_x86_64.whl (3.2 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m3.2/3.2 MB\u001B[0m \u001B[31m62.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hRequirement already satisfied: plotly==5.9.0 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 92)) (5.9.0)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.2/3.2 MB\u001b[0m \u001b[31m62.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: plotly==5.9.0 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 92)) (5.9.0)\n", "Collecting protobuf==3.19.4 (from -r /content/fastsurfer//requirements.txt (line 94))\n", " Downloading protobuf-3.19.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.1/1.1 MB\u001B[0m \u001B[31m83.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting pyasn1==0.4.8 (from -r /content/fastsurfer//requirements.txt (line 96))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m83.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pyasn1==0.4.8 (from -r /content/fastsurfer//requirements.txt (line 96))\n", " Downloading pyasn1-0.4.8-py2.py3-none-any.whl (77 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m77.1/77.1 kB\u001B[0m \u001B[31m477.8 kB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting pyasn1-modules==0.2.8 (from -r /content/fastsurfer//requirements.txt (line 100))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.1/77.1 kB\u001b[0m \u001b[31m477.8 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pyasn1-modules==0.2.8 (from -r /content/fastsurfer//requirements.txt (line 100))\n", " Downloading pyasn1_modules-0.2.8-py2.py3-none-any.whl (155 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m155.3/155.3 kB\u001B[0m \u001B[31m22.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hRequirement already satisfied: pyparsing==3.0.9 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 102)) (3.0.9)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m155.3/155.3 kB\u001b[0m \u001b[31m22.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pyparsing==3.0.9 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 102)) (3.0.9)\n", "Requirement already satisfied: python-dateutil==2.8.2 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 106)) (2.8.2)\n", "Collecting pytz==2022.1 (from -r /content/fastsurfer//requirements.txt (line 111))\n", " Downloading pytz-2022.1-py2.py3-none-any.whl (503 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m503.5/503.5 kB\u001B[0m \u001B[31m33.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting pywavelets==1.3.0 (from -r /content/fastsurfer//requirements.txt (line 113))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m503.5/503.5 kB\u001b[0m \u001b[31m33.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pywavelets==1.3.0 (from -r /content/fastsurfer//requirements.txt (line 113))\n", " Downloading PyWavelets-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.9 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m6.9/6.9 MB\u001B[0m \u001B[31m101.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hRequirement already satisfied: pyyaml==6.0 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 115)) (6.0)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.9/6.9 MB\u001b[0m \u001b[31m101.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pyyaml==6.0 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 115)) (6.0)\n", "Collecting requests==2.28.1 (from -r /content/fastsurfer//requirements.txt (line 119))\n", " Downloading https://download.pytorch.org/whl/requests-2.28.1-py3-none-any.whl (62 kB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m62.8/62.8 kB\u001B[0m \u001B[31m3.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hRequirement already satisfied: requests-oauthlib==1.3.1 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 124)) (1.3.1)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.8/62.8 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests-oauthlib==1.3.1 in /usr/local/lib/python3.10/dist-packages (from -r /content/fastsurfer//requirements.txt (line 124)) (1.3.1)\n", "Collecting rsa==4.8 (from -r /content/fastsurfer//requirements.txt (line 126))\n", " Downloading rsa-4.8-py3-none-any.whl (39 kB)\n", "Collecting scikit-image==0.19.2 (from -r /content/fastsurfer//requirements.txt (line 128))\n", " Downloading scikit_image-0.19.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m14.0/14.0 MB\u001B[0m \u001B[31m107.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting scikit-learn==1.1.2 (from -r /content/fastsurfer//requirements.txt (line 130))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.0/14.0 MB\u001b[0m \u001b[31m107.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting scikit-learn==1.1.2 (from -r /content/fastsurfer//requirements.txt (line 130))\n", " Downloading scikit_learn-1.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30.5 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m30.5/30.5 MB\u001B[0m \u001B[31m43.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25hCollecting scipy==1.8.0 (from -r /content/fastsurfer//requirements.txt (line 132))\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m30.5/30.5 MB\u001b[0m \u001b[31m43.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting scipy==1.8.0 (from -r /content/fastsurfer//requirements.txt (line 132))\n", " Downloading scipy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.3 MB)\n", - "\u001B[2K \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m42.3/42.3 MB\u001B[0m \u001B[31m10.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n", - "\u001B[?25h\u001B[31mERROR: Ignored the following versions that require a different python version: 1.6.2 Requires-Python >=3.7,<3.10; 1.6.3 Requires-Python >=3.7,<3.10; 1.7.0 Requires-Python >=3.7,<3.10; 1.7.0rc1 Requires-Python >=3.7,<3.10; 1.7.0rc2 Requires-Python >=3.7,<3.10; 1.7.1 Requires-Python >=3.7,<3.10; 2.5.2 Requires-Python !=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,<3.9dev,>=2.7\u001B[0m\u001B[31m\n", - "\u001B[0m\u001B[31mERROR: Could not find a version that satisfies the requirement simpleitk==2.1.1 (from versions: 1.0.1, 1.2.0, 2.1.0, 2.1.1.1, 2.1.1.2, 2.2.0, 2.2.1)\u001B[0m\u001B[31m\n", - "\u001B[0m\u001B[31mERROR: No matching distribution found for simpleitk==2.1.1\u001B[0m\u001B[31m\n", - "\u001B[0mInstalling required packages\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.3/42.3 MB\u001b[0m \u001b[31m10.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: Ignored the following versions that require a different python version: 1.6.2 Requires-Python >=3.7,<3.10; 1.6.3 Requires-Python >=3.7,<3.10; 1.7.0 Requires-Python >=3.7,<3.10; 1.7.0rc1 Requires-Python >=3.7,<3.10; 1.7.0rc2 Requires-Python >=3.7,<3.10; 1.7.1 Requires-Python >=3.7,<3.10; 2.5.2 Requires-Python !=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,<3.9dev,>=2.7\u001b[0m\u001b[31m\n", + "\u001b[0m\u001b[31mERROR: Could not find a version that satisfies the requirement simpleitk==2.1.1 (from versions: 1.0.1, 1.2.0, 2.1.0, 2.1.1.1, 2.1.1.2, 2.2.0, 2.2.1)\u001b[0m\u001b[31m\n", + "\u001b[0m\u001b[31mERROR: No matching distribution found for simpleitk==2.1.1\u001b[0m\u001b[31m\n", + "\u001b[0mInstalling required packages\n", "----------------------------------------------\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: torchio==0.18.83 in /usr/local/lib/python3.10/dist-packages (0.18.83)\n", @@ -218,7 +218,7 @@ "#@title Here we first setup the environment by downloading the open source deep-mi/fastsurfer project and the required packages\n", "import os\n", "import sys\n", - "from os.path import exists, basename, splitext\n", + "from os.path import basename, exists, splitext\n", "\n", "print(\"Starting setup. This could take a few minutes\")\n", "print(\"----------------------------------------------\")\n", @@ -292,6 +292,7 @@ "%cd \"{SETUP_DIR}\"\n", "\n", "from google.colab import files\n", + "\n", "uploaded = files.upload()\n", "\n", "img = SETUP_DIR + list(uploaded.keys())[0]" @@ -475,6 +476,7 @@ "#@title Click this run button, if you would prefer to download the segmentation in nifti-format\n", "import nibabel as nib\n", "from google.colab import files\n", + "\n", "# conversion to nifti\n", "data = nib.load(f'{SETUP_DIR}fastsurfer_seg/Tutorial/mri/aparc.DKTatlas+aseg.deep.mgz')\n", "img_nifti = nib.Nifti1Image(data.get_fdata(), data.affine, header=nib.Nifti1Header())\n", @@ -494,6 +496,7 @@ "source": [ "#@title If you chose the example subject (Alternative in Click 1), click the run button if you want to download the input image as well\n", "from google.colab import files\n", + "\n", "files.download(f\"{SETUP_DIR}140_orig.mgz\")" ] }, @@ -520,6 +523,7 @@ "#@title Click this run button, if you would prefer to download the image in nifti-format\n", "import nibabel as nib\n", "from google.colab import files\n", + "\n", "# conversion to nifti\n", "data = nib.load(f\"{SETUP_DIR}140_orig.mgz\")\n", "img_nifti = nib.Nifti1Image(data.get_fdata(), data.affine, header=nib.Nifti1Header())\n", @@ -610,12 +614,13 @@ "source": [ "#@title Click the run buttion to plot some slices from the segmented brain\n", "%matplotlib inline\n", - "import nibabel as nib\n", "import matplotlib.pyplot as plt\n", - "import torch\n", + "import nibabel as nib\n", "import numpy as np\n", + "import torch\n", "from skimage import color\n", "from torchvision import utils\n", + "\n", "plt.style.use('seaborn-v0_8-whitegrid')\n", "\n", "def plot_predictions(image, pred):\n", @@ -673,9 +678,10 @@ "outputs": [], "source": [ "#@title Select and visualize your structures of interest in 3D by using the dropdown menu and clicking \"Run Interact\". \n", - "from ipywidgets import widgets\n", "import matplotlib.pyplot as plt\n", "import nibabel as nib\n", + "from ipywidgets import widgets\n", + "\n", "#from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n", "from skimage import measure\n", "\n", diff --git a/recon_surf/N4_bias_correct.py b/recon_surf/N4_bias_correct.py index ba9cb9ec..f1a17356 100644 --- a/recon_surf/N4_bias_correct.py +++ b/recon_surf/N4_bias_correct.py @@ -19,17 +19,18 @@ import argparse import logging import sys +from collections.abc import Callable from pathlib import Path -from typing import Optional, cast, Literal, TypeVar, Callable +from typing import Literal, TypeVar, cast -# Group 2: External modules -import SimpleITK as sitk +# Group 2: Internal modules +import image_io as iio + +# Group 3: External modules import numpy as np +import SimpleITK as sitk from numpy import typing as npt -# Group 3: Internal modules -import image_io as iio - HELPTEXT = """ Script to call SITK N4 Bias Correction @@ -234,7 +235,7 @@ def options_parse(): def itk_n4_bfcorrection( itk_image: sitk.Image, - itk_mask: Optional[sitk.Image] = None, + itk_mask: sitk.Image | None = None, shrink: int = 4, levels: int = 4, numiter: int = 50, @@ -302,9 +303,9 @@ def itk_n4_bfcorrection( def normalize_wm_mask_ball( itk_image: sitk.Image, - itk_mask: Optional[sitk.Image] = None, + itk_mask: sitk.Image | None = None, radius: float = 50., - centroid: Optional[np.ndarray] = None, + centroid: np.ndarray | None = None, target_wm: float = 110., target_bg: float = 3. ) -> sitk.Image: @@ -372,7 +373,7 @@ def get_distance(axis): def normalize_wm_aseg( itk_image: sitk.Image, - itk_mask: Optional[sitk.Image], + itk_mask: sitk.Image | None, itk_aseg: sitk.Image, target_wm: float = 110., target_bg: float = 3. @@ -429,7 +430,7 @@ def normalize_wm_aseg( def normalize_img( itk_image: sitk.Image, - itk_mask: Optional[sitk.Image], + itk_mask: sitk.Image | None, source_intensity: tuple[float, float], target_intensity: tuple[float, float] ) -> sitk.Image: @@ -501,7 +502,7 @@ def read_talairach_xfm(fname: Path | str) -> np.ndarray: # advance transform_iter to linear header _ = next(ln for ln in transform_iter if ln.lower().startswith("linear_")) # return the next 3 lines in transform_lines - transform_lines = (ln for ln, _ in zip(transform_iter, range(3))) + transform_lines = (ln for ln, _ in zip(transform_iter, range(3), strict=False)) tal_str = [ln.replace(";", " ") for ln in transform_lines] tal = np.genfromtxt(tal_str) tal = np.vstack([tal, [0, 0, 0, 1]]) @@ -511,8 +512,8 @@ def read_talairach_xfm(fname: Path | str) -> np.ndarray: return tal except StopIteration: _logger.error(msg := f"Could not find 'linear_' in {fname}.") - raise ValueError(msg) - except (Exception, StopIteration) as e: + raise ValueError(msg) from None + except Exception as e: err = ValueError(f"Could not find taiairach transform in {fname}.") _logger.exception(err) raise err from e @@ -567,7 +568,7 @@ def print_options(options: dict): _logger.info(m.format(**options)) -def get_image_mean(image: sitk.Image, mask: Optional[sitk.Image] = None) -> float: +def get_image_mean(image: sitk.Image, mask: sitk.Image | None = None) -> float: """ Get the mean of a sitk Image. @@ -622,13 +623,13 @@ def main( rescalevol: LiteralSkipRescaling | Path = SKIP_RESCALING, dtype: str = "keep", threads: int = 1, - mask: Optional[Path] = None, - aseg: Optional[Path] = None, + mask: Path | None = None, + aseg: Path | None = None, shrink: int = 4, levels: int = 4, numiter: int = 50, thres: float = 0.0, - tal: Optional[Path] = None, + tal: Path | None = None, verbosity: int = -1, ) -> int | str: if rescalevol == "skip rescaling" and outvol == DO_NOT_SAVE: @@ -653,7 +654,7 @@ def main( has_mask = bool(mask) if has_mask: logger.debug(f"reading mask {mask}") - itk_mask: Optional[sitk.Image] = iio.readITKimage( + itk_mask: sitk.Image | None = iio.readITKimage( str(mask), sitk.sitkUInt8, with_header=False diff --git a/recon_surf/align_points.py b/recon_surf/align_points.py index e94a8f1c..e549466d 100755 --- a/recon_surf/align_points.py +++ b/recon_surf/align_points.py @@ -22,12 +22,12 @@ # - find_affine # IMPORTS + import numpy as np from numpy import typing as npt -from typing import Tuple -def rmat2angles(R: npt.NDArray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +def rmat2angles(R: npt.NDArray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Extract rotation angles (alpha,beta,gamma) in FreeSurfer format (mris_register) from a rotation matrix. @@ -104,9 +104,7 @@ def find_rotation(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: """ if p_mov.shape != p_dst.shape: raise ValueError( - "Shape of points should be identical, but mov = {}, dst = {} expecting Nx3".format( - p_mov.shape, p_dst.shape - ) + f"Shape of points should be identical, but mov = {p_mov.shape}, dst = {p_dst.shape} expecting Nx3" ) # average SSD # dd = p_mov-p_dst @@ -148,9 +146,7 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: """ if p_mov.shape != p_dst.shape: raise ValueError( - "Shape of points should be identical, but mov = {}, dst = {} expecting Nx3".format( - p_mov.shape, p_dst.shape - ) + f"Shape of points should be identical, but mov = {p_mov.shape}, dst = {p_dst.shape} expecting Nx3" ) # average SSD # translate points to be centered around origin centroid_mov = np.mean(p_mov, axis=0) @@ -168,9 +164,9 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: T[:m, m] = t # compute disteances dd = p_mov - p_dst - print("Initial avg SSD: {}".format(np.sum(dd * dd) / p_mov.shape[0])) + print(f"Initial avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") dd = (np.transpose(R @ np.transpose(p_mov)) + t) - p_dst - print("Final avg SSD: {}".format(np.sum(dd * dd) / p_mov.shape[0])) + print(f"Final avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") # return T, R, t return T @@ -199,9 +195,7 @@ def find_affine(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: """ if p_mov.shape != p_dst.shape: raise ValueError( - "Shape of points should be identical, but mov = {}, dst = {} expecting Nx3".format( - p_mov.shape, p_dst.shape - ) + f"Shape of points should be identical, but mov = {p_mov.shape}, dst = {p_dst.shape} expecting Nx3" ) # average SSD n = len(p_mov) # Solve overdetermined system for the three rows of diff --git a/recon_surf/align_seg.py b/recon_surf/align_seg.py index 98b5e5ed..49d30869 100755 --- a/recon_surf/align_seg.py +++ b/recon_surf/align_seg.py @@ -17,15 +17,14 @@ # IMPORTS import optparse -from typing import Optional, Tuple -import numpy as np -from numpy import typing as npt import sys -import SimpleITK as sitk -import image_io as iio + import align_points as align +import image_io as iio import lta as lta - +import numpy as np +import SimpleITK as sitk +from numpy import typing as npt HELPTEXT = """ @@ -46,8 +45,8 @@ Description: For each common segmentation ID in the two inputs, the centroid coordinate is -computed. The point pairs are then aligned by finding the optimal translation and rotation -(rigid) or affine. The output is a FreeSurfer LTA registration file. +computed. The point pairs are then aligned by finding the optimal translation +and rotation (rigid) or affine. The output is a FreeSurfer LTA registration file. Original Author: Martin Reuter Date: Aug-24-2022 @@ -61,7 +60,8 @@ ) h_outlta = "path to output transform lta file" h_flipped = "register to left-right flipped as target aparc+aseg (cortical needed)" -h_midslice = "Optional, only for flipped. Slice where the midplane should be. Defaults to middle of image (width-1)/2." +h_midslice = "Optional, only for flipped. Slice where the midplane should be." \ + " Defaults to middle of image (width-1)/2." def options_parse(): @@ -94,12 +94,17 @@ def options_parse(): or options.outlta is None ): sys.exit( - "\nERROR: Please specify srcseg and trgseg (or flipped) as well as output lta file\n Use --help to see all options.\n" + "\nERROR: Please specify srcseg and trgseg (or flipped)" \ + " as well as output lta file\n Use --help to see all options.\n" ) return options -def get_seg_centroids(seg_mov: sitk.Image, seg_dst: sitk.Image, label_ids: Optional[npt.NDArray[int]] = []) -> Tuple[npt.NDArray, npt.NDArray]: +def get_seg_centroids( + seg_mov: sitk.Image, + seg_dst: sitk.Image, + label_ids: npt.NDArray[int] | None = None +) -> tuple[npt.NDArray, npt.NDArray]: """ Extract the centroids of the segmentation labels for mov and dst in RAS coords. @@ -119,7 +124,7 @@ def get_seg_centroids(seg_mov: sitk.Image, seg_dst: sitk.Image, label_ids: Optio centroids_dst List of centroids of target segmentation. """ - if not label_ids: + if label_ids is not None: # use all joint labels except -1 and 0: nda1 = sitk.GetArrayFromImage(seg_mov) nda2 = sitk.GetArrayFromImage(seg_dst) @@ -156,7 +161,7 @@ def get_seg_centroids(seg_mov: sitk.Image, seg_dst: sitk.Image, label_ids: Optio def align_seg_centroids( seg_mov: sitk.Image, seg_dst: sitk.Image, - label_ids: Optional[npt.NDArray[int]] = [], + label_ids: npt.NDArray[int] | None = None, affine: bool = False ) -> npt.NDArray: """ @@ -169,7 +174,7 @@ def align_seg_centroids( seg_dst : sitk.Image Path to trg segmentation. label_ids : Optional[npt.NDArray[int]] - List of label ids to align. Defaults to []. + List of label ids to align. Defaults to None. affine : bool True if affine should be returned. False if rigid should be returned. Defaults to False. @@ -219,7 +224,7 @@ def get_vox2ras(img:sitk.Image) -> npt.NDArray: vox2ras[0:3,3] = img.GetOrigin() * np.array([-1, -1, 1]) return vox2ras -def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDArray: +def align_flipped(seg: sitk.Image, mid_slice: float | None = None) -> npt.NDArray: """ Registrate Left - right (make upright). @@ -323,9 +328,9 @@ def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDA # compute vox2ras matrix from image information vox2ras = get_vox2ras(seg) - print("vox2ras:\n {}".format(vox2ras)) + print(f"vox2ras:\n {vox2ras}") ras2vox = np.linalg.inv(vox2ras) - print("ras2vox:\n {}".format(ras2vox)) + print(f"ras2vox:\n {ras2vox}") # find mid slice of image (usually 127.5 for 256 width) # instead we could also fix this to be 128 independent of width @@ -334,7 +339,7 @@ def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDA middle = 0.5*(seg.GetWidth()-1.0) else: middle = mid_slice - print("Mid slice will be at: {}".format(middle)) + print(f"Mid slice will be at: {middle}") # negate right-left by flipping across middle of image (as make_upright would do it) centroids_flipped = centroids.copy() @@ -350,7 +355,7 @@ def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDA from scipy.linalg import sqrtm Tsqrt = np.real(sqrtm(T)) - print("Matrix sqrt diff: {}".format(np.linalg.norm(T - (Tsqrt @ Tsqrt)))) + print(f"Matrix sqrt diff: {np.linalg.norm(T - (Tsqrt @ Tsqrt))}") # convert vox2vox to ras2ras: Tsqrt = vox2ras @ Tsqrt @ ras2vox @@ -366,23 +371,23 @@ def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDA print() print("Align Segmentations Parameters:") print() - print("- src seg {}".format(options.srcseg)) + print(f"- src seg {options.srcseg}") if options.trgseg is not None: - print("- trg seg {}".format(options.trgseg)) + print(f"- trg seg {options.trgseg}") if options.flipped: print("- registering with left-right flipped image") if options.affine: print("- affine registration") else: print("- rigid registration") - print("- out lta {}".format(options.outlta)) + print(f"- out lta {options.outlta}") - print("\nreading src {}".format(options.srcseg)) + print(f"\nreading src {options.srcseg}") srcseg, srcheader = iio.readITKimage( options.srcseg, sitk.sitkInt16, with_header=True ) if options.trgseg is not None: - print("reading trg {} ...".format(options.trgseg)) + print(f"reading trg {options.trgseg} ...") trgseg, trgheader = iio.readITKimage( options.trgseg, sitk.sitkInt16, with_header=True ) @@ -394,7 +399,7 @@ def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDA trgheader = srcheader # write transform lta - print("writing: {}".format(options.outlta)) + print(f"writing: {options.outlta}") lta.writeLTA( options.outlta, T, options.srcseg, srcheader, options.trgseg, trgheader ) diff --git a/recon_surf/create_annotation.py b/recon_surf/create_annotation.py index 4e8345ba..bdab05bf 100755 --- a/recon_surf/create_annotation.py +++ b/recon_surf/create_annotation.py @@ -19,13 +19,12 @@ # IMPORTS import optparse import os.path -from typing import Union, Optional, Tuple, List -import numpy as np -from numpy import typing as npt import sys -import nibabel.freesurfer.io as fs -from map_surf_label import mapSurfLabel, getSurfCorrespondence +import nibabel.freesurfer.io as fs +import numpy as np +from map_surf_label import getSurfCorrespondence, mapSurfLabel +from numpy import typing as npt HELPTEXT = """ @@ -122,7 +121,8 @@ def options_parse(): or options.outannot is None ): sys.exit( - "\nERROR: Please specify all parameters!\n Use --help to see all options.\n" + "\nERROR: Please specify all parameters!\n" \ + " Use --help to see all options.\n" ) if ( options.trgsphere is not None @@ -132,13 +132,16 @@ def options_parse(): ): if options.trgsphere is None or options.srcsphere is None: sys.exit( - "\nERROR: Please specify at least src and trg sphere when mapping!\n Use --help to see all options.\n" + "\nERROR: Please specify at least src and trg sphere when mapping!\n" \ + " Use --help to see all options.\n" ) if (options.trgdir is not None and options.trgsid is None) or ( options.trgdir is None and options.trgsid is not None ): sys.exit( - "\nERROR: Please specify both trgdir and trgsid when outputting mapped labels!\n Use --help to see all options.\n" + "\nERROR: Please specify both trgdir and trgsid" \ + " when outputting mapped labels!\n" \ + " Use --help to see all options.\n" ) return options @@ -153,9 +156,9 @@ def map_multiple_labels( trg_sphere_name: str, trg_white_name: str, trg_sid: str, - out_dir: Optional[str] = None, + out_dir: str | None = None, stop_missing: bool = True -) -> Tuple[npt.ArrayLike, npt.ArrayLike]: +) -> tuple[npt.ArrayLike, npt.ArrayLike]: """ Map a list of labels from one surface (e.g. fsavaerage sphere.reg) to another. @@ -201,7 +204,7 @@ def map_multiple_labels( all_labels = [] all_values = [] # read target surf info (for label writing) - print("Reading in trg white surface: {} ...".format(trg_white_name)) + print(f"Reading in trg white surface: {trg_white_name} ...") trg_white = fs.read_geometry(trg_white_name, read_metadata=False)[0] out_label_name = None for l_name in src_labels: @@ -220,10 +223,10 @@ def map_multiple_labels( else: if stop_missing: raise ValueError( - "ERROR: Label file missing {}\n".format(src_label_name) + f"ERROR: Label file missing {src_label_name}\n" ) else: - print("\nWARNING: Label file missing {}\n".format(src_label_name)) + print(f"\nWARNING: Label file missing {src_label_name}\n") ll = [] vv = [] all_labels.append(ll) @@ -235,7 +238,7 @@ def read_multiple_labels( hemi: str, input_dir: str, label_names: npt.ArrayLike -) -> Tuple[ List[npt.NDArray], List[npt.NDArray]]: +) -> tuple[ list[npt.NDArray], list[npt.NDArray]]: """ Read multiple label files from input_dir. @@ -262,7 +265,7 @@ def read_multiple_labels( if os.path.exists(label_file): ll, vv = fs.read_label(label_file, read_scalars=True) else: - print("\nWARNING: Label file missing {}\n".format(label_file)) + print(f"\nWARNING: Label file missing {label_file}\n") ll = [] vv = [] all_labels.append(ll) @@ -272,9 +275,9 @@ def read_multiple_labels( def build_annot(all_labels: npt.ArrayLike, all_values: npt.ArrayLike, - col_ids: npt.ArrayLike, trg_white: Union[str, npt.NDArray], - cortex_label_name: Optional[str] = None - ) -> Tuple[npt.NDArray, npt.NDArray]: + col_ids: npt.ArrayLike, trg_white: str | npt.NDArray, + cortex_label_name: str | None = None + ) -> tuple[npt.NDArray, npt.NDArray]: """ Create an annotation from multiple labels. @@ -314,9 +317,7 @@ def build_annot(all_labels: npt.ArrayLike, all_values: npt.ArrayLike, # print("counter={}".format(counter)) if len(label) == 0: print( - "\nWARNING: Label with id {} missing, skipping ...\n".format( - col_ids[counter] - ) + f"\nWARNING: Label with id {col_ids[counter]} missing, skipping ...\n" ) counter = counter + 1 continue @@ -338,7 +339,7 @@ def build_annot(all_labels: npt.ArrayLike, all_values: npt.ArrayLike, return annot_ids, annot_vals -def read_colortable(colortab_name: str) -> Tuple[npt.ArrayLike, List[str], npt.ArrayLike]: +def read_colortable(colortab_name: str) -> tuple[npt.ArrayLike, list[str], npt.ArrayLike]: """ Read the colortable of given name. @@ -369,7 +370,7 @@ def write_annot( label_names: npt.ArrayLike, colortab_name: str, out_annot: str, - append: Union[None, str] = "" + append: None | str = "" ) -> None: """ Combine the colortable with the annotations ids to write an annotation file. @@ -396,13 +397,11 @@ def write_annot( offset = 0 if col_names[0] == "unknown": offset = 1 - for name_tab, name_list in zip(col_names[offset:], label_names): + for name_tab, name_list in zip(col_names[offset:], label_names, strict=False): if name_tab + append != name_list: # print("Name in colortable and in label lists disagree: {} != {}".format(name_tab+append,name_list)) raise ValueError( - "Error: name in colortable and in label lists disagree: {} != {}".format( - name_tab + append, name_list - ) + f"Error: name in colortable and in label lists disagree: {name_tab + append} != {name_list}" ) # fill_ctab computes the last column (R+G*2^8+B*2^16) if offset == 0: @@ -438,28 +437,28 @@ def create_annotation(options, verbose: bool = True) -> None: print("Map BA Labels Parameters:") print() if verbose: - print("- hemi: {}".format(options.hemi)) - print("- color table: {}".format(options.colortab)) - print("- label dir: {}".format(options.labeldir)) - print("- white: {}".format(options.white)) - print("- out annot: {}".format(options.outannot)) + print(f"- hemi: {options.hemi}") + print(f"- color table: {options.colortab}") + print(f"- label dir: {options.labeldir}") + print(f"- white: {options.white}") + print(f"- out annot: {options.outannot}") if options.cortex is not None: - print("- cortex mask: {}".format(options.cortex)) + print(f"- cortex mask: {options.cortex}") if options.append is not None: if options.append[0] != ".": options.append = "." + options.append - print("- append {} to label names".format(options.append)) + print(f"- append {options.append} to label names") if options.trgsphere is not None: print("Mapping labels from another subject:") - print("- src sphere: {}".format(options.srcsphere)) - print("- trg sphere: {}".format(options.trgsphere)) + print(f"- src sphere: {options.srcsphere}") + print(f"- trg sphere: {options.trgsphere}") if options.trgdir is not None: print("And will write mapped labels:") - print("- trg dir: {}".format(options.trgdir)) - print("- trg sid: {}".format(options.trgsid)) + print(f"- trg dir: {options.trgdir}") + print(f"- trg sid: {options.trgsid}") print() # read label names from color table - print("Reading in colortable: {} ...".format(options.colortab)) + print(f"Reading in colortable: {options.colortab} ...") ids, names, cols = read_colortable(options.colortab) if names[0] == "unknown": ids = ids[1:] @@ -467,19 +466,17 @@ def create_annotation(options, verbose: bool = True) -> None: cols = cols[1:] # although we do not care about color at this stage at all if options.append is not None: names = [x + options.append for x in names] - print("Merging these labels into annot:\n{}\n".format(names)) + print(f"Merging these labels into annot:\n{names}\n") # if reading multiple label files if options.trgsphere is None: - print("Reading multiple labels from {} ...".format(options.labeldir)) + print(f"Reading multiple labels from {options.labeldir} ...") all_labels, all_values = read_multiple_labels( options.hemi, options.labeldir, names ) else: # if mapping multiple label files print( - "Mapping multiple labels from {} to {} ...".format( - options.labeldir, options.trgdir - ) + f"Mapping multiple labels from {options.labeldir} to {options.trgdir} ..." ) all_labels, all_values = map_multiple_labels( options.hemi, @@ -492,12 +489,12 @@ def create_annotation(options, verbose: bool = True) -> None: options.trgdir, ) # merge labels into annot - print("Creating annotation on {}".format(options.white)) + print(f"Creating annotation on {options.white}") annot_ids, annot_vals = build_annot( all_labels, all_values, ids, options.white, options.cortex ) # write annot - print("Writing annotation to {}".format(options.outannot)) + print(f"Writing annotation to {options.outannot}") write_annot(annot_ids, names, options.colortab, options.outannot, options.append) print("...done\n") diff --git a/recon_surf/fs_balabels.py b/recon_surf/fs_balabels.py index 0fbdbb7b..1ede0139 100755 --- a/recon_surf/fs_balabels.py +++ b/recon_surf/fs_balabels.py @@ -18,15 +18,15 @@ # IMPORTS import optparse -import os.path import os -from typing import Tuple, List -import numpy as np +import os.path import sys + +import numpy as np from create_annotation import ( + build_annot, map_multiple_labels, read_colortable, - build_annot, write_annot, ) @@ -105,10 +105,10 @@ def options_parse(): def read_colortables( - colnames: List[str], - colappend: List[str], + colnames: list[str], + colappend: list[str], drop_unknown: bool = True -) -> Tuple[List, List, List]: +) -> tuple[list, list, list]: """ Read multiple colortables and appends extensions, drops unknown by default. @@ -136,7 +136,7 @@ def read_colortables( all_ids = [] all_cols = [] for coltab in colnames: - print("Reading in colortable: {} ...".format(coltab)) + print(f"Reading in colortable: {coltab} ...") ids, names, cols = read_colortable(coltab) if drop_unknown and names[0] == "unknown": ids = ids[1:] @@ -156,10 +156,10 @@ def read_colortables( stream = os.popen("date") output = stream.read() - print + print() print("#--------------------------------------------") print("#@# BA_exvivo Labels " + output) - print + print() # Command line options and error checking done here options = options_parse() @@ -196,9 +196,7 @@ def read_colortables( white = os.path.join(options.sd, options.sid, "surf", hemi + ".white") cortex = os.path.join(options.sd, options.sid, "label", hemi + ".cortex.label") print( - "Mapping multiple labels from {} to {} for {} ...\n".format( - labeldir, trgdir, hemi - ) + f"Mapping multiple labels from {labeldir} to {trgdir} for {hemi} ...\n" ) all_labels, all_values = map_multiple_labels( hemi, @@ -217,7 +215,7 @@ def read_colortables( for annot in annotnames: # print("Debug length labelids pos {}".format(len(label_ids[pos]))) stop = start + len(label_ids[pos]) - print("\nCreating {} annotation on {}".format(annot, white)) + print(f"\nCreating {annot} annotation on {white}") # print("Debug info start: {}, stop: {}".format(start,stop)) annot_ids, annot_vals = build_annot( all_labels[start:stop], @@ -227,7 +225,7 @@ def read_colortables( cortex, ) # write annot - print("Writing BA_exvivo annotation to {}\n".format(annot)) + print(f"Writing BA_exvivo annotation to {annot}\n") annotout = os.path.join( options.sd, options.sid, "label", hemi + "." + annot + ".annot" ) @@ -241,10 +239,8 @@ def read_colortables( options.sd, options.sid, "stats", hemi + "." + annot + ".stats" ) ctab = os.path.join(options.sd, options.sid, "label", annot + ".ctab") - cmd = "mris_anatomical_stats -mgz -f {} -b -a {} -c {} \ - {} {} white".format( - stats, annotout, ctab, options.sid, hemi - ) + cmd = f"mris_anatomical_stats -mgz -f {stats} -b -a {annotout} -c {ctab} \ + {options.sid} {hemi} white" print("Debug cmd: " + cmd) stream = os.popen(cmd) print(stream.read()) diff --git a/recon_surf/image_io.py b/recon_surf/image_io.py index 44e53100..3d8a5cbf 100644 --- a/recon_surf/image_io.py +++ b/recon_surf/image_io.py @@ -17,17 +17,18 @@ # IMPORTS -import numpy as np import sys -import SimpleITK as sitk +from typing import Any, overload + import nibabel as nib +import numpy as np +import SimpleITK as sitk from nibabel.freesurfer.mghformat import MGHHeader -from typing import Union, Any, Optional, Tuple, overload def mgh_from_sitk( sitk_img: sitk.Image, - orig_mgh_header: Optional[nib.freesurfer.mghformat.MGHHeader] = None + orig_mgh_header: nib.freesurfer.mghformat.MGHHeader | None = None ) -> nib.MGHImage: """Convert sitk image to mgh image. @@ -109,7 +110,7 @@ def sitk_from_mgh(img: nib.MGHImage) -> sitk.Image: @overload def readITKimage( filename: str, - vox_type: Optional[Any] = None, + vox_type: Any | None = None, with_header: False = False ) -> sitk.Image: ... @@ -118,17 +119,17 @@ def readITKimage( @overload def readITKimage( filename: str, - vox_type: Optional[Any] = None, + vox_type: Any | None = None, with_header: True = True -) -> Tuple[sitk.Image, Any]: +) -> tuple[sitk.Image, Any]: ... def readITKimage( filename: str, - vox_type: Optional[Any] = None, + vox_type: Any | None = None, with_header: bool = False -) -> Union[sitk.Image, Tuple[sitk.Image, Any]]: +) -> sitk.Image | tuple[sitk.Image, Any]: """Read the itk image. Parameters @@ -166,9 +167,7 @@ def readITKimage( itkimage = sitk.Cast(itkimage, vox_type) else: sys.exit( - "read ERROR: {} image type not supported (only: .mgz, .nii, .nii.gz).\n".format( - filename - ) + f"read ERROR: {filename} image type not supported (only: .mgz, .nii, .nii.gz).\n" ) if with_header: return itkimage, header @@ -179,7 +178,7 @@ def readITKimage( def writeITKimage( img: sitk.Image, filename: str, - header: Optional[nib.freesurfer.mghformat.MGHHeader] = None + header: nib.freesurfer.mghformat.MGHHeader | None = None ) -> None: """ Writes the given ITK image to a file. @@ -205,9 +204,7 @@ def writeITKimage( nib.save(mgh_image, filename) else: sys.exit( - "write ERROR: {} image type not supported (only: .mgz, .nii, .nii.gz).\n".format( - filename - ) + f"write ERROR: {filename} image type not supported (only: .mgz, .nii, .nii.gz).\n" ) diff --git a/recon_surf/lta.py b/recon_surf/lta.py index 3a6f32f1..6825e381 100755 --- a/recon_surf/lta.py +++ b/recon_surf/lta.py @@ -50,18 +50,18 @@ def writeLTA( ValueError Header format missing field (Source or Destination). """ - from datetime import datetime import getpass + from datetime import datetime fields = ("dims", "delta", "Mdc", "Pxyz_c") for field in fields: if field not in src_header: raise ValueError( - "writeLTA Error: src_header format missing field: {}".format(field) + f"writeLTA Error: src_header format missing field: {field}" ) if field not in dst_header: raise ValueError( - "writeLTA Error: dst_header format missing field: {}".format(field) + f"writeLTA Error: dst_header format missing field: {field}" ) src_dims = str(src_header["dims"][0:3]).replace("[", "").replace("]", "") @@ -75,9 +75,9 @@ def writeLTA( dst_c = dst_header["Pxyz_c"] f = open(filename, "w") - f.write("# transform file {}\n".format(filename)) + f.write(f"# transform file {filename}\n") f.write( - "# created by {} on {}\n\n".format(getpass.getuser(), datetime.now().ctime()) + f"# created by {getpass.getuser()} on {datetime.now().ctime()}\n\n" ) f.write("type = 1 # LINEAR_RAS_TO_RAS\n") f.write("nxforms = 1\n") @@ -88,20 +88,20 @@ def writeLTA( f.write("\n") f.write("src volume info\n") f.write("valid = 1 # volume info valid\n") - f.write("filename = {}\n".format(src_fname)) - f.write("volume = {}\n".format(src_dims)) - f.write("voxelsize = {}\n".format(src_vsize)) - f.write("xras = {}\n".format(src_v2r[0, :]).replace("[", "").replace("]", "")) - f.write("yras = {}\n".format(src_v2r[1, :]).replace("[", "").replace("]", "")) - f.write("zras = {}\n".format(src_v2r[2, :]).replace("[", "").replace("]", "")) - f.write("cras = {}\n".format(src_c).replace("[", "").replace("]", "")) + f.write(f"filename = {src_fname}\n") + f.write(f"volume = {src_dims}\n") + f.write(f"voxelsize = {src_vsize}\n") + f.write(f"xras = {src_v2r[0, :]}\n".replace("[", "").replace("]", "")) + f.write(f"yras = {src_v2r[1, :]}\n".replace("[", "").replace("]", "")) + f.write(f"zras = {src_v2r[2, :]}\n".replace("[", "").replace("]", "")) + f.write(f"cras = {src_c}\n".replace("[", "").replace("]", "")) f.write("dst volume info\n") f.write("valid = 1 # volume info valid\n") - f.write("filename = {}\n".format(dst_fname)) - f.write("volume = {}\n".format(dst_dims)) - f.write("voxelsize = {}\n".format(dst_vsize)) - f.write("xras = {}\n".format(dst_v2r[0, :]).replace("[", "").replace("]", "")) - f.write("yras = {}\n".format(dst_v2r[1, :]).replace("[", "").replace("]", "")) - f.write("zras = {}\n".format(dst_v2r[2, :]).replace("[", "").replace("]", "")) - f.write("cras = {}\n".format(dst_c).replace("[", "").replace("]", "")) + f.write(f"filename = {dst_fname}\n") + f.write(f"volume = {dst_dims}\n") + f.write(f"voxelsize = {dst_vsize}\n") + f.write(f"xras = {dst_v2r[0, :]}\n".replace("[", "").replace("]", "")) + f.write(f"yras = {dst_v2r[1, :]}\n".replace("[", "").replace("]", "")) + f.write(f"zras = {dst_v2r[2, :]}\n".replace("[", "").replace("]", "")) + f.write(f"cras = {dst_c}\n".replace("[", "").replace("]", "")) f.close() diff --git a/recon_surf/map_surf_label.py b/recon_surf/map_surf_label.py index 52f0d317..4f9c51e6 100755 --- a/recon_surf/map_surf_label.py +++ b/recon_surf/map_surf_label.py @@ -18,14 +18,13 @@ # IMPORTS import optparse -from typing import Union, Optional, Tuple -import numpy as np -import numpy.typing as npt import sys + import nibabel.freesurfer.io as fs +import numpy as np +import numpy.typing as npt from sklearn.neighbors import KDTree - HELPTEXT = """ Script to map surface labels across surfaces based on given sphere registrations @@ -128,14 +127,10 @@ def writeSurfLabel( values = np.zeros(label.shape) if values.size != label.size: raise ValueError( - "writeLabel Error: label and values should have same sizes {}!={}".format( - label.size, values.size - ) + f"writeLabel Error: label and values should have same sizes {label.size}!={values.size}" ) coords = surf[label, :] - header = "#!ascii label , from subject {} vox2ras=TkReg \n{}".format( - sid, label.size - ) + header = f"#!ascii label , from subject {sid} vox2ras=TkReg \n{label.size}" data = np.column_stack([label, coords, values]) np.savetxt( filename, @@ -147,10 +142,10 @@ def writeSurfLabel( def getSurfCorrespondence( - src_sphere: Union[str, Tuple, np.ndarray], - trg_sphere: Union[str, Tuple, np.ndarray], - tree: Optional[KDTree] = None -) -> Tuple[np.ndarray, np.ndarray, KDTree]: + src_sphere: str | tuple | np.ndarray, + trg_sphere: str | tuple | np.ndarray, + tree: KDTree | None = None +) -> tuple[np.ndarray, np.ndarray, KDTree]: """ For each vertex in src_sphere find the closest vertex in trg_sphere. @@ -201,10 +196,10 @@ def getSurfCorrespondence( def mapSurfLabel( src_label_name: str, out_label_name: str, - trg_surf: Union[str, np.ndarray], + trg_surf: str | np.ndarray, trg_sid: str, rev_mapping: np.ndarray -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray]: """ Map a label from src surface according to the correspondence. @@ -237,18 +232,16 @@ def mapSurfLabel( ValueError If label and trg vertices are not of same sizes. """ - print("Mapping label: {} ...".format(src_label_name)) + print(f"Mapping label: {src_label_name} ...") src_label, src_values = fs.read_label(src_label_name, read_scalars=True) smax = max(np.max(src_label), np.max(rev_mapping)) + 1 tmax = rev_mapping.size if isinstance(trg_surf, str): - print("Reading in surface: {} ...".format(trg_surf)) + print(f"Reading in surface: {trg_surf} ...") trg_surf = fs.read_geometry(trg_surf, read_metadata=False)[0] if trg_surf.shape[0] != tmax: raise ValueError( - "mapSurfLabel Error: label and trg vertices should have same sizes {}!={}".format( - tmax, trg_surf.shape[0] - ) + f"mapSurfLabel Error: label and trg vertices should have same sizes {tmax}!={trg_surf.shape[0]}" ) inside = np.zeros(smax, dtype=bool) inside[src_label] = True @@ -268,12 +261,12 @@ def mapSurfLabel( print() print("Map Surface Labels Parameters:") print() - print("- src label {}".format(options.srclabel)) - print("- src sphere {}".format(options.srcsphere)) - print("- trg sphere {}".format(options.trgsphere)) - print("- trg surf {}".format(options.trgsurf)) - print("- trg sid {}".format(options.trgsid)) - print("- out label {}".format(options.outlabel)) + print(f"- src label {options.srclabel}") + print(f"- src sphere {options.srcsphere}") + print(f"- trg sphere {options.trgsphere}") + print(f"- trg surf {options.trgsurf}") + print(f"- trg sid {options.trgsid}") + print(f"- out label {options.outlabel}") # for example: # src_label_name = "fsaverage/label/lh.BA1_exvivo.label" @@ -283,11 +276,16 @@ def mapSurfLabel( # trg_white_name = "OAS1_0111_MR1/surf/lh.white" # trg_sid = "OAS1_0111_MR1" - # ./map_surf_label.py --srclabel fsaverage/label/lh.BA1_exvivo.label --srcsphere fsaverage/surf/lh.sphere.reg --trgsphere OAS1_0111_MR1/surf/lh.sphere72_my4.reg --trgsurf OAS1_0111_MR1/surf/lh.white --trgsid OAS1_0111_MR1 --outlabel lh.BA1_exvivo_my.label + # ./map_surf_label.py --srclabel fsaverage/label/lh.BA1_exvivo.label \ + # --srcsphere fsaverage/surf/lh.sphere.reg \ + # --trgsphere OAS1_0111_MR1/surf/lh.sphere72_my4.reg \ + # --trgsurf OAS1_0111_MR1/surf/lh.white \ + # --trgsid OAS1_0111_MR1 \ + # --outlabel lh.BA1_exvivo_my.label - print("Reading in src sphere: {} ...".format(options.srcsphere)) + print(f"Reading in src sphere: {options.srcsphere} ...") src_sphere = fs.read_geometry(options.srcsphere, read_metadata=False)[0] - print("Reading in trg sphere: {} ...".format(options.trgsphere)) + print(f"Reading in trg sphere: {options.trgsphere} ...") trg_sphere = fs.read_geometry(options.trgsphere, read_metadata=False)[0] # get reverse mapping (trg->src) for sampling print("Computing reverse mapping ...") @@ -298,7 +296,7 @@ def mapSurfLabel( mapSurfLabel( options.srclabel, options.outlabel, options.trgsurf, options.trgsid, rev_mapping ) - print("Output label {} written".format(options.outlabel)) + print(f"Output label {options.outlabel} written") print("...done\n") diff --git a/recon_surf/paint_cc_into_pred.py b/recon_surf/paint_cc_into_pred.py index 1d41ae57..ec649a86 100644 --- a/recon_surf/paint_cc_into_pred.py +++ b/recon_surf/paint_cc_into_pred.py @@ -15,10 +15,11 @@ # IMPORTS -import sys import argparse -import numpy as np +import sys + import nibabel as nib +import numpy as np from numpy import typing as npt HELPTEXT = """ @@ -106,12 +107,12 @@ def paint_in_cc(pred: npt.ArrayLike, aseg_cc: npt.ArrayLike) -> npt.ArrayLike: # Command Line options are error checking done here options = argument_parse() - print("Reading inputs: {} {}...".format(options.input_cc, options.input_pred)) + print(f"Reading inputs: {options.input_cc} {options.input_pred}...") aseg_image = np.asanyarray(nib.load(options.input_cc).dataobj) prediction = nib.load(options.input_pred) pred_with_cc = paint_in_cc(np.asanyarray(prediction.dataobj), aseg_image) - print("Writing segmentation with corpus callosum to: {}".format(options.output)) + print(f"Writing segmentation with corpus callosum to: {options.output}") pred_with_cc_fin = nib.MGHImage(pred_with_cc, prediction.affine, prediction.header) pred_with_cc_fin.to_filename(options.output) diff --git a/recon_surf/rewrite_mc_surface.py b/recon_surf/rewrite_mc_surface.py index bd13429e..5e4298a6 100644 --- a/recon_surf/rewrite_mc_surface.py +++ b/recon_surf/rewrite_mc_surface.py @@ -14,8 +14,9 @@ # IMPORTS -import sys import optparse +import sys + import nibabel.freesurfer.io as fs from nibabel import load as nibload @@ -86,7 +87,7 @@ def resafe_surface(insurf: str, outsurf: str, pretess: str) -> None: surf_out = options.output_surf vol_in = options.in_pretess - print("Reading in surface: {} ...".format(surf_in)) + print(f"Reading in surface: {surf_in} ...") resafe_surface(surf_in, surf_out, vol_in) - print("Outputting surface as: {}".format(surf_out)) + print(f"Outputting surface as: {surf_out}") sys.exit(0) diff --git a/recon_surf/rewrite_oriented_surface.py b/recon_surf/rewrite_oriented_surface.py index fc9c4a37..f1a1951b 100644 --- a/recon_surf/rewrite_oriented_surface.py +++ b/recon_surf/rewrite_oriented_surface.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse import shutil + # IMPORTS import sys -import argparse from pathlib import Path import lapy diff --git a/recon_surf/rotate_sphere.py b/recon_surf/rotate_sphere.py index b944b376..3c7b4b72 100644 --- a/recon_surf/rotate_sphere.py +++ b/recon_surf/rotate_sphere.py @@ -20,12 +20,10 @@ import optparse import sys +import align_points as align +import nibabel.freesurfer.io as fs import numpy as np from numpy import typing as npt -import nibabel.freesurfer.io as fs - -import align_points as align - HELPTEXT = """ @@ -95,7 +93,9 @@ def options_parse(): or options.out is None ): sys.exit( - "\nERROR: Please specify src and target sphere and parcellation files as well as output txt file\n Use --help to see all options.\n" + "\nERROR: Please specify src and target sphere and parcellation" \ + " files as well as output txt file\n" \ + " Use --help to see all options.\n" ) return options @@ -106,7 +106,7 @@ def align_aparc_centroids( labels_mov: npt.ArrayLike, v_dst: npt.ArrayLike, labels_dst: npt.ArrayLike, - label_ids: npt.ArrayLike = [] + label_ids: npt.ArrayLike = None ) -> np.ndarray: """ Align centroid of aparc parcels on the sphere (Attention mapping back to sphere!). @@ -122,7 +122,7 @@ def align_aparc_centroids( labels_dst : npt.ArrayLike Labels of aparc parcelation for rotation destination. label_ids : npt.ArrayLike - Ids of the centroid to be aligned. Defaults to []. + Ids of the centroid to be aligned. Defaults to None. Returns ------- @@ -135,7 +135,7 @@ def align_aparc_centroids( # lids=np.array([8,9,22,24,31]) # lids=np.array([8,22,24]) - if not label_ids: + if label_ids is not None: # use all joint labels except -1 and 0: lids = np.intersect1d(labels_mov, labels_dst) lids = lids[(lids > 0)] @@ -168,30 +168,30 @@ def align_aparc_centroids( print() print("Rotate Sphere Parameters:") print() - print("- src sphere {}".format(options.srcsphere)) - print("- src aparc: {}".format(options.srcaparc)) - print("- trg sphere {}".format(options.trgsphere)) - print("- trg aparc: {}".format(options.trgaparc)) - print("- out txt {}".format(options.out)) + print(f"- src sphere {options.srcsphere}") + print(f"- src aparc: {options.srcaparc}") + print(f"- trg sphere {options.trgsphere}") + print(f"- trg aparc: {options.trgaparc}") + print(f"- out txt {options.out}") # read image (only nii supported) and convert to float32 - print("\nreading {}".format(options.srcsphere)) + print(f"\nreading {options.srcsphere}") srcsphere = fs.read_geometry(options.srcsphere, read_metadata=True) - print("reading annotation: {} ...".format(options.srcaparc)) + print(f"reading annotation: {options.srcaparc} ...") srcaparc = fs.read_annot(options.srcaparc) - print("reading {}".format(options.trgsphere)) + print(f"reading {options.trgsphere}") trgsphere = fs.read_geometry(options.trgsphere, read_metadata=True) - print("reading annotation: {} ...".format(options.trgaparc)) + print(f"reading annotation: {options.trgaparc} ...") trgaparc = fs.read_annot(options.trgaparc) R = align_aparc_centroids(srcsphere[0], srcaparc[0], trgsphere[0], trgaparc[0]) alpha, beta, gamma = align.rmat2angles(R) - print("\nalpha {:.1f} beta {:.1f} gamma {:.1f}\n".format(alpha, beta, gamma)) + print(f"\nalpha {alpha:.1f} beta {beta:.1f} gamma {gamma:.1f}\n") # write angles - print("writing: {}".format(options.out)) + print(f"writing: {options.out}") f = open(options.out, "w") - f.write("{:.1f} {:.1f} {:.1f}\n".format(alpha, beta, gamma)) + f.write(f"{alpha:.1f} {beta:.1f} {gamma:.1f}\n") f.close() print("...done\n") diff --git a/recon_surf/sample_parc.py b/recon_surf/sample_parc.py index 3d784ad2..c58515ef 100644 --- a/recon_surf/sample_parc.py +++ b/recon_surf/sample_parc.py @@ -20,16 +20,14 @@ import optparse import sys -import numpy as np -import nibabel.freesurfer.io as fs import nibabel as nib +import nibabel.freesurfer.io as fs +import numpy as np +from lapy import TriaMesh from scipy import sparse from scipy.sparse.csgraph import connected_components -from lapy import TriaMesh - from smooth_aparc import smooth_aparc - HELPTEXT = """ Script to sample labels from image to surface and clean up. @@ -167,7 +165,7 @@ def find_all_islands(surf, annot): lmax = np.bincount(ll).argmax() v = lidx[(ll != lmax)] if v.size > 0: - print("Found disconnected islands ({} vertices total) for label {}!".format(v.size, lid)) + print(f"Found disconnected islands ({v.size} vertices total) for label {lid}!") vidx = np.concatenate((vidx,v)) return vidx @@ -191,7 +189,7 @@ def sample_nearest_nonzero(img, vox_coords, radius=3.0): """ # check for isotropic voxels voxsize = img.header.get_zooms() - print("Check isotropic vox sizes: {}".format(voxsize)) + print(f"Check isotropic vox sizes: {voxsize}") assert (np.max(np.abs(voxsize - voxsize[0])) < 0.001), 'Voxels not isotropic!' data = np.asarray(img.dataobj) @@ -327,7 +325,7 @@ def sample_img(surf, img, cortex=None, projmm=0.0, radius=None): return samplesfull # here we need to do the hard work of searching in a windows # for non-zero samples - print("sample_img: found {} zero samples, searching radius ...".format(zeros.size)) + print(f"sample_img: found {zeros.size} zero samples, searching radius ...") z_nn = x_nn[zeros] z_samples = sample_nearest_nonzero(img, z_nn, radius=radius) samples_nn[zeros] = z_samples @@ -427,7 +425,8 @@ def sample_parc (surf, seg, imglut, surflut, outaparc, cortex=None, projmm=0.0, # Command Line options are error checking done here options = options_parse() - sample_parc(options.insurf, options.inseg, options.seglut, options.surflut, options.outaparc, options.incort, options.projmm, options.radius) + sample_parc(options.insurf, options.inseg, options.seglut, options.surflut, + options.outaparc, options.incort, options.projmm, options.radius) sys.exit(0) diff --git a/recon_surf/smooth_aparc.py b/recon_surf/smooth_aparc.py index 43a063e8..3c5053e6 100644 --- a/recon_surf/smooth_aparc.py +++ b/recon_surf/smooth_aparc.py @@ -19,12 +19,12 @@ # IMPORTS import optparse import sys -import numpy as np + import nibabel.freesurfer.io as fs +import numpy as np from numpy import typing as npt from scipy import sparse - HELPTEXT = """ Script to fill holes and smooth aparc labels. @@ -131,7 +131,7 @@ def mode_filter( adjM: sparse.csr_matrix, labels: npt.NDArray[str], fillonlylabel = None, - novote: npt.ArrayLike = [] + novote: npt.ArrayLike = None ) -> npt.NDArray[int]: """ Apply mode filter (smoothing) to integer labels on mesh vertices. @@ -148,7 +148,7 @@ def mode_filter( fillonlylabel : int Label to fill exclusively. Defaults to None to smooth all labels. novote : npt.ArrayLike - Label ids that should not vote. Defaults to []. + Label ids that should not vote. Defaults to None. Returns ------- @@ -223,7 +223,7 @@ def mode_filter( # get rid of entries that should not vote # since we have only rows that were non-uniform, they should not become empty # rows may become unform: we still need to vote below to update this label - if novote: + if novote is not None: rr = np.isin(nlabels.data, novote) nlabels.data[rr] = 0 nlabels.eliminate_zeros() @@ -338,7 +338,9 @@ def smooth_aparc(surf, labels, cortex = None): if ids.size == idssize: # no more improvement, strange could be an island in the cortex label that cannot be filled print( - "Warning: Cannot improve but still have holes. Maybe there is an island in the cortex label that cannot be filled with real labels." + "Warning: Cannot improve, but still have holes." \ + " Maybe there is an island in the cortex " \ + " label that cannot be filled with real labels?" ) fillids = np.where(labels == fillonlylabel)[0] labels[fillids] = 0 @@ -385,16 +387,16 @@ def main( Surface filepath and name of destination. """ # read input files - print("Reading in surface: {} ...".format(insurfname)) + print(f"Reading in surface: {insurfname} ...") surf = fs.read_geometry(insurfname, read_metadata=True) - print("Reading in annotation: {} ...".format(inaparcname)) + print(f"Reading in annotation: {inaparcname} ...") aparc = fs.read_annot(inaparcname) - print("Reading in cortex label: {} ...".format(incortexname)) + print(f"Reading in cortex label: {incortexname} ...") cortex = fs.read_label(incortexname) # set labels (n) and triangles (n x 3) labels = aparc[0] slabels = smooth_aparc(surf, labels, cortex) - print("Outputting fixed annot: {}".format(outaparcname)) + print(f"Outputting fixed annot: {outaparcname}") fs.write_annot(outaparcname, slabels, aparc[1], aparc[2]) diff --git a/recon_surf/spherically_project.py b/recon_surf/spherically_project.py index ddb58ce4..6a88b8b8 100644 --- a/recon_surf/spherically_project.py +++ b/recon_surf/spherically_project.py @@ -14,13 +14,14 @@ # IMPORTS +import math import optparse import sys + import nibabel.freesurfer.io as fs import numpy as np -import math -from lapy.diffgeo import tria_mean_curvature_flow from lapy import TriaMesh +from lapy.diffgeo import tria_mean_curvature_flow from lapy.solver import Solver HELPTEXT = """ @@ -182,12 +183,12 @@ def get_flipped_area(tria): l31 = abs(cmax3[1] - cmin3[1]) if l11 < l21 or l11 < l31: print("ERROR: direction 1 should be (anterior -posterior) but is not!") - print(" debug info: {} {} {} ".format(l11, l21, l31)) + print(f" debug info: {l11} {l21} {l31} ") # sys.exit(1) raise ValueError("Direction 1 should be anterior - posterior") # only flip direction if necessary - print("ev1 min: {} max {} ".format(cmin1, cmax1)) + print(f"ev1 min: {cmin1} max {cmax1} ") # axis 1 = y is aligned with this function (for brains in FS space) v1 = cmax1 - cmin1 if cmax1[1] < cmin1[1]: @@ -209,7 +210,7 @@ def get_flipped_area(tria): if l33 < l23: print("WARNING: direction 3 wants to swap with 2, but cannot") - print("ev2 min: {} max {} ".format(cmin2, cmax2)) + print(f"ev2 min: {cmin2} max {cmax2} ") # axis 2 = z is aligned with this function (for brains in FS space) v2 = cmax2 - cmin2 if cmax2[2] < cmin2[2]: @@ -217,7 +218,7 @@ def get_flipped_area(tria): print("inverting direction 2 (superior - inferior)") l2 = abs(cmax2[2] - cmin2[2]) - print("ev3 min: {} max {} ".format(cmin3, cmax3)) + print(f"ev3 min: {cmin3} max {cmax3} ") # axis 0 = x is aligned with this function (for brains in FS space) v3 = cmax3 - cmin3 if cmax3[0] < cmin3[0]: @@ -229,13 +230,13 @@ def get_flipped_area(tria): v2 = v2 * (1.0 / np.sqrt(np.sum(v2 * v2))) v3 = v3 * (1.0 / np.sqrt(np.sum(v3 * v3))) spatvol = abs(np.dot(v1, np.cross(v2, v3))) - print("spat vol: {}".format(spatvol)) + print(f"spat vol: {spatvol}") mvol = tria.volume() - print("orig mesh vol {}".format(mvol)) + print(f"orig mesh vol {mvol}") bvol = l1 * l2 * l3 - print("box {}, {}, {} volume: {} ".format(l1, l2, l3, bvol)) - print("box coverage: {}".format(bvol / mvol)) + print(f"box {l1}, {l2}, {l3} volume: {bvol} ") + print(f"box coverage: {bvol / mvol}") # we map evN to -1..0..+1 (keep zero level fixed) # I have the feeling that this helps a little with the stretching @@ -275,14 +276,14 @@ def get_flipped_area(tria): trianew = TriaMesh(vn, tria.t) svol = trianew.area() / (4.0 * math.pi * 10000) - print("sphere area fraction: {} ".format(svol)) + print(f"sphere area fraction: {svol} ") flippedarea = get_flipped_area(trianew) / (4.0 * math.pi * 10000) if flippedarea > 0.95: print("ERROR: global normal flip, exiting ..") raise ValueError("global normal flip") - print("flipped area fraction: {} ".format(flippedarea)) + print(f"flipped area fraction: {flippedarea} ") if svol < 0.99: print("ERROR: sphere area fraction should be above .99, exiting ..") @@ -332,9 +333,9 @@ def spherically_project_surface( surf_to_project = options.input_surf projected_surf = options.output_surf - print("Reading in surface: {} ...".format(surf_to_project)) + print(f"Reading in surface: {surf_to_project} ...") # switching cholmod off will be slower, but does not require scikit sparse cholmod spherically_project_surface(surf_to_project, projected_surf, use_cholmod=False) - print("Outputting spherically projected surface: {}".format(projected_surf)) + print(f"Outputting spherically projected surface: {projected_surf}") sys.exit(0) diff --git a/recon_surf/spherically_project_wrapper.py b/recon_surf/spherically_project_wrapper.py index 82035bf3..1cfe8a80 100644 --- a/recon_surf/spherically_project_wrapper.py +++ b/recon_surf/spherically_project_wrapper.py @@ -14,10 +14,10 @@ # IMPORTS -import shlex import argparse +import shlex +from subprocess import PIPE, Popen from typing import Any -from subprocess import Popen, PIPE def setup_options(): @@ -101,14 +101,12 @@ def spherical_wrapper(command1: str, command2: str, **kwargs: Any) -> int: Return code of command1. If command1 failed return code of command2. """ # First try to run standard spherical project - print("Running command: {}".format(command1)) + print(f"Running command: {command1}") code_1 = call(command1, **kwargs) if code_1 != 0: print( - "Command {} failed.\nRunning fallback command: {}".format( - command1, command2 - ) + f"Command {command1} failed.\nRunning fallback command: {command2}" ) code_1 = call(command2, **kwargs) diff --git a/recon_surf/utils/extract_recon_surf_time_info.py b/recon_surf/utils/extract_recon_surf_time_info.py index dd8897b9..aebee170 100644 --- a/recon_surf/utils/extract_recon_surf_time_info.py +++ b/recon_surf/utils/extract_recon_surf_time_info.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 +import argparse +import locale from datetime import datetime, timedelta from pathlib import Path import dateutil.parser -import argparse import yaml -import locale def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float: From a951fd74e71a418c2a16aad6a67512d8e809fa95 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 12:37:47 +0200 Subject: [PATCH 05/31] sort include and other formatting fixes --- Docker/build.py | 45 ++++++++++++++++++++++--------------------- Docker/install_env.py | 8 +++----- pyproject.toml | 2 -- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/Docker/build.py b/Docker/build.py index 6467c7bf..e3301db5 100755 --- a/Docker/build.py +++ b/Docker/build.py @@ -18,12 +18,12 @@ # June 27th 2023 import argparse +import logging import os import subprocess +from collections.abc import Sequence from pathlib import Path -from typing import Tuple, Literal, Sequence, Optional, Dict, get_args, cast, List, Callable, Union -import logging - +from typing import Literal, cast, get_args logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ class DEFAULTS: # torch 1.12.0 comes compiled with cu113, cu116, rocm5.0 and rocm5.1.1 # torch 2.0.1 comes compiled with cu117, cu118, and rocm5.4.2 # torch 2.4 comes compiled with cu118, cu121, cu124 and rocm6.1 - MapDeviceType: Dict[AllDeviceType, DeviceType] = dict( + MapDeviceType: dict[AllDeviceType, DeviceType] = dict( ((d, d) for d in get_args(DeviceType)), rocm="rocm6.1", cuda="cu124", @@ -114,7 +114,7 @@ class CacheSpec: _type: CacheType _params: dict - CACHE_PARAMETERS: Dict[CacheType, Tuple[List[str], List[str]]] = { + CACHE_PARAMETERS: dict[CacheType, tuple[list[str], list[str]]] = { "inline": ([], []), "registry": ( ["ref", "compression", "compression-level", "force-compression", @@ -150,7 +150,7 @@ def type(self) -> CacheType: @type.setter def type(self, _type: str): - from typing import get_args, cast + from typing import cast, get_args if _type.lower() in get_args(CacheType): self._type = cast(CacheType, _type) else: @@ -205,7 +205,7 @@ def make_parser() -> argparse.ArgumentParser: type=target, choices=get_args(Target), metavar="target", - help=f"""target to build (from list of targets below, defaults to runtime):
+ help="""target to build (from list of targets below, defaults to runtime):
- build_conda: "finished" conda build image
- build_freesurfer: "finished" freesurfer build image
- runtime: final fastsurfer runtime image""", @@ -317,7 +317,7 @@ def make_parser() -> argparse.ArgumentParser: def red(skk): - return "\033[91m {}\033[00m" .format(skk) + return f"\033[91m {skk}\033[00m" def get_builder( @@ -326,8 +326,8 @@ def get_builder( require_builder_type: bool = False, ) -> tuple[bool, str]: """Get the builder to build the fastsurfer image.""" - from subprocess import PIPE from re import compile + from subprocess import PIPE buildx_binfo = Popen(["docker", "buildx", "ls"], stdout=PIPE, stderr=PIPE).finish() header, *lines = buildx_binfo.out_str("utf-8").strip().split("\n") @@ -373,7 +373,7 @@ def get_builder( def docker_build_image( image_name: str, dockerfile: Path, - working_directory: Optional[Path] = None, + working_directory: Path | None = None, context: Path | str = ".", dry_run: bool = False, attestation: bool = False, @@ -431,7 +431,7 @@ def docker_build_image( raise ValueError(f"Invalid Value for 'action' {action}, must be load or push.") def to_pair(key, values): - if isinstance(values, Sequence) and isinstance(values, (str, bytes)): + if isinstance(values, Sequence) and isinstance(values, str | bytes): values = [values] key_dashed = key.replace("_", "-") # concatenate the --key_dashed value pairs @@ -484,10 +484,10 @@ def is_inline_cache(cache_kw): require_container, ) if has_storage and action == "load": - image_type = f"docker" + image_type = "docker" elif action == "push": # with containerd storage driver or pushing to registry - image_type = f"image" + image_type = "image" # both support attestation no problem elif action == "export": experimental = ". No image will be imported. This features is experimental." @@ -513,7 +513,7 @@ def is_inline_cache(cache_kw): # Future Alternative: save the image to preserve the manifest files to file else: # no attestation, docker builder supports this format - image_type = f"docker" + image_type = "docker" args.extend(["--output", f"type={image_type},name={image_name}"]) if not bool(import_after_args): @@ -565,7 +565,7 @@ def forward_output_to_logger(process): def singularity_build_image( image_name: str, singularity_image: Path, - working_directory: Optional[Path] = None, + working_directory: Path | None = None, dry_run: bool = False, ): """ @@ -608,17 +608,18 @@ def singularity_build_image( def main( device: DeviceType, - cache: Optional[CacheSpec] = None, + cache: CacheSpec | None = None, target: Target = "runtime", debug: bool = False, - image_tag: Optional[str] = None, + image_tag: str | None = None, dry_run: bool = False, tag_dev: bool = True, - fastsurfer_home: Optional[Path] = None, + fastsurfer_home: Path | None = None, **keywords, ) -> int | str: - from FastSurferCNN.version import has_git, main as version - kwargs: Dict[str, Union[str, List[str]]] = {} + from FastSurferCNN.version import has_git + from FastSurferCNN.version import main as version + kwargs: dict[str, str | list[str]] = {} if cache is not None: if not isinstance(cache, CacheSpec): cache = CacheSpec(cache) @@ -640,7 +641,7 @@ def main( kwargs["target"] = target kwargs["build_arg"] = [f"DEVICE={DEFAULTS.MapDeviceType.get(device, 'cpu')}"] if debug: - kwargs["build_arg"].append(f"DEBUG=true") + kwargs["build_arg"].append("DEBUG=true") build_arg_list = [ "build_base_image", "runtime_base_image", @@ -676,7 +677,7 @@ def main( if ret_version != 0: return f"Creating the version file failed with message: {ret_version}" - with open(build_filename, "r") as build_file: + with open(build_filename) as build_file: from FastSurferCNN.version import parse_build_file build_info = parse_build_file(build_file) diff --git a/Docker/install_env.py b/Docker/install_env.py index 3d03d38d..6107e4b9 100644 --- a/Docker/install_env.py +++ b/Docker/install_env.py @@ -7,7 +7,6 @@ import os.path import re - logger = logging.getLogger(__name__) @@ -61,18 +60,17 @@ def make_parser() -> argparse.ArgumentParser: def main(args): """Function to split a conda env file for pytorch cuda and cpu versions.""" - from operator import xor - mode = getattr(args, 'mode') + mode = args.mode if mode is None: return "ERROR: No mode set." yaml_in = getattr(args, 'yaml_in', None) if yaml_in is None or not os.path.exists(yaml_in): return f"ERROR: yaml environment file {yaml_in} is not valid!" - with open(yaml_in, "r") as f_yaml: + with open(yaml_in) as f_yaml: lines = f_yaml.readlines() - out_file = getattr(args, 'yaml_out') + out_file = args.yaml_out out_file_pointer = open(out_file, "w") if out_file else None # filter yaml file for pip content kwargs = {"sep": "", "end": "", "file": out_file_pointer} diff --git a/pyproject.toml b/pyproject.toml index c6d16506..86d32b53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,8 +115,6 @@ extend-exclude = [ "checkpoints", "doc", "env", - "images", - "setup.py", ] [tool.ruff.lint] From 555c25765a0c26d735b5bb3766363ff3a39e2f87 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 13:05:02 +0200 Subject: [PATCH 06/31] hypvinn fix formatting and ruff warnings --- HypVINN/config/hypvinn.py | 3 +- HypVINN/config/hypvinn_files.py | 1 - HypVINN/data_loader/data_utils.py | 6 ++-- HypVINN/data_loader/dataset.py | 12 +++---- HypVINN/inference.py | 21 ++++++------ HypVINN/models/networks.py | 28 +++++++++------- HypVINN/run_prediction.py | 48 ++++++++++++++------------- HypVINN/utils/__init__.py | 2 +- HypVINN/utils/img_processing_utils.py | 4 +-- HypVINN/utils/mode_config.py | 6 ++-- HypVINN/utils/preproc.py | 5 +-- HypVINN/utils/stats_utils.py | 2 +- HypVINN/utils/visualization_utils.py | 31 ++++++++--------- 13 files changed, 88 insertions(+), 81 deletions(-) diff --git a/HypVINN/config/hypvinn.py b/HypVINN/config/hypvinn.py index 8df5807e..9af9c0f1 100644 --- a/HypVINN/config/hypvinn.py +++ b/HypVINN/config/hypvinn.py @@ -100,7 +100,8 @@ _C.MODEL.MULTI_SMOOTH = False # Brach weights can be aleatory set to zero _C.MODEL.HETERO_INPUT = False -# Flag for replicating any given modality into the two branches. This branch require that the hetero_input also set to TRUE +# Flag for replicating any given modality into the two branches. +# This branch require that the hetero_input also set to TRUE _C.MODEL.DUPLICATE_INPUT = False # ---------------------------------------------------------------------------- # # Training options diff --git a/HypVINN/config/hypvinn_files.py b/HypVINN/config/hypvinn_files.py index db444260..5e1555c8 100644 --- a/HypVINN/config/hypvinn_files.py +++ b/HypVINN/config/hypvinn_files.py @@ -15,7 +15,6 @@ # IMPORTS from FastSurferCNN.utils.checkpoint import FASTSURFER_ROOT - HYPVINN_LUT = FASTSURFER_ROOT / "HypVINN/config/HypVINN_ColorLUT.txt" HYPVINN_STATS_NAME = "hypothalamus.HypVINN.stats" diff --git a/HypVINN/data_loader/data_utils.py b/HypVINN/data_loader/data_utils.py index c7caee71..8bda9e0b 100644 --- a/HypVINN/data_loader/data_utils.py +++ b/HypVINN/data_loader/data_utils.py @@ -20,10 +20,12 @@ from FastSurferCNN.data_loader.conform import getscale, scalecrop from HypVINN.config.hypvinn_global_var import ( - hyposubseg_labels, SAG2FULL_MAP, HYPVINN_CLASS_NAMES, FS_CLASS_NAMES, + FS_CLASS_NAMES, + HYPVINN_CLASS_NAMES, + SAG2FULL_MAP, + hyposubseg_labels, ) - ## # Helper Functions ## diff --git a/HypVINN/data_loader/dataset.py b/HypVINN/data_loader/dataset.py index 776ddb56..01c22a84 100644 --- a/HypVINN/data_loader/dataset.py +++ b/HypVINN/data_loader/dataset.py @@ -13,15 +13,13 @@ # limitations under the License. import numpy as np -from numpy import typing as npt import torch +from numpy import typing as npt from torch.utils.data import Dataset - -from HypVINN.data_loader.data_utils import transform_axial2sagittal,transform_axial2coronal -from FastSurferCNN.data_loader.data_utils import get_thick_slices - import FastSurferCNN.utils.logging as logging +from FastSurferCNN.data_loader.data_utils import get_thick_slices +from HypVINN.data_loader.data_utils import transform_axial2coronal, transform_axial2sagittal from HypVINN.utils import ModalityDict, ModalityMode logger = logging.get_logger(__name__) @@ -121,8 +119,8 @@ def __init__( if ((cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and (self.mode == 't1t2' or cfg.MODEL.DUPLICATE_INPUT)) : logger.info( - f"For inference T1 block weight and the T2 block are set to " - f"the weights learn during training" + "For inference T1 block weight and the T2 block are set to " + "the weights learn during training" ) else: logger.info( diff --git a/HypVINN/inference.py b/HypVINN/inference.py index 07953dae..806a2f29 100644 --- a/HypVINN/inference.py +++ b/HypVINN/inference.py @@ -15,19 +15,19 @@ from time import time from typing import Optional -import torch import numpy as np +import torch import yacs.config -from tqdm import tqdm from torch.utils.data import DataLoader from torchvision import transforms +from tqdm import tqdm import FastSurferCNN.utils.logging as logging -from FastSurferCNN.utils.common import find_device from FastSurferCNN.data_loader.augmentation import ToTensorTest, ZeroPad2DTest -from HypVINN.models.networks import build_model +from FastSurferCNN.utils.common import find_device from HypVINN.data_loader.data_utils import hypo_map_prediction_sagittal2full from HypVINN.data_loader.dataset import HypVINNDataset +from HypVINN.models.networks import build_model from HypVINN.utils import ModalityMode logger = logging.get_logger(__name__) @@ -180,7 +180,7 @@ def load_checkpoint(self, ckpt: str): The path to the checkpoint file. The checkpoint file should be a .pth file containing a state dictionary of a model. """ - logger.info("Loading checkpoint {}".format(ckpt)) + logger.info(f"Loading checkpoint {ckpt}") # WARNING: weights_only=False can cause unsafe code execution, but here the # checkpoint can be considered to be from a safe source model_state = torch.load(ckpt, map_location=self.device, weights_only=False) @@ -271,8 +271,8 @@ def get_max_size(self): Returns ------- int or tuple - The maximum size. If the width and height of the output tensor are equal, it returns the width. Otherwise, it - returns both the width and height. + The maximum size. If the width and height of the output tensor are equal, + it returns the width. Otherwise, it returns both the width and height. """ if self.cfg.MODEL.OUT_TENSOR_WIDTH == self.cfg.MODEL.OUT_TENSOR_HEIGHT: return self.cfg.MODEL.OUT_TENSOR_WIDTH @@ -292,7 +292,8 @@ def get_device(self): """ return self.device,self.viewagg_device - #TODO check is possible to modify to CerebNet inference mode from RAS directly to LIA (CerebNet.Inference._predict_single_subject) + #TODO check is possible to modify to CerebNet inference mode from RAS directly to LIA + # (CerebNet.Inference._predict_single_subject) @torch.no_grad() def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float = None) -> torch.Tensor: """ @@ -318,7 +319,7 @@ def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float self.model.eval() start_index = 0 - for batch_idx, batch in tqdm(enumerate(val_loader), total=len(val_loader)): + for _batch_idx, batch in tqdm(enumerate(val_loader), total=len(val_loader)): images = batch["image"].to(self.device) scale_factors = batch["scale_factor"].to(self.device) @@ -341,7 +342,7 @@ def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float pred_prob[start_index:start_index + pred.shape[0],:, :, :] += torch.mul(pred, 0.2) start_index += pred.shape[0] - logger.info("---> {} Model Testing Done.".format(self.cfg.DATA.PLANE)) + logger.info(f"---> {self.cfg.DATA.PLANE} Model Testing Done.") return pred_prob diff --git a/HypVINN/models/networks.py b/HypVINN/models/networks.py index ca2cfbfd..75f2a038 100644 --- a/HypVINN/models/networks.py +++ b/HypVINN/models/networks.py @@ -15,23 +15,23 @@ # IMPORTS -from typing import Dict -import yacs.config -from torch import Tensor, nn +import numpy as np import torch -import FastSurferCNN.models.sub_module as sm +import yacs.config +from torch import nn + import FastSurferCNN.models.interpolation_layer as il +import FastSurferCNN.models.sub_module as sm from FastSurferCNN.models.networks import FastSurferCNNBase -import numpy as np class HypVINN(FastSurferCNNBase): """ HypVINN class that extends the FastSurferCNNBase class. - This class represents a HypVINN model. It includes methods for initializing the model, setting up the layers, - and performing forward propagation. + This class represents a HypVINN model. It includes methods for initializing the model, + setting up the layers, and performing forward propagation. Attributes ---------- @@ -42,9 +42,11 @@ class HypVINN(FastSurferCNNBase): out_tensor_shape : tuple The shape of the output tensor. interpolation_mode : str - The interpolation mode to use when resizing the images. This can be 'nearest', 'bilinear', 'bicubic', or 'area'. + The interpolation mode to use when resizing the images. This can be 'nearest', 'bilinear', + 'bicubic', or 'area'. crop_position : str - The position to crop the images from. This can be 'center', 'top_left', 'top_right', 'bottom_left', or 'bottom_right'. + The position to crop the images from. This can be 'center', 'top_left', 'top_right', + 'bottom_left', or 'bottom_right'. m1_inp_block : InputDenseBlock The input block for the first modality. m2_inp_block : InputDenseBlock @@ -89,7 +91,7 @@ def __init__(self, params, padded_size=256): params["num_channels"] = params["num_filters_interpol"] - super(HypVINN, self).__init__(params) + super().__init__(params) # Flex options self.height = params["height"] @@ -172,7 +174,8 @@ def forward(self, x: torch.Tensor, scale_factor: torch.Tensor, weight_factor: to weight_factor : torch.Tensor The weight factor for the two modalities. It should have a shape of (batch_size, 2). scale_factor_out : torch.Tensor, optional - The scale factor for the output images. If not provided, it defaults to the scale factor of the input images. + The scale factor for the output images. If not provided, it defaults to the scale factor + of the input images. Returns ------- @@ -206,7 +209,8 @@ def forward(self, x: torch.Tensor, scale_factor: torch.Tensor, weight_factor: to # Shared latent space skip_encoder_0 = mw1 * skip_encoder_01 + mw2 * skip_encoder_02 - encoder_output0, rescale_factor = self.interpol1(skip_encoder_0, scale_factor) # instead of maxpool = encoder_output_0 + # instead of maxpool = encoder_output_0: + encoder_output0, rescale_factor = self.interpol1(skip_encoder_0, scale_factor) # FastSurferCNN Base decoder_output1 = super().forward(encoder_output0, scale_factor=scale_factor) diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index d6945914..dfe487fd 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # IMPORTS -from typing import TYPE_CHECKING, Optional, cast, Literal import argparse from pathlib import Path from time import time +from typing import TYPE_CHECKING, Literal, cast import numpy as np -from numpy import typing as npt import torch +from numpy import typing as npt if TYPE_CHECKING: import yacs.config @@ -30,9 +30,8 @@ get_checkpoints, load_checkpoint_config_defaults, ) -from FastSurferCNN.utils.common import assert_no_root, SerialExecutor - -from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME, HYPVINN_MASK_NAME +from FastSurferCNN.utils.common import SerialExecutor, assert_no_root +from HypVINN.config.hypvinn_files import HYPVINN_MASK_NAME, HYPVINN_SEG_NAME from HypVINN.data_loader.data_utils import hypo_map_label2subseg, rescale_image from HypVINN.inference import Inference from HypVINN.utils import ModalityDict, ModalityMode, ViewOperations @@ -52,7 +51,7 @@ ## -def optional_path(a: Path | str) -> Optional[Path]: +def optional_path(a: Path | str) -> Path | None: """ Convert a string to a Path object or None. @@ -163,8 +162,8 @@ def option_parse() -> argparse.ArgumentParser: def main( out_dir: Path, - t2: Optional[Path], - orig_name: Optional[Path], + t2: Path | None, + orig_name: Path | None, sid: str, ckpt_ax: Path, ckpt_cor: Path, @@ -183,7 +182,7 @@ def main( device: str = "auto", viewagg_device: str = "auto", ) -> int | str: - f""" + """ Main function of the hypothalamus segmentation module. Parameters @@ -208,10 +207,10 @@ def main( The path to the coronal configuration file. cfg_sag : Path The path to the sagittal configuration file. - hypo_segfile : str, default="{HYPVINN_SEG_NAME}" - The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}. - hypo_maskfile : str, default="{HYPVINN_MASK_NAME}" - The name of the hypothalamus mask file. Default is {HYPVINN_MASK_NAME}. + hypo_segfile : str, default is in HYPVINN_SEG_NAME as specified in config. + The name of the hypothalamus segmentation file. Default is in HYPVINN_SEG_NAME. + hypo_maskfile : str, default is in HYPVINN_MASK_NAME + The name of the hypothalamus mask file. Default is in HYPVINN_MASK_NAME. allow_root : bool, default=False Whether to allow running as root user. Default is False. qc_snapshots : bool, optional @@ -236,7 +235,7 @@ def main( 0, if successful, an error message describing the cause for the failure otherwise. """ - from concurrent.futures import ProcessPoolExecutor, Future + from concurrent.futures import Future, ProcessPoolExecutor if threads != 1: pool = ProcessPoolExecutor(threads) else: @@ -301,7 +300,7 @@ def main( cfgs = (cfg_ax, cfg_cor, cfg_sag) ckpts = (ckpt_ax, ckpt_cor, ckpt_sag) - for plane, _cfg_file, _ckpt_file in zip(PLANES, cfgs, ckpts): + for plane, _cfg_file, _ckpt_file in zip(PLANES, cfgs, ckpts, strict=False): logger.info(f"{plane} model configuration from {_cfg_file}") view_ops[plane] = { "cfg": set_up_cfgs(_cfg_file, subject_dir, batch_size), @@ -374,7 +373,7 @@ def main( ) logger.info(f"Prediction successfully saved in {time_needed} seconds.") if qc_snapshots: - qc_future: Optional[Future] = pool.submit( + qc_future: Future | None = pool.submit( plot_qc_images, subject_qc_dir=subject_dir / "qc_snapshots", orig_path=orig_path, @@ -438,8 +437,8 @@ def prepare_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag): def load_volumes( mode: ModalityMode, - t1_path: Optional[Path] = None, - t2_path: Optional[Path] = None, + t1_path: Path | None = None, + t2_path: Path | None = None, ) -> tuple[ ModalityDict, npt.NDArray[float], @@ -450,8 +449,9 @@ def load_volumes( """ Load the volumes of T1 and T2 images. - This function loads the T1 and T2 images, checks their compatibility based on the mode, and returns the loaded - volumes along with their affine transformations, headers, zoom levels, and sizes. + This function loads the T1 and T2 images, checks their compatibility based + on the mode, and returns the loaded volumes along with their affine + transformations, headers, zoom levels, and sizes. Parameters ---------- @@ -466,7 +466,8 @@ def load_volumes( ------- tuple A tuple containing the following elements: - - modalities: A dictionary with keys 't1' and/or 't2' and values being the corresponding loaded and rescaled images. + - modalities: A dictionary with keys 't1' and/or 't2' and values + being the corresponding loaded and rescaled images. - affine: The affine transformation of the loaded image(s). - header: The header of the loaded image(s). - zoom: The zoom level of the loaded image(s). @@ -475,7 +476,8 @@ def load_volumes( Raises ------ RuntimeError - If the mode is inconsistent with the provided image paths, or if the number of dimensions of the data is invalid. + If the mode is inconsistent with the provided image paths, + or if the number of dimensions of the data is invalid. ValueError If the mode is invalid, or if a header is missing. AssertionError @@ -489,7 +491,7 @@ def load_volumes( t1_zoom = () t2_zoom = () affine: npt.NDArray[float] = np.ndarray([0]) - header: Optional["FileBasedHeader"] = None + header: FileBasedHeader | None = None zoom: tuple[float, float, float] = (0.0, 0.0, 0.0) size: tuple[int, ...] = (0, 0, 0) diff --git a/HypVINN/utils/__init__.py b/HypVINN/utils/__init__.py index 2f2b06fe..476f7923 100644 --- a/HypVINN/utils/__init__.py +++ b/HypVINN/utils/__init__.py @@ -4,7 +4,7 @@ from FastSurferCNN.utils import Plane -ViewOperations = dict[Plane, Optional[dict[Literal["cfg", "ckpt"], Any]]] +ViewOperations = dict[Plane, dict[Literal["cfg", "ckpt"], Any] | None] ModalityMode = Literal["t1", "t2", "t1t2"] ModalityDict = dict[Literal["t1", "t2"], ndarray] RegistrationMode = Literal["robust", "coreg", "none"] diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py index f96adc29..e6243c74 100644 --- a/HypVINN/utils/img_processing_utils.py +++ b/HypVINN/utils/img_processing_utils.py @@ -14,11 +14,11 @@ from pathlib import Path +import nibabel as nib import numpy as np from numpy import typing as npt -import nibabel as nib -from skimage.measure import label from scipy import ndimage +from skimage.measure import label import FastSurferCNN.utils.logging as logging from HypVINN.data_loader.data_utils import hypo_map_subseg_2_fsseg diff --git a/HypVINN/utils/mode_config.py b/HypVINN/utils/mode_config.py index 6f140265..e94980bd 100644 --- a/HypVINN/utils/mode_config.py +++ b/HypVINN/utils/mode_config.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from pathlib import Path -from typing import Optional from FastSurferCNN.utils import logging from HypVINN.utils import ModalityMode @@ -23,8 +21,8 @@ def get_hypinn_mode( - t1_path: Optional[Path] = None, - t2_path: Optional[Path] = None, + t1_path: Path | None = None, + t2_path: Path | None = None, ) -> ModalityMode: """ Determine the input mode for HypVINN based on the existence of T1 and T2 files. diff --git a/HypVINN/utils/preproc.py b/HypVINN/utils/preproc.py index c4dbbc19..55f9fccf 100644 --- a/HypVINN/utils/preproc.py +++ b/HypVINN/utils/preproc.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import time from pathlib import Path -import os from typing import cast import nibabel as nib @@ -65,9 +65,10 @@ def t1_to_t2_registration( If mri_coreg, mri_vol2vol, or mri_robust_register fails to run or if they cannot be found. """ + import shutil + from FastSurferCNN.utils.run_tools import Popen from FastSurferCNN.utils.threads import get_num_threads - import shutil if threads <= 0: threads = get_num_threads() diff --git a/HypVINN/utils/stats_utils.py b/HypVINN/utils/stats_utils.py index 5fee8ca5..792cce93 100644 --- a/HypVINN/utils/stats_utils.py +++ b/HypVINN/utils/stats_utils.py @@ -48,8 +48,8 @@ def compute_stats( """ from collections import namedtuple - from FastSurferCNN.utils.checkpoint import FASTSURFER_ROOT from FastSurferCNN.segstats import main + from FastSurferCNN.utils.checkpoint import FASTSURFER_ROOT from HypVINN.config.hypvinn_files import HYPVINN_STATS_NAME from HypVINN.config.hypvinn_global_var import FS_CLASS_NAMES diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py index a43aaf9b..63fd9e79 100644 --- a/HypVINN/utils/visualization_utils.py +++ b/HypVINN/utils/visualization_utils.py @@ -11,17 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os.path from pathlib import Path -import numpy as np -import nibabel as nib import matplotlib.pyplot as plt +import nibabel as nib +import numpy as np +#from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT from HypVINN.config.hypvinn_files import HYPVINN_LUT -from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT -_doc_HYPVINN_LUT = os.path.relpath(HYPVINN_LUT, FASTSURFER_ROOT) +#_doc_HYPVINN_LUT = os.path.relpath(HYPVINN_LUT, FASTSURFER_ROOT) def remove_values_from_list(the_list, val): @@ -44,14 +43,14 @@ def remove_values_from_list(the_list, val): def get_lut(lookup_table_path: Path = HYPVINN_LUT): - f""" + """ Retrieve a color lookup table (LUT) from a file. This function reads a file and constructs a lookup table (LUT) from it. Parameters ---------- - lookup_table_path: Path, default="{_doc_HYPVINN_LUT}" + lookup_table_path: Path, defaults to local LUT" The path to the file from which the LUT will be constructed. Returns @@ -61,7 +60,7 @@ def get_lut(lookup_table_path: Path = HYPVINN_LUT): """ from collections import OrderedDict lut = OrderedDict() - with open(lookup_table_path, "r") as f: + with open(lookup_table_path) as f: for line in f: if line[0] == "#" or line[0] == "\n": pass @@ -73,14 +72,14 @@ def get_lut(lookup_table_path: Path = HYPVINN_LUT): def map_hyposeg2label(hyposeg: np.ndarray, lut_file: Path = HYPVINN_LUT): - f""" + """ Map a HypVINN segmentation to a continuous label space using a lookup table. Parameters ---------- hyposeg : np.ndarray The original segmentation map. - lut_file : Path, default="{_doc_HYPVINN_LUT}" + lut_file : Path, defaults to local LUT" The path to the lookup table file. Returns @@ -171,8 +170,10 @@ def plot_coronal_predictions(cmap, images_batch=None, pred_batch=None, img_per_r pred = torch.from_numpy(pred_batch.copy()) pred = torch.unsqueeze(pred, 1) pred_grid = utils.make_grid(pred.cpu(), nrow=img_per_row)[0] # dont take the channels axis from grid - # pred_grid=color.label2rgb(pred_grid.numpy(),grid.numpy().transpose(1, 2, 0),alpha=0.6,bg_label=0,colors=DEFAULT_COLORS) - # pred_grid = color.label2rgb(pred_grid.numpy(), grid.numpy().transpose(1, 2, 0), alpha=0.6, bg_label=0,bg_color=None,colors=DEFAULT_COLORS) + # pred_grid=color.label2rgb(pred_grid.numpy(),grid.numpy().transpose(1, 2, 0), \ + # alpha=0.6,bg_label=0,colors=DEFAULT_COLORS) + # pred_grid = color.label2rgb(pred_grid.numpy(), grid.numpy().transpose(1, 2, 0), \ + # alpha=0.6, bg_label=0,bg_color=None,colors=DEFAULT_COLORS) alphas = np.ones(pred_grid.numpy().shape) * 0.8 alphas[pred_grid.numpy() == 0] = 0 @@ -245,7 +246,7 @@ def plot_qc_images( padd: int = 45, lut_file: Path = HYPVINN_LUT, slice_step: int = 2): - f""" + """ Plot the quality control images for the subject. Parameters @@ -258,15 +259,15 @@ def plot_qc_images( The path to the predicted image. padd : int, default=45 The padding value for cropping the images and segmentations. - lut_file : Path, default="{_doc_HYPVINN_LUT}" + lut_file : Path, defaults to local LUT" The path to the lookup table file. slice_step : int, default=2 The step size for selecting indices from the predicted segmentation. """ from scipy import ndimage - from HypVINN.data_loader.data_utils import transform_axial2coronal, hypo_map_subseg_2_fsseg from HypVINN.config.hypvinn_files import HYPVINN_QC_IMAGE_NAME + from HypVINN.data_loader.data_utils import hypo_map_subseg_2_fsseg, transform_axial2coronal subject_qc_dir.mkdir(exist_ok=True, parents=True) From 4bba70604c9bacb669a6bc0b4f042789f000a57f Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 14:34:08 +0200 Subject: [PATCH 07/31] include sort, fixes, formatting --- CerebNet/apply_warp.py | 6 ++-- CerebNet/config/__init__.py | 1 + CerebNet/data_loader/augmentation.py | 18 +++++------ CerebNet/data_loader/data_utils.py | 2 +- CerebNet/data_loader/dataset.py | 35 +++++++++----------- CerebNet/data_loader/loader.py | 5 ++- CerebNet/datasets/generate_hdf5.py | 18 ++++------- CerebNet/datasets/load_data.py | 6 ++-- CerebNet/datasets/utils.py | 46 +++++++++++++------------- CerebNet/datasets/wm_merge_clean.py | 13 ++++---- CerebNet/inference.py | 48 ++++++++++++++-------------- CerebNet/models/networks.py | 7 ++-- CerebNet/models/sub_module.py | 18 +++++------ CerebNet/run_prediction.py | 14 ++++---- CerebNet/utils/checkpoint.py | 10 ++---- CerebNet/utils/load_config.py | 6 ++-- CerebNet/utils/lr_scheduler.py | 19 ++++++----- CerebNet/utils/meters.py | 23 ++++--------- CerebNet/utils/metrics.py | 6 ++-- CerebNet/utils/misc.py | 16 +++++----- 20 files changed, 144 insertions(+), 173 deletions(-) diff --git a/CerebNet/apply_warp.py b/CerebNet/apply_warp.py index fb5fbb85..22f209cc 100644 --- a/CerebNet/apply_warp.py +++ b/CerebNet/apply_warp.py @@ -1,4 +1,7 @@ import argparse +from os.path import join + +import nibabel as nib # Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn # @@ -13,12 +16,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - # IMPORTS import numpy as np -import nibabel as nib -from os.path import join from CerebNet.datasets import utils diff --git a/CerebNet/config/__init__.py b/CerebNet/config/__init__.py index 926ad07f..38198f41 100644 --- a/CerebNet/config/__init__.py +++ b/CerebNet/config/__init__.py @@ -15,6 +15,7 @@ # IMPORTS from CerebNet.config.cerebnet import get_cfg_cerebnet from CerebNet.config.dataset import get_cfg_dataset + __all__ = [ "cerebnet", "dataset", diff --git a/CerebNet/data_loader/augmentation.py b/CerebNet/data_loader/augmentation.py index c67b2fde..67df0876 100644 --- a/CerebNet/data_loader/augmentation.py +++ b/CerebNet/data_loader/augmentation.py @@ -16,21 +16,19 @@ # IMPORTS import numbers import random -from typing import Optional import numpy as np import torch from numpy import random as npr -from scipy.ndimage import gaussian_filter, affine_transform +from scipy.ndimage import affine_transform, gaussian_filter from scipy.stats import median_abs_deviation from torchvision import transforms from CerebNet.data_loader.data_utils import FLIPPED_LABELS - # Transformations for training -class ToTensor(object): +class ToTensor: """ Convert ndarrays in sample to Tensors. """ @@ -66,7 +64,7 @@ def _apply_img(self, img): return super()._apply_img(img.transpose((2, 0, 1))) -class RandomAffine(object): +class RandomAffine: """ Apply a random affine transformation to images, label and weight @@ -99,7 +97,7 @@ def _get_random_affine(self): degrees = (-self.degree, self.degree) else: assert ( - isinstance(self.degree, (tuple, list)) and len(self.degree) == 2 + isinstance(self.degree, tuple | list) and len(self.degree) == 2 ), "degrees should be a list or tuple and it must be of length 2." if isinstance(self.translate, numbers.Number): if not (0.0 <= self.translate <= 1.0): @@ -107,7 +105,7 @@ def _get_random_affine(self): translate = (self.translate, self.translate) else: assert ( - isinstance(self.translate, (tuple, list)) and len(self.translate) == 2 + isinstance(self.translate, tuple | list) and len(self.translate) == 2 ), "translate should be a list or tuple and it must be of length 2." for t in self.translate: if not (0.0 <= t <= 1.0): @@ -159,7 +157,7 @@ def __call__(self, sample): return sample -class RandomFlip(object): +class RandomFlip: """ Random horizontal flipping. """ @@ -196,7 +194,7 @@ class RandomBiasField: def __init__( self, cfg, - seed: Optional[int] = None, + seed: int | None = None, ): """ Initialize the RandomBiasField object with configuration and optional seed. @@ -287,7 +285,7 @@ def __call__(self, sample): return sample -class RandomLabelsToImage(object): +class RandomLabelsToImage: """ Generate image from segmentation using the dataset intensity priors. diff --git a/CerebNet/data_loader/data_utils.py b/CerebNet/data_loader/data_utils.py index f8cf8915..3c5c57e1 100644 --- a/CerebNet/data_loader/data_utils.py +++ b/CerebNet/data_loader/data_utils.py @@ -402,7 +402,7 @@ def uncrop_volume(vol, uncrop_shape, roi): def get_binary_map(lbl_map, class_names): - bin_map = np.logical_or.reduce(list(map(lambda l: lbl_map == l, class_names))) + bin_map = np.logical_or.reduce(list(map(lambda lb: lbl_map == lb, class_names))) return bin_map diff --git a/CerebNet/data_loader/dataset.py b/CerebNet/data_loader/dataset.py index f83a61af..8f195d60 100644 --- a/CerebNet/data_loader/dataset.py +++ b/CerebNet/data_loader/dataset.py @@ -13,31 +13,30 @@ # limitations under the License. # IMPORTS -from typing import Tuple, Literal, TypeVar, Dict from numbers import Number +from typing import Literal, TypeVar +import h5py import nibabel as nib -import torch import numpy as np +import torch from numpy import typing as npt -import h5py from torch.utils.data.dataset import Dataset from torchvision.transforms import Compose -from FastSurferCNN.utils import logging, Plane +from CerebNet.data_loader import data_utils as utils +from CerebNet.data_loader.augmentation import ToTensor +from CerebNet.datasets.load_data import SubjectLoader +from CerebNet.datasets.utils import bounding_volume_offset, crop_transform from FastSurferCNN.data_loader.data_utils import ( get_thick_slices, transform_axial, transform_sagittal, ) - -from CerebNet.data_loader import data_utils as utils -from CerebNet.data_loader.augmentation import ToTensor -from CerebNet.datasets.load_data import SubjectLoader -from CerebNet.datasets.utils import crop_transform, bounding_volume_offset +from FastSurferCNN.utils import Plane, logging ROIKeys = Literal["source_shape", "offsets", "target_shape"] -LocalizerROI = Dict[ROIKeys, Tuple[int, ...]] +LocalizerROI = dict[ROIKeys, tuple[int, ...]] NT = TypeVar("NT", bound=Number) @@ -118,11 +117,7 @@ def __init__(self, dataset_path, cfg, transforms, load_aux_data): del self.dataset["subject"] logger.info( - "Successfully loaded {} slices in {} plane from {}".format( - self.count, - cfg.DATA.PLANE, - dataset_path, - ) + f"Successfully loaded {self.count} slices in {cfg.DATA.PLANE} plane from {dataset_path}" ) logger.info( @@ -242,7 +237,7 @@ def __init__( self, img_org: nib.analyze.SpatialImage, brain_seg: nib.analyze.SpatialImage, - patch_size: Tuple[int, ...], + patch_size: tuple[int, ...], slice_thickness: int, primary_slice: str, ): @@ -298,10 +293,10 @@ def __init__( "coronal": img, "sagittal": transform_sagittal(img), } - for plane, data in data.items(): + for plane, data_i in data.items(): # data is transformed to 'plane'-direction in axis 2 thick_slices = get_thick_slices( - data, self.slice_thickness + data_i, self.slice_thickness ) # [H, W, n_slices, C] # it seems x and y are flipped with respect to expectations here self.images_per_plane[plane] = np.transpose( @@ -315,7 +310,7 @@ def locate_mask_bbox(self, mask: npt.NDArray[bool]): bbox of min0, min1, ..., max0, max1, ... """ # filter disconnected components - from skimage.measure import regionprops, label + from skimage.measure import label, regionprops label_image = label(mask, connectivity=3) regions = regionprops(label_image) @@ -341,7 +336,7 @@ def plane(self) -> Plane: """Returns the active plane""" return self._plane - def __getitem__(self, index: int) -> Tuple[Plane, np.ndarray]: + def __getitem__(self, index: int) -> tuple[Plane, np.ndarray]: """Get the plane and data belonging to indices given.""" if not (0 <= index < self.images_per_plane[self.plane].shape[0]): diff --git a/CerebNet/data_loader/loader.py b/CerebNet/data_loader/loader.py index 58bf8e4d..cd120e56 100644 --- a/CerebNet/data_loader/loader.py +++ b/CerebNet/data_loader/loader.py @@ -13,13 +13,12 @@ # limitations under the License. # IMPORTS -from torchvision import transforms from torch.utils.data import DataLoader - -from FastSurferCNN.utils import logging +from torchvision import transforms from CerebNet.data_loader import dataset as dset from CerebNet.data_loader.augmentation import ToTensor, get_transform +from FastSurferCNN.utils import logging logger = logging.get_logger(__name__) diff --git a/CerebNet/datasets/generate_hdf5.py b/CerebNet/datasets/generate_hdf5.py index 8566caed..dde5f9d2 100644 --- a/CerebNet/datasets/generate_hdf5.py +++ b/CerebNet/datasets/generate_hdf5.py @@ -14,10 +14,10 @@ # IMPORTS +import time import warnings from collections import defaultdict -from os.path import join, isfile -import time +from os.path import isfile, join import h5py import numpy as np @@ -65,7 +65,7 @@ def _read_warp_dict(self): """ subj2warps = defaultdict(list) all_imgs = [] - with open(join(self.cfg.REG_DATA_DIR, self.cfg.REG_DATA_CSV), "r") as f: + with open(join(self.cfg.REG_DATA_DIR, self.cfg.REG_DATA_CSV)) as f: for line in f.readlines(): line = line.strip() ids = line.split(",") @@ -84,7 +84,7 @@ def _read_warp_dict(self): all_imgs.append(img_path) subj2warps[i].append((img_path, lbl_path)) else: - warnings.warn(f"Warp field at {img_path} not found.") + warnings.warn(f"Warp field at {img_path} not found.", stacklevel=2) return subj2warps def create_hdf5_dataset( @@ -123,9 +123,7 @@ def create_hdf5_dataset( # try: start = time.time() print( - "Volume Nr: {}/{} Processing MRI Data from {}".format( - idx + 1, len(subjects_list), current_subject - ) + f"Volume Nr: {idx + 1}/{len(subjects_list)} Processing MRI Data from {current_subject}" ) in_data = self.subj_loader.load_subject( @@ -145,9 +143,7 @@ def create_hdf5_dataset( end = time.time() - start print("Number of Cerebellum classes", len(np.unique(in_data["label"]))) print( - "Volume: {} Finished Data Reading and Appending in {:.3f} seconds.".format( - idx + 1, end - ) + f"Volume: {idx + 1} Finished Data Reading and Appending in {end:.3f} seconds." ) # except Exception as e: @@ -156,4 +152,4 @@ def create_hdf5_dataset( self._save_hdf5_file(datasets, dataset_name) end_d = time.time() - start_d - print("Successfully written {} in {:.3f} seconds.".format(dataset_name, end_d)) + print(f"Successfully written {dataset_name} in {end_d:.3f} seconds.") diff --git a/CerebNet/datasets/load_data.py b/CerebNet/datasets/load_data.py index 500362fb..d0a26dfb 100644 --- a/CerebNet/datasets/load_data.py +++ b/CerebNet/datasets/load_data.py @@ -14,8 +14,8 @@ # IMPORTS -from os.path import join, isfile from functools import partial +from os.path import isfile, join import numpy as np @@ -91,9 +91,9 @@ def _load_volumes(self, subject_path, store_talairach=False): img_meta_data = {} orig, _ = utils.load_reorient_rescale_image(orig_path) - print("Orig image {}".format(orig_path)) + print(f"Orig image {orig_path}") - print("Loading from {}".format(subseg_path)) + print(f"Loading from {subseg_path}") subseg_file = utils.load_reorient(subseg_path) cereb_subseg = np.asarray(subseg_file.get_fdata(), dtype=np.int16) img_meta_data["affine"] = subseg_file.affine diff --git a/CerebNet/datasets/utils.py b/CerebNet/datasets/utils.py index b5a93e05..882610bf 100644 --- a/CerebNet/datasets/utils.py +++ b/CerebNet/datasets/utils.py @@ -14,13 +14,14 @@ # IMPORTS -from typing import Tuple, Union, Sequence, Optional, TypeVar, TypedDict, Iterable, Type +from collections.abc import Sequence from pathlib import Path +from typing import TypedDict, TypeVar import nibabel as nib import numpy as np -from numpy import typing as npt import torch +from numpy import typing as npt from FastSurferCNN.data_loader.conform import getscale, scalecrop @@ -112,7 +113,7 @@ def map_size(arr, base_shape, return_border=False): _pad = [] _unpad_borders = [] - for i, j in zip(arr.shape, base_shape): + for i, j in zip(arr.shape, base_shape, strict=False): delta = i - j left = delta // 2 if delta > 0: # crop @@ -186,10 +187,10 @@ def map_size_leg(arr, base_shape, return_border=False): def bounding_volume_offset( - img: Union[np.ndarray, Sequence[int]], - target_img_size: Tuple[int, ...], - image_shape: Optional[Tuple[int, ...]] = None, -) -> Tuple[int, ...]: + img: np.ndarray | Sequence[int], + target_img_size: tuple[int, ...], + image_shape: tuple[int, ...] | None = None, +) -> tuple[int, ...]: """Find the center of the non-zero values in img and returns offsets so this center is in the center of a bounding volume of size target_img_size.""" if isinstance(img, np.ndarray): @@ -201,10 +202,10 @@ def bounding_volume_offset( bbox = img center = ( (_max + _min) / 2 - for _min, _max in zip(bbox[: len(bbox) // 2], bbox[len(bbox) // 2 :]) + for _min, _max in zip(bbox[: len(bbox) // 2], bbox[len(bbox) // 2 :], strict=False) ) offset = tuple( - max(0, int(round(c - ts / 2))) for c, ts in zip(center, target_img_size) + max(0, int(round(c - ts / 2))) for c, ts in zip(center, target_img_size, strict=False) ) img_shape = ( image_shape @@ -216,7 +217,7 @@ def bounding_volume_offset( if img_shape is not None: offset = tuple( min(max(0, o), imgs - ts) - for o, ts, imgs in zip(offset, target_img_size, img_shape) + for o, ts, imgs in zip(offset, target_img_size, img_shape, strict=False) ) if any(o < 0 for o in offset): raise RuntimeError( @@ -357,15 +358,16 @@ def read_lta(file: Path | str) -> LTADict: """Read the LTA info.""" import re from functools import partial + import numpy as np parameter_pattern = re.compile("^\s*([^=]+)\s*=\s*([^#]*)\s*(#.*)") vol_info_pattern = re.compile("^(.*) volume info$") shape_pattern = re.compile("^(\s*\d+)+$") matrix_pattern = re.compile("^(-?\d+\.\S+\s+)+$") - _Type = TypeVar("_Type", bound=Type) + _Type = TypeVar("_Type", bound=type) - def _vector(_a: str, dtype: Type[_Type] = float, count: int = -1) -> list[_Type]: + def _vector(_a: str, dtype: type[_Type] = float, count: int = -1) -> list[_Type]: return np.fromstring(_a, dtype=dtype, count=count, sep=" ").tolist() parameters = { @@ -384,7 +386,7 @@ def _vector(_a: str, dtype: Type[_Type] = float, count: int = -1) -> list[_Type] **{f"{c}ras": partial(_vector, dtype=float) for c in "xyzc"} } - with open(file, "r") as f: + with open(file) as f: lines = f.readlines() items = [] @@ -415,9 +417,9 @@ def _vector(_a: str, dtype: Type[_Type] = float, count: int = -1) -> list[_Type] shape_lines = list(map(tuple, shape_lines)) lta = dict(items) if lta["nxforms"] != len(shape_lines): - raise IOError("Inconsistent lta format: nxforms inconsistent with shapes.") + raise OSError("Inconsistent lta format: nxforms inconsistent with shapes.") if len(shape_lines) > 1 and np.any(np.not_equal([shape_lines[0]], shape_lines[1:])): - raise IOError(f"Inconsistent lta format: shapes inconsistent {shape_lines}") + raise OSError(f"Inconsistent lta format: shapes inconsistent {shape_lines}") lta_matrix = np.asarray(matrix_lines).reshape((-1,) + shape_lines[0].shape) lta["lta"] = lta_matrix return lta @@ -486,7 +488,7 @@ def _crop_transform_make_indices(image_shape, offsets, target_shape): paddings = [] any_pad = False for offset, t_shape, i_shape in zip( - offsets, target_shape, image_shape[batch_dims:] + offsets, target_shape, image_shape[batch_dims:], strict=False ): crop_end = min(offset + t_shape, i_shape) indices.append(slice(max(0, offset), crop_end)) @@ -544,9 +546,9 @@ def _crop_transform_pad_fn(image, pad_tuples, pad): def crop_transform( image: AT, - offsets: Optional[Sequence[int]] = None, - target_shape: Optional[Sequence[int]] = None, - out: Optional[AT] = None, + offsets: Sequence[int] | None = None, + target_shape: Sequence[int] | None = None, + out: AT | None = None, pad: int = 0, ) -> AT: """ @@ -609,13 +611,13 @@ def crop_transform( if target_shape is None: raise ValueError("Either target_shape or offsets must be defined!") _target_shape = image.shape[: -len(target_shape)] + tuple(target_shape) - offsets = tuple(int((i - t) / 2) for t, i in zip(_target_shape, image.shape)) + offsets = tuple(int((i - t) / 2) for t, i in zip(_target_shape, image.shape, strict=False)) len_off = len(offsets) else: len_off = len(offsets) if target_shape is None: _target_shape = image.shape[:-len_off] + tuple( - i - 2 * o for i, o in zip(image.shape[-len_off:], offsets) + i - 2 * o for i, o in zip(image.shape[-len_off:], offsets, strict=False) ) elif len_off != len(target_shape): raise ValueError( @@ -624,7 +626,7 @@ def crop_transform( else: _target_shape = tuple( i if t == -1 else t - for i, t in zip(image.shape[-len_off:], target_shape) + for i, t in zip(image.shape[-len_off:], target_shape, strict=False) ) _target_shape = image.shape[:-len_off] + _target_shape diff --git a/CerebNet/datasets/wm_merge_clean.py b/CerebNet/datasets/wm_merge_clean.py index 4024e78a..766280c5 100644 --- a/CerebNet/datasets/wm_merge_clean.py +++ b/CerebNet/datasets/wm_merge_clean.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from numbers import Number # Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn @@ -13,18 +14,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - # IMPORTS from os.path import join -from typing import TypeVar, Iterable +from typing import TypeVar +import nibabel as nib import numpy as np from numpy import typing as npt -import nibabel as nib from scipy import ndimage +from skimage.measure import label, regionprops from skimage.morphology import binary_dilation -from skimage.measure import regionprops, label NT = TypeVar("NT", bound=Number) @@ -124,7 +123,7 @@ def add_cereb_wm(cereb_subseg, aseg, manual_cereb): :return: """ # to capture small holes - struc = ndimage.generate_binary_structure(3, 1) + # struc = ndimage.generate_binary_structure(3, 1) l_wm_fs = aseg == 7 r_wm_fs = aseg == 46 @@ -244,7 +243,7 @@ def add_cereb_wm(cereb_subseg, aseg, manual_cereb): ) print("4.Filling any remaining holes") - filled_bin_mask = ndimage.binary_fill_holes(dropped_comp_img != 0, structure=struc) + #filled_bin_mask = ndimage.binary_fill_holes(dropped_comp_img != 0, structure=struc) remaining_holes_mask = ( dropped_comp_img != 0 ) # np.logical_xor(filled_bin_mask, dropped_comp_img != 0) diff --git a/CerebNet/inference.py b/CerebNet/inference.py index d597cd95..9ffb0f16 100644 --- a/CerebNet/inference.py +++ b/CerebNet/inference.py @@ -14,30 +14,31 @@ # IMPORTS import time -from pathlib import Path -from typing import Dict, List, Tuple, Optional, TYPE_CHECKING from concurrent.futures import Future, ThreadPoolExecutor +from pathlib import Path +from typing import TYPE_CHECKING import nibabel as nib import numpy as np +import pandas as pd import torch from torch.utils.data import DataLoader from tqdm import tqdm -from FastSurferCNN.utils import logging, Plane, PLANES -from FastSurferCNN.utils.threads import get_num_threads -from FastSurferCNN.utils.mapper import JsonColorLookupTable, TSVLookupTable, Mapper -from FastSurferCNN.utils.common import ( - find_device, - SubjectList, - SubjectDirectory, - SerialExecutor, -) from CerebNet.data_loader.augmentation import ToTensorTest from CerebNet.data_loader.dataset import SubjectDataset from CerebNet.datasets.utils import crop_transform from CerebNet.models.networks import build_model from CerebNet.utils import checkpoint as cp +from FastSurferCNN.utils import PLANES, Plane, logging +from FastSurferCNN.utils.common import ( + SerialExecutor, + SubjectDirectory, + SubjectList, + find_device, +) +from FastSurferCNN.utils.mapper import JsonColorLookupTable, Mapper, TSVLookupTable +from FastSurferCNN.utils.threads import get_num_threads if TYPE_CHECKING: import yacs.config @@ -159,7 +160,7 @@ def __del__(self): if self.pool is not None: self.pool.shutdown(True) - def _load_model(self, cfg) -> Dict[Plane, torch.nn.Module]: + def _load_model(self, cfg) -> dict[Plane, torch.nn.Module]: """Loads the three models per plane.""" def __load_model(cfg: "yacs.config.CfgNode", plane: Plane) -> torch.nn.Module: @@ -181,12 +182,12 @@ def __load_model(cfg: "yacs.config.CfgNode", plane: Plane) -> torch.nn.Module: from functools import partial _load_model_func = partial(__load_model, cfg) - return dict(zip(PLANES, self.pool.map(_load_model_func, PLANES))) + return dict(zip(PLANES, self.pool.map(_load_model_func, PLANES), strict=False)) @torch.no_grad() def _predict_single_subject( self, subject_dataset: SubjectDataset - ) -> Dict[Plane, List[torch.Tensor]]: + ) -> dict[Plane, list[torch.Tensor]]: """Predict the classes based on a SubjectDataset.""" img_loader = DataLoader( subject_dataset, batch_size=self.batch_size, shuffle=False @@ -219,8 +220,8 @@ def _predict_single_subject( return prediction_logits def _post_process_preds( - self, preds: Dict[Plane, List[torch.Tensor]] - ) -> Dict[Plane, torch.Tensor]: + self, preds: dict[Plane, list[torch.Tensor]] + ) -> dict[Plane, torch.Tensor]: """ Permutes axes, so it has consistent sagittal, coronal, axial, channels format. Also maps classes of sagittal predictions into the global label space. @@ -251,9 +252,10 @@ def _convert(plane: Plane) -> torch.Tensor: return {plane: _convert(plane) for plane in preds.keys()} - def _view_aggregation(self, logits: Dict[Plane, torch.Tensor]) -> torch.Tensor: + def _view_aggregation(self, logits: dict[Plane, torch.Tensor]) -> torch.Tensor: """ - Aggregate the view (axial, coronal, sagittal) into one volume and get the class of the largest probability. (argmax) + Aggregate the view (axial, coronal, sagittal) into one volume and get the + class of the largest probability. (argmax) Args: logits: dictionary of per plane predicted logits (axial, coronal, sagittal) @@ -269,12 +271,12 @@ def _view_aggregation(self, logits: Dict[Plane, torch.Tensor]) -> torch.Tensor: def _calc_segstats( self, seg_data: np.ndarray, norm_data: np.ndarray, vox_vol: float - ) -> "pandas.DataFrame": + ) -> "pd.DataFrame": """ Computes volume and volume similarity """ - def _get_ids_startswith(_label_map: Dict[int, str], prefix: str) -> List[int]: + def _get_ids_startswith(_label_map: dict[int, str], prefix: str) -> list[int]: return [ id for id, name in _label_map.items() @@ -301,7 +303,7 @@ def _get_ids_startswith(_label_map: Dict[int, str], prefix: str) -> List[int]: seg_data, norm_data, norm_data, - list(filter(lambda l: l != 0, label_map.keys())), + list(filter(lambda lb: lb != 0, label_map.keys())), vox_vol=vox_vol, threads=self.threads, patch_size=32, @@ -319,8 +321,6 @@ def _get_ids_startswith(_label_map: Dict[int, str], prefix: str) -> List[int]: # noinspection PyTypeChecker table[i]["StructName"] = "Merged-Label-" + str(_id) - import pandas as pd - dataframe = pd.DataFrame(table, index=np.arange(len(table))) dataframe = dataframe[dataframe["NVoxels"] != 0].sort_values("SegId") dataframe.index = np.arange(1, len(dataframe) + 1) @@ -361,7 +361,7 @@ def _save_cerebnet_seg( def _get_subject_dataset( self, subject: SubjectDirectory - ) -> Tuple[Optional[np.ndarray], Optional[Path], SubjectDataset]: + ) -> tuple[np.ndarray | None, Path | None, SubjectDataset]: """ Load and prepare input files asynchronously, then locate the cerebellum and provide a localized patch. diff --git a/CerebNet/models/networks.py b/CerebNet/models/networks.py index fb8af95f..d80c3061 100644 --- a/CerebNet/models/networks.py +++ b/CerebNet/models/networks.py @@ -14,14 +14,13 @@ # IMPORTS -from typing import Mapping +from collections.abc import Mapping import torch import torch.nn as nn -from FastSurferCNN.utils import logging from CerebNet.models import sub_module as sm - +from FastSurferCNN.utils import logging logger = logging.get_logger(__name__) @@ -42,7 +41,7 @@ def __init__(self, params): """ Create the FastSurferCNN model. """ - super(FastSurferCNN, self).__init__() + super().__init__() # Parameters for the Descending Arm self.encode1 = sm.CompetitiveEncoderBlockInput(params) diff --git a/CerebNet/models/sub_module.py b/CerebNet/models/sub_module.py index 3833fae2..864803c2 100644 --- a/CerebNet/models/sub_module.py +++ b/CerebNet/models/sub_module.py @@ -53,7 +53,7 @@ def __init__(self, params, outblock=False, discriminator_block=False): discriminator_block : bool, default=False Flag indicating if the block is discriminator block or not. """ - super(CompetitiveDenseBlock, self).__init__() + super().__init__() # Padding to get output tensor of same dimensions padding_h = int((params["kernel_h"] - 1) / 2) @@ -177,7 +177,7 @@ def __init__(self, params): params : dict Dictionary with parameters specifying block architecture. """ - super(CompetitiveDenseBlockInput, self).__init__() + super().__init__() # Padding to get output tensor of same dimensions padding_h = int((params["kernel_h"] - 1) / 2) @@ -276,7 +276,7 @@ def __init__(self, params): params : dict Parameters like number of channels, stride etc. """ - super(CompetitiveEncoderBlock, self).__init__(params) + super().__init__(params) self.maxpool = nn.MaxPool2d( kernel_size=params["pool"], stride=params["stride_pool"], @@ -303,7 +303,7 @@ def forward(self, x): indicies : Tensor Maxpool indices. """ - out_block = super(CompetitiveEncoderBlock, self).forward( + out_block = super().forward( x ) # To be concatenated as Skip Connection out_encoder, indices = self.maxpool( @@ -326,7 +326,7 @@ def __init__(self, params): params : dict Parameters like number of channels, stride etc. """ - super(CompetitiveEncoderBlockInput, self).__init__( + super().__init__( params ) # The init of CompetitiveDenseBlock takes in params self.maxpool = nn.MaxPool2d( @@ -355,7 +355,7 @@ def forward(self, x): Tensor The indices of the maxpool operation. """ - out_block = super(CompetitiveEncoderBlockInput, self).forward( + out_block = super().forward( x ) # To be concatenated as Skip Connection out_encoder, indices = self.maxpool( @@ -381,7 +381,7 @@ def __init__(self, params, outblock=False): Flag, indicating if last block of network before classifier is created. """ - super(CompetitiveDecoderBlock, self).__init__(params, outblock=outblock) + super().__init__(params, outblock=outblock) self.unpool = nn.MaxUnpool2d( kernel_size=params["pool"], stride=params["stride_pool"] ) @@ -413,7 +413,7 @@ def forward(self, x, out_block, indices): out_block = torch.unsqueeze(out_block, 4) concat = torch.cat((unpool, out_block), dim=4) # Competitive Concatenation concat_max, _ = torch.max(concat, 4) - out_block = super(CompetitiveDecoderBlock, self).forward(concat_max) + out_block = super().forward(concat_max) return out_block @@ -428,7 +428,7 @@ def __init__(self, params): Classifier Block initialization :param dict params: parameters like number of channels, stride etc. """ - super(ClassifierBlock, self).__init__() + super().__init__() self.conv = nn.Conv2d( params["num_channels"], params["num_classes"], diff --git a/CerebNet/run_prediction.py b/CerebNet/run_prediction.py index 410adcd5..07029271 100644 --- a/CerebNet/run_prediction.py +++ b/CerebNet/run_prediction.py @@ -15,19 +15,19 @@ # limitations under the License. # IMPORTS -import sys import argparse +import sys from pathlib import Path -from FastSurferCNN.utils import logging, parser_defaults, Plane, PLANES +from CerebNet.inference import Inference +from CerebNet.utils.checkpoint import YAML_DEFAULT as CHECKPOINT_PATHS_FILE +from CerebNet.utils.load_config import get_config +from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults from FastSurferCNN.utils.checkpoint import ( get_checkpoints, load_checkpoint_config_defaults, ) -from FastSurferCNN.utils.common import assert_no_root, SubjectList -from CerebNet.inference import Inference -from CerebNet.utils.checkpoint import YAML_DEFAULT as CHECKPOINT_PATHS_FILE -from CerebNet.utils.load_config import get_config +from FastSurferCNN.utils.common import SubjectList, assert_no_root logger = logging.get_logger(__name__) DEFAULT_CEREBELLUM_STATSFILE = Path("stats/cerebellum.CerebNet.stats") @@ -138,7 +138,7 @@ def main(args: argparse.Namespace) -> int | str: # Set up logging from FastSurferCNN.utils.logging import setup_logging - setup_logging(getattr(args, "log_name")) + setup_logging(args.log_name) subjects_kwargs = {} cereb_statsfile = getattr(args, "cereb_statsfile", None) diff --git a/CerebNet/utils/checkpoint.py b/CerebNet/utils/checkpoint.py index e8cdf1ca..2c83469d 100644 --- a/CerebNet/utils/checkpoint.py +++ b/CerebNet/utils/checkpoint.py @@ -15,18 +15,12 @@ # IMPORTS from typing import TYPE_CHECKING + if TYPE_CHECKING: import yacs -from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT from FastSurferCNN.utils import logging -from FastSurferCNN.utils.checkpoint import ( - load_from_checkpoint, - create_checkpoint_dir, - get_checkpoint, - get_checkpoint_path, - save_checkpoint, -) +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT # DEFAULTS YAML_DEFAULT = FASTSURFER_ROOT / "CerebNet/config/checkpoint_paths.yaml" diff --git a/CerebNet/utils/load_config.py b/CerebNet/utils/load_config.py index ac54f0ba..4041b406 100644 --- a/CerebNet/utils/load_config.py +++ b/CerebNet/utils/load_config.py @@ -39,11 +39,11 @@ def get_config(args) -> "yacs.config.CfgNode": if hasattr(args, "out_dir"): cfg.LOG_DIR = str(args.out_dir) - path_ax, path_sag, path_cor = [ + path_ax, path_sag, path_cor = ( getattr(args, name) for name in ["ckpt_ax", "ckpt_sag", "ckpt_cor"] - ] + ) - for plane, path in zip(PLANES, (path_ax, path_cor, path_sag)): + for plane, path in zip(PLANES, (path_ax, path_cor, path_sag), strict=False): setattr(cfg.TEST, f"{plane.upper()}_CHECKPOINT_PATH", str(path)) # overwrite the batch size if it is passed as a parameter diff --git a/CerebNet/utils/lr_scheduler.py b/CerebNet/utils/lr_scheduler.py index 5e3eab3d..f7b67a72 100644 --- a/CerebNet/utils/lr_scheduler.py +++ b/CerebNet/utils/lr_scheduler.py @@ -16,7 +16,6 @@ # IMPORTS import math import numbers -from typing import List import typing as _T import torch @@ -61,7 +60,7 @@ def __init__(self, optimizer, *args, T_0=10, Tmult=1, lr_restart=None, **kwargs) map(lambda group: group["initial_lr"], optimizer.param_groups) ) - super(ReduceLROnPlateauWithRestarts, self).__init__(optimizer, *args, **kwargs) + super().__init__(optimizer, *args, **kwargs) self.T_0 = T_0 self.Tmult = Tmult self.lr_restart = lr_restart @@ -88,7 +87,7 @@ def step(self, metrics, epoch=None): https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html """ self.Tcur += 1 - super(ReduceLROnPlateauWithRestarts, self).step(metrics, epoch) + super().step(metrics, epoch) if self.Tcur >= self.T_i: self._reset_lr() self._last_lr = [group["lr"] for group in self.optimizer.param_groups] @@ -105,7 +104,7 @@ def _reset_lr(self): self._reset() for i, param_group in enumerate(self.optimizer.param_groups): - old_lr = float(param_group["lr"]) + # old_lr = float(param_group["lr"]) lr_r = ( self.lr_restart[i] if isinstance(self.lr_restart, _T.Sequence) @@ -119,8 +118,8 @@ def _reset_lr(self): param_group["lr"] = new_lr if self.verbose: logger.info( - "Epoch {:5d}: restarting learning rate with " - "{:.4e} for group {}.".format(self.last_epoch, new_lr, i) + f"Epoch {self.last_epoch:5d}: restarting learning rate with " + f"{new_lr:.4e} for group {i}." ) @@ -144,7 +143,7 @@ def __init__( self.warmup_method = warmup_method super().__init__(optimizer, last_epoch) - def get_lr(self) -> List[float]: + def get_lr(self) -> list[float]: """ Get the learning rates at the current epoch. """ @@ -164,7 +163,7 @@ def get_lr(self) -> List[float]: for base_lr in self.base_lrs ] - def _compute_values(self) -> List[float]: + def _compute_values(self) -> list[float]: # The new interface return self.get_lr() @@ -195,7 +194,7 @@ def _get_warmup_factor_at_iter( alpha = iter / warmup_iters return warmup_factor * (1 - alpha) + alpha else: - raise ValueError("Unknown warmup method: {}".format(method)) + raise ValueError(f"Unknown warmup method: {method}") class CosineLR: @@ -255,7 +254,7 @@ class CosineAnnealingWarmRestartsDecay(scheduler.CosineAnnealingWarmRestarts): decay factor for where the learning rate restarts at. """ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): - super(CosineAnnealingWarmRestartsDecay, self).__init__( + super().__init__( optimizer, T_0, T_mult=T_mult, eta_min=eta_min, last_epoch=last_epoch ) pass diff --git a/CerebNet/utils/meters.py b/CerebNet/utils/meters.py index 67d7ebdd..ddeecb5b 100644 --- a/CerebNet/utils/meters.py +++ b/CerebNet/utils/meters.py @@ -14,17 +14,15 @@ # IMPORTS -from functools import partial -import numpy as np import matplotlib.pyplot as plt +import numpy as np import torch -from FastSurferCNN.utils import logging - -from CerebNet.utils.metrics import DiceScore, hd, dice_score, volume_similarity +from CerebNet.data_loader.data_utils import GRAY_MATTER, VERMIS_NAMES +from CerebNet.utils.metrics import DiceScore, dice_score, hd, volume_similarity from CerebNet.utils.misc import plot_confusion_matrix, plot_predictions -from CerebNet.data_loader.data_utils import VERMIS_NAMES, GRAY_MATTER +from FastSurferCNN.utils import logging logger = logging.get_logger(__name__) @@ -93,7 +91,7 @@ def _get_binray_map(self, lbl_map, class_names): bin_map : np.array Binary map where True represents class and False represents its absence. """ - bin_map = np.logical_or.reduce(list(map(lambda l: lbl_map == l, class_names))) + bin_map = np.logical_or.reduce(list(map(lambda lb: lbl_map == lb, class_names))) return bin_map def metrics_per_class(self, pred, gt): @@ -283,15 +281,8 @@ def log_iter(self, cur_iter, cur_epoch): ) logger.info( - "{} Epoch [{}/{}] Iter [{}/{}] [Dice Score: {:.4f}] [{}]".format( - self.mode, - cur_epoch + 1, - self.total_epochs, - cur_iter + 1, - self.total_iter_num, - dice_score_per_class[1:].mean(), - out_losses, - ) + f"{self.mode} Epoch [{cur_epoch + 1}/{self.total_epochs}] Iter [{cur_iter + 1}/{self.total_iter_num}]" \ + f" [Dice Score: {dice_score_per_class[1:].mean():.4f}] [{out_losses}]" ) def log_lr(self, lr, step=None): diff --git a/CerebNet/utils/metrics.py b/CerebNet/utils/metrics.py index b76d1d20..f100bb96 100644 --- a/CerebNet/utils/metrics.py +++ b/CerebNet/utils/metrics.py @@ -18,8 +18,8 @@ import torch from scipy.ndimage import _ni_support from scipy.ndimage.morphology import ( - distance_transform_edt, binary_erosion, + distance_transform_edt, generate_binary_structure, ) @@ -84,9 +84,7 @@ def _check_output_type(self, output): """ if not (isinstance(output, tuple)): raise TypeError( - "Output should be a tuple consisting of torch.Tensors, but given {}".format( - type(output) - ) + f"Output should be a tuple consisting of torch.Tensors, but given {type(output)}" ) def _update_union_intersection(self, batch_output, labels_batch): diff --git a/CerebNet/utils/misc.py b/CerebNet/utils/misc.py index 6e7ed790..7cbe08e3 100644 --- a/CerebNet/utils/misc.py +++ b/CerebNet/utils/misc.py @@ -14,17 +14,17 @@ # IMPORTS -import os import glob import math +import os from itertools import product -import torch -from torchvision import utils import numpy as np -from skimage import color +import torch from matplotlib import pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable +from skimage import color +from torchvision import utils from FastSurferCNN.utils import logging @@ -209,7 +209,7 @@ def set_summary_path(cfg): cfg.EXPR_NUM = str(find_latest_experiment(os.path.join(cfg.LOG_DIR, "summary")) + 1) if cfg.TRAIN.RESUME and cfg.TRAIN.RESUME_EXPR_NUM > 0: cfg.EXPR_NUM = cfg.TRAIN.RESUME_EXPR_NUM - cfg.SUMMARY_PATH = check_path(os.path.join(summary_path, "{}".format(cfg.EXPR_NUM))) + cfg.SUMMARY_PATH = check_path(os.path.join(summary_path, f"{cfg.EXPR_NUM}")) def load_classwise_weights(cfg): @@ -237,7 +237,7 @@ def update_results_dir(cfg): """ cfg.EXPR_NUM = str(find_latest_experiment(cfg.TEST.RESULTS_DIR) + 1) cfg.TEST.RESULTS_DIR = check_path( - os.path.join(cfg.TEST.RESULTS_DIR, "{}".format(cfg.EXPR_NUM)) + os.path.join(cfg.TEST.RESULTS_DIR, f"{cfg.EXPR_NUM}") ) @@ -250,7 +250,7 @@ def update_split_path(cfg): cfg : yacs.config.CfgNode Configuration node. """ - from os.path import split, join + from os.path import join, split split_num = cfg.SPLIT_NUM keys = [ @@ -273,8 +273,8 @@ def visualize_batch(img, label, idx): :param batch_dict: :return: """ - from skimage import color import matplotlib.pyplot as plt + from skimage import color plt.imshow(img[idx, 3].cpu().numpy(), cmap="gray") plt.imshow(color.label2rgb(label[idx].cpu().numpy(), bg_label=0), alpha=0.4) From 2a5ee8d0981fd657314e140448253d66e88c42b1 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 16:43:04 +0200 Subject: [PATCH 08/31] fastsurfercnn fix ruff warnings and cleanup --- FastSurferCNN/config/defaults.py | 3 +- FastSurferCNN/data_loader/augmentation.py | 43 +++--- FastSurferCNN/data_loader/conform.py | 53 ++++--- FastSurferCNN/data_loader/data_utils.py | 62 ++++---- FastSurferCNN/data_loader/dataset.py | 66 +++----- FastSurferCNN/data_loader/loader.py | 6 +- FastSurferCNN/download_checkpoints.py | 20 +-- FastSurferCNN/generate_hdf5.py | 40 ++--- FastSurferCNN/inference.py | 31 ++-- FastSurferCNN/models/interpolation_layer.py | 50 +++--- FastSurferCNN/models/losses.py | 20 +-- FastSurferCNN/models/networks.py | 19 +-- FastSurferCNN/models/optimizer.py | 1 - FastSurferCNN/models/sub_module.py | 39 +++-- FastSurferCNN/mri_brainvol_stats.py | 2 +- FastSurferCNN/mri_segstats.py | 19 +-- FastSurferCNN/quick_qc.py | 14 +- FastSurferCNN/reduce_to_aseg.py | 10 +- FastSurferCNN/run_model.py | 4 +- FastSurferCNN/run_prediction.py | 64 ++++---- FastSurferCNN/segstats.py | 159 ++++++++++---------- FastSurferCNN/train.py | 40 ++--- FastSurferCNN/utils/arg_types.py | 8 +- FastSurferCNN/utils/brainvolstats.py | 91 ++++++----- FastSurferCNN/utils/checkpoint.py | 25 +-- FastSurferCNN/utils/common.py | 39 ++--- FastSurferCNN/utils/dataclasses.py | 20 +-- FastSurferCNN/utils/logging.py | 5 +- FastSurferCNN/utils/lr_scheduler.py | 3 +- FastSurferCNN/utils/mapper.py | 117 +++++++------- FastSurferCNN/utils/meters.py | 24 ++- FastSurferCNN/utils/metrics.py | 14 +- FastSurferCNN/utils/misc.py | 3 +- FastSurferCNN/utils/parser_defaults.py | 25 ++- FastSurferCNN/utils/run_tools.py | 15 +- FastSurferCNN/version.py | 33 ++-- 36 files changed, 555 insertions(+), 632 deletions(-) diff --git a/FastSurferCNN/config/defaults.py b/FastSurferCNN/config/defaults.py index 4008fbef..c82098e3 100644 --- a/FastSurferCNN/config/defaults.py +++ b/FastSurferCNN/config/defaults.py @@ -100,7 +100,8 @@ # Flag to disable or enable Early Stopping _C.TRAIN.EARLY_STOPPING = True -# Mode for early stopping (min = stop when metric is no longer decreasing, max = stop when mwtric is no longer increasing) +# Mode for early stopping (min = stop when metric is no longer decreasing, +# max = stop when mwtric is no longer increasing) _C.TRAIN.EARLY_STOPPING_MODE = "min" # Patience = Number of epochs to wait before stopping diff --git a/FastSurferCNN/data_loader/augmentation.py b/FastSurferCNN/data_loader/augmentation.py index 4e8a5407..32259c93 100644 --- a/FastSurferCNN/data_loader/augmentation.py +++ b/FastSurferCNN/data_loader/augmentation.py @@ -15,7 +15,8 @@ # IMPORTS from numbers import Number, Real -from typing import Union, Tuple, Any, Dict +from typing import Any + import numpy as np import numpy.typing as npt import torch @@ -24,7 +25,7 @@ ## # Transformations for evaluation ## -class ToTensorTest(object): +class ToTensorTest: """ Convert np.ndarrays in sample to Tensors. @@ -61,7 +62,7 @@ def __call__(self, img: npt.NDArray) -> np.ndarray: return img -class ZeroPad2DTest(object): +class ZeroPad2DTest: """ Pad the input with zeros to get output size. @@ -81,7 +82,7 @@ class ZeroPad2DTest(object): """ def __init__( self, - output_size: Union[Number, Tuple[Number, Number]], + output_size: Number | tuple[Number, Number], pos: str = 'top_left' ): """ @@ -147,7 +148,7 @@ def __call__(self, img: npt.NDArray) -> np.ndarray: ## # Transformations for training ## -class ToTensor(object): +class ToTensor: """ Convert ndarrays in sample to Tensors. @@ -157,7 +158,7 @@ class ToTensor(object): Convert image. """ - def __call__(self, sample: npt.NDArray) -> Dict[str, Any]: + def __call__(self, sample: npt.NDArray) -> dict[str, Any]: """ Convert the image to float within range [0, 1] and make it torch compatible. @@ -196,7 +197,7 @@ def __call__(self, sample: npt.NDArray) -> Dict[str, Any]: } -class ZeroPad2D(object): +class ZeroPad2D: """ Pad the input with zeros to get output size. @@ -216,8 +217,8 @@ class ZeroPad2D(object): """ def __init__( self, - output_size: Union[Number, Tuple[Number, Number]], - pos: Union[None, str] = 'top_left' + output_size: Number | tuple[Number, Number], + pos: None | str = 'top_left' ): """ Initialize position and output_size (as Tuple[float]). @@ -261,7 +262,7 @@ def _pad(self, image: npt.NDArray) -> np.ndarray: return padded_img - def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: """ Pad the image, label and weights. @@ -289,7 +290,7 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: return {"img": img, "label": label, "weight": weight, "scale_factor": sf} -class AddGaussianNoise(object): +class AddGaussianNoise: """ Add gaussian noise to sample. @@ -319,7 +320,7 @@ def __init__(self, mean: Real = 0, std: Real = 0.1): self.std = std self.mean = mean - def __call__(self, sample: Dict[str, Real]) -> Dict[str, Real]: + def __call__(self, sample: dict[str, Real]) -> dict[str, Real]: """ Add gaussian noise to scalefactor. @@ -344,7 +345,7 @@ def __call__(self, sample: Dict[str, Real]) -> Dict[str, Real]: return {"img": img, "label": label, "weight": weight, "scale_factor": sf} -class AugmentationPadImage(object): +class AugmentationPadImage: """ Pad Image with either zero padding or reflection padding of img, label and weight. @@ -364,8 +365,8 @@ class AugmentationPadImage(object): """ def __init__( self, - pad_size: Tuple[Tuple[int, int], - Tuple[int, int]] = ((16, 16), (16, 16)), + pad_size: tuple[tuple[int, int], + tuple[int, int]] = ((16, 16), (16, 16)), pad_type: str = "edge" ): """ @@ -378,7 +379,7 @@ def __init__( pad_type : str The type of padding to be applied. """ - assert isinstance(pad_size, (int, tuple)) + assert isinstance(pad_size, int | tuple) if isinstance(pad_size, int): @@ -391,7 +392,7 @@ def __init__( self.pad_type = pad_type - def __call__(self, sample: Dict[str, Number]): + def __call__(self, sample: dict[str, Number]): """ Pad zeroes of sample image, label and weight. @@ -414,12 +415,12 @@ def __call__(self, sample: Dict[str, Number]): return {"img": img, "label": label, "weight": weight, "scale_factor": sf} -class AugmentationRandomCrop(object): +class AugmentationRandomCrop: """ Randomly Crop Image to given size. """ - def __init__(self, output_size: Union[int, Tuple], crop_type: str = 'Random'): + def __init__(self, output_size: int | tuple, crop_type: str = 'Random'): """Construct object. Attributes @@ -429,7 +430,7 @@ def __init__(self, output_size: Union[int, Tuple], crop_type: str = 'Random'): crop_type The type of crop to be performed. """ - assert isinstance(output_size, (int, tuple)) + assert isinstance(output_size, int | tuple) if isinstance(output_size, int): self.output_size = (output_size, output_size) @@ -439,7 +440,7 @@ def __init__(self, output_size: Union[int, Tuple], crop_type: str = 'Random'): self.crop_type = crop_type - def __call__(self, sample: Dict[str, Number]) -> Dict[str, Number]: + def __call__(self, sample: dict[str, Number]) -> dict[str, Number]: """ Crops the augmentation. diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index 4e8b3224..877f7c63 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -16,21 +16,28 @@ # IMPORTS import argparse -from enum import Enum import logging import sys -from typing import Optional, Type, Tuple, Union, Iterable, cast +from collections.abc import Iterable +from enum import Enum +from typing import cast +import nibabel as nib import numpy as np import numpy.typing as npt -import nibabel as nib from FastSurferCNN.utils.arg_types import ( - vox_size as __vox_size, - target_dtype as __target_dtype, - float_gt_zero_and_le_one as __conform_to_one_mm, VoxSizeOption, ) +from FastSurferCNN.utils.arg_types import ( + float_gt_zero_and_le_one as __conform_to_one_mm, +) +from FastSurferCNN.utils.arg_types import ( + target_dtype as __target_dtype, +) +from FastSurferCNN.utils.arg_types import ( + vox_size as __vox_size, +) HELPTEXT = """ Script to conform an MRI brain image to UCHAR, RAS orientation, @@ -194,9 +201,9 @@ def map_image( img: nib.analyze.SpatialImage, out_affine: np.ndarray, out_shape: tuple[int, ...] | np.ndarray | Iterable[int], - ras2ras: Optional[np.ndarray] = None, + ras2ras: np.ndarray | None = None, order: int = 1, - dtype: Optional[Type] = None + dtype: type | None = None ) -> np.ndarray: """ Map image to new voxel space (RAS orientation). @@ -222,8 +229,8 @@ def map_image( np.ndarray Mapped image data array. """ - from scipy.ndimage import affine_transform from numpy.linalg import inv + from scipy.ndimage import affine_transform if ras2ras is None: ras2ras = np.eye(4) @@ -276,7 +283,7 @@ def getscale( dst_max: float, f_low: float = 0.0, f_high: float = 0.999 -) -> Tuple[float, float]: +) -> tuple[float, float]: """ Get offset and scale of image intensities to robustly rescale to dst_min..dst_max. @@ -521,11 +528,11 @@ def conform( img: nib.analyze.SpatialImage, order: int = 1, conform_vox_size: VoxSizeOption = 1.0, - dtype: Optional[Type] = None, - conform_to_1mm_threshold: Optional[float] = None, + dtype: type | None = None, + conform_to_1mm_threshold: float | None = None, criteria: set[Criteria] = DEFAULT_CRITERIA, ) -> nib.MGHImage: - f"""Python version of mri_convert -c. + """Python version of mri_convert -c. mri_convert -c by default turns image intensity values into UCHAR, reslices images to standard position, fills up slices to standard @@ -547,7 +554,7 @@ def conform( conform_to_1mm_threshold : Optional[float] The threshold above which the image is conformed to 1mm (default: ignore). - criteria : set[Criteria], default={DEFAULT_CRITERIA} + criteria : set[Criteria], default in DEFAULT_CRITERIA Whether to force the conforming to include a LIA data layout, an image size requirement and/or a voxel size requirement. @@ -607,7 +614,7 @@ def conform( h1["Mdc"] = np.linalg.inv(mdc_affine) print(h1.get_zooms()) - h1["fov"] = max(i * v for i, v in zip(h1.get_data_shape(), h1.get_zooms())) + h1["fov"] = max(i * v for i, v in zip(h1.get_data_shape(), h1.get_zooms(), strict=False)) center = np.asarray(img.shape[:3], dtype=float) / 2.0 h1["Pxyz_c"] = img.affine.dot(np.hstack((center, [1.0])))[:3] @@ -707,12 +714,12 @@ def is_conform( conform_vox_size: VoxSizeOption = 1.0, eps: float = 1e-06, check_dtype: bool = True, - dtype: Optional[Type] = None, + dtype: type | None = None, verbose: bool = True, - conform_to_1mm_threshold: Optional[float] = None, + conform_to_1mm_threshold: float | None = None, criteria: set[Criteria] = DEFAULT_CRITERIA, ) -> bool: - f""" + """ Check if an image is already conformed or not. Dimensions: 256x256x256, Voxel size: 1x1x1, LIA orientation, and data type UCHAR. @@ -740,7 +747,7 @@ def is_conform( are displayed. conform_to_1mm_threshold : float, optional Above this threshold the image is conformed to 1mm (default or None: ignore). - criteria : set[Criteria], default={DEFAULT_CRITERIA} + criteria : set[Criteria], default in DEFAULT_CRITERIA An enum/set of criteria to check. Returns @@ -855,8 +862,8 @@ def get_primary_dirs(a): return np.argmax(abs(a), axis=0) def get_conformed_vox_img_size( img: nib.analyze.SpatialImage, conform_vox_size: VoxSizeOption, - conform_to_1mm_threshold: Optional[float] = None -) -> Tuple[float, int]: + conform_to_1mm_threshold: float | None = None +) -> tuple[float, int]: """ Extract the voxel size and the image size. @@ -897,8 +904,8 @@ def get_conformed_vox_img_size( def check_affine_in_nifti( - img: Union[nib.Nifti1Image, nib.Nifti2Image], - logger: Optional[logging.Logger] = None + img: nib.Nifti1Image | nib.Nifti2Image, + logger: logging.Logger | None = None ) -> bool: """ Check the affine in nifti Image. diff --git a/FastSurferCNN/data_loader/data_utils.py b/FastSurferCNN/data_loader/data_utils.py index dd11302b..438f6e05 100644 --- a/FastSurferCNN/data_loader/data_utils.py +++ b/FastSurferCNN/data_loader/data_utils.py @@ -14,27 +14,28 @@ # IMPORTS +from collections.abc import Mapping from pathlib import Path -from typing import Optional, Tuple, Union, Mapping, cast, Iterable +from typing import cast +import nibabel as nib import numpy as np -from numpy import typing as npt +import pandas as pd +import scipy.ndimage.morphology as morphology import torch -from skimage.measure import label, regionprops +from nibabel.filebasedimages import FileBasedHeader as _Header +from numpy import typing as npt from scipy.ndimage import ( - binary_erosion, binary_closing, + binary_erosion, filters, - uniform_filter, generate_binary_structure, + uniform_filter, ) -import scipy.ndimage.morphology as morphology -import nibabel as nib -from nibabel.filebasedimages import FileBasedHeader as _Header -import pandas as pd +from skimage.measure import label, regionprops +from FastSurferCNN.data_loader.conform import check_affine_in_nifti, conform, is_conform from FastSurferCNN.utils import logging -from FastSurferCNN.data_loader.conform import is_conform, conform, check_affine_in_nifti from FastSurferCNN.utils.arg_types import VoxSizeOption ## @@ -160,8 +161,8 @@ def load_image( """ try: img = cast(nib.analyze.SpatialImage, nib.load(file, **kwargs)) - except (IOError, FileNotFoundError) as e: - raise IOError( + except (OSError, FileNotFoundError) as e: + raise OSError( f"Failed loading the {name} '{file}' with error: {e.args[0]}" ) from e data = np.asarray(img.dataobj) @@ -212,7 +213,6 @@ def load_maybe_conform( dst_file = file else: # the image is not conformed to 1mm, do this now. - from nibabel.filebasedimages import FileBasedHeader as _Header fileext = [ ext for ext in SUPPORTED_OUTPUT_FILE_FORMATS @@ -258,7 +258,7 @@ def save_image( affine_info: npt.NDArray[float], img_array: np.ndarray, save_as: str | Path, - dtype: Optional[npt.DTypeLike] = None + dtype: npt.DTypeLike | None = None ) -> None: """ Save an image (nibabel MGHImage), according to the desired output file format. @@ -395,7 +395,7 @@ def filter_blank_slices_thick( label_vol: npt.NDArray, weight_vol: npt.NDArray, threshold: int = 50 -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Filter blank slices from the volume using the label volume. @@ -435,7 +435,7 @@ def create_weight_mask( mapped_aseg: npt.NDArray, max_weight: int = 5, max_edge_weight: int = 5, - max_hires_weight: Optional[int] = None, + max_hires_weight: int | None = None, ctx_thresh: int = 33, mean_filter: bool = False, cortex_mask: bool = True, @@ -642,7 +642,7 @@ def read_classes_from_lut(lut_file: str | Path): def map_label2aparc_aseg( mapped_aseg: torch.Tensor, - labels: Union[torch.Tensor, npt.NDArray] + labels: torch.Tensor | npt.NDArray ) -> torch.Tensor: """ Perform look-up table mapping from sequential label space to LUT space. @@ -890,8 +890,8 @@ def split_cortex_labels(aparc: npt.NDArray) -> np.ndarray: def unify_lateralized_labels( - lut: Union[str, pd.DataFrame], - combi: Tuple[str, str] = ("Left-", "Right-") + lut: str | pd.DataFrame, + combi: tuple[str, str] = ("Left-", "Right-") ) -> Mapping: """ Generate lookup dictionary of left-right labels. @@ -923,9 +923,9 @@ def unify_lateralized_labels( def get_labels_from_lut( - lut: Union[str, pd.DataFrame], - label_extract: Tuple[str, str] = ("Left-", "ctx-rh") -) -> Tuple[np.ndarray, np.ndarray]: + lut: str | pd.DataFrame, + label_extract: tuple[str, str] = ("Left-", "ctx-rh") +) -> tuple[np.ndarray, np.ndarray]: """ Extract labels from the lookup tables. @@ -960,9 +960,9 @@ def map_aparc_aseg2label( labels: npt.NDArray, labels_sag: npt.NDArray, sagittal_lut_dict: Mapping, - aseg_nocc: Optional[npt.NDArray] = None, + aseg_nocc: npt.NDArray | None = None, processing: str = "aparc" -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray]: """ Perform look-up table mapping of aparc.DKTatlas+aseg.mgz data to label space. @@ -1011,16 +1011,14 @@ def map_aparc_aseg2label( assert not np.any( 251 <= aseg - ), "Error: CC classes (251-255) still exist in aseg {}".format(np.unique(aseg)) + ), f"Error: CC classes (251-255) still exist in aseg {np.unique(aseg)}" assert np.any(aseg == 3) and np.any( aseg == 42 - ), "Error: no cortical marker detected {}".format(np.unique(aseg)) + ), f"Error: no cortical marker detected {np.unique(aseg)}" assert set(labels).issuperset( np.unique(aseg) - ), "Error: segmentation image contains classes not listed in the labels: \n{}\n{}".format( - np.unique(aseg), labels - ) + ), f"Error: segmentation image contains classes not listed in the labels: \n{np.unique(aseg)}\n{labels}" h, w, d = aseg.shape lut_aseg = np.zeros(max(labels) + 1, dtype="int") @@ -1088,7 +1086,7 @@ def sagittal_coronal_remap_lookup(x: int) -> int: def infer_mapping_from_lut( num_classes_full: int, - lut: Union[str, pd.DataFrame] + lut: str | pd.DataFrame ) -> np.ndarray: """ Guess the mapping from a lookup table. @@ -1123,7 +1121,7 @@ def infer_mapping_from_lut( def map_prediction_sagittal2full( prediction_sag: npt.NDArray, num_classes: int = 51, - lut: Optional[str] = None + lut: str | None = None ) -> np.ndarray: """ Remap the prediction on the sagittal network to full label space used by coronal and axial networks. @@ -1165,7 +1163,7 @@ def map_prediction_sagittal2full( # Clean up and class separation def bbox_3d( img: npt.NDArray -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Extract the three-dimensional bounding box coordinates. diff --git a/FastSurferCNN/data_loader/dataset.py b/FastSurferCNN/data_loader/dataset.py index 5efe6ee3..d18506e6 100644 --- a/FastSurferCNN/data_loader/dataset.py +++ b/FastSurferCNN/data_loader/dataset.py @@ -14,7 +14,7 @@ # IMPORTS import time -from typing import Optional, Tuple, Dict +from typing import Optional import h5py import numpy as np @@ -66,16 +66,16 @@ def __init__( if self.plane == "sagittal": orig_data = du.transform_sagittal(orig_data) self.zoom = orig_zoom[::-1][:2] - logger.info("Loading Sagittal with input voxelsize {}".format(self.zoom)) + logger.info(f"Loading Sagittal with input voxelsize {self.zoom}") elif self.plane == "axial": orig_data = du.transform_axial(orig_data) self.zoom = orig_zoom[::-1][:2] - logger.info("Loading Axial with input voxelsize {}".format(self.zoom)) + logger.info(f"Loading Axial with input voxelsize {self.zoom}") else: self.zoom = orig_zoom[:2] - logger.info("Loading Coronal with input voxelsize {}".format(self.zoom)) + logger.info(f"Loading Coronal with input voxelsize {self.zoom}") # Create thick slices orig_thick = du.get_thick_slices(orig_data, self.slice_thickness) @@ -101,7 +101,7 @@ def _get_scale_factor(self) -> npt.NDArray[float]: return scale - def __getitem__(self, index: int) -> Dict: + def __getitem__(self, index: int) -> dict: """ Return a single image and its scale factor. @@ -181,38 +181,28 @@ def __init__( logger.info(f"Processing images of size {size}.") img_dset = list(hf[f"{size}"]["orig_dataset"]) logger.info( - "Processed origs of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed origs of size {size} in {time.time() - start:.3f} seconds" ) self.images.extend(img_dset) self.labels.extend(list(hf[f"{size}"]["aseg_dataset"])) logger.info( - "Processed asegs of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed asegs of size {size} in {time.time() - start:.3f} seconds" ) self.weights.extend(list(hf[f"{size}"]["weight_dataset"])) self.zooms.extend(list(hf[f"{size}"]["zoom_dataset"])) logger.info( - "Processed zooms of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed zooms of size {size} in {time.time() - start:.3f} seconds" ) logger.info( - "Processed weights of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed weights of size {size} in {time.time() - start:.3f} seconds" ) self.subjects.extend(list(hf[f"{size}"]["subject"])) logger.info( - "Processed subjects of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed subjects of size {size} in {time.time() - start:.3f} seconds" ) logger.info(f"Number of slices for size {size} is {len(img_dset)}") - except KeyError as e: + except KeyError: print( f"KeyError: Unable to open object (object {size} does not exist)" ) @@ -222,9 +212,8 @@ def __init__( self.transforms = transforms logger.info( - "Successfully loaded {} data from {} with plane {} in {:.3f} seconds".format( - self.count, dataset_path, cfg.DATA.PLANE, time.time() - start - ) + f"Successfully loaded {self.count} data from {dataset_path} with plane {cfg.DATA.PLANE}" \ + f" in {time.time() - start:.3f} seconds" ) def get_subject_names(self): @@ -313,7 +302,7 @@ def unify_imgs( img: npt.NDArray, label: npt.NDArray, weight: npt.NDArray - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Pad img, label and weight. @@ -433,38 +422,28 @@ def __init__(self, dataset_path, cfg, transforms=None): logger.info(f"Processing images of size {size}.") img_dset = list(hf[f"{size}"]["orig_dataset"]) logger.info( - "Processed origs of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed origs of size {size} in {time.time() - start:.3f} seconds" ) self.images.extend(img_dset) self.labels.extend(list(hf[f"{size}"]["aseg_dataset"])) logger.info( - "Processed asegs of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed asegs of size {size} in {time.time() - start:.3f} seconds" ) self.weights.extend(list(hf[f"{size}"]["weight_dataset"])) logger.info( - "Processed weights of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed weights of size {size} in {time.time() - start:.3f} seconds" ) self.zooms.extend(list(hf[f"{size}"]["zoom_dataset"])) logger.info( - "Processed zooms of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed zooms of size {size} in {time.time() - start:.3f} seconds" ) self.subjects.extend(list(hf[f"{size}"]["subject"])) logger.info( - "Processed subjects of size {} in {:.3f} seconds".format( - size, time.time() - start - ) + f"Processed subjects of size {size} in {time.time() - start:.3f} seconds" ) logger.info(f"Number of slices for size {size} is {len(img_dset)}") - except KeyError as e: + except KeyError: print( f"KeyError: Unable to open object (object {size} does not exist)" ) @@ -473,9 +452,8 @@ def __init__(self, dataset_path, cfg, transforms=None): self.count = len(self.images) self.transforms = transforms logger.info( - "Successfully loaded {} data from {} with plane {} in {:.3f} seconds".format( - self.count, dataset_path, cfg.DATA.PLANE, time.time() - start - ) + f"Successfully loaded {self.count} data from {dataset_path} with plane {cfg.DATA.PLANE}" \ + f" in {time.time() - start:.3f} seconds" ) def get_subject_names(self): diff --git a/FastSurferCNN/data_loader/loader.py b/FastSurferCNN/data_loader/loader.py index 1ae09d9a..73adb259 100644 --- a/FastSurferCNN/data_loader/loader.py +++ b/FastSurferCNN/data_loader/loader.py @@ -12,12 +12,12 @@ # limitations under the License. # IMPORTS -from torchvision import transforms -from torch.utils.data import DataLoader import yacs.config +from torch.utils.data import DataLoader +from torchvision import transforms from FastSurferCNN.data_loader import dataset as dset -from FastSurferCNN.data_loader.augmentation import ToTensor, ZeroPad2D, AddGaussianNoise +from FastSurferCNN.data_loader.augmentation import AddGaussianNoise, ToTensor, ZeroPad2D from FastSurferCNN.utils import logging logger = logging.getLogger(__name__) diff --git a/FastSurferCNN/download_checkpoints.py b/FastSurferCNN/download_checkpoints.py index 54e5377b..fa34d1e4 100644 --- a/FastSurferCNN/download_checkpoints.py +++ b/FastSurferCNN/download_checkpoints.py @@ -15,19 +15,18 @@ # limitations under the License. import argparse -from functools import lru_cache -from typing import Optional +from CerebNet.utils.checkpoint import ( + YAML_DEFAULT as CEREBNET_YAML, +) from FastSurferCNN.utils import PLANES +from FastSurferCNN.utils.checkpoint import ( + YAML_DEFAULT as VINN_YAML, +) from FastSurferCNN.utils.checkpoint import ( check_and_download_ckpts, get_checkpoints, load_checkpoint_config_defaults, - YAML_DEFAULT as VINN_YAML, -) - -from CerebNet.utils.checkpoint import ( - YAML_DEFAULT as CEREBNET_YAML, ) from HypVINN.utils.checkpoint import ( YAML_DEFAULT as HYPVINN_YAML, @@ -35,15 +34,12 @@ class ConfigCache: - @lru_cache def vinn_url(self): return load_checkpoint_config_defaults("url", filename=VINN_YAML) - @lru_cache def cerebnet_url(self): return load_checkpoint_config_defaults("url", filename=CEREBNET_YAML) - @lru_cache def hypvinn_url(self): return load_checkpoint_config_defaults("url", filename=HYPVINN_YAML) @@ -108,7 +104,7 @@ def main( hypvinn: bool, all: bool, files: list[str], - url: Optional[str] = None, + url: str | None = None, ) -> int | str: if not vinn and not files and not cerebnet and not hypvinn and not all: return ("Specify either files to download or --vinn, --cerebnet, " @@ -159,7 +155,7 @@ def main( if __name__ == "__main__": import sys - from logging import basicConfig, INFO + from logging import INFO, basicConfig basicConfig(stream=sys.stdout, level=INFO) args = make_arguments() diff --git a/FastSurferCNN/generate_hdf5.py b/FastSurferCNN/generate_hdf5.py index 2b13a430..37383642 100644 --- a/FastSurferCNN/generate_hdf5.py +++ b/FastSurferCNN/generate_hdf5.py @@ -17,14 +17,12 @@ # IMPORTS import time from collections import defaultdict -from os.path import dirname, join +from os.path import join from pathlib import Path -from typing import Dict, Tuple import h5py import nibabel as nib import numpy as np -from numpy import ndarray from numpy import typing as npt from FastSurferCNN.data_loader.data_utils import ( @@ -106,7 +104,7 @@ class H5pyDataset: Create a hdf5 file """ - def __init__(self, params: Dict, processing: str = "aparc"): + def __init__(self, params: dict, processing: str = "aparc"): """ Construct H5pyDataset object. @@ -161,7 +159,7 @@ def __init__(self, params: Dict, processing: str = "aparc"): self.lateralization = unify_lateralized_labels(self.lut, params["combi"]) if params["csv_file"] is not None: - with open(params["csv_file"], "r") as s_dirs: + with open(params["csv_file"]) as s_dirs: self.subject_dirs = [line.strip() for line in s_dirs.readlines()] else: @@ -172,7 +170,7 @@ def __init__(self, params: Dict, processing: str = "aparc"): def _load_volumes( self, subject_path: str - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple]: """ Load the given image and segmentation and gets the zoom values. @@ -196,9 +194,7 @@ def _load_volumes( """ # Load the orig and extract voxel spacing information (x, y, and z dim) LOGGER.info( - "Processing intensity image {} and ground truth segmentation {}".format( - self.orig_name, self.aparc_name - ) + f"Processing intensity image {self.orig_name} and ground truth segmentation {self.aparc_name}" ) orig = nib.load(join(subject_path, self.orig_name)) # Load the segmentation ground truth @@ -219,7 +215,7 @@ def _load_volumes( def transform( self, plane: str, imgs: npt.NDArray, zoom: npt.NDArray - ) -> Tuple[npt.NDArray, npt.NDArray]: + ) -> tuple[npt.NDArray, npt.NDArray]: """ Transform the image and zoom along the given axis. @@ -268,7 +264,7 @@ def _pad_image(self, img: npt.NDArray, max_out: int) -> np.ndarray: """ # Get correct size = max along shape h, w, d = img.shape - LOGGER.info("Padding image from {0} to {1}x{1}x{1}".format(img.shape, max_out)) + LOGGER.info(f"Padding image from {img.shape} to {max_out}x{max_out}x{max_out}") padded_img = np.zeros((max_out, max_out, max_out), dtype=img.dtype) padded_img[0:h, 0:w, 0:d] = img return padded_img @@ -287,12 +283,10 @@ def create_hdf5_dataset(self, blt: int): for idx, current_subject in enumerate(self.subject_dirs): try: - start = time.time() + # start = time.time() LOGGER.info( - "Volume Nr: {} Processing MRI Data from {}/{}".format( - idx + 1, current_subject, self.orig_name - ) + f"Volume Nr: {idx + 1} Processing MRI Data from {current_subject}/{self.orig_name}" ) orig, aseg, aseg_nocc, zoom = self._load_volumes(current_subject) @@ -331,14 +325,8 @@ def create_hdf5_dataset(self, blt: int): ) print( - "Created weights with max_w {}, gradient {}," - " edge_w {}, hires_w {}, gm_mask {}".format( - self.max_weight, - self.gradient, - self.edge_weight, - self.hires_weight, - self.gm_mask, - ) + f"Created weights with max_w {self.max_weight}, gradient {self.gradient}," + f" edge_w {self.edge_weight}, hires_w {self.hires_weight}, gm_mask {self.gm_mask}" ) # transform volumes to correct shape @@ -368,7 +356,7 @@ def create_hdf5_dataset(self, blt: int): ) except Exception as e: - LOGGER.info("Volume: {} Failed Reading Data. Error: {}".format(idx, e)) + LOGGER.info(f"Volume: {idx} Failed Reading Data. Error: {e}") continue for key, data_dict in data_per_size.items(): @@ -388,9 +376,7 @@ def create_hdf5_dataset(self, blt: int): end_d = time.time() - start_d LOGGER.info( - "Successfully written {} in {:.3f} seconds.".format( - self.dataset_name, end_d - ) + f"Successfully written {self.dataset_name} in {end_d:.3f} seconds." ) diff --git a/FastSurferCNN/inference.py b/FastSurferCNN/inference.py index f99fea0a..9ba0306a 100644 --- a/FastSurferCNN/inference.py +++ b/FastSurferCNN/inference.py @@ -16,7 +16,7 @@ # IMPORTS import time -from typing import Dict, Optional, Tuple, Union +from typing import Optional import numpy as np import torch @@ -75,8 +75,8 @@ class Inference: Run the loaded model """ - permute_order: Dict[str, Tuple[int, int, int, int]] - device: Optional[torch.device] + permute_order: dict[str, tuple[int, int, int, int]] + device: torch.device | None default_device: torch.device def __init__( @@ -84,7 +84,7 @@ def __init__( cfg: yacs.config.CfgNode, device: torch.device, ckpt: str = "", - lut: Union[None, str, np.ndarray, DataFrame] = None, + lut: None | str | np.ndarray | DataFrame = None, ): """ Construct Inference object. @@ -171,7 +171,7 @@ def set_cfg(self, cfg: yacs.config.CfgNode): """ self.cfg = cfg - def to(self, device: Optional[torch.device] = None): + def to(self, device: torch.device | None = None): """ Move and/or cast the parameters and buffers. @@ -188,7 +188,7 @@ def to(self, device: Optional[torch.device] = None): self.device = _device self.model.to(device=_device) - def load_checkpoint(self, ckpt: Union[str, os.PathLike]): + def load_checkpoint(self, ckpt: str | os.PathLike): """ Load the checkpoint and set device and model. @@ -197,7 +197,7 @@ def load_checkpoint(self, ckpt: Union[str, os.PathLike]): ckpt : Union[str, os.PathLike] String or os.PathLike object containing the name to the checkpoint file. """ - logger.info("Loading checkpoint {}".format(ckpt)) + logger.info(f"Loading checkpoint {ckpt}") self.model = self._model_not_init # If device is None, the model has never been loaded (still in random initial configuration) @@ -323,7 +323,7 @@ def eval( val_loader: DataLoader, *, out_scale: Optional = None, - out: Optional[torch.Tensor] = None, + out: torch.Tensor | None = None, ) -> torch.Tensor: """Perform prediction and inplace-aggregate views into pred_prob. @@ -363,11 +363,13 @@ def eval( if out is None: out = init_pred.detach().clone() + log_batch_idx = None with logging_redirect_tqdm(): try: for batch_idx, batch in tqdm( enumerate(val_loader), total=len(val_loader), unit="batch" ): + log_batch_idx = batch_idx # move data to the model device images, scale_factors = batch["image"].to(self.device), batch[ "scale_factor" @@ -399,14 +401,12 @@ def eval( except: logger.exception( - "Exception in batch {} of {} inference.".format(batch_idx, plane) + f"Exception in batch {log_batch_idx} of {plane} inference." ) raise else: logger.info( - "Inference on {} batches for {} successful".format( - batch_idx + 1, plane - ) + f"Inference on {batch_idx + 1} batches for {plane} successful" ) return out @@ -418,12 +418,13 @@ def run( img_filename: str, orig_data: npt.NDArray, orig_zoom: npt.NDArray, - out: Optional[torch.Tensor] = None, - out_res: Optional[int] = None, + out: torch.Tensor | None = None, + out_res: int | None = None, batch_size: int = None, ) -> torch.Tensor: """ - Run the loaded model on the data (T1) from orig_data and img_filename (for messages only) with scale factors orig_zoom. + Run the loaded model on the data (T1) from orig_data and + img_filename (for messages only) with scale factors orig_zoom. Parameters ---------- diff --git a/FastSurferCNN/models/interpolation_layer.py b/FastSurferCNN/models/interpolation_layer.py index ccafd4da..dbe37d6f 100644 --- a/FastSurferCNN/models/interpolation_layer.py +++ b/FastSurferCNN/models/interpolation_layer.py @@ -25,7 +25,7 @@ LOGGER = _getLogger(__name__) -T_Scale = _T.TypeVar("T_Scale", _T.List[float], Tensor) +T_Scale = _T.TypeVar("T_Scale", list[float], Tensor) T_ScaleAll = _T.TypeVar("T_ScaleAll", _T.Sequence[float], Tensor, np.ndarray, float) @@ -57,7 +57,7 @@ class _ZoomNd(nn.Module): def __init__( self, - target_shape: _T.Optional[_T.Sequence[int]], + target_shape: _T.Sequence[int] | None, interpolation_mode: str = "nearest", ): """ @@ -72,22 +72,22 @@ def __init__( Interpolation mode as in `torch.nn.interpolate` (default: 'neareast'). """ - super(_ZoomNd, self).__init__() + super().__init__() self._mode = interpolation_mode if not hasattr(self, "_N"): self._N = -1 - self._target_shape: _T.Tuple[int, ...] = tuple() + self._target_shape: tuple[int, ...] = tuple() self.target_shape = target_shape @property - def target_shape(self) -> _T.Tuple[int, ...]: + def target_shape(self) -> tuple[int, ...]: """ Return the target shape. """ return self._target_shape @target_shape.setter - def target_shape(self, target_shape: _T.Optional[_T.Sequence[int]]) -> None: + def target_shape(self, target_shape: _T.Sequence[int] | None) -> None: """ Validate and set the target_shape. """ @@ -112,7 +112,7 @@ def target_shape(self, target_shape: _T.Optional[_T.Sequence[int]]) -> None: def forward( self, input_tensor: Tensor, scale_factors: T_ScaleAll, rescale: bool = False - ) -> _T.Tuple[Tensor, _T.List[T_Scale]]: + ) -> tuple[Tensor, list[T_Scale]]: """ Zoom the `input_tensor` with `scale_factors`. @@ -130,8 +130,8 @@ def forward( or a (cascaded) sequence of floats or ints) or a float. If it is a float, all axis and all images of the batch are treated the same (zoomed by the float). Else, it will be interpreted as a multidimensional image: The first dimension corresponds to and must be equal to the batch size of the image. The second - dimension is optional and may contain different values for the _scale_limits factor per axis. In consequence, - this dimension can have 1 or {dim} values. + dimension is optional and may contain different values for the _scale_limits factor per axis. + In consequence, this dimension can have 1 or {dim} values. rescale : bool, default="False" (Default value = False). @@ -161,7 +161,7 @@ def forward( ) scales_chunks = list( - zip(*self._fix_scale_factors(scale_factors, input_tensor.shape[0])) + zip(*self._fix_scale_factors(scale_factors, input_tensor.shape[0]), strict=False) ) if len(scales_chunks) == 0: raise ValueError( @@ -173,7 +173,7 @@ def forward( # Pytorch Tensor shape BxCxHxW --> loop over batches, interpolate single images, concatenate output at end for tensor, scale_f, num in zip( - torch.split(input_tensor, chunks, dim=0), scales, chunks + torch.split(input_tensor, chunks, dim=0), scales, chunks, strict=False ): if rescale: if isinstance(scale_f, list): @@ -188,7 +188,7 @@ def forward( def _fix_scale_factors( self, scale_factors: T_ScaleAll, batch_size: int - ) -> _T.Iterable[_T.Tuple[T_Scale, int]]: + ) -> _T.Iterable[tuple[T_Scale, int]]: """ Check and fix the conformity of scale_factors. @@ -209,7 +209,7 @@ def _fix_scale_factors( ValueError Scale_factors is neither a _T.Iterable nor a Number. """ - if isinstance(scale_factors, (Tensor, np.ndarray)): + if isinstance(scale_factors, Tensor | np.ndarray): batch_size_sf = scale_factors.shape[0] elif isinstance(scale_factors, _T.Iterable): scale_factors = list(scale_factors) @@ -223,13 +223,13 @@ def _fix_scale_factors( "scale_factors is a Sequence, but not the same length as the batch-size." ) num = 0 - last_sf: _T.Optional[T_Scale] = None + last_sf: T_Scale | None = None # Loop over batches for i, sf in enumerate(scale_factors): if isinstance(sf, Number): sf = [sf] * self._N else: - if isinstance(sf, (np.ndarray, Tensor)): + if isinstance(sf, np.ndarray | Tensor): if isinstance(sf, Tensor) and sf.dim() == 0: sf_dim = 1 sf = [sf] * self._N @@ -252,7 +252,7 @@ def _fix_scale_factors( f"scale factors, but only 1 or {self._N} are valid: {sf}." ) - if last_sf is not None and any(l != t for l, t in zip(last_sf, sf)): + if last_sf is not None and any(ln != t for ln, t in zip(last_sf, sf, strict=False)): yield last_sf, num # reset the counter num = 0 @@ -267,7 +267,7 @@ def _fix_scale_factors( "scale_factors is not the correct type, must be sequence of floats or float." ) - def _interpolate(self, *args) -> _T.Tuple[Tensor, T_Scale]: + def _interpolate(self, *args) -> tuple[Tensor, T_Scale]: """ Abstract method. @@ -284,7 +284,7 @@ def _calculate_crop_pad( scale_factor: T_Scale, dim: int, alignment: str, - ) -> _T.Tuple[slice, T_Scale, _T.Tuple[int, int], int]: + ) -> tuple[slice, T_Scale, tuple[int, int], int]: """ Return start- and end- coordinate given sizes, the updated scale factor. @@ -380,7 +380,7 @@ class Zoom2d(_ZoomNd): def __init__( self, - target_shape: _T.Optional[_T.Sequence[int]], + target_shape: _T.Sequence[int] | None, interpolation_mode: str = "nearest", crop_position: str = "top_left", ): @@ -401,7 +401,7 @@ def __init__( raise ValueError(f"invalid interpolation_mode, got {interpolation_mode}") self._N = 2 - super(Zoom2d, self).__init__(target_shape, interpolation_mode) + super().__init__(target_shape, interpolation_mode) self.crop_position = crop_position @property @@ -435,8 +435,8 @@ def crop_position(self, crop_position: str) -> None: def _interpolate( self, data: Tensor, - scale_factor: _T.Union[Tensor, np.ndarray, _T.Sequence[float]], - ) -> _T.Tuple[Tensor, T_Scale]: + scale_factor: Tensor | np.ndarray | _T.Sequence[float], + ) -> tuple[Tensor, T_Scale]: """ Crop, interpolate and pad the tensor according to the scale_factor. @@ -512,7 +512,7 @@ class Zoom3d(_ZoomNd): def __init__( self, - target_shape: _T.Optional[_T.Sequence[int]], + target_shape: _T.Sequence[int] | None, interpolation_mode: str = "nearest", crop_position: str = "front_top_left", ): @@ -536,7 +536,7 @@ def __init__( raise ValueError(f"invalid interpolation_mode, got {interpolation_mode}") self._N = 3 - super(Zoom3d, self).__init__(target_shape, interpolation_mode) + super().__init__(target_shape, interpolation_mode) self.crop_position = crop_position @property @@ -573,7 +573,7 @@ def crop_position(self, crop_position: str) -> None: self._crop_position = crop_position def _interpolate( - self, data: Tensor, scale_factor: _T.Union[Tensor, np.ndarray, _T.Sequence[int]] + self, data: Tensor, scale_factor: Tensor | np.ndarray | _T.Sequence[int] ): """ Crop, interpolate and pad the tensor according to the scale_factor. diff --git a/FastSurferCNN/models/losses.py b/FastSurferCNN/models/losses.py index 7b25820d..5e8ba9d4 100644 --- a/FastSurferCNN/models/losses.py +++ b/FastSurferCNN/models/losses.py @@ -15,14 +15,14 @@ # IMPORTS +from numbers import Real + import torch import yacs.config - from torch import Tensor, nn from torch.nn import functional as F from torch.nn.modules.loss import _Loss -from numbers import Real -from typing import Optional, Tuple, Union + class DiceLoss(_Loss): """ @@ -38,8 +38,8 @@ def forward( self, output: Tensor, target: Tensor, - weights: Optional[int] = None, - ignore_index: Optional[int] = None, + weights: int | None = None, + ignore_index: int | None = None, ) -> torch.Tensor: """ Calulate the DiceLoss. @@ -108,7 +108,7 @@ class CrossEntropy2D(nn.Module): Returns calculated cross entropy. """ - def __init__(self, weight: Optional[Tensor] = None, reduction: str = "none"): + def __init__(self, weight: Tensor | None = None, reduction: str = "none"): """ Construct CrossEntropy2D object. @@ -119,7 +119,7 @@ def __init__(self, weight: Optional[Tensor] = None, reduction: str = "none"): reduction : str Specifies the reduction to apply to the output, as in nn.CrossEntropyLoss. Defaults to 'None'. """ - super(CrossEntropy2D, self).__init__() + super().__init__() self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction) print( f"Initialized {self.__class__.__name__} with weight: {weight} and reduction: {reduction}" @@ -159,7 +159,7 @@ def __init__(self, weight_dice: Real = 1, weight_ce: Real = 1): weight_ce : Real Weight for cross entropy loss. Defaults to 1. """ - super(CombinedLoss, self).__init__() + super().__init__() self.cross_entropy_loss = CrossEntropy2D() self.dice_loss = DiceLoss() self.weight_dice = weight_dice @@ -167,7 +167,7 @@ def __init__(self, weight_dice: Real = 1, weight_ce: Real = 1): def forward( self, inputx: Tensor, target: Tensor, weight: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: + ) -> tuple[Tensor, Tensor, Tensor]: """ Calculate the total loss, dice loss and cross entropy value for the given input. @@ -210,7 +210,7 @@ def forward( def get_loss_func( cfg: yacs.config.CfgNode, -) -> Union[CombinedLoss, CrossEntropy2D, DiceLoss]: +) -> CombinedLoss | CrossEntropy2D | DiceLoss: """ Give a default object of the loss function. diff --git a/FastSurferCNN/models/networks.py b/FastSurferCNN/models/networks.py index 0c83b79a..2ef064ec 100644 --- a/FastSurferCNN/models/networks.py +++ b/FastSurferCNN/models/networks.py @@ -13,10 +13,11 @@ # limitations under the License. # IMPORTS -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import numpy as np from torch import Tensor, nn + if TYPE_CHECKING: import yacs.config @@ -63,7 +64,7 @@ def __init__(self, params: dict, padded_size: int = 256): padded_size : int, default = 256 Size of image when padded (Default value = 256). """ - super(FastSurferCNNBase, self).__init__() + super().__init__() # Parameters for the Descending Arm self.encode1 = sm.CompetitiveEncoderBlockInput(params) @@ -94,8 +95,8 @@ def __init__(self, params: dict, padded_size: int = 256): def forward( self, x: Tensor, - scale_factor: Optional[Tensor] = None, - scale_factor_out: Optional[Tensor] = None, + scale_factor: Tensor | None = None, + scale_factor_out: Tensor | None = None, ) -> Tensor: """ Feedforward through graph. @@ -167,7 +168,7 @@ def __init__(self, params: dict, padded_size: int): padded_size : int Size of image when padded. """ - super(FastSurferCNN, self).__init__(params) + super().__init__(params) params["num_channels"] = params["num_filters"] self.classifier = sm.ClassifierBlock(params) @@ -184,8 +185,8 @@ def __init__(self, params: dict, padded_size: int): def forward( self, x: Tensor, - scale_factor: Optional[Tensor] = None, - scale_factor_out: Optional[Tensor] = None, + scale_factor: Tensor | None = None, + scale_factor_out: Tensor | None = None, ) -> Tensor: """ Feedforward through graph. @@ -264,7 +265,7 @@ def __init__(self, params: dict, padded_size: int = 256): """ num_c = params["num_channels"] params["num_channels"] = params["num_filters_interpol"] - super(FastSurferVINN, self).__init__(params) + super().__init__(params) # Flex options self.height = params["height"] @@ -329,7 +330,7 @@ def __init__(self, params: dict, padded_size: int = 256): nn.init.constant_(m.bias, 0) def forward( - self, x: Tensor, scale_factor: Tensor, scale_factor_out: Optional[Tensor] = None + self, x: Tensor, scale_factor: Tensor, scale_factor_out: Tensor | None = None ) -> Tensor: """ Feedforward through graph. diff --git a/FastSurferCNN/models/optimizer.py b/FastSurferCNN/models/optimizer.py index 0afe8c17..6cb049e9 100644 --- a/FastSurferCNN/models/optimizer.py +++ b/FastSurferCNN/models/optimizer.py @@ -13,7 +13,6 @@ # limitations under the License. # IMPORTS -from typing import Union import torch import yacs diff --git a/FastSurferCNN/models/sub_module.py b/FastSurferCNN/models/sub_module.py index 67c687d9..5814b1e0 100644 --- a/FastSurferCNN/models/sub_module.py +++ b/FastSurferCNN/models/sub_module.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Tuple # IMPORTS import torch @@ -40,7 +39,7 @@ class InputDenseBlock(nn.Module): Feedforward through graph. """ - def __init__(self, params: Dict): + def __init__(self, params: dict): """ Construct InputDenseBlock object. @@ -49,7 +48,7 @@ def __init__(self, params: Dict): params : Dict Parameters in dictionary format. """ - super(InputDenseBlock, self).__init__() + super().__init__() # Padding to get output tensor of same dimensions padding_h = int((params["kernel_h"] - 1) / 2) padding_w = int((params["kernel_w"] - 1) / 2) @@ -177,7 +176,7 @@ class CompetitiveDenseBlock(nn.Module): Feedforward through graph. """ - def __init__(self, params: Dict, outblock: bool = False): + def __init__(self, params: dict, outblock: bool = False): """ Construct CompetitiveDenseBlock object. @@ -188,7 +187,7 @@ def __init__(self, params: Dict, outblock: bool = False): outblock : bool Flag indicating if last block (Default value = False). """ - super(CompetitiveDenseBlock, self).__init__() + super().__init__() # Padding to get output tensor of same dimensions padding_h = int((params["kernel_h"] - 1) / 2) @@ -322,7 +321,7 @@ class CompetitiveDenseBlockInput(nn.Module): Feedforward through graph. """ - def __init__(self, params: Dict): + def __init__(self, params: dict): """ Construct CompetitiveDenseBlockInput object. @@ -331,7 +330,7 @@ def __init__(self, params: Dict): params : Dict Dictionary with parameters specifying block architecture. """ - super(CompetitiveDenseBlockInput, self).__init__() + super().__init__() # Padding to get output tensor of same dimensions padding_h = int((params["kernel_h"] - 1) / 2) @@ -496,7 +495,7 @@ class CompetitiveEncoderBlock(CompetitiveDenseBlock): Feed forward trough graph. """ - def __init__(self, params: Dict): + def __init__(self, params: dict): """ Construct CompetitiveEncoderBlock object. @@ -505,14 +504,14 @@ def __init__(self, params: Dict): params : Dict Parameters like number of channels, stride etc. """ - super(CompetitiveEncoderBlock, self).__init__(params) + super().__init__(params) self.maxpool = nn.MaxPool2d( kernel_size=params["pool"], stride=params["stride_pool"], return_indices=True, ) # For Unpooling later on with the indices - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ Feed forward trough Encoder Block. @@ -533,7 +532,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: indicies : Tensor Maxpool indices. """ - out_block = super(CompetitiveEncoderBlock, self).forward( + out_block = super().forward( x ) # To be concatenated as Skip Connection out_encoder, indices = self.maxpool( @@ -547,7 +546,7 @@ class CompetitiveEncoderBlockInput(CompetitiveDenseBlockInput): Encoder Block = CompetitiveDenseBlockInput + Max Pooling. """ - def __init__(self, params: Dict): + def __init__(self, params: dict): """ Construct CompetitiveEncoderBlockInput object. @@ -556,7 +555,7 @@ def __init__(self, params: Dict): params : Dict Parameters like number of channels, stride etc. """ - super(CompetitiveEncoderBlockInput, self).__init__( + super().__init__( params ) # The init of CompetitiveDenseBlock takes in params self.maxpool = nn.MaxPool2d( @@ -565,7 +564,7 @@ def __init__(self, params: Dict): return_indices=True, ) # For Unpooling later on with the indices - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ Feed forward trough Encoder Block. @@ -586,7 +585,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: Tensor The indices of the maxpool operation. """ - out_block = super(CompetitiveEncoderBlockInput, self).forward( + out_block = super().forward( x ) # To be concatenated as Skip Connection out_encoder, indices = self.maxpool( @@ -600,7 +599,7 @@ class CompetitiveDecoderBlock(CompetitiveDenseBlock): Decoder Block = (Unpooling + Skip Connection) --> Dense Block. """ - def __init__(self, params: Dict, outblock: bool = False): + def __init__(self, params: dict, outblock: bool = False): """ Construct CompetitiveDecoderBlock object. @@ -612,7 +611,7 @@ def __init__(self, params: Dict, outblock: bool = False): Flag, indicating if last block of network before classifier is created.(Default value = False) """ - super(CompetitiveDecoderBlock, self).__init__(params, outblock=outblock) + super().__init__(params, outblock=outblock) self.unpool = nn.MaxUnpool2d( kernel_size=params["pool"], stride=params["stride_pool"] ) @@ -641,7 +640,7 @@ def forward(self, x: Tensor, out_block: Tensor, indices: Tensor) -> Tensor: """ unpool = self.unpool(x, indices) concat_max = torch.maximum(unpool, out_block) - out_block = super(CompetitiveDecoderBlock, self).forward(concat_max) + out_block = super().forward(concat_max) return out_block @@ -674,7 +673,7 @@ def __init__(self, params: dict): params : dict Parameters like number of channels, stride etc. """ - super(OutputDenseBlock, self).__init__() + super().__init__() # Padding to get output tensor of same dimensions padding_h = int((params["kernel_h"] - 1) / 2) @@ -793,7 +792,7 @@ def __init__(self, params: dict): params : dict Parameters like number of channels, stride etc. """ - super(ClassifierBlock, self).__init__() + super().__init__() self.conv = nn.Conv2d( params["num_channels"], params["num_classes"], diff --git a/FastSurferCNN/mri_brainvol_stats.py b/FastSurferCNN/mri_brainvol_stats.py index e90b3081..6a37b7aa 100644 --- a/FastSurferCNN/mri_brainvol_stats.py +++ b/FastSurferCNN/mri_brainvol_stats.py @@ -20,8 +20,8 @@ from os import environ as env from pathlib import Path -from FastSurferCNN.segstats import HelpFormatter, main, VERSION from FastSurferCNN.mri_segstats import print_and_exit +from FastSurferCNN.segstats import VERSION, HelpFormatter, main DEFAULT_MEASURES_STRINGS = [ (False, "BrainSeg"), diff --git a/FastSurferCNN/mri_segstats.py b/FastSurferCNN/mri_segstats.py index bf25a71a..25a7bb16 100644 --- a/FastSurferCNN/mri_segstats.py +++ b/FastSurferCNN/mri_segstats.py @@ -17,16 +17,17 @@ # IMPORTS import argparse -from itertools import pairwise, chain +from collections.abc import Iterable, Sequence +from itertools import chain, pairwise from pathlib import Path -from typing import TypeVar, Sequence, Any, Iterable +from typing import Any, TypeVar from FastSurferCNN.segstats import ( - main, + VERSION, HelpFormatter, add_two_help_messages, - VERSION, empty, + main, ) _T = TypeVar("_T") @@ -79,7 +80,7 @@ def __init__( help: str | None = None, metavar: str | tuple[str, ...] | None = None, ) -> None: - super(_ExtendConstAction, self).__init__( + super().__init__( option_strings=option_strings, dest=dest, nargs=0, @@ -129,11 +130,11 @@ def add_etiv_measures(args: argparse.Namespace) -> None: measures = [m for m in measures if m[1] == ETIV_RATIO_KEY] for k, v in ETIV_RATIOS.items(): - for is_imported, m in measures: + for _is_imported, m in measures: if m == v or m.startswith(v + "("): measures.append((False, k)) continue - setattr(args, "measures", measures) + args.measures = measures def _update_what_to_import(args: argparse.Namespace) -> argparse.Namespace: """ @@ -175,7 +176,7 @@ def help_text(keys: Iterable[str]) -> Iterable[str]: def help_add_measures(message: str, keys: list[str]) -> str: if help_text: _keys = (k.split(' ')[0] for k in keys) - keys = [f"{k}: {text}" for k, text in zip(keys, help_text(_keys))] + keys = [f"{k}: {text}" for k, text in zip(keys, help_text(_keys), strict=False)] return "
- ".join([message] + list(keys)) add_two_help_messages(parser) @@ -540,6 +541,6 @@ def _extend_arg(name: str, flag: str = None): args = make_arguments().parse_args() parse_actions = getattr(args, "parse_actions", []) - for i, parse_action in sorted(parse_actions, key=lambda x: x[0], reverse=True): + for _i, parse_action in sorted(parse_actions, key=lambda x: x[0], reverse=True): parse_action(args) sys.exit(main(args)) diff --git a/FastSurferCNN/quick_qc.py b/FastSurferCNN/quick_qc.py index 313580ff..04a5179c 100644 --- a/FastSurferCNN/quick_qc.py +++ b/FastSurferCNN/quick_qc.py @@ -90,8 +90,8 @@ def check_volume(asegdkt_segfile:np.ndarray, voxvol: float, thres: float = 0.70) print("Checking total volume ...") mask = asegdkt_segfile > 0 total_vol = np.sum(mask) * voxvol / 1000000 - print("Voxel size in mm3: {}".format(voxvol)) - print("Total segmentation volume in liter: {}".format(np.round(total_vol, 2))) + print(f"Voxel size in mm3: {voxvol}") + print(f"Total segmentation volume in liter: {np.round(total_vol, 2)}") if total_vol < thres: return False @@ -101,7 +101,7 @@ def check_volume(asegdkt_segfile:np.ndarray, voxvol: float, thres: float = 0.70) def get_region_bg_intersection_mask( seg_array, region_labels=VENT_LABELS, bg_label=BG_LABEL ): - f""" + """ Return a mask of the intersection between the voxels of a given region and background voxels. This is obtained by dilating the region by 1 voxel and computing the intersection with the @@ -115,8 +115,7 @@ def get_region_bg_intersection_mask( Segmentation array. region_labels : dict, default= Dictionary whose values correspond to the desired region's labels (see Note). - bg_label : int, default={BG_LABEL} - (Default value = {BG_LABEL}). + bg_label : int, default as in BG_LABEL. Returns ------- @@ -185,9 +184,8 @@ def get_ventricle_bg_intersection_volume(seg_array, voxvol): # Ventricle-BG intersection volume check: print("Estimating ventricle-background intersection volume...") print( - "Ventricle-background intersection volume in mm3: {:.2f}".format( - get_ventricle_bg_intersection_volume(inseg_data, inseg_voxvol) - ) + f"Ventricle-background intersection volume in mm3:" \ + f" {get_ventricle_bg_intersection_volume(inseg_data, inseg_voxvol):.2f}" ) # Total volume check: diff --git a/FastSurferCNN/reduce_to_aseg.py b/FastSurferCNN/reduce_to_aseg.py index c7c3cc69..a9666021 100644 --- a/FastSurferCNN/reduce_to_aseg.py +++ b/FastSurferCNN/reduce_to_aseg.py @@ -153,7 +153,7 @@ def create_mask(aseg_data, dnum, enum): # extract largest component labels = label(datab) assert labels.max() != 0 # assume at least 1 real connected component - print(" Found {} connected component(s)!".format(labels.max())) + print(f" Found {labels.max()} connected component(s)!") if labels.max() > 1: print(" Selecting largest component!") @@ -221,7 +221,7 @@ def flip_wm_islands(aseg_data : np.ndarray) -> np.ndarray: flip_data = aseg_data.copy() flip_data[rhswap] = lh_wm flip_data[lhswap] = rh_wm - print("FlipWM: rh {} and lh {} flipped.".format(rhswap.sum(), lhswap.sum())) + print(f"FlipWM: rh {rhswap.sum()} and lh {lhswap.sum()} flipped.") return flip_data @@ -230,7 +230,7 @@ def flip_wm_islands(aseg_data : np.ndarray) -> np.ndarray: # Command Line options are error checking done here options = options_parse() - print("Reading in aparc+aseg: {} ...".format(options.input_seg)) + print(f"Reading in aparc+aseg: {options.input_seg} ...") inseg = nib.load(options.input_seg) inseg_data = np.asanyarray(inseg.dataobj) inseg_header = inseg.header @@ -242,7 +242,7 @@ def flip_wm_islands(aseg_data : np.ndarray) -> np.ndarray: # get mask if options.output_mask: bm = create_mask(copy.deepcopy(inseg_data), 5, 4) - print("Outputting mask: {}".format(options.output_mask)) + print(f"Outputting mask: {options.output_mask}") mask = nib.MGHImage(bm, inseg_affine, inseg_header) mask.to_filename(options.output_mask) @@ -256,7 +256,7 @@ def flip_wm_islands(aseg_data : np.ndarray) -> np.ndarray: if options.fix_wm: aseg = flip_wm_islands(aseg) - print("Outputting aseg: {}".format(options.output_seg)) + print(f"Outputting aseg: {options.output_seg}") aseg_fin = nib.MGHImage(aseg, inseg_affine, inseg_header) aseg_fin.to_filename(options.output_seg) diff --git a/FastSurferCNN/run_model.py b/FastSurferCNN/run_model.py index 4d7dcc89..add33bca 100644 --- a/FastSurferCNN/run_model.py +++ b/FastSurferCNN/run_model.py @@ -76,9 +76,9 @@ def main(args): if cfg.TRAIN.RESUME and cfg.TRAIN.RESUME_EXPR_NUM != "Default": cfg.EXPR_NUM = cfg.TRAIN.RESUME_EXPR_NUM - cfg.SUMMARY_PATH = misc.check_path(join(summary_path, "{}".format(cfg.EXPR_NUM))) + cfg.SUMMARY_PATH = misc.check_path(join(summary_path, f"{cfg.EXPR_NUM}")) cfg.CONFIG_LOG_PATH = misc.check_path( - join(cfg.LOG_DIR, "config", "{}".format(cfg.EXPR_NUM)) + join(cfg.LOG_DIR, "config", f"{cfg.EXPR_NUM}") ) with open(join(cfg.CONFIG_LOG_PATH, "config.yaml"), "w") as json_file: diff --git a/FastSurferCNN/run_prediction.py b/FastSurferCNN/run_prediction.py index e6799ec6..6af3535f 100644 --- a/FastSurferCNN/run_prediction.py +++ b/FastSurferCNN/run_prediction.py @@ -29,9 +29,10 @@ import argparse import copy import sys -from concurrent.futures import Executor, ThreadPoolExecutor, Future +from collections.abc import Iterator, Sequence +from concurrent.futures import Executor, Future, ThreadPoolExecutor from pathlib import Path -from typing import Any, Iterator, Literal, Optional, Sequence +from typing import Any, Literal import nibabel as nib import numpy as np @@ -42,29 +43,29 @@ from FastSurferCNN.data_loader import conform as conf from FastSurferCNN.data_loader import data_utils as du from FastSurferCNN.inference import Inference -from FastSurferCNN.utils import logging, parser_defaults, Plane, PLANES +from FastSurferCNN.quick_qc import check_volume +from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults from FastSurferCNN.utils.arg_types import VoxSizeOption from FastSurferCNN.utils.checkpoint import ( get_checkpoints, load_checkpoint_config_defaults, ) -from FastSurferCNN.utils.load_config import load_config from FastSurferCNN.utils.common import ( SerialExecutor, - find_device, + SubjectDirectory, + SubjectList, assert_no_root, + find_device, handle_cuda_memory_exception, - SubjectList, - SubjectDirectory, pipeline, ) -from FastSurferCNN.utils.parser_defaults import SubjectDirectoryConfig -from FastSurferCNN.quick_qc import check_volume +from FastSurferCNN.utils.load_config import load_config ## # Global Variables ## -from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT, SubjectDirectoryConfig + LOGGER = logging.getLogger(__name__) CHECKPOINT_PATHS_FILE = FASTSURFER_ROOT / "FastSurferCNN/config/checkpoint_paths.yaml" @@ -103,9 +104,9 @@ def set_up_cfgs( def args2cfg( - cfg_ax: Optional[str] = None, - cfg_cor: Optional[str] = None, - cfg_sag: Optional[str] = None, + cfg_ax: str | None = None, + cfg_cor: str | None = None, + cfg_sag: str | None = None, batch_size: int = 1, ) -> tuple[ yacs.config.CfgNode, yacs.config.CfgNode, yacs.config.CfgNode, yacs.config.CfgNode @@ -139,8 +140,8 @@ def args2cfg( # returns the first non-None cfg try: cfg_fin = next(filter(None, cfgs)) - except StopIteration: - raise RuntimeError("No valid configuration passed!") + except StopIteration as err: + raise RuntimeError("No valid configuration passed!") from err return (cfg_fin,) + cfgs @@ -192,7 +193,7 @@ class RunModelOnData: current_plane: Plane models: dict[Plane, Inference] view_ops: dict[Plane, dict[str, Any]] - conform_to_1mm_threshold: Optional[float] + conform_to_1mm_threshold: float | None device: torch.device viewagg_device: torch.device _pool: Executor @@ -200,12 +201,12 @@ class RunModelOnData: def __init__( self, lut: Path, - ckpt_ax: Optional[Path] = None, - ckpt_sag: Optional[Path] = None, - ckpt_cor: Optional[Path] = None, - cfg_ax: Optional[Path] = None, - cfg_sag: Optional[Path] = None, - cfg_cor: Optional[Path] = None, + ckpt_ax: Path | None = None, + ckpt_sag: Path | None = None, + ckpt_cor: Path | None = None, + cfg_ax: Path | None = None, + cfg_sag: Path | None = None, + cfg_cor: Path | None = None, device: str = "auto", viewagg_device: str = "auto", threads: int = 1, @@ -246,11 +247,11 @@ def __init__( try: self.lut = du.read_classes_from_lut(lut) - except FileNotFoundError: + except FileNotFoundError as err: raise ValueError( f"Could not find the ColorLUT in {lut}, please make sure the " f"--lut argument is valid." - ) + ) from err self.labels = self.lut["ID"].values self.torch_labels = torch.from_numpy(self.lut["ID"].values) self.names = ["SubjectName", "Average", "Subcortical", "Cortical"] @@ -418,7 +419,7 @@ def save_img( save_as: str | Path, data: np.ndarray | torch.Tensor, orig: nib.analyze.SpatialImage, - dtype: Optional[type] = None, + dtype: type | None = None, ) -> None: """ Save image as a file. @@ -532,8 +533,7 @@ def pipeline_conform_and_save_orig( yield subject, self.conform_and_save_orig(subject) else: # pipeline the same - for data in pipeline(self.pool, self.conform_and_save_orig, subjects): - yield data + yield from pipeline(self.pool, self.conform_and_save_orig, subjects) def make_parser(): @@ -614,11 +614,11 @@ def main( log_name: str = "", allow_root: bool = False, conf_name: str = "mri/orig.mgz", - in_dir: Optional[Path] = None, - sid: Optional[str] = None, - search_tag: Optional[str] = None, - csv_file: Optional[str | Path] = None, - lut: Optional[Path | str] = None, + in_dir: Path | None = None, + sid: str | None = None, + search_tag: str | None = None, + csv_file: str | Path | None = None, + lut: Path | str | None = None, remove_suffix: str = "", brainmask_name: str = "mri/mask.mgz", aseg_name: str = "mri/aseg.auto_noCC.mgz", diff --git a/FastSurferCNN/segstats.py b/FastSurferCNN/segstats.py index 0be39ef2..98319cb7 100644 --- a/FastSurferCNN/segstats.py +++ b/FastSurferCNN/segstats.py @@ -18,30 +18,23 @@ # IMPORTS import argparse import logging +from collections.abc import Callable, Container, Iterable, Iterator, Sequence, Sized +from concurrent.futures import Executor, ThreadPoolExecutor from functools import partial, reduce from itertools import product from numbers import Number from pathlib import Path from typing import ( - Any, - Callable, - cast, - Iterable, IO, + Any, Literal, - Optional, - overload, - Sequence, - Sized, - Type, TypedDict, TypeVar, - Container, - Iterator, + cast, + overload, ) -from concurrent.futures import Executor, ThreadPoolExecutor - +import nibabel as nib import numpy as np import pandas as pd from numpy import typing as npt @@ -49,6 +42,7 @@ from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as robust_threshold from FastSurferCNN.utils.arg_types import int_ge_zero as id_type from FastSurferCNN.utils.arg_types import int_gt_zero as patch_size_type +from FastSurferCNN.utils.brainvolstats import Manager from FastSurferCNN.utils.parser_defaults import add_arguments from FastSurferCNN.utils.threads import get_num_threads @@ -90,9 +84,9 @@ SlicingTuple = tuple[slice, ...] SlicingSequence = Sequence[slice] VirtualLabel = dict[int, Sequence[int]] -_GlobalStats = tuple[int, int, Optional[_NumberType], Optional[_NumberType], - Optional[float], Optional[float], float, npt.NDArray[bool]] -SubparserCallback = Type[argparse.ArgumentParser.add_subparsers] +_GlobalStats = tuple[int, int, _NumberType | None, _NumberType | None, + float | None, float | None, float, npt.NDArray[bool]] +SubparserCallback = type[argparse.ArgumentParser.add_subparsers] class _RequiredPVStats(TypedDict): @@ -154,15 +148,15 @@ def _fill_text(self, text: str, width: int, indent: str) -> str: """ cond_len, texts = self._itemized_lines(text) lines = (super(HelpFormatter, self)._fill_text(t[p:], width, indent + " " * p) - for t, (c, p) in zip(texts, cond_len)) - return "\n".join("- " + t[p:] if c else t for t, (c, p) in zip(lines, cond_len)) + for t, (c, p) in zip(texts, cond_len, strict=False)) + return "\n".join("- " + t[p:] if c else t for t, (c, p) in zip(lines, cond_len, strict=False)) def _itemized_lines(self, text): texts = text.split(self._linebreak_sub()) item = self._item_symbol() il = len(item) cond_len = [(c, il if c else 0) for c in map(lambda t: t[:il] == item, texts)] - texts = [t[p:] for t, (c, p) in zip(texts, cond_len)] + texts = [t[p:] for t, (c, p) in zip(texts, cond_len, strict=False)] return cond_len, texts def _split_lines(self, text: str, width: int) -> list[str]: @@ -182,13 +176,13 @@ def _split_lines(self, text: str, width: int) -> list[str]: The list of lines. """ def indent_list(items: list[str]) -> list[str]: - return ["- " + items[0]] + [" " + l for l in items[1:]] + return ["- " + items[0]] + [" " + ln for ln in items[1:]] cond_len, texts = self._itemized_lines(text) from itertools import chain lines = (super(HelpFormatter, self)._split_lines(tex, width - p) - for tex, (c, p) in zip(texts, cond_len)) - lines = ((indent_list(lst) if c[0] else lst) for lst, c in zip(lines, cond_len)) + for tex, (c, p) in zip(texts, cond_len, strict=False)) + lines = ((indent_list(lst) if c[0] else lst) for lst, c in zip(lines, cond_len, strict=False)) return list(chain.from_iterable(lines)) @@ -755,9 +749,10 @@ def main(args: argparse.Namespace) -> Literal[0] | str: Either as a successful return code or a string with an error message. """ from time import perf_counter_ns - from FastSurferCNN.utils.common import assert_no_root - from FastSurferCNN.utils.brainvolstats import Manager, read_volume_file, ImageTuple + from FastSurferCNN.data_loader.data_utils import read_classes_from_lut + from FastSurferCNN.utils.brainvolstats import ImageTuple, Manager, read_volume_file + from FastSurferCNN.utils.common import assert_no_root start = perf_counter_ns() getattr(args, "allow_root", False) or assert_no_root() @@ -783,8 +778,8 @@ def main(args: argparse.Namespace) -> Literal[0] | str: require_pvfile=not legacy_freesurfer, ) if legacy_freesurfer and not measure_only and pvfile is None: - return (f"No files are defined via -pv/--pvfile or -norm/--normfile: " - f"This is only supported for header only in legacy mode.") + return ("No files are defined via -pv/--pvfile or -norm/--normfile: " + "This is only supported for header only in legacy mode.") if measurefile: manager_kwargs["measurefile"] = measurefile except ValueError as e: @@ -837,10 +832,10 @@ def main(args: argparse.Namespace) -> Literal[0] | str: norm, norm_data = _norm check_shape_affine(seg, norm, "segmentation", "norm") - except (IOError, RuntimeError, FileNotFoundError) as e: + except (OSError, RuntimeError, FileNotFoundError) as e: return e.args[0] - lut: Optional[pd.DataFrame] = None + lut: pd.DataFrame | None = None if lut_file: try: lut = read_lut(lut_file) @@ -934,7 +929,7 @@ def main(args: argparse.Namespace) -> Literal[0] | str: if save_maps: table, maps = out dtypes = [np.int16] + [np.float32] * 4 - for name, dtype in zip(names, dtypes): + for name, dtype in zip(names, dtypes, strict=False): if not bool(file := getattr(args, name, "")) or file == Path(): # skip "fullview"-files that are not defined continue @@ -1035,7 +1030,7 @@ def infer_merged_labels( def table_to_dataframe( table: list[PVStats], report_empty: bool = True, - must_keep_ids: Optional[Container[int]] = None, + must_keep_ids: Container[int] | None = None, ) -> pd.DataFrame: """ Convert the list of PVStats dictionaries into a dataframe. @@ -1068,7 +1063,7 @@ def table_to_dataframe( def update_structnames( table: list[PVStats], lut: pd.DataFrame, - merged_labels: Optional[dict[_IntType, Sequence[_IntType]]] = None + merged_labels: dict[_IntType, Sequence[_IntType]] | None = None ) -> None: """ Update StructNames from `lut` and `merged_labels` in `table`. @@ -1115,11 +1110,11 @@ def write_statsfile( segstatsfile: Path | str, dataframe: pd.DataFrame, vox_vol: float, - exclude: Optional[Sequence[int | str]] = None, - segfile: Optional[Path | str] = None, - normfile: Optional[Path | str] = None, - pvfile: Optional[Path | str] = None, - lut: Optional[Path | str] = None, + exclude: Sequence[int | str] | None = None, + segfile: Path | str | None = None, + normfile: Path | str | None = None, + pvfile: Path | str | None = None, + lut: Path | str | None = None, report_empty: bool = False, extra_header: Sequence[str] = (), norm_name: str = "norm", @@ -1178,6 +1173,7 @@ def _system_info(file: IO) -> None: """ import os import sys + from FastSurferCNN.version import read_and_close_version file.write( "# generating_program segstats.py\n" @@ -1201,7 +1197,7 @@ def _system_info(file: IO) -> None: try: file.write(f"# user {getuser()}\n") except KeyError: - file.write(f"# user UNKNOWN\n") + file.write("# user UNKNOWN\n") def _extra_header(file: IO, lines_extra_header: Iterable[str]) -> None: """ @@ -1221,12 +1217,12 @@ def _extra_header(file: IO, lines_extra_header: Iterable[str]) -> None: warn_msg_sent or warn( f"extra_header[{i}] includes embedded newline characters. " - "Replacing all newline characters with ." + "Replacing all newline characters with .", stacklevel=2 ) warn_msg_sent = True file.write(f"# {line}\n") - def _file_annotation(file: IO, name: str, path_to_annotate: Optional[Path]) -> None: + def _file_annotation(file: IO, name: str, path_to_annotate: Path | None) -> None: """ Write the annotation to file/path to a file. """ @@ -1242,7 +1238,7 @@ def _extra_parameters( _voxvol: float, _exclude: Sequence[int | str], _report_empty: bool = False, - _lut: Optional[Path] = None, + _lut: Path | None = None, _leg_freesurfer: bool = False, ) -> None: """ @@ -1313,7 +1309,7 @@ def _table_body(file: IO, _dataframe: pd.DataFrame) -> None: """Write the volume stats from _dataframe to a file.""" columns = [col for col in COLUMNS if col in _dataframe.columns] fmt = " ".join(_column_format(k) for k in columns) - for index, row in _dataframe.iterrows(): + for _index, row in _dataframe.iterrows(): data = [row[k] for k in columns] file.write(fmt.format(*data) + "\n") @@ -1331,7 +1327,7 @@ def _table_body(file: IO, _dataframe: pd.DataFrame) -> None: with open(segstatsfile, "w") as fp: _title(fp) _system_info(fp) - fp.write(f"# anatomy_type volume\n#\n") + fp.write("# anatomy_type volume\n#\n") _extra_header(fp, extra_header) _file_annotation(fp, "SegVolFile", segfile) @@ -1392,7 +1388,7 @@ def preproc_image( def seg_borders( _array: _ArrayType, label: np.integer | bool, - out: Optional[npt.NDArray[bool]] = None, + out: npt.NDArray[bool] | None = None, cmp_dtype: npt.DTypeLike = "int8", ) -> npt.NDArray[bool]: """ @@ -1436,9 +1432,9 @@ def seg_borders( def borders( _array: _ArrayType, labels: Iterable[np.integer] | bool, - max_label: Optional[np.integer] = None, + max_label: np.integer | None = None, six_connected: bool = True, - out: Optional[npt.NDArray[bool]] = None, + out: npt.NDArray[bool] | None = None, ) -> npt.NDArray[bool]: """ Handle to fast border computation. @@ -1496,7 +1492,6 @@ def cmp(a, b): labels = [0] + labels lookup[labels] = np.arange(len(labels), dtype=lookup.dtype) _array = lookup[_array] - logical_or = np.logical_or # pad array by 1 voxel of zeros all around padded = np.pad(_array, 1) @@ -1517,7 +1512,7 @@ def indexer(axis: int, is_mid: bool) -> tuple[SlicingTuple, SlicingTuple]: # ((False, True), (True, False), (False, False), (False, True)) for each dim # is_mid=False: padded values already dropped indexes = (indexer(i, is_mid=False) for i in range(dim)) - nbr_same = [(nbr_[i], nbr_[j]) for (i, j), nbr_ in zip(indexes, nbr_same)] + nbr_same = [(nbr_[i], nbr_[j]) for (i, j), nbr_ in zip(indexes, nbr_same, strict=False)] from itertools import chain nbr_same = list(chain.from_iterable(nbr_same)) else: @@ -1577,8 +1572,8 @@ def _slice(start_end: npt.NDArray[int]) -> slice: _start, _end = start_end return slice(_start.item(), None if _end.item() == 0 else _end.item()) # make grown patch and grown patch to patch - padded_slicer = tuple(slice(s.item(), e.item()) for s, e in zip(_start, _stop)) - unpadded_slicer = tuple(map(_slice, zip(start - _start, stop - _stop))) + padded_slicer = tuple(slice(s.item(), e.item()) for s, e in zip(_start, _stop, strict=False)) + unpadded_slicer = tuple(map(_slice, zip(start - _start, stop - _stop, strict=False))) return padded_slicer, unpadded_slicer @@ -1586,7 +1581,7 @@ def uniform_filter( data: _ArrayType, filter_size: int, fillval: float = 0., - slicer_patch: Optional[SlicingTuple] = None, + slicer_patch: SlicingTuple | None = None, ) -> _ArrayType: """ Apply a uniform filter (with kernel size `filter_size`) to `input`. @@ -1636,8 +1631,8 @@ def pv_calc( patch_size: int = 32, vox_vol: float = 1.0, eps: float = 1e-6, - robust_percentage: Optional[float] = None, - merged_labels: Optional[VirtualLabel] = None, + robust_percentage: float | None = None, + merged_labels: VirtualLabel | None = None, threads: int | Executor = -1, return_maps: False = False, legacy_freesurfer: bool = False, @@ -1654,8 +1649,8 @@ def pv_calc( patch_size: int = 32, vox_vol: float = 1.0, eps: float = 1e-6, - robust_percentage: Optional[float] = None, - merged_labels: Optional[VirtualLabel] = None, + robust_percentage: float | None = None, + merged_labels: VirtualLabel | None = None, threads: int | Executor = -1, return_maps: True = True, legacy_freesurfer: bool = False, @@ -1666,7 +1661,7 @@ def pv_calc( def pv_calc( seg: npt.NDArray[_IntType], pv_guide: np.ndarray, - norm: Optional[np.ndarray], + norm: np.ndarray | None, labels: npt.ArrayLike, patch_size: int = 32, vox_vol: float = 1.0, @@ -1753,8 +1748,8 @@ def pv_calc( f"are {seg.shape} and {norm.shape}!" ) - mins, maxes, voxel_counts, robust_voxel_counts = [{} for _ in range(4)] - borders, sums, sums_2, volumes = [{} for _ in range(4)] + mins, maxes, voxel_counts, robust_voxel_counts = ({} for _ in range(4)) + borders, sums, sums_2, volumes = ({} for _ in range(4)) if isinstance(merged_labels, dict) and len(merged_labels) > 0: _more_labels = list(merged_labels.values()) @@ -1789,7 +1784,7 @@ def pv_calc( raise ValueError("Zero is not a valid number of threads.") elif isinstance(threads, int) and threads > 0: nthreads = threads - elif isinstance(threads, (Executor, int)): + elif isinstance(threads, Executor | int): nthreads: int = get_num_threads() else: raise TypeError("threads must be int or concurrent.futures.Executor object.") @@ -1821,7 +1816,7 @@ def pv_calc( # un_global_crop border here any_border = np.any(list(borders.values()), axis=0) pad_width = np.asarray( - [(slc.start, shp - slc.stop) for slc, shp in zip(global_crop, seg.shape)], + [(slc.start, shp - slc.stop) for slc, shp in zip(global_crop, seg.shape, strict=False)], dtype=int, ) any_border = np.pad(any_border, pad_width) @@ -1874,7 +1869,7 @@ def get_std(lab: _IntType, nvox: int) -> float: stds = {lab: get_std(lab, nvox) for lab, nvox in robust_vc_it if nvox > eps} - for lab, this in zip(labels, table): + for lab, this in zip(labels, table, strict=False): this.update( Mean=means.get(lab, 0.0), StdDev=stds.get(lab, 0.0), @@ -1904,10 +1899,10 @@ def calculate_merged_labels( voxel_counts: dict[_IntType, int], robust_voxel_counts: dict[_IntType, int], volumes: dict[_IntType, float], - mins: Optional[dict[_IntType, float]] = None, - maxes: Optional[dict[_IntType, float]] = None, - sums: Optional[dict[_IntType, float]] = None, - sums_of_squares: Optional[dict[_IntType, float]] = None, + mins: dict[_IntType, float] | None = None, + maxes: dict[_IntType, float] | None = None, + sums: dict[_IntType, float] | None = None, + sums_of_squares: dict[_IntType, float] | None = None, eps: float = 1e-6, ) -> Iterator[PVStats]: """ @@ -1945,18 +1940,18 @@ def num_robust_voxels(lab): def aggregate(source, merge_labels, f: Callable[..., np.ndarray] = np.sum): """aggregate labels `merge_labels` from `source` with function `f`""" - _data = [source.get(l, 0) for l in merge_labels if num_robust_voxels(l) > eps] + _data = [source.get(lb, 0) for lb in merge_labels if num_robust_voxels(lb) > eps] return f(_data).item() def aggregate_std(sums, sums2, merge_labels, nvox): """aggregate std of labels `merge_labels` from `source`""" - s2 = [(s := sums.get(l, 0)) * s / r for l in group - if (r := num_robust_voxels(l)) > eps] + s2 = [(s := sums.get(lb, 0)) * s / r for lb in group + if (r := num_robust_voxels(lb)) > eps] return np.sqrt((aggregate(sums2, merge_labels) - np.sum(s2)) / nvox).item() for lab, group in merged_labels.items(): stats = {"SegId": lab} - if all(l not in robust_voxel_counts for l in group): + if all(lb not in robust_voxel_counts for lb in group): logging.getLogger(__name__).warning( f"None of the labels {group} for merged label {lab} exist in the " f"segmentation." @@ -1994,8 +1989,8 @@ def global_stats( lab: _IntType, norm: npt.NDArray[_NumberType] | None, seg: npt.NDArray[_IntType], - out: Optional[npt.NDArray[bool]] = None, - robust_percentage: Optional[float] = None, + out: npt.NDArray[bool] | None = None, + robust_percentage: float | None = None, ) -> tuple[_IntType, _GlobalStats]: """ Compute Label, Number of voxels, 'robust' number of voxels, norm minimum, maximum, @@ -2025,7 +2020,7 @@ def global_stats( sum_of_intensity_squares, and border with respect to the label. """ - def __compute_borders(out: Optional[np.ndarray]) -> np.ndarray: + def __compute_borders(out: np.ndarray | None) -> np.ndarray: # compute/update the border if out is None: out = seg_borders(label_mask, True, cmp_dtype="int8").astype(bool) @@ -2099,14 +2094,14 @@ def _slice(patch_start, _patch_size, image_stop): return slice(patch_start, min(patch_start + _patch_size, image_stop)) # create slices for current patch context (constrained by the global_crop) - patch = [_slice(pc, patch_size, s.stop) for pc, s in zip(patch_corner, global_crop)] + patch = [_slice(pc, patch_size, s.stop) for pc, s in zip(patch_corner, global_crop, strict=False)] # crop patch context to the image content return crop_patch_to_mask(mask, sub_patch=patch) def crop_patch_to_mask( mask: npt.NDArray[_NumberType], - sub_patch: Optional[SlicingSequence] = None, + sub_patch: SlicingSequence | None = None, ) -> tuple[bool, SlicingSequence]: """ Crop the patch to regions of the mask that are non-zero. @@ -2164,7 +2159,7 @@ def _move_slice(the_slice: slice, offset: int) -> slice: return slice(the_slice.start + offset, the_slice.stop + offset) target_slicer = [_move_slice(ts, sc.start) for ts, sc in zip(_target_slicer, - slicer_context)] + slicer_context, strict=False)] return _target_slicer[0].start != _target_slicer[0].stop, target_slicer @@ -2175,11 +2170,11 @@ def pv_calc_patch( seg: npt.NDArray[_IntType], pv_guide: npt.NDArray, border: npt.NDArray[bool], - full_pv: Optional[npt.NDArray[float]] = None, - full_ipv: Optional[npt.NDArray[float]] = None, - full_nbr_label: Optional[npt.NDArray[_IntType]] = None, - full_seg_mean: Optional[npt.NDArray[float]] = None, - full_nbr_mean: Optional[npt.NDArray[float]] = None, + full_pv: npt.NDArray[float] | None = None, + full_ipv: npt.NDArray[float] | None = None, + full_nbr_label: npt.NDArray[_IntType] | None = None, + full_seg_mean: npt.NDArray[float] | None = None, + full_nbr_mean: npt.NDArray[float] | None = None, eps: float = 1e-6, legacy_freesurfer: bool = False, ) -> dict[_IntType, float]: @@ -2245,11 +2240,11 @@ def pv_calc_patch( slicer_large_to_small = tuple( slice(l2p.start - s2p.start, None if l2p.stop == s2p.stop else l2p.stop - s2p.stop) - for s2p, l2p in zip(slicer_small_to_patch, slicer_large_to_patch)) + for s2p, l2p in zip(slicer_small_to_patch, slicer_large_to_patch, strict=False)) patch_in_gc = tuple( slice(p.start - gc.start, p.stop - gc.start) - for p, gc in zip(slicer_patch, global_crop)) + for p, gc in zip(slicer_patch, global_crop, strict=False)) label_lookup = np.unique(seg[slicer_small_patch]) maxlabels = label_lookup[-1] + 1 @@ -2506,5 +2501,5 @@ def patch_neighbors( from os import environ as env if (sd := env.get("SUBJECTS_DIR")) is not None: - setattr(opts, "out_dir", sd) + opts.out_dir = sd sys.exit(main(opts)) diff --git a/FastSurferCNN/train.py b/FastSurferCNN/train.py index 8e30ebd4..2a2f43c8 100644 --- a/FastSurferCNN/train.py +++ b/FastSurferCNN/train.py @@ -18,7 +18,6 @@ import pprint import time from collections import defaultdict -from typing import Union import numpy as np import torch @@ -173,9 +172,7 @@ def train( train_meter.log_epoch(epoch) logger.info( - "Training epoch {} finished in {:.04f} seconds".format( - epoch, time.time() - epoch_start - ) + f"Training epoch {epoch} finished in {time.time() - epoch_start:.04f} seconds" ) @torch.no_grad() @@ -267,9 +264,7 @@ def eval( val_meter.log_epoch(epoch) logger.info( - "Validation epoch {} finished in {:.04f} seconds".format( - epoch, time.time() - val_start - ) + f"Validation epoch {epoch} finished in {time.time() - val_start:.04f} seconds" ) # Get final measures and log them @@ -282,21 +277,12 @@ def eval( # Log metrics logger.info( - "[Epoch {} stats]: SF: {}, MIoU: {:.4f}; " - "Mean Recall: {:.4f}; " - "Mean Precision: {:.4f}; " - "Avg loss total: {:.4f}; " - "Avg loss dice: {:.4f}; " - "Avg loss ce: {:.4f}".format( - epoch, - key, - np.mean(ious), - np.mean(accs[key] / per_cls_counts_gt[key]), - np.mean(accs[key] / per_cls_counts_pred[key]), - val_loss_total[key], - val_loss_dice[key], - val_loss_ce[key], - ) + f"[Epoch {epoch} stats]: SF: {key}, MIoU: {np.mean(ious):.4f}; " + f"Mean Recall: {np.mean(accs[key] / per_cls_counts_gt[key]):.4f}; " + f"Mean Precision: {np.mean(accs[key] / per_cls_counts_pred[key]):.4f}; " + f"Avg loss total: {val_loss_total[key]:.4f}; " + f"Avg loss dice: {val_loss_dice[key]:.4f}; " + f"Avg loss ce: {val_loss_ce[key]:.4f}" ) logger.info(self.a.format(*self.class_names)) @@ -344,7 +330,7 @@ def run(self): logger.info(f"Resume training from epoch {start_epoch}") except Exception as e: print( - "No model to restore. Resuming training from Epoch 0. {}".format(e) + f"No model to restore. Resuming training from Epoch 0. {e}" ) else: logger.info("Training from scratch") @@ -352,9 +338,7 @@ def run(self): best_miou = 0 logger.info( - "{} parameters in total".format( - sum(x.numel() for x in self.model.parameters()) - ) + f"{sum(x.numel() for x in self.model.parameters())} parameters in total" ) # Create tensorboard summary writer @@ -381,9 +365,9 @@ def run(self): writer=writer, ) - logger.info("Summary path {}".format(self.cfg.SUMMARY_PATH)) + logger.info(f"Summary path {self.cfg.SUMMARY_PATH}") # Perform the training loop. - logger.info("Start epoch: {}".format(start_epoch + 1)) + logger.info(f"Start epoch: {start_epoch + 1}") for epoch in range(start_epoch, self.cfg.TRAIN.NUM_EPOCHS): self.train(train_loader, optimizer, scheduler, train_meter, epoch=epoch) diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index faeb7fa2..2896c279 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -13,12 +13,12 @@ # limitations under the License. import argparse -from typing import Literal, Optional, Union +from typing import Literal import nibabel as nib import numpy as np -VoxSizeOption = Union[float, Literal["min"]] +VoxSizeOption = float | Literal["min"] def vox_size(a: str) -> VoxSizeOption: @@ -51,7 +51,7 @@ def vox_size(a: str) -> VoxSizeOption: ) from None -def float_gt_zero_and_le_one(a: str) -> Optional[float]: +def float_gt_zero_and_le_one(a: str) -> float | None: """ Check whether a parameters are a float between 0 and one. @@ -121,7 +121,7 @@ def target_dtype(a: str) -> str: raise argparse.ArgumentTypeError(f"Invalid dtype {a}. {msg}") -def int_gt_zero(value: Union[str, int]) -> int: +def int_gt_zero(value: str | int) -> int: """ Convert to positive integers. diff --git a/FastSurferCNN/utils/brainvolstats.py b/FastSurferCNN/utils/brainvolstats.py index 40451bb5..b8a2b4d1 100644 --- a/FastSurferCNN/utils/brainvolstats.py +++ b/FastSurferCNN/utils/brainvolstats.py @@ -1,20 +1,29 @@ import abc import logging import re -from concurrent.futures import Executor +from collections.abc import Callable, Iterable, Sequence +from concurrent.futures import Executor, Future from contextlib import contextmanager from pathlib import Path -from typing import (TYPE_CHECKING, Sequence, cast, Literal, Iterable, Callable, Union, - Optional, overload, TextIO, Protocol, TypeVar, Generic, Type) -from concurrent.futures import Future +from typing import ( + TYPE_CHECKING, + Generic, + Literal, + Protocol, + TextIO, + TypeVar, + Union, + cast, + overload, +) import numpy as np if TYPE_CHECKING: - from numpy import typing as npt import lapy import nibabel as nib import pandas as pd + from numpy import typing as npt from CerebNet.datasets.utils import LTADict @@ -55,7 +64,7 @@ def __call__(self, file: Path, blocking: True = True) -> T_BufferType: ... @overload def __call__(self, file: Path, blocking: False) -> None: ... - def __call__(self, file: Path, b: bool = True) -> Optional[T_BufferType]: ... + def __call__(self, file: Path, b: bool = True) -> T_BufferType | None: ... class _DefaultFloat(float): @@ -77,12 +86,12 @@ def read_measure_file(path: Path) -> dict[str, MeasureTuple]: {'': ('', '', , '')}. """ if not path.exists(): - raise IOError(f"Measures could not be imported from {path}, " + raise OSError(f"Measures could not be imported from {path}, " f"the file does not exist.") - with open(path, "r") as fp: + with open(path) as fp: lines = list(fp.readlines()) - vox_line = list(filter(lambda l: l.startswith("# VoxelVolume_mm3 "), lines)) - lines = filter(lambda l: l.startswith("# Measure "), lines) + vox_line = list(filter(lambda ln: ln.startswith("# VoxelVolume_mm3 "), lines)) + lines = filter(lambda ln: ln.startswith("# Measure "), lines) def to_measure(line: str) -> tuple[str, MeasureTuple]: data_tup = line.removeprefix("# Measure ").strip() @@ -119,9 +128,9 @@ def read_volume_file(path: Path) -> ImageTuple: raise RuntimeError( f"Loading the file '{path}' for Measure was invalid, no SpatialImage." ) - except (IOError, FileNotFoundError) as e: + except (OSError, FileNotFoundError) as e: args = e.args[0] - raise IOError(f"Failed loading the file '{path}' with error: {args}") from e + raise OSError(f"Failed loading the file '{path}' with error: {args}") from e data = np.asarray(img.dataobj) return img, data @@ -143,9 +152,9 @@ def read_mesh_file(path: Path) -> "lapy.TriaMesh": try: import lapy mesh = lapy.TriaMesh.read_fssurf(str(path)) - except (IOError, FileNotFoundError) as e: + except (OSError, FileNotFoundError) as e: args = e.args[0] - raise IOError( + raise OSError( f"Failed loading the file '{path}' with error: {args}") from e return mesh @@ -191,8 +200,8 @@ def read_xfm_transform_file(path: Path) -> "npt.NDArray[float]": lines = f.readlines() try: - transf_start = [l.lower().startswith("linear_") for l in lines].index(True) + 1 - tal_str = [l.replace(";", " ") for l in lines[transf_start:transf_start + 3]] + transf_start = [ln.lower().startswith("linear_") for ln in lines].index(True) + 1 + tal_str = [ln.replace(";", " ") for ln in lines[transf_start:transf_start + 3]] tal = np.genfromtxt(tal_str) tal = np.vstack([tal, [0, 0, 0, 1]]) @@ -403,11 +412,11 @@ def kwerror(i, args, msg) -> RuntimeError: ) _kwargs = {} _kwmode = False - for i, (arg, default_key) in enumerate(zip(args, _pargs)): + for i, (arg, default_key) in enumerate(zip(args, _pargs, strict=False)): if (hit := self.__PATTERN.match(arg)) is None: # non-keyword mode if _kwmode: - raise kwerror(i, args, f"non-keyword after keyword") + raise kwerror(i, args, "non-keyword after keyword") _kwargs[default_key] = arg else: # keyword mode @@ -485,7 +494,7 @@ def __init__( ): self._file = file self._callback = read_hook - self._data: Optional[T_BufferType] = None + self._data: T_BufferType | None = None self.__buffer = None super().__init__(name, description, unit) @@ -549,8 +558,8 @@ def __init__( name: str = "N/A", description: str = "N/A", unit: UnitString = "unitless", - read_file: Optional[ReadFileHook[dict[str, MeasureTuple]]] = None, - vox_vol: Optional[float] = None, + read_file: ReadFileHook[dict[str, MeasureTuple]] | None = None, + vox_vol: float | None = None, ): self._key: str = key super().__init__( @@ -560,7 +569,7 @@ def __init__( unit, self.read_file if read_file is None else read_file, ) - self._vox_vol: Optional[float] = vox_vol + self._vox_vol: float | None = vox_vol def _compute(self) -> int | float: """ @@ -656,7 +665,7 @@ def __init__( name: str, description: str, unit: UnitString, - read_mesh: Optional[ReadFileHook["lapy.TriaMesh"]] = None, + read_mesh: ReadFileHook["lapy.TriaMesh"] | None = None, ): super().__init__( surface_file, @@ -807,10 +816,10 @@ def __init__( name: str, description: str, unit: UnitString = "unitless", - read_file: Optional[ReadFileHook[ImageTuple]] = None, + read_file: ReadFileHook[ImageTuple] | None = None, ): if callable(classes_or_cond): - self._classes: Optional[ClassesType] = None + self._classes: ClassesType | None = None self._cond: _ToBoolCallback = classes_or_cond else: if len(classes_or_cond) == 0: @@ -886,7 +895,7 @@ def __init__( threshold: float = 0.5, # sign: MaskSign = "abs", frame: int = 0, # erode: int = 0, invert: bool = False, - read_file: Optional[ReadFileHook[ImageTuple]] = None, + read_file: ReadFileHook[ImageTuple] | None = None, ): self._threshold: float = threshold # self._sign: MaskSign = sign @@ -946,7 +955,7 @@ def __init__( name: str, description: str, unit: str, - read_lta: Optional[ReadFileHook["npt.NDArray[float]"]] = None, + read_lta: ReadFileHook["npt.NDArray[float]"] | None = None, ): super().__init__( lta_file, @@ -985,7 +994,7 @@ def __init__( name: str, description: str, unit: str, - read_lta: Optional[ReadFileHook["LTADict"]] = None, + read_lta: ReadFileHook["LTADict"] | None = None, etiv_scale_factor: float | None = None, ): if etiv_scale_factor is None: @@ -1022,7 +1031,7 @@ def __init__( description: str, unit: str = "from parents", operation: DerivedAggOperation = "sum", - measure_host: Optional[dict[str, AbstractMeasure]] = None, + measure_host: dict[str, AbstractMeasure] | None = None, ): """ Create the Measure, which depends on other measures, called parent measures. @@ -1058,7 +1067,7 @@ def to_tuple( else: factor, measure = 1., value - if not isinstance(measure, (str, AbstractMeasure)): + if not isinstance(measure, str | AbstractMeasure): raise ValueError(f"Expected a str or AbstractMeasure, not " f"{type(measure).__name__}!") if not isinstance(factor, float): @@ -1303,7 +1312,7 @@ def __call__( def format_measure(key: str, data: MeasureTuple) -> str: - value = data[2] if isinstance(data[2], int) else ("%.6f" % data[2]) + value = data[2] if isinstance(data[2], int) else f"{data[2]:.6f}" return f"# Measure {key}, {data[0]}, {data[1]}, {value}, {data[3]}" @@ -1339,12 +1348,12 @@ class Manager(dict[str, AbstractMeasure]): def __init__( self, measures: Sequence[tuple[bool, str]], - measurefile: Optional[Path] = None, - segfile: Optional[Path] = None, + measurefile: Path | None = None, + segfile: Path | None = None, on_missing: Literal["fail", "skip", "fill"] = "fail", - executor: Optional[Executor] = None, + executor: Executor | None = None, legacy_freesurfer: bool = False, - aseg_replace: Optional[Path] = None, + aseg_replace: Path | None = None, ): """ @@ -1365,7 +1374,7 @@ def __init__( legacy_freesurfer : bool, default=False FreeSurfer compatibility mode. """ - from concurrent.futures import ThreadPoolExecutor, Future + from concurrent.futures import Future, ThreadPoolExecutor from copy import deepcopy def _check_measures(x): @@ -1447,7 +1456,7 @@ def assert_measure_need_subject(self) -> None: AssertionError """ any_computed = False - for key, measure in self.items(): + for _key, measure in self.items(): if isinstance(measure, DerivedMeasure): pass elif isinstance(measure, ImportedMeasure): @@ -1759,7 +1768,7 @@ def make_read_hook( The function returns None or the output of the wrapped function. """ - def read_wrapper(file: Path, blocking: bool = True) -> Optional[T_BufferType]: + def read_wrapper(file: Path, blocking: bool = True) -> T_BufferType | None: out = self._cache.get(file, None) if out is None: # not already in cache @@ -1795,7 +1804,7 @@ def update_measures(self) -> dict[str, float | int]: m.update({key: self[key]() for key in self._exported_measures}) return m - def print_measures(self, file: Optional[TextIO] = None) -> None: + def print_measures(self, file: TextIO | None = None) -> None: """ Print the measures to stdout or file. @@ -2024,7 +2033,7 @@ def default(self, key: str) -> AbstractMeasure: return DerivedMeasure( ["lhCortex", "rhCortex"], "CortexVol", - f"Total cortical gray matter volume", + "Total cortical gray matter volume", measure_host=self, ) elif key == "CorpusCallosumVol": @@ -2346,7 +2355,7 @@ def get_virtual_labels(self, label_pool: Iterable[int]) -> dict[int, list[int]]: """ lbls = (this.labels() for this in self.values() if isinstance(this, PVMeasure)) no_duplicate_dict = {self.__to_lookup(labs): labs for labs in lbls} - return dict(zip(label_pool, no_duplicate_dict.values())) + return dict(zip(label_pool, no_duplicate_dict.values(), strict=False)) @staticmethod def __to_lookup(labels: Sequence[int]) -> str: diff --git a/FastSurferCNN/utils/checkpoint.py b/FastSurferCNN/utils/checkpoint.py index 6663a9e0..967f2841 100644 --- a/FastSurferCNN/utils/checkpoint.py +++ b/FastSurferCNN/utils/checkpoint.py @@ -14,16 +14,17 @@ # IMPORTS import os +from collections.abc import MutableSequence from functools import lru_cache from pathlib import Path -from typing import MutableSequence, Optional, Union, Literal, TypedDict, cast, overload +from typing import Literal, TypedDict, cast, overload import requests import torch import yacs.config import yaml -from FastSurferCNN.utils import logging, Plane +from FastSurferCNN.utils import Plane, logging from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT Scheduler = "torch.optim.lr_scheduler" @@ -61,15 +62,15 @@ def load_checkpoint_config(filename: Path | str = YAML_DEFAULT) -> CheckpointCon if not filename.absolute(): filename = FASTSURFER_ROOT / filename - with open(filename, "r") as file: + with open(filename) as file: data = yaml.load(file, Loader=yaml.FullLoader) required_fields = ("url", "checkpoint") checks = [k not in data for k in required_fields] if any(checks): - missing = tuple(k for k, c in zip(required_fields, checks) if c) + missing = tuple(k for k, c in zip(required_fields, checks, strict=False) if c) message = f"The file {filename} is not valid, missing key(s): {missing}" - raise IOError(message) + raise OSError(message) if isinstance(data["url"], str): data["url"] = [data["url"]] else: @@ -93,7 +94,7 @@ def load_checkpoint_config_defaults( filename: str | Path = YAML_DEFAULT, ) -> list[str]: ... - +@lru_cache def load_checkpoint_config_defaults( configtype: CheckpointConfigFields, filename: str | Path = YAML_DEFAULT, @@ -124,7 +125,7 @@ def load_checkpoint_config_defaults( return load_checkpoint_config(filename)[configtype] -def create_checkpoint_dir(expr_dir: Union[os.PathLike], expr_num: int): +def create_checkpoint_dir(expr_dir: os.PathLike, expr_num: int): """ Create the checkpoint dir if not exists. @@ -163,13 +164,13 @@ def get_checkpoint(ckpt_dir: str, epoch: int) -> str: Standardizes checkpoint name. """ checkpoint_dir = os.path.join( - ckpt_dir, "Epoch_{:05d}_training_state.pkl".format(epoch) + ckpt_dir, f"Epoch_{epoch:05d}_training_state.pkl" ) return checkpoint_dir def get_checkpoint_path( - log_dir: Path | str, resume_experiment: Union[str, int, None] = None + log_dir: Path | str, resume_experiment: str | int | None = None ) -> MutableSequence[Path]: """ Find the paths to checkpoints from the experiment directory. @@ -200,8 +201,8 @@ def get_checkpoint_path( def load_from_checkpoint( checkpoint_path: str | Path, model: torch.nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[Scheduler] = None, + optimizer: torch.optim.Optimizer | None = None, + scheduler: Scheduler | None = None, fine_tune: bool = False, drop_classifier: bool = False, ): @@ -259,7 +260,7 @@ def save_checkpoint( cfg: yacs.config.CfgNode, model: torch.nn.Module, optimizer: torch.optim.Optimizer, - scheduler: Optional[Scheduler] = None, + scheduler: Scheduler | None = None, best: bool = False, ) -> None: """ diff --git a/FastSurferCNN/utils/common.py b/FastSurferCNN/utils/common.py index d2577b85..e11e76b6 100644 --- a/FastSurferCNN/utils/common.py +++ b/FastSurferCNN/utils/common.py @@ -14,19 +14,11 @@ # IMPORTS import os -from collections import namedtuple +from collections.abc import Callable, Iterable, Iterator from concurrent.futures import Executor, Future -from dataclasses import dataclass from pathlib import Path from typing import ( Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Tuple, TypeVar, ) @@ -126,6 +118,7 @@ def assert_no_root() -> bool: """ if os.name == "posix" and os.getuid() == 0: import sys + import __main__ sys.exit( @@ -185,7 +178,7 @@ def pipeline( iterable: Iterable[_Ti], *, pipeline_size: int = 1, -) -> Iterator[Tuple[_Ti, _T]]: +) -> Iterator[tuple[_Ti, _T]]: """ Pipeline a function to be executed in the pool. @@ -229,7 +222,7 @@ def pipeline( def iterate( pool: Executor, func: Callable[[_Ti], _T], iterable: Iterable[_Ti], -) -> Iterator[Tuple[_Ti, _T]]: +) -> Iterator[tuple[_Ti, _T]]: """ Iterate over iterable, yield pairs of elements and func(element). @@ -674,18 +667,18 @@ class SubjectList: Represent a list of subjects. """ - _subjects: List[Path] + _subjects: list[Path] _orig_name_: str _conf_name_: str _segfile_: str - _flags: Dict[str, dict] + _flags: dict[str, dict] DEFAULT_FLAGS = {k: v(dict) for k, v in parser_defaults.ALL_FLAGS.items()} def __init__( self, args: SubjectDirectoryConfig, - flags: Optional[dict[str, dict]] = None, + flags: dict[str, dict] | None = None, **assign, ): """ @@ -778,10 +771,10 @@ def __init__( self._out_segfile = getattr(self, "_segfile_", None) if self._out_segfile is None: raise RuntimeError( - f"The segmentation output file is not set, it should be either " - f"'segfile' (which gets populated from args.segfile), or a keyword " - f"argument to __init__, e.g. `SubjectList(args, subseg='subseg_param', " - f"out_filename='subseg')`." + "The segmentation output file is not set, it should be either " + "'segfile' (which gets populated from args.segfile), or a keyword " + "argument to __init__, e.g. `SubjectList(args, subseg='subseg_param', " + "out_filename='subseg')`." ) # if out_dir is not set, fall back to in_dir by default @@ -796,12 +789,12 @@ def __init__( raise RuntimeError(msg.format(**self._flags)) # 1. are we doing a csv file of subjects - if getattr(args, "csv_file") is not None: - with open(args.csv_file, "r") as s_dirs: + if args.csv_file is not None: + with open(args.csv_file) as s_dirs: self._subjects = [Path(line.strip()) for line in s_dirs.readlines()] if any(not d.is_absolute() for d in self._subjects): msg = f"At least one path in {args.csv_file} was relative, but the " - if getattr(args, "in_dir") is None: + if args.in_dir is None: raise RuntimeError( "{}in_dir was not in args (no {in_dir[flag]} flag).".format( msg, **self._flags @@ -950,7 +943,7 @@ def _not_abs(subj_attr): __init__.__doc__ = __init__.__doc__.format(**DEFAULT_FLAGS) @property - def flags(self) -> Dict[str, Dict]: + def flags(self) -> dict[str, dict]: """ Give the flags. @@ -1090,7 +1083,7 @@ def map( self, fn: Callable[..., _T], *iterables: Iterable[Any], - timeout: Optional[float] = None, + timeout: float | None = None, chunksize: int = -1, ) -> Iterator[_T]: """ diff --git a/FastSurferCNN/utils/dataclasses.py b/FastSurferCNN/utils/dataclasses.py index c78abdfc..2cc54548 100644 --- a/FastSurferCNN/utils/dataclasses.py +++ b/FastSurferCNN/utils/dataclasses.py @@ -1,20 +1,22 @@ -from typing import Mapping, TypeVar, overload, Any, Callable, Optional - +from collections.abc import Callable, Mapping from dataclasses import ( - field as _field, + KW_ONLY, + MISSING, + Field, + FrozenInstanceError, + InitVar, asdict, astuple, dataclass, fields, - Field, - FrozenInstanceError, is_dataclass, - InitVar, make_dataclass, - MISSING, - KW_ONLY, replace, ) +from dataclasses import ( + field as _field, +) +from typing import Any, TypeVar, overload __all__ = [ "field", @@ -116,7 +118,7 @@ def field( elif metadata is None: metadata = {} else: - raise TypeError(f"Invalid type of metadata, must be a Mapping!") + raise TypeError("Invalid type of metadata, must be a Mapping!") if help: if not isinstance(help, str): raise TypeError("help must be a str!") diff --git a/FastSurferCNN/utils/logging.py b/FastSurferCNN/utils/logging.py index 3b98752b..2aab4c99 100644 --- a/FastSurferCNN/utils/logging.py +++ b/FastSurferCNN/utils/logging.py @@ -13,10 +13,7 @@ # limitations under the License. # IMPORTS -from logging import * -from logging import DEBUG, INFO, FileHandler, StreamHandler, basicConfig -from logging import getLogger -from logging import getLogger as get_logger +from logging import INFO, FileHandler, StreamHandler, basicConfig from pathlib import Path as _Path from sys import stdout as _stdout diff --git a/FastSurferCNN/utils/lr_scheduler.py b/FastSurferCNN/utils/lr_scheduler.py index 490a1080..9298f821 100644 --- a/FastSurferCNN/utils/lr_scheduler.py +++ b/FastSurferCNN/utils/lr_scheduler.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union import torch.optim @@ -22,7 +21,7 @@ def get_lr_scheduler( optimzer: torch.optim.Optimizer, cfg: yacs.config.CfgNode -) -> Union[None, scheduler.StepLR, scheduler.CosineAnnealingWarmRestarts]: +) -> None | scheduler.StepLR | scheduler.CosineAnnealingWarmRestarts: """ Give a schedular for left-right scheduling. diff --git a/FastSurferCNN/utils/mapper.py b/FastSurferCNN/utils/mapper.py index 4603b693..ba3098a1 100644 --- a/FastSurferCNN/utils/mapper.py +++ b/FastSurferCNN/utils/mapper.py @@ -20,37 +20,23 @@ import json import os.path +from collections.abc import Callable, Collection, Hashable, Iterable, Iterator, Mapping, Sequence from functools import partial, partialmethod, reduce from numbers import Integral, Number from typing import ( Any, - Callable, - Collection, - Container, - Dict, Generic, - Hashable, - Iterable, - Iterator, - List, Literal, - Mapping, - Optional, - Sequence, - Set, TextIO, - Tuple, TypeVar, - Union, cast, - overload, ) import numpy as np import pandas import torch -from matplotlib.pyplot import get_cmap from matplotlib.colors import Colormap +from matplotlib.pyplot import get_cmap from numpy import typing as npt from FastSurferCNN.utils import logging @@ -70,7 +56,7 @@ NT = TypeVar("NT", bound=Number) AT = TypeVar("AT", npt.NDArray[Number], torch.Tensor) -ColorTuple = Tuple[float, float, float] +ColorTuple = tuple[float, float, float] ColormapGenerator = Callable[[int], npt.NDArray[float]] logger = logging.getLogger(__name__) @@ -137,15 +123,15 @@ class Mapper(Generic[KT, VT]): Map from one label space to a generic 'label'-space. """ - _map_dict: Dict[KT, npt.NDArray[VT]] - _label_shape: Tuple[int, ...] + _map_dict: dict[KT, npt.NDArray[VT]] + _label_shape: tuple[int, ...] _map_np: npt.NDArray[VT] _map_torch: torch.Tensor - _max_label: Optional[int] + _max_label: int | None _name: str def __init__( - self, mappings: Mapping[KT, Union[VT, npt.NDArray[VT]]], name: str = "undefined" + self, mappings: Mapping[KT, VT | npt.NDArray[VT]], name: str = "undefined" ): """ Construct `Mapper` object from a mappings dictionary. @@ -196,7 +182,7 @@ def name(self, name: str): self._name = name @property - def source_space(self) -> Set[KT]: + def source_space(self) -> set[KT]: """ Return a set of labels the mapper accepts. """ @@ -205,7 +191,8 @@ def source_space(self) -> Set[KT]: @property def target_space(self) -> Collection[VT]: """ - Return the set of labels the mapper converts to as a set of python-natives (if possible), arrays expanded to tuples. + Return the set of labels the mapper converts to as a set of python-natives (if possible), + arrays expanded to tuples. """ return self._map_dict.values() @@ -249,7 +236,7 @@ def update( __iadd__ = partialmethod(update, overwrite=True) - def map(self, image: AT, out: Optional[AT] = None) -> AT: + def map(self, image: AT, out: AT | None = None) -> AT: """ Forward map the labels from prediction to internal space. @@ -337,7 +324,7 @@ def map(self, image: AT, out: Optional[AT] = None) -> AT: return out return to_same_type(mapped, type_hint=image) - def _map_py(self, image: AT, out: Optional[AT] = None) -> AT: + def _map_py(self, image: AT, out: AT | None = None) -> AT: """ Map internally by python, for example for strings. @@ -378,8 +365,8 @@ def _internal_map(img, o): return out def __call__( - self, image: AT, label_image: Union[npt.NDArray[KT], torch.Tensor] - ) -> Tuple[AT, Union[npt.NDArray, torch.Tensor]]: + self, image: AT, label_image: npt.NDArray[KT] | torch.Tensor + ) -> tuple[AT, npt.NDArray | torch.Tensor]: """ Transform a dataset from prediction to internal space for sets of image and segmentation. @@ -409,7 +396,7 @@ def reversed_dict(self) -> Mapping[VT, KT]: a = self._map_dict[src] if not isinstance(a, Hashable): a = tuple( - a.tolist() if isinstance(a, (np.ndarray, torch.Tensor)) else a + a.tolist() if isinstance(a, np.ndarray | torch.Tensor) else a ) rev_mappings.setdefault(a, src) return rev_mappings @@ -432,7 +419,7 @@ def __getitem__(self, item: KT) -> VT: """ return self._map_dict[item] - def __iter__(self) -> Iterator[Tuple[KT, VT]]: + def __iter__(self) -> Iterator[tuple[KT, VT]]: """ Create an iterator for the Mapper object. """ @@ -492,7 +479,7 @@ def chain( @classmethod def make_classmapper( cls, - mappings: Dict[int, int], + mappings: dict[int, int], keep_labels: Sequence[int] = tuple(), compress_out_space: bool = False, name: str = "undefined", @@ -535,7 +522,7 @@ def make_classmapper( (v, i) for i, v in enumerate(sorted(set(mappings.values()))) ) - def relabel(old_label_in: int, old_label_out: int) -> Tuple[int, int]: + def relabel(old_label_in: int, old_label_out: int) -> tuple[int, int]: return old_label_in, target_labels[old_label_out] mappings = dict(map(relabel, mappings.items())) @@ -547,7 +534,7 @@ def _map_logits( logits: AT, axis: int = -1, reverse: bool = False, - out: Optional[AT] = None, + out: AT | None = None, mode: Literal["logit", "prob"] = "logit", ) -> AT: """ @@ -585,7 +572,7 @@ def _map_logits( unique_target_classes = np.unique( list(self._map_dict.values()), return_counts=True ) - cls_cts = {cls: cts for cls, cts in zip(*unique_target_classes) if cts > 1} + cls_cts = {cls: cts for cls, cts in zip(*unique_target_classes, strict=False) if cts > 1} mappings = ((v, k) for k, v in mappings) # swap source and target mappings else: cls_cts = {} @@ -632,17 +619,17 @@ class ColorLookupTable(Generic[KT]): This class provides utility in creating color palettes from colormaps. """ - _color_palette: Optional[npt.NDArray[float]] - _colormap: Union[str, Colormap, ColormapGenerator] - _classes: Optional[List[KT]] + _color_palette: npt.NDArray[float] | None + _colormap: str | Colormap | ColormapGenerator + _classes: list[KT] | None _name: str def __init__( self, - classes: Optional[Iterable[KT]] = None, - color_palette: Union[Dict[KT, npt.ArrayLike], npt.ArrayLike, None] = None, - colormap: Union[str, Colormap, ColormapGenerator] = "gist_ncar", - name: Optional[str] = None, + classes: Iterable[KT] | None = None, + color_palette: dict[KT, npt.ArrayLike] | npt.ArrayLike | None = None, + colormap: str | Colormap | ColormapGenerator = "gist_ncar", + name: str | None = None, ): """ Construct a LookupTable object. @@ -690,14 +677,14 @@ def name(self, name: str): self._name = name @property - def classes(self) -> Optional[List[KT]]: + def classes(self) -> list[KT] | None: """ Return the classes. """ return self._classes @classes.setter - def classes(self, classes: Optional[Iterable[KT]]): + def classes(self, classes: Iterable[KT] | None): """ Set the classes and generates a color palette for the given classes. @@ -718,7 +705,7 @@ def classes(self, classes: Optional[Iterable[KT]]): self._color_palette = get_cmap(self._colormap, num)(list(range(num))) @property - def color_palette(self) -> Optional[npt.NDArray[float]]: + def color_palette(self) -> npt.NDArray[float] | None: """ Return the color palette if it exists. """ @@ -726,7 +713,7 @@ def color_palette(self) -> Optional[npt.NDArray[float]]: @color_palette.setter def color_palette( - self, color_palette: Union[Dict[KT, npt.ArrayLike], npt.ArrayLike, None] + self, color_palette: dict[KT, npt.ArrayLike] | npt.ArrayLike | None ): """ Set (or reset) the color palette of the LookupTable. @@ -751,7 +738,7 @@ def color_palette( self._color_palette = color_palette - def __getitem__(self, key: KT) -> Tuple[int, KT, Tuple[int, int, int, int], Any]: + def __getitem__(self, key: KT) -> tuple[int, KT, tuple[int, int, int, int], Any]: """ Return index, key, colors and additional values for the key. @@ -770,14 +757,14 @@ def __getitem__(self, key: KT) -> Tuple[int, KT, Tuple[int, int, int, int], Any] def getitem_by_index( self, index: int - ) -> Tuple[int, KT, Tuple[int, int, int, int], Any]: + ) -> tuple[int, KT, tuple[int, int, int, int], Any]: """ Return index, key, colors and additional values for the key. """ color = self.get_color_by_index(index, 255) return index, self._classes[index], color, None - def get_color_by_index(self, index: int, base: NT = 1.0) -> Tuple[NT, NT, NT, NT]: + def get_color_by_index(self, index: int, base: NT = 1.0) -> tuple[NT, NT, NT, NT]: """ Return the color (r, g, b, a) tuple associated with the index in the passed base. """ @@ -804,7 +791,7 @@ def colormap(self) -> Mapper[KT, ColorTuple]: if self._color_palette is None: raise RuntimeError("No color_palette set") return Mapper( - dict(zip(self.classes, self.color_palette)), name="color-" + self.name + dict(zip(self.classes, self.color_palette, strict=False)), name="color-" + self.name ) def labelname2index(self) -> Mapper[KT, int]: @@ -814,7 +801,7 @@ def labelname2index(self) -> Mapper[KT, int]: This is the inverse of ColorLookupTable.classes. """ return Mapper( - dict(zip(self._classes, range(len(self._classes)))), + dict(zip(self._classes, range(len(self._classes)), strict=False)), name="index-" + self.name, ) @@ -843,9 +830,9 @@ class JsonColorLookupTable(ColorLookupTable[KT]): def __init__( self, file_or_buffer, - color_palette: Union[Dict[KT, npt.ArrayLike], npt.ArrayLike, None] = None, - colormap: Union[str, Colormap, ColormapGenerator] = "gist_ncar", - name: Optional[str] = None, + color_palette: dict[KT, npt.ArrayLike] | npt.ArrayLike | None = None, + colormap: str | Colormap | ColormapGenerator = "gist_ncar", + name: str | None = None, ) -> None: """ Construct a JsonLookupTable object from `file_or_buffer` passed. @@ -868,9 +855,9 @@ def __init__( self._data = json.loads(file_or_buffer) if name is None: name = "unnamed json buffer string" - elif isinstance(file_or_buffer, (str, os.PathLike)): + elif isinstance(file_or_buffer, str | os.PathLike): if os.path.exists(file_or_buffer): - with open(file_or_buffer, "r") as file: + with open(file_or_buffer) as file: self._data = json.load(file) else: raise ValueError(f"The file {file_or_buffer} does not exist") @@ -899,11 +886,11 @@ def __init__( raise KeyError( f"Duplicate classes in source_space: {unique_classes[counts>1]}" ) - super(JsonColorLookupTable, self).__init__( + super().__init__( classes=classes, color_palette=color_palette, colormap=colormap, name=name ) - def _get_labels(self) -> Union[Dict[KT, Any], Iterable[KT]]: + def _get_labels(self) -> dict[KT, Any] | Iterable[KT]: """ Return labels. """ @@ -920,12 +907,12 @@ def dataframe(self) -> pandas.DataFrame: if isinstance(self._data, dict) and "labels" in self._data: return pandas.DataFrame.from_dict(self._data["labels"]) - def __getitem__(self, key: KT) -> Tuple[int, KT, Tuple[int, int, int, int], Any]: + def __getitem__(self, key: KT) -> tuple[int, KT, tuple[int, int, int, int], Any]: """ Index by the index position, unless either key or value are int. """ labels = self._get_labels() - index, key, color, _other = super(JsonColorLookupTable, self).__getitem__(key) + index, key, color, _other = super().__getitem__(key) if isinstance(labels, dict): _other = labels[str(key)] return index, key, color, _other @@ -948,7 +935,7 @@ def labelname2id(self) -> Mapper[KT, Any]: if not isinstance(labels, dict): raise RuntimeError("The json file contained no values.") return Mapper( - dict(zip(self._classes, labels.values())), name="value-" + self.name + dict(zip(self._classes, labels.values(), strict=False)), name="value-" + self.name ) @@ -962,7 +949,7 @@ class TSVLookupTable(ColorLookupTable[str]): def __init__( self, file_or_buffer, - name: Optional[str] = None, + name: str | None = None, header: bool = False, add_background: bool = True, ) -> None: @@ -1011,7 +998,7 @@ def __init__( ) if not (self._data.index == 0).any(): df = pandas.DataFrame.from_dict( - {k: [v] for k, v in zip(names.keys(), [0, "Unknown", 0, 0, 0, 0])} + {k: [v] for k, v in zip(names.keys(), [0, "Unknown", 0, 0, 0, 0], strict=False)} ) self._data = pandas.concat([df, self._data]) classes = self._data["Label name"].tolist() @@ -1019,13 +1006,13 @@ def __init__( color_palette = np.asarray( [tuple(int(row[k].item()) for k in channels) for row in self._data.iloc] ) - super(TSVLookupTable, self).__init__( + super().__init__( classes=classes, color_palette=color_palette, name=name ) def getitem_by_index( self, index: int - ) -> Tuple[int, str, Tuple[int, int, int, int], int]: + ) -> tuple[int, str, tuple[int, int, int, int], int]: """ Find the Entry associated by a No. @@ -1046,7 +1033,7 @@ def getitem_by_index( int The data index associated with the entry. """ - index, key, color, _ = super(TSVLookupTable, self).getitem_by_index(index) + index, key, color, _ = super().getitem_by_index(index) return index, key, color, self._data.iloc[index].name def dataframe(self) -> pandas.DataFrame: @@ -1071,5 +1058,5 @@ class and data index. If no value is associated. """ return Mapper( - dict(zip(self._classes, self._data.index)), name="value-" + self.name + dict(zip(self._classes, self._data.index, strict=False)), name="value-" + self.name ) diff --git a/FastSurferCNN/utils/meters.py b/FastSurferCNN/utils/meters.py index e0a7b902..c9bc3b9a 100644 --- a/FastSurferCNN/utils/meters.py +++ b/FastSurferCNN/utils/meters.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any import matplotlib.pyplot as plt # IMPORTS import numpy as np -import torch import yacs.config from FastSurferCNN.utils import logging @@ -37,11 +36,11 @@ def __init__( cfg: yacs.config.CfgNode, mode: str, global_step: int, - total_iter: Optional[int] = None, - total_epoch: Optional[int] = None, - class_names: Optional[Any] = None, - device: Optional[Any] = None, - writer: Optional[Any] = None, + total_iter: int | None = None, + total_epoch: int | None = None, + class_names: Any | None = None, + device: Any | None = None, + writer: Any | None = None, ): """ Construct a Meter object. @@ -150,14 +149,9 @@ def log_iter(self, cur_iter: int, cur_epoch: int): """ if (cur_iter + 1) % self._cfg.TRAIN.LOG_INTERVAL == 0: logger.info( - "{} Epoch [{}/{}] Iter [{}/{}] with loss {:.4f}".format( - self.mode, - cur_epoch + 1, - self.total_epochs, - cur_iter + 1, - self.total_iter_num, - np.array(self.batch_losses).mean(), - ) + f"{self.mode} Epoch [{cur_epoch + 1}/{self.total_epochs}]" \ + f" Iter [{cur_iter + 1}/{self.total_iter_num}]" \ + f" with loss {np.array(self.batch_losses).mean():.4f}" ) def log_epoch(self, cur_epoch: int): diff --git a/FastSurferCNN/utils/metrics.py b/FastSurferCNN/utils/metrics.py index 0b0cadf2..75d04b29 100644 --- a/FastSurferCNN/utils/metrics.py +++ b/FastSurferCNN/utils/metrics.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple +from typing import Any import numpy as np @@ -26,7 +26,7 @@ def iou_score( pred_cls: torch.Tensor, true_cls: torch.Tensor, nclass: int = 79 -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray]: """ Compute the intersection-over-union score. @@ -64,7 +64,7 @@ def iou_score( def precision_recall( pred_cls: torch.Tensor, true_cls: torch.Tensor, nclass: int = 79 -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Calculate recall (TP/(TP + FN) and precision (TP/(TP+FP) per class. @@ -125,7 +125,7 @@ class DiceScore: def __init__( self, num_classes: int, - device: Optional[str] = None, + device: str | None = None, output_transform=lambda y_pred, y: (y_pred.data.max(1)[1], y), ): """ @@ -152,9 +152,7 @@ def _check_output_type(self, output): """ if not (isinstance(output, tuple)): raise TypeError( - "Output should a tuple consist of of torch.Tensors, but given {}".format( - type(output) - ) + f"Output should a tuple consist of of torch.Tensors, but given {type(output)}" ) def _update_union_intersection_matrix( @@ -195,7 +193,7 @@ def _update_union_intersection( self.intersection[i, i] += torch.sum(torch.mul(gt, pred)) self.union[i, i] += torch.sum(gt) + torch.sum(pred) - def update(self, output: Tuple[Any, Any], cnf_mat: bool): + def update(self, output: tuple[Any, Any], cnf_mat: bool): """ Update the intersection. diff --git a/FastSurferCNN/utils/misc.py b/FastSurferCNN/utils/misc.py index 310caeaa..ca7b3e2c 100644 --- a/FastSurferCNN/utils/misc.py +++ b/FastSurferCNN/utils/misc.py @@ -15,7 +15,6 @@ # IMPORTS import os from itertools import product -from typing import List import matplotlib.figure import matplotlib.pyplot as plt @@ -83,7 +82,7 @@ def plot_predictions( def plot_confusion_matrix( cm: npt.NDArray, - classes: List[str], + classes: list[str], title: str = "Confusion matrix", cmap: plt.cm.ColormapRegistry = plt.cm.Blues, file_save_name: str = "temp.pdf", diff --git a/FastSurferCNN/utils/parser_defaults.py b/FastSurferCNN/utils/parser_defaults.py index 9b90a4cc..58f0c48a 100644 --- a/FastSurferCNN/utils/parser_defaults.py +++ b/FastSurferCNN/utils/parser_defaults.py @@ -26,14 +26,13 @@ >>> # 'dest': 'root', 'help': 'Allow execution as root user.'} """ -import argparse import types -from dataclasses import dataclass, Field, fields +from collections.abc import Iterable, Mapping +from dataclasses import Field, dataclass from pathlib import Path -from typing import (Dict, Iterable, Literal, Mapping, Protocol, Type, TypeVar, Union, - Optional, get_origin, get_args) +from typing import Literal, Protocol, TypeVar, get_args, get_origin -from FastSurferCNN.utils import Plane, PLANES +from FastSurferCNN.utils import PLANES, Plane from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one from FastSurferCNN.utils.arg_types import unquote_str from FastSurferCNN.utils.arg_types import vox_size as __vox_size @@ -46,7 +45,7 @@ "checkpoint": "{} checkpoint to load", "config": "Path to the {} config file", } -VoxSize = Union[Literal["min"], float] +VoxSize = Literal["min"] | float class CanAddArguments(Protocol): @@ -63,7 +62,7 @@ def add_argument(self, *args, **kwargs): def __arg( *default_flags: str, - dcf: Optional[Field] = None, + dcf: Field | None = None, dc=None, fieldname: str = "", **default_kwargs, @@ -113,7 +112,7 @@ def __arg( else: default_kwargs.setdefault(kw, default) - def _stub(parser: Union[CanAddArguments, Type[Dict]], *flags, **kwargs): + def _stub(parser: CanAddArguments | type[dict], *flags, **kwargs): # prefer the value passed to the "new" call for kw, arg in kwargs.items(): if callable(arg) and kw in default_kwargs: @@ -126,7 +125,7 @@ def _stub(parser: Union[CanAddArguments, Type[Dict]], *flags, **kwargs): _flags = flags if len(flags) != 0 else default_flags if hasattr(parser, "add_argument"): return parser.add_argument(*_flags, **kwargs) - elif parser == dict or isinstance(parser, dict): + elif isinstance(parser, dict): return {"flag": _flags[0], "flags": _flags, **kwargs} else: raise ValueError( @@ -168,18 +167,18 @@ class SubjectDirectoryConfig: "mri/orig.mgz.", flags=("--conformed_name",), ) - in_dir: Optional[Path] = field( + in_dir: Path | None = field( flags=("--in_dir",), default=None, help="Directory in which input volume(s) are located. Optional, if full path " "is defined for --t1.", ) - csv_file: Optional[Path] = field( + csv_file: Path | None = field( flags=("--csv_file",), default=None, help="Csv-file with subjects to analyze (alternative to --tag)", ) - sid: Optional[str] = field( + sid: str | None = field( flags=("--sid",), default=None, help="Optional: directly set the subject id to use. Can be used for single " @@ -207,7 +206,7 @@ class SubjectDirectoryConfig: "correct subject name (e.g. /ses-x/anat/ for BIDS or /mri/ for FreeSurfer " "input). Default: do not remove anything.", ) - out_dir: Optional[Path] = field( + out_dir: Path | None = field( default=None, help="Directory in which evaluation results should be written. Will be created " "if it does not exist. Optional if full path is defined for --pred_name.", diff --git a/FastSurferCNN/utils/run_tools.py b/FastSurferCNN/utils/run_tools.py index a206420c..2c9a72cc 100644 --- a/FastSurferCNN/utils/run_tools.py +++ b/FastSurferCNN/utils/run_tools.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from concurrent.futures import Executor, Future import subprocess +from collections.abc import Generator, Sequence from concurrent.futures import Executor, Future from dataclasses import dataclass -from functools import partialmethod -from typing import Generator, Optional, Sequence, Callable, Any, Collection, Iterable from datetime import datetime +from functools import partialmethod # TODO: python3.9+ # from collections.abc import Generator @@ -34,7 +33,7 @@ class MessageBuffer: out: bytes = b"" err: bytes = b"" - retcode: Optional[int] = None + retcode: int | None = None runtime: float = 0. def __add__(self, other: "MessageBuffer") -> "MessageBuffer": @@ -71,7 +70,7 @@ class Popen(subprocess.Popen): """ Extension of subprocess.Popen for convenience. """ - _starttime: Optional[datetime] = None + _starttime: datetime | None = None def __init__(self, *args, **kwargs): self._starttime = datetime.now() @@ -168,8 +167,8 @@ def finish(self, timeout: float = None) -> MessageBuffer: if i > 0: self.kill() raise RuntimeError( - "The process {} did not stop properly in Popen.finish, " - "abandoning.".format(self) + f"The process {self} did not stop properly in Popen.finish, " + "abandoning." ) i += 1 if i == 0: @@ -235,6 +234,6 @@ def __init__(self, args: Sequence[str], *_args, **kwargs): } flags = "".join(k for k, v in all_flags.items() if getattr(sys.flags, v) == 1) flags = [] if len(flags) == 0 else ["-" + flags] - super(PyPopen, self).__init__( + super().__init__( [sys.executable] + flags + list(args), *_args, **kwargs ) diff --git a/FastSurferCNN/version.py b/FastSurferCNN/version.py index bcd27d7e..51ffc09c 100644 --- a/FastSurferCNN/version.py +++ b/FastSurferCNN/version.py @@ -4,9 +4,10 @@ import re import shutil import subprocess +from collections.abc import Sequence +from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path -from typing import Any, cast, get_args, Literal, Optional, TypedDict, Sequence, TextIO -from concurrent.futures import ThreadPoolExecutor, Future +from typing import Any, Literal, TextIO, TypedDict, cast, get_args class DEFAULTS: @@ -125,7 +126,7 @@ def make_parser(): "--file", default=None, type=argparse.FileType("w"), - help=f"File to write version info to (default: write to stdout).", + help="File to write version info to (default: write to stdout).", ) parser.add_argument( "--prefer_cache", @@ -154,7 +155,7 @@ def print_build_file( git_status: str = "", checkpoints: str = "", pypackages: str = "", - file: Optional[TextIO] = None, + file: TextIO | None = None, ) -> None: """ Format and print the build file. @@ -207,9 +208,9 @@ def print_header(section_name: str) -> None: def main( sections: str = "", - project_file: Optional[TextIO] = None, - build_cache: Optional[TextIO | bool] = None, - file: Optional[TextIO] = None, + project_file: TextIO | None = None, + build_cache: TextIO | bool | None = None, + file: TextIO | None = None, prefer_cache: bool = False, ) -> str | int: """ @@ -280,7 +281,7 @@ def main( if sections == "all": sections = "+checkpoints+git+pip" - from FastSurferCNN.utils.run_tools import Popen, PyPopen, MessageBuffer + from FastSurferCNN.utils.run_tools import MessageBuffer, Popen, PyPopen build_cache_required = prefer_cache kw_root = {"cwd": DEFAULTS.PROJECT_ROOT, "stdout": subprocess.PIPE} @@ -334,7 +335,7 @@ def calculate_md5_for_checkpoints() -> "MessageBuffer": try: version = futures.pop("version").result() - except IOError: + except OSError: version = build_cache["version"] def __future_or_cache( @@ -401,7 +402,7 @@ def get_default_version_info() -> VersionDict: } -def parse_build_file(build_file: Optional[TextIO]) -> VersionDict: +def parse_build_file(build_file: TextIO | None) -> VersionDict: """Read and parse a build file (same as output of `main`). Read and parse a file with version information in the format that is also the @@ -426,8 +427,8 @@ def parse_build_file(build_file: Optional[TextIO]) -> VersionDict: file_cache: VersionDict = {} if build_file is None: try: - build_file = open(DEFAULTS.BUILD_TXT, "r") - except FileNotFoundError as e: + build_file = open(DEFAULTS.BUILD_TXT) + except FileNotFoundError: return get_default_version_info() file_cache["content"] = "".join(build_file.readlines()) if not build_file.closed: @@ -456,7 +457,7 @@ def parse_build_file(build_file: Optional[TextIO]) -> VersionDict: else: file_cache["version_tag"] = file_cache["version"] - def get_section_name_by_header(header: str) -> Optional[str]: + def get_section_name_by_header(header: str) -> str | None: for name, info in DEFAULTS.VERSION_SECTIONS.items(): if info[1] == header: return name @@ -508,7 +509,7 @@ def read_version_from_project_file(project_file: TextIO) -> str: return version -def filter_git_status(git_process: "FastSurferCNN.utils.run_tools.Popen") -> str: +def filter_git_status(git_process) -> str: """ Filter the output of a running git status process. @@ -531,7 +532,7 @@ def filter_git_status(git_process: "FastSurferCNN.utils.run_tools.Popen") -> str ) -def read_and_close_version(project_file: Optional[TextIO] = None) -> str: +def read_and_close_version(project_file: TextIO | None = None) -> str: """ Read and close the version from the pyproject file. Also fill default. @@ -552,7 +553,7 @@ def read_and_close_version(project_file: Optional[TextIO] = None) -> str: See also FastSurferCNN.version.read_version_from_project_file """ if project_file is None: - project_file = open(DEFAULTS.PROJECT_TOML, "r") + project_file = open(DEFAULTS.PROJECT_TOML) try: version = read_version_from_project_file(project_file) finally: From 792a281e5ead32207f0839d22e5dcc8e39f2c090 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 16:47:37 +0200 Subject: [PATCH 09/31] test dir, fix ruff stuff --- test/quick_test/test_errors.py | 13 +++++++------ test/quick_test/test_file_existence.py | 12 +++++++----- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/test/quick_test/test_errors.py b/test/quick_test/test_errors.py index 01611b5e..c1686e42 100644 --- a/test/quick_test/test_errors.py +++ b/test/quick_test/test_errors.py @@ -1,9 +1,10 @@ +import argparse import sys -import yaml import unittest -import argparse from pathlib import Path +import yaml + class TestErrors(unittest.TestCase): """ @@ -23,7 +24,7 @@ def setUpClass(cls): """ # Open the error_file_path and read the errors and whitelist into arrays - with open(cls.error_file_path, 'r') as file: + with open(cls.error_file_path) as file: data = yaml.safe_load(file) cls.errors = data.get('errors', []) cls.whitelist = data.get('whitelist', []) @@ -34,7 +35,7 @@ def setUpClass(cls): print(cls.log_directory) cls.log_files = [file for file in cls.log_directory.iterdir() if file.suffix == '.log'] except FileNotFoundError: - raise FileNotFoundError(f"Log directory not found at path: {cls.log_directory}") + raise FileNotFoundError(f"Log directory not found at path: {cls.log_directory}") from None def test_find_errors_in_logs(self): """ @@ -64,13 +65,13 @@ def test_find_errors_in_logs(self): files_with_errors[rel_path] = lines_with_errors self.error_flag = True except FileNotFoundError: - raise FileNotFoundError(f"Log file not found at path: {log_file}") + raise FileNotFoundError(f"Log file not found at path: {log_file}") from None continue # Print the lines and context with errors for each file for file, lines in files_with_errors.items(): print(f"\nFile {file}, in line {files_with_errors[file][0][0]}:") - for line_number, line in lines: + for _line_number, line in lines: print(*line, sep = "") # Assert that there are no lines with any of the keywords diff --git a/test/quick_test/test_file_existence.py b/test/quick_test/test_file_existence.py index 77abcb09..0d44176b 100644 --- a/test/quick_test/test_file_existence.py +++ b/test/quick_test/test_file_existence.py @@ -1,9 +1,10 @@ +import argparse import sys -import yaml import unittest -import argparse from pathlib import Path +import yaml + class TestFileExistence(unittest.TestCase): """ @@ -39,7 +40,8 @@ def test_file_existence(self): """ Test method to check the existence of files in the folder. - This method gets a list of all files in the folder recursively and checks if each file specified in the YAML file exists in the folder. + This method gets a list of all files in the folder recursively and checks + if each file specified in the YAML file exists in the folder. """ # Check if each file in the YAML file exists in the folder @@ -57,8 +59,8 @@ def test_file_existence(self): """ Main entry point of the script. - This block checks if there are any command line arguments, assigns the first argument to the error_file_path class variable, - and runs the unittest main function. + This block checks if there are any command line arguments, assigns the first argument + to the error_file_path class variable, and runs the unittest main function. """ parser = argparse.ArgumentParser(description="Test for file existence based on a YAML file.") From ea27fceffaf9d0bab09a0c555b3a48df0a1a8912 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 16:52:59 +0200 Subject: [PATCH 10/31] run on PRs and pushes in dev --- .github/workflows/code-style.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/code-style.yml b/.github/workflows/code-style.yml index 85e4ec0f..2284d735 100644 --- a/.github/workflows/code-style.yml +++ b/.github/workflows/code-style.yml @@ -3,9 +3,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} cancel-in-progress: true on: -# pull_request: -# push: -# branches: [dev] + pull_request: + push: + branches: [dev] workflow_dispatch: jobs: @@ -34,7 +34,7 @@ jobs: check_hidden: true skip: './.git,./build,./.mypy_cache,./.pytest_cache' ignore_words_file: ./.codespellignore - - name: Run pydocstyle - run: pydocstyle . - - name: Run bibclean - run: bibclean-check doc/references.bib +# - name: Run pydocstyle +# run: pydocstyle . +# - name: Run bibclean +# run: bibclean-check doc/references.bib From 02062f9e80f66453ea22ff8fe339360788a3e5ec Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 21:56:47 +0200 Subject: [PATCH 11/31] correct spelling --- .codespellignore | 2 ++ .github/workflows/code-style.yml | 2 +- CerebNet/config/cerebnet.py | 2 +- CerebNet/models/sub_module.py | 2 +- CerebNet/utils/lr_scheduler.py | 6 +++--- Docker/Dockerfile | 4 ++-- Docker/build.py | 2 +- Docker/install_fs_pruned.sh | 4 ++-- FastSurferCNN/data_loader/conform.py | 2 +- FastSurferCNN/data_loader/data_utils.py | 2 +- FastSurferCNN/data_loader/dataset.py | 2 +- FastSurferCNN/generate_hdf5.py | 6 +++--- FastSurferCNN/models/losses.py | 6 +++--- FastSurferCNN/models/sub_module.py | 18 +++++++++--------- FastSurferCNN/mri_brainvol_stats.py | 2 +- FastSurferCNN/mri_segstats.py | 2 +- FastSurferCNN/segstats.py | 6 +++--- FastSurferCNN/utils/arg_types.py | 2 +- FastSurferCNN/utils/brainvolstats.py | 6 +++--- FastSurferCNN/utils/lr_scheduler.py | 8 ++++---- HypVINN/config/hypvinn.py | 6 +++--- HypVINN/data_loader/dataset.py | 2 +- HypVINN/utils/visualization_utils.py | 2 +- Tutorial/Complete_FastSurfer_Tutorial.ipynb | 2 +- doc/overview/CODE_OF_CONDUCT.md | 2 +- doc/overview/CONTRIBUTING.md | 4 ++-- doc/overview/EDITING.md | 8 ++++---- doc/sphinx_ext/fix_links/resolve.py | 4 ++-- recon_surf/N4_bias_correct.py | 2 +- recon_surf/README.md | 2 +- recon_surf/align_seg.py | 5 +---- recon_surf/recon-surf.sh | 2 +- recon_surf/sample_parc.py | 10 +++++----- recon_surf/smooth_aparc.py | 6 +++--- 34 files changed, 71 insertions(+), 72 deletions(-) diff --git a/.codespellignore b/.codespellignore index f5c72f40..df781460 100644 --- a/.codespellignore +++ b/.codespellignore @@ -1,4 +1,6 @@ +assertIn mapp padd struc +TE warmup diff --git a/.github/workflows/code-style.yml b/.github/workflows/code-style.yml index 2284d735..f8bb9568 100644 --- a/.github/workflows/code-style.yml +++ b/.github/workflows/code-style.yml @@ -32,7 +32,7 @@ jobs: with: check_filenames: true check_hidden: true - skip: './.git,./build,./.mypy_cache,./.pytest_cache' + skip: './build,./doc/images,./Tutorial,./.git,./.mypy_cache,./.pytest_cache' ignore_words_file: ./.codespellignore # - name: Run pydocstyle # run: pydocstyle . diff --git a/CerebNet/config/cerebnet.py b/CerebNet/config/cerebnet.py index 77cf81ab..74be2a5c 100644 --- a/CerebNet/config/cerebnet.py +++ b/CerebNet/config/cerebnet.py @@ -116,7 +116,7 @@ # Data Augmentation options # ---------------------------------------------------------------------------- # -# Augmentation for traning +# Augmentation for training _C.AUGMENTATION = CN() # list of augmentations to use for training diff --git a/CerebNet/models/sub_module.py b/CerebNet/models/sub_module.py index 864803c2..5ad5e190 100644 --- a/CerebNet/models/sub_module.py +++ b/CerebNet/models/sub_module.py @@ -300,7 +300,7 @@ def forward(self, x): Original feature map. out_block : Tensor Maxpooled feature map. - indicies : Tensor + indices : Tensor Maxpool indices. """ out_block = super().forward( diff --git a/CerebNet/utils/lr_scheduler.py b/CerebNet/utils/lr_scheduler.py index f7b67a72..e78297f1 100644 --- a/CerebNet/utils/lr_scheduler.py +++ b/CerebNet/utils/lr_scheduler.py @@ -44,7 +44,7 @@ def __init__(self, optimizer, *args, T_0=10, Tmult=1, lr_restart=None, **kwargs) Args: ...: same as ReduceLROnPlateau T_0 (optional): number of epochs until first restart (default: 10) - Tmult (optiona): multiplicative factor for future restarts (default: 1) + Tmult (optional): multiplicative factor for future restarts (default: 1) lr_restart (optinoal): multiplicative factor for learning rate adjustment at restart. """ # from torch.optim.lr_scheduler._LRSchduler @@ -72,12 +72,12 @@ def __init__(self, optimizer, *args, T_0=10, Tmult=1, lr_restart=None, **kwargs) def step(self, metrics, epoch=None): """ - Perfroms an optimization step. + Performs an optimization step. Parameters ---------- metrics : float - The value of matrics= used to determine learning rate adjustments. + The value of metrics is used to determine learning rate adjustments. epoch : int, default=None Number of epochs. diff --git a/Docker/Dockerfile b/Docker/Dockerfile index 8e4f854e..b239c7e8 100644 --- a/Docker/Dockerfile +++ b/Docker/Dockerfile @@ -40,7 +40,7 @@ # - build_freesurfer: # Build the freesurfer build image only. # - build_common: -# Build the basic image with the python enviroment (hardware/driver-agnostic) +# Build the basic image with the python environment (hardware/driver-agnostic) # - build_conda: # Build the python environment image with cuda/rocm/cpu support @@ -51,7 +51,7 @@ ARG BUILD_BASE_IMAGE=ubuntu:22.04 # BUILDKIT_SBOM:SCAN_CONTEXT enables buildkit to provide and scan build images # this is active by default to provide proper SBOM manifests, however, it may also # include parts that are not part of the distributed image (specifically build image -# parts installed in the build image, but not transfered to the runtime image such as +# parts installed in the build image, but not transferred to the runtime image such as # git, wget, the miniconda installer, etc.) ARG BUILDKIT_SBOM_SCAN_CONTEXT=true diff --git a/Docker/build.py b/Docker/build.py index e3301db5..f54e436c 100755 --- a/Docker/build.py +++ b/Docker/build.py @@ -353,7 +353,7 @@ def get_builder( # see if there is an alternative builder named "fastsurfer*" for builder in builders.keys(): if builder.startswith("fastsurfer") and builders[builder] == builder_type: - # set the default_builder to this (prefered) builder + # set the default_builder to this (preferred) builder alternative_builder = builder break # update is_correct_type diff --git a/Docker/install_fs_pruned.sh b/Docker/install_fs_pruned.sh index e6d4e1a9..2dd37a33 100755 --- a/Docker/install_fs_pruned.sh +++ b/Docker/install_fs_pruned.sh @@ -7,7 +7,7 @@ # In order to update to a new FreeSurfer version you need to update the fslink and then build a # docker with this setup. Run it and whenever it crashes/exits, find the missing file (binary, # atlas, datafile, or dependency) and add it here or if a dependency is missing install it in the -# docker and rebuild and re-run. Repeat until recon-surf finishes sucessfullly. Then repeat with +# docker and rebuild and re-run. Repeat until recon-surf finishes sucessfully. Then repeat with # all supported recon-surf flags (--hires, --fsaparc etc.). @@ -97,7 +97,7 @@ function run_parallel () } -# get Freesurfer and upack (some of it) +# get FreeSurfer and upnack (some of it) echo "Downloading FS and unpacking portions ..." wget --no-check-certificate -qO- $fslink | tar zxv --no-same-owner -C $where \ --exclude='freesurfer/average/*.gca' \ diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index 877f7c63..6c90ccac 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -159,7 +159,7 @@ def options_parse(): dest="force_lia", action="store_false", help="Ignore the reordering of data into LIA (without interpolation). " - "Superceeds --no_strict_lia", + "Supersedes --no_strict_lia", ) advanced.add_argument( "--no_iso_vox", diff --git a/FastSurferCNN/data_loader/data_utils.py b/FastSurferCNN/data_loader/data_utils.py index 438f6e05..1aadd7d9 100644 --- a/FastSurferCNN/data_loader/data_utils.py +++ b/FastSurferCNN/data_loader/data_utils.py @@ -988,7 +988,7 @@ def map_aparc_aseg2label( np.ndarray Mapped aseg for coronal and axial. np.ndarray - Mapped aseg for sagital. + Mapped aseg for sagittal. """ # If corpus callosum is not removed yet, do it now if aseg_nocc is not None: diff --git a/FastSurferCNN/data_loader/dataset.py b/FastSurferCNN/data_loader/dataset.py index d18506e6..ce7a54a7 100644 --- a/FastSurferCNN/data_loader/dataset.py +++ b/FastSurferCNN/data_loader/dataset.py @@ -48,7 +48,7 @@ def __init__( Parameters ---------- orig_data : npt.NDArray - Orignal Data. + Original Data. orig_zoom : npt.NDArray Original zoomfactors. cfg : yacs.config.CfgNode diff --git a/FastSurferCNN/generate_hdf5.py b/FastSurferCNN/generate_hdf5.py index 37383642..6359242c 100644 --- a/FastSurferCNN/generate_hdf5.py +++ b/FastSurferCNN/generate_hdf5.py @@ -93,9 +93,9 @@ class H5pyDataset: Methods ------- __init__ - Consturctor + Constructor _load_volumes - load image and segmentation volume + Load image and segmentation volume transform Transform image along axis _pad_image @@ -233,7 +233,7 @@ def transform( npt.NDArray Transformed image. npt.NDArray - Transformed zoom facors. + Transformed zoom factors. """ for i in range(len(imgs)): if self.plane == "sagittal": diff --git a/FastSurferCNN/models/losses.py b/FastSurferCNN/models/losses.py index 5e8ba9d4..83797f44 100644 --- a/FastSurferCNN/models/losses.py +++ b/FastSurferCNN/models/losses.py @@ -31,7 +31,7 @@ class DiceLoss(_Loss): Methods ------- forward - Calulate the DiceLoss. + Calculate the DiceLoss. """ def forward( @@ -42,7 +42,7 @@ def forward( ignore_index: int | None = None, ) -> torch.Tensor: """ - Calulate the DiceLoss. + Calculate the DiceLoss. Parameters ---------- @@ -178,7 +178,7 @@ def forward( target : Tensor A Tensor of shape N x H x W of integers containing the target. weight : Tensor - A Tensor of shape N x H x W of floats containg the weights. + A Tensor of shape N x H x W of floats containing the weights. Returns ------- diff --git a/FastSurferCNN/models/sub_module.py b/FastSurferCNN/models/sub_module.py index 5814b1e0..02bd700e 100644 --- a/FastSurferCNN/models/sub_module.py +++ b/FastSurferCNN/models/sub_module.py @@ -385,7 +385,7 @@ def __init__(self, params: dict): def forward(self, x: Tensor) -> Tensor: """ - Feed forward trough CompetitiveDenseBlockInput. + Feed forward through CompetitiveDenseBlockInput. in -> BN -> {Conv -> BN -> PReLU} -> {Conv -> BN -> Maxout -> PReLU} -> {Conv -> BN} -> out @@ -492,7 +492,7 @@ class CompetitiveEncoderBlock(CompetitiveDenseBlock): Methods ------- forward - Feed forward trough graph. + Feed forward through graph. """ def __init__(self, params: dict): @@ -513,7 +513,7 @@ def __init__(self, params: dict): def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ - Feed forward trough Encoder Block. + Feed forward through Encoder Block. * CompetitiveDenseBlock * Max Pooling (+ retain indices) @@ -529,7 +529,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: Original feature map. out_block : Tensor Maxpooled feature map. - indicies : Tensor + indices : Tensor Maxpool indices. """ out_block = super().forward( @@ -566,7 +566,7 @@ def __init__(self, params: dict): def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ - Feed forward trough Encoder Block. + Feed forward through Encoder Block. * CompetitiveDenseBlockInput * Max Pooling (+ retain indices) @@ -618,7 +618,7 @@ def __init__(self, params: dict, outblock: bool = False): def forward(self, x: Tensor, out_block: Tensor, indices: Tensor) -> Tensor: """ - Feed forward trough Decoder block. + Feed forward through Decoder block. * Unpooling of feature maps from lower block * Maxout combination of unpooled map + skip connection @@ -661,7 +661,7 @@ class OutputDenseBlock(nn.Module): Methods ------- forward - Feed forward trough graph. + Feed forward through graph. """ def __init__(self, params: dict): @@ -727,7 +727,7 @@ def __init__(self, params: dict): def forward(self, x: Tensor, out_block: Tensor) -> Tensor: """ - Feed forward trough Output block. + Feed forward through Output block. * Maxout combination of unpooled map from previous block + skip connection * Forwarding toward CompetitiveDenseBlock @@ -802,7 +802,7 @@ def __init__(self, params: dict): def forward(self, x: Tensor) -> Tensor: """ - Feed forward trough classifier. + Feed forward through classifier. Parameters ---------- diff --git a/FastSurferCNN/mri_brainvol_stats.py b/FastSurferCNN/mri_brainvol_stats.py index 6a37b7aa..1b1541cb 100644 --- a/FastSurferCNN/mri_brainvol_stats.py +++ b/FastSurferCNN/mri_brainvol_stats.py @@ -130,7 +130,7 @@ def make_arguments() -> argparse.ArgumentParser: type=Path, dest="pvfile", help="Path to image used to compute the partial volume effects. This file is " - "only used in the FastSurfer algoritms (--no_legacy).", + "only used in the FastSurfer algorithms (--no_legacy).", ) return parser diff --git a/FastSurferCNN/mri_segstats.py b/FastSurferCNN/mri_segstats.py index 25a7bb16..cfd7bfdd 100644 --- a/FastSurferCNN/mri_segstats.py +++ b/FastSurferCNN/mri_segstats.py @@ -211,7 +211,7 @@ def help_add_measures(message: str, keys: list[str]) -> str: type=Path, metavar="file", dest="segstatsfile", - help="Specifiy the output summary statistics file.", + help="Specify the output summary statistics file.", ) parser.add_argument( "--pv", diff --git a/FastSurferCNN/segstats.py b/FastSurferCNN/segstats.py index 98319cb7..e253c289 100644 --- a/FastSurferCNN/segstats.py +++ b/FastSurferCNN/segstats.py @@ -706,7 +706,7 @@ def infer_labels_excludeid( labels : npt.NDArray[int] The array of all labels to calculate partial volumes for. exclude_id : list[int] - A list of labels exlicitly excluded from the output table. + A list of labels explicitly excluded from the output table. """ explicit_ids = False if __ids := getattr(args, "ids", None): @@ -849,7 +849,7 @@ def main(args: argparse.Namespace) -> Literal[0] | str: return exception.args[0] if measure_only: - # in this mode, we do not output a data tabel anyways, so no need to compute + # in this mode, we do not output a data table anyways, so no need to compute # all these PV values. labels, exclude_id = np.zeros((0,), dtype=int), [] else: @@ -1450,7 +1450,7 @@ def borders( List of labels for which borders will be computed. If labels is True, _array is treated as a binary mask. max_label : int, optional - The maximum label ot consider. If None, the maximum label in the array is used. + The maximum label to consider. If None, the maximum label in the array is used. six_connected : bool, default=True If True, 6-connected borders (must share a face) are computed, otherwise 26-connected borders (must share a vertex) are computed. diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index 2896c279..49fa0f62 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -39,7 +39,7 @@ def vox_size(a: str) -> VoxSizeOption: Raises ------ argparse.ArgumentTypeError - If the arguemnt is not "min", "auto" or convertible to a float between 0 and 1. + If the argument is not "min", "auto" or convertible to a float between 0 and 1. """ if a.lower() in ["auto", "min"]: return "min" diff --git a/FastSurferCNN/utils/brainvolstats.py b/FastSurferCNN/utils/brainvolstats.py index b8a2b4d1..a0bd21f9 100644 --- a/FastSurferCNN/utils/brainvolstats.py +++ b/FastSurferCNN/utils/brainvolstats.py @@ -1934,7 +1934,7 @@ def default(self, key: str) -> AbstractMeasure: The volume of the corpus callosum in the segmentation. - `lhWM-hypointensities`, and `rhWM-hypointensities` The volume of unlateralized the white matter hypointensities in the - segmentation, but lateralized by neigboring voxels + segmentation, but lateralized by neighboring voxels (FreeSurfer uses talairach coordinates to re-lateralize). - `lhCerebralWhiteMatter`, `rhCerebralWhiteMatter`, and `CerebralWhiteMatter` The volume of the cerebral white matter in the segmentation (including corpus @@ -1952,7 +1952,7 @@ def default(self, key: str) -> AbstractMeasure: - `VentricleChoroidVol` The volume of the choroid plexus and inferiar and lateral ventricles and CSF. - `BrainSeg` - The volume of all brains structres in the segmentation. + The volume of all brain structures in the segmentation. - `BrainSegNotVent`, and `BrainSegNotVentSurf` The brain segmentation volume without ventricles. - `Cerebellum` @@ -2171,7 +2171,7 @@ def mask_77_lat(arr): ) elif key in "BrainSeg": # 0 => BrainSegVol: - # FS7 (does mot use ribbon any more, just ) + # FS7 (does not use ribbon any more, just ) # not background, in aseg ctab, not Brain stem, not optic chiasm, # aseg undefined in aseg ctab and not cortex or WM (L/R Cerebral # Ctx/WM) diff --git a/FastSurferCNN/utils/lr_scheduler.py b/FastSurferCNN/utils/lr_scheduler.py index 9298f821..00f1aeba 100644 --- a/FastSurferCNN/utils/lr_scheduler.py +++ b/FastSurferCNN/utils/lr_scheduler.py @@ -20,14 +20,14 @@ def get_lr_scheduler( - optimzer: torch.optim.Optimizer, cfg: yacs.config.CfgNode + optimizer: torch.optim.Optimizer, cfg: yacs.config.CfgNode ) -> None | scheduler.StepLR | scheduler.CosineAnnealingWarmRestarts: """ Give a schedular for left-right scheduling. Parameters ---------- - optimzer : torch.optim.Optimizer + optimizer : torch.optim.Optimizer Optimizer for the scheduler. cfg : yacs.config.CfgNode Configuration node. @@ -45,13 +45,13 @@ def get_lr_scheduler( scheduler_type = cfg.OPTIMIZER.LR_SCHEDULER if scheduler_type == "step_lr": return scheduler.StepLR( - optimizer=optimzer, + optimizer=optimizer, step_size=cfg.OPTIMIZER.STEP_SIZE, gamma=cfg.OPTIMIZER.GAMMA, ) elif scheduler_type == "cosineWarmRestarts": return scheduler.CosineAnnealingWarmRestarts( - optimizer=optimzer, + optimizer=optimizer, T_0=cfg.OPTIMIZER.T_ZERO, T_mult=cfg.OPTIMIZER.T_MULT, eta_min=cfg.OPTIMIZER.ETA_MIN, diff --git a/HypVINN/config/hypvinn.py b/HypVINN/config/hypvinn.py index 9af9c0f1..caeeee35 100644 --- a/HypVINN/config/hypvinn.py +++ b/HypVINN/config/hypvinn.py @@ -93,15 +93,15 @@ # Options for addition instead of Maxout _C.MODEL.ADDITION = False -#Options for multi modalitie +# Options for multi modalitie _C.MODEL.MULTI_AUTO_W = False # weight per modalitiy _C.MODEL.MULTI_AUTO_W_CHANNELS = False #weight per channel # Flag, for smoothing testing (double number of feature maps before the input interpolation block) _C.MODEL.MULTI_SMOOTH = False -# Brach weights can be aleatory set to zero +# Branch weights can be aleatory set to zero _C.MODEL.HETERO_INPUT = False # Flag for replicating any given modality into the two branches. -# This branch require that the hetero_input also set to TRUE +# This branch requires that the hetero_input also set to TRUE _C.MODEL.DUPLICATE_INPUT = False # ---------------------------------------------------------------------------- # # Training options diff --git a/HypVINN/data_loader/dataset.py b/HypVINN/data_loader/dataset.py index 01c22a84..15db9f36 100644 --- a/HypVINN/data_loader/dataset.py +++ b/HypVINN/data_loader/dataset.py @@ -82,7 +82,7 @@ def __init__( self.plane = cfg.DATA.PLANE #Inference Mode self.mode = mode - #set thickness base on train paramters + #set thickness base on train parameters if cfg.MODEL.MODE in ["t1", "t2"]: self.slice_thickness = cfg.MODEL.NUM_CHANNELS//2 else: diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py index 63fd9e79..baa762ff 100644 --- a/HypVINN/utils/visualization_utils.py +++ b/HypVINN/utils/visualization_utils.py @@ -214,7 +214,7 @@ def select_index_to_plot(hyposeg, slice_step=2): idx_only_third_ventricle = [] for i in idx_with_third_ventricle: label = np.unique(hyposeg[i]) - # Background is allways at position 0 + # Background is always at position 0 if label[1] == 10: idx_only_third_ventricle.append(i) # Remove slices with only third ventricle from the total diff --git a/Tutorial/Complete_FastSurfer_Tutorial.ipynb b/Tutorial/Complete_FastSurfer_Tutorial.ipynb index 2115a5a6..d2d1249a 100644 --- a/Tutorial/Complete_FastSurfer_Tutorial.ipynb +++ b/Tutorial/Complete_FastSurfer_Tutorial.ipynb @@ -2945,7 +2945,7 @@ "id": "qysXQQtefugX" }, "source": [ - "Investiage outlier subject with freeview to see what went wrong.\n", + "Investigate outlier subject with freeview to see what went wrong.\n", "\n", "```bash\n", "outlier=032\n", diff --git a/doc/overview/CODE_OF_CONDUCT.md b/doc/overview/CODE_OF_CONDUCT.md index b83acd71..de97d7a9 100644 --- a/doc/overview/CODE_OF_CONDUCT.md +++ b/doc/overview/CODE_OF_CONDUCT.md @@ -6,7 +6,7 @@ In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, -level of experience, education, socio-economic status, nationality, personal +level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards diff --git a/doc/overview/CONTRIBUTING.md b/doc/overview/CONTRIBUTING.md index aaf0ce49..e8384973 100644 --- a/doc/overview/CONTRIBUTING.md +++ b/doc/overview/CONTRIBUTING.md @@ -72,7 +72,7 @@ This is the preferred way, but only possible if you are the sole develop or your 10. Switch into dev branch (`git checkout dev`) 11. Update your dev branch (`git pull upstream dev`) -12. Switch into your feature (`git chekcout my-new-feature`) +12. Switch into your feature (`git checkout my-new-feature`) 13. Rebase your branch onto dev (`git rebase dev`), resolve conflicts and continue until complete 14. Force push the updated feature branch to your gihub (`git push -f origin my-new-feature`) @@ -81,7 +81,7 @@ Instead you need to merge upstream dev into your branch: 10. Switch into dev branch (`git checkout dev`) 11. Update your dev branch (`git pull upstream dev`) -12. Switch into your feature (`git chekcout my-new-feature`) +12. Switch into your feature (`git checkout my-new-feature`) 13. Merge dev into your feature (`git merge dev`), resolve conflicts and commit 14. Push to origin (`git push origin my-new-feature`) diff --git a/doc/overview/EDITING.md b/doc/overview/EDITING.md index bf7203c2..32b64662 100644 --- a/doc/overview/EDITING.md +++ b/doc/overview/EDITING.md @@ -64,7 +64,7 @@ that was provided from the first run. This can help brighten up some regions and You can manually edit ```aparc.DKTatlas+aseg.deep.mgz```. This is similar to aseg edits in FreeSurfer. You can fill-in undersegmented regions (with the correct segmentation ID). To re-create the aseg and mask run the following command before continuing with other modules: -- Step 1: Assuming that you have run the full fastsurfer pipeline once as described in method_1 and succesfully produced segmentations and surfaces +- Step 1: Assuming that you have run the full fastsurfer pipeline once as described in method_1 and successfully produced segmentations and surfaces - Step 2: Execute this command where reduce_to_aseg.py is located ```bash python3 reduce_to_aseg.py -i sid/mri/aparc.DKTatlas+aseg.deep.edited.mgz \ @@ -72,7 +72,7 @@ You can manually edit ```aparc.DKTatlas+aseg.deep.mgz```. This is similar to ase --outmask sid/mri/mask.mgz \ --fixwm ``` - Assuming you have edited ```aparc.DKTatlas+aseg.deep.edited.mgz``` in freeview, step_2 will produce two files i.e ```aseg.auto_noCCseg.mgz``` and ```mask.mgz ``` in the specified output folder. The ouput files can be loaded in freeview as a load volume. Edit-->load volume + Assuming you have edited ```aparc.DKTatlas+aseg.deep.edited.mgz``` in freeview, step_2 will produce two files i.e ```aseg.auto_noCCseg.mgz``` and ```mask.mgz ``` in the specified output folder. The output files can be loaded in freeview as a load volume. Edit-->load volume - Step 3: For this step you would have to copy segmentation files produced in step_1, edited file ```aparc.DKTatlas+aseg.deep.edited.mgz``` and re-created file produced in step_2 in new output directory beforehand. @@ -99,7 +99,7 @@ You can manually edit ```aparc.DKTatlas+aseg.deep.mgz```. This is similar to ase ## 3. Brainmask Edits: When surfaces go out too far, e.g. they grab dura, you can tighten the mask directly, just edit ```mask.mgz```and start the *surface module*. -- Step 1: Assuming that you have run the full fastsurfer pipeline once as described in method_1 and succesfully produced segmentations and surfaces +- Step 1: Assuming that you have run the full fastsurfer pipeline once as described in method_1 and successfully produced segmentations and surfaces - Step 2: Edit ```mask.mgz``` file in freeview - Step 3: Run the pipeline again in order to get the surfaces but before running the pipeline again do not forget to copy all the segmented files in to new input and output directory. Note: The files in output folder should be pasted in the subjectX folder, the name of subjectX should be the same as it was used in step_1 otherwise it would raise an error of missing files even though the segmentation files exists in output folder. @@ -121,5 +121,5 @@ When surfaces go out too far, e.g. they grab dura, you can tighten the mask dire Note: ```t1-weighted-nii.gz``` would be the original input mri image. - We hope that this will help with (some of) your editing needs. If more edits become availble we will update this file. + We hope that this will help with (some of) your editing needs. If more edits become available we will update this file. Thanks for using FastSurfer. diff --git a/doc/sphinx_ext/fix_links/resolve.py b/doc/sphinx_ext/fix_links/resolve.py index 570e881d..87691523 100644 --- a/doc/sphinx_ext/fix_links/resolve.py +++ b/doc/sphinx_ext/fix_links/resolve.py @@ -81,7 +81,7 @@ def resolve_xref( app : sphinx.application.Sphinx env : sphinx.environment.BuildEnvironment node : sphinx.addnodes.pending_xref - contnode : docutils.noes.Element + contnode : docutils.nodes.Element Returns ------- @@ -96,7 +96,7 @@ def resolve_xref( if attr not in node.attributes: logger.debug( - f"[fix_links] Skipping replacement of {node.attibutes} (no {attr})", + f"[fix_links] Skipping replacement of {node.attributes} (no {attr})", location=loc(node), ) return diff --git a/recon_surf/N4_bias_correct.py b/recon_surf/N4_bias_correct.py index f1a17356..3873cfb9 100644 --- a/recon_surf/N4_bias_correct.py +++ b/recon_surf/N4_bias_correct.py @@ -725,7 +725,7 @@ def main( if aseg: # has aseg # used to be 110, but we found experimentally, that freesurfer wm-normalized - # intensity insde the WM mask is closer to 105 (but also quite inconsistent). + # intensity inside the WM mask is closer to 105 (but also quite inconsistent). # So when we have a WM mask, we need to use 105 and not 110 as for the # percentile approach above. target_wm = 105. diff --git a/recon_surf/README.md b/recon_surf/README.md index 6e812bd4..989ac105 100644 --- a/recon_surf/README.md +++ b/recon_surf/README.md @@ -69,7 +69,7 @@ available in the subjectX/mri directory (e.g. `/home/user/my_fastsurfeer_analysi ## Example 2: recon-surf inside Singularity Singularity can be used instead of Docker to run the full pipeline or individual modules. In this example we change the entrypoint to `recon-surf.sh` instead of the standard -`run_fastsurfer.sh`. Usually it is recomended to just use the default, so this is for expert users who may want to try out specific flags that are not passed to the wrapper. +`run_fastsurfer.sh`. Usually it is recommended to just use the default, so this is for expert users who may want to try out specific flags that are not passed to the wrapper. Given you already ran the segmentation pipeline, and want to just run the surface pipeline on top of it (i.e. on a different cluster), the following command can be used: ```bash diff --git a/recon_surf/align_seg.py b/recon_surf/align_seg.py index 49d30869..b272cf3c 100755 --- a/recon_surf/align_seg.py +++ b/recon_surf/align_seg.py @@ -226,9 +226,7 @@ def get_vox2ras(img:sitk.Image) -> npt.NDArray: def align_flipped(seg: sitk.Image, mid_slice: float | None = None) -> npt.NDArray: """ - Registrate Left - right (make upright). - - Register cortial lables + Register selected cortical labels with left-right flipped (make upright). Parameters ---------- @@ -237,7 +235,6 @@ def align_flipped(seg: sitk.Image, mid_slice: float | None = None) -> npt.NDArra mid_slice : Optional[float] Where the mid slice will be in upright space. Defaults to (width-1)/2. - Returns ------- Tsqrt diff --git a/recon_surf/recon-surf.sh b/recon_surf/recon-surf.sh index 19feb9dc..37cefb89 100755 --- a/recon_surf/recon-surf.sh +++ b/recon_surf/recon-surf.sh @@ -562,7 +562,7 @@ for hemi in lh rh; do echo "#!/bin/bash" > "$CMDF" -# ============================= TESSELATE - SMOOTH ===================================================== +# ============================= TESSELLATE - SMOOTH ===================================================== { echo "echo \" \"" diff --git a/recon_surf/sample_parc.py b/recon_surf/sample_parc.py index c58515ef..17d43d48 100644 --- a/recon_surf/sample_parc.py +++ b/recon_surf/sample_parc.py @@ -143,13 +143,13 @@ def find_all_islands(surf, annot): surf[1] is the np.array of (m, 3) triangle indices. annot : np.ndarray Annotation as an int array of (n,) with label ids for each vertex. - This is for example the first element of the tupel returned by + This is for example the first element of the tuple returned by nibabel fs.read_annot. Returns ------- vidx : np.ndarray (i,) - Arrray listing vertex indices of island vertices, empty if no islands + Array listing vertex indices of island vertices, empty if no islands (components disconnetcted from largest label region) are found. """ # construct adjaceny matrix without edges across regions: @@ -203,7 +203,7 @@ def sample_nearest_nonzero(img, vox_coords, radius=3.0): # the nearest neighbor voxel, instead of at the float vox coordinates # create box with 2*rvox+1 side length to fully contain ball - # and get coordiante offsets with zero at center + # and get coordinate offsets with zero at center ri = np.floor(rvox).astype(int) ll = np.arange(-ri,ri+1) xv, yv, zv = np.meshgrid(ll, ll, ll) @@ -300,7 +300,7 @@ def sample_img(surf, img, cortex=None, projmm=0.0, radius=None): T.orient_() # compute sample coordinates projmm mm along the surface normal - # in surface RAS coordiante system: + # in surface RAS coordinate system: x = T.v + projmm * T.vertex_normals() # mask cortex xx = x[mask] @@ -341,7 +341,7 @@ def replace_labels(img_labels, img_lut, surf_lut): Parameters ---------- img_labels : np.ndarray(n,) - Array with imgage label ids. + Array with image label ids. img_lut : str Filename for image label look up table. surf_lut : str diff --git a/recon_surf/smooth_aparc.py b/recon_surf/smooth_aparc.py index 3c5053e6..b1f1437e 100644 --- a/recon_surf/smooth_aparc.py +++ b/recon_surf/smooth_aparc.py @@ -222,7 +222,7 @@ def mode_filter( # Only after fixing the rows above, we can # get rid of entries that should not vote # since we have only rows that were non-uniform, they should not become empty - # rows may become unform: we still need to vote below to update this label + # rows may become uniform: we still need to vote below to update this label if novote is not None: rr = np.isin(nlabels.data, novote) nlabels.data[rr] = 0 @@ -262,7 +262,7 @@ def smooth_aparc(surf, labels, cortex = None): Parameters ---------- surf : nibabel surface - Suface filepath and name of source. + Surface filepath and name of source. labels : np.array[int] Labels at each vertex (int). cortex : np.array[int] @@ -378,7 +378,7 @@ def main( Parameters ---------- insurfname : str - Suface filepath and name of source. + Surface filepath and name of source. inaparcname : str Annotation filepath and name of source. incortexname : str From bcd45487ae3adb5e79a8888269267d89300e475c Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 22:00:27 +0200 Subject: [PATCH 12/31] fix spelling --- Docker/install_fs_pruned.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Docker/install_fs_pruned.sh b/Docker/install_fs_pruned.sh index 2dd37a33..ddc1980d 100755 --- a/Docker/install_fs_pruned.sh +++ b/Docker/install_fs_pruned.sh @@ -7,7 +7,7 @@ # In order to update to a new FreeSurfer version you need to update the fslink and then build a # docker with this setup. Run it and whenever it crashes/exits, find the missing file (binary, # atlas, datafile, or dependency) and add it here or if a dependency is missing install it in the -# docker and rebuild and re-run. Repeat until recon-surf finishes sucessfully. Then repeat with +# docker and rebuild and re-run. Repeat until recon-surf finishes successfully. Then repeat with # all supported recon-surf flags (--hires, --fsaparc etc.). From 63d5685d479f65bca3f72e51655daf9727bb3b54 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 22:38:21 +0200 Subject: [PATCH 13/31] add back (locally) unused imports in logging --- FastSurferCNN/utils/logging.py | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/FastSurferCNN/utils/logging.py b/FastSurferCNN/utils/logging.py index 2aab4c99..32ab7fb7 100644 --- a/FastSurferCNN/utils/logging.py +++ b/FastSurferCNN/utils/logging.py @@ -13,7 +13,7 @@ # limitations under the License. # IMPORTS -from logging import INFO, FileHandler, StreamHandler, basicConfig +from logging import getLogger, DEBUG, INFO, FileHandler, StreamHandler, basicConfig from pathlib import Path as _Path from sys import stdout as _stdout diff --git a/pyproject.toml b/pyproject.toml index 86d32b53..7bdd5995 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ select = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] "Tutorial/*.ipynb" = ["E501"] # exclude "Line too long" +"FastSurferCNN/utils/logging.py" =["F401"] # exclude "Imported but unused" [tool.pytest.ini_options] minversion = '6.0' From f0fe277bbf8f7b8c74dea100960549e5c70b5617 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 22:42:57 +0200 Subject: [PATCH 14/31] fix import order --- FastSurferCNN/utils/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FastSurferCNN/utils/logging.py b/FastSurferCNN/utils/logging.py index 32ab7fb7..6427ae72 100644 --- a/FastSurferCNN/utils/logging.py +++ b/FastSurferCNN/utils/logging.py @@ -13,7 +13,7 @@ # limitations under the License. # IMPORTS -from logging import getLogger, DEBUG, INFO, FileHandler, StreamHandler, basicConfig +from logging import DEBUG, INFO, FileHandler, StreamHandler, basicConfig, getLogger from pathlib import Path as _Path from sys import stdout as _stdout From e8a445e2ec3f64bd1701151f585f39056ae64284 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 22:49:51 +0200 Subject: [PATCH 15/31] add back Logger and alias get_logger --- FastSurferCNN/utils/logging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/FastSurferCNN/utils/logging.py b/FastSurferCNN/utils/logging.py index 6427ae72..eae72067 100644 --- a/FastSurferCNN/utils/logging.py +++ b/FastSurferCNN/utils/logging.py @@ -13,7 +13,8 @@ # limitations under the License. # IMPORTS -from logging import DEBUG, INFO, FileHandler, StreamHandler, basicConfig, getLogger +from logging import DEBUG, INFO, FileHandler, Logger, StreamHandler, basicConfig, getLogger +from logging import getLogger as get_logger from pathlib import Path as _Path from sys import stdout as _stdout From 977d7d6b757a4c3f76f8b11993518b920bd17343 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 22:57:17 +0200 Subject: [PATCH 16/31] fix doc string bg_label description --- FastSurferCNN/quick_qc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/FastSurferCNN/quick_qc.py b/FastSurferCNN/quick_qc.py index 04a5179c..d74a20de 100644 --- a/FastSurferCNN/quick_qc.py +++ b/FastSurferCNN/quick_qc.py @@ -115,7 +115,8 @@ def get_region_bg_intersection_mask( Segmentation array. region_labels : dict, default= Dictionary whose values correspond to the desired region's labels (see Note). - bg_label : int, default as in BG_LABEL. + bg_label : int, default as in + Label id of the background. Returns ------- From 5d133e6b68e74f282b615ad8329e2a914c8ea8f3 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 23:11:30 +0200 Subject: [PATCH 17/31] fix string split --- FastSurferCNN/utils/parser_defaults.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FastSurferCNN/utils/parser_defaults.py b/FastSurferCNN/utils/parser_defaults.py index 58f0c48a..8b6a5255 100644 --- a/FastSurferCNN/utils/parser_defaults.py +++ b/FastSurferCNN/utils/parser_defaults.py @@ -129,7 +129,7 @@ def _stub(parser: CanAddArguments | type[dict], *flags, **kwargs): return {"flag": _flags[0], "flags": _flags, **kwargs} else: raise ValueError( - f"Unclear parameter, should be dict or argparse.ArgumentParser, not " + "Unclear parameter, should be dict or argparse.ArgumentParser, not " \ f"{type(parser).__name__}." ) From a7c84b26ecef4321dba3fafe54782685a56659ee Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 23:36:55 +0200 Subject: [PATCH 18/31] add back import of locally unused argparse in parser_defaults --- FastSurferCNN/utils/parser_defaults.py | 3 ++- pyproject.toml | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/FastSurferCNN/utils/parser_defaults.py b/FastSurferCNN/utils/parser_defaults.py index 8b6a5255..6b54ab45 100644 --- a/FastSurferCNN/utils/parser_defaults.py +++ b/FastSurferCNN/utils/parser_defaults.py @@ -26,6 +26,7 @@ >>> # 'dest': 'root', 'help': 'Allow execution as root user.'} """ +import argparse import types from collections.abc import Iterable, Mapping from dataclasses import Field, dataclass @@ -129,7 +130,7 @@ def _stub(parser: CanAddArguments | type[dict], *flags, **kwargs): return {"flag": _flags[0], "flags": _flags, **kwargs} else: raise ValueError( - "Unclear parameter, should be dict or argparse.ArgumentParser, not " \ + f"Unclear parameter, should be dict or argparse.ArgumentParser, not " f"{type(parser).__name__}." ) diff --git a/pyproject.toml b/pyproject.toml index 7bdd5995..882d2d43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,7 +131,8 @@ select = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] "Tutorial/*.ipynb" = ["E501"] # exclude "Line too long" -"FastSurferCNN/utils/logging.py" =["F401"] # exclude "Imported but unused" +"FastSurferCNN/utils/logging.py" = ["F401"] # exclude "Imported but unused" +"FastSurferCNN/utils/parser_defaults.py" = ["F401"] # exclude "Imported but unused" [tool.pytest.ini_options] minversion = '6.0' From c79651053df2dc75e13e8ce964346e4f9d07e01d Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Tue, 3 Sep 2024 23:50:18 +0200 Subject: [PATCH 19/31] add back fields --- FastSurferCNN/utils/parser_defaults.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FastSurferCNN/utils/parser_defaults.py b/FastSurferCNN/utils/parser_defaults.py index 6b54ab45..7713c0ce 100644 --- a/FastSurferCNN/utils/parser_defaults.py +++ b/FastSurferCNN/utils/parser_defaults.py @@ -29,7 +29,7 @@ import argparse import types from collections.abc import Iterable, Mapping -from dataclasses import Field, dataclass +from dataclasses import Field, dataclass, fields from pathlib import Path from typing import Literal, Protocol, TypeVar, get_args, get_origin From 0b8c54da11bbb03f45ce96d3fdb55ccfcee10541 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 4 Sep 2024 09:28:06 +0200 Subject: [PATCH 20/31] fix typo MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: David Kügler --- Docker/install_fs_pruned.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Docker/install_fs_pruned.sh b/Docker/install_fs_pruned.sh index ddc1980d..6f55580b 100755 --- a/Docker/install_fs_pruned.sh +++ b/Docker/install_fs_pruned.sh @@ -97,7 +97,7 @@ function run_parallel () } -# get FreeSurfer and upnack (some of it) +# get FreeSurfer and unpack (some of it) echo "Downloading FS and unpacking portions ..." wget --no-check-certificate -qO- $fslink | tar zxv --no-same-owner -C $where \ --exclude='freesurfer/average/*.gca' \ From d83d1720d8a33e3603509c6614548f15cb4df723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Wed, 4 Sep 2024 11:05:48 +0200 Subject: [PATCH 21/31] Fix formatting in CerebNet/apply_warp.py --- CerebNet/apply_warp.py | 72 ++++++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/CerebNet/apply_warp.py b/CerebNet/apply_warp.py index 22f209cc..a774deff 100644 --- a/CerebNet/apply_warp.py +++ b/CerebNet/apply_warp.py @@ -1,7 +1,4 @@ -import argparse -from os.path import join - -import nibabel as nib +#!/bin/python # Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn # @@ -16,13 +13,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + # IMPORTS +import argparse +from pathlib import Path +from numbers import Number +from typing import cast + +import nibabel as nib import numpy as np +from numpy import typing as npt from CerebNet.datasets import utils -def save_nii_image(img_data, save_path, header, affine): +def save_nii_image( + img_data: npt.ArrayLike, + save_path: Path | str, + header: nib.Nifti1Header | nib.Nifti2Header, + affine: npt.NDArray[float], +): """ Save an image data array as a NIfTI file. @@ -30,33 +40,40 @@ def save_nii_image(img_data, save_path, header, affine): ---------- img_data : ndarray The image data to be saved. - save_path : str + save_path : Path, str The path (including file name) where the image will be saved. - header : nibabel.Nifti1Header + header : nibabel.Nifti1Header, nibabel.Nifti2Header The header information for the NIfTI file. affine : ndarray The affine matrix for the NIfTI file. """ - + if not isinstance(header, nib.Nifti1Header): + header = nib.Nifti1Header.from_header(header) img_out = nib.Nifti1Image(img_data, header=header, affine=affine) print(f"Saving {save_path}") nib.save(img_out, save_path) -def main(img_path, lbl_path, warp_path, result_path, patch_size): +def main( + img_path: Path | str, + lbl_path: Path | str, + warp_path: Path | str, + result_path: Path | str, + patch_size, +): """ Load, warp, crop, and save both an image and its corresponding label based on a given warp field. Parameters ---------- - img_path : str + img_path : Path, str Path to the T1-weighted MRI image to be warped. - lbl_path : str + lbl_path : Path, str Path to the label image corresponding to the T1 image, to be warped similarly. - warp_path : str + warp_path : Path, str Path to the warp field file used to warp the images. - result_path : str + result_path : Path, str Directory path where the warped and cropped images will be saved. patch_size : tuple of int The dimensions (height, width, depth) cropped images after warping. @@ -65,9 +82,16 @@ def main(img_path, lbl_path, warp_path, result_path, patch_size): img, img_file = utils.load_reorient_rescale_image(img_path) lbl_file = nib.load(lbl_path) - label = np.asarray(lbl_file.get_fdata(), dtype=np.int16) - - warp_field = np.asarray(nib.load(warp_path).get_fdata()) + # if not isinstance(lbl_file, nib.analyze.SpatialImage): + if not isinstance(lbl_file, nib.Nifti1Image | nib.Nifti2Image): + raise ValueError(f"{lbl_file} is not a valid file format!") + lbl_header = cast(nib.Nifti1Header | nib.Nifti2Header, lbl_file.header) + label = np.asarray(lbl_file.dataobj, dtype=np.int16) + + warp_file = nib.load(warp_path) + if not isinstance(warp_file, nib.analyze.SpatialImage): + raise ValueError(f"{warp_file} is not a valid file format!") + warp_field = np.asarray(warp_file.dataobj, dtype=float) img = utils.map_size(img, base_shape=warp_field.shape[:3]) label = utils.map_size(label, base_shape=warp_field.shape[:3]) warped_img = utils.apply_warp_field(warp_field, img, interpol_order=3) @@ -80,14 +104,14 @@ def main(img_path, lbl_path, warp_path, result_path, patch_size): img_file.header['dim'][1:4] = patch_size img_file.set_data_dtype(img.dtype) - lbl_file.header['dim'][1:4] = patch_size + lbl_header['dim'][1:4] = patch_size save_nii_image(img, - join(result_path, "T1_warped_cropped.nii.gz"), + Path(result_path) / "T1_warped_cropped.nii.gz", header=img_file.header, affine=img_file.affine) save_nii_image(label, - join(result_path, "label_warped_cropped.nii.gz"), - header=lbl_file.header, + Path(result_path) / "label_warped_cropped.nii.gz", + header=lbl_header, affine=lbl_file.affine) @@ -95,13 +119,13 @@ def make_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--img_path", help="path to T1 image", - type=str) + type=Path) parser.add_argument("--lbl_path", help="path to label image", - type=str) + type=Path) parser.add_argument("--result_path", help="folder to store the results", - type=str) + type=Path) parser.add_argument("--warp_filename", help="Warp field file", @@ -113,7 +137,7 @@ def make_parser() -> argparse.ArgumentParser: if __name__ == '__main__': parser = make_parser() args = parser.parse_args() - warp_path = str(join(args.result_path, args.warp_filename)) + warp_path = Path(args.result_path) / args.warp_filename main( args.img_path, args.lbl_path, From 04c236bb53f022cb6e42bfc4dbdd5a78cc28efa4 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 4 Sep 2024 12:41:24 +0200 Subject: [PATCH 22/31] bump version to a dev --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c6d16506..a38cf2c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = 'setuptools.build_meta' [project] name = 'fastsurfer' -version = '2.3.0' +version = '2.4.0-dev' description = 'A fast and accurate deep-learning based neuroimaging pipeline' readme = 'README.md' license = {file = 'LICENSE'} From 751c40db32df99b9c013c55a8a27660f3e7b50aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Wed, 4 Sep 2024 14:26:20 +0200 Subject: [PATCH 23/31] HypVINN/run_prediction.py - revert changes to the docstring of main - move replacement of strings into a decorator function - fix indentation errors in doc FastSurferCNN/checkpoint.py - import Scheduler from torch instead of string-declaring (which does not work with pipe for Union) FastSurferCNN/parser_defaults.py - revert | None to Optional[] - add # noqa: UP0007 to ignore this ruff rule - add documentation for this - remove fields import - add Optional import - revert parser == None removal - reformat for number of characters per line --- FastSurferCNN/utils/checkpoint.py | 9 +- FastSurferCNN/utils/parser_defaults.py | 134 +++++++++++-------------- HypVINN/run_prediction.py | 23 +++-- 3 files changed, 85 insertions(+), 81 deletions(-) diff --git a/FastSurferCNN/utils/checkpoint.py b/FastSurferCNN/utils/checkpoint.py index 967f2841..2e373b58 100644 --- a/FastSurferCNN/utils/checkpoint.py +++ b/FastSurferCNN/utils/checkpoint.py @@ -17,7 +17,7 @@ from collections.abc import MutableSequence from functools import lru_cache from pathlib import Path -from typing import Literal, TypedDict, cast, overload +from typing import Literal, TypedDict, cast, overload, TYPE_CHECKING import requests import torch @@ -27,7 +27,12 @@ from FastSurferCNN.utils import Plane, logging from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT -Scheduler = "torch.optim.lr_scheduler" +if TYPE_CHECKING: + from torch.optim import lr_scheduler as Scheduler +else: + class Scheduler: + ... + LOGGER = logging.getLogger(__name__) # Defaults diff --git a/FastSurferCNN/utils/parser_defaults.py b/FastSurferCNN/utils/parser_defaults.py index 7713c0ce..c0017feb 100644 --- a/FastSurferCNN/utils/parser_defaults.py +++ b/FastSurferCNN/utils/parser_defaults.py @@ -29,9 +29,9 @@ import argparse import types from collections.abc import Iterable, Mapping -from dataclasses import Field, dataclass, fields +from dataclasses import Field, dataclass from pathlib import Path -from typing import Literal, Protocol, TypeVar, get_args, get_origin +from typing import Literal, Optional, Protocol, TypeVar, get_args, get_origin from FastSurferCNN.utils import PLANES, Plane from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one @@ -71,14 +71,12 @@ def __arg( """ Create stub function, which sets default settings for argparse arguments. - The positional and keyword arguments function as if they were directly passed to - parser.add_arguments(). + The positional and keyword arguments function as if they were directly passed to parser.add_arguments(). - The result will be a stub function, which has as first argument a parser (or other - object with an add_argument method) to which the argument is added. The stub - function also accepts positional and keyword arguments, which overwrite the default - arguments. Additionally, these specific values can be callables, which will be - called upon the default values (to alter the default value). + The result will be a stub function, which has as first argument a parser (or other object with an add_argument + method) to which the argument is added. The stub function also accepts positional and keyword arguments, which + overwrite the default arguments. Additionally, these specific values can be callables, which will be called upon the + default values (to alter the default value). This function is private for this module. """ @@ -126,7 +124,7 @@ def _stub(parser: CanAddArguments | type[dict], *flags, **kwargs): _flags = flags if len(flags) != 0 else default_flags if hasattr(parser, "add_argument"): return parser.add_argument(*_flags, **kwargs) - elif isinstance(parser, dict): + elif parser is dict or isinstance(parser, dict): return {"flag": _flags[0], "flags": _flags, **kwargs} else: raise ValueError( @@ -145,6 +143,13 @@ def _stub(parser: CanAddArguments | type[dict], *flags, **kwargs): class SubjectDirectoryConfig: """ This class describes the 'minimal' parameters used by SubjectList. + + Notes + ----- + Important: + Data Types of fields should stay `Optional[]` and not be replaced by ` | None`, so the Parser can use + the type in argparse as the value for `type` of `parser.add_argument()` (`Optional` is a callable, while `Union` is + not). """ orig_name: str = field( help="Name of T1 full head MRI. Absolute path if single image else common " @@ -154,63 +159,57 @@ class SubjectDirectoryConfig: ) pred_name: str = field( default="mri/aparc.DKTatlas+aseg.deep.mgz", - help="Name of intermediate DL-based segmentation file (similar to aparc+aseg). " - "When using FastSurfer, this segmentation is already conformed, since " - "inference is always based on a conformed image. Absolute path if single " - "image else common image name. Default: mri/aparc.DKTatlas+aseg.deep.mgz", + help="Name of intermediate DL-based segmentation file (similar to aparc+aseg). When using FastSurfer, this " + "segmentation is already conformed, since inference is always based on a conformed image. Absolute path " + "if single image else common image name. Default: mri/aparc.DKTatlas+aseg.deep.mgz", ) conf_name: str = field( default="mri/orig.mgz", - help="Name under which the conformed input image will be saved, in the same " - "directory as the segmentation (the input image is always conformed " - "first, if it is not already conformed). The original input image is " - "saved in the output directory as $id/mri/orig/001.mgz. Default: " - "mri/orig.mgz.", + help="Name under which the conformed input image will be saved, in the same directory as the segmentation (the " + "input image is always conformed first, if it is not already conformed). The original input image is " + "saved in the output directory as $id/mri/orig/001.mgz. Default: mri/orig.mgz.", flags=("--conformed_name",), ) - in_dir: Path | None = field( + + in_dir: Optional[Path] = field( # noqa: UP007 flags=("--in_dir",), default=None, - help="Directory in which input volume(s) are located. Optional, if full path " - "is defined for --t1.", + help="Directory in which input volume(s) are located. Optional, if full path is defined for --t1.", ) - csv_file: Path | None = field( + csv_file: Optional[Path] = field( # noqa: UP007 flags=("--csv_file",), default=None, help="Csv-file with subjects to analyze (alternative to --tag)", ) - sid: str | None = field( + sid: Optional[str] = field( # noqa: UP007 flags=("--sid",), default=None, - help="Optional: directly set the subject id to use. Can be used for single " - "subject input. For multi-subject processing, use remove suffix if sid is " - "not second to last element of input file passed to --t1", + help="Optional: directly set the subject id to use. Can be used for single subject input. For multi-subject " + "processing, use remove suffix if sid is not second to last element of input file passed to --t1", ) search_tag: str = field( flags=("--tag",), default="*", - help="Search tag to process only certain subjects. If a single image should be " - "analyzed, set the tag with its id. Default: processes all.", + help="Search tag to process only certain subjects. If a single image should be analyzed, set the tag with its " + "id. Default: processes all.", ) brainmask_name: str = field( default="mri/mask.mgz", - help="Name under which the brainmask image will be saved, in the same " - "directory as the segmentation. The brainmask is created from the " - "aparc_aseg segmentation (dilate 5, erode 4, largest component). Default: " + help="Name under which the brainmask image will be saved, in the same directory as the segmentation. The " + "brainmask is created from the aparc_aseg segmentation (dilate 5, erode 4, largest component). Default: " "`mri/mask.mgz`.", flags=("--brainmask_name",), ) remove_suffix: str = field( flags=("--remove_suffix",), default="", - help="Optional: remove suffix from path definition of input file to yield " - "correct subject name (e.g. /ses-x/anat/ for BIDS or /mri/ for FreeSurfer " - "input). Default: do not remove anything.", + help="Optional: remove suffix from path definition of input file to yield correct subject name (e.g. " + "/ses-x/anat/ for BIDS or /mri/ for FreeSurfer input). Default: do not remove anything.", ) - out_dir: Path | None = field( + out_dir: Optional[Path] = field( # noqa: UP007 default=None, - help="Directory in which evaluation results should be written. Will be created " - "if it does not exist. Optional if full path is defined for --pred_name.", + help="Directory in which evaluation results should be written. Will be created if it does not exist. Optional " + "if full path is defined for --pred_name.", ) @@ -230,8 +229,7 @@ class SubjectDirectoryConfig: type=str, dest="norm_name", default="mri/norm.mgz", - help="Name under which the bias field corrected image is stored. Default: " - "mri/norm.mgz.", + help="Name under which the bias field corrected image is stored. Default: mri/norm.mgz.", ), "brainmask_name": __arg("--brainmask_name", dc=SubjectDirectoryConfig), "aseg_name": __arg( @@ -239,36 +237,33 @@ class SubjectDirectoryConfig: type=str, dest="aseg_name", default="mri/aseg.auto_noCCseg.mgz", - help="Name under which the reduced aseg segmentation will be saved, in the " - "same directory as the aparc-aseg segmentation (labels of full aparc " - "segmentation are reduced to aseg). Default: mri/aseg.auto_noCCseg.mgz.", + help="Name under which the reduced aseg segmentation will be saved, in the same directory as the aparc-aseg " + "segmentation (labels of full aparc segmentation are reduced to aseg). Default: " + "mri/aseg.auto_noCCseg.mgz.", ), "seg_log": __arg( "--seg_log", type=str, dest="log_name", default="", - help="Absolute path to file in which run logs will be saved. If not set, logs " - "will not be saved.", + help="Absolute path to file in which run logs will be saved. If not set, logs will not be saved.", ), "device": __arg( "--device", default="auto", - help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or " - "specify a certain gpu (e.g. cuda:1), default: auto", + help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or specify a certain gpu (e.g. cuda:1), " + "Default: auto", ), "viewagg_device": __arg( "--viewagg_device", dest="viewagg_device", type=str, default="auto", - help="Define the device, where the view aggregation should be run. By default, " - "the program checks if you have enough memory to run the view aggregation " - "on the gpu (cuda). The total memory is considered for this decision. If " - "this fails, or you actively overwrote the check with setting " - "> --viewagg_device cpu <, view agg is run on the cpu. Equivalently, if " - "you define > --viewagg_device cuda <, view agg will be run on the gpu " - "(no memory check will be done).", + help="Define the device, where the view aggregation should be run. By default, the program checks if you have " + "enough memory to run the view aggregation on the gpu (cuda). The total memory is considered for this " + "decision. If this fails, or you actively overwrote the check with setting > --viewagg_device cpu <, view " + "agg is run on the cpu. Equivalently, if you define > --viewagg_device cuda <, view agg will be run on " + "the gpu (no memory check will be done).", ), "in_dir": __arg("--in_dir", dc=SubjectDirectoryConfig, fieldname="in_dir"), "tag": __arg( @@ -290,19 +285,17 @@ class SubjectDirectoryConfig: type=str, dest="qc_log", default="", - help="Absolute path to file in which a list of subjects that failed QC check " - "(when processing multiple subjects) will be saved. If not set, the file " - "will not be saved.", + help="Absolute path to file in which a list of subjects that failed QC check (when processing multiple " + "subjects) will be saved. If not set, the file will not be saved.", ), "vox_size": __arg( "--vox_size", type=__vox_size, default="min", dest="vox_size", - help="Choose the primary voxelsize to process, must be either a number between " - "0 and 1 (below 0.7 is experimental) or 'min' (default). A number forces " - "processing at that specific voxel size, 'min' determines the voxel size " - "from the image itself (conforming to the minimum voxel size, or 1 if the " + help="Choose the primary voxelsize to process, must be either a number between 0 and 1 (below 0.7 is " + "experimental) or 'min' (default). A number forces processing at that specific voxel size, 'min' " + "determines the voxel size from the image itself (conforming to the minimum voxel size, or 1 if the " "minimum voxel size is above 0.95mm). ", ), "conform_to_1mm_threshold": __arg( @@ -310,9 +303,8 @@ class SubjectDirectoryConfig: type=__conform_to_one, default=0.95, dest="conform_to_1mm_threshold", - help="The voxelsize threshold, above which images will be conformed to 1mm " - "isotropic, if the --vox_size argument is also 'min' (the --vox_size " - "default setting). Contrary to conform.py, the default behavior of " + help="The voxelsize threshold, above which images will be conformed to 1mm isotropic, if the --vox_size " + "argument is also 'min' (the --vox_size default setting). Contrary to conform.py, the default behavior of " "%(prog)s is to resample all images above 0.95mm to 1mm.", ), "lut": __arg( @@ -332,16 +324,14 @@ class SubjectDirectoryConfig: dest="threads", default=get_num_threads(), type=int, - help=f"Number of threads to use (defaults to number of hardware threads: " - f"{get_num_threads()})", + help=f"Number of threads to use (defaults to number of hardware threads: {get_num_threads()})", ), "async_io": __arg( "--async_io", dest="async_io", action="store_true", - help="Allow asynchronous file operations (default: off). Note, this may impact " - "the order of messages in the log, but speed up the segmentation " - "specifically for slow file systems.", + help="Allow asynchronous file operations (default: off). Note, this may impact the order of messages in the " + "log, but speed up the segmentation specifically for slow file systems.", ), } @@ -403,11 +393,9 @@ def add_plane_flags( The parser to add flags to. configtype : Literal["checkpoint", "config"] The type of files (for help text and prefix from "checkpoint" and "config". - "checkpoint" will lead to flags like "--ckpt_{plane}", "config" to - "--cfg_{plane}". + "checkpoint" will lead to flags like "--ckpt_{plane}", "config" to "--cfg_{plane}". files : Mapping[Plane, Path | str] - A dictionary of plane to filename. Relative files are assumed to be relative to - the FastSurfer root directory. + A dictionary of plane to filename. Relative files are assumed to be relative to the FastSurfer root directory. defaults_path : Path, str A path to the file to load defaults from. diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index dfe487fd..649b127c 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -160,6 +160,18 @@ def option_parse() -> argparse.ArgumentParser: return parser +def _update_docstring(**kwargs): + """ + Make custom replacements in the docstring. + """ + + def stub(f): + f.__doc__ = f.__doc__.format(**kwargs) + return f + return stub + + +@_update_docstring(HYPVINN_SEG_NAME=HYPVINN_SEG_NAME, HYPVINN_MASK_NAME=HYPVINN_MASK_NAME) def main( out_dir: Path, t2: Path | None, @@ -207,10 +219,10 @@ def main( The path to the coronal configuration file. cfg_sag : Path The path to the sagittal configuration file. - hypo_segfile : str, default is in HYPVINN_SEG_NAME as specified in config. - The name of the hypothalamus segmentation file. Default is in HYPVINN_SEG_NAME. - hypo_maskfile : str, default is in HYPVINN_MASK_NAME - The name of the hypothalamus mask file. Default is in HYPVINN_MASK_NAME. + hypo_segfile : str, default="{HYPVINN_SEG_NAME}" + The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}. + hypo_maskfile : str, default="{HYPVINN_MASK_NAME}" + The name of the hypothalamus mask file. Default is {HYPVINN_MASK_NAME}. allow_root : bool, default=False Whether to allow running as root user. Default is False. qc_snapshots : bool, optional @@ -466,8 +478,7 @@ def load_volumes( ------- tuple A tuple containing the following elements: - - modalities: A dictionary with keys 't1' and/or 't2' and values - being the corresponding loaded and rescaled images. + - modalities: A dictionary of `ndarrays` of rescaled images for keys 't1' and/or 't2'. - affine: The affine transformation of the loaded image(s). - header: The header of the loaded image(s). - zoom: The zoom level of the loaded image(s). From 8ce424888a14937e054929d61ee8f6652b01045a Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 4 Sep 2024 14:59:10 +0200 Subject: [PATCH 24/31] fix import sorting --- CerebNet/apply_warp.py | 1 - FastSurferCNN/utils/checkpoint.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/CerebNet/apply_warp.py b/CerebNet/apply_warp.py index a774deff..c3254aae 100644 --- a/CerebNet/apply_warp.py +++ b/CerebNet/apply_warp.py @@ -17,7 +17,6 @@ # IMPORTS import argparse from pathlib import Path -from numbers import Number from typing import cast import nibabel as nib diff --git a/FastSurferCNN/utils/checkpoint.py b/FastSurferCNN/utils/checkpoint.py index 2e373b58..324184d2 100644 --- a/FastSurferCNN/utils/checkpoint.py +++ b/FastSurferCNN/utils/checkpoint.py @@ -17,7 +17,7 @@ from collections.abc import MutableSequence from functools import lru_cache from pathlib import Path -from typing import Literal, TypedDict, cast, overload, TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, TypedDict, cast, overload import requests import torch From ecf3a94ac0218491c584224e1aab29210af90e7c Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Wed, 4 Sep 2024 15:19:44 +0200 Subject: [PATCH 25/31] Update CONTRIBUTING.md stable is not default any longer. --- doc/overview/CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/overview/CONTRIBUTING.md b/doc/overview/CONTRIBUTING.md index e8384973..48956a62 100644 --- a/doc/overview/CONTRIBUTING.md +++ b/doc/overview/CONTRIBUTING.md @@ -65,7 +65,7 @@ Enhancement suggestions are tracked as [GitHub issues](https://github.com/Deep-M 7. Create your feature branch from dev (`git checkout -b my-new-feature`) 8. Commit your changes (`git commit -am 'Add some feature'`) 9. Push to the branch to your github (`git push origin my-new-feature`) -10. Create new pull request on github web interface from that branch into Deep-NI **dev branch** (not into stable, which is default) +10. Create new pull request on github web interface from that branch into Deep-NI **dev branch** (not into stable) If lots of things changed in the meantime or the pull request is showing conflicts you should rebase your branch to the current upstream dev. This is the preferred way, but only possible if you are the sole develop or your branch: From 1865050eeb774d7750202891ed5f6654a001321c Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Thu, 5 Sep 2024 16:20:38 +0200 Subject: [PATCH 26/31] fix typo --- Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb b/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb index e12a8fb8..4d75019e 100644 --- a/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb +++ b/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb @@ -409,7 +409,7 @@ } ], "source": [ - "#@title The first part of FastSurfer creates a whole-brain segmentation into 95 classes. Here, we use the pretrained deep-learning network FastSurferCNN using the checkpoints stored at the open source project deep-mi/fastsurfer to to run the model inference on a single image.\n", + "#@title The first part of FastSurfer creates a whole-brain segmentation into 95 classes. Here, we use the pretrained deep-learning network FastSurferCNN using the checkpoints stored at the open source project deep-mi/fastsurfer to run the model inference on a single image.\n", "\n", "# Note, you should also add --3T, if you are processing data from a 3T scanner.\n", "! FASTSURFER_HOME=$FASTSURFER_HOME \\\n", From 4557d3bc4d9d54ed908cd222030bc038efc54c2a Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Thu, 5 Sep 2024 16:31:21 +0200 Subject: [PATCH 27/31] skip biasfield cerebnet and hypvinn in quick segmentation --- Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb b/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb index 4d75019e..69e146c9 100644 --- a/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb +++ b/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb @@ -417,6 +417,7 @@ " --sd \"{SETUP_DIR}fastsurfer_seg\" \\\n", " --sid Tutorial \\\n", " --seg_only --py python3 \\\n", + " --no_biasfield --no_cereb --no_hypothal \\\n", " --allow_root" ] }, From 945847fa4ae2251c1d1e2ee53e6bee8c4a40a974 Mon Sep 17 00:00:00 2001 From: David Kuegler Date: Mon, 9 Sep 2024 16:58:41 +0200 Subject: [PATCH 28/31] Subject directory should default to environment SUBJECTS_DIR variable Fix initialization of sd variable in run_fastsurfer.sh --- run_fastsurfer.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index ebed4eb2..16cd5726 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -37,6 +37,7 @@ reconsurfdir="$FASTSURFER_HOME/recon_surf" # Regular flags defaults subject="" +sd="$SUBJECTS_DIR" t1="" t2="" merged_segfile="" From d2382ea3f358e898c7e3da7f8ce83750947ec2fe Mon Sep 17 00:00:00 2001 From: David Kuegler Date: Mon, 9 Sep 2024 16:46:42 +0200 Subject: [PATCH 29/31] Restore CerebNet checkpoint functions Restore FastSurfer checkpoint functions inherited by CerebNet via import. Those functions are used in downstream modules. --- CerebNet/utils/checkpoint.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/CerebNet/utils/checkpoint.py b/CerebNet/utils/checkpoint.py index 2c83469d..985fd330 100644 --- a/CerebNet/utils/checkpoint.py +++ b/CerebNet/utils/checkpoint.py @@ -20,8 +20,25 @@ import yacs from FastSurferCNN.utils import logging +from FastSurferCNN.utils.checkpoint import ( + create_checkpoint_dir, + get_checkpoint, + get_checkpoint_path, + load_from_checkpoint, + save_checkpoint, +) from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT +__all__ = [ + "create_checkpoint_dir", + "get_checkpoint", + "get_checkpoint_path", + "is_checkpoint_epoch", + "load_from_checkpoint", + "save_checkpoint", + "YAML_DEFAULT", +] + # DEFAULTS YAML_DEFAULT = FASTSURFER_ROOT / "CerebNet/config/checkpoint_paths.yaml" From c4c67d35ede7fb047ba4ac1092c8ff8f2d3897ea Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Mon, 9 Sep 2024 18:50:52 +0200 Subject: [PATCH 30/31] bugfix rotate_sphere.py --- recon_surf/rotate_sphere.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recon_surf/rotate_sphere.py b/recon_surf/rotate_sphere.py index 3c7b4b72..c40348a6 100644 --- a/recon_surf/rotate_sphere.py +++ b/recon_surf/rotate_sphere.py @@ -135,7 +135,7 @@ def align_aparc_centroids( # lids=np.array([8,9,22,24,31]) # lids=np.array([8,22,24]) - if label_ids is not None: + if label_ids is None: # use all joint labels except -1 and 0: lids = np.intersect1d(labels_mov, labels_dst) lids = lids[(lids > 0)] From 77231fbad2f0468b27274133843c6f7ab67a9515 Mon Sep 17 00:00:00 2001 From: Martin Reuter Date: Mon, 9 Sep 2024 22:25:07 +0200 Subject: [PATCH 31/31] bug fix (if none) in align_seg.py --- recon_surf/align_seg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recon_surf/align_seg.py b/recon_surf/align_seg.py index b272cf3c..1f2f29da 100755 --- a/recon_surf/align_seg.py +++ b/recon_surf/align_seg.py @@ -124,7 +124,7 @@ def get_seg_centroids( centroids_dst List of centroids of target segmentation. """ - if label_ids is not None: + if label_ids is None: # use all joint labels except -1 and 0: nda1 = sitk.GetArrayFromImage(seg_mov) nda2 = sitk.GetArrayFromImage(seg_dst)