generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial * addressed case with no matching pixels * edited comments * edited test file name * delete old test file * minor typo * added test to verify that ndarray values are matchin * edited tests --------- Co-authored-by: Arindam Kulshi <[email protected]>
- Loading branch information
1 parent
35e509e
commit ab5fbbf
Showing
6 changed files
with
143 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
|
||
from dotenv import load_dotenv | ||
import os | ||
import cv2 as cv | ||
|
||
from services.image_segmenter import ImageSegmenter | ||
|
||
def main(): | ||
load_dotenv() | ||
|
||
segmentation_template = os.getenv('SEGMENTATION_TEMPLATE_PATH') | ||
raw_image = os.getenv('RAW_IMAGE_PATH') | ||
labels_path = os.getenv('LABELS_PATH') | ||
|
||
|
||
|
||
segmenter = ImageSegmenter(raw_image, segmentation_template,labels_path) | ||
segments = segmenter.segment() | ||
|
||
segment_info = {label: segment_data.shape for label, segment_data in segments.items()} | ||
print(segment_info) | ||
|
||
|
||
nbs_patient_id_image_path = 'nbs_patient_id_image_path.png' | ||
nbs_cas_id_image_path = 'nbs_cas_id_image_path.png' | ||
|
||
# Save the images | ||
cv.imwrite(nbs_patient_id_image_path , segments['nbs_patient_id']) | ||
cv.imwrite(nbs_cas_id_image_path, segments['nbs_cas_id']) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import cv2 as cv | ||
import numpy as np | ||
import json | ||
import os | ||
|
||
|
||
class ImageSegmenter: | ||
def __init__(self, raw_image, segmentation_template, labels): | ||
|
||
if not os.path.isfile(raw_image) or not os.path.isfile(segmentation_template): | ||
raise FileNotFoundError("One or more input files do not exist.") | ||
|
||
self.raw_image = cv.imread(raw_image) | ||
if self.raw_image is None: | ||
raise ValueError(f"Failed to open image file: {raw_image}") | ||
|
||
self.segmentation_template = cv.imread(segmentation_template) | ||
if self.segmentation_template is None: | ||
raise ValueError(f"Failed to open image file: {segmentation_template}") | ||
|
||
|
||
self.labels = labels | ||
self.segments = {} | ||
|
||
def segment(self): | ||
with open(self.labels, 'r') as f: | ||
labels = json.load(f) | ||
#iterate over the labels | ||
for color, label in labels.items(): | ||
color = tuple(map(int, color.split(','))) | ||
#find indices of the color in the segmentation template where the color matches the expected colors | ||
indices = np.where(np.all(self.segmentation_template == color, axis=-1)) | ||
#if there are no matching pixels | ||
if indices[0].size == 0: | ||
raise ValueError(f"No pixels found for color {color} in segmentation template.") | ||
#if there are matching pixels | ||
if indices[0].size > 0: | ||
#Find the x-y coordinates | ||
y_min, y_max = indices[0].min(), indices[0].max() | ||
x_min, x_max = indices[1].min(), indices[1].max() | ||
#crop the area and store the image in the dict | ||
self.segments[label] = self.raw_image[y_min:y_max+1, x_min:x_max+1] | ||
return self.segments | ||
|
||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
import json | ||
import os | ||
from OCR.services.image_segmenter import ImageSegmenter | ||
from dotenv import load_dotenv | ||
import numpy as np | ||
import cv2 as cv | ||
|
||
|
||
load_dotenv() | ||
segmentation_template = os.getenv('SEGMENTATION_TEMPLATE_PATH') | ||
raw_image = os.getenv('RAW_IMAGE_PATH') | ||
labels_path = os.getenv('LABELS_PATH') | ||
|
||
class TestImageSegmenter: | ||
@pytest.fixture(autouse=True) | ||
def setup(self): | ||
self.raw_image = raw_image | ||
self.segmentation_template = segmentation_template | ||
self.labels_path = labels_path | ||
self.segmenter = ImageSegmenter(self.raw_image, self.segmentation_template, self.labels_path) | ||
|
||
def test_segment(self): | ||
segments = self.segmenter.segment() | ||
assert isinstance(segments, dict) | ||
with open(self.labels_path, 'r') as f: | ||
labels = json.load(f) | ||
assert set(segments.keys()) == set(labels.values()) | ||
for segment in segments.values(): | ||
assert isinstance(segment, np.ndarray) | ||
|
||
def test_segment_shapes(self): | ||
segments = self.segmenter.segment() | ||
for segment in segments.values(): | ||
assert len(segment.shape) == 3 | ||
|
||
def test_segment_shapes(self): | ||
expected_shapes = {'nbs_patient_id': (41, 376, 3), 'nbs_cas_id': (57, 366, 3)} | ||
segments = self.segmenter.segment() | ||
for label, segment in segments.items(): | ||
assert segment.shape == expected_shapes[label] | ||
|
||
def test_no_matching_pixels(self): | ||
segmentation_template = np.zeros((10, 10, 3), dtype=np.uint8) | ||
cv.imwrite('no_matching_colors.png', segmentation_template) | ||
segmenter = ImageSegmenter(self.raw_image, 'no_matching_colors.png', self.labels_path) | ||
with pytest.raises(ValueError): | ||
segmenter.segment() | ||
os.remove('no_matching_colors.png') | ||
|
||
def test_invalid_file_paths(self): | ||
with pytest.raises(FileNotFoundError): | ||
ImageSegmenter('invalid_path', 'invalid_path', {}) | ||
|
||
def test_invalid_image_files(self): | ||
with open('empty_file1', 'w'), open('empty_file2', 'w'): | ||
pass | ||
|
||
with pytest.raises(ValueError): | ||
ImageSegmenter('empty_file1', 'empty_file2', {}) | ||
os.remove('empty_file1') | ||
os.remove('empty_file2') | ||
|
||
|