From 413ab81e2d085d5909b71a07650076cd523d26a8 Mon Sep 17 00:00:00 2001 From: Zedd Shmais Date: Fri, 30 Aug 2024 11:33:34 -0500 Subject: [PATCH] fix tests --- OCR/tests/empty_file3 | 0 OCR/tests/segmentation_template_test.py | 32 +++++++++++++++---------- 2 files changed, 20 insertions(+), 12 deletions(-) create mode 100644 OCR/tests/empty_file3 diff --git a/OCR/tests/empty_file3 b/OCR/tests/empty_file3 new file mode 100644 index 00000000..e69de29b diff --git a/OCR/tests/segmentation_template_test.py b/OCR/tests/segmentation_template_test.py index c36b9155..7cbb1c71 100644 --- a/OCR/tests/segmentation_template_test.py +++ b/OCR/tests/segmentation_template_test.py @@ -8,21 +8,23 @@ path = os.path.dirname(__file__) -segmentation_template = os.path.join(path, "./assets/form_segmention_template.png") -raw_image = os.path.join(path, "./assets/form_filled.png") +segmentation_template_path = os.path.join(path, "./assets/form_segmention_template.png") +raw_image_path = os.path.join(path, "./assets/form_filled.png") labels_path = os.path.join(path, "./assets/labels.json") class TestImageSegmenter: @pytest.fixture(autouse=True) def setup(self): - self.raw_image = raw_image - self.segmentation_template = segmentation_template + self.raw_image_path = raw_image_path + self.segmentation_template_path = segmentation_template_path self.labels_path = labels_path - self.segmenter = ImageSegmenter(self.raw_image, self.segmentation_template, self.labels_path) + self.segmenter = ImageSegmenter() def test_segment(self): - segments = self.segmenter.segment() + segments = self.segmenter.load_and_segment( + self.raw_image_path, self.segmentation_template_path, self.labels_path + ) assert isinstance(segments, dict) with open(self.labels_path, "r") as f: labels = json.load(f) @@ -32,7 +34,9 @@ def test_segment(self): def test_segment_shapes(self): expected_shapes = {"nbs_patient_id": (57, 366, 3), "nbs_cas_id": (41, 376, 3)} - segments = self.segmenter.segment() + segments = self.segmenter.load_and_segment( + self.raw_image_path, self.segmentation_template_path, self.labels_path + ) for label, segment in segments.items(): assert segment.shape == expected_shapes[label] @@ -41,8 +45,8 @@ def test_no_matching_pixels(self): raw_image = np.ones((10, 10, 3), dtype=np.uint8) cv.imwrite("no_matching_colors_raw.png", raw_image) cv.imwrite("no_matching_colors_seg.png", segmentation_template) - segmenter = ImageSegmenter("no_matching_colors_raw.png", "no_matching_colors_seg.png", self.labels_path) - segments = segmenter.segment() + segmenter = ImageSegmenter() + segments = segmenter.load_and_segment("no_matching_colors_raw.png", "no_matching_colors_seg.png", self.labels_path) assert len(segments) == 2 assert segments["nbs_patient_id"] is None assert segments["nbs_cas_id"] is None @@ -50,14 +54,18 @@ def test_no_matching_pixels(self): os.remove("no_matching_colors_seg.png") def test_invalid_file_paths(self): + segmenter = ImageSegmenter() + with pytest.raises(FileNotFoundError): - ImageSegmenter("invalid_path", "invalid_path", {}) + segmenter.load_and_segment("invalid_path", "invalid_path", {}) def test_invalid_image_files(self): - with open("empty_file1", "w"), open("empty_file2", "w"): + segmenter = ImageSegmenter() + + with open("empty_file1", "w"), open("empty_file2", "w"), open("empty_file3", "w"): pass with pytest.raises(ValueError): - ImageSegmenter("empty_file1", "empty_file2", {}) + segmenter.load_and_segment("empty_file1", "empty_file2", "empty_file3") os.remove("empty_file1") os.remove("empty_file2")