Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ruff linter and formatter. #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,32 @@
import collections
import copy
import json
import os
import time

import datasets
from tabulate import tabulate

from surya.benchmark.bbox import get_pdf_lines
from surya.benchmark.metrics import precision_recall
from surya.benchmark.tesseract import tesseract_bboxes, tesseract_parallel
from surya.model.segformer import load_model, load_processor
from surya.model.processing import open_pdf, get_page_images
from surya.benchmark.tesseract import tesseract_parallel
from surya.detection import batch_inference
from surya.model.processing import get_page_images, open_pdf
from surya.model.segformer import load_model, load_processor
from surya.postprocessing.heatmap import draw_bboxes_on_image
from surya.postprocessing.util import rescale_bbox
from surya.settings import settings
import os
import time
from tabulate import tabulate
import datasets


def main():
parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.")
parser.add_argument("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None)
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument(
"--results_dir",
type=str,
help="Path to JSON file with OCR results.",
default=os.path.join(settings.RESULT_DIR, "benchmark"),
)
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=100)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()
Expand All @@ -34,7 +40,7 @@ def main():
doc = open_pdf(args.pdf_path)
page_count = len(doc)
page_indices = list(range(page_count))
page_indices = page_indices[:args.max]
page_indices = page_indices[: args.max]

images = get_page_images(doc, page_indices)
doc.close()
Expand Down Expand Up @@ -72,10 +78,7 @@ def main():
surya_metrics = precision_recall(surya_boxes, cb)
tess_metrics = precision_recall(tb, cb)

page_metrics[idx] = {
"surya": surya_metrics,
"tesseract": tess_metrics
}
page_metrics[idx] = {"surya": surya_metrics, "tesseract": tess_metrics}

if args.debug:
bbox_image = draw_bboxes_on_image(surya_boxes, copy.deepcopy(images[idx]))
Expand All @@ -93,12 +96,9 @@ def main():
mean_metrics[k][m] = sum(metric) / len(metric)

out_data = {
"times": {
"surya": surya_time,
"tesseract": tess_time
},
"times": {"surya": surya_time, "tesseract": tess_time},
"metrics": mean_metrics,
"page_metrics": page_metrics
"page_metrics": page_metrics,
}

with open(os.path.join(result_path, "results.json"), "w+") as f:
Expand All @@ -107,11 +107,13 @@ def main():
table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types
table_data = [
["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types],
["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types]
["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types],
]

print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.")
print(
"Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines."
)
print(f"Wrote results to {result_path}")


Expand Down
12 changes: 8 additions & 4 deletions benchmark/pymupdf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
import os

from surya.benchmark.bbox import get_pdf_lines
from surya.model.processing import get_page_images, open_pdf
from surya.postprocessing.heatmap import draw_bboxes_on_image

from surya.model.processing import open_pdf, get_page_images
from surya.settings import settings


def main():
parser = argparse.ArgumentParser(description="Draw pymupdf line bboxes on images.")
parser.add_argument("pdf_path", type=str, help="Path to PDF to detect bboxes in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "pymupdf"))
parser.add_argument(
"--results_dir",
type=str,
help="Path to JSON file with OCR results.",
default=os.path.join(settings.RESULT_DIR, "pymupdf"),
)
args = parser.parse_args()

doc = open_pdf(args.pdf_path)
Expand All @@ -34,6 +38,6 @@ def main():

print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()

12 changes: 8 additions & 4 deletions benchmark/tesseract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
import os

from surya.benchmark.tesseract import tesseract_bboxes
from surya.model.processing import get_page_images, open_pdf
from surya.postprocessing.heatmap import draw_bboxes_on_image

from surya.model.processing import open_pdf, get_page_images
from surya.settings import settings


def main():
parser = argparse.ArgumentParser(description="Draw tesseract bboxes on imagese.")
parser.add_argument("pdf_path", type=str, help="Path to PDF to detect bboxes in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "tesseract"))
parser.add_argument(
"--results_dir",
type=str,
help="Path to JSON file with OCR results.",
default=os.path.join(settings.RESULT_DIR, "tesseract"),
)
args = parser.parse_args()

doc = open_pdf(args.pdf_path)
Expand All @@ -33,6 +37,6 @@ def main():

print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()

22 changes: 10 additions & 12 deletions detect_text.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import argparse
import copy
import json
import os
from collections import defaultdict

import filetype
from PIL import Image

from surya.model.segformer import load_model, load_processor
from surya.model.processing import open_pdf, get_page_images
from surya.detection import batch_inference
from surya.model.processing import get_page_images, open_pdf
from surya.model.segformer import load_model, load_processor
from surya.postprocessing.affinity import draw_lines_on_image
from surya.postprocessing.heatmap import draw_bboxes_on_image
from surya.settings import settings
import os
import filetype


def get_name_from_path(path):
Expand Down Expand Up @@ -68,7 +68,12 @@ def load_from_folder(folder_path, max_pages=None):
def main():
parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
parser.add_argument("input_path", type=str, help="Path to pdf or image file to detect bboxes in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
parser.add_argument(
"--results_dir",
type=str,
help="Path to JSON file with OCR results.",
default=os.path.join(settings.RESULT_DIR, "surya"),
)
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
Expand Down Expand Up @@ -121,10 +126,3 @@ def main():

if __name__ == "__main__":
main()







Loading