diff --git a/.github/workflows/quicktest.yaml b/.github/workflows/quicktest.yaml index b6d827f3..bcf1a7b3 100644 --- a/.github/workflows/quicktest.yaml +++ b/.github/workflows/quicktest.yaml @@ -1,90 +1,109 @@ -name: FastSurfer Singularity +name: quicktest + +# File: quicktest.yaml +# Author: Taha Abdullah +# Created on: 2023-03-04 +# Functionality: This workflow runs some quick integration tests on FastSurfer commits. It checks out the new +# FastSurfer repo, sets up Python, builds a Singularity image, runs FastSurfer on sample MRI data, and +# runs pytest to check if the results are acceptable +# Usage: This workflow is exclusively triggered manually with workflow-dispatch in DeepMI/FastSurfer. + on: +# pull_request: workflow_dispatch: jobs: # Checkout repo checkout: - runs-on: ci-gpu + runs-on: self-hosted steps: - uses: actions/checkout@v2 - # Prepare job: Set up Python, Go, Singularity + # Prepare job: Set up Python, Go, Apptainer prepare-job: - runs-on: ci-gpu + runs-on: self-hosted needs: checkout steps: - name: Set up Python 3.10 uses: actions/setup-python@v3 with: python-version: "3.10" - - name: Set up Go - uses: actions/setup-go@v5 - with: - go-version: '^1.13.1' # The Go version to download (if necessary) and use. + - name: Install package + run: | + python -m pip install --progress-bar off --upgrade pip setuptools wheel + python -m pip install --progress-bar off .[test] +# - name: Set up Go +# uses: actions/setup-go@v5 +# with: +# go-version: '^1.13.1' # The Go version to download (if necessary) and use. - name: Set up Singularity uses: eWaterCycle/setup-singularity@v7 with: - singularity-version: 3.8.3 - - # Build Docker Image and convert it to Singularity - build-singularity-image: - runs-on: ci-gpu + singularity-version: 3.8.7 + + # Build Docker Image and convert it to Apptainer + build-apptainer-image: + runs-on: self-hosted needs: prepare-job steps: - - name: Build Docker Image and convert to Singularity + - name: Build Docker Image and convert to Apptainer run: | - cd $RUNNER_SINGULARITY_IMGS + cd $RUNNER_FASTSURFER_IMGS FILE="fastsurfer-gpu.sif" if [ ! -f "$FILE" ]; then # If the file does not exist, build the file echo "SIF File does not exist. Building file." - PYTHONPATH=$PYTHONPATH - cd $PYTHONPATH + cd $FASTSURFER_HOME python3 Docker/build.py --device cuda --tag fastsurfer_gpu:cuda - cd $RUNNER_SINGULARITY_IMGS - singularity build --force fastsurfer-gpu.sif docker-daemon://fastsurfer_gpu:cuda + apptainer build --force fastsurfer-gpu.sif docker-daemon://fastsurfer_gpu:cuda else echo "File already exists" - cd $PYTHONPATH + cd $FASTSURFER_HOME fi # Run FastSurfer on MRI data run-fastsurfer: - runs-on: ci-gpu - needs: build-singularity-image + runs-on: self-hosted + needs: build-apptainer-image steps: - name: Run FastSurfer run: | - singularity exec --nv \ + cd $RUNNER_FS_OUTPUT + # DIRECTORY="subjectX" + echo "pwd: $(pwd)" + # if [ -d "$DIRECTORY" ]; then + # # if output already exists, delete it and run again + # echo "Output already exists. Deleting output directory and running FastSurfer again." + # rm -rf $DIRECTORY + # fi + apptainer exec --nv \ --no-home \ --bind $GITHUB_WORKSPACE:/fastsurfer-dev \ --env FASTSURFER_HOME=/fastsurfer-dev \ -B $RUNNER_FS_MRI_DATA:/data \ -B $RUNNER_FS_OUTPUT:/output \ - -B $RUNNER_FS_LICENSE:/fs_license \ - $RUNNER_SINGULARITY_IMGS/fastsurfer-gpu.sif \ - /fastsurfer/run_fastsurfer.sh \ + -B $RUNNER_FS_LICENSE:/fs_license/.license \ + $RUNNER_FASTSURFER_IMGS/fastsurfer-gpu.sif \ + /fastsurfer/brun_fastsurfer.sh \ --fs_license /fs_license/.license \ - --t1 /data/subjectx/orig.mgz \ - --sid subjectX --sd /output \ - --parallel --3T + --subject_list /data/subject_list.txt \ + --sd /output \ + --parallel --3T \ + --parallel_subjects surf - # Test file existence - test-file-existence: - runs-on: ci-gpu - needs: run-fastsurfer - steps: - - name: Test File Existence - run: | - python3 test/quick_test/test_file_existence.py $RUNNER_FS_OUTPUT_FILES - - # Test for errors in log files - test-error-messages: - runs-on: ci-gpu - needs: [run-fastsurfer, test-file-existence] - steps: - - name: Test Log Files For Error Messages - run: | - python3 test/quick_test/test_errors.py $RUNNER_FS_OUTPUT_LOGS + # Run pytest + run-pytest: + runs-on: self-hosted + needs: run-fastsurfer + steps: + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install package + run: | + python -m pip install --progress-bar off --upgrade pip setuptools wheel + python -m pip install --progress-bar off .[test] + - name : Run pytest + run: pytest test/quick_test diff --git a/.github/workflows/quicktest_runner.yaml b/.github/workflows/quicktest_runner.yaml new file mode 100644 index 00000000..e7b9fdbf --- /dev/null +++ b/.github/workflows/quicktest_runner.yaml @@ -0,0 +1,87 @@ +name: quicktest-runner + +# File: quicktest_runner.yaml +# Author: Taha Abdullah +# Created on: 2023-07-10 +# Functionality: This workflow runs FastSurfer on MRI data and runs pytest to check if the results are acceptable. It +# also checks if the FastSurfer environment and output already exist, and if not, it creates them. +# Usage: This workflow is triggered on a pull request to the dev and main branch. It can also be triggered manually +# with workflow-dispatch. +# Expected/Used Environment Variables: +# - MAMBAPATH: Path to the micromamba binary. +# - MAMBAROOT: Root path for micromamba. +# - RUNNER_FS_OUTPUT: Path to the directory where FastSurfer output is stored. +# - RUNNER_FS_MRI_DATA: Path to the directory where MRI data is stored. +# - FREESURFER_HOME: Path to the freesurfer directory. +# - FS_LICENSE: Path to the FreeSurfer license file. + +on: + pull_request: + branches: + - dev + - stable + workflow_dispatch: + +jobs: + # Checkout repo + checkout: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + + # Create conda environment, install packages, and run Fastsurfer + run-fastsurfer: + runs-on: self-hosted + needs: checkout + steps: + # Check if the Environment Variables used in further steps are present + - name: Check Environment Variables + run: | + REQUIRED_ENV_VARS=( + "MAMBAPATH" + "MAMBAROOT" + "RUNNER_FS_OUTPUT" + "RUNNER_FS_MRI_DATA" + "FREESURFER_HOME" + "FS_LICENSE" + ) + + for VAR_NAME in "${REQUIRED_ENV_VARS[@]}"; do + if [ -z "${!VAR_NAME}" ]; then + echo "Error: Required environment variable $VAR_NAME is not set" + exit 1 + fi + done + + if [ ! -f "$FS_LICENSE" ]; then + echo "Error: FreeSurfer license file does not exist at $FS_LICENSE" + exit 1 + fi + + if [ ! -d "$FREESURFER_HOME" ]; then + echo "Error: FreeSurfer installation directory does not exist at $FREESURFER_HOME" + exit 1 + fi + # Run FastSurfer on test subjects + - name: Run FastSurfer + run: | + echo "Running FastSurfer..." + echo "Output will be saved in data/${GITHUB_SHA:0:7}" + export FASTSURFER_HOME=$(pwd) + export THIS_RUN_OUTDIR=${GITHUB_SHA:0:7} + mkdir -p $SUBJECTS_DIR/$THIS_RUN_OUTDIR + export TEST_DIR=$THIS_RUN_OUTDIR + ./brun_fastsurfer.sh --subject_list $RUNNER_FS_MRI_DATA/subjects_list.txt \ + --sd $SUBJECTS_DIR/$THIS_RUN_OUTDIR \ + --parallel --threads 4 --3T --parallel_subjects surf + + # Test fastsurfer output + run-pytest: + runs-on: self-hosted + if: always() + needs: run-fastsurfer + steps: + - name: Run pytest + run: | + source /venv-pytest/bin/activate + python -m pytest test/quick_test diff --git a/pyproject.toml b/pyproject.toml index 71598d9a..92bf3942 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,9 +83,13 @@ style = [ 'pydocstyle[toml]', 'ruff', ] +quicktest = [ + 'pytest>=8.2.2', +] all = [ 'fastsurfer[doc]', 'fastsurfer[style]', + 'fastsurfer[quicktest]', ] full = [ 'fastsurfer[all]', diff --git a/test/__init__.py b/test/__init__.py index 29780dc9..db6655bb 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,8 +1,5 @@ - - - -__all__ = [ # This is a list of modules that should be imported when using the import * syntax - 'test_file_existence', - 'test_error_messages', - 'test_errors' - ] +__all__ = [ # This is a list of modules that should be imported when using the import * syntax + "test_file_existence", + "test_error_messages", + "test_errors", +] diff --git a/test/quick_test/common.py b/test/quick_test/common.py new file mode 100644 index 00000000..a708b25c --- /dev/null +++ b/test/quick_test/common.py @@ -0,0 +1,31 @@ +import os +from logging import getLogger + +logger = getLogger(__name__) + + +__all__ = ["load_test_subjects"] + + +def load_test_subjects(): + """ + Load the test files from the given file path. + + Returns: + test_subjects (list): List of subjects to test subjects. + """ + + subjects_dir = os.environ["SUBJECTS_DIR"] + subjects_list = os.environ["SUBJECTS_LIST"] + + test_subjects = [] + + # Load the reference and test files + with open(os.path.join(subjects_dir, subjects_list)) as file: + for line in file: + filename = line.strip() + logger.debug(filename) + # test_file = os.path.join(subjects_dir, filename) + test_subjects.append(filename) + + return test_subjects diff --git a/test/quick_test/conftest.py b/test/quick_test/conftest.py new file mode 100644 index 00000000..3303740f --- /dev/null +++ b/test/quick_test/conftest.py @@ -0,0 +1,26 @@ +import os +from pathlib import Path + +import pytest + +__all__ = ["subjects_dir", "test_dir", "reference_dir", "subjects_list"] + + +@pytest.fixture +def subjects_dir(): + return Path(os.environ["SUBJECTS_DIR"]) + + +@pytest.fixture +def test_dir(): + return Path(os.environ["TEST_DIR"]) + + +@pytest.fixture +def reference_dir(): + return Path(os.environ["REFERENCE_DIR"]) + + +@pytest.fixture +def subjects_list(): + return Path(os.environ["SUBJECTS_LIST"]) diff --git a/test/quick_test/data/errors.yaml b/test/quick_test/data/errors.yaml deleted file mode 100644 index a54810ae..00000000 --- a/test/quick_test/data/errors.yaml +++ /dev/null @@ -1,11 +0,0 @@ -errors: - - "error" - - "error:" - - "exception" - - "traceback" - -whitelist: - - "without error" - - "not included" - - "distance" - - "correcting" diff --git a/test/quick_test/data/logfile.errors.yaml b/test/quick_test/data/logfile.errors.yaml new file mode 100644 index 00000000..06148b6a --- /dev/null +++ b/test/quick_test/data/logfile.errors.yaml @@ -0,0 +1,14 @@ +errors: + - "error" + - "error:" + - "exception" + - "traceback" + +whitelist: + - "without error" + - "not included" + - "distance" + - "correcting" + - "error=" + - "rms error" + - "mcsrch error" diff --git a/test/quick_test/data/thresholds/aparc+DKT.stats.yaml b/test/quick_test/data/thresholds/aparc+DKT.stats.yaml new file mode 100644 index 00000000..77aeef1e --- /dev/null +++ b/test/quick_test/data/thresholds/aparc+DKT.stats.yaml @@ -0,0 +1,23 @@ +default_threshold: 0.01 + +thresholds: + BrainSegVol: 0.1 + BrainSegVolNotVent: 0.1 + VentricleChoroidVol: 0.1 + lhCortexVol: 0.1 + rhCortexVol: 0.1 + CortexVol: 0.1 + lhCerebralWhiteMatterVol: 0.1 + rhCerebralWhiteMatterVol: 0.1 + CerebralWhiteMatterVol: 0.1 + SubCortGrayVol: 0.1 + TotalGrayVol: 0.1 + SupraTentorialVol: 0.1 + SupraTentorialVolNotVent: 0.1 + MaskVol: 0.1 + BrainSegVol-to-eTIV: 0.1 + MaskVol-to-eTIV: 0.1 + lhSurfaceHoles: 0.1 + rhSurfaceHoles: 0.1 + SurfaceHoles: 0.1 + eTIV: 0.1 diff --git a/test/quick_test/data/thresholds/aseg.stats.yaml b/test/quick_test/data/thresholds/aseg.stats.yaml new file mode 100644 index 00000000..09a7edd7 --- /dev/null +++ b/test/quick_test/data/thresholds/aseg.stats.yaml @@ -0,0 +1,46 @@ +default_threshold: 0.01 + +thresholds: + BrainSegVol: 0.1 + BrainSegVolNotVent: 0.1 + VentricleChoroidVol: 0.1 + lhCortexVol: 0.1 + rhCortexVol: 0.1 + CortexVol: 0.1 + lhCerebralWhiteMatterVol: 0.1 + rhCerebralWhiteMatterVol: 0.1 + CerebralWhiteMatterVol: 0.1 + SubCortGrayVol: 0.1 + TotalGrayVol: 0.1 + SupraTentorialVol: 0.1 + SupraTentorialVolNotVent: 0.1 + MaskVol: 0.1 + BrainSegVol-to-eTIV: 0.1 + MaskVol-to-eTIV: 0.1 + lhSurfaceHoles: 0.1 + rhSurfaceHoles: 0.1 + SurfaceHoles: 0.1 + eTIV: 0.1 + + +structs: + - 'BrainSegVol' + - 'BrainSegVolNotVent' + - 'VentricleChoroidVol' + - 'lhCortexVol' + - 'rhCortexVol' + - 'CortexVol' + - 'lhCerebralWhiteMatterVol' + - 'rhCerebralWhiteMatterVol' + - 'CerebralWhiteMatterVol' + - 'SubCortGrayVol' + - 'TotalGrayVol' + - 'SupraTentorialVol' + - 'SupraTentorialVolNotVent' + - 'MaskVol' + - 'BrainSegVol-to-eTIV' + - 'MaskVol-to-eTIV' + - 'lhSurfaceHoles' + - 'rhSurfaceHoles' + - 'SurfaceHoles' + - 'eTIV' \ No newline at end of file diff --git a/test/quick_test/data/thresholds/labels.yaml b/test/quick_test/data/thresholds/labels.yaml new file mode 100644 index 00000000..77857b4d --- /dev/null +++ b/test/quick_test/data/thresholds/labels.yaml @@ -0,0 +1,4 @@ + # Labels + labels: [2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 24, 26, 28, 41, 42, 43, 44, + 46, 47, 49, 50, 51, 52, 53, 54, 58, 60] \ No newline at end of file diff --git a/test/quick_test/test_errors.py b/test/quick_test/test_errors.py deleted file mode 100644 index c1686e42..00000000 --- a/test/quick_test/test_errors.py +++ /dev/null @@ -1,97 +0,0 @@ -import argparse -import sys -import unittest -from pathlib import Path - -import yaml - - -class TestErrors(unittest.TestCase): - """ - A test case class to check for the word "error" in the given log files. - """ - - error_file_path: Path = Path("./test/quick_test/data/errors.yaml") - - error_flag = False - - @classmethod - def setUpClass(cls): - """ - Set up the test class. - This method retrieves the log directory from the command line argument, - and assigns it to a class variable. - """ - - # Open the error_file_path and read the errors and whitelist into arrays - with open(cls.error_file_path) as file: - data = yaml.safe_load(file) - cls.errors = data.get('errors', []) - cls.whitelist = data.get('whitelist', []) - - # Retrieve the log files in given log directory - try: - # cls.log_directory = Path(cls.log_directory) - print(cls.log_directory) - cls.log_files = [file for file in cls.log_directory.iterdir() if file.suffix == '.log'] - except FileNotFoundError: - raise FileNotFoundError(f"Log directory not found at path: {cls.log_directory}") from None - - def test_find_errors_in_logs(self): - """ - Test that the words "error", "exception", and "traceback" are not in the log files. - - This method retrieves the log files in the log directory, reads each log file line by line, - and checks that none of the keywords are in any line. - """ - - files_with_errors = {} - - # Check if any of the keywords are in the log files - for log_file in self.log_files: - rel_path = log_file.relative_to(self.log_directory) - print(f"Checking file: {rel_path}") - try: - with log_file.open('r') as file: - lines = file.readlines() - lines_with_errors = [] - for line_number, line in enumerate(lines, start=1): - if any(error in line.lower() for error in self.errors): - if not any(white in line.lower() for white in self.whitelist): - # Get two lines before and after the current line - context = lines[max(0, line_number-2):min(len(lines), line_number+3)] - lines_with_errors.append((line_number, context)) - print(lines_with_errors) - files_with_errors[rel_path] = lines_with_errors - self.error_flag = True - except FileNotFoundError: - raise FileNotFoundError(f"Log file not found at path: {log_file}") from None - continue - - # Print the lines and context with errors for each file - for file, lines in files_with_errors.items(): - print(f"\nFile {file}, in line {files_with_errors[file][0][0]}:") - for _line_number, line in lines: - print(*line, sep = "") - - # Assert that there are no lines with any of the keywords - self.assertEqual(self.error_flag, False, f"Found errors in the following files: {files_with_errors}") - print("No errors found in any log files.") - - -if __name__ == '__main__': - """ - Main entry point of the script. - - This block checks if there are any command line arguments, - assigns the first argument to the log_directory class variable - """ - - parser = argparse.ArgumentParser(description="Test for errors in log files.") - parser.add_argument('log_directory', type=Path, help="The directory containing the log files.") - - args = parser.parse_args() - - TestErrors.log_directory = args.log_directory - - unittest.main(argv=[sys.argv[0]]) diff --git a/test/quick_test/test_errors_in_logfiles.py b/test/quick_test/test_errors_in_logfiles.py new file mode 100644 index 00000000..2342d3d4 --- /dev/null +++ b/test/quick_test/test_errors_in_logfiles.py @@ -0,0 +1,118 @@ +from logging import getLogger +from pathlib import Path + +import pytest +import yaml + +from .common import load_test_subjects + +logger = getLogger(__name__) + + +def load_errors(): + """ + Load the errors and whitelist strings from ./data/logfile.errors.yaml. + + Returns + ------- + errors : list[str] + List of errors. + whitelist : list[str] + List of whitelisted errors. + """ + + # Open the error_file_path and read the errors and whitelist into arrays + + error_file_path = Path(__file__).parent / "data" / "logfile.errors.yaml" + + with open(error_file_path) as file: + data = yaml.safe_load(file) + errors = data.get("errors", []) + whitelist = data.get("whitelist", []) + + return errors, whitelist + + +def load_log_files(test_subject: Path): + """ + Retrieve the log files in the given log directory. + + Parameters + ---------- + test_subject : Path + Subject directory to test. + + Returns + ------- + log_files : list[Path] + List of log files in the given log directory. + """ + + # Retrieve the log files in given log directory + + log_directory = test_subject / "scripts" + log_files = [file for file in Path(log_directory).iterdir() if file.suffix == ".log"] + + return log_files + + +@pytest.mark.parametrize("test_subject", load_test_subjects()) +def test_errors(subjects_dir: Path, test_dir: Path, test_subject: Path): + """ + Test if there are any errors in the log files. + + Parameters + ---------- + subjects_dir : Path + Subjects directory. + Filled by pytest fixture from conftest.py. + test_dir : Path + Tests directory. + Filled by pytest fixture from conftest.py. + test_subject : Path + Subject to test. + + Raises + ------ + AssertionError + If any of the keywords are in the log files. + """ + + test_subject = subjects_dir / test_dir / test_subject + log_files = load_log_files(test_subject) + + error_flag = False + + errors, whitelist = load_errors() + + files_with_errors = {} + + # Check if any of the keywords are in the log files + for log_file in log_files: + rel_path = log_file.relative_to(subjects_dir) + logger.debug(f"Checking file: {rel_path}") + try: + with log_file.open("r") as file: + lines = file.readlines() + lines_with_errors = [] + for _line_number, line in enumerate(lines, start=1): + if any(error in line.lower() for error in errors): + if not any(white in line.lower() for white in whitelist): + # Get two lines before and after the current line + context = lines[max(0, _line_number - 2) : min(len(lines), _line_number + 3)] + lines_with_errors.append((_line_number, context)) + # print(lines_with_errors) + files_with_errors[rel_path] = lines_with_errors + error_flag = True + except FileNotFoundError: + raise FileNotFoundError(f"Log file not found at path: {log_file}") from None + + # Print the lines and context with errors for each file + for file, lines in files_with_errors.items(): + logger.debug(f"\nFile {file}, in line {files_with_errors[file][0][0]}:") + for _line_number, line in lines: + logger.debug(*line, sep="") + + # Assert that there are no lines with any of the keywords + assert not error_flag, f"Found errors in the following files: {files_with_errors}" + logger.debug("\nNo errors found in any log files.") diff --git a/test/quick_test/test_file_existence.py b/test/quick_test/test_file_existence.py index 0d44176b..ea91e07f 100644 --- a/test/quick_test/test_file_existence.py +++ b/test/quick_test/test_file_existence.py @@ -1,73 +1,71 @@ -import argparse -import sys -import unittest +from logging import getLogger from pathlib import Path -import yaml +import pytest +from .common import load_test_subjects -class TestFileExistence(unittest.TestCase): - """ - A test case class to check the existence of files in a folder based on a YAML file. - - This class defines test methods to verify if each file specified in the YAML file exists in the given folder. - """ - - file_path: Path = Path("./test/quick_test/data/files.yaml") - - @classmethod - def setUpClass(cls): - """ - Set up the test case by loading the YAML file and extracting the folder path. +logger = getLogger(__name__) - This method is executed once before any test methods in the class. - """ - # Open the file_path and read the files into an array - with cls.file_path.open('r') as file: - data = yaml.safe_load(file) - cls.files = data.get('files', []) - - # Get a list of all files in the folder recursively - cls.filenames = [] - for file in cls.folder_path.glob('**/*'): - if file.is_file(): - # Get the relative path from the current directory to the file - rel_path = file.relative_to(cls.folder_path) - cls.filenames.append(str(rel_path)) - - def test_file_existence(self): - """ - Test method to check the existence of files in the folder. +def get_files_from_folder(folder_path: Path): + """ + Get the list of files in the directory relative to the folder path. - This method gets a list of all files in the folder recursively and checks - if each file specified in the YAML file exists in the folder. - """ + Parameters + ---------- + folder_path : Path + Path to the folder. - # Check if each file in the YAML file exists in the folder - if not self.files: - self.fail("The 'files' key was not found in the YAML file") + Returns + ------- + list + List of files in the directory. + """ - for file in self.files: - print(f"Checking for file: {file}") - self.assertIn(file, self.filenames, f"File '{file}' does not exist in the folder.") + # Get a list of all files in the folder recursively + filenames = [] + for file in Path(folder_path).rglob("*"): + filenames.append(str(file.relative_to(folder_path))) - print("All files present") + return filenames -if __name__ == '__main__': +@pytest.mark.parametrize("test_subject", load_test_subjects()) +def test_file_existence(subjects_dir: Path, test_dir: Path, reference_dir: Path, test_subject: Path): """ - Main entry point of the script. - - This block checks if there are any command line arguments, assigns the first argument - to the error_file_path class variable, and runs the unittest main function. + Test the existence of files in the folder. + + Parameters + ---------- + subjects_dir : Path + Path to the subjects directory. + Filled by pytest fixture from conftest.py. + test_dir : Path + Name of the test directory. + Filled by pytest fixture from conftest.py. + reference_dir : Path + Name of the reference directory. + Filled by pytest fixture from conftest.py. + test_subject : Path + Name of the test subject. + + Raises + ------ + AssertionError + If a file in the reference list does not exist in the test list. """ - parser = argparse.ArgumentParser(description="Test for file existence based on a YAML file.") - parser.add_argument('folder_path', type=Path, help="The path to the folder to check.") + # Get reference files from the reference subject directory + reference_subject = subjects_dir / reference_dir / test_subject + reference_files = get_files_from_folder(reference_subject) - args = parser.parse_args() + # Get test list of files in the test subject directory + test_subject = subjects_dir / test_dir / test_subject + test_files = get_files_from_folder(test_subject) - TestFileExistence.folder_path = args.folder_path + # Check if each file in the reference list exists in the test list + missing_files = [file for file in reference_files if file not in test_files] + assert not missing_files, f"Files '{missing_files}' do not exist in test subject." - unittest.main(argv=[sys.argv[0]]) + logger.debug("\nAll files present.") diff --git a/test/quick_test/test_images.py b/test/quick_test/test_images.py new file mode 100644 index 00000000..97a3a1ab --- /dev/null +++ b/test/quick_test/test_images.py @@ -0,0 +1,230 @@ +from collections import OrderedDict +from logging import getLogger +from pathlib import Path + +import nibabel as nib +import nibabel.cmdline.diff +import numpy as np +import pytest + +from CerebNet.utils.metrics import dice_score + +from .common import load_test_subjects + +logger = getLogger(__name__) + + +def load_image(subject_path: Path, image_name: Path): + """ + Load the image data using nibabel. + + Parameters + ---------- + subject_path : Path + Path to the subject directory. + image_name : Path + Name of the image file. + + Returns + ------- + nibabel.nifti1.Nifti1Image + Image data. + """ + image_path = subject_path / "mri" / image_name + image = nib.load(image_path) + + return image + + +def compute_dice_score(test_data, reference_data, labels): + """ + Compute the dice score for each class. + + Parameters + ---------- + test_data : np.ndarray + Test image data. + reference_data : np.ndarray + Reference image data. + labels : np.ndarray + Unique labels in the image data. + + Returns + ------- + np.ndarray + Dice scores for each class. + """ + + # Classes + num_classes = len(labels) + + dscore = np.zeros(shape=num_classes) + + for idx in range(num_classes): + current_label = labels[idx] + + pred = (test_data == current_label).astype(int) + gt = (reference_data == current_label).astype(int) + + dscore[idx] = dice_score(pred, gt) + + logger.debug("\nDice score: ", dscore) + + return dscore + + +def compute_mean_square_error(test_data, reference_data): + """ + Compute the mean square error between the test and reference data. + + Parameters + ---------- + test_data : np.ndarray + Test image data. + reference_data : np.ndarray + Reference image data. + + Returns + ------- + float + Mean square error. + """ + + mse = ((test_data - reference_data) ** 2).mean() + logger.debug("\nMean square error: ", mse) + + return mse + + +@pytest.mark.parametrize("test_subject", load_test_subjects()) +def test_image_headers(subjects_dir: Path, test_dir: Path, reference_dir: Path, test_subject: Path): + """ + Test the image headers by comparing the headers of the test and reference images. + + Parameters + ---------- + subjects_dir : Path + Path to the subjects directory. + Filled by pytest fixture from conftest.py. + test_dir : Path + Name of test directory. + Filled by pytest fixture from conftest.py. + reference_dir: Path + Name of reference directory. + Filled by pytest fixture from conftest.py. + test_subject : Path + Name of the test subject. + + Raises + ------ + AssertionError + If the image headers do not match + """ + + # Load images + test_subject = subjects_dir / test_dir / test_subject + test_image = load_image(test_subject, "brain.mgz") + + reference_subject = subjects_dir / reference_dir / test_subject + reference_image = load_image(reference_subject, "brain.mgz") + + # Get the image headers + headers = [test_image.header, reference_image.header] + + # Check the image headers + header_diff = nibabel.cmdline.diff.get_headers_diff(headers) + assert header_diff == OrderedDict(), f"Image headers do not match: {header_diff}" + logger.debug("Image headers are correct") + + +@pytest.mark.parametrize("test_subject", load_test_subjects()) +def test_seg_data(subjects_dir: Path, test_dir: Path, reference_dir: Path, test_subject: Path): + """ + Test the segmentation data by calculating and comparing dice scores. + + Parameters + ---------- + subjects_dir : Path + Path to the subjects directory. + Filled by pytest fixture from conftest.py. + test_dir : Path + Name of test directory. + Filled by pytest fixture from conftest.py. + reference_dir : Path + Name of reference directory. + Filled by pytest fixture from conftest.py. + test_subject : Path + Name of the test subject. + + Raises + ------ + AssertionError + If the dice score is not 0 for all classes + """ + + test_file = subjects_dir / test_dir / test_subject + test_image = load_image(test_file, "aseg.mgz") + + reference_subject = subjects_dir / reference_dir / test_subject + reference_image = load_image(reference_subject, "aseg.mgz") + + labels = np.unique([np.asarray(reference_image.dataobj), np.asarray(test_image.dataobj)]) + + # Get the image data + test_data = np.asarray(test_image.dataobj) + reference_data = np.asarray(reference_image.dataobj) + + # Compute the dice score + dscore = compute_dice_score(test_data, reference_data, labels) + + # Check the dice score + np.testing.assert_allclose( + dscore, 0, atol=1e-6, rtol=1e-6, err_msg="Dice scores are not within range for all classes" + ) + + # assert dscore == 1, "Dice scores are not 1 for all classes" + + logger.debug("Dice scores are within range for all classes") + + +@pytest.mark.parametrize("test_subject", load_test_subjects()) +def test_int_data(subjects_dir: Path, test_dir: Path, reference_dir: Path, test_subject: Path): + """ + Test the intensity data by calculating and comparing the mean square error. + + Parameters + ---------- + subjects_dir : Path + Path to the subjects directory. + Filled by pytest fixture from conftest.py. + test_dir : Path + Name of test directory. + Filled by pytest fixture from conftest.py. + reference_dir : Path + Name of reference directory. + Filled by pytest fixture from conftest.py. + test_subject : Path + Name of the test subject. + + Raises + ------ + AssertionError + If the mean square error is not 0 + """ + + test_file = subjects_dir / test_dir / test_subject + test_image = load_image(test_file, "brain.mgz") + + reference_subject = subjects_dir / reference_dir / test_subject + reference_image = load_image(reference_subject, "brain.mgz") + + # Get the image data + test_data = test_image.get_fdata() + reference_data = reference_image.get_fdata() + + mse = compute_mean_square_error(test_data, reference_data) + + # Check the image data + assert mse == 0, "Mean square error is not 0" + + logger.debug("\nImage data matches") diff --git a/test/quick_test/test_stats.py b/test/quick_test/test_stats.py new file mode 100644 index 00000000..adb3b253 --- /dev/null +++ b/test/quick_test/test_stats.py @@ -0,0 +1,261 @@ +import os +from logging import getLogger +from pathlib import Path + +import pandas as pd +import pytest +import yaml + +from .common import load_test_subjects + +logger = getLogger(__name__) + + +@pytest.fixture +def thresholds(): + """ + Load the thresholds from the given file path. + + Returns + ------- + default_threshold : float + Default threshold value. + thresholds : dict + Dictionary containing the thresholds + """ + + thresholds_file = Path(__file__).parent / "data/thresholds/aseg.stats.yaml" + + # Open the file_path and read the thresholds into a dictionary + with open(thresholds_file) as file: + data = yaml.safe_load(file) + default_threshold = data.get("default_threshold") + thresholds = data.get("thresholds", {}) + + return default_threshold, thresholds + + +def load_stats_file(test_subject: Path): + """ + Load the stats file from the given file path. + + Parameters + ---------- + test_subject : Path + Path to the test subject. + + Returns + ------- + stats_file : Path + """ + + files = os.listdir(test_subject / "stats") + + if "aseg.stats" in files: + return test_subject / "stats" / "aseg.stats" + elif "aparc+DKT.stats" in files: + return test_subject / "stats" / "aparc+DKT.stats" + else: + raise ValueError("Unknown stats file") + + +def load_structs(test_file: Path): + """ + Load the structs from the given file path. + + Parameters + ---------- + test_file : Path + Path to the test file. + + Returns + ------- + structs : list + List of structs. + """ + + if test_file.name == "aseg.stats": + structs_file = Path(__file__).parent / "data/thresholds/aseg.stats.yaml" + elif test_file.name == "aparc+DKT.stats": + structs_file = Path(__file__).parent / "data/thresholds/aparc+DKT.stats.yaml" + else: + raise ValueError("Unknown test file") + + # Open the file_path and read the structs: into a list + with open(structs_file) as file: + data = yaml.safe_load(file) + structs = data.get("structs", []) + + return structs + + +def read_measure_stats(file_path: Path): + """ + Read the measure stats from the given file path. + + Parameters + ---------- + file_path : Path + Path to the stats file. + + Returns + ------- + measure : list + List of measures. + measurements : dict + Dictionary containing the measurements. + """ + + measure = [] + measurements = {} + + # Retrieve lines starting with "# Measure" from the stats file + with open(file_path) as file: + # Read each line in the file + for _i, line in enumerate(file, 1): + # Check if the line starts with "# ColHeaders" + if line.startswith("# ColHeaders"): + line.removeprefix("# ColHeaders").strip().split(" ") + + # Check if the line starts with "# Measure" + if line.startswith("# Measure"): + # Strip "# Measure" from the line + line = line.removeprefix("# Measure").strip() + # Append the measure to the list + line = line.split(", ") + measure.append(line[1]) + measurements[line[1]] = float(line[3]) + + return measure, measurements + + +def read_table(file_path: Path): + """ + Read the table from the given file path. + + Parameters + ---------- + file_path : Path + Path to the stats file. + + Returns + ------- + table : pandas.DataFrame + Table containing the + """ + + table_start = 0 + columns = [] + + file_path = file_path / "stats" / "aseg.stats" + + # Retrieve stats table from the stats file + with open(file_path) as file: + # Read each line in the file + for i, line in enumerate(file, 1): + # Check if the line starts with "# ColHeaders" + if line.startswith("# ColHeaders"): + table_start = i + columns = line.removeprefix("# ColHeaders").strip().split(" ") + + # Read the reference table into a pandas dataframe + table = pd.read_table(file_path, skiprows=table_start, sep="\s+", header=None) + table.columns = columns + table.set_index(columns[0], inplace=True) + + return table + + +@pytest.mark.parametrize("test_subject", load_test_subjects()) +def test_measure_exists(subjects_dir: Path, test_dir: Path, test_subject: Path): + """ + Test if the measure exists in the stats file. + + Parameters + ---------- + subjects_dir : Path + Path to the subjects directory. + test_dir : Path + Name of the test directory. + test_subject : Path + Name of the test subject. + + Raises + ------ + AssertionError + If the measure does not exist in the stats file. + """ + + test_subject = subjects_dir / test_dir / test_subject + test_file = load_stats_file(test_subject) + data = read_measure_stats(test_file) + ref_data = read_measure_stats(test_file) + errors = [] + + for struct in load_structs(test_file): + if struct not in data[1]: + errors.append( + f"for struct {struct} the value {data[1].get(struct)} is not close to " f"{ref_data[1].get(struct)}" + ) + + # Check if all measures exist in stats file + assert len(errors) == 0, ", ".join(errors) + + +@pytest.mark.parametrize("test_subject", load_test_subjects()) +def test_tables(subjects_dir: Path, test_dir: Path, reference_dir: Path, test_subject: Path, thresholds): + """ + Test if the tables are within the threshold. + + Parameters + ---------- + subjects_dir : Path + Path to the subjects directory. + test_dir : Path + Name of the test directory. + reference_dir : Path + Name of the reference directory. + test_subject : Path + Name of the test subject. + thresholds : tuple + Tuple containing the default threshold and the thresholds. + + Raises + ------ + AssertionError + If the table values are not within the threshold. + """ + + # Load the test and reference tables + test_file = subjects_dir / test_dir / test_subject + test_table = read_table(test_file) + + reference_subject = subjects_dir / reference_dir / test_subject + ref_table = read_table(reference_subject) + + # Load the thresholds + default_threshold, thresholds = thresholds + + variations = {} + + # Check if table values are within the threshold + for i in ref_table.index: + struct = ref_table.loc[i, "StructName"] + for j in ref_table.columns: + if j == "StructName": + continue + threshold = default_threshold + if ref_table.loc[i, j] == 0: + continue + variation = (test_table.loc[i, j] / ref_table.loc[i, j]) - 1 + if abs(variation) > threshold: + variations[struct] = {j: abs(variation)} + + if variations: + logger.debug("\nVariations greater than threshold:") + for key, value in variations.items(): + logger.debug(key, value) + + assert not variations, "Variations greater than threshold found." + + logger.debug("\nAll table values are within the threshold.")