Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackthon7th No51】add the MixTex model into PaddleOCR #14417

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions configs/rec/MixTex.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
Global:
use_gpu: False
epoch_num: 2
log_smooth_window: 20 # for logging metrics during training procedure
print_batch_step: 10
save_model_dir: ./output/MixTex
save_epoch_step: 5
max_seq_len: 768
eval_batch_step: [0, 5]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/datasets/pme_demo/0000013.png
infer_mode: False
use_space_char: False
rec_char_dict_path: ./ppocr/utils/dict/mixtex
save_res_path: ./output/rec/predicts_mixtex.txt
d2s_train_image_shape: [3, 400, 500]


Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs : [3]
values : [0.0005, 0.00005]
warmup_epoch: 1
regularizer:
name: L2
factor: 3.0e-05

Architecture:
model_type: rec
algorithm: MixTex
in_channels: 3
Transform:
Backbone:
name: SwinTransformer_tiny_patch4_window7_224
img_size: 224
patch_size: 4
num_classes: 25678 # class num of vob
input_channel:
is_predict: False
is_export: False
Head:
name: RobertHead
pad_value: 1
is_export: False
decoder_args:
vocab_size: 25681
cross_attend: True
rel_pos_bias: False
use_scalenorm: False
attention_probs_dropout_prob: 0.1
bos_token_id: 0
chunk_size_feed_forward: 0
diversity_penalty: 0.0
do_sample: False
eos_token_id: 2
hidden_act: gelu
hidden_dropout_prob: 0.1
hidden_size: 768
max_position_embeddings: 770
# max_position_embeddings: 1000
num_attention_heads: 12
num_hidden_layers: 4
pad_token_id: 1
temperature: 1.0
# tie_word_embeddings: True
top_k: 50
top_p: 1.0
intermediate_size: 3072
type_vocab_size: 1
initializer_range: 0.02


Loss:
name: MixTexLoss

PostProcess:
name: MixTexDecode
rec_char_dict_path: ./ppocr/utils/dict/mixtex

Metric:
name: MixTexMetric
main_indicator: exp_rate
cal_blue_score: True

Train:
dataset:
name: MixTexDataSet
data_dir: D:/study/dl/MixTex/data/Pseudo-Latext-ZhEn
batch_size_per_pair: 24
transforms:
- RecResizeImg:
image_shape: [3, 400, 500]
- RescaleImage:
scale: 0.00392156862745098
- KeepKeys:
keep_keys: ['image']
loader:
shuffle: True
batch_size_per_card: 10
drop_last: False
num_workers: 0
collate_fn: MixTexCollator

Eval:
dataset:
name: MixTexDataSet
data_dir: D:/study/dl/MixTex/data/Pseudo-Latext-ZhEn
data:
batch_size_per_pair: 24
transforms:
- RecResizeImg:
image_shape: [3, 400, 500]
- RescaleImage:
scale: 0.00392156862745098
- KeepKeys:
keep_keys: ['image']
loader:
shuffle: True
batch_size_per_card: 10
drop_last: False
num_workers: 0
2 changes: 2 additions & 0 deletions ppocr/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ppocr.data.pubtab_dataset import PubTabDataSet
from ppocr.data.multi_scale_sampler import MultiScaleSampler
from ppocr.data.latexocr_dataset import LaTeXOCRDataSet
from ppocr.data.mixtex_dataset import MixTexDataSet

