-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OCR Block v2 #706
Draft
stellasphere
wants to merge
8
commits into
main
Choose a base branch
from
trocr-block
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
OCR Block v2 #706
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
f659d10
OCR model options
stellasphere 21ce0c3
Add TrOCR
stellasphere fc33aab
Update Dockerfile.onnx.cpu
stellasphere f61305d
Added Google Cloud Vision
stellasphere 3f48e7b
Add EasyOCR and Mathpix
stellasphere 8d3d754
Formatting
stellasphere 5411c3f
More formatting
stellasphere 89482b8
Remove EasyOCR
stellasphere File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
inference/core/workflows/core_steps/models/foundation/ocr/models/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Callable, List | ||
|
||
from inference.core.workflows.core_steps.common.entities import ( | ||
StepExecutionMode, | ||
) | ||
from inference.core.workflows.execution_engine.entities.base import ( | ||
Batch, | ||
WorkflowImageData, | ||
) | ||
from inference.core.workflows.prototypes.block import BlockResult | ||
|
||
|
||
class BaseOCRModel(ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please do not share common base class for OCR blocks |
||
|
||
def __init__(self, model_manager, api_key): | ||
self.model_manager = model_manager | ||
self.api_key = api_key | ||
|
||
@abstractmethod | ||
def run( | ||
self, | ||
images: Batch[WorkflowImageData], | ||
step_execution_mode: StepExecutionMode, | ||
post_process_result: Callable[ | ||
[Batch[WorkflowImageData], List[dict]], BlockResult | ||
], | ||
) -> BlockResult: | ||
pass |
64 changes: 64 additions & 0 deletions
64
inference/core/workflows/core_steps/models/foundation/ocr/models/doctr.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest | ||
from inference.core.workflows.core_steps.common.entities import ( | ||
StepExecutionMode, | ||
) | ||
from inference.core.workflows.core_steps.common.utils import load_core_model | ||
from inference.core.workflows.execution_engine.entities.base import ( | ||
Batch, | ||
WorkflowImageData, | ||
) | ||
from inference.core.workflows.prototypes.block import BlockResult | ||
from typing import Callable, List | ||
|
||
from .base import BaseOCRModel | ||
|
||
|
||
class DoctrOCRModel(BaseOCRModel): | ||
|
||
def run( | ||
self, | ||
images: Batch[WorkflowImageData], | ||
step_execution_mode: StepExecutionMode, | ||
post_process_result: Callable[ | ||
[Batch[WorkflowImageData], List[dict]], BlockResult | ||
], | ||
) -> BlockResult: | ||
if step_execution_mode is StepExecutionMode.LOCAL: | ||
return self.run_locally(images, post_process_result) | ||
elif step_execution_mode is StepExecutionMode.REMOTE: | ||
return self.run_remotely(images, post_process_result) | ||
|
||
def run_locally( | ||
self, | ||
images: Batch[WorkflowImageData], | ||
post_process_result: Callable[ | ||
[Batch[WorkflowImageData], List[dict]], BlockResult | ||
], | ||
) -> BlockResult: | ||
predictions = [] | ||
for single_image in images: | ||
inference_request = DoctrOCRInferenceRequest( | ||
image=single_image.to_inference_format(numpy_preferred=True), | ||
api_key=self.api_key, | ||
) | ||
doctr_model_id = load_core_model( | ||
model_manager=self.model_manager, | ||
inference_request=inference_request, | ||
core_model="doctr", | ||
) | ||
result = self.model_manager.infer_from_request_sync( | ||
doctr_model_id, inference_request | ||
) | ||
predictions.append(result.model_dump()) | ||
return post_process_result(images, predictions) | ||
|
||
def run_remotely( | ||
self, | ||
images: Batch[WorkflowImageData], | ||
post_process_result: Callable[ | ||
[Batch[WorkflowImageData], List[dict]], BlockResult | ||
], | ||
) -> BlockResult: | ||
raise NotImplementedError( | ||
"Remote execution is not implemented for DoctrOCRModel." | ||
) |
64 changes: 64 additions & 0 deletions
64
inference/core/workflows/core_steps/models/foundation/ocr/models/google_cloud_vision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# models/google_cloud_vision.py | ||
|
||
from .base import BaseOCRModel | ||
from inference.core.workflows.core_steps.common.entities import ( | ||
StepExecutionMode, | ||
) | ||
from inference.core.workflows.execution_engine.entities.base import ( | ||
Batch, | ||
WorkflowImageData, | ||
) | ||
from typing import Optional | ||
import requests | ||
|
||
|
||
class GoogleCloudVisionOCRModel(BaseOCRModel): | ||
def __init__( | ||
self, model_manager, api_key: Optional[str], google_cloud_api_key: str | ||
): | ||
super().__init__(model_manager, api_key) | ||
self.google_cloud_api_key = google_cloud_api_key | ||
|
||
def run( | ||
self, | ||
images: Batch[WorkflowImageData], | ||
step_execution_mode: StepExecutionMode, | ||
post_process_result, | ||
): | ||
predictions = [] | ||
for image_data in images: | ||
encoded_image = image_data.base64_image | ||
url = ( | ||
f"https://vision.googleapis.com/v1/images:annotate" | ||
f"?key={self.google_cloud_api_key}" | ||
) | ||
|
||
payload = { | ||
"requests": [ | ||
{ | ||
"image": {"content": encoded_image}, | ||
"features": [{"type": "TEXT_DETECTION"}], | ||
} | ||
] | ||
} | ||
# Send the request | ||
response = requests.post(url, json=payload) | ||
if response.status_code == 200: | ||
result = response.json() | ||
text_annotations = result["responses"][0].get( | ||
"textAnnotations", | ||
[], | ||
) | ||
if text_annotations: | ||
text = text_annotations[0]["description"] | ||
else: | ||
text = "" | ||
else: | ||
error_info = response.json().get("error", {}) | ||
message = error_info.get("message", response.text) | ||
raise Exception( | ||
f"Google Cloud Vision API request failed: {message}", | ||
) | ||
prediction = {"result": text} | ||
predictions.append(prediction) | ||
return post_process_result(images, predictions) |
83 changes: 83 additions & 0 deletions
83
inference/core/workflows/core_steps/models/foundation/ocr/models/mathpix.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from .base import BaseOCRModel | ||
from inference.core.workflows.core_steps.common.entities import ( | ||
StepExecutionMode, | ||
) | ||
from inference.core.workflows.execution_engine.entities.base import ( | ||
Batch, | ||
WorkflowImageData, | ||
) | ||
from typing import Optional, List, Callable | ||
from inference.core.workflows.prototypes.block import BlockResult | ||
|
||
import requests | ||
import json | ||
import base64 | ||
|
||
|
||
class MathpixOCRModel(BaseOCRModel): | ||
def __init__( | ||
self, | ||
model_manager, | ||
api_key: Optional[str], | ||
mathpix_app_id: str, | ||
mathpix_app_key: str, | ||
): | ||
super().__init__(model_manager, api_key) | ||
self.mathpix_app_id = mathpix_app_id | ||
self.mathpix_app_key = mathpix_app_key | ||
|
||
def run( | ||
self, | ||
images: Batch[WorkflowImageData], | ||
step_execution_mode: StepExecutionMode, | ||
post_process_result: Callable[ | ||
[Batch[WorkflowImageData], List[dict]], BlockResult | ||
], | ||
) -> BlockResult: | ||
predictions = [] | ||
for image_data in images: | ||
# Decode base64 image to bytes | ||
image_bytes = base64.b64decode(image_data.base64_image) | ||
|
||
# Prepare the request | ||
url = "https://api.mathpix.com/v3/text" | ||
headers = { | ||
"app_id": self.mathpix_app_id, | ||
"app_key": self.mathpix_app_key, | ||
} | ||
data = { | ||
"options_json": json.dumps( | ||
{ | ||
"math_inline_delimiters": ["$", "$"], | ||
"rm_spaces": True, | ||
} | ||
) | ||
} | ||
files = {"file": ("image.jpg", image_bytes, "image/jpeg")} | ||
|
||
# Send the request | ||
response = requests.post( | ||
url, | ||
headers=headers, | ||
data=data, | ||
files=files, | ||
) | ||
|
||
if response.status_code == 200: | ||
result = response.json() | ||
# Extract the text result | ||
text = result.get("text", "") | ||
else: | ||
error_info = response.json().get("error", {}) | ||
message = error_info.get("message", response.text) | ||
detailed_message = error_info.get("detail", "") | ||
|
||
raise Exception( | ||
f"Mathpix API request failed: {message} \n\n" | ||
f"Detailed: {detailed_message}" | ||
) | ||
|
||
prediction = {"result": text} | ||
predictions.append(prediction) | ||
|
||
return post_process_result(images, predictions) |
65 changes: 65 additions & 0 deletions
65
inference/core/workflows/core_steps/models/foundation/ocr/models/trocr.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import Callable, List | ||
|
||
from inference.core.entities.requests.trocr import TrOCRInferenceRequest | ||
from inference.core.workflows.core_steps.common.entities import ( | ||
StepExecutionMode, | ||
) | ||
from inference.core.workflows.core_steps.common.utils import load_core_model | ||
from inference.core.workflows.execution_engine.entities.base import ( | ||
Batch, | ||
WorkflowImageData, | ||
) | ||
from inference.core.workflows.prototypes.block import BlockResult | ||
|
||
from .base import BaseOCRModel | ||
|
||
|
||
class TrOCRModel(BaseOCRModel): | ||
|
||
def run( | ||
self, | ||
images: Batch[WorkflowImageData], | ||
step_execution_mode: StepExecutionMode, | ||
post_process_result: Callable[ | ||
[Batch[WorkflowImageData], List[dict]], BlockResult | ||
], | ||
) -> BlockResult: | ||
if step_execution_mode is StepExecutionMode.LOCAL: | ||
return self.run_locally(images, post_process_result) | ||
elif step_execution_mode is StepExecutionMode.REMOTE: | ||
return self.run_remotely(images, post_process_result) | ||
|
||
def run_locally( | ||
self, | ||
images: Batch[WorkflowImageData], | ||
post_process_result: Callable[ | ||
[Batch[WorkflowImageData], List[dict]], BlockResult | ||
], | ||
) -> BlockResult: | ||
predictions = [] | ||
for single_image in images: | ||
inference_request = TrOCRInferenceRequest( | ||
image=single_image.to_inference_format(numpy_preferred=True), | ||
api_key=self.api_key, | ||
) | ||
trocr_model_id = load_core_model( | ||
model_manager=self.model_manager, | ||
inference_request=inference_request, | ||
core_model="trocr", | ||
) | ||
result = self.model_manager.infer_from_request_sync( | ||
trocr_model_id, inference_request | ||
) | ||
predictions.append(result.model_dump()) | ||
return post_process_result(images, predictions) | ||
|
||
def run_remotely( | ||
self, | ||
images: Batch[WorkflowImageData], | ||
post_process_result: Callable[ | ||
[Batch[WorkflowImageData], List[dict]], BlockResult | ||
], | ||
) -> BlockResult: | ||
raise NotImplementedError( | ||
"Remote execution is not implemented for TrOCRModel.", | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@EmilyGavrilenko added .dev image: https://github.com/roboflow/inference/blob/main/CONTRIBUTING.md