Skip to content

Commit

Permalink
PR fixes:
Browse files Browse the repository at this point in the history
- ruff fixes

- docstrings: added description for fixture filled arguments
  • Loading branch information
taha-abdullah committed Sep 18, 2024
1 parent 810a170 commit 054ea58
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 46 deletions.
6 changes: 4 additions & 2 deletions test/quick_test/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os

from logging import getLogger

logger = getLogger(__name__)


__all__ = ["load_test_subjects"]


def load_test_subjects():
"""
Load the test files from the given file path.
Expand All @@ -19,7 +21,7 @@ def load_test_subjects():
test_subjects = []

# Load the reference and test files
with open(os.path.join(subjects_dir, subjects_list), "r") as file:
with open(os.path.join(subjects_dir, subjects_list)) as file:
for line in file:
filename = line.strip()
logger.debug(filename)
Expand Down
2 changes: 2 additions & 0 deletions test/quick_test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import pytest

__all__ = ["subjects_dir", "test_dir", "reference_dir", "subjects_list"]


@pytest.fixture
def subjects_dir():
Expand Down
22 changes: 12 additions & 10 deletions test/quick_test/test_errors_in_logfiles.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import yaml
from logging import getLogger
from pathlib import Path

from .common import *
import pytest
import yaml

from logging import getLogger
from .common import load_test_subjects

logger = getLogger(__name__)

Expand All @@ -25,7 +25,7 @@ def load_errors():

error_file_path = Path(__file__).parent / "data" / "logfile.errors.yaml"

with open(error_file_path, "r") as file:
with open(error_file_path) as file:
data = yaml.safe_load(file)
errors = data.get("errors", [])
whitelist = data.get("whitelist", [])
Expand Down Expand Up @@ -65,8 +65,10 @@ def test_errors(subjects_dir: Path, test_dir: Path, test_subject: Path):
----------
subjects_dir : Path
Subjects directory.
Filled by pytest fixture from conftest.py.
test_dir : Path
Tests directory.
Filled by pytest fixture from conftest.py.
test_subject : Path
Subject to test.
Expand All @@ -93,22 +95,22 @@ def test_errors(subjects_dir: Path, test_dir: Path, test_subject: Path):
with log_file.open("r") as file:
lines = file.readlines()
lines_with_errors = []
for line_number, line in enumerate(lines, start=1):
for _line_number, line in enumerate(lines, start=1):
if any(error in line.lower() for error in errors):
if not any(white in line.lower() for white in whitelist):
# Get two lines before and after the current line
context = lines[max(0, line_number - 2) : min(len(lines), line_number + 3)]
lines_with_errors.append((line_number, context))
context = lines[max(0, _line_number - 2) : min(len(lines), _line_number + 3)]
lines_with_errors.append((_line_number, context))
# print(lines_with_errors)
files_with_errors[rel_path] = lines_with_errors
error_flag = True
except FileNotFoundError:
raise FileNotFoundError(f"Log file not found at path: {log_file}")
raise FileNotFoundError(f"Log file not found at path: {log_file}") from None

# Print the lines and context with errors for each file
for file, lines in files_with_errors.items():
logger.debug(f"\nFile {file}, in line {files_with_errors[file][0][0]}:")
for line_number, line in lines:
for _line_number, line in lines:
logger.debug(*line, sep="")

# Assert that there are no lines with any of the keywords
Expand Down
14 changes: 7 additions & 7 deletions test/quick_test/test_file_existence.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import pytest

from .common import *

from logging import getLogger

from pathlib import Path

import pytest

from .common import load_test_subjects

logger = getLogger(__name__)


Expand Down Expand Up @@ -41,10 +40,13 @@ def test_file_existence(subjects_dir: Path, test_dir: Path, reference_dir: Path,
----------
subjects_dir : Path
Path to the subjects directory.
Filled by pytest fixture from conftest.py.
test_dir : Path
Name of the test directory.
Filled by pytest fixture from conftest.py.
reference_dir : Path
Name of the reference directory.
Filled by pytest fixture from conftest.py.
test_subject : Path
Name of the test subject.
Expand All @@ -54,8 +56,6 @@ def test_file_existence(subjects_dir: Path, test_dir: Path, reference_dir: Path,
If a file in the reference list does not exist in the test list.
"""

print(test_subject)

# Get reference files from the reference subject directory
reference_subject = subjects_dir / reference_dir / test_subject
reference_files = get_files_from_folder(reference_subject)
Expand Down
24 changes: 16 additions & 8 deletions test/quick_test/test_images.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import pytest
from collections import OrderedDict
from logging import getLogger
from pathlib import Path

import nibabel as nib
import nibabel.cmdline.diff
import numpy as np
from pathlib import Path

from collections import OrderedDict

from .common import *
import pytest

from CerebNet.utils.metrics import dice_score

from logging import getLogger
from .common import load_test_subjects

logger = getLogger(__name__)