# for PaddleX dataset_type
TextDetDataset = SimpleDataSet
Expand Down Expand Up @@ -97,6 +98,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
"PubTabTableRecDataset",
"KieDataset",
"LaTeXOCRDataSet",
"MixTexDataSet",
]
module_name = config[mode]["dataset"]["name"]
assert module_name in support_dict, Exception(
Expand Down
14 changes: 14 additions & 0 deletions ppocr/data/collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ def __call__(self, batch):
return images, labels, attention_mask


class MixTexCollator(object):
"""
batch: [
image [batch_size, channel, maxHinbatch, maxWinbatch]
label [batch_size, maxLabelLen]
...
]
"""

def __call__(self, batch):
images, labels = batch[0]
return images, labels


class UniMERNetCollator(object):
"""
batch: [
Expand Down
1 change: 1 addition & 0 deletions ppocr/data/imaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
RFLRecResizeImg,
SVTRRecAug,
ParseQRecAug,
RescaleImage,
)
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
Expand Down
20 changes: 20 additions & 0 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,7 @@ def encode(
)
for encoding in encodings
]

sanitized_tokens = {}
for key in tokens_and_encodings[0][0].keys():
stack = [e for item, _ in tokens_and_encodings for e in item[key]]
Expand Down Expand Up @@ -2191,3 +2192,22 @@ def __call__(self, data):
data["label"] = np.array(topk["input_ids"]).astype(np.int64)[0]
data["attention_mask"] = np.array(topk["attention_mask"]).astype(np.int64)[0]
return data


class MixTexLabelEncode:
def __init__(
self,
rec_char_dict_path,
**kwargs,
):
from paddlenlp.transformers.roberta.tokenizer import RobertaTokenizer

self.tokenizer = RobertaTokenizer.from_pretrained(
pretrained_model_name_or_path=rec_char_dict_path
)

def __call__(
self, target_text, padding="max_length", max_length=256, truncation=True
):
target = self.tokenizer(target_text, padding, max_length, truncation).input_ids
return target
13 changes: 13 additions & 0 deletions ppocr/data/imaug/rec_img_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,19 @@ def __call__(self, data):
return data


class RescaleImage(object):
def __init__(self, scale, dtype=np.float32, **kwargs):
self.scale = scale
self.dtype = np.float32

def __call__(self, data):
img = data["image"]
rescaled_image = img * self.scale
rescaled_image = rescaled_image.astype(self.dtype)
data = {"image": rescaled_image}
return data


def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
Expand Down
78 changes: 78 additions & 0 deletions ppocr/data/mixtex_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from datasets import load_dataset

import paddle
from paddle.io import Dataset
from .imaug.label_ops import MixTexLabelEncode
from .imaug import transform, create_operators

from paddlenlp.transformers.roberta.tokenizer import RobertaTokenizer


class MixTexDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(MixTexDataSet, self).__init__()
self.logger = logger
self.mode = mode.lower()

global_config = config["Global"]
dataset_config = config[mode]["dataset"]
loader_config = config[mode]["loader"]

self.data_dir = dataset_config["data_dir"]
self.image_size = global_config["d2s_train_image_shape"]
self.batchsize = dataset_config["batch_size_per_pair"]
self.max_seq_len = global_config["max_seq_len"]
self.rec_char_dict_path = global_config["rec_char_dict_path"]
self.tokenizer = MixTexLabelEncode(self.rec_char_dict_path)

self.dataframe = load_dataset(self.data_dir)

self.ops = create_operators(dataset_config["transforms"], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2)
self.need_reset = True

def __getitem__(self, idx):
image = self.dataframe["train"][idx]["image"].convert("RGB")
image = np.asarray(image)
data = {"image": image}
pixel_values = transform(data, self.ops)
target_text = self.dataframe["train"][idx]["text"]
target = self.tokenizer.tokenizer(
target_text,
padding="max_length",
max_length=self.max_seq_len,
truncation=True,
).input_ids
labels = [
label if label != self.tokenizer.tokenizer.pad_token_id else 1
for label in target
]
labels = np.array(labels)

pixel_values = np.array(pixel_values).reshape(
(
len(pixel_values),
pixel_values[0].shape[0],
pixel_values[0].shape[1],
pixel_values[0].shape[2],
)
)
return (pixel_values, labels)

def __len__(self):
return len(self.dataframe["train"])
2 changes: 2 additions & 0 deletions ppocr/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from .rec_cppd_loss import CPPDLoss
from .rec_latexocr_loss import LaTeXOCRLoss
from .rec_unimernet_loss import UniMERNetLoss
from .rec_mixtex_loss import MixTexLoss

# cls loss
from .cls_loss import ClsLoss
Expand Down Expand Up @@ -110,6 +111,7 @@ def build_loss(config):
"ParseQLoss",
"CPPDLoss",
"LaTeXOCRLoss",
"MixTexLoss",
"UniMERNetLoss",
]
config = copy.deepcopy(config)
Expand Down
48 changes: 48 additions & 0 deletions ppocr/losses/rec_mixtex_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This code is refer from:
https://github.com/lucidrains/x-transformers/blob/main/x_transformers/autoregressive_wrapper.py
"""

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np


class MixTexLoss(nn.Layer):
"""
MixTex adopt CrossEntropyLoss for network training.
"""

def __init__(self):
super(MixTexLoss, self).__init__()
self.ignore_index = 1
self.cross = nn.CrossEntropyLoss(
reduction="mean", ignore_index=self.ignore_index
)

def forward(self, preds, batch):
word_probs = preds
labels = batch[1]
labels = paddle.to_tensor(labels, dtype=paddle.int32)
word_loss = self.cross(
paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
paddle.reshape(labels, [-1]),
)

loss = word_loss
return {"loss": loss}
3 changes: 2 additions & 1 deletion ppocr/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
__all__ = ["build_metric"]

from .det_metric import DetMetric, DetFCEMetric
from .rec_metric import RecMetric, CNTMetric, CANMetric, LaTeXOCRMetric
from .rec_metric import RecMetric, CNTMetric, CANMetric, LaTeXOCRMetric, MixTexMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
Expand Down Expand Up @@ -51,6 +51,7 @@ def build_metric(config):
"CNTMetric",
"CANMetric",
"LaTeXOCRMetric",
"MixTexMetric",
]

config = copy.deepcopy(config)
Expand Down
Loading