Skip to content

Commit

Permalink
Merge pull request #28 from echo840/main
Browse files Browse the repository at this point in the history
add_ocrbench
  • Loading branch information
Luodian authored Mar 25, 2024
2 parents d7b207f + e00d0ca commit 9dfb53a
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lmms_eval/tasks/_task_utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os


def generate_submission_file(file_name, args):
path = os.path.join(args.output_path, "submissions")
def generate_submission_file(file_name, args, subpath="submissions"):
path = os.path.join(args.output_path, subpath)
os.makedirs(path, exist_ok=True)
path = os.path.join(path, file_name)
return os.path.abspath(path)
22 changes: 22 additions & 0 deletions lmms_eval/tasks/ocrbench/ocrbench.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
dataset_path: echo840/OCRBench
dataset_kwargs:
token: True
task: "ocrbench"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.ocrbench_doc_to_visual
doc_to_text: !function utils.ocrbench_doc_to_text
doc_to_target: "answer"
generation_kwargs:
max_new_tokens: 128
temperature: 0
top_p: 0
num_beams: 1
do_sample: false
process_results: !function utils.ocrbench_process_results
metric_list:
- metric: ocrbench_accuracy
aggregation: !function utils.ocrbench_aggregate_accuracy
higher_is_better: true
metadata:
- version: 0.0
94 changes: 94 additions & 0 deletions lmms_eval/tasks/ocrbench/upload_ocrbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json

import datasets
from PIL import Image as PIL_Image
import json
from uuid import uuid4
from datasets import Dataset, Features
import pandas as pd
from tqdm import tqdm
import io

# Find for instance the citation on arxiv or on the dataset repo/website
_CITATION = """https://arxiv.org/abs/2305.07895"""
_DESCRIPTION = "OCRBench is a comprehensive evaluation benchmark designed to assess the OCR capabilities of Large Multimodal Models."


def image2byte(image):
img_bytes = io.BytesIO()
image.save(img_bytes, format="JPEG")
image_bytes = img_bytes.getvalue()
return image_bytes


def get_builder_config(VERSION):
builder_config = [
datasets.BuilderConfig(
name=f"ocrbench",
version=VERSION,
description=f"ocrbench",
)
]
return builder_config


ocrbench_json = "pathto/OCRBench/OCRBench.json"
img_dir = "pathto/OCRBench_Images/"

dataset_features = Features(
{
"dataset": datasets.Value("string"),
"question": datasets.Value("string"),
"question_type": datasets.Value("string"),
"answer": datasets.features.Sequence(datasets.Value("string")),
"image": datasets.Image(),
}
)

df_items = {
"dataset": [],
"question": [],
"question_type": [],
"answer": [],
"image": [],
}
# img_feature = datasets.Image(decode=False)
with open(ocrbench_json, "r") as f:
data = json.load(f)
for i in tqdm(range(len(data))):
dataset_name = data[i]["dataset_name"]
image_path = img_dir + data[i]["image_path"]
question = data[i]["question"]
answers = data[i]["answers"]
question_type = data[i]["type"]
if type(answers) == str:
answers = [answers]
img = PIL_Image.open(image_path).convert("RGB")
byte_data = image2byte(img)
image = {"bytes": byte_data, "path": ""}
df_items["image"].append(image)
df_items["question"].append(str(question))
df_items["answer"].append(answers)
df_items["question_type"].append(str(question_type))
df_items["dataset"].append(str(dataset_name))

df_items = pd.DataFrame(df_items)
df_items.head()
dataset = Dataset.from_pandas(df_items, features=dataset_features)
hub_dataset_path = "echo840/OCRBench"
dataset.push_to_hub(repo_id=hub_dataset_path, split="test")
103 changes: 103 additions & 0 deletions lmms_eval/tasks/ocrbench/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

logger = logging.getLogger("lmms-eval")

