Skip to content

Commit

Permalink
Add new segmentation method (#84)
Browse files Browse the repository at this point in the history
* add new segmentation template

* re-enable ocr on main

* update defaults

* try to please the lint gods

* remove black and add ruff

* use ruff to reformat
  • Loading branch information
zdeveloper authored Apr 23, 2024
1 parent 58f6d8c commit 4cdca32
Show file tree
Hide file tree
Showing 12 changed files with 217 additions and 139 deletions.
2 changes: 0 additions & 2 deletions OCR/.idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions OCR/.idea/ruff.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions OCR/ocr/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import os

from ocr.services.image_segmenter import ImageSegmenter
from ocr.services.image_ocr import ImageOCR

path = os.path.dirname(__file__)


def main():
segmentation_template = os.path.join(path, "../tests/assets/form_segmention_template.png")
raw_image = os.path.join(path, "../tests/assets/form_filled.png")
labels_path = os.path.join(path, "../tests/assets/labels.json")
segmentation_template = os.path.join(path, "../tests/assets/form_segmentation_template_hep_page_1.png")
raw_image = os.path.join(path, "../tests/assets/form_filled_hep.jpg")
labels_path = os.path.join(path, "../tests/assets/labels_hep_page1.json")

segmenter = ImageSegmenter(raw_image, segmentation_template, labels_path)
segments = segmenter.segment()

print("{:<20} {:<20}".format("Label", "Segment shape"))
for label, segment in segments.items():
print("{:<20} {:<20}".format(label, f"{segment.shape}"))
segment_shape = segment.shape if segment is not None else "INVALID"
print("{:<20} {:<20}".format(f"{segment_shape}", label))
# cv.imwrite(f"{label}_segment.png", segment)

ocr = ImageOCR()
values = ocr.image_to_text(segments=segments)
Expand Down
3 changes: 3 additions & 0 deletions OCR/ocr/services/image_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def __init__(self, model="microsoft/trocr-base-printed"):
def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, str]:
digitized: dict[str, str] = {}
for label, image in segments.items():
if image is None:
continue

pixel_values = self.processor(images=image, return_tensors="pt").pixel_values

generated_ids = self.model.generate(pixel_values)
Expand Down
108 changes: 88 additions & 20 deletions OCR/ocr/services/image_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,90 @@
import os


def crop_zeros(image):
# argwhere will give you the coordinates of every non-zero point
true_points = np.argwhere(image)

if len(true_points) == 0:
return None

# take the smallest points and use them as the top left of your crop
top_left = true_points.min(axis=0)
# take the largest points and use them as the bottom right of your crop
bottom_right = true_points.max(axis=0)
return image[
top_left[0] : bottom_right[0] + 1, # plus 1 because slice isn't
top_left[1] : bottom_right[1] + 1,
] # inclusive


def segment_by_mask_then_crop(self) -> dict[str, np.ndarray]:
segments = {}

with open(self.labels, "r") as f:
labels = json.load(f)
# iterate over the labels
for color, label in labels.items():
raw_image = np.array(self.raw_image, copy=True)
segmentation_template = np.array(self.segmentation_template, copy=True)
color = tuple(map(int, reversed(color.split(","))))
# create a mask for that color
mask = np.all(segmentation_template == color, axis=2).astype(int)

# add a third dimension to the mask
mask = mask[:, :, np.newaxis]

# multiply the original image with the mask then crop it
segments[label] = crop_zeros(raw_image * mask)

if self.debug is True:
print(f"label: {label}")
print(f"color {color}")
print("mask.shape", mask.shape)
mask[mask >= 1] = 255
cv.imwrite(f"{label}_mask.png", mask)
if segments[label] is not None:
print("segment.shape", segments[label].shape)
cv.imwrite(f"{label}.png", segments[label])
print("====")

return segments


def segment_by_color_bounding_box(self) -> dict[str, np.ndarray]:
segments = {}
with open(self.labels, "r") as f:
labels = json.load(f)
# iterate over the labels
for color, label in labels.items():
# we are reversing from RGB in the label to BGR used by the openCV
color = tuple(map(int, reversed(color.split(","))))
# find indices of the color in the segmentation template where the color matches the expected colors
indices = np.where(np.all(self.segmentation_template == color, axis=-1))
# if there are matching pixels
if indices[0].size > 0:
# Find the x-y coordinates
y_min, y_max = indices[0].min(), indices[0].max()
x_min, x_max = indices[1].min(), indices[1].max()
# crop the area and store the image in the dict
segments[label] = self.raw_image[y_min : y_max + 1, x_min : x_max + 1]
else:
segments[label] = None
return segments


class ImageSegmenter:
def __init__(self, raw_image, segmentation_template, labels):
def __init__(
self,
raw_image,
segmentation_template,
labels,
segmentation_function=segment_by_mask_then_crop,
debug=False,
):
self.debug = debug
self.segmentation_function = segmentation_function

if not os.path.isfile(raw_image) or not os.path.isfile(segmentation_template):
raise FileNotFoundError("One or more input files do not exist.")

Expand All @@ -18,24 +100,10 @@ def __init__(self, raw_image, segmentation_template, labels):
raise ValueError(f"Failed to open image file: {segmentation_template}")

self.labels = labels
self.segments = {}

if self.debug is True:
print(f"raw_image shape: {self.raw_image.shape}")
print(f"segmentation_template shape: {self.segmentation_template.shape}")

def segment(self) -> dict[str, np.ndarray]:
with open(self.labels, "r") as f:
labels = json.load(f)
# iterate over the labels
for color, label in labels.items():
color = tuple(map(int, color.split(",")))
# find indices of the color in the segmentation template where the color matches the expected colors
indices = np.where(np.all(self.segmentation_template == color, axis=-1))
# if there are no matching pixels
if indices[0].size == 0:
raise ValueError(f"No pixels found for color {color} in segmentation template.")
# if there are matching pixels
if indices[0].size > 0:
# Find the x-y coordinates
y_min, y_max = indices[0].min(), indices[0].max()
x_min, x_max = indices[1].min(), indices[1].max()
# crop the area and store the image in the dict
self.segments[label] = self.raw_image[y_min : y_max + 1, x_min : x_max + 1]
return self.segments
return self.segmentation_function(self)
125 changes: 27 additions & 98 deletions OCR/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion OCR/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ transformers = {extras = ["torch"], version = "^4.39.3"}
pillow = "^10.3.0"

[tool.poetry.group.dev.dependencies]
black = "^24.3.0"
ruff = "^0.3.7"
pytest = "^8.1.1"

[build-system]
Expand Down
Binary file added OCR/tests/assets/form_filled_hep.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 4cdca32

Please sign in to comment.