Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zdeveloper committed Aug 30, 2024
1 parent ebdd734 commit 413ab81
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
Empty file added OCR/tests/empty_file3
Empty file.
32 changes: 20 additions & 12 deletions OCR/tests/segmentation_template_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@
path = os.path.dirname(__file__)


segmentation_template = os.path.join(path, "./assets/form_segmention_template.png")
raw_image = os.path.join(path, "./assets/form_filled.png")
segmentation_template_path = os.path.join(path, "./assets/form_segmention_template.png")
raw_image_path = os.path.join(path, "./assets/form_filled.png")
labels_path = os.path.join(path, "./assets/labels.json")


class TestImageSegmenter:
@pytest.fixture(autouse=True)
def setup(self):
self.raw_image = raw_image
self.segmentation_template = segmentation_template
self.raw_image_path = raw_image_path
self.segmentation_template_path = segmentation_template_path
self.labels_path = labels_path
self.segmenter = ImageSegmenter(self.raw_image, self.segmentation_template, self.labels_path)
self.segmenter = ImageSegmenter()

def test_segment(self):
segments = self.segmenter.segment()
segments = self.segmenter.load_and_segment(
self.raw_image_path, self.segmentation_template_path, self.labels_path
)
assert isinstance(segments, dict)
with open(self.labels_path, "r") as f:
labels = json.load(f)
Expand All @@ -32,7 +34,9 @@ def test_segment(self):

def test_segment_shapes(self):
expected_shapes = {"nbs_patient_id": (57, 366, 3), "nbs_cas_id": (41, 376, 3)}
segments = self.segmenter.segment()
segments = self.segmenter.load_and_segment(
self.raw_image_path, self.segmentation_template_path, self.labels_path
)
for label, segment in segments.items():
assert segment.shape == expected_shapes[label]

Expand All @@ -41,23 +45,27 @@ def test_no_matching_pixels(self):
raw_image = np.ones((10, 10, 3), dtype=np.uint8)
cv.imwrite("no_matching_colors_raw.png", raw_image)
cv.imwrite("no_matching_colors_seg.png", segmentation_template)
segmenter = ImageSegmenter("no_matching_colors_raw.png", "no_matching_colors_seg.png", self.labels_path)
segments = segmenter.segment()
segmenter = ImageSegmenter()
segments = segmenter.load_and_segment("no_matching_colors_raw.png", "no_matching_colors_seg.png", self.labels_path)
assert len(segments) == 2
assert segments["nbs_patient_id"] is None
assert segments["nbs_cas_id"] is None
os.remove("no_matching_colors_raw.png")
os.remove("no_matching_colors_seg.png")

def test_invalid_file_paths(self):
segmenter = ImageSegmenter()

with pytest.raises(FileNotFoundError):
ImageSegmenter("invalid_path", "invalid_path", {})
segmenter.load_and_segment("invalid_path", "invalid_path", {})

def test_invalid_image_files(self):
with open("empty_file1", "w"), open("empty_file2", "w"):
segmenter = ImageSegmenter()

with open("empty_file1", "w"), open("empty_file2", "w"), open("empty_file3", "w"):
pass

with pytest.raises(ValueError):
ImageSegmenter("empty_file1", "empty_file2", {})
segmenter.load_and_segment("empty_file1", "empty_file2", "empty_file3")
os.remove("empty_file1")
os.remove("empty_file2")

0 comments on commit 413ab81

Please sign in to comment.