Skip to content

Commit

Permalink
feat(Proofs): ML: receipt content extraction (#715)
Browse files Browse the repository at this point in the history
  • Loading branch information
TTalex authored Feb 22, 2025
1 parent d05320f commit 11dad87
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 3 deletions.
88 changes: 87 additions & 1 deletion open_prices/proofs/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,20 @@ class Labels(typing.TypedDict):
labels: list[Label]


class ReceiptItem(typing.TypedDict):
product: Products
price: float
product_name: str


class Receipt(typing.TypedDict):
store_name: str
store_address: str
store_city_name: str
date: str
items: list[ReceiptItem]


def extract_from_price_tag(image: Image.Image) -> Label:
"""Extract price tag information from an image.
Expand Down Expand Up @@ -258,6 +272,28 @@ def extract_from_price_tags(images: Image.Image) -> Labels:
return json.loads(response.text)


def extract_from_receipt(image: Image.Image) -> Receipt:
"""Extract receipt information from an image."""
# Gemini model max payload size is 20MB
# To prevent the payload from being too large, we resize the images before
# upload
max_size = 1024
if image.width > max_size or image.height > max_size:
image = image.copy()
image.thumbnail((max_size, max_size))

response = model.generate_content(
[
"Extract all relevent information, use empty strings for unknown values.",
image,
],
generation_config=genai.GenerationConfig(
response_mime_type="application/json", response_schema=Receipt
),
)
return json.loads(response.text)


def predict_proof_type(
image: Image.Image,
model_name: str = PROOF_CLASSIFICATION_MODEL_NAME,
Expand Down Expand Up @@ -654,15 +690,63 @@ def run_and_save_proof_type_prediction(
)


def run_and_save_receipt_extraction_prediction(
image: Image, proof: Proof, overwrite: bool = False
) -> ProofPrediction | None:
"""Run the receipt extraction model and save the prediction in
ProofPrediction table.
:param image: the image to run the model on
:param proof: the Proof instance to associate the ProofPrediction with
:param overwrite: whether to overwrite existing prediction, defaults to
False
:return: the ProofPrediction instance created, or None if the prediction
already exists and overwrite is False
"""
if proof.type != proof_constants.TYPE_RECEIPT:
logger.debug("Skipping proof %s, not of type RECEIPT", proof.id)
return None

if ProofPrediction.objects.filter(
proof=proof, model_name=GEMINI_MODEL_NAME
).exists():
if overwrite:
logger.info("Overwriting existing type prediction for proof %s", proof.id)
ProofPrediction.objects.filter(
proof=proof, model_name=GEMINI_MODEL_NAME
).delete()
else:
logger.debug(
"Proof %s already has a prediction for model %s",
proof.id,
GEMINI_MODEL_NAME,
)
return None

prediction = extract_from_receipt(image)

return ProofPrediction.objects.create(
proof=proof,
type=proof_constants.PROOF_PREDICTION_RECEIPT_EXTRACTION_TYPE,
model_name=GEMINI_MODEL_NAME,
model_version=GEMINI_MODEL_VERSION,
data=prediction,
)


def run_and_save_proof_prediction(
proof: Proof, run_price_tag_extraction: bool = True
proof: Proof,
run_price_tag_extraction: bool = True,
run_receipt_extraction: bool = True,
) -> None:
"""Run all ML models on a specific proof, and save the predictions in DB.
Currently, the following models are run:
- proof type classification model
- price tag detection model (object detector)
- price tag extraction model
- receipt extraction model
:param proof_id: the ID of the proof to be classified
:param run_price_tag_extraction: whether to run the price tag extraction
Expand All @@ -683,3 +767,5 @@ def run_and_save_proof_prediction(
run_and_save_price_tag_detection(
image, proof, run_extraction=run_price_tag_extraction
)
if run_receipt_extraction:
run_and_save_receipt_extraction_prediction(image, proof)
12 changes: 10 additions & 2 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,11 @@ def test_run_and_save_proof_prediction_for_receipt_proof(self):
return_value=None,
) as mock_detect_price_tags,
):
run_and_save_proof_prediction(proof, run_price_tag_extraction=False)
run_and_save_proof_prediction(
proof,
run_price_tag_extraction=False,
run_receipt_extraction=False,
)
mock_predict_proof_type.assert_called_once()
mock_detect_price_tags.assert_not_called()

Expand Down Expand Up @@ -513,7 +517,11 @@ def test_run_and_save_proof_prediction_for_price_tag_proof(self):
return_value=detect_price_tags_response,
) as mock_detect_price_tags,
):
run_and_save_proof_prediction(proof, run_price_tag_extraction=False)
run_and_save_proof_prediction(
proof,
run_price_tag_extraction=False,
run_receipt_extraction=False,
)
mock_predict_proof_type.assert_called_once()
mock_detect_price_tags.assert_called_once()

Expand Down

0 comments on commit 11dad87

Please sign in to comment.