-
Notifications
You must be signed in to change notification settings - Fork 202
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #28 from echo840/main
add_ocrbench
- Loading branch information
Showing
4 changed files
with
221 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
import os | ||
|
||
|
||
def generate_submission_file(file_name, args): | ||
path = os.path.join(args.output_path, "submissions") | ||
def generate_submission_file(file_name, args, subpath="submissions"): | ||
path = os.path.join(args.output_path, subpath) | ||
os.makedirs(path, exist_ok=True) | ||
path = os.path.join(path, file_name) | ||
return os.path.abspath(path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
dataset_path: echo840/OCRBench | ||
dataset_kwargs: | ||
token: True | ||
task: "ocrbench" | ||
test_split: test | ||
output_type: generate_until | ||
doc_to_visual: !function utils.ocrbench_doc_to_visual | ||
doc_to_text: !function utils.ocrbench_doc_to_text | ||
doc_to_target: "answer" | ||
generation_kwargs: | ||
max_new_tokens: 128 | ||
temperature: 0 | ||
top_p: 0 | ||
num_beams: 1 | ||
do_sample: false | ||
process_results: !function utils.ocrbench_process_results | ||
metric_list: | ||
- metric: ocrbench_accuracy | ||
aggregation: !function utils.ocrbench_aggregate_accuracy | ||
higher_is_better: true | ||
metadata: | ||
- version: 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import json | ||
|
||
import datasets | ||
from PIL import Image as PIL_Image | ||
import json | ||
from uuid import uuid4 | ||
from datasets import Dataset, Features | ||
import pandas as pd | ||
from tqdm import tqdm | ||
import io | ||
|
||
# Find for instance the citation on arxiv or on the dataset repo/website | ||
_CITATION = """https://arxiv.org/abs/2305.07895""" | ||
_DESCRIPTION = "OCRBench is a comprehensive evaluation benchmark designed to assess the OCR capabilities of Large Multimodal Models." | ||
|
||
|
||
def image2byte(image): | ||
img_bytes = io.BytesIO() | ||
image.save(img_bytes, format="JPEG") | ||
image_bytes = img_bytes.getvalue() | ||
return image_bytes | ||
|
||
|
||
def get_builder_config(VERSION): | ||
builder_config = [ | ||
datasets.BuilderConfig( | ||
name=f"ocrbench", | ||
version=VERSION, | ||
description=f"ocrbench", | ||
) | ||
] | ||
return builder_config | ||
|
||
|
||
ocrbench_json = "pathto/OCRBench/OCRBench.json" | ||
img_dir = "pathto/OCRBench_Images/" | ||
|
||
dataset_features = Features( | ||
{ | ||
"dataset": datasets.Value("string"), | ||
"question": datasets.Value("string"), | ||
"question_type": datasets.Value("string"), | ||
"answer": datasets.features.Sequence(datasets.Value("string")), | ||
"image": datasets.Image(), | ||
} | ||
) | ||
|
||
df_items = { | ||
"dataset": [], | ||
"question": [], | ||
"question_type": [], | ||
"answer": [], | ||
"image": [], | ||
} | ||
# img_feature = datasets.Image(decode=False) | ||
with open(ocrbench_json, "r") as f: | ||
data = json.load(f) | ||
for i in tqdm(range(len(data))): | ||
dataset_name = data[i]["dataset_name"] | ||
image_path = img_dir + data[i]["image_path"] | ||
question = data[i]["question"] | ||
answers = data[i]["answers"] | ||
question_type = data[i]["type"] | ||
if type(answers) == str: | ||
answers = [answers] | ||
img = PIL_Image.open(image_path).convert("RGB") | ||
byte_data = image2byte(img) | ||
image = {"bytes": byte_data, "path": ""} | ||
df_items["image"].append(image) | ||
df_items["question"].append(str(question)) | ||
df_items["answer"].append(answers) | ||
df_items["question_type"].append(str(question_type)) | ||
df_items["dataset"].append(str(dataset_name)) | ||
|
||
df_items = pd.DataFrame(df_items) | ||
df_items.head() | ||
dataset = Dataset.from_pandas(df_items, features=dataset_features) | ||
hub_dataset_path = "echo840/OCRBench" | ||
dataset.push_to_hub(repo_id=hub_dataset_path, split="test") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import logging | ||
|
||
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file | ||
|
||
logger = logging.getLogger("lmms-eval") | ||
|
||
# Add the following functions to your existing utils.py file | ||
OCRBench_score = { | ||
"Regular Text Recognition": 0, | ||
"Irregular Text Recognition": 0, | ||
"Artistic Text Recognition": 0, | ||
"Handwriting Recognition": 0, | ||
"Digit String Recognition": 0, | ||
"Non-Semantic Text Recognition": 0, | ||
"Scene Text-centric VQA": 0, | ||
"Doc-oriented VQA": 0, | ||
"Key Information Extraction": 0, | ||
"Handwritten Mathematical Expression Recognition": 0, | ||
} | ||
|
||
|
||
def ocrbench_doc_to_visual(doc): | ||
# Assuming the 'doc' dictionary has a key 'image' with image data | ||
return [doc["image"].convert("RGB")] | ||
|
||
|
||
def ocrbench_doc_to_text(doc): | ||
# Assuming the 'doc' dictionary has a key 'question' with the question text | ||
question = doc["question"].strip() | ||
return f"{question}" | ||
|
||
|
||
def ocrbench_process_results(doc, results): | ||
pred = results[0].lower().strip() | ||
gt_ans = doc["answer"] | ||
dataset_name = doc["dataset"] | ||
|
||
score = 0 | ||
if dataset_name == "HME100k": | ||
if type(gt_ans) == list: | ||
for j in range(len(gt_ans)): | ||
answer = gt_ans[j].strip().replace("\n", " ").replace(" ", "") | ||
predict = pred.strip().replace("\n", " ").replace(" ", "") | ||
if answer in predict: | ||
score = 1 | ||
else: | ||
answer = gt_ans.strip().replace("\n", " ").replace(" ", "") | ||
predict = pred.strip().replace("\n", " ").replace(" ", "") | ||
if answer in predict: | ||
score = 1 | ||
else: | ||
if type(gt_ans) == list: | ||
for j in range(len(gt_ans)): | ||
answer = gt_ans[j].lower().strip().replace("\n", " ") | ||
predict = pred.lower().strip().replace("\n", " ") | ||
if answer in predict: | ||
score = 1 | ||
else: | ||
answer = gt_ans.lower().strip().replace("\n", " ") | ||
predict = pred.lower().strip().replace("\n", " ") | ||
if answer in predict: | ||
score = 1 | ||
return { | ||
"ocrbench_accuracy": {"question_type": doc["question_type"], "score": score, "prediction": pred, "ground_truth": gt_ans}, | ||
} | ||
|
||
|
||
def ocrbench_aggregate_accuracy(results, args): | ||
for result in results: | ||
OCRBench_score[result["question_type"]] += result["score"] | ||
recognition_score = ( | ||
OCRBench_score["Regular Text Recognition"] | ||
+ OCRBench_score["Irregular Text Recognition"] | ||
+ OCRBench_score["Artistic Text Recognition"] | ||
+ OCRBench_score["Handwriting Recognition"] | ||
+ OCRBench_score["Digit String Recognition"] | ||
+ OCRBench_score["Non-Semantic Text Recognition"] | ||
) | ||
Final_score = recognition_score + OCRBench_score["Scene Text-centric VQA"] + OCRBench_score["Doc-oriented VQA"] + OCRBench_score["Key Information Extraction"] + OCRBench_score["Handwritten Mathematical Expression Recognition"] | ||
file_name = generate_submission_file("ocrbench_results.txt", args, subpath="results") | ||
with open(file_name, "w") as f: | ||
print("######################### OCRBench #############################", file=f) | ||
print(f"Text Recognition(Total 300): {recognition_score}", file=f) | ||
print("---------------- Details of Recognition Score ------------------", file=f) | ||
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}", file=f) | ||
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}", file=f) | ||
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}", file=f) | ||
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}", file=f) | ||
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}", file=f) | ||
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}", file=f) | ||
print("----------------------------------------------------------------", file=f) | ||
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}", file=f) | ||
print("----------------------------------------------------------------", file=f) | ||
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}", file=f) | ||
print("----------------------------------------------------------------", file=f) | ||
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}", file=f) | ||
print("----------------------------------------------------------------") | ||
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}", file=f) | ||
print("--------------------- Final Score ------------------------------", file=f) | ||
print(f"Final Score(Total 1000): {Final_score}", file=f) | ||
logger.info(f"OCR Bench results saved to {file_name}") | ||
# return {"Final Score":Final_score,"Text Recognition":recognition_score,'Scene Text-centric VQA':OCRBench_score['Scene Text-centric VQA'],'Doc-oriented VQA':OCRBench_score['Doc-oriented VQA'],'Key Information Extraction':OCRBench_score['Key Information Extraction'],'Handwritten Mathematical Expression Recognition':OCRBench_score['Handwritten Mathematical Expression Recognition']} | ||
return Final_score |