# Add the following functions to your existing utils.py file
OCRBench_score = {
"Regular Text Recognition": 0,
"Irregular Text Recognition": 0,
"Artistic Text Recognition": 0,
"Handwriting Recognition": 0,
"Digit String Recognition": 0,
"Non-Semantic Text Recognition": 0,
"Scene Text-centric VQA": 0,
"Doc-oriented VQA": 0,
"Key Information Extraction": 0,
"Handwritten Mathematical Expression Recognition": 0,
}


def ocrbench_doc_to_visual(doc):
# Assuming the 'doc' dictionary has a key 'image' with image data
return [doc["image"].convert("RGB")]


def ocrbench_doc_to_text(doc):
# Assuming the 'doc' dictionary has a key 'question' with the question text
question = doc["question"].strip()
return f"{question}"


def ocrbench_process_results(doc, results):
pred = results[0].lower().strip()
gt_ans = doc["answer"]
dataset_name = doc["dataset"]

score = 0
if dataset_name == "HME100k":
if type(gt_ans) == list:
for j in range(len(gt_ans)):
answer = gt_ans[j].strip().replace("\n", " ").replace(" ", "")
predict = pred.strip().replace("\n", " ").replace(" ", "")
if answer in predict:
score = 1
else:
answer = gt_ans.strip().replace("\n", " ").replace(" ", "")
predict = pred.strip().replace("\n", " ").replace(" ", "")
if answer in predict:
score = 1
else:
if type(gt_ans) == list:
for j in range(len(gt_ans)):
answer = gt_ans[j].lower().strip().replace("\n", " ")
predict = pred.lower().strip().replace("\n", " ")
if answer in predict:
score = 1
else:
answer = gt_ans.lower().strip().replace("\n", " ")
predict = pred.lower().strip().replace("\n", " ")
if answer in predict:
score = 1
return {
"ocrbench_accuracy": {"question_type": doc["question_type"], "score": score, "prediction": pred, "ground_truth": gt_ans},
}


def ocrbench_aggregate_accuracy(results, args):
for result in results:
OCRBench_score[result["question_type"]] += result["score"]
recognition_score = (
OCRBench_score["Regular Text Recognition"]
+ OCRBench_score["Irregular Text Recognition"]
+ OCRBench_score["Artistic Text Recognition"]
+ OCRBench_score["Handwriting Recognition"]
+ OCRBench_score["Digit String Recognition"]
+ OCRBench_score["Non-Semantic Text Recognition"]
)
Final_score = recognition_score + OCRBench_score["Scene Text-centric VQA"] + OCRBench_score["Doc-oriented VQA"] + OCRBench_score["Key Information Extraction"] + OCRBench_score["Handwritten Mathematical Expression Recognition"]
file_name = generate_submission_file("ocrbench_results.txt", args, subpath="results")
with open(file_name, "w") as f:
print("######################### OCRBench #############################", file=f)
print(f"Text Recognition(Total 300): {recognition_score}", file=f)
print("---------------- Details of Recognition Score ------------------", file=f)
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}", file=f)
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}", file=f)
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}", file=f)
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}", file=f)
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}", file=f)
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}", file=f)
print("----------------------------------------------------------------", file=f)
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}", file=f)
print("----------------------------------------------------------------", file=f)
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}", file=f)
print("----------------------------------------------------------------", file=f)
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}", file=f)
print("----------------------------------------------------------------")
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}", file=f)
print("--------------------- Final Score ------------------------------", file=f)
print(f"Final Score(Total 1000): {Final_score}", file=f)
logger.info(f"OCR Bench results saved to {file_name}")
# return {"Final Score":Final_score,"Text Recognition":recognition_score,'Scene Text-centric VQA':OCRBench_score['Scene Text-centric VQA'],'Doc-oriented VQA':OCRBench_score['Doc-oriented VQA'],'Key Information Extraction':OCRBench_score['Key Information Extraction'],'Handwritten Mathematical Expression Recognition':OCRBench_score['Handwritten Mathematical Expression Recognition']}
return Final_score

0 comments on commit 9dfb53a

Please sign in to comment.