Skip to content

Commit

Permalink
add post process openai
Browse files Browse the repository at this point in the history
  • Loading branch information
mrzaizai2k committed Sep 13, 2024
1 parent a3b395b commit 4db5bbe
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 29 deletions.
23 changes: 4 additions & 19 deletions config/invoice_template.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
```json
{
"invoice_info": {
"amount": ,
Expand Down Expand Up @@ -59,7 +60,6 @@
"order_number": "",
"table_number": "",
"table_group": "",
"server": "",
"merchant_name": "",
"merchant_id": "",
"merchant_coc_number": "",
Expand All @@ -73,32 +73,17 @@
"merchant_website": "",
"merchant_email": "",
"merchant_address": "",
"merchant_street_name": "",
"merchant_house_number": "",
"merchant_city": "",
"merchant_municipality": "",
"merchant_province": "",
"merchant_country": "",
"merchant_country_code": "",
"merchant_phone": "",
"merchant_main_activity_code": "",
"customer_name": "",
"customer_number": "",
"customer_reference": "",
"customer_address": "",
"customer_street_name": "",
"customer_house_number": "",
"customer_city": "",
"customer_municipality": "",
"customer_province": "",
"customer_country": "",
"customer_phone": "",
"customer_website": "",
"customer_vat_number": "",
"customer_coc_number": "",
"customer_bank_account_number": "",
"customer_bank_account_number_bic": "",
"customer_website": "",
"customer_email": "",
"document_language": "",
},
}
}
```
4 changes: 3 additions & 1 deletion config/page_1_template.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
```json
{
"invoice_info": {
"name": "Tümmler, Dirk",
Expand Down Expand Up @@ -102,4 +103,5 @@
"sign_date": "13.08.2024",
"has_employee_signature": true
}
}
}
```
4 changes: 3 additions & 1 deletion config/page_2_template.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
```json
{
"invoice_info": {
"name": "Tümmler, Dirk",
Expand Down Expand Up @@ -39,4 +40,5 @@
],
"has_employee_signature": true
}
}
}
```
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ pyyaml
pymongo
uvicorn
pytesseract
rank_bm25
pydantic
25 changes: 22 additions & 3 deletions src/base_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def postprocess(self, invoice_template:str,

# Prepare the prompt for the LLM
prompt = f"""
You are a helpful assistant that responds in JSON format with the invoice information in English. Don't add any annotations there. Remember to close any bracket. And just output the field that has value, don't return field that are empty.
You are a helpful assistant that responds in JSON format with the invoice information in English. Don't add any annotations there.
Remember to close any bracket. And just output the field that has value, don't return field that are empty. return the key names as in the template is a MUST.
Use the text from the model response and the text from OCR. Describe what's in the image as the template here:
{invoice_template}.
The OCR text is: {ocr_text}
Expand Down Expand Up @@ -103,19 +104,35 @@ def __init__(self, config_path: str = "config/config.yaml"):
self.client = OpenAI(api_key=self.OPENAI_API_KEY)

def _extract_invoice_llm(self, ocr_text, base64_image:str, invoice_template:str):

response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant that responds in JSON format with the invoice information in English. Don't add any annotations there. Remember to close any bracket. And just output the field that has value, don't return field that are empty."},
{"role": "user", "content": [
{"type": "text", "text": f"From the image of the bill and the text from OCR, extract the information. The ocr text is: {ocr_text} \n The invoice template: \n {invoice_template}"},
{"type": "text", "text": f"From the image of the bill and the text from OCR, extract the information. The ocr text is: {ocr_text} \n."},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}
]}
],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
return response.choices[0].message.content

def postprocess(self, invoice_template:str,
ocr_text: str, model_text: str) -> str:

response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant that responds in JSON format with the invoice information in English as in invoice template. Don't add any annotations there. Remember to close any bracket. And just output the field that has value, don't return field that are empty."},
{"role": "user", "content": [
{"type": "text", "text": f"From the model text and the text from OCR, Fill in the value to the invoice template. The invoice template: \n {invoice_template} \n The ocr text is: {ocr_text} \n. The model text is {model_text}"},
]}
],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
return response.choices[0].message.content

def extract_json(self, text: str) -> dict:
start_index = text.find('{')
Expand All @@ -125,11 +142,13 @@ def extract_json(self, text: str) -> dict:
result = eval(json_string)
return result

@timeit
@retry_on_failure(max_retries=3, delay=1.0)
def extract_invoice(self, ocr_text, image: Union[str, np.ndarray], invoice_template:str) -> dict:
base64_image = self.encode_image(image)
invoice_info = self._extract_invoice_llm(ocr_text, base64_image,
invoice_template=invoice_template)
invoice_info = self.postprocess(ocr_text=ocr_text, model_text=invoice_info, invoice_template=invoice_template)
invoice_info = self.extract_json(invoice_info)
return invoice_info

Expand Down
21 changes: 17 additions & 4 deletions src/invoice_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def get_document_type(ocr_result: dict, config: dict) -> str:

def get_document_template(document_type:str, config:dict):
invoice_dict=config['invoice_dict']
return invoice_dict[document_type]
invoice_template = read_txt_file(invoice_dict[document_type])
return invoice_template

def extract_invoice_info(base64_img:str, ocr_reader:OcrReader, invoice_extractor:BaseExtractor, config:dict) -> dict:
result = {}
Expand All @@ -108,29 +109,41 @@ def extract_invoice_info(base64_img:str, ocr_reader:OcrReader, invoice_extractor
rotate_image = ocr_reader.get_rotated_image(pil_img)
invoice_info = invoice_extractor.extract_invoice(ocr_text=ocr_result['text'], image=rotate_image,
invoice_template=invoice_template)

invoice_info = validate_invoice(invoice_info, invoice_type)
result['translator'] = ocr_reader['translator']
result['ocr_detector'] = ocr_reader['ocr_detector']
result['invoice_info'] = invoice_info
result['invoice_type'] = invoice_type
result['ocr_info'] = ocr_result
result['llm_extractor'] = invoice_extractor['llm_extractor']
result['post_processor'] = invoice_extractor['post_processor']

result["last_modified_at"] = get_current_time(timezone=config['timezone'])
result["status"] = "completed"

return result



def validate_invoice(invoice_info:dict, invoice_type:str) ->dict:
if invoice_type == "invoice 3":
return invoice_info

return invoice_info



if __name__ == "__main__":
config_path = "config/config.yaml"
config = read_config(config_path)

ocr_reader = OcrReader(config_path=config_path, translator=GoogleTranslator())
invoice_extractor = OpenAIExtractor(config_path=config_path)
img_path = "test/images/page_9.png"
img_path = "test/images/page_6.png"
base64_img = convert_img_path_to_base64(img_path)
result = extract_invoice_info(base64_img=base64_img, ocr_reader=ocr_reader,
invoice_extractor=invoice_extractor, config=config)
print("info", result)
print("info", result['invoice_info'])


Empty file added src/validate_invoice.py
Empty file.

0 comments on commit 4db5bbe

Please sign in to comment.