Expand Down Expand Up @@ -106,10 +105,13 @@ def test_image_headers(subjects_dir: Path, test_dir: Path, reference_dir: Path,
----------
subjects_dir : Path
Path to the subjects directory.
Filled by pytest fixture from conftest.py.
test_dir : Path
Name of test directory.
Filled by pytest fixture from conftest.py.
reference_dir: Path
Name of reference directory.
Filled by pytest fixture from conftest.py.
test_subject : Path
Name of the test subject.
Expand Down Expand Up @@ -144,10 +146,13 @@ def test_seg_data(subjects_dir: Path, test_dir: Path, reference_dir: Path, test_
----------
subjects_dir : Path
Path to the subjects directory.
Filled by pytest fixture from conftest.py.
test_dir : Path
Name of test directory.
Filled by pytest fixture from conftest.py.
reference_dir : Path
Name of reference directory.
Filled by pytest fixture from conftest.py.
test_subject : Path
Name of the test subject.
Expand All @@ -174,7 +179,7 @@ def test_seg_data(subjects_dir: Path, test_dir: Path, reference_dir: Path, test_

# Check the dice score
np.testing.assert_allclose(
dscore, 0, atol=1e-6, rtol=1e-6, err_msg=f"Dice scores are not within range for all classes"
dscore, 0, atol=1e-6, rtol=1e-6, err_msg="Dice scores are not within range for all classes"
)

# assert dscore == 1, "Dice scores are not 1 for all classes"
Expand All @@ -191,10 +196,13 @@ def test_int_data(subjects_dir: Path, test_dir: Path, reference_dir: Path, test_
----------
subjects_dir : Path
Path to the subjects directory.
Filled by pytest fixture from conftest.py.
test_dir : Path
Name of test directory.
Filled by pytest fixture from conftest.py.
reference_dir : Path
Name of reference directory.
Filled by pytest fixture from conftest.py.
test_subject : Path
Name of the test subject.
Expand Down
32 changes: 13 additions & 19 deletions test/quick_test/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from .conftest import *

import os
from logging import getLogger
from pathlib import Path

import pandas as pd
import pytest
import yaml

from .common import *

from logging import getLogger
from .common import load_test_subjects

logger = getLogger(__name__)

Expand All @@ -29,7 +27,7 @@ def thresholds():
thresholds_file = Path(__file__).parent / "data/thresholds/aseg.stats.yaml"

# Open the file_path and read the thresholds into a dictionary
with open(thresholds_file, "r") as file:
with open(thresholds_file) as file:
data = yaml.safe_load(file)
default_threshold = data.get("default_threshold")
thresholds = data.get("thresholds", {})
Expand Down Expand Up @@ -76,15 +74,15 @@ def load_structs(test_file: Path):
List of structs.
"""

if "aseg" in str(test_file):
if test_file.name == "aseg.stats":
structs_file = Path(__file__).parent / "data/thresholds/aseg.stats.yaml"
elif "aparc+DKT" in str(test_file):
elif test_file.name == "aparc+DKT.stats":
structs_file = Path(__file__).parent / "data/thresholds/aparc+DKT.stats.yaml"
else:
raise ValueError("Unknown test file")

# Open the file_path and read the structs: into a list
with open(structs_file, "r") as file:
with open(structs_file) as file:
data = yaml.safe_load(file)
structs = data.get("structs", [])

Expand Down Expand Up @@ -112,18 +110,17 @@ def read_measure_stats(file_path: Path):
measurements = {}

# Retrieve lines starting with "# Measure" from the stats file
with open(file_path, "r") as file:
with open(file_path) as file:
# Read each line in the file
for i, line in enumerate(file, 1):
for _i, line in enumerate(file, 1):
# Check if the line starts with "# ColHeaders"
if line.startswith("# ColHeaders"):
table_start = i
columns = line.strip("# ColHeaders").strip().split(" ")
line.removeprefix("# ColHeaders").strip().split(" ")

# Check if the line starts with "# Measure"
if line.startswith("# Measure"):
# Strip "# Measure" from the line
line = line.strip("# Measure").strip()
line = line.removeprefix("# Measure").strip()
# Append the measure to the list
line = line.split(", ")
measure.append(line[1])
Expand Down Expand Up @@ -153,17 +150,16 @@ def read_table(file_path: Path):
file_path = file_path / "stats" / "aseg.stats"

# Retrieve stats table from the stats file
with open(file_path, "r") as file:
with open(file_path) as file:
# Read each line in the file
for i, line in enumerate(file, 1):
# Check if the line starts with "# ColHeaders"
if line.startswith("# ColHeaders"):
table_start = i
columns = line.strip("# ColHeaders").strip().split(" ")
columns = line.removeprefix("# ColHeaders").strip().split(" ")

# Read the reference table into a pandas dataframe
table = pd.read_table(file_path, skiprows=table_start, sep="\s+", header=None)
table_numeric = table.apply(pd.to_numeric, errors="coerce")
table.columns = columns
table.set_index(columns[0], inplace=True)

Expand Down Expand Up @@ -202,8 +198,6 @@ def test_measure_exists(subjects_dir: Path, test_dir: Path, test_subject: Path):
f"for struct {struct} the value {data[1].get(struct)} is not close to " f"{ref_data[1].get(struct)}"
)

stats_data = read_measure_stats(test_file)

# Check if all measures exist in stats file
assert len(errors) == 0, ", ".join(errors)

Expand Down

0 comments on commit 054ea58

Please sign in to comment.