Skip to content

Commit

Permalink
Added logic to segment pdf (#33)
Browse files Browse the repository at this point in the history
* initial

* addressed case with no matching pixels

* edited comments

* edited test file name

* delete old test file

* minor typo

* added test to verify that ndarray values are matchin

* edited tests

---------

Co-authored-by: Arindam Kulshi <[email protected]>
  • Loading branch information
arinkulshi-skylight and arinkulshi authored Mar 26, 2024
1 parent 35e509e commit ab5fbbf
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 0 deletions.
Empty file added OCR/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions OCR/segmentation_template_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

from dotenv import load_dotenv
import os
import cv2 as cv

from services.image_segmenter import ImageSegmenter

def main():
load_dotenv()

segmentation_template = os.getenv('SEGMENTATION_TEMPLATE_PATH')
raw_image = os.getenv('RAW_IMAGE_PATH')
labels_path = os.getenv('LABELS_PATH')



segmenter = ImageSegmenter(raw_image, segmentation_template,labels_path)
segments = segmenter.segment()

segment_info = {label: segment_data.shape for label, segment_data in segments.items()}
print(segment_info)


nbs_patient_id_image_path = 'nbs_patient_id_image_path.png'
nbs_cas_id_image_path = 'nbs_cas_id_image_path.png'

# Save the images
cv.imwrite(nbs_patient_id_image_path , segments['nbs_patient_id'])
cv.imwrite(nbs_cas_id_image_path, segments['nbs_cas_id'])


if __name__ == "__main__":
main()
Empty file added OCR/services/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions OCR/services/image_segmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import cv2 as cv
import numpy as np
import json
import os


class ImageSegmenter:
def __init__(self, raw_image, segmentation_template, labels):

if not os.path.isfile(raw_image) or not os.path.isfile(segmentation_template):
raise FileNotFoundError("One or more input files do not exist.")

self.raw_image = cv.imread(raw_image)
if self.raw_image is None:
raise ValueError(f"Failed to open image file: {raw_image}")

self.segmentation_template = cv.imread(segmentation_template)
if self.segmentation_template is None:
raise ValueError(f"Failed to open image file: {segmentation_template}")


self.labels = labels
self.segments = {}

def segment(self):
with open(self.labels, 'r') as f:
labels = json.load(f)
#iterate over the labels
for color, label in labels.items():
color = tuple(map(int, color.split(',')))
#find indices of the color in the segmentation template where the color matches the expected colors
indices = np.where(np.all(self.segmentation_template == color, axis=-1))
#if there are no matching pixels
if indices[0].size == 0:
raise ValueError(f"No pixels found for color {color} in segmentation template.")
#if there are matching pixels
if indices[0].size > 0:
#Find the x-y coordinates
y_min, y_max = indices[0].min(), indices[0].max()
x_min, x_max = indices[1].min(), indices[1].max()
#crop the area and store the image in the dict
self.segments[label] = self.raw_image[y_min:y_max+1, x_min:x_max+1]
return self.segments



Empty file added OCR/tests/__init__.py
Empty file.
64 changes: 64 additions & 0 deletions OCR/tests/segmentation_template_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
import json
import os
from OCR.services.image_segmenter import ImageSegmenter
from dotenv import load_dotenv
import numpy as np
import cv2 as cv


load_dotenv()
segmentation_template = os.getenv('SEGMENTATION_TEMPLATE_PATH')
raw_image = os.getenv('RAW_IMAGE_PATH')
labels_path = os.getenv('LABELS_PATH')

class TestImageSegmenter:
@pytest.fixture(autouse=True)
def setup(self):
self.raw_image = raw_image
self.segmentation_template = segmentation_template
self.labels_path = labels_path
self.segmenter = ImageSegmenter(self.raw_image, self.segmentation_template, self.labels_path)

def test_segment(self):
segments = self.segmenter.segment()
assert isinstance(segments, dict)
with open(self.labels_path, 'r') as f:
labels = json.load(f)
assert set(segments.keys()) == set(labels.values())
for segment in segments.values():
assert isinstance(segment, np.ndarray)

def test_segment_shapes(self):
segments = self.segmenter.segment()
for segment in segments.values():
assert len(segment.shape) == 3

def test_segment_shapes(self):
expected_shapes = {'nbs_patient_id': (41, 376, 3), 'nbs_cas_id': (57, 366, 3)}
segments = self.segmenter.segment()
for label, segment in segments.items():
assert segment.shape == expected_shapes[label]

def test_no_matching_pixels(self):
segmentation_template = np.zeros((10, 10, 3), dtype=np.uint8)
cv.imwrite('no_matching_colors.png', segmentation_template)
segmenter = ImageSegmenter(self.raw_image, 'no_matching_colors.png', self.labels_path)
with pytest.raises(ValueError):
segmenter.segment()
os.remove('no_matching_colors.png')

def test_invalid_file_paths(self):
with pytest.raises(FileNotFoundError):
ImageSegmenter('invalid_path', 'invalid_path', {})

def test_invalid_image_files(self):
with open('empty_file1', 'w'), open('empty_file2', 'w'):
pass

with pytest.raises(ValueError):
ImageSegmenter('empty_file1', 'empty_file2', {})
os.remove('empty_file1')
os.remove('empty_file2')


0 comments on commit ab5fbbf

Please sign in to comment.