diff --git a/.github/ISSUE_TEMPLATE/custom.md b/.github/ISSUE_TEMPLATE/custom.md new file mode 100644 index 0000000..a6d4478 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/custom.md @@ -0,0 +1,19 @@ +--- +name: Issue template +about: Issue template for code error. +title: '' +labels: '' +assignees: '' + +--- + +请提供下述完整信息以便快速定位问题/Please provide the following information to quickly locate the problem + +- 系统环境/System Environment: +- 版本号/Version:Paddle: PaddleOCR: 问题相关组件/Related components: +- 运行指令/Command Code: +- 完整报错/Complete Error Message: + +我们提供了AceIssueSolver来帮助你解答问题,你是否想要它来解答(请填写yes/no)?/We provide AceIssueSolver to solve issues, do you want it? (Please write yes/no): + +请尽量不要包含图片在问题中/Please try to not include the image in the issue. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..ca62dac --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,15 @@ +### PR 类型 PR types + + +### PR 变化内容类型 PR changes + + +### 描述 Description + + +### 提PR之前的检查 Check-list + +- [ ] 这个 PR 是提交到dygraph分支或者是一个cherry-pick,否则请先提交到dygarph分支。 + This PR is pushed to the dygraph branch or cherry-picked from the dygraph branch. Otherwise, please push your changes to the dygraph branch. +- [ ] 这个PR清楚描述了功能,帮助评审能提升效率。This PR have fully described what it does such that reviewers can speedup. +- [ ] 这个PR已经经过本地测试。This PR can be convered by current tests or already test locally by you. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3a05fb7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,34 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +.ipynb_checkpoints/ +*.py[cod] +*$py.class + +# C extensions +*.so + +inference/ +inference_results/ +output/ +train_data/ +log/ +*.DS_Store +*.vs +*.user +*~ +*.vscode +*.idea + +*.log +.clang-format +.clang_format.hook + +build/ +dist/ +paddleocr.egg-info/ +/deploy/android_demo/app/OpenCV/ +/deploy/android_demo/app/PaddleLite/ +/deploy/android_demo/app/.cxx/ +/deploy/android_demo/app/cache/ +test_tipc/web/models/ +test_tipc/web/node_modules/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..5f7fec8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,35 @@ +- repo: https://github.com/PaddlePaddle/mirrors-yapf.git + sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 + hooks: + - id: yapf + files: \.py$ +- repo: https://github.com/pre-commit/pre-commit-hooks + sha: a11d9314b22d8f8c7556443875b731ef05965464 + hooks: + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + files: (?!.*paddle)^.*$ + - id: end-of-file-fixer + files: \.md$ + - id: trailing-whitespace + files: \.md$ +- repo: https://github.com/Lucas-C/pre-commit-hooks + sha: v1.0.1 + hooks: + - id: forbid-crlf + files: \.md$ + - id: remove-crlf + files: \.md$ + - id: forbid-tabs + files: \.md$ + - id: remove-tabs + files: \.md$ +- repo: local + hooks: + - id: clang-format + name: clang-format + description: Format files with ClangFormat + entry: bash .clang_format.hook -i + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ \ No newline at end of file diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000..4741fb4 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,3 @@ +[style] +based_on_style = pep8 +column_limit = 80 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5fe8694 --- /dev/null +++ b/LICENSE @@ -0,0 +1,203 @@ +Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..f821618 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,10 @@ +include LICENSE +include README.md + +recursive-include ppocr/utils *.* +recursive-include ppocr/data *.py +recursive-include ppocr/postprocess *.py +recursive-include tools/infer *.py +recursive-include tools __init__.py +recursive-include ppocr/utils/e2e_utils *.py +recursive-include ppstructure *.py \ No newline at end of file diff --git a/StyleText/README.md b/StyleText/README.md new file mode 100644 index 0000000..eddedbd --- /dev/null +++ b/StyleText/README.md @@ -0,0 +1,219 @@ +English | [简体中文](README_ch.md) + +## Style Text + +### Contents +- [1. Introduction](#Introduction) +- [2. Preparation](#Preparation) +- [3. Quick Start](#Quick_Start) +- [4. Applications](#Applications) +- [5. Code Structure](#Code_structure) + + + +### Introduction + +
+ +
+ +
+ +
+ + +The Style-Text data synthesis tool is a tool based on Baidu and HUST cooperation research work, "Editing Text in the Wild" [https://arxiv.org/abs/1908.03047](https://arxiv.org/abs/1908.03047). + +Different from the commonly used GAN-based data synthesis tools, the main framework of Style-Text includes: +* (1) Text foreground style transfer module. +* (2) Background extraction module. +* (3) Fusion module. + +After these three steps, you can quickly realize the image text style transfer. The following figure is some results of the data synthesis tool. + +
+ +
+ + + +#### Preparation + +1. Please refer the [QUICK INSTALLATION](../doc/doc_en/installation_en.md) to install PaddlePaddle. Python3 environment is strongly recommended. +2. Download the pretrained models and unzip: + +```bash +cd StyleText +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip +unzip style_text_models.zip +``` + +If you save the model in another location, please modify the address of the model file in `configs/config.yml`, and you need to modify these three configurations at the same time: + +``` +bg_generator: + pretrain: style_text_models/bg_generator +... +text_generator: + pretrain: style_text_models/text_generator +... +fusion_generator: + pretrain: style_text_models/fusion_generator +``` + + +### Quick Start + +#### Synthesis single image + +1. You can run `tools/synth_image` and generate the demo image, which is saved in the current folder. + +```python +python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en +``` + +* Note 1: The language options is correspond to the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko). +* Note 2: Synth-Text is mainly used to generate images for OCR recognition models. + So the height of style images should be around 32 pixels. Images in other sizes may behave poorly. +* Note 3: You can modify `use_gpu` in `configs/config.yml` to determine whether to use GPU for prediction. + + + +For example, enter the following image and corpus `PaddleOCR`. + +
+ +
+ +The result `fake_fusion.jpg` will be generated. + +
+ +
+ +What's more, the medium result `fake_bg.jpg` will also be saved, which is the background output. + +
+ +
+ + +`fake_text.jpg` * `fake_text.jpg` is the generated image with the same font style as `Style Input`. + + +
+ +
+ + +#### Batch synthesis + +In actual application scenarios, it is often necessary to synthesize pictures in batches and add them to the training set. StyleText can use a batch of style pictures and corpus to synthesize data in batches. The synthesis process is as follows: + +1. The referenced dataset can be specifed in `configs/dataset_config.yml`: + + * `Global`: + * `output_dir:`:Output synthesis data path. + * `StyleSampler`: + * `image_home`:style images' folder. + * `label_file`:Style images' file list. If label is provided, then it is the label file path. + * `with_label`:Whether the `label_file` is label file list. + * `CorpusGenerator`: + * `method`:Method of CorpusGenerator,supports `FileCorpus` and `EnNumCorpus`. If `EnNumCorpus` is used,No other configuration is needed,otherwise you need to set `corpus_file` and `language`. + * `language`:Language of the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko). + * `corpus_file`: Filepath of the corpus. Corpus file should be a text file which will be split by line-endings('\n'). Corpus generator samples one line each time. + + +Example of corpus file: +``` +PaddleOCR +飞桨文字识别 +StyleText +风格文本图像数据合成 +``` + +We provide a general dataset containing Chinese, English and Korean (50,000 images in all) for your trial ([download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)), some examples are given below : + +
+ +
+ +2. You can run the following command to start synthesis task: + + ``` bash + python3 tools/synth_dataset.py -c configs/dataset_config.yml + ``` + +We also provide example corpus and images in `examples` folder. +
+ + +
+If you run the code above directly, you will get example output data in `output_data` folder. +You will get synthesis images and labels as below: +
+ +
+There will be some cache under the `label` folder. If the program exit unexpectedly, you can find cached labels there. +When the program finish normally, you will find all the labels in `label.txt` which give the final results. + + +### Applications +We take two scenes as examples, which are metal surface English number recognition and general Korean recognition, to illustrate practical cases of using StyleText to synthesize data to improve text recognition. The following figure shows some examples of real scene images and composite images: + +
+ +
+ + +After adding the above synthetic data for training, the accuracy of the recognition model is improved, which is shown in the following table: + + +| Scenario | Characters | Raw Data | Test Data | Only Use Raw Data
Recognition Accuracy | New Synthetic Data | Simultaneous Use of Synthetic Data
Recognition Accuracy | Index Improvement | +| -------- | ---------- | -------- | -------- | -------------------------- | ------------ | ---------------------- | -------- | +| Metal surface | English and numbers | 2203 | 650 | 59.38% | 20000 | 75.46% | 16.08% | +| Random background | Korean | 5631 | 1230 | 30.12% | 100000 | 50.57% | 20.45% | + + +### Code Structure + +``` +StyleText +|-- arch // Network module files. +| |-- base_module.py +| |-- decoder.py +| |-- encoder.py +| |-- spectral_norm.py +| `-- style_text_rec.py +|-- configs // Config files. +| |-- config.yml +| `-- dataset_config.yml +|-- engine // Synthesis engines. +| |-- corpus_generators.py // Sample corpus from file or generate random corpus. +| |-- predictors.py // Predict using network. +| |-- style_samplers.py // Sample style images. +| |-- synthesisers.py // Manage other engines to synthesis images. +| |-- text_drawers.py // Generate standard input text images. +| `-- writers.py // Write synthesis images and labels into files. +|-- examples // Example files. +| |-- corpus +| | `-- example.txt +| |-- image_list.txt +| `-- style_images +| |-- 1.jpg +| `-- 2.jpg +|-- fonts // Font files. +| |-- ch_standard.ttf +| |-- en_standard.ttf +| `-- ko_standard.ttf +|-- tools // Program entrance. +| |-- __init__.py +| |-- synth_dataset.py // Synthesis dataset. +| `-- synth_image.py // Synthesis image. +`-- utils // Module of basic functions. + |-- config.py + |-- load_params.py + |-- logging.py + |-- math_functions.py + `-- sys_funcs.py +``` diff --git a/StyleText/README_ch.md b/StyleText/README_ch.md new file mode 100644 index 0000000..7818f2d --- /dev/null +++ b/StyleText/README_ch.md @@ -0,0 +1,205 @@ +简体中文 | [English](README.md) + +## Style Text + + +### 目录 +- [一、工具简介](#工具简介) +- [二、环境配置](#环境配置) +- [三、快速上手](#快速上手) +- [四、应用案例](#应用案例) +- [五、代码结构](#代码结构) + + +### 一、工具简介 +
+ +
+ +
+ +
+ + +Style-Text数据合成工具是基于百度和华科合作研发的文本编辑算法《Editing Text in the Wild》https://arxiv.org/abs/1908.03047 + +不同于常用的基于GAN的数据合成工具,Style-Text主要框架包括:1.文本前景风格迁移模块 2.背景抽取模块 3.融合模块。经过这样三步,就可以迅速实现图像文本风格迁移。下图是一些该数据合成工具效果图。 + +
+ +
+ + +### 二、环境配置 + +1. 参考[快速安装](../doc/doc_ch/installation.md),安装PaddleOCR。 +2. 进入`StyleText`目录,下载模型,并解压: + +```bash +cd StyleText +wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip +unzip style_text_models.zip +``` + +如果您将模型保存再其他位置,请在`configs/config.yml`中修改模型文件的地址,修改时需要同时修改这三个配置: + +``` +bg_generator: + pretrain: style_text_models/bg_generator +... +text_generator: + pretrain: style_text_models/text_generator +... +fusion_generator: + pretrain: style_text_models/fusion_generator +``` + + +### 三、快速上手 + +#### 合成单张图 +输入一张风格图和一段文字语料,运行tools/synth_image,合成单张图片,结果图像保存在当前目录下: + +```python +python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en +``` +* 注1:语言选项和语料相对应,目前支持英文(en)、简体中文(ch)和韩语(ko)。 +* 注2:Style-Text生成的数据主要应用于OCR识别场景。基于当前PaddleOCR识别模型的设计,我们主要支持高度在32左右的风格图像。 + 如果输入图像尺寸相差过多,效果可能不佳。 +* 注3:可以通过修改配置文件`configs/config.yml`中的`use_gpu`(true或者false)参数来决定是否使用GPU进行预测。 + + +例如,输入如下图片和语料"PaddleOCR": + +
+ +
+ +生成合成数据`fake_fusion.jpg`: +
+ +
+ +除此之外,程序还会生成并保存中间结果`fake_bg.jpg`:为风格参考图去掉文字后的背景; + +
+ +
+ +`fake_text.jpg`:是用提供的字符串,仿照风格参考图中文字的风格,生成在灰色背景上的文字图片。 + +
+ +
+ +#### 批量合成 +在实际应用场景中,经常需要批量合成图片,补充到训练集中。Style-Text可以使用一批风格图片和语料,批量合成数据。合成过程如下: + +1. 在`configs/dataset_config.yml`中配置目标场景风格图像和语料的路径,具体如下: + + * `Global`: + * `output_dir:`:保存合成数据的目录。 + * `StyleSampler`: + * `image_home`:风格图片目录; + * `label_file`:风格图片路径列表文件,如果所用数据集有label,则label_file为label文件路径; + * `with_label`:标志`label_file`是否为label文件。 + * `CorpusGenerator`: + * `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`; + * `language`:语料的语种,目前支持英文(en)、简体中文(ch)和韩语(ko); + * `corpus_file`: 语料文件路径。语料文件应使用文本文件。语料生成器首先会将语料按行切分,之后每次随机选取一行。 + + 语料文件格式示例: + ``` + PaddleOCR + 飞桨文字识别 + StyleText + 风格文本图像数据合成 + ... + ``` + + Style-Text也提供了一批中英韩5万张通用场景数据用作文本风格图像,便于合成场景丰富的文本图像,下图给出了一些示例。 + + 中英韩5万张通用场景数据: [下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar) + +
+ +
+ +2. 运行`tools/synth_dataset`合成数据: + + ``` bash + python3 tools/synth_dataset.py -c configs/dataset_config.yml + ``` + 我们在examples目录下提供了样例图片和语料。 +
+ + +
+ + 直接运行上述命令,可以在output_data中产生样例输出,包括图片和用于训练识别模型的标注文件: +
+ +
+ + 其中label目录下的标注文件为程序运行过程中产生的缓存,如果程序在中途异常终止,可以使用缓存的标注文件。 + 如果程序正常运行完毕,则会在output_data下生成label.txt,为最终的标注结果。 + + +### 四、应用案例 +下面以金属表面英文数字识别和通用韩语识别两个场景为例,说明使用Style-Text合成数据,来提升文本识别效果的实际案例。下图给出了一些真实场景图像和合成图像的示例: + +
+ +
+ +在添加上述合成数据进行训练后,识别模型的效果提升,如下表所示: + +| 场景 | 字符 | 原始数据 | 测试数据 | 只使用原始数据
识别准确率 | 新增合成数据 | 同时使用合成数据
识别准确率 | 指标提升 | +| -------- | ---------- | -------- | -------- | -------------------------- | ------------ | ---------------------- | -------- | +| 金属表面 | 英文和数字 | 2203 | 650 | 59.38% | 20000 | 75.46% | 16.08% | +| 随机背景 | 韩语 | 5631 | 1230 | 30.12% | 100000 | 50.57% | 20.45% | + + + +### 五、代码结构 + +``` +StyleText +|-- arch // 网络结构定义文件 +| |-- base_module.py +| |-- decoder.py +| |-- encoder.py +| |-- spectral_norm.py +| `-- style_text_rec.py +|-- configs // 配置文件 +| |-- config.yml +| `-- dataset_config.yml +|-- engine // 数据合成引擎 +| |-- corpus_generators.py // 从文本采样或随机生成语料 +| |-- predictors.py // 调用网络生成数据 +| |-- style_samplers.py // 采样风格图片 +| |-- synthesisers.py // 调度各个模块,合成数据 +| |-- text_drawers.py // 生成标准文字图片,用作输入 +| `-- writers.py // 将合成的图片和标签写入本地目录 +|-- examples // 示例文件 +| |-- corpus +| | `-- example.txt +| |-- image_list.txt +| `-- style_images +| |-- 1.jpg +| `-- 2.jpg +|-- fonts // 字体文件 +| |-- ch_standard.ttf +| |-- en_standard.ttf +| `-- ko_standard.ttf +|-- tools // 程序入口 +| |-- __init__.py +| |-- synth_dataset.py // 批量合成数据 +| `-- synth_image.py // 合成单张图片 +`-- utils // 其他基础功能模块 + |-- config.py + |-- load_params.py + |-- logging.py + |-- math_functions.py + `-- sys_funcs.py +``` diff --git a/StyleText/__init__.py b/StyleText/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/StyleText/arch/__init__.py b/StyleText/arch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/StyleText/arch/base_module.py b/StyleText/arch/base_module.py new file mode 100644 index 0000000..da2b6b8 --- /dev/null +++ b/StyleText/arch/base_module.py @@ -0,0 +1,255 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn + +from arch.spectral_norm import spectral_norm + + +class CBN(nn.Layer): + def __init__(self, + name, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + use_bias=False, + norm_layer=None, + act=None, + act_attr=None): + super(CBN, self).__init__() + if use_bias: + bias_attr = paddle.ParamAttr(name=name + "_bias") + else: + bias_attr = None + self._conv = paddle.nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + weight_attr=paddle.ParamAttr(name=name + "_weights"), + bias_attr=bias_attr) + if norm_layer: + self._norm_layer = getattr(paddle.nn, norm_layer)( + num_features=out_channels, name=name + "_bn") + else: + self._norm_layer = None + if act: + if act_attr: + self._act = getattr(paddle.nn, act)(**act_attr, + name=name + "_" + act) + else: + self._act = getattr(paddle.nn, act)(name=name + "_" + act) + else: + self._act = None + + def forward(self, x): + out = self._conv(x) + if self._norm_layer: + out = self._norm_layer(out) + if self._act: + out = self._act(out) + return out + + +class SNConv(nn.Layer): + def __init__(self, + name, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + use_bias=False, + norm_layer=None, + act=None, + act_attr=None): + super(SNConv, self).__init__() + if use_bias: + bias_attr = paddle.ParamAttr(name=name + "_bias") + else: + bias_attr = None + self._sn_conv = spectral_norm( + paddle.nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + weight_attr=paddle.ParamAttr(name=name + "_weights"), + bias_attr=bias_attr)) + if norm_layer: + self._norm_layer = getattr(paddle.nn, norm_layer)( + num_features=out_channels, name=name + "_bn") + else: + self._norm_layer = None + if act: + if act_attr: + self._act = getattr(paddle.nn, act)(**act_attr, + name=name + "_" + act) + else: + self._act = getattr(paddle.nn, act)(name=name + "_" + act) + else: + self._act = None + + def forward(self, x): + out = self._sn_conv(x) + if self._norm_layer: + out = self._norm_layer(out) + if self._act: + out = self._act(out) + return out + + +class SNConvTranspose(nn.Layer): + def __init__(self, + name, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + dilation=1, + groups=1, + use_bias=False, + norm_layer=None, + act=None, + act_attr=None): + super(SNConvTranspose, self).__init__() + if use_bias: + bias_attr = paddle.ParamAttr(name=name + "_bias") + else: + bias_attr = None + self._sn_conv_transpose = spectral_norm( + paddle.nn.Conv2DTranspose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + weight_attr=paddle.ParamAttr(name=name + "_weights"), + bias_attr=bias_attr)) + if norm_layer: + self._norm_layer = getattr(paddle.nn, norm_layer)( + num_features=out_channels, name=name + "_bn") + else: + self._norm_layer = None + if act: + if act_attr: + self._act = getattr(paddle.nn, act)(**act_attr, + name=name + "_" + act) + else: + self._act = getattr(paddle.nn, act)(name=name + "_" + act) + else: + self._act = None + + def forward(self, x): + out = self._sn_conv_transpose(x) + if self._norm_layer: + out = self._norm_layer(out) + if self._act: + out = self._act(out) + return out + + +class MiddleNet(nn.Layer): + def __init__(self, name, in_channels, mid_channels, out_channels, + use_bias): + super(MiddleNet, self).__init__() + self._sn_conv1 = SNConv( + name=name + "_sn_conv1", + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + use_bias=use_bias, + norm_layer=None, + act=None) + self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate") + self._sn_conv2 = SNConv( + name=name + "_sn_conv2", + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=3, + use_bias=use_bias) + self._sn_conv3 = SNConv( + name=name + "_sn_conv3", + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + use_bias=use_bias) + + def forward(self, x): + + sn_conv1 = self._sn_conv1.forward(x) + pad_2d = self._pad2d.forward(sn_conv1) + sn_conv2 = self._sn_conv2.forward(pad_2d) + sn_conv3 = self._sn_conv3.forward(sn_conv2) + return sn_conv3 + + +class ResBlock(nn.Layer): + def __init__(self, name, channels, norm_layer, use_dropout, use_dilation, + use_bias): + super(ResBlock, self).__init__() + if use_dilation: + padding_mat = [1, 1, 1, 1] + else: + padding_mat = [0, 0, 0, 0] + self._pad1 = nn.Pad2D(padding_mat, mode="replicate") + + self._sn_conv1 = SNConv( + name=name + "_sn_conv1", + in_channels=channels, + out_channels=channels, + kernel_size=3, + padding=0, + norm_layer=norm_layer, + use_bias=use_bias, + act="ReLU", + act_attr=None) + if use_dropout: + self._dropout = nn.Dropout(0.5) + else: + self._dropout = None + self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate") + self._sn_conv2 = SNConv( + name=name + "_sn_conv2", + in_channels=channels, + out_channels=channels, + kernel_size=3, + norm_layer=norm_layer, + use_bias=use_bias, + act="ReLU", + act_attr=None) + + def forward(self, x): + pad1 = self._pad1.forward(x) + sn_conv1 = self._sn_conv1.forward(pad1) + pad2 = self._pad2.forward(sn_conv1) + sn_conv2 = self._sn_conv2.forward(pad2) + return sn_conv2 + x diff --git a/StyleText/arch/decoder.py b/StyleText/arch/decoder.py new file mode 100644 index 0000000..36f07c5 --- /dev/null +++ b/StyleText/arch/decoder.py @@ -0,0 +1,251 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn + +from arch.base_module import SNConv, SNConvTranspose, ResBlock + + +class Decoder(nn.Layer): + def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer, + act, act_attr, conv_block_dropout, conv_block_num, + conv_block_dilation, out_conv_act, out_conv_act_attr): + super(Decoder, self).__init__() + conv_blocks = [] + for i in range(conv_block_num): + conv_blocks.append( + ResBlock( + name="{}_conv_block_{}".format(name, i), + channels=encode_dim * 8, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias)) + self.conv_blocks = nn.Sequential(*conv_blocks) + self._up1 = SNConvTranspose( + name=name + "_up1", + in_channels=encode_dim * 8, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up2 = SNConvTranspose( + name=name + "_up2", + in_channels=encode_dim * 4, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up3 = SNConvTranspose( + name=name + "_up3", + in_channels=encode_dim * 2, + out_channels=encode_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate") + self._out_conv = SNConv( + name=name + "_out_conv", + in_channels=encode_dim, + out_channels=out_channels, + kernel_size=3, + use_bias=use_bias, + norm_layer=None, + act=out_conv_act, + act_attr=out_conv_act_attr) + + def forward(self, x): + if isinstance(x, (list, tuple)): + x = paddle.concat(x, axis=1) + output_dict = dict() + output_dict["conv_blocks"] = self.conv_blocks.forward(x) + output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"]) + output_dict["up2"] = self._up2.forward(output_dict["up1"]) + output_dict["up3"] = self._up3.forward(output_dict["up2"]) + output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"]) + output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"]) + return output_dict + + +class DecoderUnet(nn.Layer): + def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer, + act, act_attr, conv_block_dropout, conv_block_num, + conv_block_dilation, out_conv_act, out_conv_act_attr): + super(DecoderUnet, self).__init__() + conv_blocks = [] + for i in range(conv_block_num): + conv_blocks.append( + ResBlock( + name="{}_conv_block_{}".format(name, i), + channels=encode_dim * 8, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias)) + self._conv_blocks = nn.Sequential(*conv_blocks) + self._up1 = SNConvTranspose( + name=name + "_up1", + in_channels=encode_dim * 8, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up2 = SNConvTranspose( + name=name + "_up2", + in_channels=encode_dim * 8, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up3 = SNConvTranspose( + name=name + "_up3", + in_channels=encode_dim * 4, + out_channels=encode_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate") + self._out_conv = SNConv( + name=name + "_out_conv", + in_channels=encode_dim, + out_channels=out_channels, + kernel_size=3, + use_bias=use_bias, + norm_layer=None, + act=out_conv_act, + act_attr=out_conv_act_attr) + + def forward(self, x, y, feature2, feature1): + output_dict = dict() + output_dict["conv_blocks"] = self._conv_blocks( + paddle.concat( + (x, y), axis=1)) + output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"]) + output_dict["up2"] = self._up2.forward( + paddle.concat( + (output_dict["up1"], feature2), axis=1)) + output_dict["up3"] = self._up3.forward( + paddle.concat( + (output_dict["up2"], feature1), axis=1)) + output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"]) + output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"]) + return output_dict + + +class SingleDecoder(nn.Layer): + def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer, + act, act_attr, conv_block_dropout, conv_block_num, + conv_block_dilation, out_conv_act, out_conv_act_attr): + super(SingleDecoder, self).__init__() + conv_blocks = [] + for i in range(conv_block_num): + conv_blocks.append( + ResBlock( + name="{}_conv_block_{}".format(name, i), + channels=encode_dim * 4, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias)) + self._conv_blocks = nn.Sequential(*conv_blocks) + self._up1 = SNConvTranspose( + name=name + "_up1", + in_channels=encode_dim * 4, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up2 = SNConvTranspose( + name=name + "_up2", + in_channels=encode_dim * 8, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up3 = SNConvTranspose( + name=name + "_up3", + in_channels=encode_dim * 4, + out_channels=encode_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate") + self._out_conv = SNConv( + name=name + "_out_conv", + in_channels=encode_dim, + out_channels=out_channels, + kernel_size=3, + use_bias=use_bias, + norm_layer=None, + act=out_conv_act, + act_attr=out_conv_act_attr) + + def forward(self, x, feature2, feature1): + output_dict = dict() + output_dict["conv_blocks"] = self._conv_blocks.forward(x) + output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"]) + output_dict["up2"] = self._up2.forward( + paddle.concat( + (output_dict["up1"], feature2), axis=1)) + output_dict["up3"] = self._up3.forward( + paddle.concat( + (output_dict["up2"], feature1), axis=1)) + output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"]) + output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"]) + return output_dict diff --git a/StyleText/arch/encoder.py b/StyleText/arch/encoder.py new file mode 100644 index 0000000..b884cda --- /dev/null +++ b/StyleText/arch/encoder.py @@ -0,0 +1,186 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn + +from arch.base_module import SNConv, SNConvTranspose, ResBlock + + +class Encoder(nn.Layer): + def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer, + act, act_attr, conv_block_dropout, conv_block_num, + conv_block_dilation): + super(Encoder, self).__init__() + self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate") + self._in_conv = SNConv( + name=name + "_in_conv", + in_channels=in_channels, + out_channels=encode_dim, + kernel_size=7, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down1 = SNConv( + name=name + "_down1", + in_channels=encode_dim, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down2 = SNConv( + name=name + "_down2", + in_channels=encode_dim * 2, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down3 = SNConv( + name=name + "_down3", + in_channels=encode_dim * 4, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + conv_blocks = [] + for i in range(conv_block_num): + conv_blocks.append( + ResBlock( + name="{}_conv_block_{}".format(name, i), + channels=encode_dim * 4, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias)) + self._conv_blocks = nn.Sequential(*conv_blocks) + + def forward(self, x): + out_dict = dict() + x = self._pad2d(x) + out_dict["in_conv"] = self._in_conv.forward(x) + out_dict["down1"] = self._down1.forward(out_dict["in_conv"]) + out_dict["down2"] = self._down2.forward(out_dict["down1"]) + out_dict["down3"] = self._down3.forward(out_dict["down2"]) + out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"]) + return out_dict + + +class EncoderUnet(nn.Layer): + def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer, + act, act_attr): + super(EncoderUnet, self).__init__() + self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate") + self._in_conv = SNConv( + name=name + "_in_conv", + in_channels=in_channels, + out_channels=encode_dim, + kernel_size=7, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down1 = SNConv( + name=name + "_down1", + in_channels=encode_dim, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down2 = SNConv( + name=name + "_down2", + in_channels=encode_dim * 2, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down3 = SNConv( + name=name + "_down3", + in_channels=encode_dim * 2, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._down4 = SNConv( + name=name + "_down4", + in_channels=encode_dim * 2, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up1 = SNConvTranspose( + name=name + "_up1", + in_channels=encode_dim * 2, + out_channels=encode_dim * 2, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + self._up2 = SNConvTranspose( + name=name + "_up2", + in_channels=encode_dim * 4, + out_channels=encode_dim * 4, + kernel_size=3, + stride=2, + padding=1, + use_bias=use_bias, + norm_layer=norm_layer, + act=act, + act_attr=act_attr) + + def forward(self, x): + output_dict = dict() + x = self._pad2d(x) + output_dict['in_conv'] = self._in_conv.forward(x) + output_dict['down1'] = self._down1.forward(output_dict['in_conv']) + output_dict['down2'] = self._down2.forward(output_dict['down1']) + output_dict['down3'] = self._down3.forward(output_dict['down2']) + output_dict['down4'] = self._down4.forward(output_dict['down3']) + output_dict['up1'] = self._up1.forward(output_dict['down4']) + output_dict['up2'] = self._up2.forward( + paddle.concat( + (output_dict['down3'], output_dict['up1']), axis=1)) + output_dict['concat'] = paddle.concat( + (output_dict['down2'], output_dict['up2']), axis=1) + return output_dict diff --git a/StyleText/arch/spectral_norm.py b/StyleText/arch/spectral_norm.py new file mode 100644 index 0000000..21d0afc --- /dev/null +++ b/StyleText/arch/spectral_norm.py @@ -0,0 +1,150 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def normal_(x, mean=0., std=1.): + temp_value = paddle.normal(mean, std, shape=x.shape) + x.set_value(temp_value) + return x + + +class SpectralNorm(object): + def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError('Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format( + n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight): + weight_mat = weight + if self.dim != 0: + # transpose dim to front + weight_mat = weight_mat.transpose([ + self.dim, + * [d for d in range(weight_mat.dim()) if d != self.dim] + ]) + + height = weight_mat.shape[0] + + return weight_mat.reshape([height, -1]) + + def compute_weight(self, module, do_power_iteration): + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with paddle.no_grad(): + for _ in range(self.n_power_iterations): + v.set_value( + F.normalize( + paddle.matmul( + weight_mat, + u, + transpose_x=True, + transpose_y=False), + axis=0, + epsilon=self.eps, )) + + u.set_value( + F.normalize( + paddle.matmul(weight_mat, v), + axis=0, + epsilon=self.eps, )) + if self.n_power_iterations > 0: + u = u.clone() + v = v.clone() + + sigma = paddle.dot(u, paddle.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module): + with paddle.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + + module.add_parameter(self.name, weight.detach()) + + def __call__(self, module, inputs): + setattr( + module, + self.name, + self.compute_weight( + module, do_power_iteration=module.training)) + + @staticmethod + def apply(module, name, n_power_iterations, dim, eps): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError( + "Cannot register two spectral_norm hooks on " + "the same parameter {}".format(name)) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + + with paddle.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + h, w = weight_mat.shape + + # randomly initialize u and v + u = module.create_parameter([h]) + u = normal_(u, 0., 1.) + v = module.create_parameter([w]) + v = normal_(v, 0., 1.) + u = F.normalize(u, axis=0, epsilon=fn.eps) + v = F.normalize(v, axis=0, epsilon=fn.eps) + + # delete fn.name form parameters, otherwise you can not set attribute + del module._parameters[fn.name] + module.add_parameter(fn.name + "_orig", weight) + # still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an Parameter and + # gets added as a parameter. Instead, we register weight * 1.0 as a plain + # attribute. + setattr(module, fn.name, weight * 1.0) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + return fn + + +def spectral_norm(module, + name='weight', + n_power_iterations=1, + eps=1e-12, + dim=None): + + if dim is None: + if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose, + nn.Conv3DTranspose, nn.Linear)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module diff --git a/StyleText/arch/style_text_rec.py b/StyleText/arch/style_text_rec.py new file mode 100644 index 0000000..599927c --- /dev/null +++ b/StyleText/arch/style_text_rec.py @@ -0,0 +1,285 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn + +from arch.base_module import MiddleNet, ResBlock +from arch.encoder import Encoder +from arch.decoder import Decoder, DecoderUnet, SingleDecoder +from utils.load_params import load_dygraph_pretrain +from utils.logging import get_logger + + +class StyleTextRec(nn.Layer): + def __init__(self, config): + super(StyleTextRec, self).__init__() + self.logger = get_logger() + self.text_generator = TextGenerator(config["Predictor"][ + "text_generator"]) + self.bg_generator = BgGeneratorWithMask(config["Predictor"][ + "bg_generator"]) + self.fusion_generator = FusionGeneratorSimple(config["Predictor"][ + "fusion_generator"]) + bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"] + text_generator_pretrain = config["Predictor"]["text_generator"][ + "pretrain"] + fusion_generator_pretrain = config["Predictor"]["fusion_generator"][ + "pretrain"] + load_dygraph_pretrain( + self.bg_generator, + self.logger, + path=bg_generator_pretrain, + load_static_weights=False) + load_dygraph_pretrain( + self.text_generator, + self.logger, + path=text_generator_pretrain, + load_static_weights=False) + load_dygraph_pretrain( + self.fusion_generator, + self.logger, + path=fusion_generator_pretrain, + load_static_weights=False) + + def forward(self, style_input, text_input): + text_gen_output = self.text_generator.forward(style_input, text_input) + fake_text = text_gen_output["fake_text"] + fake_sk = text_gen_output["fake_sk"] + bg_gen_output = self.bg_generator.forward(style_input) + bg_encode_feature = bg_gen_output["bg_encode_feature"] + bg_decode_feature1 = bg_gen_output["bg_decode_feature1"] + bg_decode_feature2 = bg_gen_output["bg_decode_feature2"] + fake_bg = bg_gen_output["fake_bg"] + + fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg) + fake_fusion = fusion_gen_output["fake_fusion"] + return { + "fake_fusion": fake_fusion, + "fake_text": fake_text, + "fake_sk": fake_sk, + "fake_bg": fake_bg, + } + + +class TextGenerator(nn.Layer): + def __init__(self, config): + super(TextGenerator, self).__init__() + name = config["module_name"] + encode_dim = config["encode_dim"] + norm_layer = config["norm_layer"] + conv_block_dropout = config["conv_block_dropout"] + conv_block_num = config["conv_block_num"] + conv_block_dilation = config["conv_block_dilation"] + if norm_layer == "InstanceNorm2D": + use_bias = True + else: + use_bias = False + self.encoder_text = Encoder( + name=name + "_encoder_text", + in_channels=3, + encode_dim=encode_dim, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation) + self.encoder_style = Encoder( + name=name + "_encoder_style", + in_channels=3, + encode_dim=encode_dim, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation) + self.decoder_text = Decoder( + name=name + "_decoder_text", + encode_dim=encode_dim, + out_channels=int(encode_dim / 2), + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation, + out_conv_act="Tanh", + out_conv_act_attr=None) + self.decoder_sk = Decoder( + name=name + "_decoder_sk", + encode_dim=encode_dim, + out_channels=1, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation, + out_conv_act="Sigmoid", + out_conv_act_attr=None) + + self.middle = MiddleNet( + name=name + "_middle_net", + in_channels=int(encode_dim / 2) + 1, + mid_channels=encode_dim, + out_channels=3, + use_bias=use_bias) + + def forward(self, style_input, text_input): + style_feature = self.encoder_style.forward(style_input)["res_blocks"] + text_feature = self.encoder_text.forward(text_input)["res_blocks"] + fake_c_temp = self.decoder_text.forward([text_feature, + style_feature])["out_conv"] + fake_sk = self.decoder_sk.forward([text_feature, + style_feature])["out_conv"] + fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1)) + return {"fake_sk": fake_sk, "fake_text": fake_text} + + +class BgGeneratorWithMask(nn.Layer): + def __init__(self, config): + super(BgGeneratorWithMask, self).__init__() + name = config["module_name"] + encode_dim = config["encode_dim"] + norm_layer = config["norm_layer"] + conv_block_dropout = config["conv_block_dropout"] + conv_block_num = config["conv_block_num"] + conv_block_dilation = config["conv_block_dilation"] + self.output_factor = config.get("output_factor", 1.0) + + if norm_layer == "InstanceNorm2D": + use_bias = True + else: + use_bias = False + + self.encoder_bg = Encoder( + name=name + "_encoder_bg", + in_channels=3, + encode_dim=encode_dim, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation) + + self.decoder_bg = SingleDecoder( + name=name + "_decoder_bg", + encode_dim=encode_dim, + out_channels=3, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation, + out_conv_act="Tanh", + out_conv_act_attr=None) + + self.decoder_mask = Decoder( + name=name + "_decoder_mask", + encode_dim=encode_dim // 2, + out_channels=1, + use_bias=use_bias, + norm_layer=norm_layer, + act="ReLU", + act_attr=None, + conv_block_dropout=conv_block_dropout, + conv_block_num=conv_block_num, + conv_block_dilation=conv_block_dilation, + out_conv_act="Sigmoid", + out_conv_act_attr=None) + + self.middle = MiddleNet( + name=name + "_middle_net", + in_channels=3 + 1, + mid_channels=encode_dim, + out_channels=3, + use_bias=use_bias) + + def forward(self, style_input): + encode_bg_output = self.encoder_bg(style_input) + decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"], + encode_bg_output["down2"], + encode_bg_output["down1"]) + + fake_c_temp = decode_bg_output["out_conv"] + fake_bg_mask = self.decoder_mask.forward(encode_bg_output[ + "res_blocks"])["out_conv"] + fake_bg = self.middle( + paddle.concat( + (fake_c_temp, fake_bg_mask), axis=1)) + return { + "bg_encode_feature": encode_bg_output["res_blocks"], + "bg_decode_feature1": decode_bg_output["up1"], + "bg_decode_feature2": decode_bg_output["up2"], + "fake_bg": fake_bg, + "fake_bg_mask": fake_bg_mask, + } + + +class FusionGeneratorSimple(nn.Layer): + def __init__(self, config): + super(FusionGeneratorSimple, self).__init__() + name = config["module_name"] + encode_dim = config["encode_dim"] + norm_layer = config["norm_layer"] + conv_block_dropout = config["conv_block_dropout"] + conv_block_dilation = config["conv_block_dilation"] + if norm_layer == "InstanceNorm2D": + use_bias = True + else: + use_bias = False + + self._conv = nn.Conv2D( + in_channels=6, + out_channels=encode_dim, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=paddle.ParamAttr(name=name + "_conv_weights"), + bias_attr=False) + + self._res_block = ResBlock( + name="{}_conv_block".format(name), + channels=encode_dim, + norm_layer=norm_layer, + use_dropout=conv_block_dropout, + use_dilation=conv_block_dilation, + use_bias=use_bias) + + self._reduce_conv = nn.Conv2D( + in_channels=encode_dim, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + groups=1, + weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"), + bias_attr=False) + + def forward(self, fake_text, fake_bg): + fake_concat = paddle.concat((fake_text, fake_bg), axis=1) + fake_concat_tmp = self._conv(fake_concat) + output_res = self._res_block(fake_concat_tmp) + fake_fusion = self._reduce_conv(output_res) + return {"fake_fusion": fake_fusion} diff --git a/StyleText/configs/config.yml b/StyleText/configs/config.yml new file mode 100644 index 0000000..3b10b3d --- /dev/null +++ b/StyleText/configs/config.yml @@ -0,0 +1,54 @@ +Global: + output_num: 10 + output_dir: output_data + use_gpu: false + image_height: 32 + image_width: 320 +TextDrawer: + fonts: + en: fonts/en_standard.ttf + ch: fonts/ch_standard.ttf + ko: fonts/ko_standard.ttf +Predictor: + method: StyleTextRecPredictor + algorithm: StyleTextRec + scale: 0.00392156862745098 + mean: + - 0.5 + - 0.5 + - 0.5 + std: + - 0.5 + - 0.5 + - 0.5 + expand_result: false + bg_generator: + pretrain: style_text_models/bg_generator + module_name: bg_generator + generator_type: BgGeneratorWithMask + encode_dim: 64 + norm_layer: null + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true + output_factor: 1.05 + text_generator: + pretrain: style_text_models/text_generator + module_name: text_generator + generator_type: TextGenerator + encode_dim: 64 + norm_layer: InstanceNorm2D + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true + fusion_generator: + pretrain: style_text_models/fusion_generator + module_name: fusion_generator + generator_type: FusionGeneratorSimple + encode_dim: 64 + norm_layer: null + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true +Writer: + method: SimpleWriter diff --git a/StyleText/configs/dataset_config.yml b/StyleText/configs/dataset_config.yml new file mode 100644 index 0000000..aa4ec69 --- /dev/null +++ b/StyleText/configs/dataset_config.yml @@ -0,0 +1,64 @@ +Global: + output_num: 10 + output_dir: output_data + use_gpu: false + image_height: 32 + image_width: 320 + standard_font: fonts/en_standard.ttf +TextDrawer: + fonts: + en: fonts/en_standard.ttf + ch: fonts/ch_standard.ttf + ko: fonts/ko_standard.ttf +StyleSampler: + method: DatasetSampler + image_home: examples + label_file: examples/image_list.txt + with_label: true +CorpusGenerator: + method: FileCorpus + language: ch + corpus_file: examples/corpus/example.txt +Predictor: + method: StyleTextRecPredictor + algorithm: StyleTextRec + scale: 0.00392156862745098 + mean: + - 0.5 + - 0.5 + - 0.5 + std: + - 0.5 + - 0.5 + - 0.5 + expand_result: false + bg_generator: + pretrain: style_text_models/bg_generator + module_name: bg_generator + generator_type: BgGeneratorWithMask + encode_dim: 64 + norm_layer: null + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true + output_factor: 1.05 + text_generator: + pretrain: style_text_models/text_generator + module_name: text_generator + generator_type: TextGenerator + encode_dim: 64 + norm_layer: InstanceNorm2D + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true + fusion_generator: + pretrain: style_text_models/fusion_generator + module_name: fusion_generator + generator_type: FusionGeneratorSimple + encode_dim: 64 + norm_layer: null + conv_block_num: 4 + conv_block_dropout: false + conv_block_dilation: true +Writer: + method: SimpleWriter diff --git a/StyleText/doc/images/1.png b/StyleText/doc/images/1.png new file mode 100644 index 0000000..8f7574b Binary files /dev/null and b/StyleText/doc/images/1.png differ diff --git a/StyleText/doc/images/10.png b/StyleText/doc/images/10.png new file mode 100644 index 0000000..6123cff Binary files /dev/null and b/StyleText/doc/images/10.png differ diff --git a/StyleText/doc/images/11.png b/StyleText/doc/images/11.png new file mode 100644 index 0000000..ebfa093 Binary files /dev/null and b/StyleText/doc/images/11.png differ diff --git a/StyleText/doc/images/12.png b/StyleText/doc/images/12.png new file mode 100644 index 0000000..74ba4a0 Binary files /dev/null and b/StyleText/doc/images/12.png differ diff --git a/StyleText/doc/images/2.png b/StyleText/doc/images/2.png new file mode 100644 index 0000000..ce9bf47 Binary files /dev/null and b/StyleText/doc/images/2.png differ diff --git a/StyleText/doc/images/3.png b/StyleText/doc/images/3.png new file mode 100644 index 0000000..0fb73a3 Binary files /dev/null and b/StyleText/doc/images/3.png differ diff --git a/StyleText/doc/images/4.jpg b/StyleText/doc/images/4.jpg new file mode 100644 index 0000000..d881074 Binary files /dev/null and b/StyleText/doc/images/4.jpg differ diff --git a/StyleText/doc/images/5.png b/StyleText/doc/images/5.png new file mode 100644 index 0000000..b7d28b7 Binary files /dev/null and b/StyleText/doc/images/5.png differ diff --git a/StyleText/doc/images/6.png b/StyleText/doc/images/6.png new file mode 100644 index 0000000..75af727 Binary files /dev/null and b/StyleText/doc/images/6.png differ diff --git a/StyleText/doc/images/7.jpg b/StyleText/doc/images/7.jpg new file mode 100644 index 0000000..887094f Binary files /dev/null and b/StyleText/doc/images/7.jpg differ diff --git a/StyleText/doc/images/8.jpg b/StyleText/doc/images/8.jpg new file mode 100644 index 0000000..234d7f3 Binary files /dev/null and b/StyleText/doc/images/8.jpg differ diff --git a/StyleText/doc/images/9.png b/StyleText/doc/images/9.png new file mode 100644 index 0000000..1797802 Binary files /dev/null and b/StyleText/doc/images/9.png differ diff --git a/StyleText/engine/__init__.py b/StyleText/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/StyleText/engine/corpus_generators.py b/StyleText/engine/corpus_generators.py new file mode 100644 index 0000000..186d15f --- /dev/null +++ b/StyleText/engine/corpus_generators.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +from utils.logging import get_logger + + +class FileCorpus(object): + def __init__(self, config): + self.logger = get_logger() + self.logger.info("using FileCorpus") + + self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + + corpus_file = config["CorpusGenerator"]["corpus_file"] + self.language = config["CorpusGenerator"]["language"] + with open(corpus_file, 'r') as f: + corpus_raw = f.read() + self.corpus_list = corpus_raw.split("\n")[:-1] + assert len(self.corpus_list) > 0 + random.shuffle(self.corpus_list) + self.index = 0 + + def generate(self, corpus_length=0): + if self.index >= len(self.corpus_list): + self.index = 0 + random.shuffle(self.corpus_list) + corpus = self.corpus_list[self.index] + if corpus_length != 0: + corpus = corpus[0:corpus_length] + if corpus_length > len(corpus): + self.logger.warning("generated corpus is shorter than expected.") + self.index += 1 + return self.language, corpus + + +class EnNumCorpus(object): + def __init__(self, config): + self.logger = get_logger() + self.logger.info("using NumberCorpus") + self.num_list = "0123456789" + self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + self.height = config["Global"]["image_height"] + self.max_width = config["Global"]["image_width"] + + def generate(self, corpus_length=0): + corpus = "" + if corpus_length == 0: + corpus_length = random.randint(5, 15) + for i in range(corpus_length): + if random.random() < 0.2: + corpus += "{}".format(random.choice(self.en_char_list)) + else: + corpus += "{}".format(random.choice(self.num_list)) + return "en", corpus diff --git a/StyleText/engine/predictors.py b/StyleText/engine/predictors.py new file mode 100644 index 0000000..ca9ab9c --- /dev/null +++ b/StyleText/engine/predictors.py @@ -0,0 +1,139 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import cv2 +import math +import paddle + +from arch import style_text_rec +from utils.sys_funcs import check_gpu +from utils.logging import get_logger + + +class StyleTextRecPredictor(object): + def __init__(self, config): + algorithm = config['Predictor']['algorithm'] + assert algorithm in ["StyleTextRec" + ], "Generator {} not supported.".format(algorithm) + use_gpu = config["Global"]['use_gpu'] + check_gpu(use_gpu) + paddle.set_device('gpu' if use_gpu else 'cpu') + self.logger = get_logger() + self.generator = getattr(style_text_rec, algorithm)(config) + self.height = config["Global"]["image_height"] + self.width = config["Global"]["image_width"] + self.scale = config["Predictor"]["scale"] + self.mean = config["Predictor"]["mean"] + self.std = config["Predictor"]["std"] + self.expand_result = config["Predictor"]["expand_result"] + + def reshape_to_same_height(self, img_list): + h = img_list[0].shape[0] + for idx in range(1, len(img_list)): + new_w = round(1.0 * img_list[idx].shape[1] / + img_list[idx].shape[0] * h) + img_list[idx] = cv2.resize(img_list[idx], (new_w, h)) + return img_list + + def predict_single_image(self, style_input, text_input): + style_input = self.rep_style_input(style_input, text_input) + tensor_style_input = self.preprocess(style_input) + tensor_text_input = self.preprocess(text_input) + style_text_result = self.generator.forward(tensor_style_input, + tensor_text_input) + fake_fusion = self.postprocess(style_text_result["fake_fusion"]) + fake_text = self.postprocess(style_text_result["fake_text"]) + fake_sk = self.postprocess(style_text_result["fake_sk"]) + fake_bg = self.postprocess(style_text_result["fake_bg"]) + bbox = self.get_text_boundary(fake_text) + if bbox: + left, right, top, bottom = bbox + fake_fusion = fake_fusion[top:bottom, left:right, :] + fake_text = fake_text[top:bottom, left:right, :] + fake_sk = fake_sk[top:bottom, left:right, :] + fake_bg = fake_bg[top:bottom, left:right, :] + + # fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text) + return { + "fake_fusion": fake_fusion, + "fake_text": fake_text, + "fake_sk": fake_sk, + "fake_bg": fake_bg, + } + + def predict(self, style_input, text_input_list): + if not isinstance(text_input_list, (tuple, list)): + return self.predict_single_image(style_input, text_input_list) + + synth_result_list = [] + for text_input in text_input_list: + synth_result = self.predict_single_image(style_input, text_input) + synth_result_list.append(synth_result) + + for key in synth_result: + res = [r[key] for r in synth_result_list] + res = self.reshape_to_same_height(res) + synth_result[key] = np.concatenate(res, axis=1) + return synth_result + + def preprocess(self, img): + img = (img.astype('float32') * self.scale - self.mean) / self.std + img_height, img_width, channel = img.shape + assert channel == 3, "Please use an rgb image." + ratio = img_width / float(img_height) + if math.ceil(self.height * ratio) > self.width: + resized_w = self.width + else: + resized_w = int(math.ceil(self.height * ratio)) + img = cv2.resize(img, (resized_w, self.height)) + + new_img = np.zeros([self.height, self.width, 3]).astype('float32') + new_img[:, 0:resized_w, :] = img + img = new_img.transpose((2, 0, 1)) + img = img[np.newaxis, :, :, :] + return paddle.to_tensor(img) + + def postprocess(self, tensor): + img = tensor.numpy()[0] + img = img.transpose((1, 2, 0)) + img = (img * self.std + self.mean) / self.scale + img = np.maximum(img, 0.0) + img = np.minimum(img, 255.0) + img = img.astype('uint8') + return img + + def rep_style_input(self, style_input, text_input): + rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) / + (style_input.shape[1] / style_input.shape[0])) + 1 + style_input = np.tile(style_input, reps=[1, rep_num, 1]) + max_width = int(self.width / self.height * style_input.shape[0]) + style_input = style_input[:, :max_width, :] + return style_input + + def get_text_boundary(self, text_img): + img_height = text_img.shape[0] + img_width = text_img.shape[1] + bounder = 3 + text_canny_img = cv2.Canny(text_img, 10, 20) + edge_num_h = text_canny_img.sum(axis=0) + no_zero_list_h = np.where(edge_num_h > 0)[0] + edge_num_w = text_canny_img.sum(axis=1) + no_zero_list_w = np.where(edge_num_w > 0)[0] + if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0: + return None + left = max(no_zero_list_h[0] - bounder, 0) + right = min(no_zero_list_h[-1] + bounder, img_width) + top = max(no_zero_list_w[0] - bounder, 0) + bottom = min(no_zero_list_w[-1] + bounder, img_height) + return [left, right, top, bottom] diff --git a/StyleText/engine/style_samplers.py b/StyleText/engine/style_samplers.py new file mode 100644 index 0000000..e171d58 --- /dev/null +++ b/StyleText/engine/style_samplers.py @@ -0,0 +1,62 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import random +import cv2 + + +class DatasetSampler(object): + def __init__(self, config): + self.image_home = config["StyleSampler"]["image_home"] + label_file = config["StyleSampler"]["label_file"] + self.dataset_with_label = config["StyleSampler"]["with_label"] + self.height = config["Global"]["image_height"] + self.index = 0 + with open(label_file, "r") as f: + label_raw = f.read() + self.path_label_list = label_raw.split("\n")[:-1] + assert len(self.path_label_list) > 0 + random.shuffle(self.path_label_list) + + def sample(self): + if self.index >= len(self.path_label_list): + random.shuffle(self.path_label_list) + self.index = 0 + if self.dataset_with_label: + path_label = self.path_label_list[self.index] + rel_image_path, label = path_label.split('\t') + else: + rel_image_path = self.path_label_list[self.index] + label = None + img_path = "{}/{}".format(self.image_home, rel_image_path) + image = cv2.imread(img_path) + origin_height = image.shape[0] + ratio = self.height / origin_height + width = int(image.shape[1] * ratio) + height = int(image.shape[0] * ratio) + image = cv2.resize(image, (width, height)) + + self.index += 1 + if label: + return {"image": image, "label": label} + else: + return {"image": image} + + +def duplicate_image(image, width): + image_width = image.shape[1] + dup_num = width // image_width + 1 + image = np.tile(image, reps=[1, dup_num, 1]) + cropped_image = image[:, :width, :] + return cropped_image diff --git a/StyleText/engine/synthesisers.py b/StyleText/engine/synthesisers.py new file mode 100644 index 0000000..6461d9e --- /dev/null +++ b/StyleText/engine/synthesisers.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import numpy as np +import cv2 + +from utils.config import ArgsParser, load_config, override_config +from utils.logging import get_logger +from engine import style_samplers, corpus_generators, text_drawers, predictors, writers + + +class ImageSynthesiser(object): + def __init__(self): + self.FLAGS = ArgsParser().parse_args() + self.config = load_config(self.FLAGS.config) + self.config = override_config(self.config, options=self.FLAGS.override) + self.output_dir = self.config["Global"]["output_dir"] + if not os.path.exists(self.output_dir): + os.mkdir(self.output_dir) + self.logger = get_logger( + log_file='{}/predict.log'.format(self.output_dir)) + + self.text_drawer = text_drawers.StdTextDrawer(self.config) + + predictor_method = self.config["Predictor"]["method"] + assert predictor_method is not None + self.predictor = getattr(predictors, predictor_method)(self.config) + + def synth_image(self, corpus, style_input, language="en"): + corpus_list, text_input_list = self.text_drawer.draw_text( + corpus, language, style_input_width=style_input.shape[1]) + synth_result = self.predictor.predict(style_input, text_input_list) + return synth_result + + +class DatasetSynthesiser(ImageSynthesiser): + def __init__(self): + super(DatasetSynthesiser, self).__init__() + self.tag = self.FLAGS.tag + self.output_num = self.config["Global"]["output_num"] + corpus_generator_method = self.config["CorpusGenerator"]["method"] + self.corpus_generator = getattr(corpus_generators, + corpus_generator_method)(self.config) + + style_sampler_method = self.config["StyleSampler"]["method"] + assert style_sampler_method is not None + self.style_sampler = style_samplers.DatasetSampler(self.config) + self.writer = writers.SimpleWriter(self.config, self.tag) + + def synth_dataset(self): + for i in range(self.output_num): + style_data = self.style_sampler.sample() + style_input = style_data["image"] + corpus_language, text_input_label = self.corpus_generator.generate() + text_input_label_list, text_input_list = self.text_drawer.draw_text( + text_input_label, + corpus_language, + style_input_width=style_input.shape[1]) + + text_input_label = "".join(text_input_label_list) + + synth_result = self.predictor.predict(style_input, text_input_list) + fake_fusion = synth_result["fake_fusion"] + self.writer.save_image(fake_fusion, text_input_label) + self.writer.save_label() + self.writer.merge_label() diff --git a/StyleText/engine/text_drawers.py b/StyleText/engine/text_drawers.py new file mode 100644 index 0000000..6ccc423 --- /dev/null +++ b/StyleText/engine/text_drawers.py @@ -0,0 +1,85 @@ +from PIL import Image, ImageDraw, ImageFont +import numpy as np +import cv2 +from utils.logging import get_logger + + +class StdTextDrawer(object): + def __init__(self, config): + self.logger = get_logger() + self.max_width = config["Global"]["image_width"] + self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + self.height = config["Global"]["image_height"] + self.font_dict = {} + self.load_fonts(config["TextDrawer"]["fonts"]) + self.support_languages = list(self.font_dict) + + def load_fonts(self, fonts_config): + for language in fonts_config: + font_path = fonts_config[language] + font_height = self.get_valid_height(font_path) + font = ImageFont.truetype(font_path, font_height) + self.font_dict[language] = font + + def get_valid_height(self, font_path): + font = ImageFont.truetype(font_path, self.height - 4) + left, top, right, bottom = font.getbbox(self.char_list) + _, font_height = right - left, bottom - top + if font_height <= self.height - 4: + return self.height - 4 + else: + return int((self.height - 4)**2 / font_height) + + def draw_text(self, + corpus, + language="en", + crop=True, + style_input_width=None): + if language not in self.support_languages: + self.logger.warning( + "language {} not supported, use en instead.".format(language)) + language = "en" + if crop: + width = min(self.max_width, len(corpus) * self.height) + 4 + else: + width = len(corpus) * self.height + 4 + + if style_input_width is not None: + width = min(width, style_input_width) + + corpus_list = [] + text_input_list = [] + + while len(corpus) != 0: + bg = Image.new("RGB", (width, self.height), color=(127, 127, 127)) + draw = ImageDraw.Draw(bg) + char_x = 2 + font = self.font_dict[language] + i = 0 + while i < len(corpus): + char_i = corpus[i] + char_size = font.getsize(char_i)[0] + # split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image) + if char_x + char_size >= width and i != 0: + text_input = np.array(bg).astype(np.uint8) + text_input = text_input[:, 0:char_x, :] + + corpus_list.append(corpus[0:i]) + text_input_list.append(text_input) + corpus = corpus[i:] + i = 0 + break + draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font) + char_x += char_size + + i += 1 + # the whole text is shorter than style input + if i == len(corpus): + text_input = np.array(bg).astype(np.uint8) + text_input = text_input[:, 0:char_x, :] + + corpus_list.append(corpus[0:i]) + text_input_list.append(text_input) + break + + return corpus_list, text_input_list diff --git a/StyleText/engine/writers.py b/StyleText/engine/writers.py new file mode 100644 index 0000000..0df75e7 --- /dev/null +++ b/StyleText/engine/writers.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import cv2 +import glob + +from utils.logging import get_logger + + +class SimpleWriter(object): + def __init__(self, config, tag): + self.logger = get_logger() + self.output_dir = config["Global"]["output_dir"] + self.counter = 0 + self.label_dict = {} + self.tag = tag + self.label_file_index = 0 + + def save_image(self, image, text_input_label): + image_home = os.path.join(self.output_dir, "images", self.tag) + if not os.path.exists(image_home): + os.makedirs(image_home) + + image_path = os.path.join(image_home, "{}.png".format(self.counter)) + # todo support continue synth + cv2.imwrite(image_path, image) + self.logger.info("generate image: {}".format(image_path)) + + image_name = os.path.join(self.tag, "{}.png".format(self.counter)) + self.label_dict[image_name] = text_input_label + + self.counter += 1 + if not self.counter % 100: + self.save_label() + + def save_label(self): + label_raw = "" + label_home = os.path.join(self.output_dir, "label") + if not os.path.exists(label_home): + os.mkdir(label_home) + for image_path in self.label_dict: + label = self.label_dict[image_path] + label_raw += "{}\t{}\n".format(image_path, label) + label_file_path = os.path.join(label_home, + "{}_label.txt".format(self.tag)) + with open(label_file_path, "w") as f: + f.write(label_raw) + self.label_file_index += 1 + + def merge_label(self): + label_raw = "" + label_file_regex = os.path.join(self.output_dir, "label", + "*_label.txt") + label_file_list = glob.glob(label_file_regex) + for label_file_i in label_file_list: + with open(label_file_i, "r") as f: + label_raw += f.read() + label_file_path = os.path.join(self.output_dir, "label.txt") + with open(label_file_path, "w") as f: + f.write(label_raw) diff --git a/StyleText/examples/corpus/example.txt b/StyleText/examples/corpus/example.txt new file mode 100644 index 0000000..93ba35a --- /dev/null +++ b/StyleText/examples/corpus/example.txt @@ -0,0 +1,2 @@ +Paddle +飞桨文字识别 diff --git a/StyleText/examples/image_list.txt b/StyleText/examples/image_list.txt new file mode 100644 index 0000000..b07be03 --- /dev/null +++ b/StyleText/examples/image_list.txt @@ -0,0 +1,2 @@ +style_images/1.jpg NEATNESS +style_images/2.jpg 锁店君和宾馆 diff --git a/StyleText/examples/style_images/1.jpg b/StyleText/examples/style_images/1.jpg new file mode 100644 index 0000000..4da7838 Binary files /dev/null and b/StyleText/examples/style_images/1.jpg differ diff --git a/StyleText/examples/style_images/2.jpg b/StyleText/examples/style_images/2.jpg new file mode 100644 index 0000000..0ab932b Binary files /dev/null and b/StyleText/examples/style_images/2.jpg differ diff --git a/StyleText/fonts/ch_standard.ttf b/StyleText/fonts/ch_standard.ttf new file mode 100644 index 0000000..cdb7fa5 Binary files /dev/null and b/StyleText/fonts/ch_standard.ttf differ diff --git a/StyleText/fonts/en_standard.ttf b/StyleText/fonts/en_standard.ttf new file mode 100644 index 0000000..2e31d02 Binary files /dev/null and b/StyleText/fonts/en_standard.ttf differ diff --git a/StyleText/fonts/ko_standard.ttf b/StyleText/fonts/ko_standard.ttf new file mode 100644 index 0000000..982bd87 Binary files /dev/null and b/StyleText/fonts/ko_standard.ttf differ diff --git a/StyleText/tools/__init__.py b/StyleText/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/StyleText/tools/synth_dataset.py b/StyleText/tools/synth_dataset.py new file mode 100644 index 0000000..a75f7f3 --- /dev/null +++ b/StyleText/tools/synth_dataset.py @@ -0,0 +1,31 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +from engine.synthesisers import DatasetSynthesiser + + +def synth_dataset(): + dataset_synthesiser = DatasetSynthesiser() + dataset_synthesiser.synth_dataset() + + +if __name__ == '__main__': + synth_dataset() diff --git a/StyleText/tools/synth_image.py b/StyleText/tools/synth_image.py new file mode 100644 index 0000000..cbc3118 --- /dev/null +++ b/StyleText/tools/synth_image.py @@ -0,0 +1,82 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import cv2 +import sys +import glob + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +from utils.config import ArgsParser +from engine.synthesisers import ImageSynthesiser + + +def synth_image(): + args = ArgsParser().parse_args() + image_synthesiser = ImageSynthesiser() + style_image_path = args.style_image + img = cv2.imread(style_image_path) + text_corpus = args.text_corpus + language = args.language + + synth_result = image_synthesiser.synth_image(text_corpus, img, language) + fake_fusion = synth_result["fake_fusion"] + fake_text = synth_result["fake_text"] + fake_bg = synth_result["fake_bg"] + cv2.imwrite("fake_fusion.jpg", fake_fusion) + cv2.imwrite("fake_text.jpg", fake_text) + cv2.imwrite("fake_bg.jpg", fake_bg) + + +def batch_synth_images(): + image_synthesiser = ImageSynthesiser() + + corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt" + style_data_dir = "../StyleTextRec_data/test_20201208/style_images/" + save_path = "./output_data/" + corpus_list = [] + with open(corpus_file, "rb") as fin: + lines = fin.readlines() + for line in lines: + substr = line.decode("utf-8").strip("\n").split("\t") + corpus_list.append(substr) + style_img_list = glob.glob("{}/*.jpg".format(style_data_dir)) + corpus_num = len(corpus_list) + style_img_num = len(style_img_list) + for cno in range(corpus_num): + for sno in range(style_img_num): + corpus, lang = corpus_list[cno] + style_img_path = style_img_list[sno] + img = cv2.imread(style_img_path) + synth_result = image_synthesiser.synth_image(corpus, img, lang) + fake_fusion = synth_result["fake_fusion"] + fake_text = synth_result["fake_text"] + fake_bg = synth_result["fake_bg"] + for tp in range(2): + if tp == 0: + prefix = "%s/c%d_s%d_" % (save_path, cno, sno) + else: + prefix = "%s/s%d_c%d_" % (save_path, sno, cno) + cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion) + cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text) + cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg) + cv2.imwrite("%s_input_style.jpg" % prefix, img) + print(cno, corpus_num, sno, style_img_num) + + +if __name__ == '__main__': + # batch_synth_images() + synth_image() diff --git a/StyleText/utils/__init__.py b/StyleText/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/StyleText/utils/config.py b/StyleText/utils/config.py new file mode 100644 index 0000000..b2f8a61 --- /dev/null +++ b/StyleText/utils/config.py @@ -0,0 +1,224 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import yaml +import os +from argparse import ArgumentParser, RawDescriptionHelpFormatter + + +def override(dl, ks, v): + """ + Recursively replace dict of list + + Args: + dl(dict or list): dict or list to be replaced + ks(list): list of keys + v(str): value to be replaced + """ + + def str2num(v): + try: + return eval(v) + except Exception: + return v + + assert isinstance(dl, (list, dict)), ("{} should be a list or a dict") + assert len(ks) > 0, ('lenght of keys should larger than 0') + if isinstance(dl, list): + k = str2num(ks[0]) + if len(ks) == 1: + assert k < len(dl), ('index({}) out of range({})'.format(k, dl)) + dl[k] = str2num(v) + else: + override(dl[k], ks[1:], v) + else: + if len(ks) == 1: + #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl)) + if not ks[0] in dl: + logger.warning('A new filed ({}) detected!'.format(ks[0], dl)) + dl[ks[0]] = str2num(v) + else: + assert ks[0] in dl, ( + '({}) doesn\'t exist in {}, a new dict field is invalid'. + format(ks[0], dl)) + override(dl[ks[0]], ks[1:], v) + + +def override_config(config, options=None): + """ + Recursively override the config + + Args: + config(dict): dict to be replaced + options(list): list of pairs(key0.key1.idx.key2=value) + such as: [ + 'topk=2', + 'VALID.transforms.1.ResizeImage.resize_short=300' + ] + + Returns: + config(dict): replaced config + """ + if options is not None: + for opt in options: + assert isinstance(opt, str), ( + "option({}) should be a str".format(opt)) + assert "=" in opt, ( + "option({}) should contain a =" + "to distinguish between key and value".format(opt)) + pair = opt.split('=') + assert len(pair) == 2, ("there can be only a = in the option") + key, value = pair + keys = key.split('.') + override(config, keys, value) + + return config + + +class ArgsParser(ArgumentParser): + def __init__(self): + super(ArgsParser, self).__init__( + formatter_class=RawDescriptionHelpFormatter) + self.add_argument("-c", "--config", help="configuration file to use") + self.add_argument( + "-t", "--tag", default="0", help="tag for marking worker") + self.add_argument( + '-o', + '--override', + action='append', + default=[], + help='config options to be overridden') + self.add_argument( + "--style_image", default="examples/style_images/1.jpg", help="tag for marking worker") + self.add_argument( + "--text_corpus", default="PaddleOCR", help="tag for marking worker") + self.add_argument( + "--language", default="en", help="tag for marking worker") + + def parse_args(self, argv=None): + args = super(ArgsParser, self).parse_args(argv) + assert args.config is not None, \ + "Please specify --config=configure_file_path." + return args + + +def load_config(file_path): + """ + Load config from yml/yaml file. + Args: + file_path (str): Path of the config file to be loaded. + Returns: config + """ + ext = os.path.splitext(file_path)[1] + assert ext in ['.yml', '.yaml'], "only support yaml files for now" + with open(file_path, 'rb') as f: + config = yaml.load(f, Loader=yaml.Loader) + + return config + + +def gen_config(): + base_config = { + "Global": { + "algorithm": "SRNet", + "use_gpu": True, + "start_epoch": 1, + "stage1_epoch_num": 100, + "stage2_epoch_num": 100, + "log_smooth_window": 20, + "print_batch_step": 2, + "save_model_dir": "./output/SRNet", + "use_visualdl": False, + "save_epoch_step": 10, + "vgg_pretrain": "./pretrained/VGG19_pretrained", + "vgg_load_static_pretrain": True + }, + "Architecture": { + "model_type": "data_aug", + "algorithm": "SRNet", + "net_g": { + "name": "srnet_net_g", + "encode_dim": 64, + "norm": "batch", + "use_dropout": False, + "init_type": "xavier", + "init_gain": 0.02, + "use_dilation": 1 + }, + # input_nc, ndf, netD, + # n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0' + "bg_discriminator": { + "name": "srnet_bg_discriminator", + "input_nc": 6, + "ndf": 64, + "netD": "basic", + "norm": "none", + "init_type": "xavier", + }, + "fusion_discriminator": { + "name": "srnet_fusion_discriminator", + "input_nc": 6, + "ndf": 64, + "netD": "basic", + "norm": "none", + "init_type": "xavier", + } + }, + "Loss": { + "lamb": 10, + "perceptual_lamb": 1, + "muvar_lamb": 50, + "style_lamb": 500 + }, + "Optimizer": { + "name": "Adam", + "learning_rate": { + "name": "lambda", + "lr": 0.0002, + "lr_decay_iters": 50 + }, + "beta1": 0.5, + "beta2": 0.999, + }, + "Train": { + "batch_size_per_card": 8, + "num_workers_per_card": 4, + "dataset": { + "delimiter": "\t", + "data_dir": "/", + "label_file": "tmp/label.txt", + "transforms": [{ + "DecodeImage": { + "to_rgb": True, + "to_np": False, + "channel_first": False + } + }, { + "NormalizeImage": { + "scale": 1. / 255., + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "order": None + } + }, { + "ToCHWImage": None + }] + } + } + } + with open("config.yml", "w") as f: + yaml.dump(base_config, f) + + +if __name__ == '__main__': + gen_config() diff --git a/StyleText/utils/load_params.py b/StyleText/utils/load_params.py new file mode 100644 index 0000000..be05613 --- /dev/null +++ b/StyleText/utils/load_params.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import paddle + +__all__ = ['load_dygraph_pretrain'] + + +def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): + if not os.path.exists(path + '.pdparams'): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + param_state_dict = paddle.load(path + '.pdparams') + model.set_state_dict(param_state_dict) + logger.info("load pretrained model from {}".format(path)) + return diff --git a/StyleText/utils/logging.py b/StyleText/utils/logging.py new file mode 100644 index 0000000..f700fe2 --- /dev/null +++ b/StyleText/utils/logging.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import logging +import functools +import paddle.distributed as dist + +logger_initialized = {} + + +@functools.lru_cache() +def get_logger(name='srnet', log_file=None, log_level=logging.INFO): + """Initialize and get a logger by name. + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified a FileHandler will also be added. + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + formatter = logging.Formatter( + '[%(asctime)s] %(name)s %(levelname)s: %(message)s', + datefmt="%Y/%m/%d %H:%M:%S") + + stream_handler = logging.StreamHandler(stream=sys.stdout) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + if log_file is not None and dist.get_rank() == 0: + log_file_folder = os.path.split(log_file)[0] + os.makedirs(log_file_folder, exist_ok=True) + file_handler = logging.FileHandler(log_file, 'a') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + if dist.get_rank() == 0: + logger.setLevel(log_level) + else: + logger.setLevel(logging.ERROR) + logger_initialized[name] = True + return logger diff --git a/StyleText/utils/math_functions.py b/StyleText/utils/math_functions.py new file mode 100644 index 0000000..3dc8d91 --- /dev/null +++ b/StyleText/utils/math_functions.py @@ -0,0 +1,45 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle + + +def compute_mean_covariance(img): + batch_size = img.shape[0] + channel_num = img.shape[1] + height = img.shape[2] + width = img.shape[3] + num_pixels = height * width + + # batch_size * channel_num * 1 * 1 + mu = img.mean(2, keepdim=True).mean(3, keepdim=True) + + # batch_size * channel_num * num_pixels + img_hat = img - mu.expand_as(img) + img_hat = img_hat.reshape([batch_size, channel_num, num_pixels]) + # batch_size * num_pixels * channel_num + img_hat_transpose = img_hat.transpose([0, 2, 1]) + # batch_size * channel_num * channel_num + covariance = paddle.bmm(img_hat, img_hat_transpose) + covariance = covariance / num_pixels + + return mu, covariance + + +def dice_coefficient(y_true_cls, y_pred_cls, training_mask): + eps = 1e-5 + intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask) + union = paddle.sum(y_true_cls * training_mask) + paddle.sum( + y_pred_cls * training_mask) + eps + loss = 1. - (2 * intersection / union) + return loss diff --git a/StyleText/utils/sys_funcs.py b/StyleText/utils/sys_funcs.py new file mode 100644 index 0000000..203d91d --- /dev/null +++ b/StyleText/utils/sys_funcs.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import os +import errno +import paddle + + +def get_check_global_params(mode): + check_params = [ + 'use_gpu', 'max_text_length', 'image_shape', 'image_shape', + 'character_type', 'loss_type' + ] + if mode == "train_eval": + check_params = check_params + [ + 'train_batch_size_per_card', 'test_batch_size_per_card' + ] + elif mode == "test": + check_params = check_params + ['test_batch_size_per_card'] + return check_params + + +def check_gpu(use_gpu): + """ + Log error and exit when set use_gpu=true in paddlepaddle + cpu version. + """ + err = "Config use_gpu cannot be set as true while you are " \ + "using paddlepaddle cpu version ! \nPlease try: \n" \ + "\t1. Install paddlepaddle-gpu to run model on GPU \n" \ + "\t2. Set use_gpu as false in config file to run " \ + "model on CPU" + if use_gpu: + try: + if not paddle.is_compiled_with_cuda(): + print(err) + sys.exit(1) + except: + print("Fail to check gpu state.") + sys.exit(1) + + +def _mkdir_if_not_exist(path, logger): + """ + mkdir if not exists, ignore the exception when multiprocess mkdir together + """ + if not os.path.exists(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(path): + logger.warning( + 'be happy if some process has already created {}'.format( + path)) + else: + raise OSError('Failed to mkdir {}'.format(path)) diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..a7c32e9 --- /dev/null +++ b/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .paddleocr import * + +__version__ = paddleocr.VERSION +__all__ = [ + 'PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', + 'save_structure_res', 'download_with_progressbar', 'sorted_layout_boxes', + 'convert_info_docx', 'to_excel' +] diff --git "a/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/PCB\345\255\227\347\254\246\350\257\206\345\210\253.md" "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/PCB\345\255\227\347\254\246\350\257\206\345\210\253.md" new file mode 100644 index 0000000..c695e82 --- /dev/null +++ "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/PCB\345\255\227\347\254\246\350\257\206\345\210\253.md" @@ -0,0 +1,652 @@ +# 基于PP-OCRv3的PCB字符识别 + +- [1. 项目介绍](#1-项目介绍) +- [2. 安装说明](#2-安装说明) +- [3. 数据准备](#3-数据准备) +- [4. 文本检测](#4-文本检测) + - [4.1 预训练模型直接评估](#41-预训练模型直接评估) + - [4.2 预训练模型+验证集padding直接评估](#42-预训练模型验证集padding直接评估) + - [4.3 预训练模型+fine-tune](#43-预训练模型fine-tune) +- [5. 文本识别](#5-文本识别) + - [5.1 预训练模型直接评估](#51-预训练模型直接评估) + - [5.2 三种fine-tune方案](#52-三种fine-tune方案) +- [6. 模型导出](#6-模型导出) +- [7. 端对端评测](#7-端对端评测) +- [8. Jetson部署](#8-Jetson部署) +- [9. 总结](#9-总结) +- [更多资源](#更多资源) + +# 1. 项目介绍 + +印刷电路板(PCB)是电子产品中的核心器件,对于板件质量的测试与监控是生产中必不可少的环节。在一些场景中,通过PCB中信号灯颜色和文字组合可以定位PCB局部模块质量问题,PCB文字识别中存在如下难点: + +- 裁剪出的PCB图片宽高比例较小 +- 文字区域整体面积也较小 +- 包含垂直、水平多种方向文本 + +针对本场景,PaddleOCR基于全新的PP-OCRv3通过合成数据、微调以及其他场景适配方法完成小字符文本识别任务,满足企业上线要求。PCB检测、识别效果如 **图1** 所示: + +
+
图1 PCB检测识别效果
+ +注:欢迎在AIStudio领取免费算力体验线上实训,项目链接: [基于PP-OCRv3实现PCB字符识别](https://aistudio.baidu.com/aistudio/projectdetail/4008973) + +# 2. 安装说明 + + +下载PaddleOCR源码,安装依赖环境。 + + +```python +# 如仍需安装or安装更新,可以执行以下步骤 +git clone https://github.com/PaddlePaddle/PaddleOCR.git +# git clone https://gitee.com/PaddlePaddle/PaddleOCR +``` + + +```python +# 安装依赖包 +pip install -r /home/aistudio/PaddleOCR/requirements.txt +``` + +# 3. 数据准备 + +我们通过图片合成工具生成 **图2** 所示的PCB图片,整图只有高25、宽150左右、文字区域高9、宽45左右,包含垂直和水平2种方向的文本: + +
+
图2 数据集示例
+ +暂时不开源生成的PCB数据集,但是通过更换背景,通过如下代码生成数据即可: + +``` +cd gen_data +python3 gen.py --num_img=10 +``` + +生成图片参数解释: + +``` +num_img:生成图片数量 +font_min_size、font_max_size:字体最大、最小尺寸 +bg_path:文字区域背景存放路径 +det_bg_path:整图背景存放路径 +fonts_path:字体路径 +corpus_path:语料路径 +output_dir:生成图片存储路径 +``` + +这里生成 **100张** 相同尺寸和文本的图片,如 **图3** 所示,方便大家跑通实验。通过如下代码解压数据集: + +
+
图3 案例提供数据集示例
+ + +```python +tar xf ./data/data148165/dataset.tar -C ./ +``` + +在生成数据集的时需要生成检测和识别训练需求的格式: + + +- **文本检测** + +标注文件格式如下,中间用'\t'分隔: + +``` +" 图像文件名 json.dumps编码的图像标注信息" +ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}] +``` + +json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 `transcription` 表示当前文本框的文字,***当其内容为“###”时,表示该文本框无效,在训练时会跳过。*** + +- **文本识别** + +标注文件的格式如下, txt文件中默认请将图片路径和图片标签用'\t'分割,如用其他方式分割将造成训练报错。 + +``` +" 图像文件名 图像标注信息 " + +train_data/rec/train/word_001.jpg 简单可依赖 +train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单 +... +``` + + +# 4. 文本检测 + +选用飞桨OCR开发套件[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)中的PP-OCRv3模型进行文本检测和识别。针对检测模型和识别模型,进行了共计9个方面的升级: + +- PP-OCRv3检测模型对PP-OCRv2中的CML协同互学习文本检测蒸馏策略进行了升级,分别针对教师模型和学生模型进行进一步效果优化。其中,在对教师模型优化时,提出了大感受野的PAN结构LK-PAN和引入了DML蒸馏策略;在对学生模型优化时,提出了残差注意力机制的FPN结构RSE-FPN。 + +- PP-OCRv3的识别模块是基于文本识别算法SVTR优化。SVTR不再采用RNN结构,通过引入Transformers结构更加有效地挖掘文本行图像的上下文信息,从而提升文本识别能力。PP-OCRv3通过轻量级文本识别网络SVTR_LCNet、Attention损失指导CTC损失训练策略、挖掘文字上下文信息的数据增广策略TextConAug、TextRotNet自监督预训练模型、UDML联合互学习策略、UIM无标注数据挖掘方案,6个方面进行模型加速和效果提升。 + +更多细节请参考PP-OCRv3[技术报告](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/PP-OCRv3_introduction.md)。 + + +我们使用 **3种方案** 进行检测模型的训练、评估: +- **PP-OCRv3英文超轻量检测预训练模型直接评估** +- PP-OCRv3英文超轻量检测预训练模型 + **验证集padding**直接评估 +- PP-OCRv3英文超轻量检测预训练模型 + **fine-tune** + +## **4.1 预训练模型直接评估** + +我们首先通过PaddleOCR提供的预训练模型在验证集上进行评估,如果评估指标能满足效果,可以直接使用预训练模型,不再需要训练。 + +使用预训练模型直接评估步骤如下: + +**1)下载预训练模型** + + +PaddleOCR已经提供了PP-OCR系列模型,部分模型展示如下表所示: + +| 模型简介 | 模型名称 | 推荐场景 | 检测模型 | 方向分类器 | 识别模型 | +| ------------------------------------- | ----------------------- | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | +| 中英文超轻量PP-OCRv3模型(16.2M) | ch_PP-OCRv3_xx | 移动端&服务器端 | [推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar) | +| 英文超轻量PP-OCRv3模型(13.4M) | en_PP-OCRv3_xx | 移动端&服务器端 | [推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_distill_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_train.tar) | +| 中英文超轻量PP-OCRv2模型(13.0M) | ch_PP-OCRv2_xx | 移动端&服务器端 | [推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar) | +| 中英文超轻量PP-OCR mobile模型(9.4M) | ch_ppocr_mobile_v2.0_xx | 移动端&服务器端 | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_pre.tar) | +| 中英文通用PP-OCR server模型(143.4M) | ch_ppocr_server_v2.0_xx | 服务器端 | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) | + +更多模型下载(包括多语言),可以参[考PP-OCR系列模型下载](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md) + +这里我们使用PP-OCRv3英文超轻量检测模型,下载并解压预训练模型: + + + + +```python +# 如果更换其他模型,更新下载链接和解压指令就可以 +cd /home/aistudio/PaddleOCR +mkdir pretrain_models +cd pretrain_models +# 下载英文预训练模型 +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_distill_train.tar +tar xf en_PP-OCRv3_det_distill_train.tar && rm -rf en_PP-OCRv3_det_distill_train.tar +%cd .. +``` + +**模型评估** + + +首先修改配置文件`configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml`中的以下字段: +``` +Eval.dataset.data_dir:指向验证集图片存放目录,'/home/aistudio/dataset' +Eval.dataset.label_file_list:指向验证集标注文件,'/home/aistudio/dataset/det_gt_val.txt' +Eval.dataset.transforms.DetResizeForTest: 尺寸 + limit_side_len: 48 + limit_type: 'min' +``` + +然后在验证集上进行评估,具体代码如下: + + + +```python +cd /home/aistudio/PaddleOCR +python tools/eval.py \ + -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml \ + -o Global.checkpoints="./pretrain_models/en_PP-OCRv3_det_distill_train/best_accuracy" +``` + +## **4.2 预训练模型+验证集padding直接评估** + +考虑到PCB图片比较小,宽度只有25左右、高度只有140-170左右,我们在原图的基础上进行padding,再进行检测评估,padding前后效果对比如 **图4** 所示: + +
+
图4 padding前后对比图
+ +将图片都padding到300*300大小,因为坐标信息发生了变化,我们同时要修改标注文件,在`/home/aistudio/dataset`目录里也提供了padding之后的图片,大家也可以尝试训练和评估: + +同上,我们需要修改配置文件`configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml`中的以下字段: +``` +Eval.dataset.data_dir:指向验证集图片存放目录,'/home/aistudio/dataset' +Eval.dataset.label_file_list:指向验证集标注文件,/home/aistudio/dataset/det_gt_padding_val.txt +Eval.dataset.transforms.DetResizeForTest: 尺寸 + limit_side_len: 1100 + limit_type: 'min' +``` + +如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 +
+ +
+将下载或训练完成的模型放置在对应目录下即可完成模型评估。 + + +```python +cd /home/aistudio/PaddleOCR +python tools/eval.py \ + -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml \ + -o Global.checkpoints="./pretrain_models/en_PP-OCRv3_det_distill_train/best_accuracy" +``` + +## **4.3 预训练模型+fine-tune** + + +基于预训练模型,在生成的1500图片上进行fine-tune训练和评估,其中train数据1200张,val数据300张,修改配置文件`configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml`中的以下字段: +``` +Global.epoch_num: 这里设置为1,方便快速跑通,实际中根据数据量调整该值 +Global.save_model_dir:模型保存路径 +Global.pretrained_model:指向预训练模型路径,'./pretrain_models/en_PP-OCRv3_det_distill_train/student.pdparams' +Optimizer.lr.learning_rate:调整学习率,本实验设置为0.0005 +Train.dataset.data_dir:指向训练集图片存放目录,'/home/aistudio/dataset' +Train.dataset.label_file_list:指向训练集标注文件,'/home/aistudio/dataset/det_gt_train.txt' +Train.dataset.transforms.EastRandomCropData.size:训练尺寸改为[480,64] +Eval.dataset.data_dir:指向验证集图片存放目录,'/home/aistudio/dataset/' +Eval.dataset.label_file_list:指向验证集标注文件,'/home/aistudio/dataset/det_gt_val.txt' +Eval.dataset.transforms.DetResizeForTest:评估尺寸,添加如下参数 + limit_side_len: 64 + limit_type:'min' +``` +执行下面命令启动训练: + + +```python +cd /home/aistudio/PaddleOCR/ +python tools/train.py \ + -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml +``` + +**模型评估** + + +使用训练好的模型进行评估,更新模型路径`Global.checkpoints`: + + +```python +cd /home/aistudio/PaddleOCR/ +python3 tools/eval.py \ + -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml \ + -o Global.checkpoints="./output/ch_PP-OCR_V3_det/latest" +``` + +使用训练好的模型进行评估,指标如下所示: + + +| 序号 | 方案 | hmean | 效果提升 | 实验分析 | +| -------- | -------- | -------- | -------- | -------- | +| 1 | PP-OCRv3英文超轻量检测预训练模型 | 64.64% | - | 提供的预训练模型具有泛化能力 | +| 2 | PP-OCRv3英文超轻量检测预训练模型 + 验证集padding | 72.13% |+7.49% | padding可以提升尺寸较小图片的检测效果| +| 3 | PP-OCRv3英文超轻量检测预训练模型 + fine-tune | 100.00% | +27.87% | fine-tune会提升垂类场景效果 | + + +``` +注:上述实验结果均是在1500张图片(1200张训练集,300张测试集)上训练、评估的得到,AIstudio只提供了100张数据,所以指标有所差异属于正常,只要策略有效、规律相同即可。 +``` + +# 5. 文本识别 + +我们分别使用如下4种方案进行训练、评估: + +- **方案1**:**PP-OCRv3中英文超轻量识别预训练模型直接评估** +- **方案2**:PP-OCRv3中英文超轻量检测预训练模型 + **fine-tune** +- **方案3**:PP-OCRv3中英文超轻量检测预训练模型 + fine-tune + **公开通用识别数据集** +- **方案4**:PP-OCRv3中英文超轻量检测预训练模型 + fine-tune + **增加PCB图像数量** + + +## **5.1 预训练模型直接评估** + +同检测模型,我们首先使用PaddleOCR提供的识别预训练模型在PCB验证集上进行评估。 + +使用预训练模型直接评估步骤如下: + +**1)下载预训练模型** + + +我们使用PP-OCRv3中英文超轻量文本识别模型,下载并解压预训练模型: + + +```python +# 如果更换其他模型,更新下载链接和解压指令就可以 +cd /home/aistudio/PaddleOCR/pretrain_models/ +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar +tar xf ch_PP-OCRv3_rec_train.tar && rm -rf ch_PP-OCRv3_rec_train.tar +cd .. +``` + +**模型评估** + + +首先修改配置文件`configs/det/ch_PP-OCRv3/ch_PP-OCRv2_rec_distillation.yml`中的以下字段: + +``` +Metric.ignore_space: True:忽略空格 +Eval.dataset.data_dir:指向验证集图片存放目录,'/home/aistudio/dataset' +Eval.dataset.label_file_list:指向验证集标注文件,'/home/aistudio/dataset/rec_gt_val.txt' +``` + +我们使用下载的预训练模型进行评估: + + +```python +cd /home/aistudio/PaddleOCR +python3 tools/eval.py \ + -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \ + -o Global.checkpoints=pretrain_models/ch_PP-OCRv3_rec_train/best_accuracy + +``` + +## **5.2 三种fine-tune方案** + +方案2、3、4训练和评估方式是相同的,因此在我们了解每个技术方案之后,再具体看修改哪些参数是相同,哪些是不同的。 + +**方案介绍:** + +1) **方案2**:预训练模型 + **fine-tune** + +- 在预训练模型的基础上进行fine-tune,使用1500张PCB进行训练和评估,其中训练集1200张,验证集300张。 + + +2) **方案3**:预训练模型 + fine-tune + **公开通用识别数据集** + +- 当识别数据比较少的情况,可以考虑添加公开通用识别数据集。在方案2的基础上,添加公开通用识别数据集,如lsvt、rctw等。 + +3)**方案4**:预训练模型 + fine-tune + **增加PCB图像数量** + +- 如果能够获取足够多真实场景,我们可以通过增加数据量提升模型效果。在方案2的基础上,增加PCB的数量到2W张左右。 + + +**参数修改:** + +接着我们看需要修改的参数,以上方案均需要修改配置文件`configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml`的参数,**修改一次即可**: + +``` +Global.pretrained_model:指向预训练模型路径,'pretrain_models/ch_PP-OCRv3_rec_train/best_accuracy' +Optimizer.lr.values:学习率,本实验设置为0.0005 +Train.loader.batch_size_per_card: batch size,默认128,因为数据量小于128,因此我们设置为8,数据量大可以按默认的训练 +Eval.loader.batch_size_per_card: batch size,默认128,设置为4 +Metric.ignore_space: 忽略空格,本实验设置为True +``` + +**更换不同的方案**每次需要修改的参数: +``` +Global.epoch_num: 这里设置为1,方便快速跑通,实际中根据数据量调整该值 +Global.save_model_dir:指向模型保存路径 +Train.dataset.data_dir:指向训练集图片存放目录 +Train.dataset.label_file_list:指向训练集标注文件 +Eval.dataset.data_dir:指向验证集图片存放目录 +Eval.dataset.label_file_list:指向验证集标注文件 +``` + +同时**方案3**修改以下参数 +``` +Eval.dataset.label_file_list:添加公开通用识别数据标注文件 +Eval.dataset.ratio_list:数据和公开通用识别数据每次采样比例,按实际修改即可 +``` +如 **图5** 所示: +
+
图5 添加公开通用识别数据配置文件示例
+ + +我们提取Student模型的参数,在PCB数据集上进行fine-tune,可以参考如下代码: + + +```python +import paddle +# 加载预训练模型 +all_params = paddle.load("./pretrain_models/ch_PP-OCRv3_rec_train/best_accuracy.pdparams") +# 查看权重参数的keys +print(all_params.keys()) +# 学生模型的权重提取 +s_params = {key[len("student_model."):]: all_params[key] for key in all_params if "student_model." in key} +# 查看学生模型权重参数的keys +print(s_params.keys()) +# 保存 +paddle.save(s_params, "./pretrain_models/ch_PP-OCRv3_rec_train/student.pdparams") +``` + +修改参数后,**每个方案**都执行如下命令启动训练: + + + +```python +cd /home/aistudio/PaddleOCR/ +python3 tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml +``` + + +使用训练好的模型进行评估,更新模型路径`Global.checkpoints`: + + +```python +cd /home/aistudio/PaddleOCR/ +python3 tools/eval.py \ + -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml \ + -o Global.checkpoints=./output/rec_ppocr_v3/latest +``` + +所有方案评估指标如下: + +| 序号 | 方案 | acc | 效果提升 | 实验分析 | +| -------- | -------- | -------- | -------- | -------- | +| 1 | PP-OCRv3中英文超轻量识别预训练模型直接评估 | 46.67% | - | 提供的预训练模型具有泛化能力 | +| 2 | PP-OCRv3中英文超轻量识别预训练模型 + fine-tune | 42.02% |-4.65% | 在数据量不足的情况,反而比预训练模型效果低(也可以通过调整超参数再试试)| +| 3 | PP-OCRv3中英文超轻量识别预训练模型 + fine-tune + 公开通用识别数据集 | 77.00% | +30.33% | 在数据量不足的情况下,可以考虑补充公开数据训练 | +| 4 | PP-OCRv3中英文超轻量识别预训练模型 + fine-tune + 增加PCB图像数量 | 99.99% | +22.99% | 如果能获取更多数据量的情况,可以通过增加数据量提升效果 | + +``` +注:上述实验结果均是在1500张图片(1200张训练集,300张测试集)、2W张图片、添加公开通用识别数据集上训练、评估的得到,AIstudio只提供了100张数据,所以指标有所差异属于正常,只要策略有效、规律相同即可。 +``` + +# 6. 模型导出 + +inference 模型(paddle.jit.save保存的模型) 一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。 训练过程中保存的模型是checkpoints模型,保存的只有模型的参数,多用于恢复训练等。 与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。 + + +```python +# 导出检测模型 +python3 tools/export_model.py \ + -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml \ + -o Global.pretrained_model="./output/ch_PP-OCR_V3_det/latest" \ + Global.save_inference_dir="./inference_model/ch_PP-OCR_V3_det/" +``` + +因为上述模型只训练了1个epoch,因此我们使用训练最优的模型进行预测,存储在`/home/aistudio/best_models/`目录下,解压即可 + + +```python +cd /home/aistudio/best_models/ +wget https://paddleocr.bj.bcebos.com/fanliku/PCB/det_ppocr_v3_en_infer_PCB.tar +tar xf /home/aistudio/best_models/det_ppocr_v3_en_infer_PCB.tar -C /home/aistudio/PaddleOCR/pretrain_models/ +``` + + +```python +# 检测模型inference模型预测 +cd /home/aistudio/PaddleOCR/ +python3 tools/infer/predict_det.py \ + --image_dir="/home/aistudio/dataset/imgs/0000.jpg" \ + --det_algorithm="DB" \ + --det_model_dir="./pretrain_models/det_ppocr_v3_en_infer_PCB/" \ + --det_limit_side_len=48 \ + --det_limit_type='min' \ + --det_db_unclip_ratio=2.5 \ + --use_gpu=True +``` + +结果存储在`inference_results`目录下,检测如下图所示: +
+
图6 检测结果
+ + +同理,导出识别模型并进行推理。 + +```python +# 导出识别模型 +python3 tools/export_model.py \ + -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml \ + -o Global.pretrained_model="./output/rec_ppocr_v3/latest" \ + Global.save_inference_dir="./inference_model/rec_ppocr_v3/" + +``` + +同检测模型,识别模型也只训练了1个epoch,因此我们使用训练最优的模型进行预测,存储在`/home/aistudio/best_models/`目录下,解压即可 + + +```python +cd /home/aistudio/best_models/ +wget https://paddleocr.bj.bcebos.com/fanliku/PCB/rec_ppocr_v3_ch_infer_PCB.tar +tar xf /home/aistudio/best_models/rec_ppocr_v3_ch_infer_PCB.tar -C /home/aistudio/PaddleOCR/pretrain_models/ +``` + + +```python +# 识别模型inference模型预测 +cd /home/aistudio/PaddleOCR/ +python3 tools/infer/predict_rec.py \ + --image_dir="../test_imgs/0000_rec.jpg" \ + --rec_model_dir="./pretrain_models/rec_ppocr_v3_ch_infer_PCB" \ + --rec_image_shape="3, 48, 320" \ + --use_space_char=False \ + --use_gpu=True +``` + +```python +# 检测+识别模型inference模型预测 +cd /home/aistudio/PaddleOCR/ +python3 tools/infer/predict_system.py \ + --image_dir="../test_imgs/0000.jpg" \ + --det_model_dir="./pretrain_models/det_ppocr_v3_en_infer_PCB" \ + --det_limit_side_len=48 \ + --det_limit_type='min' \ + --det_db_unclip_ratio=2.5 \ + --rec_model_dir="./pretrain_models/rec_ppocr_v3_ch_infer_PCB" \ + --rec_image_shape="3, 48, 320" \ + --draw_img_save_dir=./det_rec_infer/ \ + --use_space_char=False \ + --use_angle_cls=False \ + --use_gpu=True + +``` + +端到端预测结果存储在`det_res_infer`文件夹内,结果如下图所示: +
+
图7 检测+识别结果
+ +# 7. 端对端评测 + +接下来介绍文本检测+文本识别的端对端指标评估方式。主要分为三步: + +1)首先运行`tools/infer/predict_system.py`,将`image_dir`改为需要评估的数据文件家,得到保存的结果: + + +```python +# 检测+识别模型inference模型预测 +python3 tools/infer/predict_system.py \ + --image_dir="../dataset/imgs/" \ + --det_model_dir="./pretrain_models/det_ppocr_v3_en_infer_PCB" \ + --det_limit_side_len=48 \ + --det_limit_type='min' \ + --det_db_unclip_ratio=2.5 \ + --rec_model_dir="./pretrain_models/rec_ppocr_v3_ch_infer_PCB" \ + --rec_image_shape="3, 48, 320" \ + --draw_img_save_dir=./det_rec_infer/ \ + --use_space_char=False \ + --use_angle_cls=False \ + --use_gpu=True +``` + +得到保存结果,文本检测识别可视化图保存在`det_rec_infer/`目录下,预测结果保存在`det_rec_infer/system_results.txt`中,格式如下:`0018.jpg [{"transcription": "E295", "points": [[88, 33], [137, 33], [137, 40], [88, 40]]}]` + +2)然后将步骤一保存的数据转换为端对端评测需要的数据格式: 修改 `tools/end2end/convert_ppocr_label.py`中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。 +``` +ppocr_label_gt = "/home/aistudio/dataset/det_gt_val.txt" +convert_label(ppocr_label_gt, "gt", "./save_gt_label/") + +ppocr_label_gt = "/home/aistudio/PaddleOCR/PCB_result/det_rec_infer/system_results.txt" +convert_label(ppocr_label_gt, "pred", "./save_PPOCRV2_infer/") +``` + +运行`convert_ppocr_label.py`: + + +```python + python3 tools/end2end/convert_ppocr_label.py +``` + +得到如下结果: +``` +├── ./save_gt_label/ +├── ./save_PPOCRV2_infer/ +``` + +3) 最后,执行端对端评测,运行`tools/end2end/eval_end2end.py`计算端对端指标,运行方式如下: + + +```python +pip install editdistance +python3 tools/end2end/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/ +``` + +使用`预训练模型+fine-tune'检测模型`、`预训练模型 + 2W张PCB图片funetune`识别模型,在300张PCB图片上评估得到如下结果,fmeasure为主要关注的指标: +
+
图8 端到端评估指标
+ +``` +注: 使用上述命令不能跑出该结果,因为数据集不相同,可以更换为自己训练好的模型,按上述流程运行 +``` + +# 8. Jetson部署 + +我们只需要以下步骤就可以完成Jetson nano部署模型,简单易操作: + +**1、在Jetson nano开发版上环境准备:** + +* 安装PaddlePaddle + +* 下载PaddleOCR并安装依赖 + +**2、执行预测** + +* 将推理模型下载到jetson + +* 执行检测、识别、串联预测即可 + +详细[参考流程](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/deploy/Jetson/readme_ch.md)。 + +# 9. 总结 + +检测实验分别使用PP-OCRv3预训练模型在PCB数据集上进行了直接评估、验证集padding、 fine-tune 3种方案,识别实验分别使用PP-OCRv3预训练模型在PCB数据集上进行了直接评估、 fine-tune、添加公开通用识别数据集、增加PCB图片数量4种方案,指标对比如下: + +* 检测 + + +| 序号 | 方案 | hmean | 效果提升 | 实验分析 | +| ---- | -------------------------------------------------------- | ------ | -------- | ------------------------------------- | +| 1 | PP-OCRv3英文超轻量检测预训练模型直接评估 | 64.64% | - | 提供的预训练模型具有泛化能力 | +| 2 | PP-OCRv3英文超轻量检测预训练模型 + 验证集padding直接评估 | 72.13% | +7.49% | padding可以提升尺寸较小图片的检测效果 | +| 3 | PP-OCRv3英文超轻量检测预训练模型 + fine-tune | 100.00% | +27.87% | fine-tune会提升垂类场景效果 | + +* 识别 + +| 序号 | 方案 | acc | 效果提升 | 实验分析 | +| ---- | ------------------------------------------------------------ | ------ | -------- | ------------------------------------------------------------ | +| 1 | PP-OCRv3中英文超轻量识别预训练模型直接评估 | 46.67% | - | 提供的预训练模型具有泛化能力 | +| 2 | PP-OCRv3中英文超轻量识别预训练模型 + fine-tune | 42.02% | -4.65% | 在数据量不足的情况,反而比预训练模型效果低(也可以通过调整超参数再试试) | +| 3 | PP-OCRv3中英文超轻量识别预训练模型 + fine-tune + 公开通用识别数据集 | 77.00% | +30.33% | 在数据量不足的情况下,可以考虑补充公开数据训练 | +| 4 | PP-OCRv3中英文超轻量识别预训练模型 + fine-tune + 增加PCB图像数量 | 99.99% | +22.99% | 如果能获取更多数据量的情况,可以通过增加数据量提升效果 | + +* 端到端 + +| det | rec | fmeasure | +| --------------------------------------------- | ------------------------------------------------------------ | -------- | +| PP-OCRv3英文超轻量检测预训练模型 + fine-tune | PP-OCRv3中英文超轻量识别预训练模型 + fine-tune + 增加PCB图像数量 | 93.30% | + +*结论* + +PP-OCRv3的检测模型在未经过fine-tune的情况下,在PCB数据集上也有64.64%的精度,说明具有泛化能力。验证集padding之后,精度提升7.5%,在图片尺寸较小的情况,我们可以通过padding的方式提升检测效果。经过 fine-tune 后能够极大的提升检测效果,精度达到100%。 + +PP-OCRv3的识别模型方案1和方案2对比可以发现,当数据量不足的情况,预训练模型精度可能比fine-tune效果还要高,所以我们可以先尝试预训练模型直接评估。如果在数据量不足的情况下想进一步提升模型效果,可以通过添加公开通用识别数据集,识别效果提升30%,非常有效。最后如果我们能够采集足够多的真实场景数据集,可以通过增加数据量提升模型效果,精度达到99.99%。 + +# 更多资源 + +- 更多深度学习知识、产业案例、面试宝典等,请参考:[awesome-DeepLearning](https://github.com/paddlepaddle/awesome-DeepLearning) + +- 更多PaddleOCR使用教程,请参考:[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph) + + +- 飞桨框架相关资料,请参考:[飞桨深度学习平台](https://www.paddlepaddle.org.cn/?fr=paddleEdu_aistudio) + +# 参考 + +* 数据生成代码库:https://github.com/zcswdt/Color_OCR_image_generator diff --git "a/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/background/bg.jpg" "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/background/bg.jpg" new file mode 100644 index 0000000..3cb6eab Binary files /dev/null and "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/background/bg.jpg" differ diff --git "a/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/corpus/text.txt" "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/corpus/text.txt" new file mode 100644 index 0000000..8b8cb79 --- /dev/null +++ "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/corpus/text.txt" @@ -0,0 +1,30 @@ +5ZQ +I4UL +PWL +SNOG +ZL02 +1C30 +O3H +YHRS +N03S +1U5Y +JTK +EN4F +YKJ +DWNH +R42W +X0V +4OF5 +08AM +Y93S +GWE2 +0KR +9U2A +DBQ +Y6J +ROZ +K06 +KIEY +NZQJ +UN1B +6X4 \ No newline at end of file diff --git "a/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/det_background/1.png" "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/det_background/1.png" new file mode 100644 index 0000000..8a49eaa Binary files /dev/null and "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/det_background/1.png" differ diff --git "a/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/det_background/2.png" "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/det_background/2.png" new file mode 100644 index 0000000..c3fcc0c Binary files /dev/null and "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/det_background/2.png" differ diff --git "a/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/gen.py" "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/gen.py" new file mode 100644 index 0000000..97024d1 --- /dev/null +++ "b/applications/PCB\345\255\227\347\254\246\350\257\206\345\210\253/gen_data/gen.py" @@ -0,0 +1,263 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/zcswdt/Color_OCR_image_generator +""" +import os +import random +from PIL import Image, ImageDraw, ImageFont +import json +import argparse + + +def get_char_lines(txt_root_path): + """ + desc:get corpus line + """ + txt_files = os.listdir(txt_root_path) + char_lines = [] + for txt in txt_files: + f = open(os.path.join(txt_root_path, txt), mode='r', encoding='utf-8') + lines = f.readlines() + f.close() + for line in lines: + char_lines.append(line.strip()) + return char_lines + + +def get_horizontal_text_picture(image_file, chars, fonts_list, cf): + """ + desc:gen horizontal text picture + """ + img = Image.open(image_file) + if img.mode != 'RGB': + img = img.convert('RGB') + img_w, img_h = img.size + + # random choice font + font_path = random.choice(fonts_list) + # random choice font size + font_size = random.randint(cf.font_min_size, cf.font_max_size) + font = ImageFont.truetype(font_path, font_size) + + ch_w = [] + ch_h = [] + for ch in chars: + left, top, right, bottom = font.getbbox(ch) + wt, ht = right - left, bottom - top + ch_w.append(wt) + ch_h.append(ht) + f_w = sum(ch_w) + f_h = max(ch_h) + + # add space + char_space_width = max(ch_w) + f_w += (char_space_width * (len(chars) - 1)) + + x1 = random.randint(0, img_w - f_w) + y1 = random.randint(0, img_h - f_h) + x2 = x1 + f_w + y2 = y1 + f_h + + crop_y1 = y1 + crop_x1 = x1 + crop_y2 = y2 + crop_x2 = x2 + + best_color = (0, 0, 0) + draw = ImageDraw.Draw(img) + for i, ch in enumerate(chars): + draw.text((x1, y1), ch, best_color, font=font) + x1 += (ch_w[i] + char_space_width) + crop_img = img.crop((crop_x1, crop_y1, crop_x2, crop_y2)) + return crop_img, chars + + +def get_vertical_text_picture(image_file, chars, fonts_list, cf): + """ + desc:gen vertical text picture + """ + img = Image.open(image_file) + if img.mode != 'RGB': + img = img.convert('RGB') + img_w, img_h = img.size + # random choice font + font_path = random.choice(fonts_list) + # random choice font size + font_size = random.randint(cf.font_min_size, cf.font_max_size) + font = ImageFont.truetype(font_path, font_size) + + ch_w = [] + ch_h = [] + for ch in chars: + left, top, right, bottom = font.getbbox(ch) + wt, ht = right - left, bottom - top + ch_w.append(wt) + ch_h.append(ht) + f_w = max(ch_w) + f_h = sum(ch_h) + + x1 = random.randint(0, img_w - f_w) + y1 = random.randint(0, img_h - f_h) + x2 = x1 + f_w + y2 = y1 + f_h + + crop_y1 = y1 + crop_x1 = x1 + crop_y2 = y2 + crop_x2 = x2 + + best_color = (0, 0, 0) + draw = ImageDraw.Draw(img) + i = 0 + for ch in chars: + draw.text((x1, y1), ch, best_color, font=font) + y1 = y1 + ch_h[i] + i = i + 1 + crop_img = img.crop((crop_x1, crop_y1, crop_x2, crop_y2)) + crop_img = crop_img.transpose(Image.ROTATE_90) + return crop_img, chars + + +def get_fonts(fonts_path): + """ + desc: get all fonts + """ + font_files = os.listdir(fonts_path) + fonts_list=[] + for font_file in font_files: + font_path=os.path.join(fonts_path, font_file) + fonts_list.append(font_path) + return fonts_list + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--num_img', type=int, default=30, help="Number of images to generate") + parser.add_argument('--font_min_size', type=int, default=11) + parser.add_argument('--font_max_size', type=int, default=12, + help="Help adjust the size of the generated text and the size of the picture") + parser.add_argument('--bg_path', type=str, default='./background', + help='The generated text pictures will be pasted onto the pictures of this folder') + parser.add_argument('--det_bg_path', type=str, default='./det_background', + help='The generated text pictures will use the pictures of this folder as the background') + parser.add_argument('--fonts_path', type=str, default='../../StyleText/fonts', + help='The font used to generate the picture') + parser.add_argument('--corpus_path', type=str, default='./corpus', + help='The corpus used to generate the text picture') + parser.add_argument('--output_dir', type=str, default='./output/', help='Images save dir') + + + cf = parser.parse_args() + # save path + if not os.path.exists(cf.output_dir): + os.mkdir(cf.output_dir) + + # get corpus + txt_root_path = cf.corpus_path + char_lines = get_char_lines(txt_root_path=txt_root_path) + + # get all fonts + fonts_path = cf.fonts_path + fonts_list = get_fonts(fonts_path) + + # rec bg + img_root_path = cf.bg_path + imnames=os.listdir(img_root_path) + + # det bg + det_bg_path = cf.det_bg_path + bg_pics = os.listdir(det_bg_path) + + # OCR det files + det_val_file = open(cf.output_dir + 'det_gt_val.txt', 'w', encoding='utf-8') + det_train_file = open(cf.output_dir + 'det_gt_train.txt', 'w', encoding='utf-8') + # det imgs + det_save_dir = 'imgs/' + if not os.path.exists(cf.output_dir + det_save_dir): + os.mkdir(cf.output_dir + det_save_dir) + det_val_save_dir = 'imgs_val/' + if not os.path.exists(cf.output_dir + det_val_save_dir): + os.mkdir(cf.output_dir + det_val_save_dir) + + # OCR rec files + rec_val_file = open(cf.output_dir + 'rec_gt_val.txt', 'w', encoding='utf-8') + rec_train_file = open(cf.output_dir + 'rec_gt_train.txt', 'w', encoding='utf-8') + # rec imgs + rec_save_dir = 'rec_imgs/' + if not os.path.exists(cf.output_dir + rec_save_dir): + os.mkdir(cf.output_dir + rec_save_dir) + rec_val_save_dir = 'rec_imgs_val/' + if not os.path.exists(cf.output_dir + rec_val_save_dir): + os.mkdir(cf.output_dir + rec_val_save_dir) + + + val_ratio = cf.num_img * 0.2 # val dataset ratio + + print('start generating...') + for i in range(0, cf.num_img): + imname = random.choice(imnames) + img_path = os.path.join(img_root_path, imname) + + rnd = random.random() + # gen horizontal text picture + if rnd < 0.5: + gen_img, chars = get_horizontal_text_picture(img_path, char_lines[i], fonts_list, cf) + ori_w, ori_h = gen_img.size + gen_img = gen_img.crop((0, 3, ori_w, ori_h)) + # gen vertical text picture + else: + gen_img, chars = get_vertical_text_picture(img_path, char_lines[i], fonts_list, cf) + ori_w, ori_h = gen_img.size + gen_img = gen_img.crop((3, 0, ori_w, ori_h)) + + ori_w, ori_h = gen_img.size + + # rec imgs + save_img_name = str(i).zfill(4) + '.jpg' + if i < val_ratio: + save_dir = os.path.join(rec_val_save_dir, save_img_name) + line = save_dir + '\t' + char_lines[i] + '\n' + rec_val_file.write(line) + else: + save_dir = os.path.join(rec_save_dir, save_img_name) + line = save_dir + '\t' + char_lines[i] + '\n' + rec_train_file.write(line) + gen_img.save(cf.output_dir + save_dir, quality = 95, subsampling=0) + + # det img + # random choice bg + bg_pic = random.sample(bg_pics, 1)[0] + det_img = Image.open(os.path.join(det_bg_path, bg_pic)) + # the PCB position is fixed, modify it according to your own scenario + if bg_pic == '1.png': + x1 = 38 + y1 = 3 + else: + x1 = 34 + y1 = 1 + + det_img.paste(gen_img, (x1, y1)) + # text pos + chars_pos = [[x1, y1], [x1 + ori_w, y1], [x1 + ori_w, y1 + ori_h], [x1, y1 + ori_h]] + label = [{"transcription":char_lines[i], "points":chars_pos}] + if i < val_ratio: + save_dir = os.path.join(det_val_save_dir, save_img_name) + det_val_file.write(save_dir + '\t' + json.dumps( + label, ensure_ascii=False) + '\n') + else: + save_dir = os.path.join(det_save_dir, save_img_name) + det_train_file.write(save_dir + '\t' + json.dumps( + label, ensure_ascii=False) + '\n') + det_img.save(cf.output_dir + save_dir, quality = 95, subsampling=0) diff --git a/applications/README.md b/applications/README.md new file mode 100644 index 0000000..950adf7 --- /dev/null +++ b/applications/README.md @@ -0,0 +1,78 @@ +[English](README_en.md) | 简体中文 + +# 场景应用 + +PaddleOCR场景应用覆盖通用,制造、金融、交通行业的主要OCR垂类应用,在PP-OCR、PP-Structure的通用能力基础之上,以notebook的形式展示利用场景数据微调、模型优化方法、数据增广等内容,为开发者快速落地OCR应用提供示范与启发。 + +- [教程文档](#1) + - [通用](#11) + - [制造](#12) + - [金融](#13) + - [交通](#14) + +- [模型下载](#2) + + + +## 教程文档 + + + +### 通用 + +| 类别 | 亮点 | 模型下载 | 教程 | 示例图 | +| ---------------------- | ------------------------------------------------------------ | -------------- | --------------------------------------- | ------------------------------------------------------------ | +| 高精度中文识别模型SVTR | 比PP-OCRv3识别模型精度高3%,
可用于数据挖掘或对预测效率要求不高的场景。 | [模型下载](#2) | [中文](./高精度中文识别模型.md)/English | | +| 手写体识别 | 新增字形支持 | [模型下载](#2) | [中文](./手写文字识别.md)/English | | + + + +### 制造 + +| 类别 | 亮点 | 模型下载 | 教程 | 示例图 | +| -------------- | ------------------------------ | -------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | +| 数码管识别 | 数码管数据合成、漏识别调优 | [模型下载](#2) | [中文](./光功率计数码管字符识别/光功率计数码管字符识别.md)/English | | +| 液晶屏读数识别 | 检测模型蒸馏、Serving部署 | [模型下载](#2) | [中文](./液晶屏读数识别.md)/English | | +| 包装生产日期 | 点阵字符合成、过曝过暗文字识别 | [模型下载](#2) | [中文](./包装生产日期识别.md)/English | | +| PCB文字识别 | 小尺寸文本检测与识别 | [模型下载](#2) | [中文](./PCB字符识别/PCB字符识别.md)/English | | +| 电表识别 | 大分辨率图像检测调优 | [模型下载](#2) | | | +| 液晶屏缺陷检测 | 非文字字符识别 | | | | + + + +### 金融 + +| 类别 | 亮点 | 模型下载 | 教程 | 示例图 | +| -------------- | ----------------------------- | -------------- | ----------------------------------------- | ------------------------------------------------------------ | +| 表单VQA | 多模态通用表单结构化提取 | [模型下载](#2) | [中文](./多模态表单识别.md)/English | | +| 增值税发票 | 关键信息抽取,SER、RE任务训练 | [模型下载](#2) | [中文](./发票关键信息抽取.md)/English | | +| 印章检测与识别 | 端到端弯曲文本识别 | [模型下载](#2) | [中文](./印章弯曲文字识别.md)/English | | +| 通用卡证识别 | 通用结构化提取 | [模型下载](#2) | [中文](./快速构建卡证类OCR.md)/English | | +| 身份证识别 | 结构化提取、图像阴影 | | | | +| 合同比对 | 密集文本检测、NLP关键信息抽取 | [模型下载](#2) | [中文](./扫描合同关键信息提取.md)/English | | + + + +### 交通 + +| 类别 | 亮点 | 模型下载 | 教程 | 示例图 | +| ----------------- | ------------------------------ | -------------- | ----------------------------------- | ------------------------------------------------------------ | +| 车牌识别 | 多角度图像、轻量模型、端侧部署 | [模型下载](#2) | [中文](./轻量级车牌识别.md)/English | | +| 驾驶证/行驶证识别 | 尽请期待 | | | | +| 快递单识别 | 尽请期待 | | | | + + + +## 模型下载 + +如需下载上述场景中已经训练好的垂类模型,可以扫描下方二维码,关注公众号填写问卷后,加入PaddleOCR官方交流群获取20G OCR学习大礼包(内含《动手学OCR》电子书、课程回放视频、前沿论文等重磅资料) + +
+ +
+ +如果您是企业开发者且未在上述场景中找到合适的方案,可以填写[OCR应用合作调研问卷](https://paddle.wjx.cn/vj/QwF7GKw.aspx),免费与官方团队展开不同层次的合作,包括但不限于问题抽象、确定技术方案、项目答疑、共同研发等。如果您已经使用PaddleOCR落地项目,也可以填写此问卷,与飞桨平台共同宣传推广,提升企业技术品宣。期待您的提交! + + +traffic + diff --git a/applications/README_en.md b/applications/README_en.md new file mode 100644 index 0000000..df18465 --- /dev/null +++ b/applications/README_en.md @@ -0,0 +1,79 @@ +English| [简体中文](README.md) + +# Application + +PaddleOCR scene application covers general, manufacturing, finance, transportation industry of the main OCR vertical applications, on the basis of the general capabilities of PP-OCR, PP-Structure, in the form of notebook to show the use of scene data fine-tuning, model optimization methods, data augmentation and other content, for developers to quickly land OCR applications to provide demonstration and inspiration. + +- [Tutorial](#1) + - [General](#11) + - [Manufacturing](#12) + - [Finance](#13) + - [Transportation](#14) + +- [Model Download](#2) + + + +## Tutorial + + + +### General + +| Case | Feature | Model Download | Tutorial | Example | +| ---------------------------------------------- | ---------------- | -------------------- | --------------------------------------- | ------------------------------------------------------------ | +| High-precision Chineses recognition model SVTR | New model | [Model Download](#2) | [中文](./高精度中文识别模型.md)/English | | +| Chinese handwriting recognition | New font support | [Model Download](#2) | [中文](./手写文字识别.md)/English | | + + + +### Manufacturing + +| Case | Feature | Model Download | Tutorial | Example | +| ------------------------------ | ------------------------------------------------------------ | -------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | +| Digital tube | Digital tube data sythesis, recognition model fine-tuning | [Model Download](#2) | [中文](./光功率计数码管字符识别/光功率计数码管字符识别.md)/English | | +| LCD screen | Detection model distillation, serving deployment | [Model Download](#2) | [中文](./液晶屏读数识别.md)/English | | +| Packaging production data | Dot matrix character synthesis, overexposure and overdark text recognition | [Model Download](#2) | [中文](./包装生产日期识别.md)/English | | +| PCB text recognition | Small size text detection and recognition | [Model Download](#2) | [中文](./PCB字符识别/PCB字符识别.md)/English | | +| Meter text recognition | High-resolution image detection fine-tuning | [Model Download](#2) | | | +| LCD character defect detection | Non-text character recognition | | | | + + + +### Finance + +| Case | Feature | Model Download | Tutorial | Example | +| ----------------------------------- | -------------------------------------------------- | -------------------- | ----------------------------------------- | ------------------------------------------------------------ | +| Form visual question and answer | Multimodal general form structured extraction | [Model Download](#2) | [中文](./多模态表单识别.md)/English | | +| VAT invoice | Key information extraction, SER, RE task fine-tune | [Model Download](#2) | [中文](./发票关键信息抽取.md)/English | | +| Seal detection and recognition | End-to-end curved text recognition | [Model Download](#2) | [中文](./印章弯曲文字识别.md)/English | | +| Universal card recognition | Universal structured extraction | [Model Download](#2) | [中文](./快速构建卡证类OCR.md)/English | | +| ID card recognition | Structured extraction, image shading | | | | +| Contract key information extraction | Dense text detection, NLP concatenation | [Model Download](#2) | [中文](./扫描合同关键信息提取.md)/English | | + + + +### Transportation + +| Case | Feature | Model Download | Tutorial | Example | +| ----------------------------------------------- | ------------------------------------------------------------ | -------------------- | ----------------------------------- | ------------------------------------------------------------ | +| License plate recognition | Multi-angle images, lightweight models, edge-side deployment | [Model Download](#2) | [中文](./轻量级车牌识别.md)/English | | +| Driver's license/driving license identification | coming soon | | | | +| Express text recognition | coming soon | | | | + + + +## Model Download + +- For international developers: We're building a way to download these trained models, and since the current tutorials are Chinese, if you are good at both Chinese and English, or willing to polish English documents, please let us know in [discussion](https://github.com/PaddlePaddle/PaddleOCR/discussions). +- For Chinese developer: If you want to download the trained application model in the above scenarios, scan the QR code below with your WeChat, follow the PaddlePaddle official account to fill in the questionnaire, and join the PaddleOCR official group to get the 20G OCR learning materials (including "Dive into OCR" e-book, course video, application models and other materials) + +
+ +
+ + If you are an enterprise developer and have not found a suitable solution in the above scenarios, you can fill in the [OCR Application Cooperation Survey Questionnaire](https://paddle.wjx.cn/vj/QwF7GKw.aspx) to carry out different levels of cooperation with the official team **for free**, including but not limited to problem abstraction, technical solution determination, project Q&A, joint research and development, etc. If you have already used paddleOCR in your project, you can also fill out this questionnaire to jointly promote with the PaddlePaddle and enhance the technical publicity of enterprises. Looking forward to your submission! + + +trackgit-views + diff --git "a/applications/\344\270\255\346\226\207\350\241\250\346\240\274\350\257\206\345\210\253.md" "b/applications/\344\270\255\346\226\207\350\241\250\346\240\274\350\257\206\345\210\253.md" new file mode 100644 index 0000000..d61514f --- /dev/null +++ "b/applications/\344\270\255\346\226\207\350\241\250\346\240\274\350\257\206\345\210\253.md" @@ -0,0 +1,472 @@ +# 智能运营:通用中文表格识别 + +- [1. 背景介绍](#1-背景介绍) +- [2. 中文表格识别](#2-中文表格识别) +- [2.1 环境准备](#21-环境准备) +- [2.2 准备数据集](#22-准备数据集) + - [2.2.1 划分训练测试集](#221-划分训练测试集) + - [2.2.2 查看数据集](#222-查看数据集) +- [2.3 训练](#23-训练) +- [2.4 验证](#24-验证) +- [2.5 训练引擎推理](#25-训练引擎推理) +- [2.6 模型导出](#26-模型导出) +- [2.7 预测引擎推理](#27-预测引擎推理) +- [2.8 表格识别](#28-表格识别) +- [3. 表格属性识别](#3-表格属性识别) +- [3.1 代码、环境、数据准备](#31-代码环境数据准备) + - [3.1.1 代码准备](#311-代码准备) + - [3.1.2 环境准备](#312-环境准备) + - [3.1.3 数据准备](#313-数据准备) +- [3.2 表格属性识别训练](#32-表格属性识别训练) +- [3.3 表格属性识别推理和部署](#33-表格属性识别推理和部署) + - [3.3.1 模型转换](#331-模型转换) + - [3.3.2 模型推理](#332-模型推理) + +## 1. 背景介绍 + +中文表格识别在金融行业有着广泛的应用,如保险理赔、财报分析和信息录入等领域。当前,金融行业的表格识别主要以手动录入为主,开发一种自动表格识别成为丞待解决的问题。 +![](https://ai-studio-static-online.cdn.bcebos.com/d1e7780f0c7745ada4be540decefd6288e4d59257d8141f6842682a4c05d28b6) + + +在金融行业中,表格图像主要有清单类的单元格密集型表格,申请表类的大单元格表格,拍照表格和倾斜表格四种主要形式。 + +![](https://ai-studio-static-online.cdn.bcebos.com/da82ae8ef8fd479aaa38e1049eb3a681cf020dc108fa458eb3ec79da53b45fd1) +![](https://ai-studio-static-online.cdn.bcebos.com/5ffff2093a144a6993a75eef71634a52276015ee43a04566b9c89d353198c746) + + +当前的表格识别算法不能很好的处理这些场景下的表格图像。在本例中,我们使用PP-StructureV2最新发布的表格识别模型SLANet来演示如何进行中文表格是识别。同时,为了方便作业流程,我们使用表格属性识别模型对表格图像的属性进行识别,对表格的难易程度进行判断,加快人工进行校对速度。 + +本项目AI Studio链接:https://aistudio.baidu.com/aistudio/projectdetail/4588067 + +## 2. 中文表格识别 +### 2.1 环境准备 + + +```python +# 下载PaddleOCR代码 +! git clone -b dygraph https://gitee.com/paddlepaddle/PaddleOCR +``` + + +```python +# 安装PaddleOCR环境 +! pip install -r PaddleOCR/requirements.txt --force-reinstall +! pip install protobuf==3.19 +``` + +### 2.2 准备数据集 + +本例中使用的数据集采用表格[生成工具](https://github.com/WenmuZhou/TableGeneration)制作。 + +使用如下命令对数据集进行解压,并查看数据集大小 + + +```python +! cd data/data165849 && tar -xf table_gen_dataset.tar && cd - +! wc -l data/data165849/table_gen_dataset/gt.txt +``` + +#### 2.2.1 划分训练测试集 + +使用下述命令将数据集划分为训练集和测试集, 这里将90%划分为训练集,10%划分为测试集 + + +```python +import random +with open('/home/aistudio/data/data165849/table_gen_dataset/gt.txt') as f: + lines = f.readlines() +random.shuffle(lines) +train_len = int(len(lines)*0.9) +train_list = lines[:train_len] +val_list = lines[train_len:] + +# 保存结果 +with open('/home/aistudio/train.txt','w',encoding='utf-8') as f: + f.writelines(train_list) +with open('/home/aistudio/val.txt','w',encoding='utf-8') as f: + f.writelines(val_list) +``` + +划分完成后,数据集信息如下 + +|类型|数量|图片地址|标注文件路径| +|---|---|---|---| +|训练集|18000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/train.txt| +|测试集|2000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/val.txt| + +#### 2.2.2 查看数据集 + + +```python +import cv2 +import os, json +import numpy as np +from matplotlib import pyplot as plt +%matplotlib inline + +def parse_line(data_dir, line): + data_line = line.strip("\n") + info = json.loads(data_line) + file_name = info['filename'] + cells = info['html']['cells'].copy() + structure = info['html']['structure']['tokens'].copy() + + img_path = os.path.join(data_dir, file_name) + if not os.path.exists(img_path): + print(img_path) + return None + data = { + 'img_path': img_path, + 'cells': cells, + 'structure': structure, + 'file_name': file_name + } + return data + +def draw_bbox(img_path, points, color=(255, 0, 0), thickness=2): + if isinstance(img_path, str): + img_path = cv2.imread(img_path) + img_path = img_path.copy() + for point in points: + cv2.polylines(img_path, [point.astype(int)], True, color, thickness) + return img_path + + +def rebuild_html(data): + html_code = data['structure'] + cells = data['cells'] + to_insert = [i for i, tag in enumerate(html_code) if tag in ('', '>')] + + for i, cell in zip(to_insert[::-1], cells[::-1]): + if cell['tokens']: + text = ''.join(cell['tokens']) + # skip empty text + sp_char_list = ['', '', '\u2028', ' ', '', ''] + text_remove_style = skip_char(text, sp_char_list) + if len(text_remove_style) == 0: + continue + html_code.insert(i + 1, text) + + html_code = ''.join(html_code) + return html_code + + +def skip_char(text, sp_char_list): + """ + skip empty cell + @param text: text in cell + @param sp_char_list: style char and special code + @return: + """ + for sp_char in sp_char_list: + text = text.replace(sp_char, '') + return text + +save_dir = '/home/aistudio/vis' +os.makedirs(save_dir, exist_ok=True) +image_dir = '/home/aistudio/data/data165849/' +html_str = '' + +# 解析标注信息并还原html表格 +data = parse_line(image_dir, val_list[0]) + +img = cv2.imread(data['img_path']) +img_name = ''.join(os.path.basename(data['file_name']).split('.')[:-1]) +img_save_name = os.path.join(save_dir, img_name) +boxes = [np.array(x['bbox']) for x in data['cells']] +show_img = draw_bbox(data['img_path'], boxes) +cv2.imwrite(img_save_name + '_show.jpg', show_img) + +html = rebuild_html(data) +html_str += html +html_str += '
' + +# 显示标注的html字符串 +from IPython.core.display import display, HTML +display(HTML(html_str)) +# 显示单元格坐标 +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() +``` + +### 2.3 训练 + +这里选用PP-StructureV2中的表格识别模型[SLANet](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/configs/table/SLANet.yml) + +SLANet是PP-StructureV2全新推出的表格识别模型,相比PP-StructureV1中TableRec-RARE,在速度不变的情况下精度提升4.7%。TEDS提升2% + + +|算法|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed| +| --- | --- | --- | ---| +| EDD[2] |x| 88.30% |x| +| TableRec-RARE(ours) | 71.73%| 93.88% |779ms| +| SLANet(ours) | 76.31%| 95.89%|766ms| + +进行训练之前先使用如下命令下载预训练模型 + + +```python +# 进入PaddleOCR工作目录 +os.chdir('/home/aistudio/PaddleOCR') +# 下载英文预训练模型 +! wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar --no-check-certificate +! cd ./pretrain_models/ && tar xf en_ppstructure_mobile_v2.0_SLANet_train.tar && cd ../ +``` + +使用如下命令即可启动训练,需要修改的配置有 + +|字段|修改值|含义| +|---|---|---| +|Global.pretrained_model|./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy.pdparams|指向英文表格预训练模型地址| +|Global.eval_batch_step|562|模型多少step评估一次,一般设置为一个epoch总的step数| +|Optimizer.lr.name|Const|学习率衰减器 | +|Optimizer.lr.learning_rate|0.0005|学习率设为之前的0.05倍 | +|Train.dataset.data_dir|/home/aistudio/data/data165849|指向训练集图片存放目录 | +|Train.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/train.txt|指向训练集标注文件 | +|Train.loader.batch_size_per_card|32|训练时每张卡的batch_size | +|Train.loader.num_workers|1|训练集多进程数据读取的进程数,在aistudio中需要设为1 | +|Eval.dataset.data_dir|/home/aistudio/data/data165849|指向测试集图片存放目录 | +|Eval.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/val.txt|指向测试集标注文件 | +|Eval.loader.batch_size_per_card|32|测试时每张卡的batch_size | +|Eval.loader.num_workers|1|测试集多进程数据读取的进程数,在aistudio中需要设为1 | + + +已经修改好的配置存储在 `/home/aistudio/SLANet_ch.yml` + + +```python +import os +os.chdir('/home/aistudio/PaddleOCR') +! python3 tools/train.py -c /home/aistudio/SLANet_ch.yml +``` + +大约在7个epoch后达到最高精度 97.49% + +### 2.4 验证 + +训练完成后,可使用如下命令在测试集上评估最优模型的精度 + + +```python +! python3 tools/eval.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams +``` + +### 2.5 训练引擎推理 +使用如下命令可使用训练引擎对单张图片进行推理 + + +```python +import os;os.chdir('/home/aistudio/PaddleOCR') +! python3 tools/infer_table.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.infer_img=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg +``` + + +```python +import cv2 +from matplotlib import pyplot as plt +%matplotlib inline + +# 显示原图 +show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() + +# 显示预测的单元格 +show_img = cv2.imread('/home/aistudio/PaddleOCR/output/infer/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() +``` + +### 2.6 模型导出 + +使用如下命令可将模型导出为inference模型 + + +```python +! python3 tools/export_model.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.save_inference_dir=/home/aistudio/SLANet_ch/infer +``` + +### 2.7 预测引擎推理 +使用如下命令可使用预测引擎对单张图片进行推理 + + + +```python +os.chdir('/home/aistudio/PaddleOCR/ppstructure') +! python3 table/predict_structure.py \ + --table_model_dir=/home/aistudio/SLANet_ch/infer \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \ + --output=../output/inference +``` + + +```python +# 显示原图 +show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() + +# 显示预测的单元格 +show_img = cv2.imread('/home/aistudio/PaddleOCR/output/inference/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() +``` + +### 2.8 表格识别 + +在表格结构模型训练完成后,可结合OCR检测识别模型,对表格内容进行识别。 + +首先下载PP-OCRv3文字检测识别模型 + + +```python +# 下载PP-OCRv3文本检测识别模型并解压 +! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar --no-check-certificate +! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar --no-check-certificate +! cd ./inference/ && tar xf ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar && cd ../ +``` + +模型下载完成后,使用如下命令进行表格识别 + + +```python +import os;os.chdir('/home/aistudio/PaddleOCR/ppstructure') +! python3 table/predict_table.py \ + --det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \ + --rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \ + --table_model_dir=/home/aistudio/SLANet_ch/infer \ + --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \ + --output=../output/table +``` + + +```python +# 显示原图 +show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() + +# 显示预测结果 +from IPython.core.display import display, HTML +display(HTML('
alleadersh不贰过,推从自己参与浙江数。另一方
AnSha自己越共商共建工作协商w.east 抓好改革试点任务
EdimeImisesElec怀天下”。22.26 31.614.30 794.94
ip Profundi:2019年12月1Horspro444.482.41 87679.98
iehaiTrain组长蒋蕊Toafterdec203.4323.54 44266.62
Tyint roudlyRol谢您的好意,我知道ErChows48.9010316
NaFlint一辈的aterreclam7823.869829.237.96 3068
家上下游企业,5Tr景象。当地球上的我们Urelaw799.62354.9612.9833
赛事( uestCh复制的业务模式并Listicjust9.239253.22
Ca Iskole扶贫"之名引导 Papua 7191.901.653.6248
避讳ir但由于Fficeof0.226.377.173397.75
ndaTurk百处遗址gMa1288.342053.662.29885.45
')) +``` + +## 3. 表格属性识别 +### 3.1 代码、环境、数据准备 +#### 3.1.1 代码准备 +首先,我们需要准备训练表格属性的代码,PaddleClas集成了PULC方案,该方案可以快速获得一个在CPU上用时2ms的属性识别模型。PaddleClas代码可以clone下载得到。获取方式如下: + + + +```python +! git clone -b develop https://gitee.com/paddlepaddle/PaddleClas +``` + +#### 3.1.2 环境准备 +其次,我们需要安装训练PaddleClas相关的依赖包 + + +```python +! pip install -r PaddleClas/requirements.txt --force-reinstall +! pip install protobuf==3.20.0 +``` + + +#### 3.1.3 数据准备 + +最后,准备训练数据。在这里,我们一共定义了表格的6个属性,分别是表格来源、表格数量、表格颜色、表格清晰度、表格有无干扰、表格角度。其可视化如下: + +![](https://user-images.githubusercontent.com/45199522/190587903-ccdfa6fb-51e8-42de-b08b-a127cb04e304.png) + +这里,我们提供了一个表格属性的demo子集,可以快速迭代体验。下载方式如下: + + +```python +%cd PaddleClas/dataset +!wget https://paddleclas.bj.bcebos.com/data/PULC/table_attribute.tar +!tar -xf table_attribute.tar +%cd ../PaddleClas/dataset +%cd ../ +``` + +### 3.2 表格属性识别训练 +表格属性训练整体pipelinie如下: + +![](https://user-images.githubusercontent.com/45199522/190599426-3415b38e-e16e-4e68-9253-2ff531b1b5ca.png) + +1.训练过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络将抽取表格图片的特征,最终该特征连接输出的FC层,FC层经过Sigmoid激活函数后和真实标签做交叉熵损失函数,优化器通过对该损失函数做梯度下降来更新骨干网络的参数,经过多轮训练后,骨干网络的参数可以对为止图片做很好的预测; + +2.推理过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络加载学习好的权重后对该表格图片做出预测,预测的结果为一个6维向量,该向量中的每个元素反映了每个属性对应的概率值,通过对该值进一步卡阈值之后,得到最终的输出,最终的输出描述了该表格的6个属性。 + +当准备好相关的数据之后,可以一键启动表格属性的训练,训练代码如下: + + +```python + +!python tools/train.py -c ./ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.device=cpu -o Global.epochs=10 +``` + +### 3.3 表格属性识别推理和部署 +#### 3.3.1 模型转换 +当训练好模型之后,需要将模型转换为推理模型进行部署。转换脚本如下: + + +```python +!python tools/export_model.py -c ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.pretrained_model=output/PPLCNet_x1_0/best_model +``` + +执行以上命令之后,会在当前目录上生成`inference`文件夹,该文件夹中保存了当前精度最高的推理模型。 + +#### 3.3.2 模型推理 +安装推理需要的paddleclas包, 此时需要通过下载安装paddleclas的develop的whl包 + + + +```python +!pip install https://paddleclas.bj.bcebos.com/whl/paddleclas-0.0.0-py3-none-any.whl +``` + +进入`deploy`目录下即可对模型进行推理 + + +```python +%cd deploy/ +``` + +推理命令如下: + + +```python +!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_9.jpg" +!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_3253.jpg" +``` + +推理的表格图片: + +![](https://user-images.githubusercontent.com/45199522/190596141-74f4feda-b082-46d7-908d-b0bd5839b430.png) + +预测结果如下: +``` +val_9.jpg: {'attributes': ['Scanned', 'Little', 'Black-and-White', 'Clear', 'Without-Obstacles', 'Horizontal'], 'output': [1, 1, 1, 1, 1, 1]} +``` + + +推理的表格图片: + +![](https://user-images.githubusercontent.com/45199522/190597086-2e685200-22d0-4042-9e46-f61f24e02e4e.png) + +预测结果如下: +``` +val_3253.jpg: {'attributes': ['Photo', 'Little', 'Black-and-White', 'Blurry', 'Without-Obstacles', 'Tilted'], 'output': [0, 1, 1, 0, 1, 0]} +``` + +对比两张图片可以发现,第一张图片比较清晰,表格属性的结果也偏向于比较容易识别,我们可以更相信表格识别的结果,第二张图片比较模糊,且存在倾斜现象,表格识别可能存在错误,需要我们人工进一步校验。通过表格的属性识别能力,可以进一步将“人工”和“智能”很好的结合起来,为表格识别能力的落地的精度提供保障。 diff --git "a/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/corpus/digital.txt" "b/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/corpus/digital.txt" new file mode 100644 index 0000000..26b06e7 --- /dev/null +++ "b/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/corpus/digital.txt" @@ -0,0 +1,43 @@ +46.39 +40.08 +89.52 +-71.93 +23.19 +-81.02 +-34.09 +05.87 +-67.80 +-51.56 +-34.58 +37.91 +56.98 +29.01 +-90.13 +35.55 +66.07 +-90.35 +-50.93 +42.42 +21.40 +-30.99 +-71.78 +25.60 +-48.69 +-72.28 +-17.55 +-99.93 +-47.35 +-64.89 +-31.28 +-90.01 +05.17 +30.91 +30.56 +-06.90 +79.05 +67.74 +-32.31 +94.22 +28.75 +51.03 +-58.96 diff --git "a/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/fonts/DS-DIGI.TTF" "b/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/fonts/DS-DIGI.TTF" new file mode 100644 index 0000000..0925877 Binary files /dev/null and "b/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/fonts/DS-DIGI.TTF" differ diff --git "a/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/fonts/DS-DIGIB.TTF" "b/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/fonts/DS-DIGIB.TTF" new file mode 100644 index 0000000..064ad47 Binary files /dev/null and "b/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/fonts/DS-DIGIB.TTF" differ diff --git "a/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253.md" "b/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253.md" new file mode 100644 index 0000000..25e32cf --- /dev/null +++ "b/applications/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253/\345\205\211\345\212\237\347\216\207\350\256\241\346\225\260\347\240\201\347\256\241\345\255\227\347\254\246\350\257\206\345\210\253.md" @@ -0,0 +1,467 @@ +# 光功率计数码管字符识别 + +本案例将使用OCR技术自动识别光功率计显示屏文字,通过本章您可以掌握: + +- PaddleOCR快速使用 +- 数据合成方法 +- 数据挖掘方法 +- 基于现有数据微调 + +## 1. 背景介绍 + +光功率计(optical power meter )是指用于测量绝对光功率或通过一段光纤的光功率相对损耗的仪器。在光纤系统中,测量光功率是最基本的,非常像电子学中的万用表;在光纤测量中,光功率计是重负荷常用表。 + + + +目前光功率计缺少将数据直接输出的功能,需要人工读数。这一项工作单调重复,如果可以使用机器替代人工,将节约大量成本。针对上述问题,希望通过摄像头拍照->智能读数的方式高效地完成此任务。 + +为实现智能读数,通常会采取文本检测+文本识别的方案: + +第一步,使用文本检测模型定位出光功率计中的数字部分; + +第二步,使用文本识别模型获得准确的数字和单位信息。 + +本项目主要介绍如何完成第二步文本识别部分,包括:真实评估集的建立、训练数据的合成、基于 PP-OCRv3 和 SVTR_Tiny 两个模型进行训练,以及评估和推理。 + +本项目难点如下: + +- 光功率计数码管字符数据较少,难以获取。 +- 数码管中小数点占像素较少,容易漏识别。 + +针对以上问题, 本例选用 PP-OCRv3 和 SVTR_Tiny 两个高精度模型训练,同时提供了真实数据挖掘案例和数据合成案例。基于 PP-OCRv3 模型,在构建的真实评估集上精度从 52% 提升至 72%,SVTR_Tiny 模型精度可达到 78.9%。 + +aistudio项目链接: [光功率计数码管字符识别](https://aistudio.baidu.com/aistudio/projectdetail/4049044?contributionType=1) + +## 2. PaddleOCR 快速使用 + +PaddleOCR 旨在打造一套丰富、领先、且实用的OCR工具库,助力开发者训练出更好的模型,并应用落地。 + +![](https://github.com/PaddlePaddle/PaddleOCR/raw/release/2.5/doc/imgs_results/ch_ppocr_mobile_v2.0/test_add_91.jpg) + + +官方提供了适用于通用场景的高精轻量模型,首先使用官方提供的 PP-OCRv3 模型预测图片,验证下当前模型在光功率计场景上的效果。 + +- 准备环境 + +``` +python3 -m pip install -U pip +python3 -m pip install paddleocr +``` + + +- 测试效果 + +测试图: + +![](https://ai-studio-static-online.cdn.bcebos.com/8dca91f016884e16ad9216d416da72ea08190f97d87b4be883f15079b7ebab9a) + + +``` +paddleocr --lang=ch --det=Fase --image_dir=data +``` + +得到如下测试结果: + +``` +('.7000', 0.6885431408882141) +``` + +发现数字识别较准,然而对负号和小数点识别不准确。 由于PP-OCRv3的训练数据大多为通用场景数据,在特定的场景上效果可能不够好。因此需要基于场景数据进行微调。 + +下面就主要介绍如何在光功率计(数码管)场景上微调训练。 + + +## 3. 开始训练 + +### 3.1 数据准备 + +特定的工业场景往往很难获取开源的真实数据集,光功率计也是如此。在实际工业场景中,可以通过摄像头采集的方法收集大量真实数据,本例中重点介绍数据合成方法和真实数据挖掘方法,如何利用有限的数据优化模型精度。 + +数据集分为两个部分:合成数据,真实数据, 其中合成数据由 text_renderer 工具批量生成得到, 真实数据通过爬虫等方式在百度图片中搜索并使用 PPOCRLabel 标注得到。 + + +- 合成数据 + +本例中数据合成工具使用的是 [text_renderer](https://github.com/Sanster/text_renderer), 该工具可以合成用于文本识别训练的文本行数据: + +![](https://github.com/oh-my-ocr/text_renderer/raw/master/example_data/effect_layout_image/char_spacing_compact.jpg) + +![](https://github.com/oh-my-ocr/text_renderer/raw/master/example_data/effect_layout_image/color_image.jpg) + + +``` +export https_proxy=http://172.19.57.45:3128 +git clone https://github.com/oh-my-ocr/text_renderer +``` + +``` +import os +python3 setup.py develop +python3 -m pip install -r docker/requirements.txt +python3 main.py \ + --config example_data/example.py \ + --dataset img \ + --num_processes 2 \ + --log_period 10 +``` + +给定字体和语料,就可以合成较为丰富样式的文本行数据。 光功率计识别场景,目标是正确识别数码管文本,因此需要收集部分数码管字体,训练语料,用于合成文本识别数据。 + +将收集好的语料存放在 example_data 路径下: + +``` +ln -s ./fonts/DS* text_renderer/example_data/font/ +ln -s ./corpus/digital.txt text_renderer/example_data/text/ +``` + +修改 text_renderer/example_data/font_list/font_list.txt ,选择需要的字体开始合成: + +``` +python3 main.py \ + --config example_data/digital_example.py \ + --dataset img \ + --num_processes 2 \ + --log_period 10 +``` + +合成图片会被存在目录 text_renderer/example_data/digital/chn_data 下 + +查看合成的数据样例: + +![](https://ai-studio-static-online.cdn.bcebos.com/7d5774a273f84efba5b9ce7fd3f86e9ef24b6473e046444db69fa3ca20ac0986) + + +- 真实数据挖掘 + +模型训练需要使用真实数据作为评价指标,否则很容易过拟合到简单的合成数据中。没有开源数据的情况下,可以利用部分无标注数据+标注工具获得真实数据。 + + +1. 数据搜集 + +使用[爬虫工具](https://github.com/Joeclinton1/google-images-download.git)获得无标注数据 + +2. [PPOCRLabel](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.5/PPOCRLabel) 完成半自动标注 + +PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置PP-OCR模型对数据自动标注和重新识别。使用Python3和PyQT5编写,支持矩形框标注、表格标注、不规则文本标注、关键信息标注模式,导出格式可直接用于PaddleOCR检测和识别模型的训练。 + +![](https://github.com/PaddlePaddle/PaddleOCR/raw/release/2.5/PPOCRLabel/data/gif/steps_en.gif) + + +收集完数据后就可以进行分配了,验证集中一般都是真实数据,训练集中包含合成数据+真实数据。本例中标注了155张图片,其中训练集和验证集的数目为100和55。 + + +最终 `data` 文件夹应包含以下几部分: + +``` +|-data + |- synth_train.txt + |- real_train.txt + |- real_eval.txt + |- synthetic_data + |- word_001.png + |- word_002.jpg + |- word_003.jpg + | ... + |- real_data + |- word_001.png + |- word_002.jpg + |- word_003.jpg + | ... + ... +``` + +### 3.2 模型选择 + +本案例提供了2种文本识别模型:PP-OCRv3 识别模型 和 SVTR_Tiny: + +[PP-OCRv3 识别模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/PP-OCRv3_introduction.md):PP-OCRv3的识别模块是基于文本识别算法SVTR优化。SVTR不再采用RNN结构,通过引入Transformers结构更加有效地挖掘文本行图像的上下文信息,从而提升文本识别能力。并进行了一系列结构改进加速模型预测。 + +[SVTR_Tiny](https://arxiv.org/abs/2205.00159):SVTR提出了一种用于场景文本识别的单视觉模型,该模型在patch-wise image tokenization框架内,完全摒弃了序列建模,在精度具有竞争力的前提下,模型参数量更少,速度更快。 + +以上两个策略在自建中文数据集上的精度和速度对比如下: + +| ID | 策略 | 模型大小 | 精度 | 预测耗时(CPU + MKLDNN)| +|-----|-----|--------|----| --- | +| 01 | PP-OCRv2 | 8M | 74.80% | 8.54ms | +| 02 | SVTR_Tiny | 21M | 80.10% | 97.00ms | +| 03 | SVTR_LCNet(h32) | 12M | 71.90% | 6.60ms | +| 04 | SVTR_LCNet(h48) | 12M | 73.98% | 7.60ms | +| 05 | + GTC | 12M | 75.80% | 7.60ms | +| 06 | + TextConAug | 12M | 76.30% | 7.60ms | +| 07 | + TextRotNet | 12M | 76.90% | 7.60ms | +| 08 | + UDML | 12M | 78.40% | 7.60ms | +| 09 | + UIM | 12M | 79.40% | 7.60ms | + + +### 3.3 开始训练 + +首先下载 PaddleOCR 代码库 + +``` +git clone -b release/2.5 https://github.com/PaddlePaddle/PaddleOCR.git +``` + +PaddleOCR提供了训练脚本、评估脚本和预测脚本,本节将以 PP-OCRv3 中文识别模型为例: + +**Step1:下载预训练模型** + +首先下载 pretrain model,您可以下载训练好的模型在自定义数据上进行finetune + +``` +cd PaddleOCR/ +# 下载PP-OCRv3 中文预训练模型 +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar +# 解压模型参数 +cd pretrain_models +tar -xf ch_PP-OCRv3_rec_train.tar && rm -rf ch_PP-OCRv3_rec_train.tar +``` + +**Step2:自定义字典文件** + +接下来需要提供一个字典({word_dict_name}.txt),使模型在训练时,可以将所有出现的字符映射为字典的索引。 + +因此字典需要包含所有希望被正确识别的字符,{word_dict_name}.txt需要写成如下格式,并以 `utf-8` 编码格式保存: + +``` +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +- +. +``` + +word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,“3.14” 将被映射成 [3, 11, 1, 4] + +* 内置字典 + +PaddleOCR内置了一部分字典,可以按需使用。 + +`ppocr/utils/ppocr_keys_v1.txt` 是一个包含6623个字符的中文字典 + +`ppocr/utils/ic15_dict.txt` 是一个包含36个字符的英文字典 + +* 自定义字典 + +内置字典面向通用场景,具体的工业场景中,可能需要识别特殊字符,或者只需识别某几个字符,此时自定义字典会更提升模型精度。例如在光功率计场景中,需要识别数字和单位。 + +遍历真实数据标签中的字符,制作字典`digital_dict.txt`如下所示: + +``` +- +. +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +B +E +F +H +L +N +T +W +d +k +m +n +o +z +``` + + + + +**Step3:修改配置文件** + +为了更好的使用预训练模型,训练推荐使用[ch_PP-OCRv3_rec_distillation.yml](../../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml)配置文件,并参考下列说明修改配置文件: + +以 `ch_PP-OCRv3_rec_distillation.yml` 为例: +``` +Global: + ... + # 添加自定义字典,如修改字典请将路径指向新字典 + character_dict_path: ppocr/utils/dict/digital_dict.txt + ... + # 识别空格 + use_space_char: True + + +Optimizer: + ... + # 添加学习率衰减策略 + lr: + name: Cosine + learning_rate: 0.001 + ... + +... + +Train: + dataset: + # 数据集格式,支持LMDBDataSet以及SimpleDataSet + name: SimpleDataSet + # 数据集路径 + data_dir: ./data/ + # 训练集标签文件 + label_file_list: + - ./train_data/digital_img/digital_train.txt #11w + - ./train_data/digital_img/real_train.txt #100 + - ./train_data/digital_img/dbm_img/dbm.txt #3w + ratio_list: + - 0.3 + - 1.0 + - 1.0 + transforms: + ... + - RecResizeImg: + # 修改 image_shape 以适应长文本 + image_shape: [3, 48, 320] + ... + loader: + ... + # 单卡训练的batch_size + batch_size_per_card: 256 + ... + +Eval: + dataset: + # 数据集格式,支持LMDBDataSet以及SimpleDataSet + name: SimpleDataSet + # 数据集路径 + data_dir: ./data + # 验证集标签文件 + label_file_list: + - ./train_data/digital_img/real_val.txt + transforms: + ... + - RecResizeImg: + # 修改 image_shape 以适应长文本 + image_shape: [3, 48, 320] + ... + loader: + # 单卡验证的batch_size + batch_size_per_card: 256 + ... +``` +**注意,训练/预测/评估时的配置文件请务必与训练一致。** + +**Step4:启动训练** + +*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* + +``` +# GPU训练 支持单卡,多卡训练 +# 训练数码管数据 训练日志会自动保存为 "{save_model_dir}" 下的train.log + +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model=./pretrain_models/ch_PP-OCRv3_rec_train/best_accuracy + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model=./pretrain_models/en_PP-OCRv3_rec_train/best_accuracy +``` + + +PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/ch_PP-OCRv3_rec_distill/best_accuracy` 。 + +如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。 + +### SVTR_Tiny 训练 + +SVTR_Tiny 训练步骤与上面一致,SVTR支持的配置和模型训练权重可以参考[算法介绍文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/algorithm_rec_svtr.md) + +**Step1:下载预训练模型** + +``` +# 下载 SVTR_Tiny 中文识别预训练模型和配置文件 +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_ch_train.tar +# 解压模型参数 +tar -xf rec_svtr_tiny_none_ctc_ch_train.tar && rm -rf rec_svtr_tiny_none_ctc_ch_train.tar +``` +**Step2:自定义字典文件** + +字典依然使用自定义的 digital_dict.txt + +**Step3:修改配置文件** + +配置文件中对应修改字典路径和数据路径 + +**Step4:启动训练** + +``` +## 单卡训练 +python tools/train.py -c rec_svtr_tiny_none_ctc_ch_train/rec_svtr_tiny_6local_6global_stn_ch.yml \ + -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_ch_train/best_accuracy +``` + +### 3.4 验证效果 + +如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 +
+ +
+将下载或训练完成的模型放置在对应目录下即可完成模型推理 + +* 指标评估 + +训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。评估数据集可以通过 `configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml` 修改Eval中的 `label_file_path` 设置。 + +``` +# GPU 评估, Global.checkpoints 为待测权重 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.checkpoints={path/to/weights}/best_accuracy +``` + +* 测试识别效果 + +使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。 + +默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 加载训练好的参数文件: + +根据配置文件中设置的 `save_model_dir` 和 `save_epoch_step` 字段,会有以下几种参数被保存下来: + +``` +output/rec/ +├── best_accuracy.pdopt +├── best_accuracy.pdparams +├── best_accuracy.states +├── config.yml +├── iter_epoch_3.pdopt +├── iter_epoch_3.pdparams +├── iter_epoch_3.states +├── latest.pdopt +├── latest.pdparams +├── latest.states +└── train.log +``` + +其中 best_accuracy.* 是评估集上的最优模型;iter_epoch_x.* 是以 `save_epoch_step` 为间隔保存下来的模型;latest.* 是最后一个epoch的模型。 + +``` +# 预测英文结果 +python3 tools/infer_rec.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=test_digital.png +``` + +预测图片: + +![](https://ai-studio-static-online.cdn.bcebos.com/8dca91f016884e16ad9216d416da72ea08190f97d87b4be883f15079b7ebab9a) + + +得到输入图像的预测结果: + +``` +infer_img: test_digital.png + result: ('-70.00', 0.9998967) +``` diff --git "a/applications/\345\214\205\350\243\205\347\224\237\344\272\247\346\227\245\346\234\237\350\257\206\345\210\253.md" "b/applications/\345\214\205\350\243\205\347\224\237\344\272\247\346\227\245\346\234\237\350\257\206\345\210\253.md" new file mode 100644 index 0000000..73c174c --- /dev/null +++ "b/applications/\345\214\205\350\243\205\347\224\237\344\272\247\346\227\245\346\234\237\350\257\206\345\210\253.md" @@ -0,0 +1,685 @@ +# 一种基于PaddleOCR的产品包装生产日期识别模型 + +- [1. 项目介绍](#1-项目介绍) +- [2. 环境搭建](#2-环境搭建) +- [3. 数据准备](#3-数据准备) +- [4. 直接使用PP-OCRv3模型评估](#4-直接使用PPOCRv3模型评估) +- [5. 基于合成数据finetune](#5-基于合成数据finetune) + - [5.1 Text Renderer数据合成方法](#51-TextRenderer数据合成方法) + - [5.1.1 下载Text Renderer代码](#511-下载TextRenderer代码) + - [5.1.2 准备背景图片](#512-准备背景图片) + - [5.1.3 准备语料](#513-准备语料) + - [5.1.4 下载字体](#514-下载字体) + - [5.1.5 运行数据合成命令](#515-运行数据合成命令) + - [5.2 模型训练](#52-模型训练) +- [6. 基于真实数据finetune](#6-基于真实数据finetune) + - [6.1 python爬虫获取数据](#61-python爬虫获取数据) + - [6.2 数据挖掘](#62-数据挖掘) + - [6.3 模型训练](#63-模型训练) +- [7. 基于合成+真实数据finetune](#7-基于合成+真实数据finetune) + + +## 1. 项目介绍 + +产品包装生产日期是计算机视觉图像识别技术在工业场景中的一种应用。产品包装生产日期识别技术要求能够将产品生产日期从复杂背景中提取并识别出来,在物流管理、物资管理中得到广泛应用。 + +![](https://ai-studio-static-online.cdn.bcebos.com/d9e0533cc1df47ffa3bbe99de9e42639a3ebfa5bce834bafb1ca4574bf9db684) + + +- 项目难点 + +1. 没有训练数据 +2. 图像质量层次不齐: 角度倾斜、图片模糊、光照不足、过曝等问题严重 + +针对以上问题, 本例选用PP-OCRv3这一开源超轻量OCR系统进行包装产品生产日期识别系统的开发。直接使用PP-OCRv3进行评估的精度为62.99%。为提升识别精度,我们首先使用数据合成工具合成了3k数据,基于这部分数据进行finetune,识别精度提升至73.66%。由于合成数据与真实数据之间的分布存在差异,为进一步提升精度,我们使用网络爬虫配合数据挖掘策略得到了1k带标签的真实数据,基于真实数据finetune的精度为71.33%。最后,我们综合使用合成数据和真实数据进行finetune,将识别精度提升至86.99%。各策略的精度提升效果如下: + +| 策略 | 精度| +| :--------------- | :-------- | +| PP-OCRv3评估 | 62.99| +| 合成数据finetune | 73.66| +| 真实数据finetune | 71.33| +| 真实+合成数据finetune | 86.99| + +AIStudio项目链接: [一种基于PaddleOCR的包装生产日期识别方法](https://aistudio.baidu.com/aistudio/projectdetail/4287736) + +## 2. 环境搭建 + +本任务基于Aistudio完成, 具体环境如下: + +- 操作系统: Linux +- PaddlePaddle: 2.3 +- PaddleOCR: Release/2.5 +- text_renderer: master + +下载PaddlleOCR代码并安装依赖库: +```bash +git clone -b dygraph https://gitee.com/paddlepaddle/PaddleOCR + +# 安装依赖库 +cd PaddleOCR +pip install -r PaddleOCR/requirements.txt +``` + +## 3. 数据准备 + +本项目使用人工预标注的300张图像作为测试集。 + +部分数据示例如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/39ff30e0ab0442579712255e6a9ea6b5271169c98e624e6eb2b8781f003bfea0) + + +标签文件格式如下: +```txt +数据路径 标签(中间以制表符分隔) +``` + +|数据集类型|数量| +|---|---| +|测试集| 300| + +数据集[下载链接](https://aistudio.baidu.com/aistudio/datasetdetail/149770),下载后可以通过下方命令解压: + +```bash +tar -xvf data.tar +mv data ${PaddleOCR_root} +``` + +数据解压后的文件结构如下: + +```shell +PaddleOCR +├── data +│ ├── mining_images # 挖掘的真实数据示例 +│ ├── mining_train.list # 挖掘的真实数据文件列表 +│ ├── render_images # 合成数据示例 +│ ├── render_train.list # 合成数据文件列表 +│ ├── val # 测试集数据 +│ └── val.list # 测试集数据文件列表 +| ├── bg # 合成数据所需背景图像 +│ └── corpus # 合成数据所需语料 +``` + +## 4. 直接使用PP-OCRv3模型评估 + +准备好测试数据后,可以使用PaddleOCR的PP-OCRv3模型进行识别。 + +- 下载预训练模型 + +首先需要下载PP-OCR v3中英文识别模型文件,下载链接可以在https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/ppocr_introduction.md#6 获取,下载命令: + +```bash +cd ${PaddleOCR_root} +mkdir ckpt +wget -nc -P ckpt https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar +pushd ckpt/ +tar -xvf ch_PP-OCRv3_rec_train.tar +popd +``` + +- 模型评估 + +使用以下命令进行PP-OCRv3评估: + +```bash +python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \ + -o Global.checkpoints=ckpt/ch_PP-OCRv3_rec_train/best_accuracy \ + Eval.dataset.data_dir=./data \ + Eval.dataset.label_file_list=["./data/val.list"] + +``` + +其中各参数含义如下: + +```bash +-c: 指定使用的配置文件,ch_PP-OCRv3_rec_distillation.yml对应于OCRv3识别模型。 +-o: 覆盖配置文件中参数 +Global.checkpoints: 指定评估使用的模型文件路径 +Eval.dataset.data_dir: 指定评估数据集路径 +Eval.dataset.label_file_list: 指定评估数据集文件列表 +``` + +## 5. 基于合成数据finetune + +### 5.1 Text Renderer数据合成方法 + +#### 5.1.1 下载Text Renderer代码 + +首先从github或gitee下载Text Renderer代码,并安装相关依赖。 + +```bash +git clone https://gitee.com/wowowoll/text_renderer.git + +# 安装依赖库 +cd text_renderer +pip install -r requirements.txt +``` + +使用text renderer合成数据之前需要准备好背景图片、语料以及字体库,下面将逐一介绍各个步骤。 + +#### 5.1.2 准备背景图片 + +观察日常生活中常见的包装生产日期图片,我们可以发现其背景相对简单。为此我们可以从网上找一下图片,截取部分图像块作为背景图像。 + +本项目已准备了部分图像作为背景图片,在第3部分完成数据准备后,可以得到我们准备好的背景图像,示例如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/456ae2acb27d4a94896c478812aee0bc3551c703d7bd40c9be4dc983c7b3fc8a) + + + +背景图像存放于如下位置: + +```shell +PaddleOCR +├── data +| ├── bg # 合成数据所需背景图像 +``` + +#### 5.1.3 准备语料 + +观察测试集生产日期图像,我们可以知道如下数据有如下特点: +1. 由年月日组成,中间可能以“/”、“-”、“:”、“.”或者空格间隔,也可能以汉字年月日分隔 +2. 有些生产日期包含在产品批号中,此时可能包含具体时间、英文字母或数字标识 + +基于以上两点,我们编写语料生成脚本: + +```python +import random +from random import choice +import os + +cropus_num = 2000 #设置语料数量 + +def get_cropus(f): + # 随机生成年份 + year = random.randint(0, 22) + # 随机生成月份 + month = random.randint(1, 12) + # 随机生成日期 + day_dict = {31: [1,3,5,7,8,10,12], 30: [4,6,9,11], 28: [2]} + for item in day_dict: + if month in day_dict[item]: + day = random.randint(0, item) + # 随机生成小时 + hours = random.randint(0, 24) + # 随机生成分钟 + minute = random.randint(0, 60) + # 随机生成秒数 + second = random.randint(0, 60) + + # 随机生成产品标识字符 + length = random.randint(0, 6) + file_id = [] + flag = 0 + my_dict = [i for i in range(48,58)] + [j for j in range(40, 42)] + [k for k in range(65,90)] # 大小写字母 + 括号 + + for i in range(1, length): + if flag: + if i == flag+2: #括号匹配 + file_id.append(')') + flag = 0 + continue + sel = choice(my_dict) + if sel == 41: + continue + if sel == 40: + if i == 1 or i > length-3: + continue + flag = i + my_ascii = chr(sel) + file_id.append(my_ascii) + file_id_str = ''.join(file_id) + + #随机生成产品标识字符 + file_id2 = random.randint(0, 9) + + rad = random.random() + if rad < 0.3: + f.write('20{:02d}{:02d}{:02d} {}'.format(year, month, day, file_id_str)) + elif 0.3 < rad < 0.5: + f.write('20{:02d}年{:02d}月{:02d}日'.format(year, month, day)) + elif 0.5 < rad < 0.7: + f.write('20{:02d}/{:02d}/{:02d}'.format(year, month, day)) + elif 0.7 < rad < 0.8: + f.write('20{:02d}-{:02d}-{:02d}'.format(year, month, day)) + elif 0.8 < rad < 0.9: + f.write('20{:02d}.{:02d}.{:02d}'.format(year, month, day)) + else: + f.write('{:02d}:{:02d}:{:02d} {:02d}'.format(hours, minute, second, file_id2)) + +if __name__ == "__main__": + file_path = '/home/aistudio/text_renderer/my_data/cropus' + if not os.path.exists(file_path): + os.makedirs(file_path) + file_name = os.path.join(file_path, 'books.txt') + f = open(file_name, 'w') + for i in range(cropus_num): + get_cropus(f) + if i < cropus_num-1: + f.write('\n') + + f.close() +``` + +本项目已准备了部分语料,在第3部分完成数据准备后,可以得到我们准备好的语料库,默认位置如下: + +```shell +PaddleOCR +├── data +│ └── corpus #合成数据所需语料 +``` + +#### 5.1.4 下载字体 + +观察包装生产日期,我们可以发现其使用的字体为点阵体。字体可以在如下网址下载: +https://www.fonts.net.cn/fonts-en/tag-dianzhen-1.html + +本项目已准备了部分字体,在第3部分完成数据准备后,可以得到我们准备好的字体,默认位置如下: + +```shell +PaddleOCR +├── data +│ └── fonts #合成数据所需字体 +``` + +下载好字体后,还需要在list文件中指定字体文件存放路径,脚本如下: + +```bash +cd text_renderer/my_data/ +touch fonts.list +ls /home/aistudio/PaddleOCR/data/fonts/* > fonts.list +``` + +#### 5.1.5 运行数据合成命令 + +完成数据准备后,my_data文件结构如下: + +```shell +my_data/ +├── cropus +│ └── books.txt #语料库 +├── eng.txt #字符列表 +└── fonts.list #字体列表 +``` + +在运行合成数据命令之前,还有两处细节需要手动修改: +1. 将默认配置文件`text_renderer/configs/default.yaml`中第9行enable的值设为`true`,即允许合成彩色图像。否则合成的都是灰度图。 + +```yaml + # color boundary is in R,G,B format + font_color: ++ enable: true #false +``` + +2. 将`text_renderer/textrenderer/renderer.py`第184行作如下修改,取消padding。否则图片两端会有一些空白。 + +```python +padding = random.randint(s_bbox_width // 10, s_bbox_width // 8) #修改前 +padding = 0 #修改后 +``` + +运行数据合成命令: + +```bash +cd /home/aistudio/text_renderer/ +python main.py --num_img=3000 \ + --fonts_list='./my_data/fonts.list' \ + --corpus_dir "./my_data/cropus" \ + --corpus_mode "list" \ + --bg_dir "/home/aistudio/PaddleOCR/data/bg/" \ + --img_width 0 +``` + +合成好的数据默认保存在`text_renderer/output`目录下,可进入该目录查看合成的数据。 + + +合成数据示例如下 +![](https://ai-studio-static-online.cdn.bcebos.com/d686a48d465a43d09fbee51924fdca42ee21c50e676646da8559fb9967b94185) + +数据合成好后,还需要生成如下格式的训练所需的标注文件, +``` +图像路径 标签 +``` + +使用如下脚本即可生成标注文件: + +```python +import random + +abspath = '/home/aistudio/text_renderer/output/default/' + +#标注文件生成路径 +fout = open('./render_train.list', 'w', encoding='utf-8') + +with open('./output/default/tmp_labels.txt','r') as f: + lines = f.readlines() + for item in lines: + label = item[9:] + filename = item[:8] + '.jpg' + fout.write(abspath + filename + '\t' + label) + + fout.close() +``` + +经过以上步骤,我们便完成了包装生产日期数据合成。 +数据位于`text_renderer/output`,标注文件位于`text_renderer/render_train.list`。 + +本项目提供了生成好的数据供大家体验,完成步骤3的数据准备后,可得数据路径位于: + +```shell +PaddleOCR +├── data +│ ├── render_images # 合成数据示例 +│ ├── render_train.list #合成数据文件列表 +``` + +### 5.2 模型训练 + +准备好合成数据后,我们可以使用以下命令,利用合成数据进行finetune: +```bash +cd ${PaddleOCR_root} +python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \ + -o Global.pretrained_model=./ckpt/ch_PP-OCRv3_rec_train/best_accuracy \ + Global.epoch_num=20 \ + Global.eval_batch_step='[0, 20]' \ + Train.dataset.data_dir=./data \ + Train.dataset.label_file_list=['./data/render_train.list'] \ + Train.loader.batch_size_per_card=64 \ + Eval.dataset.data_dir=./data \ + Eval.dataset.label_file_list=["./data/val.list"] \ + Eval.loader.batch_size_per_card=64 + +``` + +其中各参数含义如下: + +```txt +-c: 指定使用的配置文件,ch_PP-OCRv3_rec_distillation.yml对应于OCRv3识别模型。 +-o: 覆盖配置文件中参数 +Global.pretrained_model: 指定finetune使用的预训练模型 +Global.epoch_num: 指定训练的epoch数 +Global.eval_batch_step: 间隔多少step做一次评估 +Train.dataset.data_dir: 训练数据集路径 +Train.dataset.label_file_list: 训练集文件列表 +Train.loader.batch_size_per_card: 训练单卡batch size +Eval.dataset.data_dir: 评估数据集路径 +Eval.dataset.label_file_list: 评估数据集文件列表 +Eval.loader.batch_size_per_card: 评估单卡batch size +``` + +## 6. 基于真实数据finetune + + +使用合成数据finetune能提升我们模型的识别精度,但由于合成数据和真实数据之间的分布可能有一定差异,因此作用有限。为进一步提高识别精度,本节介绍如何挖掘真实数据进行模型finetune。 + +数据挖掘的整体思路如下: +1. 使用python爬虫从网上获取大量无标签数据 +2. 使用模型从大量无标签数据中构建出有效训练集 + +### 6.1 python爬虫获取数据 + +- 推荐使用[爬虫工具](https://github.com/Joeclinton1/google-images-download)获取无标签图片。 + +图片获取后,可按如下目录格式组织: + +```txt +sprider +├── file.list +├── data +│ ├── 00000.jpg +│ ├── 00001.jpg +... +``` + +### 6.2 数据挖掘 + +我们使用PaddleOCR对获取到的图片进行挖掘,具体步骤如下: +1. 使用 PP-OCRv3检测模型+svtr-tiny识别模型,对每张图片进行预测。 +2. 使用数据挖掘策略,得到有效图片。 +3. 将有效图片对应的图像区域和标签提取出来,构建训练集。 + + +首先下载预训练模型,PP-OCRv3检测模型下载链接:https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar + +如需获取svtr-tiny高精度中文识别预训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 +
+ +
+ + +完成下载后,可将模型存储于如下位置: + +```shell +PaddleOCR +├── data +│ ├── rec_vit_sub_64_363_all/ # svtr_tiny高精度识别模型 +``` + +```bash +# 下载解压PP-OCRv3检测模型 +cd ${PaddleOCR_root} +wget -nc -P ckpt https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar +pushd ckpt +tar -xvf ch_PP-OCRv3_det_infer.tar +popd ckpt +``` + +在使用PPOCRv3检测模型+svtr-tiny识别模型进行预测之前,有如下两处细节需要手动修改: +1. 将`tools/infer/predict_rec.py`中第110行`imgW`修改为`320` + +```python +#imgW = int((imgH * max_wh_ratio)) +imgW = 320 +``` + +2. 将`tools/infer/predict_system.py`第169行添加如下一行,将预测分数也写入结果文件中。 + +```python +"scores": rec_res[idx][1], +``` + +模型预测命令: +```bash +python tools/infer/predict_system.py \ + --image_dir="/home/aistudio/sprider/data" \ + --det_model_dir="./ckpt/ch_PP-OCRv3_det_infer/" \ + --rec_model_dir="/home/aistudio/PaddleOCR/data/rec_vit_sub_64_363_all/" \ + --rec_image_shape="3,32,320" +``` + +获得预测结果后,我们使用数据挖掘策略得到有效图片。具体挖掘策略如下: +1. 预测置信度高于95% +2. 识别结果包含字符‘20’,即年份 +3. 没有中文,或者有中文并且‘日’和'月'同时在识别结果中 + +```python +# 获取有效预测 + +import json +import re + +zh_pattern = re.compile(u'[\u4e00-\u9fa5]+') #正则表达式,筛选字符是否包含中文 + +file_path = '/home/aistudio/PaddleOCR/inference_results/system_results.txt' +out_path = '/home/aistudio/PaddleOCR/selected_results.txt' +f_out = open(out_path, 'w') + +with open(file_path, "r", encoding='utf-8') as fin: + lines = fin.readlines() + + +for line in lines: + flag = False + # 读取文件内容 + file_name, json_file = line.strip().split('\t') + preds = json.loads(json_file) + res = [] + for item in preds: + transcription = item['transcription'] #获取识别结果 + scores = item['scores'] #获取识别得分 + # 挖掘策略 + if scores > 0.95: + if '20' in transcription and len(transcription) > 4 and len(transcription) < 12: + word = transcription + if not(zh_pattern.search(word) and ('日' not in word or '月' not in word)): + flag = True + res.append(item) + save_pred = file_name + "\t" + json.dumps( + res, ensure_ascii=False) + "\n" + if flag ==True: + f_out.write(save_pred) + +f_out.close() +``` + +然后将有效预测对应的图像区域和标签提取出来,构建训练集。具体实现脚本如下: + +```python +import cv2 +import json +import numpy as np + +PATH = '/home/aistudio/PaddleOCR/inference_results/' #数据原始路径 +SAVE_PATH = '/home/aistudio/mining_images/' #裁剪后数据保存路径 +file_list = '/home/aistudio/PaddleOCR/selected_results.txt' #数据预测结果 +label_file = '/home/aistudio/mining_images/mining_train.list' #输出真实数据训练集标签list + +if not os.path.exists(SAVE_PATH): + os.mkdir(SAVE_PATH) + +f_label = open(label_file, 'w') + + +def get_rotate_crop_image(img, points): + """ + 根据检测结果points,从输入图像img中裁剪出相应的区域 + """ + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + # 形变或倾斜,会做透视变换,reshape成矩形 + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + +def crop_and_get_filelist(file_list): + with open(file_list, "r", encoding='utf-8') as fin: + lines = fin.readlines() + + img_num = 0 + for line in lines: + img_name, json_file = line.strip().split('\t') + preds = json.loads(json_file) + for item in preds: + transcription = item['transcription'] + points = item['points'] + points = np.array(points).astype('float32') + #print('processing {}...'.format(img_name)) + + img = cv2.imread(PATH+img_name) + dst_img = get_rotate_crop_image(img, points) + h, w, c = dst_img.shape + newWidth = int((32. / h) * w) + newImg = cv2.resize(dst_img, (newWidth, 32)) + new_img_name = '{:05d}.jpg'.format(img_num) + cv2.imwrite(SAVE_PATH+new_img_name, dst_img) + f_label.write(SAVE_PATH+new_img_name+'\t'+transcription+'\n') + img_num += 1 + + +crop_and_get_filelist(file_list) +f_label.close() +``` + +### 6.3 模型训练 + +通过数据挖掘,我们得到了真实场景数据和对应的标签。接下来使用真实数据finetune,观察精度提升效果。 + + +利用真实数据进行finetune: + +```bash +cd ${PaddleOCR_root} +python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \ + -o Global.pretrained_model=./ckpt/ch_PP-OCRv3_rec_train/best_accuracy \ + Global.epoch_num=20 \ + Global.eval_batch_step='[0, 20]' \ + Train.dataset.data_dir=./data \ + Train.dataset.label_file_list=['./data/mining_train.list'] \ + Train.loader.batch_size_per_card=64 \ + Eval.dataset.data_dir=./data \ + Eval.dataset.label_file_list=["./data/val.list"] \ + Eval.loader.batch_size_per_card=64 +``` + +各参数含义参考第6部分合成数据finetune,只需要对训练数据路径做相应的修改: + +```txt +Train.dataset.data_dir: 训练数据集路径 +Train.dataset.label_file_list: 训练集文件列表 +``` + +示例使用我们提供的真实数据进行finetune,如想换成自己的数据,只需要相应的修改`Train.dataset.data_dir`和`Train.dataset.label_file_list`参数即可。 + +由于数据量不大,这里仅训练20个epoch即可。训练完成后,可以得到合成数据finetune后的精度为best acc=**71.33%**。 + +由于数量比较少,精度会比合成数据finetue的略低。 + + +## 7. 基于合成+真实数据finetune + +为进一步提升模型精度,我们结合使用合成数据和挖掘到的真实数据进行finetune。 + +利用合成+真实数据进行finetune,各参数含义参考第6部分合成数据finetune,只需要对训练数据路径做相应的修改: + +```txt +Train.dataset.data_dir: 训练数据集路径 +Train.dataset.label_file_list: 训练集文件列表 +``` + +生成训练list文件: +```bash +# 生成训练集文件list +cat /home/aistudio/PaddleOCR/data/render_train.list /home/aistudio/PaddleOCR/data/mining_train.list > /home/aistudio/PaddleOCR/data/render_mining_train.list +``` + +启动训练: +```bash +cd ${PaddleOCR_root} +python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \ + -o Global.pretrained_model=./ckpt/ch_PP-OCRv3_rec_train/best_accuracy \ + Global.epoch_num=40 \ + Global.eval_batch_step='[0, 20]' \ + Train.dataset.data_dir=./data \ + Train.dataset.label_file_list=['./data/render_mining_train.list'] \ + Train.loader.batch_size_per_card=64 \ + Eval.dataset.data_dir=./data \ + Eval.dataset.label_file_list=["./data/val.list"] \ + Eval.loader.batch_size_per_card=64 +``` + +示例使用我们提供的真实+合成数据进行finetune,如想换成自己的数据,只需要相应的修改Train.dataset.data_dir和Train.dataset.label_file_list参数即可。 + +由于数据量不大,这里仅训练40个epoch即可。训练完成后,可以得到合成数据finetune后的精度为best acc=**86.99%**。 + +可以看到,相较于原始PP-OCRv3的识别精度62.99%,使用合成数据+真实数据finetune后,识别精度能提升24%。 + +如需获取已训练模型,可以同样扫描上方二维码下载,将下载或训练完成的模型放置在对应目录下即可完成模型推理。 + +模型的推理部署方法可以参考repo文档: https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/deploy/README_ch.md diff --git "a/applications/\345\215\260\347\253\240\345\274\257\346\233\262\346\226\207\345\255\227\350\257\206\345\210\253.md" "b/applications/\345\215\260\347\253\240\345\274\257\346\233\262\346\226\207\345\255\227\350\257\206\345\210\253.md" new file mode 100644 index 0000000..07a146f --- /dev/null +++ "b/applications/\345\215\260\347\253\240\345\274\257\346\233\262\346\226\207\345\255\227\350\257\206\345\210\253.md" @@ -0,0 +1,1034 @@ +# 印章弯曲文字识别 + +- [1. 项目介绍](#1-----) +- [2. 环境搭建](#2-----) + * [2.1 准备PaddleDetection环境](#21---paddledetection--) + * [2.2 准备PaddleOCR环境](#22---paddleocr--) +- [3. 数据集准备](#3------) + * [3.1 数据标注](#31-----) + * [3.2 数据处理](#32-----) +- [4. 印章检测实践](#4-------) +- [5. 印章文字识别实践](#5---------) + * [5.1 端对端印章文字识别实践](#51------------) + * [5.2 两阶段印章文字识别实践](#52------------) + + [5.2.1 印章文字检测](#521-------) + + [5.2.2 印章文字识别](#522-------) + + +# 1. 项目介绍 + +弯曲文字识别在OCR任务中有着广泛的应用,比如:自然场景下的招牌,艺术文字,以及常见的印章文字识别。 + +在本项目中,将以印章识别任务为例,介绍如何使用PaddleDetection和PaddleOCR完成印章检测和印章文字识别任务。 + +项目难点: +1. 缺乏训练数据 +2. 图像质量参差不齐,图像模糊,文字不清晰 + +针对以上问题,本项目选用PaddleOCR里的PPOCRLabel工具完成数据标注。基于PaddleDetection完成印章区域检测,然后通过PaddleOCR里的端对端OCR算法和两阶段OCR算法分别完成印章文字识别任务。不同任务的精度效果如下: + + +| 任务 | 训练数据数量 | 精度 | +| -------- | - | -------- | +| 印章检测 | 1000 | 95.00% | +| 印章文字识别-端对端OCR方法 | 700 | 47.00% | +| 印章文字识别-两阶段OCR方法 | 700 | 55.00% | + +点击进入 [AI Studio 项目](https://aistudio.baidu.com/aistudio/projectdetail/4586113) + +# 2. 环境搭建 + +本项目需要准备PaddleDetection和PaddleOCR的项目运行环境,其中PaddleDetection用于实现印章检测任务,PaddleOCR用于实现文字识别任务 + + +## 2.1 准备PaddleDetection环境 + +下载PaddleDetection代码: +``` +!git clone https://github.com/PaddlePaddle/PaddleDetection.git +# 如果克隆github代码较慢,请从gitee上克隆代码 +#git clone https://gitee.com/PaddlePaddle/PaddleDetection.git +``` + +安装PaddleDetection依赖 +``` +!cd PaddleDetection && pip install -r requirements.txt +``` + +## 2.2 准备PaddleOCR环境 + +下载PaddleOCR代码: +``` +!git clone https://github.com/PaddlePaddle/PaddleOCR.git +# 如果克隆github代码较慢,请从gitee上克隆代码 +#git clone https://gitee.com/PaddlePaddle/PaddleOCR.git +``` + +安装PaddleOCR依赖 +``` +!cd PaddleOCR && git checkout dygraph && pip install -r requirements.txt +``` + +# 3. 数据集准备 + +## 3.1 数据标注 + +本项目中使用[PPOCRLabel](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6/PPOCRLabel)工具标注印章检测数据,标注内容包括印章的位置以及印章中文字的位置和文字内容。 + + +注:PPOCRLabel的使用方法参考[文档](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6/PPOCRLabel)。 + +PPOCRlabel标注印章数据步骤: +- 打开数据集所在文件夹 +- 按下快捷键Q进行4点(多点)标注——针对印章文本识别, + - 印章弯曲文字包围框采用偶数点标注(比如4点,8点,16点),按照阅读顺序,以16点标注为例,从文字左上方开始标注->到文字右上方标注8个点->到文字右下方->文字左下方8个点,一共8个点,形成包围曲线,参考下图。如果文字弯曲程度不高,为了减小标注工作量,可以采用4点、8点标注,需要注意的是,文字上下点数相同。(总点数尽量不要超过18个) + - 对于需要识别的印章中非弯曲文字,采用4点框标注即可 + - 对应包围框的文字部分默认是”待识别”,需要修改为包围框内的具体文字内容 +- 快捷键W进行矩形标注——针对印章区域检测,印章检测区域保证标注框包围整个印章,包围框对应文字可以设置为'印章区域',方便后续处理。 +- 针对印章中的水平文字可以视情况考虑矩形或四点标注:保证按行标注即可。如果背景文字与印章文字比较接近,标注时尽量避开背景文字。 +- 标注完成后修改右侧文本结果,确认无误后点击下方check(或CTRL+V),确认本张图片的标注。 +- 所有图片标注完成后,在顶部菜单栏点击File -> Export Label导出label.txt。 + +标注完成后,可视化效果如下: +![](https://ai-studio-static-online.cdn.bcebos.com/f5acbc4f50dd401a8f535ed6a263f94b0edff82c1aed4285836a9ead989b9c13) + +数据标注完成后,标签中包含印章检测的标注和印章文字识别的标注,如下所示: +``` +img/1.png [{"transcription": "印章区域", "points": [[87, 245], [214, 245], [214, 369], [87, 369]], "difficult": false}, {"transcription": "国家税务总局泸水市税务局第二税务分局", "points": [[110, 314], [116, 290], [131, 275], [152, 273], [170, 277], [181, 289], [186, 303], [186, 312], [201, 311], [198, 289], [189, 272], [175, 259], [152, 252], [124, 257], [100, 280], [94, 312]], "difficult": false}, {"transcription": "征税专用章", "points": [[117, 334], [183, 334], [183, 352], [117, 352]], "difficult": false}] +``` +标注中包含表示'印章区域'的坐标和'印章文字'坐标以及文字内容。 + + + +## 3.2 数据处理 + +标注时为了方便标注,没有区分印章区域的标注框和文字区域的标注框,可以通过python代码完成标签的划分。 + +在本项目的'/home/aistudio/work/seal_labeled_datas'目录下,存放了标注的数据示例,如下: + + +![](https://ai-studio-static-online.cdn.bcebos.com/3d762970e2184177a2c633695a31029332a4cd805631430ea797309492e45402) + +标签文件'/home/aistudio/work/seal_labeled_datas/Label.txt'中的标注内容如下: + +``` +img/test1.png [{"transcription": "待识别", "points": [[408, 232], [537, 232], [537, 352], [408, 352]], "difficult": false}, {"transcription": "电子回单", "points": [[437, 305], [504, 305], [504, 322], [437, 322]], "difficult": false}, {"transcription": "云南省农村信用社", "points": [[417, 290], [434, 295], [438, 281], [446, 267], [455, 261], [472, 258], [489, 264], [498, 277], [502, 295], [526, 289], [518, 267], [503, 249], [475, 232], [446, 239], [429, 255], [418, 275]], "difficult": false}, {"transcription": "专用章", "points": [[437, 319], [503, 319], [503, 338], [437, 338]], "difficult": false}] +``` + + +为了方便训练,我们需要通过python代码将用于训练印章检测和训练印章文字识别的标注区分开。 + + +``` +import numpy as np +import json +import cv2 +import os +from shapely.geometry import Polygon + + +def poly2box(poly): + xmin = np.min(np.array(poly)[:, 0]) + ymin = np.min(np.array(poly)[:, 1]) + xmax = np.max(np.array(poly)[:, 0]) + ymax = np.max(np.array(poly)[:, 1]) + return np.array([[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]) + + +def draw_text_det_res(dt_boxes, src_im, color=(255, 255, 0)): + for box in dt_boxes: + box = np.array(box).astype(np.int32).reshape(-1, 2) + cv2.polylines(src_im, [box], True, color=color, thickness=2) + return src_im + +class LabelDecode(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + label = json.loads(data['label']) + + nBox = len(label) + seal_boxes = self.get_seal_boxes(label) + + gt_label = [] + + for seal_box in seal_boxes: + seal_anno = {'seal_box': seal_box} + boxes, txts, txt_tags = [], [], [] + + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + try: + ints = self.get_intersection(box, seal_box) + except Exception as E: + print(E) + continue + + if abs(Polygon(box).area - self.get_intersection(box, seal_box)) < 1e-3 and \ + abs(Polygon(box).area - self.get_union(box, seal_box)) > 1e-3: + + boxes.append(box) + txts.append(txt) + if txt in ['*', '###', '待识别']: + txt_tags.append(True) + else: + txt_tags.append(False) + + seal_anno['polys'] = boxes + seal_anno['texts'] = txts + seal_anno['ignore_tags'] = txt_tags + + gt_label.append(seal_anno) + + return gt_label + + def get_seal_boxes(self, label): + + nBox = len(label) + seal_box = [] + for bno in range(0, nBox): + box = label[bno]['points'] + if len(box) == 4: + seal_box.append(box) + + if len(seal_box) == 0: + return None + + seal_box = self.valid_seal_box(seal_box) + return seal_box + + + def is_seal_box(self, box, boxes): + is_seal = True + for poly in boxes: + if list(box.shape()) != list(box.shape.shape()): + if abs(Polygon(box).area - self.get_intersection(box, poly)) < 1e-3: + return False + else: + if np.sum(np.array(box) - np.array(poly)) < 1e-3: + # continue when the box is same with poly + continue + if abs(Polygon(box).area - self.get_intersection(box, poly)) < 1e-3: + return False + return is_seal + + + def valid_seal_box(self, boxes): + if len(boxes) == 1: + return boxes + + new_boxes = [] + flag = True + for k in range(0, len(boxes)): + flag = True + tmp_box = boxes[k] + for i in range(0, len(boxes)): + if k == i: continue + if abs(Polygon(tmp_box).area - self.get_intersection(tmp_box, boxes[i])) < 1e-3: + flag = False + continue + if flag: + new_boxes.append(tmp_box) + + return new_boxes + + + def get_union(self, pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(self, pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(self, pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def expand_points_num(self, boxes): + max_points_num = 0 + for box in boxes: + if len(box) > max_points_num: + max_points_num = len(box) + ex_boxes = [] + for box in boxes: + ex_box = box + [box[-1]] * (max_points_num - len(box)) + ex_boxes.append(ex_box) + return ex_boxes + + +def gen_extract_label(data_dir, label_file, seal_gt, seal_ppocr_gt): + label_decode_func = LabelDecode() + gts = open(label_file, "r").readlines() + + seal_gt_list = [] + seal_ppocr_list = [] + + for idx, line in enumerate(gts): + img_path, label = line.strip().split("\t") + data = {'label': label, 'img_path':img_path} + res = label_decode_func(data) + src_img = cv2.imread(os.path.join(data_dir, img_path)) + if res is None: + print("ERROR! res is None!") + continue + + anno = [] + for i, gt in enumerate(res): + # print(i, box, type(box), ) + anno.append({'polys': gt['seal_box'], 'cls':1}) + + seal_gt_list.append(f"{img_path}\t{json.dumps(anno)}\n") + seal_ppocr_list.append(f"{img_path}\t{json.dumps(res)}\n") + + if not os.path.exists(os.path.dirname(seal_gt)): + os.makedirs(os.path.dirname(seal_gt)) + if not os.path.exists(os.path.dirname(seal_ppocr_gt)): + os.makedirs(os.path.dirname(seal_ppocr_gt)) + + with open(seal_gt, "w") as f: + f.writelines(seal_gt_list) + f.close() + + with open(seal_ppocr_gt, 'w') as f: + f.writelines(seal_ppocr_list) + f.close() + +def vis_seal_ppocr(data_dir, label_file, save_dir): + + datas = open(label_file, 'r').readlines() + for idx, line in enumerate(datas): + img_path, label = line.strip().split('\t') + img_path = os.path.join(data_dir, img_path) + + label = json.loads(label) + src_im = cv2.imread(img_path) + if src_im is None: + continue + + for anno in label: + seal_box = anno['seal_box'] + txt_boxes = anno['polys'] + + # vis seal box + src_im = draw_text_det_res([seal_box], src_im, color=(255, 255, 0)) + src_im = draw_text_det_res(txt_boxes, src_im, color=(255, 0, 0)) + + save_path = os.path.join(save_dir, os.path.basename(img_path)) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # print(src_im.shape) + cv2.imwrite(save_path, src_im) + + +def draw_html(img_dir, save_name): + import glob + + images_dir = glob.glob(img_dir + "/*") + print(len(images_dir)) + + html_path = save_name + with open(html_path, 'w') as html: + html.write('\n\n') + html.write('\n') + html.write("") + + html.write("\n") + html.write(f'\n") + html.write(f'' % (base)) + html.write("\n") + + html.write('\n') + html.write('
\n GT') + + for i, filename in enumerate(sorted(images_dir)): + if filename.endswith("txt"): continue + print(filename) + + base = "{}".format(filename) + if True: + html.write("
{filename}\n GT') + html.write('GT 310\n
\n') + html.write('\n\n') + print("ok") + + +def crop_seal_from_img(label_file, data_dir, save_dir, save_gt_path): + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + datas = open(label_file, 'r').readlines() + all_gts = [] + count = 0 + for idx, line in enumerate(datas): + img_path, label = line.strip().split('\t') + img_path = os.path.join(data_dir, img_path) + + label = json.loads(label) + src_im = cv2.imread(img_path) + if src_im is None: + continue + + for c, anno in enumerate(label): + seal_poly = anno['seal_box'] + txt_boxes = anno['polys'] + txts = anno['texts'] + ignore_tags = anno['ignore_tags'] + + box = poly2box(seal_poly) + img_crop = src_im[box[0][1]:box[2][1], box[0][0]:box[2][0], :] + + save_path = os.path.join(save_dir, f"{idx}_{c}.jpg") + cv2.imwrite(save_path, np.array(img_crop)) + + img_gt = [] + for i in range(len(txts)): + txt_boxes_crop = np.array(txt_boxes[i]) + txt_boxes_crop[:, 1] -= box[0, 1] + txt_boxes_crop[:, 0] -= box[0, 0] + img_gt.append({'transcription': txts[i], "points": txt_boxes_crop.tolist(), "ignore_tag": ignore_tags[i]}) + + if len(img_gt) >= 1: + count += 1 + save_gt = f"{os.path.basename(save_path)}\t{json.dumps(img_gt)}\n" + + all_gts.append(save_gt) + + print(f"The num of all image: {len(all_gts)}, and the number of useful image: {count}") + if not os.path.exists(os.path.dirname(save_gt_path)): + os.makedirs(os.path.dirname(save_gt_path)) + + with open(save_gt_path, "w") as f: + f.writelines(all_gts) + f.close() + print("Done") + + + +if __name__ == "__main__": + + # 数据处理 + gen_extract_label("./seal_labeled_datas", "./seal_labeled_datas/Label.txt", "./seal_ppocr_gt/seal_det_img.txt", "./seal_ppocr_gt/seal_ppocr_img.txt") + vis_seal_ppocr("./seal_labeled_datas", "./seal_ppocr_gt/seal_ppocr_img.txt", "./seal_ppocr_gt/seal_ppocr_vis/") + draw_html("./seal_ppocr_gt/seal_ppocr_vis/", "./vis_seal_ppocr.html") + seal_ppocr_img_label = "./seal_ppocr_gt/seal_ppocr_img.txt" + crop_seal_from_img(seal_ppocr_img_label, "./seal_labeled_datas/", "./seal_img_crop", "./seal_img_crop/label.txt") + +``` + +处理完成后,生成的文件如下: +``` +├── seal_img_crop/ +│ ├── 0_0.jpg +│ ├── ... +│ └── label.txt +├── seal_ppocr_gt/ +│ ├── seal_det_img.txt +│ ├── seal_ppocr_img.txt +│ └── seal_ppocr_vis/ +│ ├── test1.png +│ ├── ... +└── vis_seal_ppocr.html + +``` +其中`seal_img_crop/label.txt`文件为印章识别标签文件,其内容格式为: +``` +0_0.jpg [{"transcription": "\u7535\u5b50\u56de\u5355", "points": [[29, 73], [96, 73], [96, 90], [29, 90]], "ignore_tag": false}, {"transcription": "\u4e91\u5357\u7701\u519c\u6751\u4fe1\u7528\u793e", "points": [[9, 58], [26, 63], [30, 49], [38, 35], [47, 29], [64, 26], [81, 32], [90, 45], [94, 63], [118, 57], [110, 35], [95, 17], [67, 0], [38, 7], [21, 23], [10, 43]], "ignore_tag": false}, {"transcription": "\u4e13\u7528\u7ae0", "points": [[29, 87], [95, 87], [95, 106], [29, 106]], "ignore_tag": false}] +``` +可以直接用于PaddleOCR的PGNet算法的训练。 + +`seal_ppocr_gt/seal_det_img.txt`为印章检测标签文件,其内容格式为: +``` +img/test1.png [{"polys": [[408, 232], [537, 232], [537, 352], [408, 352]], "cls": 1}] +``` +为了使用PaddleDetection工具完成印章检测模型的训练,需要将`seal_det_img.txt`转换为COCO或者VOC的数据标注格式。 + +可以直接使用下述代码将印章检测标注转换成VOC格式。 + + +``` +import numpy as np +import json +import cv2 +import os +from shapely.geometry import Polygon + +seal_train_gt = "./seal_ppocr_gt/seal_det_img.txt" +# 注:仅用于示例,实际使用中需要分别转换训练集和测试集的标签 +seal_valid_gt = "./seal_ppocr_gt/seal_det_img.txt" + +def gen_main_train_txt(mode='train'): + if mode == "train": + file_path = seal_train_gt + if mode in ['valid', 'test']: + file_path = seal_valid_gt + + save_path = f"./seal_VOC/ImageSets/Main/{mode}.txt" + save_train_path = f"./seal_VOC/{mode}.txt" + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) + + datas = open(file_path, 'r').readlines() + img_names = [] + train_names = [] + for line in datas: + img_name = line.strip().split('\t')[0] + img_name = os.path.basename(img_name) + (i_name, extension) = os.path.splitext(img_name) + t_name = 'JPEGImages/'+str(img_name)+' '+'Annotations/'+str(i_name)+'.xml\n' + train_names.append(t_name) + img_names.append(i_name + "\n") + + with open(save_train_path, "w") as f: + f.writelines(train_names) + f.close() + + with open(save_path, "w") as f: + f.writelines(img_names) + f.close() + + print(f"{mode} save done") + + +def gen_xml_label(mode='train'): + if mode == "train": + file_path = seal_train_gt + if mode in ['valid', 'test']: + file_path = seal_valid_gt + + datas = open(file_path, 'r').readlines() + img_names = [] + train_names = [] + anno_path = "./seal_VOC/Annotations" + img_path = "./seal_VOC/JPEGImages" + + if not os.path.exists(anno_path): + os.makedirs(anno_path) + if not os.path.exists(img_path): + os.makedirs(img_path) + + for idx, line in enumerate(datas): + img_name, label = line.strip().split('\t') + img = cv2.imread(os.path.join("./seal_labeled_datas", img_name)) + cv2.imwrite(os.path.join(img_path, os.path.basename(img_name)), img) + height, width, c = img.shape + img_name = os.path.basename(img_name) + (i_name, extension) = os.path.splitext(img_name) + label = json.loads(label) + + xml_file = open(("./seal_VOC/Annotations" + '/' + i_name + '.xml'), 'w') + xml_file.write('\n') + xml_file.write(' seal_VOC\n') + xml_file.write(' ' + str(img_name) + '\n') + xml_file.write(' ' + 'Annotations/' + str(img_name) + '\n') + xml_file.write(' \n') + xml_file.write(' ' + str(width) + '\n') + xml_file.write(' ' + str(height) + '\n') + xml_file.write(' 3\n') + xml_file.write(' \n') + xml_file.write(' 0\n') + + for anno in label: + poly = anno['polys'] + if anno['cls'] == 1: + gt_cls = 'redseal' + xmin = np.min(np.array(poly)[:, 0]) + ymin = np.min(np.array(poly)[:, 1]) + xmax = np.max(np.array(poly)[:, 0]) + ymax = np.max(np.array(poly)[:, 1]) + xmin,ymin,xmax,ymax= int(xmin),int(ymin),int(xmax),int(ymax) + xml_file.write(' \n') + xml_file.write(' '+str(gt_cls)+'\n') + xml_file.write(' Unspecified\n') + xml_file.write(' 0\n') + xml_file.write(' 0\n') + xml_file.write(' \n') + xml_file.write(' '+str(xmin)+'\n') + xml_file.write(' '+str(ymin)+'\n') + xml_file.write(' '+str(xmax)+'\n') + xml_file.write(' '+str(ymax)+'\n') + xml_file.write(' \n') + xml_file.write(' \n') + xml_file.write('') + xml_file.close() + print(f'{mode} xml save done!') + + +gen_main_train_txt() +gen_main_train_txt('valid') +gen_xml_label('train') +gen_xml_label('valid') + +``` + +数据处理完成后,转换为VOC格式的印章检测数据存储在~/data/seal_VOC目录下,目录组织结构为: + +``` +├── Annotations/ +├── ImageSets/ +│   └── Main/ +│   ├── train.txt +│   └── valid.txt +├── JPEGImages/ +├── train.txt +└── valid.txt +└── label_list.txt +``` + +Annotations下为数据的标签,JPEGImages目录下为图像文件,label_list.txt为标注检测框类别标签文件。 + +在接下来一节中,将介绍如何使用PaddleDetection工具库完成印章检测模型的训练。 + +# 4. 印章检测实践 + +在实际应用中,印章多是出现在合同,发票,公告等场景中,印章文字识别的任务需要排除图像中背景文字的影响,因此需要先检测出图像中的印章区域。 + + +借助PaddleDetection目标检测库可以很容易的实现印章检测任务,使用PaddleDetection训练印章检测任务流程如下: + +- 选择算法 +- 修改数据集配置路径 +- 启动训练 + + +**算法选择** + +PaddleDetection中有许多检测算法可以选择,考虑到每条数据中印章区域较为清晰,且考虑到性能需求。在本项目中,我们采用mobilenetv3为backbone的ppyolo算法完成印章检测任务,对应的配置文件是:configs/ppyolo/ppyolo_mbv3_large.yml + + + +**修改配置文件** + +配置文件中的默认数据路径是COCO, +需要修改为印章检测的数据路径,主要修改如下: +在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容: +``` +metric: VOC +map_type: 11point +num_classes: 2 + +TrainDataset: + !VOCDataSet + dataset_dir: dataset/seal_VOC + anno_path: train.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +EvalDataset: + !VOCDataSet + dataset_dir: dataset/seal_VOC + anno_path: test.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +TestDataset: + !ImageFolder + anno_path: dataset/seal_VOC/label_list.txt +``` + +配置文件中设置的数据路径在PaddleDetection/dataset目录下,我们可以将处理后的印章检测训练数据移动到PaddleDetection/dataset目录下或者创建一个软连接。 + +``` +!ln -s seal_VOC ./PaddleDetection/dataset/ +``` + +另外图象中印章数量比较少,可以调整NMS后处理的检测框数量,即keep_top_k,nms_top_k 从100,1000,调整为10,100。在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容完成后处理参数的调整 +``` +BBoxPostProcess: + decode: + name: YOLOBox + conf_thresh: 0.005 + downsample_ratio: 32 + clip_bbox: true + scale_x_y: 1.05 + nms: + name: MultiClassNMS + keep_top_k: 10 # 修改前100 + nms_threshold: 0.45 + nms_top_k: 100 # 修改前1000 + score_threshold: 0.005 +``` + + +修改完成后,需要在PaddleDetection中增加印章数据的处理代码,即在PaddleDetection/ppdet/data/source/目录下创建seal.py文件,文件中填充如下代码: +``` +import os +import numpy as np +from ppdet.core.workspace import register, serializable +from .dataset import DetDataset +import cv2 +import json + +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + + +@register +@serializable +class SealDataSet(DetDataset): + """ + Load dataset with COCO format. + + Args: + dataset_dir (str): root directory for dataset. + image_dir (str): directory for images. + anno_path (str): coco annotation file path. + data_fields (list): key name of data dictionary, at least have 'image'. + sample_num (int): number of samples to load, -1 means all. + load_crowd (bool): whether to load crowded ground-truth. + False as default + allow_empty (bool): whether to load empty entry. False as default + empty_ratio (float): the ratio of empty record number to total + record's, if empty_ratio is out of [0. ,1.), do not sample the + records and use all the empty entries. 1. as default + """ + + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + data_fields=['image'], + sample_num=-1, + load_crowd=False, + allow_empty=False, + empty_ratio=1.): + super(SealDataSet, self).__init__(dataset_dir, image_dir, anno_path, + data_fields, sample_num) + self.load_image_only = False + self.load_semantic = False + self.load_crowd = load_crowd + self.allow_empty = allow_empty + self.empty_ratio = empty_ratio + + def _sample_empty(self, records, num): + # if empty_ratio is out of [0. ,1.), do not sample the records + if self.empty_ratio < 0. or self.empty_ratio >= 1.: + return records + import random + sample_num = min( + int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records)) + records = random.sample(records, sample_num) + return records + + def parse_dataset(self): + anno_path = os.path.join(self.dataset_dir, self.anno_path) + image_dir = os.path.join(self.dataset_dir, self.image_dir) + + records = [] + empty_records = [] + ct = 0 + + assert anno_path.endswith('.txt'), \ + 'invalid seal_gt file: ' + anno_path + + all_datas = open(anno_path, 'r').readlines() + + for idx, line in enumerate(all_datas): + im_path, label = line.strip().split('\t') + img_path = os.path.join(image_dir, im_path) + label = json.loads(label) + im_h, im_w, im_c = cv2.imread(img_path).shape + + coco_rec = { + 'im_file': img_path, + 'im_id': np.array([idx]), + 'h': im_h, + 'w': im_w, + } if 'image' in self.data_fields else {} + + if not self.load_image_only: + bboxes = [] + for anno in label: + poly = anno['polys'] + # poly to box + x1 = np.min(np.array(poly)[:, 0]) + y1 = np.min(np.array(poly)[:, 1]) + x2 = np.max(np.array(poly)[:, 0]) + y2 = np.max(np.array(poly)[:, 1]) + eps = 1e-5 + if x2 - x1 > eps and y2 - y1 > eps: + clean_box = [ + round(float(x), 3) for x in [x1, y1, x2, y2] + ] + anno = {'clean_box': clean_box, 'gt_cls':int(anno['cls'])} + bboxes.append(anno) + else: + logger.info("invalid box") + + num_bbox = len(bboxes) + if num_bbox <= 0: + continue + + gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + gt_class = np.zeros((num_bbox, 1), dtype=np.int32) + is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) + # gt_poly = [None] * num_bbox + + for i, box in enumerate(bboxes): + gt_class[i][0] = box['gt_cls'] + gt_bbox[i, :] = box['clean_box'] + is_crowd[i][0] = 0 + + gt_rec = { + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + # 'gt_poly': gt_poly, + } + + for k, v in gt_rec.items(): + if k in self.data_fields: + coco_rec[k] = v + + records.append(coco_rec) + ct += 1 + if self.sample_num > 0 and ct >= self.sample_num: + break + self.roidbs = records +``` + +**启动训练** + +启动单卡训练的命令为: +``` +!python3 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval + +# 分布式训练命令为: +!python3 -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval +``` + +训练完成后,日志中会打印模型的精度: + +``` +[07/05 11:42:09] ppdet.engine INFO: Eval iter: 0 +[07/05 11:42:14] ppdet.metrics.metrics INFO: Accumulating evaluatation results... +[07/05 11:42:14] ppdet.metrics.metrics INFO: mAP(0.50, 11point) = 99.31% +[07/05 11:42:14] ppdet.engine INFO: Total sample number: 112, averge FPS: 26.45840794253432 +[07/05 11:42:14] ppdet.engine INFO: Best test bbox ap is 0.996. +``` + + +我们可以使用训练好的模型观察预测结果: +``` +!python3 tools/infer.py -c configs/ppyolo/ppyolo_mbv3_large.yml -o weights=./output/ppyolo_mbv3_large/model_final.pdparams --img_dir=./test.jpg +``` +预测结果如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/0f650c032b0f4d56bd639713924768cc820635e9977845008d233f465291a29e) + +# 5. 印章文字识别实践 + +在使用ppyolo检测到印章区域后,接下来借助PaddleOCR里的文字识别能力,完成印章中文字的识别。 + +PaddleOCR中的OCR算法包含文字检测算法,文字识别算法以及OCR端对端算法。 + +文字检测算法负责检测到图像中的文字,再由文字识别模型识别出检测到的文字,进而实现OCR的任务。文字检测+文字识别串联完成OCR任务的架构称为两阶段的OCR算法。相对应的端对端的OCR方法可以用一个算法同时完成文字检测和识别的任务。 + + +| 文字检测 | 文字识别 | 端对端算法 | +| -------- | -------- | -------- | +| DB\DB++\EAST\SAST\PSENet | SVTR\CRNN\NRTN\Abinet\SAR\... | PGNet | + + +本节中将分别介绍端对端的文字检测识别算法以及两阶段的文字检测识别算法在印章检测识别任务上的实践。 + + +## 5.1 端对端印章文字识别实践 + +本节介绍使用PaddleOCR里的PGNet算法完成印章文字识别。 + +PGNet属于端对端的文字检测识别算法,在PaddleOCR中的配置文件为: +[PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/e2e/e2e_r50_vd_pg.yml) + +使用PGNet完成文字检测识别任务的步骤为: +- 修改配置文件 +- 启动训练 + +PGNet默认配置文件的数据路径为totaltext数据集路径,本次训练中,需要修改为上一节数据处理后得到的标签文件和数据目录: + +训练数据配置修改后如下: +``` +Train: + dataset: + name: PGDataSet + data_dir: ./train_data/seal_ppocr + label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt] + ratio_list: [1.0] +``` +测试数据集配置修改后如下: +``` +Eval: + dataset: + name: PGDataSet + data_dir: ./train_data/seal_ppocr_test + label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt] +``` + +启动训练的命令为: +``` +!python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml +``` +模型训练完成后,可以得到最终的精度为47.4%。数据量较少,以及数据质量较差会影响模型的训练精度,如果有更多的数据参与训练,精度将进一步提升。 + +如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + +## 5.2 两阶段印章文字识别实践 + +上一节介绍了使用PGNet实现印章识别任务的训练流程。本小节将介绍使用PaddleOCR里的文字检测和文字识别算法分别完成印章文字的检测和识别。 + +### 5.2.1 印章文字检测 + +PaddleOCR中包含丰富的文字检测算法,包含DB,DB++,EAST,SAST,PSENet等等。其中DB,DB++,PSENet均支持弯曲文字检测,本项目中,使用DB++作为印章弯曲文字检测算法。 + +PaddleOCR中发布的db++文字检测算法模型是英文文本检测模型,因此需要重新训练模型。 + + +修改[DB++配置文件](DB++的默认配置文件位于[configs/det/det_r50_db++_icdar15.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/det/det_r50_db%2B%2B_icdar15.yml) +中的数据路径: + + +``` +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr + label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt] + ratio_list: [1.0] +``` +测试数据集配置修改后如下: +``` +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr_test + label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt] +``` + + +启动训练: +``` +!python3 tools/train.py -c configs/det/det_r50_db++_icdar15.yml -o Global.epoch_num=100 +``` + +考虑到数据较少,通过Global.epoch_num设置仅训练100个epoch。 +模型训练完成后,在测试集上预测的可视化效果如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/498119182f0a414ab86ae2de752fa31c9ddc3a74a76847049cc57884602cb269) + + +如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + + +### 5.2.2 印章文字识别 + +上一节中完成了印章文字的检测模型训练,本节介绍印章文字识别模型的训练。识别模型采用SVTR算法,SVTR算法是IJCAI收录的文字识别算法,SVTR模型具备超轻量高精度的特点。 + +在启动训练之前,需要准备印章文字识别需要的数据集,需要使用如下代码,将印章中的文字区域剪切出来构建训练集。 + +``` +import cv2 +import numpy as np + +def get_rotate_crop_image(img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + +def run(data_dir, label_file, save_dir): + datas = open(label_file, 'r').readlines() + for idx, line in enumerate(datas): + img_path, label = line.strip().split('\t') + img_path = os.path.join(data_dir, img_path) + + label = json.loads(label) + src_im = cv2.imread(img_path) + if src_im is None: + continue + + for anno in label: + seal_box = anno['seal_box'] + txt_boxes = anno['polys'] + crop_im = get_rotate_crop_image(src_im, text_boxes) + + save_path = os.path.join(save_dir, f'{idx}.png') + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # print(src_im.shape) + cv2.imwrite(save_path, crop_im) + +``` + + +数据处理完成后,即可配置训练的配置文件。SVTR配置文件选择[configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml) +修改SVTR配置文件中的训练数据部分如下: + +``` +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr_crop/ + label_file_list: + - ./train_data/seal_ppocr_crop/train_list.txt +``` + +修改预测部分配置文件: +``` +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr_crop/ + label_file_list: + - ./train_data/seal_ppocr_crop_test/train_list.txt +``` + +启动训练: + +``` +!python3 tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml + +``` + +训练完成后可以发现测试集指标达到了61%。 +由于数据较少,训练时会发现在训练集上的acc指标远大于测试集上的acc指标,即出现过拟合现象。通过补充数据和一些数据增强可以缓解这个问题。 + + + +如需获取已训练模型,请扫下图二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + +
+ +
diff --git "a/applications/\345\217\221\347\245\250\345\205\263\351\224\256\344\277\241\346\201\257\346\212\275\345\217\226.md" "b/applications/\345\217\221\347\245\250\345\205\263\351\224\256\344\277\241\346\201\257\346\212\275\345\217\226.md" new file mode 100644 index 0000000..b8a8ee2 --- /dev/null +++ "b/applications/\345\217\221\347\245\250\345\205\263\351\224\256\344\277\241\346\201\257\346\212\275\345\217\226.md" @@ -0,0 +1,343 @@ + +# 基于VI-LayoutXLM的发票关键信息抽取 + +- [1. 项目背景及意义](#1-项目背景及意义) +- [2. 项目内容](#2-项目内容) +- [3. 安装环境](#3-安装环境) +- [4. 关键信息抽取](#4-关键信息抽取) + - [4.1 文本检测](#41-文本检测) + - [4.2 文本识别](#42-文本识别) + - [4.3 语义实体识别](#43-语义实体识别) + - [4.4 关系抽取](#44-关系抽取) + + + +## 1. 项目背景及意义 + +关键信息抽取在文档场景中被广泛使用,如身份证中的姓名、住址信息抽取,快递单中的姓名、联系方式等关键字段内容的抽取。传统基于模板匹配的方案需要针对不同的场景制定模板并进行适配,较为繁琐,不够鲁棒。基于该问题,我们借助飞桨提供的PaddleOCR套件中的关键信息抽取方案,实现对增值税发票场景的关键信息抽取。 + +## 2. 项目内容 + +本项目基于PaddleOCR开源套件,以VI-LayoutXLM多模态关键信息抽取模型为基础,针对增值税发票场景进行适配,提取该场景的关键信息。 + +## 3. 安装环境 + +```bash +# 首先git官方的PaddleOCR项目,安装需要的依赖 +# 第一次运行打开该注释 +git clone https://gitee.com/PaddlePaddle/PaddleOCR.git +cd PaddleOCR +# 安装PaddleOCR的依赖 +pip install -r requirements.txt +# 安装关键信息抽取任务的依赖 +pip install -r ./ppstructure/kie/requirements.txt +``` + +## 4. 关键信息抽取 + +基于文档图像的关键信息抽取包含3个部分:(1)文本检测(2)文本识别(3)关键信息抽取方法,包括语义实体识别或者关系抽取,下面分别进行介绍。 + +### 4.1 文本检测 + + +本文重点关注发票的关键信息抽取模型训练与预测过程,因此在关键信息抽取过程中,直接使用标注的文本检测与识别标注信息进行测试,如果你希望自定义该场景的文本检测模型,完成端到端的关键信息抽取部分,请参考[文本检测模型训练教程](../doc/doc_ch/detection.md),按照训练数据格式准备数据,并完成该场景下垂类文本检测模型的微调过程。 + + +### 4.2 文本识别 + +本文重点关注发票的关键信息抽取模型训练与预测过程,因此在关键信息抽取过程中,直接使用提供的文本检测与识别标注信息进行测试,如果你希望自定义该场景的文本检测模型,完成端到端的关键信息抽取部分,请参考[文本识别模型训练教程](../doc/doc_ch/recognition.md),按照训练数据格式准备数据,并完成该场景下垂类文本识别模型的微调过程。 + +### 4.3 语义实体识别 (Semantic Entity Recognition) + +语义实体识别指的是给定一段文本行,确定其类别(如`姓名`、`住址`等类别)。PaddleOCR中提供了基于VI-LayoutXLM的多模态语义实体识别方法,融合文本、位置与版面信息,相比LayoutXLM多模态模型,去除了其中的视觉骨干网络特征提取部分,引入符合阅读顺序的文本行排序方法,同时使用UDML联合互蒸馏方法进行训练,最终在精度与速度方面均超越LayoutXLM。更多关于VI-LayoutXLM的算法介绍与精度指标,请参考:[VI-LayoutXLM算法介绍](../doc/doc_ch/algorithm_kie_vi_layoutxlm.md)。 + +#### 4.3.1 准备数据 + +发票场景为例,我们首先需要标注出其中的关键字段,我们将其标注为`问题-答案`的key-value pair,如下,编号No为12270830,则`No`字段标注为question,`12270830`字段标注为answer。如下图所示。 + +
+ +
+ +**注意:** + +* 如果文本检测模型数据标注过程中,没有标注 **非关键信息内容** 的检测框,那么在标注关键信息抽取任务的时候,也不需要标注该部分,如上图所示;如果标注的过程,如果同时标注了**非关键信息内容** 的检测框,那么我们需要将该部分的label记为other。 +* 标注过程中,需要以文本行为单位进行标注,无需标注单个字符的位置信息。 + + +已经处理好的增值税发票数据集从这里下载:[增值税发票数据集下载链接](https://aistudio.baidu.com/aistudio/datasetdetail/165561)。 + +下载好发票数据集,并解压在train_data目录下,目录结构如下所示。 + +``` +train_data + |--zzsfp + |---class_list.txt + |---imgs/ + |---train.json + |---val.json +``` + +其中`class_list.txt`是包含`other`, `question`, `answer`,3个种类的的类别列表(不区分大小写),`imgs`目录底下,`train.json`与`val.json`分别表示训练与评估集合的标注文件。训练集中包含30张图片,验证集中包含8张图片。部分标注如下所示。 + +```py +b33.jpg [{"transcription": "No", "label": "question", "points": [[2882, 472], [3026, 472], [3026, 588], [2882, 588]], }, {"transcription": "12269563", "label": "answer", "points": [[3066, 448], [3598, 448], [3598, 576], [3066, 576]], ]}] +``` + +相比于OCR检测的标注,仅多了`label`字段。 + + +#### 4.3.2 开始训练 + + +VI-LayoutXLM的配置为[ser_vi_layoutxlm_xfund_zh_udml.yml](../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml),需要修改数据、类别数目以及配置文件。 + +```yml +Architecture: + model_type: &model_type "kie" + name: DistillationModel + algorithm: Distillation + Models: + Teacher: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: &algorithm "LayoutXLM" + Transform: + Backbone: + name: LayoutXLMForSer + pretrained: True + # one of base or vi + mode: vi + checkpoints: + # 定义类别数目 + num_classes: &num_classes 5 + ... + +PostProcess: + name: DistillationSerPostProcess + model_name: ["Student", "Teacher"] + key: backbone_out + # 定义类别文件 + class_path: &class_path train_data/zzsfp/class_list.txt + +Train: + dataset: + name: SimpleDataSet + # 定义训练数据目录与标注文件 + data_dir: train_data/zzsfp/imgs + label_file_list: + - train_data/zzsfp/train.json + ... + +Eval: + dataset: + # 定义评估数据目录与标注文件 + name: SimpleDataSet + data_dir: train_data/zzsfp/imgs + label_file_list: + - train_data/zzsfp/val.json + ... +``` + +LayoutXLM与VI-LayoutXLM针对该场景的训练结果如下所示。 + +| 模型 | 迭代轮数 | Hmean | +| :---: | :---: | :---: | +| LayoutXLM | 50 | 100.00% | +| VI-LayoutXLM | 50 | 100.00% | + +可以看出,由于当前数据量较少,场景比较简单,因此2个模型的Hmean均达到了100%。 + + +#### 4.3.3 模型评估 + +模型训练过程中,使用的是知识蒸馏的策略,最终保留了学生模型的参数,在评估时,我们需要针对学生模型的配置文件进行修改: [ser_vi_layoutxlm_xfund_zh.yml](../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml),修改内容与训练配置相同,包括**类别数、类别映射文件、数据目录**。 + +修改完成后,执行下面的命令完成评估过程。 + +```bash +# 注意:需要根据你的配置文件地址与保存的模型地址,对评估命令进行修改 +python3 tools/eval.py -c ./fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy +``` + +输出结果如下所示。 + +``` +[2022/08/18 08:49:58] ppocr INFO: metric eval *************** +[2022/08/18 08:49:58] ppocr INFO: precision:1.0 +[2022/08/18 08:49:58] ppocr INFO: recall:1.0 +[2022/08/18 08:49:58] ppocr INFO: hmean:1.0 +[2022/08/18 08:49:58] ppocr INFO: fps:1.9740402401574881 +``` + +#### 4.3.4 模型预测 + +使用下面的命令进行预测。 + +```bash +python3 tools/infer_kie_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/XFUND/zh_val/val.json Global.infer_mode=False +``` + +预测结果会保存在配置文件中的`Global.save_res_path`目录中。 + +部分预测结果如下所示。 + +
+ +
+ + +* 注意:在预测时,使用的文本检测与识别结果为标注的结果,直接从json文件里面进行读取。 + +如果希望使用OCR引擎结果得到的结果进行推理,则可以使用下面的命令进行推理。 + + +```bash +python3 tools/infer_kie_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/imgs/b25.jpg Global.infer_mode=True +``` + +结果如下所示。 + +
+ +
+ +它会使用PP-OCRv3的文本检测与识别模型进行获取文本位置与内容信息。 + +可以看出,由于训练的过程中,没有标注额外的字段为other类别,所以大多数检测出来的字段被预测为question或者answer。 + +如果希望构建基于你在垂类场景训练得到的OCR检测与识别模型,可以使用下面的方法传入检测与识别的inference 模型路径,即可完成OCR文本检测与识别以及SER的串联过程。 + +```bash +python3 tools/infer_kie_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/imgs/b25.jpg Global.infer_mode=True Global.kie_rec_model_dir="your_rec_model" Global.kie_det_model_dir="your_det_model" +``` + +### 4.4 关系抽取(Relation Extraction) + +使用SER模型,可以获取图像中所有的question与answer的字段,继续这些字段的类别,我们需要进一步获取question与answer之间的连接,因此需要进一步训练关系抽取模型,解决该问题。本文也基于VI-LayoutXLM多模态预训练模型,进行下游RE任务的模型训练。 + +#### 4.4.1 准备数据 + +以发票场景为例,相比于SER任务,RE中还需要标记每个文本行的id信息以及链接关系linking,如下所示。 + +
+ +
+ + +标注文件的部分内容如下所示。 + +```py +b33.jpg [{"transcription": "No", "label": "question", "points": [[2882, 472], [3026, 472], [3026, 588], [2882, 588]], "id": 0, "linking": [[0, 1]]}, {"transcription": "12269563", "label": "answer", "points": [[3066, 448], [3598, 448], [3598, 576], [3066, 576]], "id": 1, "linking": [[0, 1]]}] +``` + +相比与SER的标注,多了`id`与`linking`的信息,分别表示唯一标识以及连接关系。 + +已经处理好的增值税发票数据集从这里下载:[增值税发票数据集下载链接](https://aistudio.baidu.com/aistudio/datasetdetail/165561)。 + +#### 4.4.2 开始训练 + +基于VI-LayoutXLM的RE任务配置为[re_vi_layoutxlm_xfund_zh_udml.yml](../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml),需要修改**数据路径、类别列表文件**。 + +```yml +Train: + dataset: + name: SimpleDataSet + # 定义训练数据目录与标注文件 + data_dir: train_data/zzsfp/imgs + label_file_list: + - train_data/zzsfp/train.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: True + algorithm: *algorithm + class_path: &class_path train_data/zzsfp/class_list.txt + ... + +Eval: + dataset: + # 定义评估数据目录与标注文件 + name: SimpleDataSet + data_dir: train_data/zzsfp/imgs + label_file_list: + - train_data/zzsfp/val.json + ... + +``` + +LayoutXLM与VI-LayoutXLM针对该场景的训练结果如下所示。 + +| 模型 | 迭代轮数 | Hmean | +| :---: | :---: | :---: | +| LayoutXLM | 50 | 98.00% | +| VI-LayoutXLM | 50 | 99.30% | + +可以看出,对于VI-LayoutXLM相比LayoutXLM的Hmean高了1.3%。 + +如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + +
+ +
+ + +#### 4.4.3 模型评估 + +模型训练过程中,使用的是知识蒸馏的策略,最终保留了学生模型的参数,在评估时,我们需要针对学生模型的配置文件进行修改: [re_vi_layoutxlm_xfund_zh.yml](../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml),修改内容与训练配置相同,包括**类别映射文件、数据目录**。 + +修改完成后,执行下面的命令完成评估过程。 + +```bash +# 注意:需要根据你的配置文件地址与保存的模型地址,对评估命令进行修改 +python3 tools/eval.py -c ./fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_udml/best_accuracy +``` + +输出结果如下所示。 + +```py +[2022/08/18 12:17:14] ppocr INFO: metric eval *************** +[2022/08/18 12:17:14] ppocr INFO: precision:1.0 +[2022/08/18 12:17:14] ppocr INFO: recall:0.9873417721518988 +[2022/08/18 12:17:14] ppocr INFO: hmean:0.9936305732484078 +[2022/08/18 12:17:14] ppocr INFO: fps:2.765963539771157 +``` + +#### 4.4.4 模型预测 + +使用下面的命令进行预测。 + +```bash +# -c 后面的是RE任务的配置文件 +# -o 后面的字段是RE任务的配置 +# -c_ser 后面的是SER任务的配置文件 +# -c_ser 后面的字段是SER任务的配置 +python3 tools/infer_kie_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_trained/best_accuracy Global.infer_img=./train_data/zzsfp/val.json Global.infer_mode=False -c_ser fapiao/ser_vi_layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_trained/best_accuracy +``` + +预测结果会保存在配置文件中的`Global.save_res_path`目录中。 + +部分预测结果如下所示。 + +
+ +
+ + +* 注意:在预测时,使用的文本检测与识别结果为标注的结果,直接从json文件里面进行读取。 + +如果希望使用OCR引擎结果得到的结果进行推理,则可以使用下面的命令进行推理。 + +```bash +python3 tools/infer_kie_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/val.json Global.infer_mode=True -c_ser fapiao/ser_vi_layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy +``` + +如果希望构建基于你在垂类场景训练得到的OCR检测与识别模型,可以使用下面的方法传入,即可完成SER + RE的串联过程。 + +```bash +python3 tools/infer_kie_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/val.json Global.infer_mode=True -c_ser fapiao/ser_vi_layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.kie_rec_model_dir="your_rec_model" Global.kie_det_model_dir="your_det_model" +``` diff --git "a/applications/\345\244\232\346\250\241\346\200\201\350\241\250\345\215\225\350\257\206\345\210\253.md" "b/applications/\345\244\232\346\250\241\346\200\201\350\241\250\345\215\225\350\257\206\345\210\253.md" new file mode 100644 index 0000000..471ca63 --- /dev/null +++ "b/applications/\345\244\232\346\250\241\346\200\201\350\241\250\345\215\225\350\257\206\345\210\253.md" @@ -0,0 +1,899 @@ +# 多模态表单识别 +- [多模态表单识别](#多模态表单识别) + - [1 项目说明](#1-项目说明) + - [2 安装说明](#2-安装说明) + - [3 数据准备](#3-数据准备) + - [3.1 下载处理好的数据集](#31-下载处理好的数据集) + - [3.2 转换为PaddleOCR检测和识别格式](#32-转换为paddleocr检测和识别格式) + - [4 OCR](#4-ocr) + - [4.1 文本检测](#41-文本检测) + - [4.1.1 方案1:预训练模型](#411-方案1预训练模型) + - [4.1.2 方案2:XFUND数据集+fine-tune](#412-方案2xfund数据集fine-tune) + - [4.2 文本识别](#42-文本识别) + - [4.2.1 方案1:预训练模型](#421-方案1预训练模型) + - [4.2.2 方案2:XFUND数据集+finetune](#422-方案2xfund数据集finetune) + - [4.2.3 方案3:XFUND数据集+finetune+真实通用识别数据](#423-方案3xfund数据集finetune真实通用识别数据) + - [5 文档视觉问答(DOC-VQA)](#5-文档视觉问答doc-vqa) + - [5.1 SER](#51-ser) + - [5.1.1 模型训练](#511-模型训练) + - [5.1.2 模型评估](#512-模型评估) + - [5.1.3 模型预测](#513-模型预测) + - [5.2 RE](#52-re) + - [5.2.1 模型训练](#521-模型训练) + - [5.2.2 模型评估](#522-模型评估) + - [5.2.3 模型预测](#523-模型预测) + - [6 导出Excel](#6-导出excel) + - [获得模型](#获得模型) + - [更多资源](#更多资源) + - [参考链接](#参考链接) + +## 1 项目说明 + +计算机视觉在金融领域的应用覆盖文字识别、图像识别、视频识别等,其中文字识别(OCR)是金融领域中的核心AI能力,其应用覆盖客户服务、风险防控、运营管理等各项业务,针对的对象包括通用卡证票据识别(银行卡、身份证、营业执照等)、通用文本表格识别(印刷体、多语言、手写体等)以及一些金融特色票据凭证。通过因此如果能够在结构化信息提取时同时利用文字、页面布局等信息,便可增强不同版式下的泛化性。 + +表单识别旨在识别各种具有表格性质的证件、房产证、营业执照、个人信息表、发票等关键键值对(如姓名-张三),其广泛应用于银行、证券、公司财务等领域,具有很高的商业价值。本次范例项目开源了全流程表单识别方案,能够在多个场景快速实现迁移能力。表单识别通常存在以下难点: + +- 人工摘录工作效率低; +- 国内常见表单版式多; +- 传统技术方案泛化效果不满足。 + + +表单识别包含两大阶段:OCR阶段和文档视觉问答阶段。 + +其中,OCR阶段选取了PaddleOCR的PP-OCRv2模型,主要由文本检测和文本识别两个模块组成。DOC-VQA文档视觉问答阶段基于PaddleNLP自然语言处理算法库实现的LayoutXLM模型,支持基于多模态方法的语义实体识别(Semantic Entity Recognition, SER)以及关系抽取(Relation Extraction, RE)任务。本案例流程如 **图1** 所示: + +
+
图1 多模态表单识别流程图
+ +注:欢迎再AIStudio领取免费算力体验线上实训,项目链接: [多模态表单识别](https://aistudio.baidu.com/aistudio/projectdetail/3884375?contributionType=1) + +## 2 安装说明 + + +下载PaddleOCR源码,上述AIStudio项目中已经帮大家打包好的PaddleOCR(已经修改好配置文件),无需下载解压即可,只需安装依赖环境~ + + +```python +unzip -q PaddleOCR.zip +``` + + +```python +# 如仍需安装or安装更新,可以执行以下步骤 +# git clone https://github.com/PaddlePaddle/PaddleOCR.git -b dygraph +# git clone https://gitee.com/PaddlePaddle/PaddleOCR +``` + + +```python +# 安装依赖包 +pip install -U pip +pip install -r /home/aistudio/PaddleOCR/requirements.txt +pip install paddleocr + +pip install yacs gnureadline paddlenlp==2.2.1 +pip install xlsxwriter +``` + +## 3 数据准备 + +这里使用[XFUN数据集](https://github.com/doc-analysis/XFUND)做为实验数据集。 XFUN数据集是微软提出的一个用于KIE任务的多语言数据集,共包含七个数据集,每个数据集包含149张训练集和50张验证集 + +分别为:ZH(中文)、JA(日语)、ES(西班牙)、FR(法语)、IT(意大利)、DE(德语)、PT(葡萄牙) + +本次实验选取中文数据集作为我们的演示数据集。法语数据集作为实践课程的数据集,数据集样例图如 **图2** 所示。 + +
+
图2 数据集样例,左中文,右法语
+ +### 3.1 下载处理好的数据集 + + +处理好的XFUND中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar) ,可以运行如下指令完成中文数据集下载和解压。 + +
+
图3 下载数据集
+ + +```python +wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar +tar -xf XFUND.tar + +# XFUN其他数据集使用下面的代码进行转换 +# 代码链接:https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppstructure/vqa/helper/trans_xfun_data.py +# %cd PaddleOCR +# python3 ppstructure/vqa/tools/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path +# %cd ../ +``` + +运行上述指令后在 /home/aistudio/PaddleOCR/ppstructure/vqa/XFUND 目录下有2个文件夹,目录结构如下所示: + +```bash +/home/aistudio/PaddleOCR/ppstructure/vqa/XFUND + └─ zh_train/ 训练集 + ├── image/ 图片存放文件夹 + ├── xfun_normalize_train.json 标注信息 + └─ zh_val/ 验证集 + ├── image/ 图片存放文件夹 + ├── xfun_normalize_val.json 标注信息 + +``` + +该数据集的标注格式为 + +```bash +{ + "height": 3508, # 图像高度 + "width": 2480, # 图像宽度 + "ocr_info": [ + { + "text": "邮政地址:", # 单个文本内容 + "label": "question", # 文本所属类别 + "bbox": [261, 802, 483, 859], # 单个文本框 + "id": 54, # 文本索引 + "linking": [[54, 60]], # 当前文本和其他文本的关系 [question, answer] + "words": [] + }, + { + "text": "湖南省怀化市市辖区", + "label": "answer", + "bbox": [487, 810, 862, 859], + "id": 60, + "linking": [[54, 60]], + "words": [] + } + ] +} +``` + +### 3.2 转换为PaddleOCR检测和识别格式 + +使用XFUND训练PaddleOCR检测和识别模型,需要将数据集格式改为训练需求的格式。 + +
+
图4 转换为OCR格式
+ +- **文本检测** 标注文件格式如下,中间用'\t'分隔: + +" 图像文件名 json.dumps编码的图像标注信息" +ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}] + +json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 `transcription` 表示当前文本框的文字,***当其内容为“###”时,表示该文本框无效,在训练时会跳过。*** + +- **文本识别** 标注文件的格式如下, txt文件中默认请将图片路径和图片标签用'\t'分割,如用其他方式分割将造成训练报错。 + +``` +" 图像文件名 图像标注信息 " + +train_data/rec/train/word_001.jpg 简单可依赖 +train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单 +... +``` + + + + +```python +unzip -q /home/aistudio/data/data140302/XFUND_ori.zip -d /home/aistudio/data/data140302/ +``` + +已经提供转换脚本,执行如下代码即可转换成功: + + +```python +%cd /home/aistudio/ +python trans_xfund_data.py +``` + +## 4 OCR + +选用飞桨OCR开发套件[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/README_ch.md)中的PP-OCRv2模型进行文本检测和识别。PP-OCRv2在PP-OCR的基础上,进一步在5个方面重点优化,检测模型采用CML协同互学习知识蒸馏策略和CopyPaste数据增广策略;识别模型采用LCNet轻量级骨干网络、UDML 改进知识蒸馏策略和[Enhanced CTC loss](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/enhanced_ctc_loss.md)损失函数改进,进一步在推理速度和预测效果上取得明显提升。更多细节请参考PP-OCRv2[技术报告](https://arxiv.org/abs/2109.03144)。 + +### 4.1 文本检测 + +我们使用2种方案进行训练、评估: +- **PP-OCRv2中英文超轻量检测预训练模型** +- **XFUND数据集+fine-tune** + +#### 4.1.1 方案1:预训练模型 + +**1)下载预训练模型** + +
+
图5 文本检测方案1-下载预训练模型
+ + +PaddleOCR已经提供了PP-OCR系列模型,部分模型展示如下表所示: + +| 模型简介 | 模型名称 | 推荐场景 | 检测模型 | 方向分类器 | 识别模型 | +| ------------------------------------- | ----------------------- | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | +| 中英文超轻量PP-OCRv2模型(13.0M) | ch_PP-OCRv2_xx | 移动端&服务器端 | [推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar) | +| 中英文超轻量PP-OCR mobile模型(9.4M) | ch_ppocr_mobile_v2.0_xx | 移动端&服务器端 | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_pre.tar) | +| 中英文通用PP-OCR server模型(143.4M) | ch_ppocr_server_v2.0_xx | 服务器端 | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) | [推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) | + +更多模型下载(包括多语言),可以参考[PP-OCR 系列模型下载](./doc/doc_ch/models_list.md) + + +这里我们使用PP-OCRv2中英文超轻量检测模型,下载并解压预训练模型: + + + + +```python +%cd /home/aistudio/PaddleOCR/pretrain/ +wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar +tar -xf ch_PP-OCRv2_det_distill_train.tar && rm -rf ch_PP-OCRv2_det_distill_train.tar +% cd .. +``` + +**2)模型评估** + +
+
图6 文本检测方案1-模型评估
+ +接着使用下载的超轻量检测模型在XFUND验证集上进行评估,由于蒸馏需要包含多个网络,甚至多个Student网络,在计算指标的时候只需要计算一个Student网络的指标即可,key字段设置为Student则表示只计算Student网络的精度。 + +``` +Metric: + name: DistillationMetric + base_metric_name: DetMetric + main_indicator: hmean + key: "Student" +``` +首先修改配置文件`configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml`中的以下字段: +``` +Eval.dataset.data_dir:指向验证集图片存放目录 +Eval.dataset.label_file_list:指向验证集标注文件 +``` + + +然后在XFUND验证集上进行评估,具体代码如下: + + +```python +%cd /home/aistudio/PaddleOCR +python tools/eval.py \ + -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_distill.yml \ + -o Global.checkpoints="./pretrain_models/ch_PP-OCRv2_det_distill_train/best_accuracy" +``` + +使用预训练模型进行评估,指标如下所示: + +| 方案 | hmeans | +| -------- | -------- | +| PP-OCRv2中英文超轻量检测预训练模型 | 77.26% | + +使用文本检测预训练模型在XFUND验证集上评估,达到77%左右,充分说明ppocr提供的预训练模型具有泛化能力。 + +#### 4.1.2 方案2:XFUND数据集+fine-tune + +PaddleOCR提供的蒸馏预训练模型包含了多个模型的参数,我们提取Student模型的参数,在XFUND数据集上进行finetune,可以参考如下代码: + +```python +import paddle +# 加载预训练模型 +all_params = paddle.load("pretrain/ch_PP-OCRv2_det_distill_train/best_accuracy.pdparams") +# 查看权重参数的keys +# print(all_params.keys()) +# 学生模型的权重提取 +s_params = {key[len("student_model."):]: all_params[key] for key in all_params if "student_model." in key} +# 查看学生模型权重参数的keys +print(s_params.keys()) +# 保存 +paddle.save(s_params, "pretrain/ch_PP-OCRv2_det_distill_train/student.pdparams") +``` + +**1)模型训练** + +
+
图7 文本检测方案2-模型训练
+ + +修改配置文件`configs/det/ch_PP-OCRv2_det_student.yml`中的以下字段: +``` +Global.pretrained_model:指向预训练模型路径 +Train.dataset.data_dir:指向训练集图片存放目录 +Train.dataset.label_file_list:指向训练集标注文件 +Eval.dataset.data_dir:指向验证集图片存放目录 +Eval.dataset.label_file_list:指向验证集标注文件 +Optimizer.lr.learning_rate:调整学习率,本实验设置为0.005 +Train.dataset.transforms.EastRandomCropData.size:训练尺寸改为[1600, 1600] +Eval.dataset.transforms.DetResizeForTest:评估尺寸,添加如下参数 + limit_side_len: 1600 + limit_type: 'min' + +``` +执行下面命令启动训练: + + +```python +CUDA_VISIBLE_DEVICES=0 python tools/train.py \ + -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml +``` + +**2)模型评估** + +
+
图8 文本检测方案2-模型评估
+ +使用训练好的模型进行评估,更新模型路径`Global.checkpoints`。如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + +
+ +
+ +将下载或训练完成的模型放置在对应目录下即可完成模型评估 + + +```python +%cd /home/aistudio/PaddleOCR/ +python tools/eval.py \ + -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml \ + -o Global.checkpoints="pretrain/ch_db_mv3-student1600-finetune/best_accuracy" +``` + +同时我们提供了未finetuen的模型,配置文件参数(`pretrained_model`设置为空,`learning_rate` 设置为0.001) + + +```python +%cd /home/aistudio/PaddleOCR/ +python tools/eval.py \ + -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml \ + -o Global.checkpoints="pretrain/ch_db_mv3-student1600/best_accuracy" +``` + +使用训练好的模型进行评估,指标如下所示: + +| 方案 | hmeans | +| -------- | -------- | +| XFUND数据集 | 79.27% | +| XFUND数据集+fine-tune | 85.24% | + +对比仅使用XFUND数据集训练的模型,使用XFUND数据集+finetune训练,在验证集上评估达到85%左右,说明 finetune会提升垂类场景效果。 + +**3)导出模型** + +
+
图9 文本检测方案2-模型导出
+ +在模型训练过程中保存的模型文件是包含前向预测和反向传播的过程,在实际的工业部署则不需要反向传播,因此需要将模型进行导成部署需要的模型格式。 执行下面命令,即可导出模型。 + + +```python +# 加载配置文件`ch_PP-OCRv2_det_student.yml`,从`pretrain/ch_db_mv3-student1600-finetune`目录下加载`best_accuracy`模型 +# inference模型保存在`./output/det_db_inference`目录下 +%cd /home/aistudio/PaddleOCR/ +python tools/export_model.py \ + -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_student.yml \ + -o Global.pretrained_model="pretrain/ch_db_mv3-student1600-finetune/best_accuracy" \ + Global.save_inference_dir="./output/det_db_inference/" +``` + +转换成功后,在目录下有三个文件: +``` +/inference/rec_crnn/ + ├── inference.pdiparams # 识别inference模型的参数文件 + ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略 + └── inference.pdmodel # 识别inference模型的program文件 +``` + +**4)模型预测** + +
+
图10 文本检测方案2-模型预测
+ +加载上面导出的模型,执行如下命令对验证集或测试集图片进行预测: + +``` +det_model_dir:预测模型 +image_dir:测试图片路径 +use_gpu:是否使用GPU +``` + +检测可视化结果保存在`/home/aistudio/inference_results/`目录下,查看检测效果。 + + +```python +%pwd +!python tools/infer/predict_det.py \ + --det_algorithm="DB" \ + --det_model_dir="./output/det_db_inference/" \ + --image_dir="./doc/vqa/input/zh_val_21.jpg" \ + --use_gpu=True +``` + +总结,我们分别使用PP-OCRv2中英文超轻量检测预训练模型、XFUND数据集+finetune2种方案进行评估、训练等,指标对比如下: + +| 方案 | hmeans | 结果分析 | +| -------- | -------- | -------- | +| PP-OCRv2中英文超轻量检测预训练模型 | 77.26% | ppocr提供的预训练模型有泛化能力 | +| XFUND数据集 | 79.27% | | +| XFUND数据集+finetune | 85.24% | finetune会提升垂类场景效果 | + +### 4.2 文本识别 + +我们分别使用如下3种方案进行训练、评估: + +- PP-OCRv2中英文超轻量识别预训练模型 +- XFUND数据集+fine-tune +- XFUND数据集+fine-tune+真实通用识别数据 + +#### 4.2.1 方案1:预训练模型 + +**1)下载预训练模型** + +
+ +
图11 文本识别方案1-下载预训练模型
+ +我们使用PP-OCRv2中英文超轻量文本识别模型,下载并解压预训练模型: + + +```python +%cd /home/aistudio/PaddleOCR/pretrain/ +wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar +tar -xf ch_PP-OCRv2_rec_train.tar && rm -rf ch_PP-OCRv2_rec_train.tar +% cd .. +``` + +**2)模型评估** + +
+ +
图12 文本识别方案1-模型评估
+ +首先修改配置文件`configs/det/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml`中的以下字段: + +``` +Eval.dataset.data_dir:指向验证集图片存放目录 +Eval.dataset.label_file_list:指向验证集标注文件 +``` + +我们使用下载的预训练模型进行评估: + + +```python +%cd /home/aistudio/PaddleOCR +CUDA_VISIBLE_DEVICES=0 python tools/eval.py \ + -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml \ + -o Global.checkpoints=./pretrain/ch_PP-OCRv2_rec_train/best_accuracy +``` + +使用预训练模型进行评估,指标如下所示: + +| 方案 | acc | +| -------- | -------- | +| PP-OCRv2中英文超轻量识别预训练模型 | 67.48% | + +使用文本预训练模型在XFUND验证集上评估,acc达到67%左右,充分说明ppocr提供的预训练模型具有泛化能力。 + +#### 4.2.2 方案2:XFUND数据集+finetune + +同检测模型,我们提取Student模型的参数,在XFUND数据集上进行finetune,可以参考如下代码: + + +```python +import paddle +# 加载预训练模型 +all_params = paddle.load("pretrain/ch_PP-OCRv2_rec_train/best_accuracy.pdparams") +# 查看权重参数的keys +print(all_params.keys()) +# 学生模型的权重提取 +s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key} +# 查看学生模型权重参数的keys +print(s_params.keys()) +# 保存 +paddle.save(s_params, "pretrain/ch_PP-OCRv2_rec_train/student.pdparams") +``` + +**1)模型训练** + +
+
图13 文本识别方案2-模型训练
+ +修改配置文件`configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml`中的以下字段: + +``` +Global.pretrained_model:指向预训练模型路径 +Global.character_dict_path: 字典路径 +Optimizer.lr.values:学习率 +Train.dataset.data_dir:指向训练集图片存放目录 +Train.dataset.label_file_list:指向训练集标注文件 +Eval.dataset.data_dir:指向验证集图片存放目录 +Eval.dataset.label_file_list:指向验证集标注文件 +``` +执行如下命令启动训练: + +```python +%cd /home/aistudio/PaddleOCR/ +CUDA_VISIBLE_DEVICES=0 python tools/train.py \ + -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml +``` + +**2)模型评估** + +
+ +
图14 文本识别方案2-模型评估
+ +使用训练好的模型进行评估,更新模型路径`Global.checkpoints`,这里为大家提供训练好的模型`./pretrain/rec_mobile_pp-OCRv2-student-finetune/best_accuracy` + + +```python +%cd /home/aistudio/PaddleOCR/ +CUDA_VISIBLE_DEVICES=0 python tools/eval.py \ + -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml \ + -o Global.checkpoints=./pretrain/rec_mobile_pp-OCRv2-student-finetune/best_accuracy +``` + +使用预训练模型进行评估,指标如下所示: + +| 方案 | acc | +| -------- | -------- | +| XFUND数据集+finetune | 72.33% | + +使用XFUND数据集+finetune训练,在验证集上评估达到72%左右,说明 finetune会提升垂类场景效果。 + +#### 4.2.3 方案3:XFUND数据集+finetune+真实通用识别数据 + +接着我们在上述`XFUND数据集+finetune`实验的基础上,添加真实通用识别数据,进一步提升识别效果。首先准备真实通用识别数据,并上传到AIStudio: + +**1)模型训练** + +
+ +
图15 文本识别方案3-模型训练
+ +在上述`XFUND数据集+finetune`实验中修改配置文件`configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml`的基础上,继续修改以下字段: + +``` +Train.dataset.label_file_list:指向真实识别训练集图片存放目录 +Train.dataset.ratio_list:动态采样 +``` +执行如下命令启动训练: + + + +```python +%cd /home/aistudio/PaddleOCR/ +CUDA_VISIBLE_DEVICES=0 python tools/train.py \ + -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml +``` + +**2)模型评估** + +
+ +
图16 文本识别方案3-模型评估
+ +使用训练好的模型进行评估,更新模型路径`Global.checkpoints`。 + + +```python +CUDA_VISIBLE_DEVICES=0 python tools/eval.py \ + -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml \ + -o Global.checkpoints=./pretrain/rec_mobile_pp-OCRv2-student-realdata/best_accuracy +``` + +使用预训练模型进行评估,指标如下所示: + +| 方案 | acc | +| -------- | -------- | +| XFUND数据集+fine-tune+真实通用识别数据 | 85.29% | + +使用XFUND数据集+finetune训练,在验证集上评估达到85%左右,说明真实通用识别数据对于性能提升很有帮助。 + +**3)导出模型** + +
+
图17 文本识别方案3-导出模型
+ +导出模型只保留前向预测的过程: + + +```python +!python tools/export_model.py \ + -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml \ + -o Global.pretrained_model=pretrain/rec_mobile_pp-OCRv2-student-realdata/best_accuracy \ + Global.save_inference_dir=./output/rec_crnn_inference/ +``` + +**4)模型预测** + +
+ +
图18 文本识别方案3-模型预测
+ +加载上面导出的模型,执行如下命令对验证集或测试集图片进行预测,检测可视化结果保存在`/home/aistudio/inference_results/`目录下,查看检测、识别效果。需要通过`--rec_char_dict_path`指定使用的字典路径 + + +```python +python tools/infer/predict_system.py \ + --image_dir="./doc/vqa/input/zh_val_21.jpg" \ + --det_model_dir="./output/det_db_inference/" \ + --rec_model_dir="./output/rec_crnn_inference/" \ + --rec_image_shape="3, 32, 320" \ + --rec_char_dict_path="/home/aistudio/XFUND/word_dict.txt" +``` + +总结,我们分别使用PP-OCRv2中英文超轻量检测预训练模型、XFUND数据集+finetune2种方案进行评估、训练等,指标对比如下: + +| 方案 | acc | 结果分析 | +| -------- | -------- | -------- | +| PP-OCRv2中英文超轻量识别预训练模型 | 67.48% | ppocr提供的预训练模型具有泛化能力 | +| XFUND数据集+fine-tune |72.33% | finetune会提升垂类场景效果 | +| XFUND数据集+fine-tune+真实通用识别数据 | 85.29% | 真实通用识别数据对于性能提升很有帮助 | + +## 5 文档视觉问答(DOC-VQA) + +VQA指视觉问答,主要针对图像内容进行提问和回答,DOC-VQA是VQA任务中的一种,DOC-VQA主要针对文本图像的文字内容提出问题。 + +PaddleOCR中DOC-VQA系列算法基于PaddleNLP自然语言处理算法库实现LayoutXLM论文,支持基于多模态方法的 **语义实体识别 (Semantic Entity Recognition, SER)** 以及 **关系抽取 (Relation Extraction, RE)** 任务。 + +如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。 + + +```python +%cd pretrain +#下载SER模型 +wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar && tar -xvf ser_LayoutXLM_xfun_zh.tar +#下载RE模型 +wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar && tar -xvf re_LayoutXLM_xfun_zh.tar +%cd ../ +``` + +### 5.1 SER + +SER: 语义实体识别 (Semantic Entity Recognition), 可以完成对图像中的文本识别与分类。 + +
+
图19 SER测试效果图
+ +**图19** 中不同颜色的框表示不同的类别,对于XFUND数据集,有QUESTION, ANSWER, HEADER 3种类别 + +- 深紫色:HEADER +- 浅紫色:QUESTION +- 军绿色:ANSWER + +在OCR检测框的左上方也标出了对应的类别和OCR识别结果。 + +#### 5.1.1 模型训练 + +
+ +
图20 SER-模型训练
+ +启动训练之前,需要修改配置文件 `configs/vqa/ser/layoutxlm.yml` 以下四个字段: + + 1. Train.dataset.data_dir:指向训练集图片存放目录 + 2. Train.dataset.label_file_list:指向训练集标注文件 + 3. Eval.dataset.data_dir:指指向验证集图片存放目录 + 4. Eval.dataset.label_file_list:指向验证集标注文件 + + + +```python +%cd /home/aistudio/PaddleOCR/ +CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/vqa/ser/layoutxlm.yml +``` + +最终会打印出`precision`, `recall`, `hmean`等指标。 在`./output/ser_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。 + +#### 5.1.2 模型评估 + +
+ +
图21 SER-模型评估
+ +我们使用下载的预训练模型进行评估,如果使用自己训练好的模型进行评估,将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段即可。 + + + + +```python +CUDA_VISIBLE_DEVICES=0 python tools/eval.py \ + -c configs/vqa/ser/layoutxlm.yml \ + -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ +``` + +最终会打印出`precision`, `recall`, `hmean`等指标,预训练模型评估指标如下: + +
+
图 SER预训练模型评估指标
+ +#### 5.1.3 模型预测 + +
+ +
图22 SER-模型预测
+ +使用如下命令即可完成`OCR引擎 + SER`的串联预测, 以SER预训练模型为例: + + +```python +CUDA_VISIBLE_DEVICES=0 python tools/infer_vqa_token_ser.py \ + -c configs/vqa/ser/layoutxlm.yml \ + -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ \ + Global.infer_img=doc/vqa/input/zh_val_42.jpg +``` + +最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。通过如下命令查看预测图片: + + +```python +import cv2 +from matplotlib import pyplot as plt +# 在notebook中使用matplotlib.pyplot绘图时,需要添加该命令进行显示 +%matplotlib inline + +img = cv2.imread('output/ser/zh_val_42_ser.jpg') +plt.figure(figsize=(48,24)) +plt.imshow(img) +``` + +### 5.2 RE + +基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。 + +
+
图23 RE预测效果图
+ +图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。 + +#### 5.2.1 模型训练 + +
+ +
图24 RE-模型训练
+ +启动训练之前,需要修改配置文件`configs/vqa/re/layoutxlm.yml`中的以下四个字段 + + Train.dataset.data_dir:指向训练集图片存放目录 + Train.dataset.label_file_list:指向训练集标注文件 + Eval.dataset.data_dir:指指向验证集图片存放目录 + Eval.dataset.label_file_list:指向验证集标注文件 + + + +```python +CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml +``` + +最终会打印出`precision`, `recall`, `hmean`等指标。 在`./output/re_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型 + +#### 5.2.2 模型评估 + +
+
图25 RE-模型评估
+ + +我们使用下载的预训练模型进行评估,如果使用自己训练好的模型进行评估,将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段即可。 + + +```python +CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py \ + -c configs/vqa/re/layoutxlm.yml \ + -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ +``` + +最终会打印出`precision`, `recall`, `hmean`等指标,预训练模型评估指标如下: + +
+
图 RE预训练模型评估指标
+ +#### 5.2.3 模型预测 + +
+ +
图26 RE-模型预测
+ +使用如下命令即可完成OCR引擎 + SER + RE的串联预测, 以预训练SER和RE模型为例, + +最终会在config.Global.save_res_path字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为infer_results.txt。 + + +```python +cd /home/aistudio/PaddleOCR +CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser_re.py \ + -c configs/vqa/re/layoutxlm.yml \ + -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ \ + Global.infer_img=test_imgs/ \ + -c_ser configs/vqa/ser/layoutxlm.yml \ + -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ +``` + +最终会在config.Global.save_res_path字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为infer_results.txt, 每一行表示一张图片的结果,每张图片的结果如下所示,前面表示测试图片路径,后面为测试结果:key字段及对应的value字段。 + +``` +test_imgs/t131.jpg {"政治面税": "群众", "性别": "男", "籍贯": "河北省邯郸市", "婚姻状况": "亏末婚口已婚口已娇", "通讯地址": "邯郸市阳光苑7号楼003", "民族": "汉族", "毕业院校": "河南工业大学", "户口性质": "口农村城镇", "户口地址": "河北省邯郸市", "联系电话": "13288888888", "健康状况": "健康", "姓名": "小六", "好高cm": "180", "出生年月": "1996年8月9日", "文化程度": "本科", "身份证号码": "458933777777777777"} +```` + +展示预测结果 + +```python +import cv2 +from matplotlib import pyplot as plt +%matplotlib inline + +img = cv2.imread('./output/re/t131_ser.jpg') +plt.figure(figsize=(48,24)) +plt.imshow(img) +``` + +## 6 导出Excel + +
+
图27 导出Excel
+ +为了输出信息匹配对,我们修改`tools/infer_vqa_token_ser_re.py`文件中的`line 194-197`。 +``` + fout.write(img_path + "\t" + json.dumps( + { + "ser_result": result, + }, ensure_ascii=False) + "\n") + +``` +更改为 +``` +result_key = {} +for ocr_info_head, ocr_info_tail in result: + result_key[ocr_info_head['text']] = ocr_info_tail['text'] + +fout.write(img_path + "\t" + json.dumps( + result_key, ensure_ascii=False) + "\n") +``` + +同时将输出结果导出到Excel中,效果如 图28 所示: + +
+
图28 Excel效果图
+ + +```python +import json +import xlsxwriter as xw + +workbook = xw.Workbook('output/re/infer_results.xlsx') +format1 = workbook.add_format({ + 'align': 'center', + 'valign': 'vcenter', + 'text_wrap': True, +}) +worksheet1 = workbook.add_worksheet('sheet1') +worksheet1.activate() +title = ['姓名', '性别', '民族', '文化程度', '身份证号码', '联系电话', '通讯地址'] +worksheet1.write_row('A1', title) +i = 2 + +with open('output/re/infer_results.txt', 'r', encoding='utf-8') as fin: + lines = fin.readlines() + for line in lines: + img_path, result = line.strip().split('\t') + result_key = json.loads(result) + # 写入Excel + row_data = [result_key['姓名'], result_key['性别'], result_key['民族'], result_key['文化程度'], result_key['身份证号码'], + result_key['联系电话'], result_key['通讯地址']] + row = 'A' + str(i) + worksheet1.write_row(row, row_data, format1) + i+=1 +workbook.close() +``` + +## 更多资源 + +- 更多深度学习知识、产业案例、面试宝典等,请参考:[awesome-DeepLearning](https://github.com/paddlepaddle/awesome-DeepLearning) + +- 更多PaddleOCR使用教程,请参考:[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph) + +- 更多PaddleNLP使用教程,请参考:[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP) + +- 飞桨框架相关资料,请参考:[飞桨深度学习平台](https://www.paddlepaddle.org.cn/?fr=paddleEdu_aistudio) + +## 参考链接 + +- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf + +- microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm + +- XFUND dataset, https://github.com/doc-analysis/XFUND + diff --git "a/applications/\345\277\253\351\200\237\346\236\204\345\273\272\345\215\241\350\257\201\347\261\273OCR.md" "b/applications/\345\277\253\351\200\237\346\236\204\345\273\272\345\215\241\350\257\201\347\261\273OCR.md" new file mode 100644 index 0000000..50b70ff --- /dev/null +++ "b/applications/\345\277\253\351\200\237\346\236\204\345\273\272\345\215\241\350\257\201\347\261\273OCR.md" @@ -0,0 +1,775 @@ +# 快速构建卡证类OCR + + +- [快速构建卡证类OCR](#快速构建卡证类ocr) + - [1. 金融行业卡证识别应用](#1-金融行业卡证识别应用) + - [1.1 金融行业中的OCR相关技术](#11-金融行业中的ocr相关技术) + - [1.2 金融行业中的卡证识别场景介绍](#12-金融行业中的卡证识别场景介绍) + - [1.3 OCR落地挑战](#13-ocr落地挑战) + - [2. 卡证识别技术解析](#2-卡证识别技术解析) + - [2.1 卡证分类模型](#21-卡证分类模型) + - [2.2 卡证识别模型](#22-卡证识别模型) + - [3. OCR技术拆解](#3-ocr技术拆解) + - [3.1技术流程](#31技术流程) + - [3.2 OCR技术拆解---卡证分类](#32-ocr技术拆解---卡证分类) + - [卡证分类:数据、模型准备](#卡证分类数据模型准备) + - [卡证分类---修改配置文件](#卡证分类---修改配置文件) + - [卡证分类---训练](#卡证分类---训练) + - [3.2 OCR技术拆解---卡证识别](#32-ocr技术拆解---卡证识别) + - [身份证识别:检测+分类](#身份证识别检测分类) + - [数据标注](#数据标注) + - [4 . 项目实践](#4--项目实践) + - [4.1 环境准备](#41-环境准备) + - [4.2 配置文件修改](#42-配置文件修改) + - [4.3 代码修改](#43-代码修改) + - [4.3.1 数据读取](#431-数据读取) + - [4.3.2 head修改](#432--head修改) + - [4.3.3 修改loss](#433-修改loss) + - [4.3.4 后处理](#434-后处理) + - [4.4. 模型启动](#44-模型启动) + - [5 总结](#5-总结) + - [References](#references) + +## 1. 金融行业卡证识别应用 + +### 1.1 金融行业中的OCR相关技术 + +* 《“十四五”数字经济发展规划》指出,2020年我国数字经济核心产业增加值占GDP比重达7.8%,随着数字经济迈向全面扩展,到2025年该比例将提升至10%。 + +* 在过去数年的跨越发展与积累沉淀中,数字金融、金融科技已在对金融业的重塑与再造中充分印证了其自身价值。 + +* 以智能为目标,提升金融数字化水平,实现业务流程自动化,降低人力成本。 + + +![](https://ai-studio-static-online.cdn.bcebos.com/8bb381f164c54ea9b4043cf66fc92ffdea8aaf851bab484fa6e19bd2f93f154f) + + + +### 1.2 金融行业中的卡证识别场景介绍 + +应用场景:身份证、银行卡、营业执照、驾驶证等。 + +应用难点:由于数据的采集来源多样,以及实际采集数据各种噪声:反光、褶皱、模糊、倾斜等各种问题干扰。 + +![](https://ai-studio-static-online.cdn.bcebos.com/981640e17d05487e961162f8576c9e11634ca157f79048d4bd9d3bc21722afe8) + + + +### 1.3 OCR落地挑战 + + +![](https://ai-studio-static-online.cdn.bcebos.com/a5973a8ddeff4bd7ac082f02dc4d0c79de21e721b41641cbb831f23c2cb8fce2) + + + + + +## 2. 卡证识别技术解析 + + +![](https://ai-studio-static-online.cdn.bcebos.com/d7f96effc2434a3ca2d4144ff33c50282b830670c892487d8d7dec151921cce7) + + +### 2.1 卡证分类模型 + +卡证分类:基于PPLCNet + +与其他轻量级模型相比在CPU环境下ImageNet数据集上的表现 + +![](https://ai-studio-static-online.cdn.bcebos.com/cbda3390cb994f98a3c8a9ba88c90c348497763f6c9f4b4797f7d63d84da5f63) + +![](https://ai-studio-static-online.cdn.bcebos.com/dedab7b7fd6543aa9e7f625132b24e3ba3f200e361fa468dac615f7814dfb98d) + + + +* 模型来自模型库PaddleClas,它是一个图像识别和图像分类任务的工具集,助力使用者训练出更好的视觉模型和应用落地。 + +### 2.2 卡证识别模型 + +* 检测:DBNet 识别:SVRT + +![](https://ai-studio-static-online.cdn.bcebos.com/9a7a4e19edc24310b46620f2ee7430f918223b93d4f14a15a52973c096926bad) + + +* PPOCRv3在文本检测、识别进行了一系列改进优化,在保证精度的同时提升预测效率 + + +![](https://ai-studio-static-online.cdn.bcebos.com/6afdbb77e8db4aef9b169e4e94c5d90a9764cfab4f2c4c04aa9afdf4f54d7680) + + +![](https://ai-studio-static-online.cdn.bcebos.com/c1a7d197847a4f168848c59b8e625d1d5e8066b778144395a8b9382bb85dc364) + + +## 3. OCR技术拆解 + +### 3.1技术流程 + +![](https://ai-studio-static-online.cdn.bcebos.com/89ba046177864d8783ced6cb31ba92a66ca2169856a44ee59ac2bb18e44a6c4b) + + +### 3.2 OCR技术拆解---卡证分类 + +#### 卡证分类:数据、模型准备 + + +A 使用爬虫获取无标注数据,将相同类别的放在同一文件夹下,文件名从0开始命名。具体格式如下图所示。 + +​ 注:卡证类数据,建议每个类别数据量在500张以上 +![](https://ai-studio-static-online.cdn.bcebos.com/6f875b6e695e4fe5aedf427beb0d4ce8064ad7cc33c44faaad59d3eb9732639d) + + +B 一行命令生成标签文件 + +``` +tree -r -i -f | grep -E "jpg|JPG|jpeg|JPEG|png|PNG|webp" | awk -F "/" '{print $0" "$2}' > train_list.txt +``` + +C [下载预训练模型 ](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/PP-LCNet.md) + + + +#### 卡证分类---修改配置文件 + + +配置文件主要修改三个部分: + + 全局参数:预训练模型路径/训练轮次/图像尺寸 + + 模型结构:分类数 + + 数据处理:训练/评估数据路径 + + + ![](https://ai-studio-static-online.cdn.bcebos.com/e0dc05039c7444c5ab1260ff550a408748df8d4cfe864223adf390e51058dbd5) + +#### 卡证分类---训练 + + +指定配置文件启动训练: + +``` +!python /home/aistudio/work/PaddleClas/tools/train.py -c /home/aistudio/work/PaddleClas/ppcls/configs/PULC/text_image_orientation/PPLCNet_x1_0.yaml +``` +![](https://ai-studio-static-online.cdn.bcebos.com/06af09bde845449ba0a676410f4daa1cdc3983ac95034bdbbafac3b7fd94042f) + +​ 注:日志中显示了训练结果和评估结果(训练时可以设置固定轮数评估一次) + + +### 3.2 OCR技术拆解---卡证识别 + +卡证识别(以身份证检测为例) +存在的困难及问题: + + * 在自然场景下,由于各种拍摄设备以及光线、角度不同等影响导致实际得到的证件影像千差万别。 + + * 如何快速提取需要的关键信息 + + * 多行的文本信息,检测结果如何正确拼接 + + ![](https://ai-studio-static-online.cdn.bcebos.com/4f8f5533a2914e0a821f4a639677843c32ec1f08a1b1488d94c0b8bfb6e72d2d) + + + +* OCR技术拆解---OCR工具库 + + PaddleOCR是一个丰富、领先且实用的OCR工具库,助力开发者训练出更好的模型并应用落地 + + +身份证识别:用现有的方法识别 + +![](https://ai-studio-static-online.cdn.bcebos.com/12d402e6a06d482a88f979e0ebdfb39f4d3fc8b80517499689ec607ddb04fbf3) + + + + +#### 身份证识别:检测+分类 + +> 方法:基于现有的dbnet检测模型,加入分类方法。检测同时进行分类,从一定程度上优化识别流程 + +![](https://ai-studio-static-online.cdn.bcebos.com/e1e798c87472477fa0bfca0da12bb0c180845a3e167a4761b0d26ff4330a5ccb) + + +![](https://ai-studio-static-online.cdn.bcebos.com/23a5a19c746441309864586e467f995ec8a551a3661640e493fc4d77520309cd) + +#### 数据标注 + +使用PaddleOCRLable进行快速标注 + +![](https://ai-studio-static-online.cdn.bcebos.com/a73180425fa14f919ce52d9bf70246c3995acea1831843cca6c17d871b8f5d95) + + +* 修改PPOCRLabel.py,将下图中的kie参数设置为True + + +![](https://ai-studio-static-online.cdn.bcebos.com/d445cf4d850e4063b9a7fc6a075c12204cf912ff23ec471fa2e268b661b3d693) + + +* 数据标注踩坑分享 + +![](https://ai-studio-static-online.cdn.bcebos.com/89f42eccd600439fa9e28c97ccb663726e4e54ce3a854825b4c3b7d554ea21df) + +​ 注:两者只有标注有差别,训练参数数据集都相同 + +## 4 . 项目实践 + +AIStudio项目链接:[快速构建卡证类OCR](https://aistudio.baidu.com/aistudio/projectdetail/4459116) + +### 4.1 环境准备 + +1)拉取[paddleocr](https://github.com/PaddlePaddle/PaddleOCR)项目,如果从github上拉取速度慢可以选择从gitee上获取。 +``` +!git clone https://github.com/PaddlePaddle/PaddleOCR.git -b release/2.6 /home/aistudio/work/ +``` + +2)获取并解压预训练模型,如果要使用其他模型可以从模型库里自主选择合适模型。 +``` +!wget -P work/pre_trained/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar +!tar -vxf /home/aistudio/work/pre_trained/ch_PP-OCRv3_det_distill_train.tar -C /home/aistudio/work/pre_trained +``` +3) 安装必要依赖 +``` +!pip install -r /home/aistudio/work/requirements.txt +``` + +### 4.2 配置文件修改 + +修改配置文件 *work/configs/det/detmv3db.yml* + +具体修改说明如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/fcdf517af5a6466294d72db7450209378d8efd9b77764e329d3f2aff3579a20c) + + 注:在上述的配置文件的Global变量中需要添加以下两个参数: + +​ label_list 为标签表 +​ num_classes 为分类数 +​ 上述两个参数根据实际的情况配置即可 + + +![](https://ai-studio-static-online.cdn.bcebos.com/0b056be24f374812b61abf43305774767ae122c8479242f98aa0799b7bfc81d4) + +其中lable_list内容如下例所示,***建议第一个参数设置为 background,不要设置为实际要提取的关键信息种类***: + +![](https://ai-studio-static-online.cdn.bcebos.com/9fc78bbcdf754898b9b2c7f000ddf562afac786482ab4f2ab063e2242faa542a) + +配置文件中的其他设置说明 + +![](https://ai-studio-static-online.cdn.bcebos.com/c7fc5e631dd44bc8b714630f4e49d9155a831d9e56c64e2482ded87081d0db22) + +![](https://ai-studio-static-online.cdn.bcebos.com/8d1022ac25d9474daa4fb236235bd58760039d58ad46414f841559d68e0d057f) + +![](https://ai-studio-static-online.cdn.bcebos.com/ee927ad9ebd442bb96f163a7ebbf4bc95e6bedee97324a51887cf82de0851fd3) + + + + +### 4.3 代码修改 + + +#### 4.3.1 数据读取 + + + +* 修改 PaddleOCR/ppocr/data/imaug/label_ops.py中的DetLabelEncode + + +```python +class DetLabelEncode(object): + + # 修改检测标签的编码处,新增了参数分类数:num_classes,重写初始化方法,以及分类标签的读取 + + def __init__(self, label_list, num_classes=8, **kwargs): + self.num_classes = num_classes + self.label_list = [] + if label_list: + if isinstance(label_list, str): + with open(label_list, 'r+', encoding='utf-8') as f: + for line in f.readlines(): + self.label_list.append(line.replace("\n", "")) + else: + self.label_list = label_list + else: + assert ' please check label_list whether it is none or config is right' + + if num_classes != len(self.label_list): # 校验分类数和标签的一致性 + assert 'label_list length is not equal to the num_classes' + + def __call__(self, data): + label = data['label'] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags, classes = [], [], [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['key_cls'] # 此处将kie中的参数作为分类读取 + boxes.append(box) + txts.append(txt) + + if txt in ['*', '###']: + txt_tags.append(True) + if self.num_classes > 1: + classes.append(-2) + else: + txt_tags.append(False) + if self.num_classes > 1: # 将KIE内容的key标签作为分类标签使用 + classes.append(int(self.label_list.index(txt))) + + if len(boxes) == 0: + + return None + boxes = self.expand_points_num(boxes) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool_) + classes = classes + data['polys'] = boxes + data['texts'] = txts + data['ignore_tags'] = txt_tags + if self.num_classes > 1: + data['classes'] = classes + return data +``` + +* 修改 PaddleOCR/ppocr/data/imaug/make_shrink_map.py中的MakeShrinkMap类。这里需要注意的是,如果我们设置的label_list中的第一个参数为要检测的信息那么会得到如下的mask, + +举例说明: +这是检测的mask图,图中有四个mask那么实际对应的分类应该是4类 + +![](https://ai-studio-static-online.cdn.bcebos.com/42d2188d3d6b498880952e12c3ceae1efabf135f8d9f4c31823f09ebe02ba9d2) + + + +label_list中第一个为关键分类,则得到的分类Mask实际如下,与上图相比,少了一个box: + +![](https://ai-studio-static-online.cdn.bcebos.com/864604967256461aa7c5d32cd240645e9f4c70af773341d5911f22d5a3e87b5f) + + + +```python +class MakeShrinkMap(object): + r''' + Making binary mask from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + ''' + + def __init__(self, min_text_size=8, shrink_ratio=0.4, num_classes=8, **kwargs): + self.min_text_size = min_text_size + self.shrink_ratio = shrink_ratio + self.num_classes = num_classes # 添加了分类 + + def __call__(self, data): + image = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + if self.num_classes > 1: + classes = data['classes'] + + h, w = image.shape[:2] + text_polys, ignore_tags = self.validate_polygons(text_polys, + ignore_tags, h, w) + gt = np.zeros((h, w), dtype=np.float32) + mask = np.ones((h, w), dtype=np.float32) + gt_class = np.zeros((h, w), dtype=np.float32) # 新增分类 + for i in range(len(text_polys)): + polygon = text_polys[i] + height = max(polygon[:, 1]) - min(polygon[:, 1]) + width = max(polygon[:, 0]) - min(polygon[:, 0]) + if ignore_tags[i] or min(height, width) < self.min_text_size: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + else: + polygon_shape = Polygon(polygon) + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + shrinked = [] + + # Increase the shrink ratio every time we get multiple polygon returned back + possible_ratios = np.arange(self.shrink_ratio, 1, + self.shrink_ratio) + np.append(possible_ratios, 1) + for ratio in possible_ratios: + distance = polygon_shape.area * ( + 1 - np.power(ratio, 2)) / polygon_shape.length + shrinked = padding.Execute(-distance) + if len(shrinked) == 1: + break + + if shrinked == []: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + continue + + for each_shirnk in shrinked: + shirnk = np.array(each_shirnk).reshape(-1, 2) + cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1) + if self.num_classes > 1: # 绘制分类的mask + cv2.fillPoly(gt_class, polygon.astype(np.int32)[np.newaxis, :, :], classes[i]) + + + data['shrink_map'] = gt + + if self.num_classes > 1: + data['class_mask'] = gt_class + + data['shrink_mask'] = mask + return data +``` + +由于在训练数据中会对数据进行resize设置,yml中的操作为:EastRandomCropData,所以需要修改PaddleOCR/ppocr/data/imaug/random_crop_data.py中的EastRandomCropData + + +```python +class EastRandomCropData(object): + def __init__(self, + size=(640, 640), + max_tries=10, + min_crop_side_ratio=0.1, + keep_ratio=True, + num_classes=8, + **kwargs): + self.size = size + self.max_tries = max_tries + self.min_crop_side_ratio = min_crop_side_ratio + self.keep_ratio = keep_ratio + self.num_classes = num_classes + + def __call__(self, data): + img = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + texts = data['texts'] + if self.num_classes > 1: + classes = data['classes'] + all_care_polys = [ + text_polys[i] for i, tag in enumerate(ignore_tags) if not tag + ] + # 计算crop区域 + crop_x, crop_y, crop_w, crop_h = crop_area( + img, all_care_polys, self.min_crop_side_ratio, self.max_tries) + # crop 图片 保持比例填充 + scale_w = self.size[0] / crop_w + scale_h = self.size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + if self.keep_ratio: + padimg = np.zeros((self.size[1], self.size[0], img.shape[2]), + img.dtype) + padimg[:h, :w] = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + img = padimg + else: + img = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], + tuple(self.size)) + # crop 文本框 + text_polys_crop = [] + ignore_tags_crop = [] + texts_crop = [] + classes_crop = [] + for poly, text, tag,class_index in zip(text_polys, texts, ignore_tags,classes): + poly = ((poly - (crop_x, crop_y)) * scale).tolist() + if not is_poly_outside_rect(poly, 0, 0, w, h): + text_polys_crop.append(poly) + ignore_tags_crop.append(tag) + texts_crop.append(text) + if self.num_classes > 1: + classes_crop.append(class_index) + data['image'] = img + data['polys'] = np.array(text_polys_crop) + data['ignore_tags'] = ignore_tags_crop + data['texts'] = texts_crop + if self.num_classes > 1: + data['classes'] = classes_crop + return data +``` + +#### 4.3.2 head修改 + + + +主要修改 ppocr/modeling/heads/det_db_head.py,将Head类中的最后一层的输出修改为实际的分类数,同时在DBHead中新增分类的head。 + +![](https://ai-studio-static-online.cdn.bcebos.com/0e25da2ccded4af19e95c85c3d3287ab4d53e31a4eed4607b6a4cb637c43f6d3) + + + +#### 4.3.3 修改loss + + +修改PaddleOCR/ppocr/losses/det_db_loss.py中的DBLoss类,分类采用交叉熵损失函数进行计算。 + +![](https://ai-studio-static-online.cdn.bcebos.com/dc10a070018d4d27946c26ec24a2a85bc3f16422f4964f72a9b63c6170d954e1) + + +#### 4.3.4 后处理 + + + +由于涉及到eval以及后续推理能否正常使用,我们需要修改后处理的相关代码,修改位置 PaddleOCR/ppocr/postprocess/db_postprocess.py中的DBPostProcess类 + + +```python +class DBPostProcess(object): + """ + The post process for Differentiable Binarization (DB). + """ + + def __init__(self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + use_dilation=False, + score_mode="fast", + **kwargs): + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + self.min_size = 3 + self.score_mode = score_mode + assert score_mode in [ + "slow", "fast" + ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + + self.dilation_kernel = None if not use_dilation else np.array( + [[1, 1], [1, 1]]) + + def boxes_from_bitmap(self, pred, _bitmap, classes, dest_width, dest_height): + """ + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + """ + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, + cv2.CHAIN_APPROX_SIMPLE) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + + boxes = [] + scores = [] + class_indexes = [] + class_scores = [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + if self.score_mode == "fast": + score, class_index, class_score = self.box_score_fast(pred, points.reshape(-1, 2), classes) + else: + score, class_index, class_score = self.box_score_slow(pred, contour, classes) + if self.box_thresh > score: + continue + + box = self.unclip(points).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + + boxes.append(box.astype(np.int16)) + scores.append(score) + + class_indexes.append(class_index) + class_scores.append(class_score) + + if classes is None: + return np.array(boxes, dtype=np.int16), scores + else: + return np.array(boxes, dtype=np.int16), scores, class_indexes, class_scores + + def unclip(self, box): + unclip_ratio = self.unclip_ratio + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box, classes): + ''' + box_score_fast: use bbox mean score as the mean score + ''' + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + + if classes is None: + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None + else: + k = 999 + class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32) + + cv2.fillPoly(class_mask, box.reshape(1, -1, 2).astype(np.int32), 0) + classes = classes[ymin:ymax + 1, xmin:xmax + 1] + + new_classes = classes + class_mask + a = new_classes.reshape(-1) + b = np.where(a >= k) + classes = np.delete(a, b[0].tolist()) + + class_index = np.argmax(np.bincount(classes)) + class_score = np.sum(classes == class_index) / len(classes) + + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score + + def box_score_slow(self, bitmap, contour, classes): + """ + box_score_slow: use polyon mean score as the mean score + """ + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) + + if classes is None: + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None + else: + k = 999 + class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32) + + cv2.fillPoly(class_mask, contour.reshape(1, -1, 2).astype(np.int32), 0) + classes = classes[ymin:ymax + 1, xmin:xmax + 1] + + new_classes = classes + class_mask + a = new_classes.reshape(-1) + b = np.where(a >= k) + classes = np.delete(a, b[0].tolist()) + + class_index = np.argmax(np.bincount(classes)) + class_score = np.sum(classes == class_index) / len(classes) + + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score + + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + if "classes" in outs_dict: + classes = outs_dict['classes'] + if isinstance(classes, paddle.Tensor): + classes = classes.numpy() + classes = classes[:, 0, :, :] + + else: + classes = None + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + else: + mask = segmentation[batch_index] + + if classes is None: + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, None, + src_w, src_h) + boxes_batch.append({'points': boxes}) + else: + boxes, scores, class_indexes, class_scores = self.boxes_from_bitmap(pred[batch_index], mask, + classes[batch_index], + src_w, src_h) + boxes_batch.append({'points': boxes, "classes": class_indexes, "class_scores": class_scores}) + + return boxes_batch +``` + +### 4.4. 模型启动 + +在完成上述步骤后我们就可以正常启动训练 + +``` +!python /home/aistudio/work/PaddleOCR/tools/train.py -c /home/aistudio/work/PaddleOCR/configs/det/det_mv3_db.yml +``` + +其他命令: +``` +!python /home/aistudio/work/PaddleOCR/tools/eval.py -c /home/aistudio/work/PaddleOCR/configs/det/det_mv3_db.yml +!python /home/aistudio/work/PaddleOCR/tools/infer_det.py -c /home/aistudio/work/PaddleOCR/configs/det/det_mv3_db.yml +``` +模型推理 +``` +!python /home/aistudio/work/PaddleOCR/tools/infer/predict_det.py --image_dir="/home/aistudio/work/test_img/" --det_model_dir="/home/aistudio/work/PaddleOCR/output/infer" +``` + +## 5 总结 + +1. 分类+检测在一定程度上能够缩短用时,具体的模型选取要根据业务场景恰当选择。 +2. 数据标注需要多次进行测试调整标注方法,一般进行检测模型微调,需要标注至少上百张。 +3. 设置合理的batch_size以及resize大小,同时注意lr设置。 + + +## References + +1 https://github.com/PaddlePaddle/PaddleOCR + +2 https://github.com/PaddlePaddle/PaddleClas + +3 https://blog.csdn.net/YY007H/article/details/124491217 diff --git "a/applications/\346\211\213\345\206\231\346\226\207\345\255\227\350\257\206\345\210\253.md" "b/applications/\346\211\213\345\206\231\346\226\207\345\255\227\350\257\206\345\210\253.md" new file mode 100644 index 0000000..09d1bba --- /dev/null +++ "b/applications/\346\211\213\345\206\231\346\226\207\345\255\227\350\257\206\345\210\253.md" @@ -0,0 +1,251 @@ +# 基于PP-OCRv3的手写文字识别 + +- [1. 项目背景及意义](#1-项目背景及意义) +- [2. 项目内容](#2-项目内容) +- [3. PP-OCRv3识别算法介绍](#3-PP-OCRv3识别算法介绍) +- [4. 安装环境](#4-安装环境) +- [5. 数据准备](#5-数据准备) +- [6. 模型训练](#6-模型训练) + - [6.1 下载预训练模型](#61-下载预训练模型) + - [6.2 修改配置文件](#62-修改配置文件) + - [6.3 开始训练](#63-开始训练) +- [7. 模型评估](#7-模型评估) +- [8. 模型导出推理](#8-模型导出推理) + - [8.1 模型导出](#81-模型导出) + - [8.2 模型推理](#82-模型推理) + + +## 1. 项目背景及意义 +目前光学字符识别(OCR)技术在我们的生活当中被广泛使用,但是大多数模型在通用场景下的准确性还有待提高。针对于此我们借助飞桨提供的PaddleOCR套件较容易的实现了在垂类场景下的应用。手写体在日常生活中较为常见,然而手写体的识别却存在着很大的挑战,因为每个人的手写字体风格不一样,这对于视觉模型来说还是相当有挑战的。因此训练一个手写体识别模型具有很好的现实意义。下面给出一些手写体的示例图: + +![example](https://ai-studio-static-online.cdn.bcebos.com/7a8865b2836f42d382e7c3fdaedc4d307d797fa2bcd0466e9f8b7705efff5a7b) + +## 2. 项目内容 +本项目基于PaddleOCR套件,以PP-OCRv3识别模型为基础,针对手写文字识别场景进行优化。 + +Aistudio项目链接:[OCR手写文字识别](https://aistudio.baidu.com/aistudio/projectdetail/4330587) + +## 3. PP-OCRv3识别算法介绍 +PP-OCRv3的识别模块是基于文本识别算法[SVTR](https://arxiv.org/abs/2205.00159)优化。SVTR不再采用RNN结构,通过引入Transformers结构更加有效地挖掘文本行图像的上下文信息,从而提升文本识别能力。如下图所示,PP-OCRv3采用了6个优化策略。 + +![v3_rec](https://ai-studio-static-online.cdn.bcebos.com/d4f5344b5b854d50be738671598a89a45689c6704c4d481fb904dd7cf72f2a1a) + +优化策略汇总如下: + +* SVTR_LCNet:轻量级文本识别网络 +* GTC:Attention指导CTC训练策略 +* TextConAug:挖掘文字上下文信息的数据增广策略 +* TextRotNet:自监督的预训练模型 +* UDML:联合互学习策略 +* UIM:无标注数据挖掘方案 + +详细优化策略描述请参考[PP-OCRv3优化策略](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/PP-OCRv3_introduction.md#3-%E8%AF%86%E5%88%AB%E4%BC%98%E5%8C%96) + +## 4. 安装环境 + + +```python +# 首先git官方的PaddleOCR项目,安装需要的依赖 +git clone https://github.com/PaddlePaddle/PaddleOCR.git +cd PaddleOCR +pip install -r requirements.txt +``` + +## 5. 数据准备 +本项目使用公开的手写文本识别数据集,包含Chinese OCR, 中科院自动化研究所-手写中文数据集[CASIA-HWDB2.x](http://www.nlpr.ia.ac.cn/databases/handwriting/Download.html),以及由中科院手写数据和网上开源数据合并组合的[数据集](https://aistudio.baidu.com/aistudio/datasetdetail/102884/0)等,该项目已经挂载处理好的数据集,可直接下载使用进行训练。 + + +```python +下载并解压数据 +tar -xf hw_data.tar +``` + +## 6. 模型训练 +### 6.1 下载预训练模型 +首先需要下载我们需要的PP-OCRv3识别预训练模型,更多选择请自行选择其他的[文字识别模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md#2-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B) + + +```python +# 使用该指令下载需要的预训练模型 +wget -P ./pretrained_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar +# 解压预训练模型文件 +tar -xf ./pretrained_models/ch_PP-OCRv3_rec_train.tar -C pretrained_models +``` + +### 6.2 修改配置文件 +我们使用`configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml`,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方: + +``` + epoch_num: 100 # 训练epoch数 + save_model_dir: ./output/ch_PP-OCR_v3_rec + save_epoch_step: 10 + eval_batch_step: [0, 100] # 评估间隔,每隔100step评估一次 + pretrained_model: ./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy # 预训练模型路径 + + + lr: + name: Cosine # 修改学习率衰减策略为Cosine + learning_rate: 0.0001 # 修改fine-tune的学习率 + warmup_epoch: 2 # 修改warmup轮数 + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data # 训练集图片路径 + ext_op_transform_idx: 1 + label_file_list: + - ./train_data/chineseocr-data/rec_hand_line_all_label_train.txt # 训练集标签 + - ./train_data/handwrite/HWDB2.0Train_label.txt + - ./train_data/handwrite/HWDB2.1Train_label.txt + - ./train_data/handwrite/HWDB2.2Train_label.txt + - ./train_data/handwrite/hwdb_ic13/handwriting_hwdb_train_labels.txt + - ./train_data/handwrite/HW_Chinese/train_hw.txt + ratio_list: + - 0.1 + - 1.0 + - 1.0 + - 1.0 + - 0.02 + - 1.0 + loader: + shuffle: true + batch_size_per_card: 64 + drop_last: true + num_workers: 4 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data # 测试集图片路径 + label_file_list: + - ./train_data/chineseocr-data/rec_hand_line_all_label_val.txt # 测试集标签 + - ./train_data/handwrite/HWDB2.0Test_label.txt + - ./train_data/handwrite/HWDB2.1Test_label.txt + - ./train_data/handwrite/HWDB2.2Test_label.txt + - ./train_data/handwrite/hwdb_ic13/handwriting_hwdb_val_labels.txt + - ./train_data/handwrite/HW_Chinese/test_hw.txt + loader: + shuffle: false + drop_last: false + batch_size_per_card: 64 + num_workers: 4 +``` +由于数据集大多是长文本,因此需要**注释**掉下面的数据增广策略,以便训练出更好的模型。 +``` +- RecConAug: + prob: 0.5 + ext_data_num: 2 + image_shape: [48, 320, 3] +``` + + +### 6.3 开始训练 +我们使用上面修改好的配置文件`configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml`,预训练模型,数据集路径,学习率,训练轮数等都已经设置完毕后,可以使用下面命令开始训练。 + + +```python +# 开始训练识别模型 +python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml + +``` + +## 7. 模型评估 +在训练之前,我们可以直接使用下面命令来评估预训练模型的效果: + + + +```python +# 评估预训练模型 +python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy" +``` +``` +[2022/07/14 10:46:22] ppocr INFO: load pretrain successful from ./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy +eval model:: 100%|████████████████████████████| 687/687 [03:29<00:00, 3.27it/s] +[2022/07/14 10:49:52] ppocr INFO: metric eval *************** +[2022/07/14 10:49:52] ppocr INFO: acc:0.03724954461811258 +[2022/07/14 10:49:52] ppocr INFO: norm_edit_dis:0.4859541065843199 +[2022/07/14 10:49:52] ppocr INFO: Teacher_acc:0.0371584699368947 +[2022/07/14 10:49:52] ppocr INFO: Teacher_norm_edit_dis:0.48718814890536477 +[2022/07/14 10:49:52] ppocr INFO: fps:947.8562684823883 +``` + +可以看出,直接加载预训练模型进行评估,效果较差,因为预训练模型并不是基于手写文字进行单独训练的,所以我们需要基于预训练模型进行finetune。 +训练完成后,可以进行测试评估,评估命令如下: + + + +```python +# 评估finetune效果 +python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_rec/best_accuracy" + +``` + +评估结果如下,可以看出识别准确率为54.3%。 +``` +[2022/07/14 10:54:06] ppocr INFO: metric eval *************** +[2022/07/14 10:54:06] ppocr INFO: acc:0.5430100180913 +[2022/07/14 10:54:06] ppocr INFO: norm_edit_dis:0.9203322593158589 +[2022/07/14 10:54:06] ppocr INFO: Teacher_acc:0.5401183969626324 +[2022/07/14 10:54:06] ppocr INFO: Teacher_norm_edit_dis:0.919827504507755 +[2022/07/14 10:54:06] ppocr INFO: fps:928.948733797251 +``` + +如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 +
+ +
+将下载或训练完成的模型放置在对应目录下即可完成模型推理。 + +## 8. 模型导出推理 +训练完成后,可以将训练模型转换成inference模型。inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。 + + +### 8.1 模型导出 +导出命令如下: + + + +```python +# 转化为推理模型 +python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_rec/best_accuracy" Global.save_inference_dir="./inference/rec_ppocrv3/" + +``` + +### 8.2 模型推理 +导出模型后,可以使用如下命令进行推理预测: + + + +```python +# 推理预测 +python tools/infer/predict_rec.py --image_dir="train_data/handwrite/HWDB2.0Test_images/104-P16_4.jpg" --rec_model_dir="./inference/rec_ppocrv3/Student" +``` + +``` +[2022/07/14 10:55:56] ppocr INFO: In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320 +[2022/07/14 10:55:58] ppocr INFO: Predicts of train_data/handwrite/HWDB2.0Test_images/104-P16_4.jpg:('品结构,差异化的多品牌渗透使欧莱雅确立了其在中国化妆', 0.9904912114143372) +``` + + +```python +# 可视化文字识别图片 +from PIL import Image +import matplotlib.pyplot as plt +import numpy as np +import os + + +img_path = 'train_data/handwrite/HWDB2.0Test_images/104-P16_4.jpg' + +def vis(img_path): + plt.figure() + image = Image.open(img_path) + plt.imshow(image) + plt.show() + # image = image.resize([208, 208]) + + +vis(img_path) +``` + + +![res](https://ai-studio-static-online.cdn.bcebos.com/ad7c02745491498d82e0ce95f4a274f9b3920b2f467646858709359b7af9d869) diff --git "a/applications/\346\211\253\346\217\217\345\220\210\345\220\214\345\205\263\351\224\256\344\277\241\346\201\257\346\217\220\345\217\226.md" "b/applications/\346\211\253\346\217\217\345\220\210\345\220\214\345\205\263\351\224\256\344\277\241\346\201\257\346\217\220\345\217\226.md" new file mode 100644 index 0000000..26c64a3 --- /dev/null +++ "b/applications/\346\211\253\346\217\217\345\220\210\345\220\214\345\205\263\351\224\256\344\277\241\346\201\257\346\217\220\345\217\226.md" @@ -0,0 +1,284 @@ +# 金融智能核验:扫描合同关键信息抽取 + +本案例将使用OCR技术和通用信息抽取技术,实现合同关键信息审核和比对。通过本章的学习,你可以快速掌握: + +1. 使用PaddleOCR提取扫描文本内容 +2. 使用PaddleNLP抽取自定义信息 + +点击进入 [AI Studio 项目](https://aistudio.baidu.com/aistudio/projectdetail/4545772) + +## 1. 项目背景 +合同审核广泛应用于大中型企业、上市公司、证券、基金公司中,是规避风险的重要任务。 +- 合同内容对比:合同审核场景中,快速找出不同版本合同修改区域、版本差异;如合同盖章归档场景中有效识别实际签署的纸质合同、电子版合同差异。 + +- 合规性检查:法务人员进行合同审核,如合同完备性检查、大小写金额检查、签约主体一致性检查、双方权利和义务对等性分析等。 + +- 风险点识别:通过合同审核可识别事实倾向型风险点和数值计算型风险点等,例如交付地点约定不明、合同总价款不一致、重要条款缺失等风险点。 + + +![](https://ai-studio-static-online.cdn.bcebos.com/d5143df967fa4364a38868793fe7c57b0c0b1213930243babd6ae01423dcbc4d) + +传统业务中大多使用人工进行纸质版合同审核,存在成本高,工作量大,效率低的问题,且一旦出错将造成巨额损失。 + + +本项目针对以上场景,使用PaddleOCR+PaddleNLP快速提取文本内容,经过少量数据微调即可准确抽取关键信息,**高效完成合同内容对比、合规性检查、风险点识别等任务,提高效率,降低风险**。 + +![](https://ai-studio-static-online.cdn.bcebos.com/54f3053e6e1b47a39b26e757006fe2c44910d60a3809422ab76c25396b92e69b) + + +## 2. 解决方案 + +### 2.1 扫描合同文本内容提取 + +使用PaddleOCR开源的模型可以快速完成扫描文档的文本内容提取,在清晰文档上识别准确率可达到95%+。下面来快速体验一下: + +#### 2.1.1 环境准备 + +[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)提供了适用于通用场景的高精轻量模型,提供数据预处理-模型推理-后处理全流程,支持pip安装: + +``` +python -m pip install paddleocr +``` + +#### 2.1.2 效果测试 + +使用一张合同图片作为测试样本,感受ppocrv3模型效果: + + + +使用中文检测+识别模型提取文本,实例化PaddleOCR类: + +``` +from paddleocr import PaddleOCR, draw_ocr + +# paddleocr目前支持中英文、英文、法语、德语、韩语、日语等80个语种,可以通过修改lang参数进行切换 +ocr = PaddleOCR(use_angle_cls=False, lang="ch") # need to run only once to download and load model into memory +``` + +一行命令启动预测,预测结果包括`检测框`和`文本识别内容`: + +``` +img_path = "./test_img/hetong2.jpg" +result = ocr.ocr(img_path, cls=False) +for line in result: + print(line) + +# 可视化结果 +from PIL import Image + +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.show() +``` + +#### 2.1.3 图片预处理 + +通过上图可视化结果可以看到,印章部分造成的文本遮盖,影响了文本识别结果,因此可以考虑通道提取,去除图片中的红色印章: + +``` +import cv2 +import numpy as np +import matplotlib.pyplot as plt + +#读入图像,三通道 +image=cv2.imread("./test_img/hetong2.jpg",cv2.IMREAD_COLOR) #timg.jpeg + +#获得三个通道 +Bch,Gch,Rch=cv2.split(image) + +#保存三通道图片 +cv2.imwrite('blue_channel.jpg',Bch) +cv2.imwrite('green_channel.jpg',Gch) +cv2.imwrite('red_channel.jpg',Rch) +``` +#### 2.1.4 合同文本信息提取 + +经过2.1.3的预处理后,合同照片的红色通道被分离,获得了一张相对更干净的图片,此时可以再次使用ppocr模型提取文本内容: + +``` +import numpy as np +import cv2 + + +img_path = './red_channel.jpg' +result = ocr.ocr(img_path, cls=False) + +# 可视化结果 +from PIL import Image + +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf') +im_show = Image.fromarray(im_show) +vis = np.array(im_show) +im_show.show() +``` + +忽略检测框内容,提取完整的合同文本: + +``` +txts = [line[1][0] for line in result] +all_context = "\n".join(txts) +print(all_context) +``` + +通过以上环节就完成了扫描合同关键信息抽取的第一步:文本内容提取,接下来可以基于识别出的文本内容抽取关键信息 + +### 2.2 合同关键信息抽取 + +#### 2.2.1 环境准备 + +安装PaddleNLP + + +``` +pip install --upgrade pip +pip install --upgrade paddlenlp +``` + +#### 2.2.2 合同关键信息抽取 + +PaddleNLP 使用 Taskflow 统一管理多场景任务的预测功能,其中`information_extraction` 通过大量的有标签样本进行训练,在通用的场景中一般可以直接使用,只需更换关键字即可。例如在合同信息抽取中,我们重新定义抽取关键字: + +甲方、乙方、币种、金额、付款方式 + + +将使用OCR提取好的文本作为输入,使用三行命令可以对上文中提取到的合同文本进行关键信息抽取: + +``` +from paddlenlp import Taskflow +schema = ["甲方","乙方","总价"] +ie = Taskflow('information_extraction', schema=schema) +ie.set_schema(schema) +ie(all_context) +``` + +可以看到UIE模型可以准确的提取出关键信息,用于后续的信息比对或审核。 + +## 3.效果优化 + +### 3.1 文本识别后处理调优 + +实际图片采集过程中,可能出现部分图片弯曲等问题,导致使用默认参数识别文本时存在漏检,影响关键信息获取。 + +例如下图: + + + + +直接进行预测: + +``` +img_path = "./test_img/hetong3.jpg" +# 预测结果 +result = ocr.ocr(img_path, cls=False) +# 可视化结果 +from PIL import Image + +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.show() +``` + +可视化结果可以看到,弯曲图片存在漏检,一般来说可以通过调整后处理参数解决,无需重新训练模型。漏检问题往往是因为检测模型获得的分割图太小,生成框的得分过低被过滤掉了,通常有两种方式调整参数: +- 开启`use_dilatiion=True` 膨胀分割区域 +- 调小`det_db_box_thresh`阈值 + +``` +# 重新实例化 PaddleOCR +ocr = PaddleOCR(use_angle_cls=False, lang="ch", det_db_box_thresh=0.3, use_dilation=True) + +# 预测并可视化 +img_path = "./test_img/hetong3.jpg" +# 预测结果 +result = ocr.ocr(img_path, cls=False) +# 可视化结果 +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.show() +``` + +可以看到漏检问题被很好的解决,提取完整的文本内容: + +``` +txts = [line[1][0] for line in result] +context = "\n".join(txts) +print(context) +``` + +### 3.2 关键信息提取调优 + +UIE通过大量有标签样本进行训练,得到了一个开箱即用的高精模型。 然而针对不同场景,可能会出现部分实体无法被抽取的情况。通常来说有以下几个方法进行效果调优: + + +- 修改 schema +- 添加正则方法 +- 标注小样本微调模型 + +**修改schema** + +Prompt和原文描述越像,抽取效果越好,例如 +``` +三:合同价格:总价为人民币大写:参拾玖万捌仟伍佰 +元,小写:398500.00元。总价中包括站房工程建设、安装 +及相关避雷、消防、接地、电力、材料费、检验费、安全、 +验收等所需费用及其他相关费用和税金。 +``` +schema = ["总金额"] 时无法准确抽取,与原文描述差异较大。 修改 schema = ["总价"] 再次尝试: + +``` +from paddlenlp import Taskflow +# schema = ["总金额"] +schema = ["总价"] +ie = Taskflow('information_extraction', schema=schema) +ie.set_schema(schema) +ie(all_context) +``` + + +**模型微调** + +UIE的建模方式主要是通过 `Prompt` 方式来建模, `Prompt` 在小样本上进行微调效果非常有效。详细的数据标注+模型微调步骤可以参考项目: + +[PaddleNLP信息抽取技术重磅升级!](https://aistudio.baidu.com/aistudio/projectdetail/3914778?channelType=0&channel=0) + +[工单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/3914778?contributionType=1) + +[快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/4038499?contributionType=1) + + +## 总结 + +扫描合同的关键信息提取可以使用 PaddleOCR + PaddleNLP 组合实现,两个工具均有以下优势: + +* 使用简单:whl包一键安装,3行命令调用 +* 效果领先:优秀的模型效果可覆盖几乎全部的应用场景 +* 调优成本低:OCR模型可通过后处理参数的调整适配略有偏差的扫描文本, UIE模型可以通过极少的标注样本微调,成本很低。 + +## 作业 + +尝试自己解析出 `test_img/homework.png` 扫描合同中的 [甲方、乙方] 关键词: + + + + + + + +更多场景下的垂类模型获取,请扫下图二维码填写问卷,加入PaddleOCR官方交流群获取模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + + diff --git "a/applications/\346\266\262\346\231\266\345\261\217\350\257\273\346\225\260\350\257\206\345\210\253.md" "b/applications/\346\266\262\346\231\266\345\261\217\350\257\273\346\225\260\350\257\206\345\210\253.md" new file mode 100644 index 0000000..f70fa06 --- /dev/null +++ "b/applications/\346\266\262\346\231\266\345\261\217\350\257\273\346\225\260\350\257\206\345\210\253.md" @@ -0,0 +1,616 @@ +# 基于PP-OCRv3的液晶屏读数识别 + +- [1. 项目背景及意义](#1-项目背景及意义) +- [2. 项目内容](#2-项目内容) +- [3. 安装环境](#3-安装环境) +- [4. 文字检测](#4-文字检测) + - [4.1 PP-OCRv3检测算法介绍](#41-PP-OCRv3检测算法介绍) + - [4.2 数据准备](#42-数据准备) + - [4.3 模型训练](#43-模型训练) + - [4.3.1 预训练模型直接评估](#431-预训练模型直接评估) + - [4.3.2 预训练模型直接finetune](#432-预训练模型直接finetune) + - [4.3.3 基于预训练模型Finetune_student模型](#433-基于预训练模型Finetune_student模型) + - [4.3.4 基于预训练模型Finetune_teacher模型](#434-基于预训练模型Finetune_teacher模型) + - [4.3.5 采用CML蒸馏进一步提升student模型精度](#435-采用CML蒸馏进一步提升student模型精度) + - [4.3.6 模型导出推理](#436-4.3.6-模型导出推理) +- [5. 文字识别](#5-文字识别) + - [5.1 PP-OCRv3识别算法介绍](#51-PP-OCRv3识别算法介绍) + - [5.2 数据准备](#52-数据准备) + - [5.3 模型训练](#53-模型训练) + - [5.4 模型导出推理](#54-模型导出推理) +- [6. 系统串联](#6-系统串联) + - [6.1 后处理](#61-后处理) +- [7. PaddleServing部署](#7-PaddleServing部署) + + +## 1. 项目背景及意义 +目前光学字符识别(OCR)技术在我们的生活当中被广泛使用,但是大多数模型在通用场景下的准确性还有待提高,针对于此我们借助飞桨提供的PaddleOCR套件较容易的实现了在垂类场景下的应用。 + +该项目以国家质量基础(NQI)为准绳,充分利用大数据、云计算、物联网等高新技术,构建覆盖计量端、实验室端、数据端和硬件端的完整计量解决方案,解决传统计量校准中存在的难题,拓宽计量检测服务体系和服务领域;解决无数传接口或数传接口不统一、不公开的计量设备,以及计量设备所处的环境比较恶劣,不适合人工读取数据。通过OCR技术实现远程计量,引领计量行业向智慧计量转型和发展。 + +## 2. 项目内容 +本项目基于PaddleOCR开源套件,以PP-OCRv3检测和识别模型为基础,针对液晶屏读数识别场景进行优化。 + +Aistudio项目链接:[OCR液晶屏读数识别](https://aistudio.baidu.com/aistudio/projectdetail/4080130) + +## 3. 安装环境 + +```python +# 首先git官方的PaddleOCR项目,安装需要的依赖 +# 第一次运行打开该注释 +# git clone https://gitee.com/PaddlePaddle/PaddleOCR.git +cd PaddleOCR +pip install -r requirements.txt +``` + +## 4. 文字检测 +文本检测的任务是定位出输入图像中的文字区域。近年来学术界关于文本检测的研究非常丰富,一类方法将文本检测视为目标检测中的一个特定场景,基于通用目标检测算法进行改进适配,如TextBoxes[1]基于一阶段目标检测器SSD[2]算法,调整目标框使之适合极端长宽比的文本行,CTPN[3]则是基于Faster RCNN[4]架构改进而来。但是文本检测与目标检测在目标信息以及任务本身上仍存在一些区别,如文本一般长宽比较大,往往呈“条状”,文本行之间可能比较密集,弯曲文本等,因此又衍生了很多专用于文本检测的算法。本项目基于PP-OCRv3算法进行优化。 + +### 4.1 PP-OCRv3检测算法介绍 +PP-OCRv3检测模型是对PP-OCRv2中的CML(Collaborative Mutual Learning) 协同互学习文本检测蒸馏策略进行了升级。如下图所示,CML的核心思想结合了①传统的Teacher指导Student的标准蒸馏与 ②Students网络之间的DML互学习,可以让Students网络互学习的同时,Teacher网络予以指导。PP-OCRv3分别针对教师模型和学生模型进行进一步效果优化。其中,在对教师模型优化时,提出了大感受野的PAN结构LK-PAN和引入了DML(Deep Mutual Learning)蒸馏策略;在对学生模型优化时,提出了残差注意力机制的FPN结构RSE-FPN。 +![](https://ai-studio-static-online.cdn.bcebos.com/c306b2f028364805a55494d435ab553a76cf5ae5dd3f4649a948ea9aeaeb28b8) + +详细优化策略描述请参考[PP-OCRv3优化策略](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/PP-OCRv3_introduction.md#2) + +### 4.2 数据准备 +[计量设备屏幕字符检测数据集](https://aistudio.baidu.com/aistudio/datasetdetail/127845)数据来源于实际项目中各种计量设备的数显屏,以及在网上搜集的一些其他数显屏,包含训练集755张,测试集355张。 + +```python +# 在PaddleOCR下创建新的文件夹train_data +mkdir train_data +# 下载数据集并解压到指定路径下 +unzip icdar2015.zip -d train_data +``` + +```python +# 随机查看文字检测数据集图片 +from PIL import Image +import matplotlib.pyplot as plt +import numpy as np +import os + + +train = './train_data/icdar2015/text_localization/test' +# 从指定目录中选取一张图片 +def get_one_image(train): + plt.figure() + files = os.listdir(train) + n = len(files) + ind = np.random.randint(0,n) + img_dir = os.path.join(train,files[ind]) + image = Image.open(img_dir) + plt.imshow(image) + plt.show() + image = image.resize([208, 208]) + +get_one_image(train) +``` +![det_png](https://ai-studio-static-online.cdn.bcebos.com/0639da09b774458096ae577e82b2c59e89ced6a00f55458f946997ab7472a4f8) + +### 4.3 模型训练 + +#### 4.3.1 预训练模型直接评估 +下载我们需要的PP-OCRv3检测预训练模型,更多选择请自行选择其他的[文字检测模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md#1-%E6%96%87%E6%9C%AC%E6%A3%80%E6%B5%8B%E6%A8%A1%E5%9E%8B) + +```python +#使用该指令下载需要的预训练模型 +wget -P ./pretrained_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar +# 解压预训练模型文件 +tar -xf ./pretrained_models/ch_PP-OCRv3_det_distill_train.tar -C pretrained_models +``` + +在训练之前,我们可以直接使用下面命令来评估预训练模型的效果: + +```python +# 评估预训练模型 +python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy" +``` + +结果如下: + +| | 方案 |hmeans| +|---|---------------------------|---| +| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.50%| + +#### 4.3.2 预训练模型直接finetune +##### 修改配置文件 +我们使用configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方: +``` +epoch:100 +save_epoch_step:10 +eval_batch_step:[0, 50] +save_model_dir: ./output/ch_PP-OCR_v3_det/ +pretrained_model: ./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy +learning_rate: 0.00025 +num_workers: 0 # 如果单卡训练,建议将Train和Eval的loader部分的num_workers设置为0,否则会出现`/dev/shm insufficient`的报错 +``` + +##### 开始训练 +使用我们上面修改的配置文件configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml,训练命令如下: + +```python +# 开始训练模型 +python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy +``` + +评估训练好的模型: + +```python +# 评估训练好的模型 +python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det/best_accuracy" +``` + +结果如下: +| | 方案 |hmeans| +|---|---------------------------|---| +| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.50%| +| 1 | PP-OCRv3中英文超轻量检测预训练模型fintune |65.20%| + +#### 4.3.3 基于预训练模型Finetune_student模型 + +我们使用configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方: +``` +epoch:100 +save_epoch_step:10 +eval_batch_step:[0, 50] +save_model_dir: ./output/ch_PP-OCR_v3_det_student/ +pretrained_model: ./pretrained_models/ch_PP-OCRv3_det_distill_train/student +learning_rate: 0.00025 +num_workers: 0 # 如果单卡训练,建议将Train和Eval的loader部分的num_workers设置为0,否则会出现`/dev/shm insufficient`的报错 +``` + +训练命令如下: + +```python +python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/student +``` + +评估训练好的模型: + +```python +# 评估训练好的模型 +python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_student/best_accuracy" +``` + +结果如下: +| | 方案 |hmeans| +|---|---------------------------|---| +| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.50%| +| 1 | PP-OCRv3中英文超轻量检测预训练模型fintune |65.20%| +| 2 | PP-OCRv3中英文超轻量检测预训练模型fintune学生模型 |80.00%| + +#### 4.3.4 基于预训练模型Finetune_teacher模型 + +首先需要从提供的预训练模型best_accuracy.pdparams中提取teacher参数,组合成适合dml训练的初始化模型,提取代码如下: + +```python +cd ./pretrained_models/ +# transform teacher params in best_accuracy.pdparams into teacher_dml.paramers +import paddle + +# load pretrained model +all_params = paddle.load("ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams") +# print(all_params.keys()) + +# keep teacher params +t_params = {key[len("Teacher."):]: all_params[key] for key in all_params if "Teacher." in key} + +# print(t_params.keys()) + +s_params = {"Student." + key: t_params[key] for key in t_params} +s2_params = {"Student2." + key: t_params[key] for key in t_params} +s_params = {**s_params, **s2_params} +# print(s_params.keys()) + +paddle.save(s_params, "ch_PP-OCRv3_det_distill_train/teacher_dml.pdparams") + +``` + +我们使用configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方: +``` +epoch:100 +save_epoch_step:10 +eval_batch_step:[0, 50] +save_model_dir: ./output/ch_PP-OCR_v3_det_teacher/ +pretrained_model: ./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_dml +learning_rate: 0.00025 +num_workers: 0 # 如果单卡训练,建议将Train和Eval的loader部分的num_workers设置为0,否则会出现`/dev/shm insufficient`的报错 +``` + +训练命令如下: + +```python +python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_dml +``` + +评估训练好的模型: + +```python +# 评估训练好的模型 +python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_teacher/best_accuracy" +``` + +结果如下: +| | 方案 |hmeans| +|---|---------------------------|---| +| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.50%| +| 1 | PP-OCRv3中英文超轻量检测预训练模型fintune |65.20%| +| 2 | PP-OCRv3中英文超轻量检测预训练模型fintune学生模型 |80.00%| +| 3 | PP-OCRv3中英文超轻量检测预训练模型fintune教师模型 |84.80%| + +#### 4.3.5 采用CML蒸馏进一步提升student模型精度 + +需要从4.3.3和4.3.4训练得到的best_accuracy.pdparams中提取各自代表student和teacher的参数,组合成适合cml训练的初始化模型,提取代码如下: + +```python +# transform teacher params and student parameters into cml model +import paddle + +all_params = paddle.load("./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams") +# print(all_params.keys()) + +t_params = paddle.load("./output/ch_PP-OCR_v3_det_teacher/best_accuracy.pdparams") +# print(t_params.keys()) + +s_params = paddle.load("./output/ch_PP-OCR_v3_det_student/best_accuracy.pdparams") +# print(s_params.keys()) + +for key in all_params: + # teacher is OK + if "Teacher." in key: + new_key = key.replace("Teacher", "Student") + #print("{} >> {}\n".format(key, new_key)) + assert all_params[key].shape == t_params[new_key].shape + all_params[key] = t_params[new_key] + + if "Student." in key: + new_key = key.replace("Student.", "") + #print("{} >> {}\n".format(key, new_key)) + assert all_params[key].shape == s_params[new_key].shape + all_params[key] = s_params[new_key] + + if "Student2." in key: + new_key = key.replace("Student2.", "") + print("{} >> {}\n".format(key, new_key)) + assert all_params[key].shape == s_params[new_key].shape + all_params[key] = s_params[new_key] + +paddle.save(all_params, "./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_cml_student.pdparams") +``` + +训练命令如下: + +```python +python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_cml_student Global.save_model_dir=./output/ch_PP-OCR_v3_det_finetune/ +``` + +评估训练好的模型: + +```python +# 评估训练好的模型 +python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_finetune/best_accuracy" +``` + +结果如下: +| | 方案 |hmeans| +|---|---------------------------|---| +| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.50%| +| 1 | PP-OCRv3中英文超轻量检测预训练模型fintune |65.20%| +| 2 | PP-OCRv3中英文超轻量检测预训练模型fintune学生模型 |80.00%| +| 3 | PP-OCRv3中英文超轻量检测预训练模型fintune教师模型 |84.80%| +| 4 | 基于2和3训练好的模型fintune |82.70%| + +如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 +
+ +
+将下载或训练完成的模型放置在对应目录下即可完成模型推理。 + +#### 4.3.6 模型导出推理 +训练完成后,可以将训练模型转换成inference模型。inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。 +##### 4.3.6.1 模型导出 +导出命令如下: + +```python +# 转化为推理模型 +python tools/export_model.py \ +-c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml \ +-o Global.pretrained_model=./output/ch_PP-OCR_v3_det_finetune/best_accuracy \ +-o Global.save_inference_dir="./inference/det_ppocrv3" + +``` + +##### 4.3.6.2 模型推理 +导出模型后,可以使用如下命令进行推理预测: + +```python +# 推理预测 +python tools/infer/predict_det.py --image_dir="train_data/icdar2015/text_localization/test/1.jpg" --det_model_dir="./inference/det_ppocrv3/Student" +``` + +## 5. 文字识别 +文本识别的任务是识别出图像中的文字内容,一般输入来自于文本检测得到的文本框截取出的图像文字区域。文本识别一般可以根据待识别文本形状分为规则文本识别和不规则文本识别两大类。规则文本主要指印刷字体、扫描文本等,文本大致处在水平线位置;不规则文本往往不在水平位置,存在弯曲、遮挡、模糊等问题。不规则文本场景具有很大的挑战性,也是目前文本识别领域的主要研究方向。本项目基于PP-OCRv3算法进行优化。 + +### 5.1 PP-OCRv3识别算法介绍 +PP-OCRv3的识别模块是基于文本识别算法[SVTR](https://arxiv.org/abs/2205.00159)优化。SVTR不再采用RNN结构,通过引入Transformers结构更加有效地挖掘文本行图像的上下文信息,从而提升文本识别能力。如下图所示,PP-OCRv3采用了6个优化策略。 +![](https://ai-studio-static-online.cdn.bcebos.com/d4f5344b5b854d50be738671598a89a45689c6704c4d481fb904dd7cf72f2a1a) + +优化策略汇总如下: +* SVTR_LCNet:轻量级文本识别网络 +* GTC:Attention指导CTC训练策略 +* TextConAug:挖掘文字上下文信息的数据增广策略 +* TextRotNet:自监督的预训练模型 +* UDML:联合互学习策略 +* UIM:无标注数据挖掘方案 + +详细优化策略描述请参考[PP-OCRv3优化策略](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/PP-OCRv3_introduction.md#3-%E8%AF%86%E5%88%AB%E4%BC%98%E5%8C%96) + +### 5.2 数据准备 +[计量设备屏幕字符识别数据集](https://aistudio.baidu.com/aistudio/datasetdetail/128714)数据来源于实际项目中各种计量设备的数显屏,以及在网上搜集的一些其他数显屏,包含训练集19912张,测试集4099张。 + +```python +# 解压下载的数据集到指定路径下 +unzip ic15_data.zip -d train_data +``` + +```python +# 随机查看文字检测数据集图片 +from PIL import Image +import matplotlib.pyplot as plt +import numpy as np +import os + +train = './train_data/ic15_data/train' +# 从指定目录中选取一张图片 +def get_one_image(train): + plt.figure() + files = os.listdir(train) + n = len(files) + ind = np.random.randint(0,n) + img_dir = os.path.join(train,files[ind]) + image = Image.open(img_dir) + plt.imshow(image) + plt.show() + image = image.resize([208, 208]) + +get_one_image(train) +``` + +![rec_png](https://ai-studio-static-online.cdn.bcebos.com/3de0d475c69746d0a184029001ef07c85fd68816d66d4beaa10e6ef60030f9b4) + +### 5.3 模型训练 +#### 下载预训练模型 +下载我们需要的PP-OCRv3识别预训练模型,更多选择请自行选择其他的[文字识别模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md#2-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B) + +```python +# 使用该指令下载需要的预训练模型 +wget -P ./pretrained_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar +# 解压预训练模型文件 +tar -xf ./pretrained_models/ch_PP-OCRv3_rec_train.tar -C pretrained_models +``` + +#### 修改配置文件 +我们使用configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方: +``` + epoch_num: 100 # 训练epoch数 + save_model_dir: ./output/ch_PP-OCR_v3_rec + save_epoch_step: 10 + eval_batch_step: [0, 100] # 评估间隔,每隔100step评估一次 + cal_metric_during_train: true + pretrained_model: ./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy # 预训练模型路径 + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + use_space_char: true # 使用空格 + + lr: + name: Cosine # 修改学习率衰减策略为Cosine + learning_rate: 0.0002 # 修改fine-tune的学习率 + warmup_epoch: 2 # 修改warmup轮数 + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ # 训练集图片路径 + ext_op_transform_idx: 1 + label_file_list: + - ./train_data/ic15_data/rec_gt_train.txt # 训练集标签 + ratio_list: + - 1.0 + loader: + shuffle: true + batch_size_per_card: 64 + drop_last: true + num_workers: 4 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ # 测试集图片路径 + label_file_list: + - ./train_data/ic15_data/rec_gt_test.txt # 测试集标签 + ratio_list: + - 1.0 + loader: + shuffle: false + drop_last: false + batch_size_per_card: 64 + num_workers: 4 +``` + +在训练之前,我们可以直接使用下面命令来评估预训练模型的效果: + +```python +# 评估预训练模型 +python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy" +``` + +结果如下: +| | 方案 |accuracy| +|---|---------------------------|---| +| 0 | PP-OCRv3中英文超轻量识别预训练模型直接预测 |70.40%| + +#### 开始训练 +我们使用上面修改好的配置文件configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml,预训练模型,数据集路径,学习率,训练轮数等都已经设置完毕后,可以使用下面命令开始训练。 + +```python +# 开始训练识别模型 +python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml +``` + +训练完成后,可以对训练模型中最好的进行测试,评估命令如下: + +```python +# 评估finetune效果 +python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.checkpoints="./output/ch_PP-OCR_v3_rec/best_accuracy" +``` + +结果如下: +| | 方案 |accuracy| +|---|---------------------------|---| +| 0 | PP-OCRv3中英文超轻量识别预训练模型直接预测 |70.40%| +| 1 | PP-OCRv3中英文超轻量识别预训练模型finetune |82.20%| + +如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 +
+ +
+将下载或训练完成的模型放置在对应目录下即可完成模型推理。 + +### 5.4 模型导出推理 +训练完成后,可以将训练模型转换成inference模型。inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。 +#### 模型导出 +导出命令如下: + +```python +# 转化为推理模型 +python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_rec/best_accuracy" Global.save_inference_dir="./inference/rec_ppocrv3/" +``` + +#### 模型推理 +导出模型后,可以使用如下命令进行推理预测 + +```python +# 推理预测 +python tools/infer/predict_rec.py --image_dir="train_data/ic15_data/test/1_crop_0.jpg" --rec_model_dir="./inference/rec_ppocrv3/Student" +``` + +## 6. 系统串联 +我们将上面训练好的检测和识别模型进行系统串联测试,命令如下: + +```python +#串联测试 +python3 tools/infer/predict_system.py --image_dir="./train_data/icdar2015/text_localization/test/142.jpg" --det_model_dir="./inference/det_ppocrv3/Student" --rec_model_dir="./inference/rec_ppocrv3/Student" +``` + +测试结果保存在`./inference_results/`目录下,可以用下面代码进行可视化 + +```python +%cd /home/aistudio/PaddleOCR +# 显示结果 +import matplotlib.pyplot as plt +from PIL import Image +img_path= "./inference_results/142.jpg" +img = Image.open(img_path) +plt.figure("test_img", figsize=(30,30)) +plt.imshow(img) +plt.show() +``` + +![sys_res_png](https://ai-studio-static-online.cdn.bcebos.com/901ab741cb46441ebec510b37e63b9d8d1b7c95f63cc4e5e8757f35179ae6373) + +### 6.1 后处理 +如果需要获取key-value信息,可以基于启发式的规则,将识别结果与关键字库进行匹配;如果匹配上了,则取该字段为key, 后面一个字段为value。 + +```python +def postprocess(rec_res): + keys = ["型号", "厂家", "版本号", "检定校准分类", "计量器具编号", "烟尘流量", + "累积体积", "烟气温度", "动压", "静压", "时间", "试验台编号", "预测流速", + "全压", "烟温", "流速", "工况流量", "标杆流量", "烟尘直读嘴", "烟尘采样嘴", + "大气压", "计前温度", "计前压力", "干球温度", "湿球温度", "流量", "含湿量"] + key_value = [] + if len(rec_res) > 1: + for i in range(len(rec_res) - 1): + rec_str, _ = rec_res[i] + for key in keys: + if rec_str in key: + key_value.append([rec_str, rec_res[i + 1][0]]) + break + return key_value +key_value = postprocess(filter_rec_res) +``` + +## 7. PaddleServing部署 +首先需要安装PaddleServing部署相关的环境 + +```python +python -m pip install paddle-serving-server-gpu +python -m pip install paddle_serving_client +python -m pip install paddle-serving-app +``` + +### 7.1 转化检测模型 + +```python +cd deploy/pdserving/ +python -m paddle_serving_client.convert --dirname ../../inference/det_ppocrv3/Student/ \ + --model_filename inference.pdmodel \ + --params_filename inference.pdiparams \ + --serving_server ./ppocr_det_v3_serving/ \ + --serving_client ./ppocr_det_v3_client/ +``` + +### 7.2 转化识别模型 + +```python +python -m paddle_serving_client.convert --dirname ../../inference/rec_ppocrv3/Student \ + --model_filename inference.pdmodel \ + --params_filename inference.pdiparams \ + --serving_server ./ppocr_rec_v3_serving/ \ + --serving_client ./ppocr_rec_v3_client/ +``` + + +### 7.3 启动服务 +首先可以将后处理代码加入到web_service.py中,具体修改如下: +``` +# 代码153行后面增加下面代码 +def _postprocess(rec_res): + keys = ["型号", "厂家", "版本号", "检定校准分类", "计量器具编号", "烟尘流量", + "累积体积", "烟气温度", "动压", "静压", "时间", "试验台编号", "预测流速", + "全压", "烟温", "流速", "工况流量", "标杆流量", "烟尘直读嘴", "烟尘采样嘴", + "大气压", "计前温度", "计前压力", "干球温度", "湿球温度", "流量", "含湿量"] + key_value = [] + if len(rec_res) > 1: + for i in range(len(rec_res) - 1): + rec_str, _ = rec_res[i] + for key in keys: + if rec_str in key: + key_value.append([rec_str, rec_res[i + 1][0]]) + break + return key_value +key_value = _postprocess(rec_list) +res = {"result": str(key_value)} +# res = {"result": str(result_list)} +``` + +启动服务端 +```python +python web_service.py 2>&1 >log.txt +``` + +### 7.4 发送请求 +然后再开启一个新的终端,运行下面的客户端代码 + +```python +python pipeline_http_client.py --image_dir ../../train_data/icdar2015/text_localization/test/142.jpg +``` + +可以获取到最终的key-value结果: +``` +大气压, 100.07kPa +干球温度, 0000℃ +计前温度, 0000℃ +湿球温度, 0000℃ +计前压力, -0000kPa +流量, 00.0L/min +静压, 00000kPa +含湿量, 00.0 % +``` diff --git "a/applications/\350\275\273\351\207\217\347\272\247\350\275\246\347\211\214\350\257\206\345\210\253.md" "b/applications/\350\275\273\351\207\217\347\272\247\350\275\246\347\211\214\350\257\206\345\210\253.md" new file mode 100644 index 0000000..c9b76ee --- /dev/null +++ "b/applications/\350\275\273\351\207\217\347\272\247\350\275\246\347\211\214\350\257\206\345\210\253.md" @@ -0,0 +1,832 @@ +# 一种基于PaddleOCR的轻量级车牌识别模型 + +- [1. 项目介绍](#1-项目介绍) +- [2. 环境搭建](#2-环境搭建) +- [3. 数据集准备](#3-数据集准备) + - [3.1 数据集标注规则](#31-数据集标注规则) + - [3.2 制作符合PP-OCR训练格式的标注文件](#32-制作符合pp-ocr训练格式的标注文件) +- [4. 实验](#4-实验) + - [4.1 检测](#41-检测) + - [4.1.1 预训练模型直接预测](#411-预训练模型直接预测) + - [4.1.2 CCPD车牌数据集fine-tune](#412-ccpd车牌数据集fine-tune) + - [4.1.3 CCPD车牌数据集fine-tune+量化训练](#413-ccpd车牌数据集fine-tune量化训练) + - [4.1.4 模型导出](#414-模型导出) + - [4.2 识别](#42-识别) + - [4.2.1 预训练模型直接预测](#421-预训练模型直接预测) + - [4.2.2 预训练模型直接预测+改动后处理](#422-预训练模型直接预测改动后处理) + - [4.2.3 CCPD车牌数据集fine-tune](#423-ccpd车牌数据集fine-tune) + - [4.2.4 CCPD车牌数据集fine-tune+量化训练](#424-ccpd车牌数据集fine-tune量化训练) + - [4.2.5 模型导出](#425-模型导出) + - [4.3 计算End2End指标](#43-计算End2End指标) + - [4.4 部署](#44-部署) + - [4.5 实验总结](#45-实验总结) + +## 1. 项目介绍 + +车牌识别(Vehicle License Plate Recognition,VLPR) 是计算机视频图像识别技术在车辆牌照识别中的一种应用。车牌识别技术要求能够将运动中的汽车牌照从复杂背景中提取并识别出来,在高速公路车辆管理,停车场管理和城市交通中得到广泛应用。 + +本项目难点如下: + +1. 车牌在图像中的尺度差异大、在车辆上的悬挂位置不固定 +2. 车牌图像质量层次不齐: 角度倾斜、图片模糊、光照不足、过曝等问题严重 +3. 边缘和端测场景应用对模型大小有限制,推理速度有要求 + +针对以上问题, 本例选用 PP-OCRv3 这一开源超轻量OCR系统进行车牌识别系统的开发。基于PP-OCRv3模型,在CCPD数据集达到99%的检测和94%的识别精度,模型大小12.8M(2.5M+10.3M)。基于量化对模型体积进行进一步压缩到5.8M(1M+4.8M), 同时推理速度提升25%。 + + + +aistudio项目链接: [基于PaddleOCR的轻量级车牌识别范例](https://aistudio.baidu.com/aistudio/projectdetail/3919091?contributionType=1) + +## 2. 环境搭建 + +本任务基于Aistudio完成, 具体环境如下: + +- 操作系统: Linux +- PaddlePaddle: 2.3 +- paddleslim: 2.2.2 +- PaddleOCR: Release/2.5 + +下载 PaddleOCR代码 + +```bash +git clone -b dygraph https://github.com/PaddlePaddle/PaddleOCR +``` + +安装依赖库 + +```bash +pip install -r PaddleOCR/requirements.txt +``` + +## 3. 数据集准备 + +所使用的数据集为 CCPD2020 新能源车牌数据集,该数据集为 + +该数据集分布如下: + +|数据集类型|数量| +|---|---| +|训练集| 5769| +|验证集| 1001| +|测试集| 5006| + +数据集图片示例如下: +![](https://ai-studio-static-online.cdn.bcebos.com/3bce057a8e0c40a0acbd26b2e29e4e2590a31bc412764be7b9e49799c69cb91c) + +数据集可以从这里下载 https://aistudio.baidu.com/aistudio/datasetdetail/101595 + +下载好数据集后对数据集进行解压 + +```bash +unzip -d /home/aistudio/data /home/aistudio/data/data101595/CCPD2020.zip +``` + +### 3.1 数据集标注规则 + +CPPD数据集的图片文件名具有特殊规则,详细可查看:https://github.com/detectRecog/CCPD + +具体规则如下: + +例如: 025-95_113-154&383_386&473-386&473_177&454_154&383_363&402-0_0_22_27_27_33_16-37-15.jpg + +每个名称可以分为七个字段,以-符号作为分割。这些字段解释如下。 + +- 025:车牌面积与整个图片区域的面积比。025 (25%) + +- 95_113:水平倾斜程度和垂直倾斜度。水平 95度 垂直 113度 + +- 154&383_386&473:左上和右下顶点的坐标。左上(154,383) 右下(386,473) + +- 386&473_177&454_154&383_363&402:整个图像中车牌的四个顶点的精确(x,y)坐标。这些坐标从右下角顶点开始。(386,473) (177,454) (154,383) (363,402) + +- 0_0_22_27_27_33_16:CCPD中的每个图像只有一个车牌。每个车牌号码由一个汉字,一个字母和五个字母或数字组成。有效的中文车牌由七个字符组成:省(1个字符),字母(1个字符),字母+数字(5个字符)。“ 0_0_22_27_27_33_16”是每个字符的索引。这三个数组定义如下。每个数组的最后一个字符是字母O,而不是数字0。我们将O用作“无字符”的符号,因为中文车牌字符中没有O。因此以上车牌拼起来即为 皖AY339S + +- 37:牌照区域的亮度。 37 (37%) + +- 15:车牌区域的模糊度。15 (15%) + +```python +provinces = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "警", "学", "O"] +alphabets = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W','X', 'Y', 'Z', 'O'] +ads = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X','Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'O'] +``` + +### 3.2 制作符合PP-OCR训练格式的标注文件 + +在开始训练之前,可使用如下代码制作符合PP-OCR训练格式的标注文件。 + + +```python +import cv2 +import os +import json +from tqdm import tqdm +import numpy as np + +provinces = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "警", "学", "O"] +alphabets = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'O'] +ads = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'O'] + +def make_label(img_dir, save_gt_folder, phase): + crop_img_save_dir = os.path.join(save_gt_folder, phase, 'crop_imgs') + os.makedirs(crop_img_save_dir, exist_ok=True) + + f_det = open(os.path.join(save_gt_folder, phase, 'det.txt'), 'w', encoding='utf-8') + f_rec = open(os.path.join(save_gt_folder, phase, 'rec.txt'), 'w', encoding='utf-8') + + i = 0 + for filename in tqdm(os.listdir(os.path.join(img_dir, phase))): + str_list = filename.split('-') + if len(str_list) < 5: + continue + coord_list = str_list[3].split('_') + txt_list = str_list[4].split('_') + boxes = [] + for coord in coord_list: + boxes.append([int(x) for x in coord.split("&")]) + boxes = [boxes[2], boxes[3], boxes[0], boxes[1]] + lp_number = provinces[int(txt_list[0])] + alphabets[int(txt_list[1])] + ''.join([ads[int(x)] for x in txt_list[2:]]) + + # det + det_info = [{'points':boxes, 'transcription':lp_number}] + f_det.write('{}\t{}\n'.format(os.path.join(phase, filename), json.dumps(det_info, ensure_ascii=False))) + + # rec + boxes = np.float32(boxes) + img = cv2.imread(os.path.join(img_dir, phase, filename)) + # crop_img = img[int(boxes[:,1].min()):int(boxes[:,1].max()),int(boxes[:,0].min()):int(boxes[:,0].max())] + crop_img = get_rotate_crop_image(img, boxes) + crop_img_save_filename = '{}_{}.jpg'.format(i,'_'.join(txt_list)) + crop_img_save_path = os.path.join(crop_img_save_dir, crop_img_save_filename) + cv2.imwrite(crop_img_save_path, crop_img) + f_rec.write('{}/crop_imgs/{}\t{}\n'.format(phase, crop_img_save_filename, lp_number)) + i+=1 + f_det.close() + f_rec.close() + +def get_rotate_crop_image(img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + +img_dir = '/home/aistudio/data/CCPD2020/ccpd_green' +save_gt_folder = '/home/aistudio/data/CCPD2020/PPOCR' +# phase = 'train' # change to val and test to make val dataset and test dataset +for phase in ['train','val','test']: + make_label(img_dir, save_gt_folder, phase) +``` + +通过上述命令可以完成了`训练集`,`验证集`和`测试集`的制作,制作完成的数据集信息如下: + +| 类型 | 数据集 | 图片地址 | 标签地址 | 图片数量 | +| --- | --- | --- | --- | --- | +| 检测 | 训练集 | /home/aistudio/data/CCPD2020/ccpd_green/train | /home/aistudio/data/CCPD2020/PPOCR/train/det.txt | 5769 | +| 检测 | 验证集 | /home/aistudio/data/CCPD2020/ccpd_green/val | /home/aistudio/data/CCPD2020/PPOCR/val/det.txt | 1001 | +| 检测 | 测试集 | /home/aistudio/data/CCPD2020/ccpd_green/test | /home/aistudio/data/CCPD2020/PPOCR/test/det.txt | 5006 | +| 识别 | 训练集 | /home/aistudio/data/CCPD2020/PPOCR/train/crop_imgs | /home/aistudio/data/CCPD2020/PPOCR/train/rec.txt | 5769 | +| 识别 | 验证集 | /home/aistudio/data/CCPD2020/PPOCR/val/crop_imgs | /home/aistudio/data/CCPD2020/PPOCR/val/rec.txt | 1001 | +| 识别 | 测试集 | /home/aistudio/data/CCPD2020/PPOCR/test/crop_imgs | /home/aistudio/data/CCPD2020/PPOCR/test/rec.txt | 5006 | + +在普遍的深度学习流程中,都是在训练集训练,在验证集选择最优模型后在测试集上进行测试。在本例中,我们省略中间步骤,直接在训练集训练,在测试集选择最优模型,因此我们只使用训练集和测试集。 + +## 4. 实验 + +由于数据集比较少,为了模型更好和更快的收敛,这里选用 PaddleOCR 中的 PP-OCRv3 模型进行文本检测和识别,并且使用 PP-OCRv3 模型参数作为预训练模型。PP-OCRv3在PP-OCRv2的基础上,中文场景端到端Hmean指标相比于PP-OCRv2提升5%, 英文数字模型端到端效果提升11%。详细优化细节请参考[PP-OCRv3](../doc/doc_ch/PP-OCRv3_introduction.md)技术报告。 + +由于车牌场景均为端侧设备部署,因此对速度和模型大小有比较高的要求,因此还需要采用量化训练的方式进行模型大小的压缩和模型推理速度的加速。模型量化可以在基本不损失模型的精度的情况下,将FP32精度的模型参数转换为Int8精度,减小模型参数大小并加速计算,使用量化后的模型在移动端等部署时更具备速度优势。 + +因此,本实验中对于车牌检测和识别有如下3种方案: + +1. PP-OCRv3中英文超轻量预训练模型直接预测 +2. CCPD车牌数据集在PP-OCRv3模型上fine-tune +3. CCPD车牌数据集在PP-OCRv3模型上fine-tune后量化 + +### 4.1 检测 +#### 4.1.1 预训练模型直接预测 + +从下表中下载PP-OCRv3文本检测预训练模型 + +|模型名称|模型简介|配置文件|推理模型大小|下载地址| +| --- | --- | --- | --- | --- | +|ch_PP-OCRv3_det| 【最新】原始超轻量模型,支持中英文、多语种文本检测 |[ch_PP-OCRv3_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml)| 3.8M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar)| + +使用如下命令下载预训练模型 + +```bash +mkdir models +cd models +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar +tar -xf ch_PP-OCRv3_det_distill_train.tar +cd /home/aistudio/PaddleOCR +``` + +预训练模型下载完成后,我们使用[ch_PP-OCRv3_det_student.yml](../configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml) 配置文件进行后续实验,在开始评估之前需要对配置文件中部分字段进行设置,具体如下: + +1. 模型存储和训练相关: + 1. Global.pretrained_model: 指向PP-OCRv3文本检测预训练模型地址 +2. 数据集相关 + 1. Eval.dataset.data_dir:指向测试集图片存放目录 + 2. Eval.dataset.label_file_list:指向测试集标注文件 + +上述字段均为必须修改的字段,可以通过修改配置文件的方式改动,也可在不需要修改配置文件的情况下,改变训练的参数。这里使用不改变配置文件的方式 。使用如下命令进行PP-OCRv3文本检测预训练模型的评估 + + +```bash +python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o \ + Global.pretrained_model=models/ch_PP-OCRv3_det_distill_train/student.pdparams \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/ccpd_green \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/det.txt] +``` +上述指令中,通过-c 选择训练使用配置文件,通过-o参数在不需要修改配置文件的情况下,改变训练的参数。 + +使用预训练模型进行评估,指标如下所示: + +| 方案 |hmeans| +|---------------------------|---| +| PP-OCRv3中英文超轻量检测预训练模型直接预测 |76.12%| + +#### 4.1.2 CCPD车牌数据集fine-tune + +**训练** + +为了进行fine-tune训练,我们需要在配置文件中设置需要使用的预训练模型地址,学习率和数据集等参数。 具体如下: + +1. 模型存储和训练相关: + 1. Global.pretrained_model: 指向PP-OCRv3文本检测预训练模型地址 + 2. Global.eval_batch_step: 模型多少step评估一次,这里设为从第0个step开始没隔772个step评估一次,772为一个epoch总的step数。 +2. 优化器相关: + 1. Optimizer.lr.name: 学习率衰减器设为常量 Const + 2. Optimizer.lr.learning_rate: 做 fine-tune 实验,学习率需要设置的比较小,此处学习率设为配置文件中的0.05倍 + 3. Optimizer.lr.warmup_epoch: warmup_epoch设为0 +3. 数据集相关: + 1. Train.dataset.data_dir:指向训练集图片存放目录 + 2. Train.dataset.label_file_list:指向训练集标注文件 + 3. Eval.dataset.data_dir:指向测试集图片存放目录 + 4. Eval.dataset.label_file_list:指向测试集标注文件 + +使用如下代码即可启动在CCPD车牌数据集上的fine-tune。 + +```bash +python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o \ + Global.pretrained_model=models/ch_PP-OCRv3_det_distill_train/student.pdparams \ + Global.save_model_dir=output/CCPD/det \ + Global.eval_batch_step="[0, 772]" \ + Optimizer.lr.name=Const \ + Optimizer.lr.learning_rate=0.0005 \ + Optimizer.lr.warmup_epoch=0 \ + Train.dataset.data_dir=/home/aistudio/data/CCPD2020/ccpd_green \ + Train.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/train/det.txt] \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/ccpd_green \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/det.txt] +``` + +在上述命令中,通过`-o`的方式修改了配置文件中的参数。 + + +**评估** + +训练完成后使用如下命令进行评估 + + +```bash +python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o \ + Global.pretrained_model=output/CCPD/det/best_accuracy.pdparams \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/ccpd_green \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/det.txt] +``` + +使用预训练模型和CCPD车牌数据集fine-tune,指标分别如下: + +|方案|hmeans| +|---|---| +|PP-OCRv3中英文超轻量检测预训练模型直接预测|76.12%| +|PP-OCRv3中英文超轻量检测预训练模型 fine-tune|99.00%| + +可以看到进行fine-tune能显著提升车牌检测的效果。 + +#### 4.1.3 CCPD车牌数据集fine-tune+量化训练 + +此处采用 PaddleOCR 中提供好的[量化教程](../deploy/slim/quantization/README.md)对模型进行量化训练。 + +量化训练可通过如下命令启动: + +```bash +python3.7 deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o \ + Global.pretrained_model=output/CCPD/det/best_accuracy.pdparams \ + Global.save_model_dir=output/CCPD/det_quant \ + Global.eval_batch_step="[0, 772]" \ + Optimizer.lr.name=Const \ + Optimizer.lr.learning_rate=0.0005 \ + Optimizer.lr.warmup_epoch=0 \ + Train.dataset.data_dir=/home/aistudio/data/CCPD2020/ccpd_green \ + Train.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/train/det.txt] \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/ccpd_green \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/det.txt] +``` + +量化后指标对比如下 + +|方案|hmeans| 模型大小 | 预测速度(lite) | +|---|---|------|------------| +|PP-OCRv3中英文超轻量检测预训练模型 fine-tune|99.00%| 2.5M | 223ms | +|PP-OCRv3中英文超轻量检测预训练模型 fine-tune+量化|98.91%| 1.0M | 189ms | + +可以看到通过量化训练在精度几乎无损的情况下,降低模型体积60%并且推理速度提升15%。 + +速度测试基于[PaddleOCR lite教程](../deploy/lite/readme_ch.md)完成。 + +#### 4.1.4 模型导出 + +使用如下命令可以将训练好的模型进行导出 + +* 非量化模型 +```bash +python tools/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o \ + Global.pretrained_model=output/CCPD/det/best_accuracy.pdparams \ + Global.save_inference_dir=output/det/infer +``` +* 量化模型 +```bash +python deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o \ + Global.pretrained_model=output/CCPD/det_quant/best_accuracy.pdparams \ + Global.save_inference_dir=output/det/infer +``` + +### 4.2 识别 +#### 4.2.1 预训练模型直接预测 + +从下表中下载PP-OCRv3文本识别预训练模型 + +|模型名称|模型简介|配置文件|推理模型大小|下载地址| +| --- | --- | --- | --- | --- | +|ch_PP-OCRv3_rec|【最新】原始超轻量模型,支持中英文、数字识别|[ch_PP-OCRv3_rec_distillation.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml)| 12.4M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar) | + +使用如下命令下载预训练模型 + +```bash +mkdir models +cd models +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar +tar -xf ch_PP-OCRv3_rec_train.tar +cd /home/aistudio/PaddleOCR +``` + +PaddleOCR提供的PP-OCRv3识别模型采用蒸馏训练策略,因此提供的预训练模型中会包含`Teacher`和`Student`模型的参数,详细信息可参考[knowledge_distillation.md](../doc/doc_ch/knowledge_distillation.md)。 因此,模型下载完成后需要使用如下代码提取`Student`模型的参数: + +```python +import paddle +# 加载预训练模型 +all_params = paddle.load("models/ch_PP-OCRv3_rec_train/best_accuracy.pdparams") +# 查看权重参数的keys +print(all_params.keys()) +# 学生模型的权重提取 +s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key} +# 查看学生模型权重参数的keys +print(s_params.keys()) +# 保存 +paddle.save(s_params, "models/ch_PP-OCRv3_rec_train/student.pdparams") +``` + +预训练模型下载完成后,我们使用[ch_PP-OCRv3_rec.yml](../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml) 配置文件进行后续实验,在开始评估之前需要对配置文件中部分字段进行设置,具体如下: + +1. 模型存储和训练相关: + 1. Global.pretrained_model: 指向PP-OCRv3文本识别预训练模型地址 +2. 数据集相关 + 1. Eval.dataset.data_dir:指向测试集图片存放目录 + 2. Eval.dataset.label_file_list:指向测试集标注文件 + +使用如下命令进行PP-OCRv3文本识别预训练模型的评估 + +```bash +python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=models/ch_PP-OCRv3_rec_train/student.pdparams \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt] +``` + +如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 +
+ +
+ + +评估部分日志如下: +```bash +[2022/05/12 19:52:02] ppocr INFO: load pretrain successful from models/ch_PP-OCRv3_rec_train/best_accuracy +eval model:: 100%|██████████████████████████████| 40/40 [00:15<00:00, 2.57it/s] +[2022/05/12 19:52:17] ppocr INFO: metric eval *************** +[2022/05/12 19:52:17] ppocr INFO: acc:0.0 +[2022/05/12 19:52:17] ppocr INFO: norm_edit_dis:0.8656084923002452 +[2022/05/12 19:52:17] ppocr INFO: Teacher_acc:0.000399520574511545 +[2022/05/12 19:52:17] ppocr INFO: Teacher_norm_edit_dis:0.8657902943394548 +[2022/05/12 19:52:17] ppocr INFO: fps:1443.1801978719905 + +``` +使用预训练模型进行评估,指标如下所示: + +|方案|acc| +|---|---| +|PP-OCRv3中英文超轻量识别预训练模型直接预测|0%| + +从评估日志中可以看到,直接使用PP-OCRv3预训练模型进行评估,acc非常低,但是norm_edit_dis很高。因此,我们猜测是模型大部分文字识别是对的,只有少部分文字识别错误。使用如下命令进行infer查看模型的推理结果进行验证: + + +```bash +python tools/infer_rec.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=models/ch_PP-OCRv3_rec_train/student.pdparams \ + Global.infer_img=/home/aistudio/data/CCPD2020/PPOCR/test/crop_imgs/0_0_0_3_32_30_31_30_30.jpg +``` + +输出部分日志如下: +```bash +[2022/05/01 08:51:57] ppocr INFO: train with paddle 2.2.2 and device CUDAPlace(0) +W0501 08:51:57.127391 11326 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1 +W0501 08:51:57.132315 11326 device_context.cc:465] device: 0, cuDNN Version: 7.6. +[2022/05/01 08:52:00] ppocr INFO: load pretrain successful from models/ch_PP-OCRv3_rec_train/student +[2022/05/01 08:52:00] ppocr INFO: infer_img: /home/aistudio/data/CCPD2020/PPOCR/test/crop_imgs/0_0_3_32_30_31_30_30.jpg +[2022/05/01 08:52:00] ppocr INFO: result: {"Student": {"label": "皖A·D86766", "score": 0.9552637934684753}, "Teacher": {"label": "皖A·D86766", "score": 0.9917094707489014}} +[2022/05/01 08:52:00] ppocr INFO: success! +``` + +从infer结果可以看到,车牌中的文字大部分都识别正确,只是多识别出了一个`·`。针对这种情况,有如下两种方案: +1. 直接通过后处理去掉多识别的`·`。 +2. 进行 fine-tune。 + +#### 4.2.2 预训练模型直接预测+改动后处理 + +直接通过后处理去掉多识别的`·`,在后处理的改动比较简单,只需在 [ppocr/postprocess/rec_postprocess.py](../ppocr/postprocess/rec_postprocess.py) 文件的76行添加如下代码: +```python +text = text.replace('·','') +``` + +改动前后指标对比: + +|方案|acc| +|---|---| +|PP-OCRv3中英文超轻量识别预训练模型直接预测|0.20%| +|PP-OCRv3中英文超轻量识别预训练模型直接预测+后处理去掉多识别的`·`|90.97%| + +可以看到,去掉多余的`·`能大幅提高精度。 + +#### 4.2.3 CCPD车牌数据集fine-tune + +**训练** + +为了进行fine-tune训练,我们需要在配置文件中设置需要使用的预训练模型地址,学习率和数据集等参数。 具体如下: + +1. 模型存储和训练相关: + 1. Global.pretrained_model: 指向PP-OCRv3文本识别预训练模型地址 + 2. Global.eval_batch_step: 模型多少step评估一次,这里设为从第0个step开始没隔45个step评估一次,45为一个epoch总的step数。 +2. 优化器相关 + 1. Optimizer.lr.name: 学习率衰减器设为常量 Const + 2. Optimizer.lr.learning_rate: 做 fine-tune 实验,学习率需要设置的比较小,此处学习率设为配置文件中的0.05倍 + 3. Optimizer.lr.warmup_epoch: warmup_epoch设为0 +3. 数据集相关 + 1. Train.dataset.data_dir:指向训练集图片存放目录 + 2. Train.dataset.label_file_list:指向训练集标注文件 + 3. Eval.dataset.data_dir:指向测试集图片存放目录 + 4. Eval.dataset.label_file_list:指向测试集标注文件 + +使用如下命令启动 fine-tune + +```bash +python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=models/ch_PP-OCRv3_rec_train/student.pdparams \ + Global.save_model_dir=output/CCPD/rec/ \ + Global.eval_batch_step="[0, 90]" \ + Optimizer.lr.name=Const \ + Optimizer.lr.learning_rate=0.0005 \ + Optimizer.lr.warmup_epoch=0 \ + Train.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \ + Train.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/train/rec.txt] \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt] +``` + +**评估** + +训练完成后使用如下命令进行评估 + +```bash +python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=output/CCPD/rec/best_accuracy.pdparams \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt] +``` + +使用预训练模型和CCPD车牌数据集fine-tune,指标分别如下: + +|方案| acc | +|---|--------| +|PP-OCRv3中英文超轻量识别预训练模型直接预测| 0.00% | +|PP-OCRv3中英文超轻量识别预训练模型直接预测+后处理去掉多识别的`·`| 90.97% | +|PP-OCRv3中英文超轻量识别预训练模型 fine-tune| 94.54% | + +可以看到进行fine-tune能显著提升车牌识别的效果。 + +#### 4.2.4 CCPD车牌数据集fine-tune+量化训练 + +此处采用 PaddleOCR 中提供好的[量化教程](../deploy/slim/quantization/README.md)对模型进行量化训练。 + +量化训练可通过如下命令启动: + +```bash +python3.7 deploy/slim/quantization/quant.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=output/CCPD/rec/best_accuracy.pdparams \ + Global.save_model_dir=output/CCPD/rec_quant/ \ + Global.eval_batch_step="[0, 90]" \ + Optimizer.lr.name=Const \ + Optimizer.lr.learning_rate=0.0005 \ + Optimizer.lr.warmup_epoch=0 \ + Train.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \ + Train.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/train/rec.txt] \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt] +``` + +量化后指标对比如下 + +|方案| acc | 模型大小 | 预测速度(lite) | +|---|--------|-------|------------| +|PP-OCRv3中英文超轻量识别预训练模型 fine-tune| 94.54% | 10.3M | 4.2ms | +|PP-OCRv3中英文超轻量识别预训练模型 fine-tune + 量化| 93.40% | 4.8M | 1.8ms | + +可以看到量化后能降低模型体积53%并且推理速度提升57%,但是由于识别数据过少,量化带来了1%的精度下降。 + +速度测试基于[PaddleOCR lite教程](../deploy/lite/readme_ch.md)完成。 + +#### 4.2.5 模型导出 + +使用如下命令可以将训练好的模型进行导出。 + +* 非量化模型 +```bash +python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=output/CCPD/rec/best_accuracy.pdparams \ + Global.save_inference_dir=output/CCPD/rec/infer +``` +* 量化模型 +```bash +python deploy/slim/quantization/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=output/CCPD/rec_quant/best_accuracy.pdparams \ + Global.save_inference_dir=output/CCPD/rec_quant/infer +``` + +### 4.3 计算End2End指标 + +端到端指标可通过 [PaddleOCR内置脚本](../tools/end2end/readme.md) 进行计算,具体步骤如下: + +1. 导出模型 + +通过如下命令进行模型的导出。注意,量化模型导出时,需要配置eval数据集 + +```bash +# 检测模型 + +# 预训练模型 +python tools/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o \ + Global.pretrained_model=models/ch_PP-OCRv3_det_distill_train/student.pdparams \ + Global.save_inference_dir=output/ch_PP-OCRv3_det_distill_train/infer + +# 非量化模型 +python tools/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o \ + Global.pretrained_model=output/CCPD/det/best_accuracy.pdparams \ + Global.save_inference_dir=output/CCPD/det/infer + +# 量化模型 +python deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o \ + Global.pretrained_model=output/CCPD/det_quant/best_accuracy.pdparams \ + Global.save_inference_dir=output/CCPD/det_quant/infer \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/ccpd_green \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/det.txt] \ + Eval.loader.num_workers=0 + +# 识别模型 + +# 预训练模型 +python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=models/ch_PP-OCRv3_rec_train/student.pdparams \ + Global.save_inference_dir=output/ch_PP-OCRv3_rec_train/infer + +# 非量化模型 +python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=output/CCPD/rec/best_accuracy.pdparams \ + Global.save_inference_dir=output/CCPD/rec/infer + +# 量化模型 +python deploy/slim/quantization/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o \ + Global.pretrained_model=output/CCPD/rec_quant/best_accuracy.pdparams \ + Global.save_inference_dir=output/CCPD/rec_quant/infer \ + Eval.dataset.data_dir=/home/aistudio/data/CCPD2020/PPOCR \ + Eval.dataset.label_file_list=[/home/aistudio/data/CCPD2020/PPOCR/test/rec.txt] +``` + +2. 用导出的模型对测试集进行预测 + +此处,分别使用PP-OCRv3预训练模型,fintune模型和量化模型对测试集的所有图像进行预测,命令如下: + +```bash +# PP-OCRv3中英文超轻量检测预训练模型,PP-OCRv3中英文超轻量识别预训练模型 +python3 tools/infer/predict_system.py --det_model_dir=models/ch_PP-OCRv3_det_distill_train/infer --rec_model_dir=models/ch_PP-OCRv3_rec_train/infer --det_limit_side_len=736 --det_limit_type=min --image_dir=/home/aistudio/data/CCPD2020/ccpd_green/test/ --draw_img_save_dir=infer/pretrain --use_dilation=true + +# PP-OCRv3中英文超轻量检测预训练模型+fine-tune,PP-OCRv3中英文超轻量识别预训练模型+fine-tune +python3 tools/infer/predict_system.py --det_model_dir=output/CCPD/det/infer --rec_model_dir=output/CCPD/rec/infer --det_limit_side_len=736 --det_limit_type=min --image_dir=/home/aistudio/data/CCPD2020/ccpd_green/test/ --draw_img_save_dir=infer/fine-tune --use_dilation=true + +# PP-OCRv3中英文超轻量检测预训练模型 fine-tune +量化,PP-OCRv3中英文超轻量识别预训练模型 fine-tune +量化 结果转换和评估 +python3 tools/infer/predict_system.py --det_model_dir=output/CCPD/det_quant/infer --rec_model_dir=output/CCPD/rec_quant/infer --det_limit_side_len=736 --det_limit_type=min --image_dir=/home/aistudio/data/CCPD2020/ccpd_green/test/ --draw_img_save_dir=infer/quant --use_dilation=true +``` + +3. 转换label并计算指标 + +将gt和上一步保存的预测结果转换为端对端评测需要的数据格式,并根据转换后的数据进行端到端指标计算 + +```bash +python3 tools/end2end/convert_ppocr_label.py --mode=gt --label_path=/home/aistudio/data/CCPD2020/PPOCR/test/det.txt --save_folder=end2end/gt + +# PP-OCRv3中英文超轻量检测预训练模型,PP-OCRv3中英文超轻量识别预训练模型 结果转换和评估 +python3 tools/end2end/convert_ppocr_label.py --mode=pred --label_path=infer/pretrain/system_results.txt --save_folder=end2end/pretrain +python3 tools/end2end/eval_end2end.py end2end/gt end2end/pretrain + +# PP-OCRv3中英文超轻量检测预训练模型,PP-OCRv3中英文超轻量识别预训练模型+后处理去掉多识别的`·` 结果转换和评估 +# 需手动修改后处理函数 +python3 tools/end2end/convert_ppocr_label.py --mode=pred --label_path=infer/post/system_results.txt --save_folder=end2end/post +python3 tools/end2end/eval_end2end.py end2end/gt end2end/post + +# PP-OCRv3中英文超轻量检测预训练模型 fine-tune,PP-OCRv3中英文超轻量识别预训练模型 fine-tune 结果转换和评估 +python3 tools/end2end/convert_ppocr_label.py --mode=pred --label_path=infer/fine-tune/system_results.txt --save_folder=end2end/fine-tune +python3 tools/end2end/eval_end2end.py end2end/gt end2end/fine-tune + +# PP-OCRv3中英文超轻量检测预训练模型 fine-tune +量化,PP-OCRv3中英文超轻量识别预训练模型 fine-tune +量化 结果转换和评估 +python3 tools/end2end/convert_ppocr_label.py --mode=pred --label_path=infer/quant/system_results.txt --save_folder=end2end/quant +python3 tools/end2end/eval_end2end.py end2end/gt end2end/quant +``` + +日志如下: +```bash +The convert label saved in end2end/gt +The convert label saved in end2end/pretrain +start testing... +hit, dt_count, gt_count 2 5988 5006 +character_acc: 70.42% +avg_edit_dist_field: 2.37 +avg_edit_dist_img: 2.37 +precision: 0.03% +recall: 0.04% +fmeasure: 0.04% +The convert label saved in end2end/post +start testing... +hit, dt_count, gt_count 4224 5988 5006 +character_acc: 81.59% +avg_edit_dist_field: 1.47 +avg_edit_dist_img: 1.47 +precision: 70.54% +recall: 84.38% +fmeasure: 76.84% +The convert label saved in end2end/fine-tune +start testing... +hit, dt_count, gt_count 4286 4898 5006 +character_acc: 94.16% +avg_edit_dist_field: 0.47 +avg_edit_dist_img: 0.47 +precision: 87.51% +recall: 85.62% +fmeasure: 86.55% +The convert label saved in end2end/quant +start testing... +hit, dt_count, gt_count 4349 4951 5006 +character_acc: 94.13% +avg_edit_dist_field: 0.47 +avg_edit_dist_img: 0.47 +precision: 87.84% +recall: 86.88% +fmeasure: 87.36% +``` + +各个方案端到端指标如下: + +|模型| 指标 | +|---|--------| +|PP-OCRv3中英文超轻量检测预训练模型
PP-OCRv3中英文超轻量识别预训练模型| 0.04% | +|PP-OCRv3中英文超轻量检测预训练模型
PP-OCRv3中英文超轻量识别预训练模型 + 后处理去掉多识别的`·`| 78.27% | +|PP-OCRv3中英文超轻量检测预训练模型+fine-tune
PP-OCRv3中英文超轻量识别预训练模型+fine-tune| 87.14% | +|PP-OCRv3中英文超轻量检测预训练模型+fine-tune+量化
PP-OCRv3中英文超轻量识别预训练模型+fine-tune+量化| 88.00% | + +从结果中可以看到对预训练模型不做修改,只根据场景下的具体情况进行后处理的修改就能大幅提升端到端指标到78.27%,在CCPD数据集上进行 fine-tune 后指标进一步提升到87.14%, 在经过量化训练之后,由于检测模型的recall变高,指标进一步提升到88%。但是这个结果仍旧不符合检测模型+识别模型的真实性能(99%*94%=93%),因此我们需要对 base case 进行具体分析。 + +在之前的端到端预测结果中,可以看到很多不符合车牌标注的文字被识别出来, 因此可以进行简单的过滤来提升precision + +为了快速评估,我们在 ` tools/end2end/convert_ppocr_label.py` 脚本的 58 行加入如下代码,对非8个字符的结果进行过滤 +```python +if len(txt) != 8: # 车牌字符串长度为8 + continue +``` + +此外,通过可视化box可以发现有很多框都是竖直翻转之后的框,并且没有完全框住车牌边界,因此需要进行框的竖直翻转以及轻微扩大,示意图如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/59ab0411c8eb4dfd917fb2b6e5b69a17ee7ca48351444aec9ac6104b79ff1028) + +修改前后个方案指标对比如下: + + +各个方案端到端指标如下: + +|模型|base|A:识别结果过滤|B:use_dilation|C:flip_box|best| +|---|---|---|---|---|---| +|PP-OCRv3中英文超轻量检测预训练模型
PP-OCRv3中英文超轻量识别预训练模型|0.04%|0.08%|0.02%|0.05%|0.00%(A)| +|PP-OCRv3中英文超轻量检测预训练模型
PP-OCRv3中英文超轻量识别预训练模型 + 后处理去掉多识别的`·`|78.27%|90.84%|78.61%|79.43%|91.66%(A+B+C)| +|PP-OCRv3中英文超轻量检测预训练模型+fine-tune
PP-OCRv3中英文超轻量识别预训练模型+fine-tune|87.14%|90.40%|87.66%|89.98%|92.50%(A+B+C)| +|PP-OCRv3中英文超轻量检测预训练模型+fine-tune+量化
PP-OCRv3中英文超轻量识别预训练模型+fine-tune+量化|88.00%|90.54%|88.50%|89.46%|92.02%(A+B+C)| + + +从结果中可以看到对预训练模型不做修改,只根据场景下的具体情况进行后处理的修改就能大幅提升端到端指标到91.66%,在CCPD数据集上进行 fine-tune 后指标进一步提升到92.5%, 在经过量化训练之后,指标变为92.02%。 + +### 4.4 部署 + +- 基于 Paddle Inference 的python推理 + +检测模型和识别模型分别 fine-tune 并导出为inference模型之后,可以使用如下命令基于 Paddle Inference 进行端到端推理并对结果进行可视化。 + +```bash +python tools/infer/predict_system.py \ + --det_model_dir=output/CCPD/det/infer/ \ + --rec_model_dir=output/CCPD/rec/infer/ \ + --image_dir="/home/aistudio/data/CCPD2020/ccpd_green/test/04131106321839081-92_258-159&509_530&611-527&611_172&599_159&509_530&525-0_0_3_32_30_31_30_30-109-106.jpg" \ + --rec_image_shape=3,48,320 +``` +推理结果如下 + +![](https://ai-studio-static-online.cdn.bcebos.com/76b6a0939c2c4cf49039b6563c4b28e241e11285d7464e799e81c58c0f7707a7) + +- 端侧部署 + +端侧部署我们采用基于 PaddleLite 的 cpp 推理。Paddle Lite是飞桨轻量化推理引擎,为手机、IOT端提供高效推理能力,并广泛整合跨平台硬件,为端侧部署及应用落地问题提供轻量化的部署方案。具体可参考 [PaddleOCR lite教程](../deploy/lite/readme_ch.md) + + +### 4.5 实验总结 + +我们分别使用PP-OCRv3中英文超轻量预训练模型在车牌数据集上进行了直接评估和 fine-tune 和 fine-tune +量化3种方案的实验,并基于[PaddleOCR lite教程](../deploy/lite/readme_ch.md)进行了速度测试,指标对比如下: + +- 检测 + +|方案|hmeans| 模型大小 | 预测速度(lite) | +|---|---|------|------------| +|PP-OCRv3中英文超轻量检测预训练模型直接预测|76.12%|2.5M| 233ms | +|PP-OCRv3中英文超轻量检测预训练模型 fine-tune|99.00%| 2.5M | 233ms | +|PP-OCRv3中英文超轻量检测预训练模型 fine-tune + 量化|98.91%| 1.0M | 189ms |fine-tune + +- 识别 + +|方案| acc | 模型大小 | 预测速度(lite) | +|---|--------|-------|------------| +|PP-OCRv3中英文超轻量识别预训练模型直接预测| 0.00% |10.3M| 4.2ms | +|PP-OCRv3中英文超轻量识别预训练模型直接预测+后处理去掉多识别的`·`| 90.97% |10.3M| 4.2ms | +|PP-OCRv3中英文超轻量识别预训练模型 fine-tune| 94.54% | 10.3M | 4.2ms | +|PP-OCRv3中英文超轻量识别预训练模型 fine-tune + 量化| 93.40% | 4.8M | 1.8ms | + + +- 端到端指标如下: + +|方案|fmeasure|模型大小|预测速度(lite) | +|---|---|---|---| +|PP-OCRv3中英文超轻量检测预训练模型
PP-OCRv3中英文超轻量识别预训练模型|0.08%|12.8M|298ms| +|PP-OCRv3中英文超轻量检测预训练模型
PP-OCRv3中英文超轻量识别预训练模型 + 后处理去掉多识别的`·`|91.66%|12.8M|298ms| +|PP-OCRv3中英文超轻量检测预训练模型+fine-tune
PP-OCRv3中英文超轻量识别预训练模型+fine-tune|92.50%|12.8M|298ms| +|PP-OCRv3中英文超轻量检测预训练模型+fine-tune+量化
PP-OCRv3中英文超轻量识别预训练模型+fine-tune+量化|92.02%|5.80M|224ms| + + +**结论** + +PP-OCRv3的检测模型在未经过fine-tune的情况下,在车牌数据集上也有一定的精度,经过 fine-tune 后能够极大的提升检测效果,精度达到99%。在使用量化训练后检测模型的精度几乎无损,并且模型大小压缩60%。 + +PP-OCRv3的识别模型在未经过fine-tune的情况下,在车牌数据集上精度为0,但是经过分析可以知道,模型大部分字符都预测正确,但是会多预测一个特殊字符,去掉这个特殊字符后,精度达到90%。PP-OCRv3识别模型在经过 fine-tune 后识别精度进一步提升,达到94.4%。在使用量化训练后识别模型大小压缩53%,但是由于数据量多少,带来了1%的精度损失。 + +从端到端结果中可以看到对预训练模型不做修改,只根据场景下的具体情况进行后处理的修改就能大幅提升端到端指标到91.66%,在CCPD数据集上进行 fine-tune 后指标进一步提升到92.5%, 在经过量化训练之后,指标轻微下降到92.02%但模型大小降低54%。 diff --git "a/applications/\351\253\230\347\262\276\345\272\246\344\270\255\346\226\207\350\257\206\345\210\253\346\250\241\345\236\213.md" "b/applications/\351\253\230\347\262\276\345\272\246\344\270\255\346\226\207\350\257\206\345\210\253\346\250\241\345\236\213.md" new file mode 100644 index 0000000..b233855 --- /dev/null +++ "b/applications/\351\253\230\347\262\276\345\272\246\344\270\255\346\226\207\350\257\206\345\210\253\346\250\241\345\236\213.md" @@ -0,0 +1,107 @@ +# 高精度中文场景文本识别模型SVTR + +## 1. 简介 + +PP-OCRv3是百度开源的超轻量级场景文本检测识别模型库,其中超轻量的场景中文识别模型SVTR_LCNet使用了SVTR算法结构。为了保证速度,SVTR_LCNet将SVTR模型的Local Blocks替换为LCNet,使用两层Global Blocks。在中文场景中,PP-OCRv3识别主要使用如下优化策略([详细技术报告](../doc/doc_ch/PP-OCRv3_introduction.md)): +- GTC:Attention指导CTC训练策略; +- TextConAug:挖掘文字上下文信息的数据增广策略; +- TextRotNet:自监督的预训练模型; +- UDML:联合互学习策略; +- UIM:无标注数据挖掘方案。 + +其中 *UIM:无标注数据挖掘方案* 使用了高精度的SVTR中文模型进行无标注文件的刷库,该模型在PP-OCRv3识别的数据集上训练,精度对比如下表。 + +|中文识别算法|模型|UIM|精度| +| --- | --- | --- |--- | +|PP-OCRv3|SVTR_LCNet| w/o |78.40%| +|PP-OCRv3|SVTR_LCNet| w |79.40%| +|SVTR|SVTR-Tiny|-|82.50%| + +aistudio项目链接: [高精度中文场景文本识别模型SVTR](https://aistudio.baidu.com/aistudio/projectdetail/4263032) + +## 2. SVTR中文模型使用 + +### 环境准备 + + +本任务基于Aistudio完成, 具体环境如下: + +- 操作系统: Linux +- PaddlePaddle: 2.3 +- PaddleOCR: dygraph + +下载 PaddleOCR代码 + +```bash +git clone -b dygraph https://github.com/PaddlePaddle/PaddleOCR +``` + +安装依赖库 + +```bash +pip install -r PaddleOCR/requirements.txt -i https://mirror.baidu.com/pypi/simple +``` + +### 快速使用 + +获取SVTR中文模型文件,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 +
+ +
+ +```bash +# 解压模型文件 +tar xf svtr_ch_high_accuracy.tar +``` + +预测中文文本,以下图为例: +![](../doc/imgs_words/ch/word_1.jpg) + +预测命令: + +```bash +# CPU预测 +python tools/infer_rec.py -c configs/rec/rec_svtrnet_ch.yml -o Global.pretrained_model=./svtr_ch_high_accuracy/best_accuracy Global.infer_img=./doc/imgs_words/ch/word_1.jpg Global.use_gpu=False + +# GPU预测 +#python tools/infer_rec.py -c configs/rec/rec_svtrnet_ch.yml -o Global.pretrained_model=./svtr_ch_high_accuracy/best_accuracy Global.infer_img=./doc/imgs_words/ch/word_1.jpg Global.use_gpu=True +``` + +可以看到最后打印结果为 +- result: 韩国小馆 0.9853458404541016 + +0.9853458404541016为预测置信度。 + +### 推理模型导出与预测 + +inference 模型(paddle.jit.save保存的模型) 一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。 训练过程中保存的模型是checkpoints模型,保存的只有模型的参数,多用于恢复训练等。 与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。 + +运行识别模型转inference模型命令,如下: + +```bash +python tools/export_model.py -c configs/rec/rec_svtrnet_ch.yml -o Global.pretrained_model=./svtr_ch_high_accuracy/best_accuracy Global.save_inference_dir=./inference/svtr_ch +``` + +转换成功后,在目录下有三个文件: +```shell +inference/svtr_ch/ + ├── inference.pdiparams # 识别inference模型的参数文件 + ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略 + └── inference.pdmodel # 识别inference模型的program文件 +``` + +inference模型预测,命令如下: + +```bash +# CPU预测 +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_1.jpg" --rec_algorithm='SVTR' --rec_model_dir=./inference/svtr_ch/ --rec_image_shape='3, 32, 320' --rec_char_dict_path=ppocr/utils/ppocr_keys_v1.txt --use_gpu=False + +# GPU预测 +#python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_1.jpg" --rec_algorithm='SVTR' --rec_model_dir=./inference/svtr_ch/ --rec_image_shape='3, 32, 320' --rec_char_dict_path=ppocr/utils/ppocr_keys_v1.txt --use_gpu=True +``` + +**注意** + +- 使用SVTR算法时,需要指定--rec_algorithm='SVTR' +- 如果使用自定义字典训练的模型,需要将--rec_char_dict_path=ppocr/utils/ppocr_keys_v1.txt修改为自定义的字典 +- --rec_image_shape='3, 32, 320' 该参数不能去掉 diff --git a/benchmark/PaddleOCR_DBNet/.gitattributes b/benchmark/PaddleOCR_DBNet/.gitattributes new file mode 100644 index 0000000..8543e0a --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/.gitattributes @@ -0,0 +1,2 @@ +*.html linguist-language=python +*.ipynb linguist-language=python \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/.gitignore b/benchmark/PaddleOCR_DBNet/.gitignore new file mode 100644 index 0000000..cef1c73 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/.gitignore @@ -0,0 +1,16 @@ +.DS_Store +*.pth +*.pyc +*.pyo +*.log +*.tmp +*.pkl +__pycache__/ +.idea/ +output/ +test/*.jpg +datasets/ +index/ +train_log/ +log/ +profiling_log/ \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/LICENSE.md b/benchmark/PaddleOCR_DBNet/LICENSE.md new file mode 100644 index 0000000..b09cd78 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/LICENSE.md @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/benchmark/PaddleOCR_DBNet/README.MD b/benchmark/PaddleOCR_DBNet/README.MD new file mode 100644 index 0000000..dbc07fa --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/README.MD @@ -0,0 +1,132 @@ +# Real-time Scene Text Detection with Differentiable Binarization + +**note**: some code is inherited from [WenmuZhou/DBNet.pytorch](https://github.com/WenmuZhou/DBNet.pytorch) + +[中文解读](https://zhuanlan.zhihu.com/p/94677957) + +![network](imgs/paper/db.jpg) + +## update +2020-06-07: 添加灰度图训练,训练灰度图时需要在配置里移除`dataset.args.transforms.Normalize` + +## Install Using Conda +``` +conda env create -f environment.yml +git clone https://github.com/WenmuZhou/DBNet.paddle.git +cd DBNet.paddle/ +``` + +or +## Install Manually +```bash +conda create -n dbnet python=3.6 +conda activate dbnet + +conda install ipython pip + +# python dependencies +pip install -r requirement.txt + +# clone repo +git clone https://github.com/WenmuZhou/DBNet.paddle.git +cd DBNet.paddle/ + +``` + +## Requirements +* paddlepaddle 2.4+ + +## Download + +TBD + +## Data Preparation + +Training data: prepare a text `train.txt` in the following format, use '\t' as a separator +``` +./datasets/train/img/001.jpg ./datasets/train/gt/001.txt +``` + +Validation data: prepare a text `test.txt` in the following format, use '\t' as a separator +``` +./datasets/test/img/001.jpg ./datasets/test/gt/001.txt +``` +- Store images in the `img` folder +- Store groundtruth in the `gt` folder + +The groundtruth can be `.txt` files, with the following format: +``` +x1, y1, x2, y2, x3, y3, x4, y4, annotation +``` + + +## Train +1. config the `dataset['train']['dataset'['data_path']'`,`dataset['validate']['dataset'['data_path']`in [config/icdar2015_resnet18_fpn_DBhead_polyLR.yaml](cconfig/icdar2015_resnet18_fpn_DBhead_polyLR.yaml) +* . single gpu train +```bash +bash singlel_gpu_train.sh +``` +* . Multi-gpu training +```bash +bash multi_gpu_train.sh +``` +## Test + +[eval.py](tools/eval.py) is used to test model on test dataset + +1. config `model_path` in [eval.sh](eval.sh) +2. use following script to test +```bash +bash eval.sh +``` + +## Predict +[predict.py](tools/predict.py) Can be used to inference on all images in a folder +1. config `model_path`,`input_folder`,`output_folder` in [predict.sh](predict.sh) +2. use following script to predict +``` +bash predict.sh +``` +You can change the `model_path` in the `predict.sh` file to your model location. + +tips: if result is not good, you can change `thre` in [predict.sh](predict.sh) + +## Export Model + +[export_model.py](tools/export_model.py) Can be used to inference on all images in a folder + +use following script to export inference model +``` +python tools/export_model.py --config_file config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml -o trainer.resume_checkpoint=model_best.pth trainer.output_dir=output/infer +``` + +## Paddle Inference infer + +[infer.py](tools/infer.py) Can be used to inference on all images in a folder + +use following script to export inference model +``` +python tools/infer.py --model-dir=output/infer/ --img-path imgs/paper/db.jpg +``` + +

Performance

+ +### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4) +only train on ICDAR2015 dataset + +| Method | image size (short size) |learning rate | Precision (%) | Recall (%) | F-measure (%) | FPS | +|:--------------------------:|:-------:|:--------:|:--------:|:------------:|:---------------:|:-----:| +| ImageNet-resnet50-FPN-DBHead(torch) |736 |1e-3|90.19 | 78.14 | 83.88 | 27 | +| ImageNet-resnet50-FPN-DBHead(paddle) |736 |1e-3| 89.47 | 79.03 | 83.92 | 27 | +| ImageNet-resnet50-FPN-DBHead(paddle_amp) |736 |1e-3| 88.62 | 79.95 | 84.06 | 27 | + + +### examples +TBD + + +### reference +1. https://arxiv.org/pdf/1911.08947.pdf +2. https://github.com/WenmuZhou/DBNet.pytorch + +**If this repository helps you,please star it. Thanks.** diff --git a/benchmark/PaddleOCR_DBNet/base/__init__.py b/benchmark/PaddleOCR_DBNet/base/__init__.py new file mode 100644 index 0000000..223e9e0 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/base/__init__.py @@ -0,0 +1,2 @@ +from .base_trainer import BaseTrainer +from .base_dataset import BaseDataSet \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/base/base_dataset.py b/benchmark/PaddleOCR_DBNet/base/base_dataset.py new file mode 100644 index 0000000..4a839a8 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/base/base_dataset.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/4 13:12 +# @Author : zhoujun +import copy +from paddle.io import Dataset +from data_loader.modules import * + + +class BaseDataSet(Dataset): + def __init__(self, + data_path: str, + img_mode, + pre_processes, + filter_keys, + ignore_tags, + transform=None, + target_transform=None): + assert img_mode in ['RGB', 'BRG', 'GRAY'] + self.ignore_tags = ignore_tags + self.data_list = self.load_data(data_path) + item_keys = [ + 'img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags' + ] + for item in item_keys: + assert item in self.data_list[ + 0], 'data_list from load_data must contains {}'.format( + item_keys) + self.img_mode = img_mode + self.filter_keys = filter_keys + self.transform = transform + self.target_transform = target_transform + self._init_pre_processes(pre_processes) + + def _init_pre_processes(self, pre_processes): + self.aug = [] + if pre_processes is not None: + for aug in pre_processes: + if 'args' not in aug: + args = {} + else: + args = aug['args'] + if isinstance(args, dict): + cls = eval(aug['type'])(**args) + else: + cls = eval(aug['type'])(args) + self.aug.append(cls) + + def load_data(self, data_path: str) -> list: + """ + 把数据加载为一个list: + :params data_path: 存储数据的文件夹或者文件 + return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags' + """ + raise NotImplementedError + + def apply_pre_processes(self, data): + for aug in self.aug: + data = aug(data) + return data + + def __getitem__(self, index): + try: + data = copy.deepcopy(self.data_list[index]) + im = cv2.imread(data['img_path'], 1 + if self.img_mode != 'GRAY' else 0) + if self.img_mode == 'RGB': + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + data['img'] = im + data['shape'] = [im.shape[0], im.shape[1]] + data = self.apply_pre_processes(data) + + if self.transform: + data['img'] = self.transform(data['img']) + data['text_polys'] = data['text_polys'].tolist() + if len(self.filter_keys): + data_dict = {} + for k, v in data.items(): + if k not in self.filter_keys: + data_dict[k] = v + return data_dict + else: + return data + except: + return self.__getitem__(np.random.randint(self.__len__())) + + def __len__(self): + return len(self.data_list) diff --git a/benchmark/PaddleOCR_DBNet/base/base_trainer.py b/benchmark/PaddleOCR_DBNet/base/base_trainer.py new file mode 100644 index 0000000..82c308d --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/base/base_trainer.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:50 +# @Author : zhoujun + +import os +import pathlib +import shutil +from pprint import pformat + +import anyconfig +import paddle +import numpy as np +import random +from paddle.jit import to_static +from paddle.static import InputSpec + +from utils import setup_logger + + +class BaseTrainer: + def __init__(self, + config, + model, + criterion, + train_loader, + validate_loader, + metric_cls, + post_process=None): + config['trainer']['output_dir'] = os.path.join( + str(pathlib.Path(os.path.abspath(__name__)).parent), + config['trainer']['output_dir']) + config['name'] = config['name'] + '_' + model.name + self.save_dir = config['trainer']['output_dir'] + self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') + + os.makedirs(self.checkpoint_dir, exist_ok=True) + + self.global_step = 0 + self.start_epoch = 0 + self.config = config + self.criterion = criterion + # logger and tensorboard + self.visualdl_enable = self.config['trainer'].get('visual_dl', False) + self.epochs = self.config['trainer']['epochs'] + self.log_iter = self.config['trainer']['log_iter'] + if paddle.distributed.get_rank() == 0: + anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml')) + self.logger = setup_logger(os.path.join(self.save_dir, 'train.log')) + self.logger_info(pformat(self.config)) + + self.model = self.apply_to_static(model) + + # device + if paddle.device.cuda.device_count( + ) > 0 and paddle.device.is_compiled_with_cuda(): + self.with_cuda = True + random.seed(self.config['trainer']['seed']) + np.random.seed(self.config['trainer']['seed']) + paddle.seed(self.config['trainer']['seed']) + else: + self.with_cuda = False + self.logger_info('train with and paddle {}'.format(paddle.__version__)) + # metrics + self.metrics = { + 'recall': 0, + 'precision': 0, + 'hmean': 0, + 'train_loss': float('inf'), + 'best_model_epoch': 0 + } + + self.train_loader = train_loader + if validate_loader is not None: + assert post_process is not None and metric_cls is not None + self.validate_loader = validate_loader + self.post_process = post_process + self.metric_cls = metric_cls + self.train_loader_len = len(train_loader) + + if self.validate_loader is not None: + self.logger_info( + 'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'. + format( + len(self.train_loader.dataset), self.train_loader_len, + len(self.validate_loader.dataset), + len(self.validate_loader))) + else: + self.logger_info( + 'train dataset has {} samples,{} in dataloader'.format( + len(self.train_loader.dataset), self.train_loader_len)) + + self._initialize_scheduler() + + self._initialize_optimizer() + + # resume or finetune + if self.config['trainer']['resume_checkpoint'] != '': + self._load_checkpoint( + self.config['trainer']['resume_checkpoint'], resume=True) + elif self.config['trainer']['finetune_checkpoint'] != '': + self._load_checkpoint( + self.config['trainer']['finetune_checkpoint'], resume=False) + + if self.visualdl_enable and paddle.distributed.get_rank() == 0: + from visualdl import LogWriter + self.writer = LogWriter(self.save_dir) + + # 混合精度训练 + self.amp = self.config.get('amp', None) + if self.amp == 'None': + self.amp = None + if self.amp: + self.amp['scaler'] = paddle.amp.GradScaler( + init_loss_scaling=self.amp.get("scale_loss", 1024), + use_dynamic_loss_scaling=self.amp.get( + 'use_dynamic_loss_scaling', True)) + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=self.amp.get('amp_level', 'O2')) + + # 分布式训练 + if paddle.device.cuda.device_count() > 1: + self.model = paddle.DataParallel(self.model) + # make inverse Normalize + self.UN_Normalize = False + for t in self.config['dataset']['train']['dataset']['args'][ + 'transforms']: + if t['type'] == 'Normalize': + self.normalize_mean = t['args']['mean'] + self.normalize_std = t['args']['std'] + self.UN_Normalize = True + + def apply_to_static(self, model): + support_to_static = self.config['trainer'].get('to_static', False) + if support_to_static: + specs = None + print('static') + specs = [InputSpec([None, 3, -1, -1])] + model = to_static(model, input_spec=specs) + self.logger_info( + "Successfully to apply @to_static with specs: {}".format(specs)) + return model + + def train(self): + """ + Full training logic + """ + for epoch in range(self.start_epoch + 1, self.epochs + 1): + self.epoch_result = self._train_epoch(epoch) + self._on_epoch_finish() + if paddle.distributed.get_rank() == 0 and self.visualdl_enable: + self.writer.close() + self._on_train_finish() + + def _train_epoch(self, epoch): + """ + Training logic for an epoch + + :param epoch: Current epoch number + """ + raise NotImplementedError + + def _eval(self, epoch): + """ + eval logic for an epoch + + :param epoch: Current epoch number + """ + raise NotImplementedError + + def _on_epoch_finish(self): + raise NotImplementedError + + def _on_train_finish(self): + raise NotImplementedError + + def _save_checkpoint(self, epoch, file_name): + """ + Saving checkpoints + + :param epoch: current epoch number + :param log: logging information of the epoch + :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar' + """ + state_dict = self.model.state_dict() + state = { + 'epoch': epoch, + 'global_step': self.global_step, + 'state_dict': state_dict, + 'optimizer': self.optimizer.state_dict(), + 'config': self.config, + 'metrics': self.metrics + } + filename = os.path.join(self.checkpoint_dir, file_name) + paddle.save(state, filename) + + def _load_checkpoint(self, checkpoint_path, resume): + """ + Resume from saved checkpoints + :param checkpoint_path: Checkpoint path to be resumed + """ + self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path)) + checkpoint = paddle.load(checkpoint_path) + self.model.set_state_dict(checkpoint['state_dict']) + if resume: + self.global_step = checkpoint['global_step'] + self.start_epoch = checkpoint['epoch'] + self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch + # self.scheduler.load_state_dict(checkpoint['scheduler']) + self.optimizer.set_state_dict(checkpoint['optimizer']) + if 'metrics' in checkpoint: + self.metrics = checkpoint['metrics'] + self.logger_info("resume from checkpoint {} (epoch {})".format( + checkpoint_path, self.start_epoch)) + else: + self.logger_info("finetune from checkpoint {}".format( + checkpoint_path)) + + def _initialize(self, name, module, *args, **kwargs): + module_name = self.config[name]['type'] + module_args = self.config[name].get('args', {}) + assert all([k not in module_args for k in kwargs + ]), 'Overwriting kwargs given in config file is not allowed' + module_args.update(kwargs) + return getattr(module, module_name)(*args, **module_args) + + def _initialize_scheduler(self): + self.lr_scheduler = self._initialize('lr_scheduler', + paddle.optimizer.lr) + + def _initialize_optimizer(self): + self.optimizer = self._initialize( + 'optimizer', + paddle.optimizer, + parameters=self.model.parameters(), + learning_rate=self.lr_scheduler) + + def inverse_normalize(self, batch_img): + if self.UN_Normalize: + batch_img[:, 0, :, :] = batch_img[:, 0, :, :] * self.normalize_std[ + 0] + self.normalize_mean[0] + batch_img[:, 1, :, :] = batch_img[:, 1, :, :] * self.normalize_std[ + 1] + self.normalize_mean[1] + batch_img[:, 2, :, :] = batch_img[:, 2, :, :] * self.normalize_std[ + 2] + self.normalize_mean[2] + + def logger_info(self, s): + if paddle.distributed.get_rank() == 0: + self.logger.info(s) diff --git a/benchmark/PaddleOCR_DBNet/config/SynthText.yaml b/benchmark/PaddleOCR_DBNet/config/SynthText.yaml new file mode 100644 index 0000000..61d5da7 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/SynthText.yaml @@ -0,0 +1,40 @@ +name: DBNet +dataset: + train: + dataset: + type: SynthTextDataset # 数据集类型 + args: + data_path: ''# SynthTextDataset 根目录 + pre_processes: # 数据的预处理过程,包含augment和标签制作 + - type: IaaAugment # 使用imgaug进行变换 + args: + - {'type':Fliplr, 'args':{'p':0.5}} + - {'type': Affine, 'args':{'rotate':[-10,10]}} + - {'type':Resize,'args':{'size':[0.5,3]}} + - type: EastRandomCropData + args: + size: [640,640] + max_tries: 50 + keep_ratio: true + - type: MakeBorderMap + args: + shrink_ratio: 0.4 + - type: MakeShrinkMap + args: + shrink_ratio: 0.4 + min_text_size: 8 + transforms: # 对图片进行的变换方式 + - type: ToTensor + args: {} + - type: Normalize + args: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + img_mode: RGB + filter_keys: ['img_path','img_name','text_polys','texts','ignore_tags','shape'] # 返回数据之前,从数据字典里删除的key + ignore_tags: ['*', '###'] + loader: + batch_size: 1 + shuffle: true + num_workers: 0 + collate_fn: '' \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/config/SynthText_resnet18_FPN_DBhead_polyLR.yaml b/benchmark/PaddleOCR_DBNet/config/SynthText_resnet18_FPN_DBhead_polyLR.yaml new file mode 100644 index 0000000..a665e94 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/SynthText_resnet18_FPN_DBhead_polyLR.yaml @@ -0,0 +1,65 @@ +name: DBNet +base: ['config/SynthText.yaml'] +arch: + type: Model + backbone: + type: resnet18 + pretrained: true + neck: + type: FPN + inner_channels: 256 + head: + type: DBHead + out_channels: 2 + k: 50 +post_processing: + type: SegDetectorRepresenter + args: + thresh: 0.3 + box_thresh: 0.7 + max_candidates: 1000 + unclip_ratio: 1.5 # from paper +metric: + type: QuadMetric + args: + is_output_polygon: false +loss: + type: DBLoss + alpha: 1 + beta: 10 + ohem_ratio: 3 +optimizer: + type: Adam + args: + lr: 0.001 + weight_decay: 0 + amsgrad: true +lr_scheduler: + type: WarmupPolyLR + args: + warmup_epoch: 3 +trainer: + seed: 2 + epochs: 1200 + log_iter: 10 + show_images_iter: 50 + resume_checkpoint: '' + finetune_checkpoint: '' + output_dir: output + visual_dl: false +amp: + scale_loss: 1024 + amp_level: O2 + custom_white_list: [] + custom_black_list: ['exp', 'sigmoid', 'concat'] +dataset: + train: + dataset: + args: + data_path: ./datasets/SynthText + img_mode: RGB + loader: + batch_size: 2 + shuffle: true + num_workers: 6 + collate_fn: '' \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/config/icdar2015.yaml b/benchmark/PaddleOCR_DBNet/config/icdar2015.yaml new file mode 100644 index 0000000..4551b14 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/icdar2015.yaml @@ -0,0 +1,69 @@ +name: DBNet +dataset: + train: + dataset: + type: ICDAR2015Dataset # 数据集类型 + args: + data_path: # 一个存放 img_path \t gt_path的文件 + - '' + pre_processes: # 数据的预处理过程,包含augment和标签制作 + - type: IaaAugment # 使用imgaug进行变换 + args: + - {'type':Fliplr, 'args':{'p':0.5}} + - {'type': Affine, 'args':{'rotate':[-10,10]}} + - {'type':Resize,'args':{'size':[0.5,3]}} + - type: EastRandomCropData + args: + size: [640,640] + max_tries: 50 + keep_ratio: true + - type: MakeBorderMap + args: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - type: MakeShrinkMap + args: + shrink_ratio: 0.4 + min_text_size: 8 + transforms: # 对图片进行的变换方式 + - type: ToTensor + args: {} + - type: Normalize + args: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + img_mode: RGB + filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key + ignore_tags: ['*', '###'] + loader: + batch_size: 1 + shuffle: true + num_workers: 0 + collate_fn: '' + validate: + dataset: + type: ICDAR2015Dataset + args: + data_path: + - '' + pre_processes: + - type: ResizeShortSize + args: + short_size: 736 + resize_text_polys: false + transforms: + - type: ToTensor + args: {} + - type: Normalize + args: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + img_mode: RGB + filter_keys: [] + ignore_tags: ['*', '###'] + loader: + batch_size: 1 + shuffle: true + num_workers: 0 + collate_fn: ICDARCollectFN \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/config/icdar2015_dcn_resnet18_FPN_DBhead_polyLR.yaml b/benchmark/PaddleOCR_DBNet/config/icdar2015_dcn_resnet18_FPN_DBhead_polyLR.yaml new file mode 100644 index 0000000..608ef42 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/icdar2015_dcn_resnet18_FPN_DBhead_polyLR.yaml @@ -0,0 +1,82 @@ +name: DBNet +base: ['config/icdar2015.yaml'] +arch: + type: Model + backbone: + type: deformable_resnet18 + pretrained: true + neck: + type: FPN + inner_channels: 256 + head: + type: DBHead + out_channels: 2 + k: 50 +post_processing: + type: SegDetectorRepresenter + args: + thresh: 0.3 + box_thresh: 0.7 + max_candidates: 1000 + unclip_ratio: 1.5 # from paper +metric: + type: QuadMetric + args: + is_output_polygon: false +loss: + type: DBLoss + alpha: 1 + beta: 10 + ohem_ratio: 3 +optimizer: + type: Adam + args: + lr: 0.001 + weight_decay: 0 + amsgrad: true +lr_scheduler: + type: WarmupPolyLR + args: + warmup_epoch: 3 +trainer: + seed: 2 + epochs: 1200 + log_iter: 10 + show_images_iter: 50 + resume_checkpoint: '' + finetune_checkpoint: '' + output_dir: output + visual_dl: false +amp: + scale_loss: 1024 + amp_level: O2 + custom_white_list: [] + custom_black_list: ['exp', 'sigmoid', 'concat'] +dataset: + train: + dataset: + args: + data_path: + - ./datasets/train.txt + img_mode: RGB + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: '' + validate: + dataset: + args: + data_path: + - ./datasets/test.txt + pre_processes: + - type: ResizeShortSize + args: + short_size: 736 + resize_text_polys: false + img_mode: RGB + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: ICDARCollectFN \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml b/benchmark/PaddleOCR_DBNet/config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml new file mode 100644 index 0000000..62c392b --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml @@ -0,0 +1,82 @@ +name: DBNet +base: ['config/icdar2015.yaml'] +arch: + type: Model + backbone: + type: resnet18 + pretrained: true + neck: + type: FPN + inner_channels: 256 + head: + type: DBHead + out_channels: 2 + k: 50 +post_processing: + type: SegDetectorRepresenter + args: + thresh: 0.3 + box_thresh: 0.7 + max_candidates: 1000 + unclip_ratio: 1.5 # from paper +metric: + type: QuadMetric + args: + is_output_polygon: false +loss: + type: DBLoss + alpha: 1 + beta: 10 + ohem_ratio: 3 +optimizer: + type: Adam + args: + lr: 0.001 + weight_decay: 0 + amsgrad: true +lr_scheduler: + type: WarmupPolyLR + args: + warmup_epoch: 3 +trainer: + seed: 2 + epochs: 1200 + log_iter: 10 + show_images_iter: 50 + resume_checkpoint: '' + finetune_checkpoint: '' + output_dir: output + visual_dl: false +amp: + scale_loss: 1024 + amp_level: O2 + custom_white_list: [] + custom_black_list: ['exp', 'sigmoid', 'concat'] +dataset: + train: + dataset: + args: + data_path: + - ./datasets/train.txt + img_mode: RGB + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: '' + validate: + dataset: + args: + data_path: + - ./datasets/test.txt + pre_processes: + - type: ResizeShortSize + args: + short_size: 736 + resize_text_polys: false + img_mode: RGB + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: ICDARCollectFN diff --git a/benchmark/PaddleOCR_DBNet/config/icdar2015_resnet18_FPN_DBhead_polyLR_finetune.yaml b/benchmark/PaddleOCR_DBNet/config/icdar2015_resnet18_FPN_DBhead_polyLR_finetune.yaml new file mode 100644 index 0000000..9b018d5 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/icdar2015_resnet18_FPN_DBhead_polyLR_finetune.yaml @@ -0,0 +1,83 @@ +name: DBNet +base: ['config/icdar2015.yaml'] +arch: + type: Model + backbone: + type: resnet18 + pretrained: true + neck: + type: FPN + inner_channels: 256 + head: + type: DBHead + out_channels: 2 + k: 50 +post_processing: + type: SegDetectorRepresenter + args: + thresh: 0.3 + box_thresh: 0.7 + max_candidates: 1000 + unclip_ratio: 1.5 # from paper +metric: + type: QuadMetric + args: + is_output_polygon: false +loss: + type: DBLoss + alpha: 1 + beta: 10 + ohem_ratio: 3 +optimizer: + type: Adam + args: + lr: 0.001 + weight_decay: 0 + amsgrad: true +lr_scheduler: + type: StepLR + args: + step_size: 10 + gama: 0.8 +trainer: + seed: 2 + epochs: 500 + log_iter: 10 + show_images_iter: 50 + resume_checkpoint: '' + finetune_checkpoint: '' + output_dir: output + visual_dl: false +amp: + scale_loss: 1024 + amp_level: O2 + custom_white_list: [] + custom_black_list: ['exp', 'sigmoid', 'concat'] +dataset: + train: + dataset: + args: + data_path: + - ./datasets/train.txt + img_mode: RGB + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: '' + validate: + dataset: + args: + data_path: + - ./datasets/test.txt + pre_processes: + - type: ResizeShortSize + args: + short_size: 736 + resize_text_polys: false + img_mode: RGB + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: ICDARCollectFN diff --git a/benchmark/PaddleOCR_DBNet/config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml b/benchmark/PaddleOCR_DBNet/config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml new file mode 100644 index 0000000..2a870fd --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml @@ -0,0 +1,79 @@ +name: DBNet +base: ['config/icdar2015.yaml'] +arch: + type: Model + backbone: + type: resnet50 + pretrained: true + neck: + type: FPN + inner_channels: 256 + head: + type: DBHead + out_channels: 2 + k: 50 +post_processing: + type: SegDetectorRepresenter + args: + thresh: 0.3 + box_thresh: 0.7 + max_candidates: 1000 + unclip_ratio: 1.5 # from paper +metric: + type: QuadMetric + args: + is_output_polygon: false +loss: + type: DBLoss + alpha: 1 + beta: 10 + ohem_ratio: 3 +optimizer: + type: Adam +lr_scheduler: + type: Polynomial + args: + learning_rate: 0.001 + warmup_epoch: 3 +trainer: + seed: 2 + epochs: 1200 + log_iter: 10 + show_images_iter: 50 + resume_checkpoint: '' + finetune_checkpoint: '' + output_dir: output/fp16_o2 + visual_dl: false +amp: + scale_loss: 1024 + amp_level: O2 + custom_white_list: [] + custom_black_list: ['exp', 'sigmoid', 'concat'] +dataset: + train: + dataset: + args: + data_path: + - ./datasets/train.txt + img_mode: RGB + loader: + batch_size: 16 + shuffle: true + num_workers: 6 + collate_fn: '' + validate: + dataset: + args: + data_path: + - ./datasets/test.txt + pre_processes: + - type: ResizeShortSize + args: + short_size: 736 + resize_text_polys: false + img_mode: RGB + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: ICDARCollectFN diff --git a/benchmark/PaddleOCR_DBNet/config/open_dataset.yaml b/benchmark/PaddleOCR_DBNet/config/open_dataset.yaml new file mode 100644 index 0000000..9726758 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/open_dataset.yaml @@ -0,0 +1,73 @@ +name: DBNet +dataset: + train: + dataset: + type: DetDataset # 数据集类型 + args: + data_path: # 一个存放 img_path \t gt_path的文件 + - '' + pre_processes: # 数据的预处理过程,包含augment和标签制作 + - type: IaaAugment # 使用imgaug进行变换 + args: + - {'type':Fliplr, 'args':{'p':0.5}} + - {'type': Affine, 'args':{'rotate':[-10,10]}} + - {'type':Resize,'args':{'size':[0.5,3]}} + - type: EastRandomCropData + args: + size: [640,640] + max_tries: 50 + keep_ratio: true + - type: MakeBorderMap + args: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - type: MakeShrinkMap + args: + shrink_ratio: 0.4 + min_text_size: 8 + transforms: # 对图片进行的变换方式 + - type: ToTensor + args: {} + - type: Normalize + args: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + img_mode: RGB + load_char_annotation: false + expand_one_char: false + filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key + ignore_tags: ['*', '###'] + loader: + batch_size: 1 + shuffle: true + num_workers: 0 + collate_fn: '' + validate: + dataset: + type: DetDataset + args: + data_path: + - '' + pre_processes: + - type: ResizeShortSize + args: + short_size: 736 + resize_text_polys: false + transforms: + - type: ToTensor + args: {} + - type: Normalize + args: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + img_mode: RGB + load_char_annotation: false # 是否加载字符级标注 + expand_one_char: false # 是否对只有一个字符的框进行宽度扩充,扩充后w = w+h + filter_keys: [] + ignore_tags: ['*', '###'] + loader: + batch_size: 1 + shuffle: true + num_workers: 0 + collate_fn: ICDARCollectFN \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/config/open_dataset_dcn_resnet50_FPN_DBhead_polyLR.yaml b/benchmark/PaddleOCR_DBNet/config/open_dataset_dcn_resnet50_FPN_DBhead_polyLR.yaml new file mode 100644 index 0000000..6c81738 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/open_dataset_dcn_resnet50_FPN_DBhead_polyLR.yaml @@ -0,0 +1,86 @@ +name: DBNet +base: ['config/open_dataset.yaml'] +arch: + type: Model + backbone: + type: deformable_resnet18 + pretrained: true + neck: + type: FPN + inner_channels: 256 + head: + type: DBHead + out_channels: 2 + k: 50 +post_processing: + type: SegDetectorRepresenter + args: + thresh: 0.3 + box_thresh: 0.7 + max_candidates: 1000 + unclip_ratio: 1.5 # from paper +metric: + type: QuadMetric + args: + is_output_polygon: false +loss: + type: DBLoss + alpha: 1 + beta: 10 + ohem_ratio: 3 +optimizer: + type: Adam + args: + lr: 0.001 + weight_decay: 0 + amsgrad: true +lr_scheduler: + type: WarmupPolyLR + args: + warmup_epoch: 3 +trainer: + seed: 2 + epochs: 1200 + log_iter: 1 + show_images_iter: 1 + resume_checkpoint: '' + finetune_checkpoint: '' + output_dir: output + visual_dl: false +amp: + scale_loss: 1024 + amp_level: O2 + custom_white_list: [] + custom_black_list: ['exp', 'sigmoid', 'concat'] +dataset: + train: + dataset: + args: + data_path: + - ./datasets/train.json + img_mode: RGB + load_char_annotation: false + expand_one_char: false + loader: + batch_size: 2 + shuffle: true + num_workers: 6 + collate_fn: '' + validate: + dataset: + args: + data_path: + - ./datasets/test.json + pre_processes: + - type: ResizeShortSize + args: + short_size: 736 + resize_text_polys: false + img_mode: RGB + load_char_annotation: false + expand_one_char: false + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: ICDARCollectFN diff --git a/benchmark/PaddleOCR_DBNet/config/open_dataset_resnest50_FPN_DBhead_polyLR.yaml b/benchmark/PaddleOCR_DBNet/config/open_dataset_resnest50_FPN_DBhead_polyLR.yaml new file mode 100644 index 0000000..d47ab06 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/open_dataset_resnest50_FPN_DBhead_polyLR.yaml @@ -0,0 +1,86 @@ +name: DBNet +base: ['config/open_dataset.yaml'] +arch: + type: Model + backbone: + type: resnest50 + pretrained: true + neck: + type: FPN + inner_channels: 256 + head: + type: DBHead + out_channels: 2 + k: 50 +post_processing: + type: SegDetectorRepresenter + args: + thresh: 0.3 + box_thresh: 0.7 + max_candidates: 1000 + unclip_ratio: 1.5 # from paper +metric: + type: QuadMetric + args: + is_output_polygon: false +loss: + type: DBLoss + alpha: 1 + beta: 10 + ohem_ratio: 3 +optimizer: + type: Adam + args: + lr: 0.001 + weight_decay: 0 + amsgrad: true +lr_scheduler: + type: WarmupPolyLR + args: + warmup_epoch: 3 +trainer: + seed: 2 + epochs: 1200 + log_iter: 1 + show_images_iter: 1 + resume_checkpoint: '' + finetune_checkpoint: '' + output_dir: output + visual_dl: false +amp: + scale_loss: 1024 + amp_level: O2 + custom_white_list: [] + custom_black_list: ['exp', 'sigmoid', 'concat'] +dataset: + train: + dataset: + args: + data_path: + - ./datasets/train.json + img_mode: RGB + load_char_annotation: false + expand_one_char: false + loader: + batch_size: 2 + shuffle: true + num_workers: 6 + collate_fn: '' + validate: + dataset: + args: + data_path: + - ./datasets/test.json + pre_processes: + - type: ResizeShortSize + args: + short_size: 736 + resize_text_polys: false + img_mode: RGB + load_char_annotation: false + expand_one_char: false + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: ICDARCollectFN diff --git a/benchmark/PaddleOCR_DBNet/config/open_dataset_resnet18_FPN_DBhead_polyLR.yaml b/benchmark/PaddleOCR_DBNet/config/open_dataset_resnet18_FPN_DBhead_polyLR.yaml new file mode 100644 index 0000000..ff16ddb --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/config/open_dataset_resnet18_FPN_DBhead_polyLR.yaml @@ -0,0 +1,93 @@ +name: DBNet +base: ['config/open_dataset.yaml'] +arch: + type: Model + backbone: + type: resnet18 + pretrained: true + neck: + type: FPN + inner_channels: 256 + head: + type: DBHead + out_channels: 2 + k: 50 +post_processing: + type: SegDetectorRepresenter + args: + thresh: 0.3 + box_thresh: 0.7 + max_candidates: 1000 + unclip_ratio: 1.5 # from paper +metric: + type: QuadMetric + args: + is_output_polygon: false +loss: + type: DBLoss + alpha: 1 + beta: 10 + ohem_ratio: 3 +optimizer: + type: Adam + args: + lr: 0.001 + weight_decay: 0 + amsgrad: true +lr_scheduler: + type: WarmupPolyLR + args: + warmup_epoch: 3 +trainer: + seed: 2 + epochs: 1200 + log_iter: 1 + show_images_iter: 1 + resume_checkpoint: '' + finetune_checkpoint: '' + output_dir: output + visual_dl: false +amp: + scale_loss: 1024 + amp_level: O2 + custom_white_list: [] + custom_black_list: ['exp', 'sigmoid', 'concat'] +dataset: + train: + dataset: + args: + data_path: + - ./datasets/train.json + transforms: # 对图片进行的变换方式 + - type: ToTensor + args: {} + - type: Normalize + args: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + img_mode: RGB + load_char_annotation: false + expand_one_char: false + loader: + batch_size: 2 + shuffle: true + num_workers: 6 + collate_fn: '' + validate: + dataset: + args: + data_path: + - ./datasets/test.json + pre_processes: + - type: ResizeShortSize + args: + short_size: 736 + resize_text_polys: false + img_mode: RGB + load_char_annotation: false + expand_one_char: false + loader: + batch_size: 1 + shuffle: true + num_workers: 6 + collate_fn: ICDARCollectFN diff --git a/benchmark/PaddleOCR_DBNet/data_loader/__init__.py b/benchmark/PaddleOCR_DBNet/data_loader/__init__.py new file mode 100644 index 0000000..afc6e56 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/data_loader/__init__.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:52 +# @Author : zhoujun +import copy + +import PIL +import numpy as np +import paddle +from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler + +from paddle.vision import transforms + + +def get_dataset(data_path, module_name, transform, dataset_args): + """ + 获取训练dataset + :param data_path: dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ + :param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset + :param transform: 该数据集使用的transforms + :param dataset_args: module_name的参数 + :return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None + """ + from . import dataset + s_dataset = getattr(dataset, module_name)(transform=transform, + data_path=data_path, + **dataset_args) + return s_dataset + + +def get_transforms(transforms_config): + tr_list = [] + for item in transforms_config: + if 'args' not in item: + args = {} + else: + args = item['args'] + cls = getattr(transforms, item['type'])(**args) + tr_list.append(cls) + tr_list = transforms.Compose(tr_list) + return tr_list + + +class ICDARCollectFN: + def __init__(self, *args, **kwargs): + pass + + def __call__(self, batch): + data_dict = {} + to_tensor_keys = [] + for sample in batch: + for k, v in sample.items(): + if k not in data_dict: + data_dict[k] = [] + if isinstance(v, (np.ndarray, paddle.Tensor, PIL.Image.Image)): + if k not in to_tensor_keys: + to_tensor_keys.append(k) + data_dict[k].append(v) + for k in to_tensor_keys: + data_dict[k] = paddle.stack(data_dict[k], 0) + return data_dict + + +def get_dataloader(module_config, distributed=False): + if module_config is None: + return None + config = copy.deepcopy(module_config) + dataset_args = config['dataset']['args'] + if 'transforms' in dataset_args: + img_transfroms = get_transforms(dataset_args.pop('transforms')) + else: + img_transfroms = None + # 创建数据集 + dataset_name = config['dataset']['type'] + data_path = dataset_args.pop('data_path') + if data_path == None: + return None + + data_path = [x for x in data_path if x is not None] + if len(data_path) == 0: + return None + if 'collate_fn' not in config['loader'] or config['loader'][ + 'collate_fn'] is None or len(config['loader']['collate_fn']) == 0: + config['loader']['collate_fn'] = None + else: + config['loader']['collate_fn'] = eval(config['loader']['collate_fn'])() + + _dataset = get_dataset( + data_path=data_path, + module_name=dataset_name, + transform=img_transfroms, + dataset_args=dataset_args) + sampler = None + if distributed: + # 3)使用DistributedSampler + batch_sampler = DistributedBatchSampler( + dataset=_dataset, + batch_size=config['loader'].pop('batch_size'), + shuffle=config['loader'].pop('shuffle')) + else: + batch_sampler = BatchSampler( + dataset=_dataset, + batch_size=config['loader'].pop('batch_size'), + shuffle=config['loader'].pop('shuffle')) + loader = DataLoader( + dataset=_dataset, batch_sampler=batch_sampler, **config['loader']) + return loader diff --git a/benchmark/PaddleOCR_DBNet/data_loader/dataset.py b/benchmark/PaddleOCR_DBNet/data_loader/dataset.py new file mode 100644 index 0000000..29d3954 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/data_loader/dataset.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:54 +# @Author : zhoujun +import pathlib +import os +import cv2 +import numpy as np +import scipy.io as sio +from tqdm.auto import tqdm + +from base import BaseDataSet +from utils import order_points_clockwise, get_datalist, load, expand_polygon + + +class ICDAR2015Dataset(BaseDataSet): + def __init__(self, + data_path: str, + img_mode, + pre_processes, + filter_keys, + ignore_tags, + transform=None, + **kwargs): + super().__init__(data_path, img_mode, pre_processes, filter_keys, + ignore_tags, transform) + + def load_data(self, data_path: str) -> list: + data_list = get_datalist(data_path) + t_data_list = [] + for img_path, label_path in data_list: + data = self._get_annotation(label_path) + if len(data['text_polys']) > 0: + item = { + 'img_path': img_path, + 'img_name': pathlib.Path(img_path).stem + } + item.update(data) + t_data_list.append(item) + else: + print('there is no suit bbox in {}'.format(label_path)) + return t_data_list + + def _get_annotation(self, label_path: str) -> dict: + boxes = [] + texts = [] + ignores = [] + with open(label_path, encoding='utf-8', mode='r') as f: + for line in f.readlines(): + params = line.strip().strip('\ufeff').strip( + '\xef\xbb\xbf').split(',') + try: + box = order_points_clockwise( + np.array(list(map(float, params[:8]))).reshape(-1, 2)) + if cv2.contourArea(box) > 0: + boxes.append(box) + label = params[8] + texts.append(label) + ignores.append(label in self.ignore_tags) + except: + print('load label failed on {}'.format(label_path)) + data = { + 'text_polys': np.array(boxes), + 'texts': texts, + 'ignore_tags': ignores, + } + return data + + +class DetDataset(BaseDataSet): + def __init__(self, + data_path: str, + img_mode, + pre_processes, + filter_keys, + ignore_tags, + transform=None, + **kwargs): + self.load_char_annotation = kwargs['load_char_annotation'] + self.expand_one_char = kwargs['expand_one_char'] + super().__init__(data_path, img_mode, pre_processes, filter_keys, + ignore_tags, transform) + + def load_data(self, data_path: str) -> list: + """ + 从json文件中读取出 文本行的坐标和gt,字符的坐标和gt + :param data_path: + :return: + """ + data_list = [] + for path in data_path: + content = load(path) + for gt in tqdm( + content['data_list'], desc='read file {}'.format(path)): + img_path = os.path.join(content['data_root'], gt['img_name']) + polygons = [] + texts = [] + illegibility_list = [] + language_list = [] + for annotation in gt['annotations']: + if len(annotation['polygon']) == 0 or len(annotation[ + 'text']) == 0: + continue + if len(annotation['text']) > 1 and self.expand_one_char: + annotation['polygon'] = expand_polygon(annotation[ + 'polygon']) + polygons.append(annotation['polygon']) + texts.append(annotation['text']) + illegibility_list.append(annotation['illegibility']) + language_list.append(annotation['language']) + if self.load_char_annotation: + for char_annotation in annotation['chars']: + if len(char_annotation['polygon']) == 0 or len( + char_annotation['char']) == 0: + continue + polygons.append(char_annotation['polygon']) + texts.append(char_annotation['char']) + illegibility_list.append(char_annotation[ + 'illegibility']) + language_list.append(char_annotation['language']) + data_list.append({ + 'img_path': img_path, + 'img_name': gt['img_name'], + 'text_polys': np.array(polygons), + 'texts': texts, + 'ignore_tags': illegibility_list + }) + return data_list + + +class SynthTextDataset(BaseDataSet): + def __init__(self, + data_path: str, + img_mode, + pre_processes, + filter_keys, + transform=None, + **kwargs): + self.transform = transform + self.dataRoot = pathlib.Path(data_path) + if not self.dataRoot.exists(): + raise FileNotFoundError('Dataset folder is not exist.') + + self.targetFilePath = self.dataRoot / 'gt.mat' + if not self.targetFilePath.exists(): + raise FileExistsError('Target file is not exist.') + targets = {} + sio.loadmat( + self.targetFilePath, + targets, + squeeze_me=True, + struct_as_record=False, + variable_names=['imnames', 'wordBB', 'txt']) + + self.imageNames = targets['imnames'] + self.wordBBoxes = targets['wordBB'] + self.transcripts = targets['txt'] + super().__init__(data_path, img_mode, pre_processes, filter_keys, + transform) + + def load_data(self, data_path: str) -> list: + t_data_list = [] + for imageName, wordBBoxes, texts in zip( + self.imageNames, self.wordBBoxes, self.transcripts): + item = {} + wordBBoxes = np.expand_dims( + wordBBoxes, axis=2) if (wordBBoxes.ndim == 2) else wordBBoxes + _, _, numOfWords = wordBBoxes.shape + text_polys = wordBBoxes.reshape( + [8, numOfWords], order='F').T # num_words * 8 + text_polys = text_polys.reshape(numOfWords, 4, + 2) # num_of_words * 4 * 2 + transcripts = [word for line in texts for word in line.split()] + if numOfWords != len(transcripts): + continue + item['img_path'] = str(self.dataRoot / imageName) + item['img_name'] = (self.dataRoot / imageName).stem + item['text_polys'] = text_polys + item['texts'] = transcripts + item['ignore_tags'] = [x in self.ignore_tags for x in transcripts] + t_data_list.append(item) + return t_data_list diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/__init__.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/__init__.py new file mode 100644 index 0000000..bc055da --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/4 10:53 +# @Author : zhoujun +from .iaa_augment import IaaAugment +from .augment import * +from .random_crop_data import EastRandomCropData, PSERandomCrop +from .make_border_map import MakeBorderMap +from .make_shrink_map import MakeShrinkMap diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/augment.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/augment.py new file mode 100644 index 0000000..e81bc12 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/augment.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:52 +# @Author : zhoujun + +import math +import numbers +import random + +import cv2 +import numpy as np +from skimage.util import random_noise + + +class RandomNoise: + def __init__(self, random_rate): + self.random_rate = random_rate + + def __call__(self, data: dict): + """ + 对图片加噪声 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + if random.random() > self.random_rate: + return data + data['img'] = (random_noise( + data['img'], mode='gaussian', clip=True) * 255).astype(im.dtype) + return data + + +class RandomScale: + def __init__(self, scales, random_rate): + """ + :param scales: 尺度 + :param ramdon_rate: 随机系数 + :return: + """ + self.random_rate = random_rate + self.scales = scales + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + if random.random() > self.random_rate: + return data + im = data['img'] + text_polys = data['text_polys'] + + tmp_text_polys = text_polys.copy() + rd_scale = float(np.random.choice(self.scales)) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + tmp_text_polys *= rd_scale + + data['img'] = im + data['text_polys'] = tmp_text_polys + return data + + +class RandomRotateImgBox: + def __init__(self, degrees, random_rate, same_size=False): + """ + :param degrees: 角度,可以是一个数值或者list + :param ramdon_rate: 随机系数 + :param same_size: 是否保持和原图一样大 + :return: + """ + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError( + "If degrees is a single number, it must be positive.") + degrees = (-degrees, degrees) + elif isinstance(degrees, list) or isinstance( + degrees, tuple) or isinstance(degrees, np.ndarray): + if len(degrees) != 2: + raise ValueError( + "If degrees is a sequence, it must be of len 2.") + degrees = degrees + else: + raise Exception( + 'degrees must in Number or list or tuple or np.ndarray') + self.degrees = degrees + self.same_size = same_size + self.random_rate = random_rate + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + if random.random() > self.random_rate: + return data + im = data['img'] + text_polys = data['text_polys'] + + # ---------------------- 旋转图像 ---------------------- + w = im.shape[1] + h = im.shape[0] + angle = np.random.uniform(self.degrees[0], self.degrees[1]) + + if self.same_size: + nw = w + nh = h + else: + # 角度变弧度 + rangle = np.deg2rad(angle) + # 计算旋转之后图像的w, h + nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) + nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) + # 构造仿射矩阵 + rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1) + # 计算原图中心点到新图中心点的偏移量 + rot_move = np.dot(rot_mat, + np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) + # 更新仿射矩阵 + rot_mat[0, 2] += rot_move[0] + rot_mat[1, 2] += rot_move[1] + # 仿射变换 + rot_img = cv2.warpAffine( + im, + rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), + flags=cv2.INTER_LANCZOS4) + + # ---------------------- 矫正bbox坐标 ---------------------- + # rot_mat是最终的旋转矩阵 + # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下 + rot_text_polys = list() + for bbox in text_polys: + point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) + point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) + point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) + point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) + rot_text_polys.append([point1, point2, point3, point4]) + data['img'] = rot_img + data['text_polys'] = np.array(rot_text_polys) + return data + + +class RandomResize: + def __init__(self, size, random_rate, keep_ratio=False): + """ + :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h] + :param ramdon_rate: 随机系数 + :param keep_ratio: 是否保持长宽比 + :return: + """ + if isinstance(size, numbers.Number): + if size < 0: + raise ValueError( + "If input_size is a single number, it must be positive.") + size = (size, size) + elif isinstance(size, list) or isinstance(size, tuple) or isinstance( + size, np.ndarray): + if len(size) != 2: + raise ValueError( + "If input_size is a sequence, it must be of len 2.") + size = (size[0], size[1]) + else: + raise Exception( + 'input_size must in Number or list or tuple or np.ndarray') + self.size = size + self.keep_ratio = keep_ratio + self.random_rate = random_rate + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + if random.random() > self.random_rate: + return data + im = data['img'] + text_polys = data['text_polys'] + + if self.keep_ratio: + # 将图片短边pad到和长边一样 + h, w, c = im.shape + max_h = max(h, self.size[0]) + max_w = max(w, self.size[1]) + im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8) + im_padded[:h, :w] = im.copy() + im = im_padded + text_polys = text_polys.astype(np.float32) + h, w, _ = im.shape + im = cv2.resize(im, self.size) + w_scale = self.size[0] / float(w) + h_scale = self.size[1] / float(h) + text_polys[:, :, 0] *= w_scale + text_polys[:, :, 1] *= h_scale + + data['img'] = im + data['text_polys'] = text_polys + return data + + +def resize_image(img, short_size): + height, width, _ = img.shape + if height < width: + new_height = short_size + new_width = new_height / height * width + else: + new_width = short_size + new_height = new_width / width * height + new_height = int(round(new_height / 32) * 32) + new_width = int(round(new_width / 32) * 32) + resized_img = cv2.resize(img, (new_width, new_height)) + return resized_img, (new_width / width, new_height / height) + + +class ResizeShortSize: + def __init__(self, short_size, resize_text_polys=True): + """ + :param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h] + :return: + """ + self.short_size = short_size + self.resize_text_polys = resize_text_polys + + def __call__(self, data: dict) -> dict: + """ + 对图片和文本框进行缩放 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + im = data['img'] + text_polys = data['text_polys'] + + h, w, _ = im.shape + short_edge = min(h, w) + if short_edge < self.short_size: + # 保证短边 >= short_size + scale = self.short_size / short_edge + im = cv2.resize(im, dsize=None, fx=scale, fy=scale) + scale = (scale, scale) + # im, scale = resize_image(im, self.short_size) + if self.resize_text_polys: + # text_polys *= scale + text_polys[:, 0] *= scale[0] + text_polys[:, 1] *= scale[1] + + data['img'] = im + data['text_polys'] = text_polys + return data + + +class HorizontalFlip: + def __init__(self, random_rate): + """ + + :param random_rate: 随机系数 + """ + self.random_rate = random_rate + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + if random.random() > self.random_rate: + return data + im = data['img'] + text_polys = data['text_polys'] + + flip_text_polys = text_polys.copy() + flip_im = cv2.flip(im, 1) + h, w, _ = flip_im.shape + flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0] + + data['img'] = flip_im + data['text_polys'] = flip_text_polys + return data + + +class VerticallFlip: + def __init__(self, random_rate): + """ + + :param random_rate: 随机系数 + """ + self.random_rate = random_rate + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + if random.random() > self.random_rate: + return data + im = data['img'] + text_polys = data['text_polys'] + + flip_text_polys = text_polys.copy() + flip_im = cv2.flip(im, 0) + h, w, _ = flip_im.shape + flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1] + data['img'] = flip_im + data['text_polys'] = flip_text_polys + return data diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/iaa_augment.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/iaa_augment.py new file mode 100644 index 0000000..1cf891b --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/iaa_augment.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/4 18:06 +# @Author : zhoujun +import numpy as np +import imgaug +import imgaug.augmenters as iaa + + +class AugmenterBuilder(object): + def __init__(self): + pass + + def build(self, args, root=True): + if args is None or len(args) == 0: + return None + elif isinstance(args, list): + if root: + sequence = [self.build(value, root=False) for value in args] + return iaa.Sequential(sequence) + else: + return getattr( + iaa, + args[0])(* [self.to_tuple_if_list(a) for a in args[1:]]) + elif isinstance(args, dict): + cls = getattr(iaa, args['type']) + return cls(**{ + k: self.to_tuple_if_list(v) + for k, v in args['args'].items() + }) + else: + raise RuntimeError('unknown augmenter arg: ' + str(args)) + + def to_tuple_if_list(self, obj): + if isinstance(obj, list): + return tuple(obj) + return obj + + +class IaaAugment(): + def __init__(self, augmenter_args): + self.augmenter_args = augmenter_args + self.augmenter = AugmenterBuilder().build(self.augmenter_args) + + def __call__(self, data): + image = data['img'] + shape = image.shape + + if self.augmenter: + aug = self.augmenter.to_deterministic() + data['img'] = aug.augment_image(image) + data = self.may_augment_annotation(aug, data, shape) + return data + + def may_augment_annotation(self, aug, data, shape): + if aug is None: + return data + + line_polys = [] + for poly in data['text_polys']: + new_poly = self.may_augment_poly(aug, shape, poly) + line_polys.append(new_poly) + data['text_polys'] = np.array(line_polys) + return data + + def may_augment_poly(self, aug, img_shape, poly): + keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] + keypoints = aug.augment_keypoints( + [imgaug.KeypointsOnImage( + keypoints, shape=img_shape)])[0].keypoints + poly = [(p.x, p.y) for p in keypoints] + return poly diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/make_border_map.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/make_border_map.py new file mode 100644 index 0000000..2985f3c --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/make_border_map.py @@ -0,0 +1,143 @@ +import cv2 +import numpy as np +np.seterr(divide='ignore', invalid='ignore') +import pyclipper +from shapely.geometry import Polygon + + +class MakeBorderMap(): + def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7): + self.shrink_ratio = shrink_ratio + self.thresh_min = thresh_min + self.thresh_max = thresh_max + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + im = data['img'] + text_polys = data['text_polys'] + ignore_tags = data['ignore_tags'] + + canvas = np.zeros(im.shape[:2], dtype=np.float32) + mask = np.zeros(im.shape[:2], dtype=np.float32) + + for i in range(len(text_polys)): + if ignore_tags[i]: + continue + self.draw_border_map(text_polys[i], canvas, mask=mask) + canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min + + data['threshold_map'] = canvas + data['threshold_mask'] = mask + return data + + def draw_border_map(self, polygon, canvas, mask): + polygon = np.array(polygon) + assert polygon.ndim == 2 + assert polygon.shape[1] == 2 + + polygon_shape = Polygon(polygon) + if polygon_shape.area <= 0: + return + distance = polygon_shape.area * ( + 1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + + padded_polygon = np.array(padding.Execute(distance)[0]) + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) + + xmin = padded_polygon[:, 0].min() + xmax = padded_polygon[:, 0].max() + ymin = padded_polygon[:, 1].min() + ymax = padded_polygon[:, 1].max() + width = xmax - xmin + 1 + height = ymax - ymin + 1 + + polygon[:, 0] = polygon[:, 0] - xmin + polygon[:, 1] = polygon[:, 1] - ymin + + xs = np.broadcast_to( + np.linspace( + 0, width - 1, num=width).reshape(1, width), (height, width)) + ys = np.broadcast_to( + np.linspace( + 0, height - 1, num=height).reshape(height, 1), (height, width)) + + distance_map = np.zeros( + (polygon.shape[0], height, width), dtype=np.float32) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = distance_map.min(axis=0) + + xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) + xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) + ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) + ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( + 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height, + xmin_valid - xmin:xmax_valid - xmax + width], + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1]) + + def distance(self, xs, ys, point_1, point_2): + ''' + compute the distance from point to a line + ys: coordinates in the first axis + xs: coordinates in the second axis + point_1, point_2: (x, y), the end of the line + ''' + height, width = xs.shape[:2] + square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[ + 1]) + square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[ + 1]) + square_distance = np.square(point_1[0] - point_2[0]) + np.square( + point_1[1] - point_2[1]) + + cosin = (square_distance - square_distance_1 - square_distance_2) / ( + 2 * np.sqrt(square_distance_1 * square_distance_2)) + square_sin = 1 - np.square(cosin) + square_sin = np.nan_to_num(square_sin) + + result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / + square_distance) + result[cosin < + 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin + < 0] + # self.extend_line(point_1, point_2, result) + return result + + def extend_line(self, point_1, point_2, result): + ex_point_1 = (int( + round(point_1[0] + (point_1[0] - point_2[0]) * ( + 1 + self.shrink_ratio))), int( + round(point_1[1] + (point_1[1] - point_2[1]) * ( + 1 + self.shrink_ratio)))) + cv2.line( + result, + tuple(ex_point_1), + tuple(point_1), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0) + ex_point_2 = (int( + round(point_2[0] + (point_2[0] - point_1[0]) * ( + 1 + self.shrink_ratio))), int( + round(point_2[1] + (point_2[1] - point_1[1]) * ( + 1 + self.shrink_ratio)))) + cv2.line( + result, + tuple(ex_point_2), + tuple(point_2), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0) + return ex_point_1, ex_point_2 diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/make_shrink_map.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/make_shrink_map.py new file mode 100644 index 0000000..3f268b9 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/make_shrink_map.py @@ -0,0 +1,133 @@ +import numpy as np +import cv2 + + +def shrink_polygon_py(polygon, shrink_ratio): + """ + 对框进行缩放,返回去的比例为1/shrink_ratio 即可 + """ + cx = polygon[:, 0].mean() + cy = polygon[:, 1].mean() + polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio + polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio + return polygon + + +def shrink_polygon_pyclipper(polygon, shrink_ratio): + from shapely.geometry import Polygon + import pyclipper + polygon_shape = Polygon(polygon) + distance = polygon_shape.area * ( + 1 - np.power(shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrinked = padding.Execute(-distance) + if shrinked == []: + shrinked = np.array(shrinked) + else: + shrinked = np.array(shrinked[0]).reshape(-1, 2) + return shrinked + + +class MakeShrinkMap(): + r''' + Making binary mask from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + ''' + + def __init__(self, + min_text_size=8, + shrink_ratio=0.4, + shrink_type='pyclipper'): + shrink_func_dict = { + 'py': shrink_polygon_py, + 'pyclipper': shrink_polygon_pyclipper + } + self.shrink_func = shrink_func_dict[shrink_type] + self.min_text_size = min_text_size + self.shrink_ratio = shrink_ratio + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + image = data['img'] + text_polys = data['text_polys'] + ignore_tags = data['ignore_tags'] + + h, w = image.shape[:2] + text_polys, ignore_tags = self.validate_polygons(text_polys, + ignore_tags, h, w) + gt = np.zeros((h, w), dtype=np.float32) + mask = np.ones((h, w), dtype=np.float32) + for i in range(len(text_polys)): + polygon = text_polys[i] + height = max(polygon[:, 1]) - min(polygon[:, 1]) + width = max(polygon[:, 0]) - min(polygon[:, 0]) + if ignore_tags[i] or min(height, width) < self.min_text_size: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + else: + shrinked = self.shrink_func(polygon, self.shrink_ratio) + if shrinked.size == 0: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + continue + cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1) + + data['shrink_map'] = gt + data['shrink_mask'] = mask + return data + + def validate_polygons(self, polygons, ignore_tags, h, w): + ''' + polygons (numpy.array, required): of shape (num_instances, num_points, 2) + ''' + if len(polygons) == 0: + return polygons, ignore_tags + assert len(polygons) == len(ignore_tags) + for polygon in polygons: + polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) + polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) + + for i in range(len(polygons)): + area = self.polygon_area(polygons[i]) + if abs(area) < 1: + ignore_tags[i] = True + if area > 0: + polygons[i] = polygons[i][::-1, :] + return polygons, ignore_tags + + def polygon_area(self, polygon): + return cv2.contourArea(polygon) + # edge = 0 + # for i in range(polygon.shape[0]): + # next_index = (i + 1) % polygon.shape[0] + # edge += (polygon[next_index, 0] - polygon[i, 0]) * (polygon[next_index, 1] - polygon[i, 1]) + # + # return edge / 2. + + +if __name__ == '__main__': + from shapely.geometry import Polygon + import pyclipper + + polygon = np.array([[0, 0], [100, 10], [100, 100], [10, 90]]) + a = shrink_polygon_py(polygon, 0.4) + print(a) + print(shrink_polygon_py(a, 1 / 0.4)) + b = shrink_polygon_pyclipper(polygon, 0.4) + print(b) + poly = Polygon(b) + distance = poly.area * 1.5 / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(b, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + bounding_box = cv2.minAreaRect(expanded) + points = cv2.boxPoints(bounding_box) + print(points) diff --git a/benchmark/PaddleOCR_DBNet/data_loader/modules/random_crop_data.py b/benchmark/PaddleOCR_DBNet/data_loader/modules/random_crop_data.py new file mode 100644 index 0000000..fac2e4c --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/data_loader/modules/random_crop_data.py @@ -0,0 +1,206 @@ +import random + +import cv2 +import numpy as np + + +# random crop algorithm similar to https://github.com/argman/EAST +class EastRandomCropData(): + def __init__(self, + size=(640, 640), + max_tries=50, + min_crop_side_ratio=0.1, + require_original_image=False, + keep_ratio=True): + self.size = size + self.max_tries = max_tries + self.min_crop_side_ratio = min_crop_side_ratio + self.require_original_image = require_original_image + self.keep_ratio = keep_ratio + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'img':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + im = data['img'] + text_polys = data['text_polys'] + ignore_tags = data['ignore_tags'] + texts = data['texts'] + all_care_polys = [ + text_polys[i] for i, tag in enumerate(ignore_tags) if not tag + ] + # 计算crop区域 + crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys) + # crop 图片 保持比例填充 + scale_w = self.size[0] / crop_w + scale_h = self.size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + if self.keep_ratio: + if len(im.shape) == 3: + padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), + im.dtype) + else: + padimg = np.zeros((self.size[1], self.size[0]), im.dtype) + padimg[:h, :w] = cv2.resize( + im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + img = padimg + else: + img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], + tuple(self.size)) + # crop 文本框 + text_polys_crop = [] + ignore_tags_crop = [] + texts_crop = [] + for poly, text, tag in zip(text_polys, texts, ignore_tags): + poly = ((poly - (crop_x, crop_y)) * scale).tolist() + if not self.is_poly_outside_rect(poly, 0, 0, w, h): + text_polys_crop.append(poly) + ignore_tags_crop.append(tag) + texts_crop.append(text) + data['img'] = img + data['text_polys'] = np.float32(text_polys_crop) + data['ignore_tags'] = ignore_tags_crop + data['texts'] = texts_crop + return data + + def is_poly_in_rect(self, poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].min() < x or poly[:, 0].max() > x + w: + return False + if poly[:, 1].min() < y or poly[:, 1].max() > y + h: + return False + return True + + def is_poly_outside_rect(self, poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].max() < x or poly[:, 0].min() > x + w: + return True + if poly[:, 1].max() < y or poly[:, 1].min() > y + h: + return True + return False + + def split_regions(self, axis): + regions = [] + min_axis = 0 + for i in range(1, axis.shape[0]): + if axis[i] != axis[i - 1] + 1: + region = axis[min_axis:i] + min_axis = i + regions.append(region) + return regions + + def random_select(self, axis, max_size): + xx = np.random.choice(axis, size=2) + xmin = np.min(xx) + xmax = np.max(xx) + xmin = np.clip(xmin, 0, max_size - 1) + xmax = np.clip(xmax, 0, max_size - 1) + return xmin, xmax + + def region_wise_random_select(self, regions, max_size): + selected_index = list(np.random.choice(len(regions), 2)) + selected_values = [] + for index in selected_index: + axis = regions[index] + xx = int(np.random.choice(axis, size=1)) + selected_values.append(xx) + xmin = min(selected_values) + xmax = max(selected_values) + return xmin, xmax + + def crop_area(self, im, text_polys): + h, w = im.shape[:2] + h_array = np.zeros(h, dtype=np.int32) + w_array = np.zeros(w, dtype=np.int32) + for points in text_polys: + points = np.round(points, decimals=0).astype(np.int32) + minx = np.min(points[:, 0]) + maxx = np.max(points[:, 0]) + w_array[minx:maxx] = 1 + miny = np.min(points[:, 1]) + maxy = np.max(points[:, 1]) + h_array[miny:maxy] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + + if len(h_axis) == 0 or len(w_axis) == 0: + return 0, 0, w, h + + h_regions = self.split_regions(h_axis) + w_regions = self.split_regions(w_axis) + + for i in range(self.max_tries): + if len(w_regions) > 1: + xmin, xmax = self.region_wise_random_select(w_regions, w) + else: + xmin, xmax = self.random_select(w_axis, w) + if len(h_regions) > 1: + ymin, ymax = self.region_wise_random_select(h_regions, h) + else: + ymin, ymax = self.random_select(h_axis, h) + + if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h: + # area too small + continue + num_poly_in_rect = 0 + for poly in text_polys: + if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, + ymax - ymin): + num_poly_in_rect += 1 + break + + if num_poly_in_rect > 0: + return xmin, ymin, xmax - xmin, ymax - ymin + + return 0, 0, w, h + + +class PSERandomCrop(): + def __init__(self, size): + self.size = size + + def __call__(self, data): + imgs = data['imgs'] + + h, w = imgs[0].shape[0:2] + th, tw = self.size + if w == tw and h == th: + return imgs + + # label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制 + if np.max(imgs[2]) > 0 and random.random() > 3 / 8: + # 文本实例的左上角点 + tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size + tl[tl < 0] = 0 + # 文本实例的右下角点 + br = np.max(np.where(imgs[2] > 0), axis=1) - self.size + br[br < 0] = 0 + # 保证选到右下角点时,有足够的距离进行crop + br[0] = min(br[0], h - th) + br[1] = min(br[1], w - tw) + + for _ in range(50000): + i = random.randint(tl[0], br[0]) + j = random.randint(tl[1], br[1]) + # 保证shrink_label_map有文本 + if imgs[1][i:i + th, j:j + tw].sum() <= 0: + continue + else: + break + else: + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + + # return i, j, th, tw + for idx in range(len(imgs)): + if len(imgs[idx].shape) == 3: + imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] + else: + imgs[idx] = imgs[idx][i:i + th, j:j + tw] + data['imgs'] = imgs + return data diff --git a/benchmark/PaddleOCR_DBNet/environment.yml b/benchmark/PaddleOCR_DBNet/environment.yml new file mode 100644 index 0000000..571dbf2 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/environment.yml @@ -0,0 +1,21 @@ +name: dbnet +channels: + - conda-forge + - defaults +dependencies: + - anyconfig==0.9.10 + - future==0.18.2 + - imgaug==0.4.0 + - matplotlib==3.1.2 + - numpy==1.17.4 + - opencv + - pyclipper + - PyYAML==5.2 + - scikit-image==0.16.2 + - Shapely==1.6.4 + - tensorboard=2 + - tqdm==4.40.1 + - ipython + - pip + - pip: + - polygon3 diff --git a/benchmark/PaddleOCR_DBNet/eval.sh b/benchmark/PaddleOCR_DBNet/eval.sh new file mode 100644 index 0000000..b3bf468 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/eval.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py --model_path '' \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/generate_lists.sh b/benchmark/PaddleOCR_DBNet/generate_lists.sh new file mode 100644 index 0000000..84f408c --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/generate_lists.sh @@ -0,0 +1,17 @@ +#Only use if your file names of the images and txts are identical +rm ./datasets/train_img.txt +rm ./datasets/train_gt.txt +rm ./datasets/test_img.txt +rm ./datasets/test_gt.txt +rm ./datasets/train.txt +rm ./datasets/test.txt +ls ./datasets/train/img/*.jpg > ./datasets/train_img.txt +ls ./datasets/train/gt/*.txt > ./datasets/train_gt.txt +ls ./datasets/test/img/*.jpg > ./datasets/test_img.txt +ls ./datasets/test/gt/*.txt > ./datasets/test_gt.txt +paste ./datasets/train_img.txt ./datasets/train_gt.txt > ./datasets/train.txt +paste ./datasets/test_img.txt ./datasets/test_gt.txt > ./datasets/test.txt +rm ./datasets/train_img.txt +rm ./datasets/train_gt.txt +rm ./datasets/test_img.txt +rm ./datasets/test_gt.txt diff --git a/benchmark/PaddleOCR_DBNet/imgs/paper/db.jpg b/benchmark/PaddleOCR_DBNet/imgs/paper/db.jpg new file mode 100644 index 0000000..aa6c7e9 Binary files /dev/null and b/benchmark/PaddleOCR_DBNet/imgs/paper/db.jpg differ diff --git a/benchmark/PaddleOCR_DBNet/models/__init__.py b/benchmark/PaddleOCR_DBNet/models/__init__.py new file mode 100644 index 0000000..26ff73f --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:55 +# @Author : zhoujun +import copy +from .model import Model +from .losses import build_loss + +__all__ = ['build_loss', 'build_model'] +support_model = ['Model'] + + +def build_model(config): + """ + get architecture model class + """ + copy_config = copy.deepcopy(config) + arch_type = copy_config.pop('type') + assert arch_type in support_model, f'{arch_type} is not developed yet!, only {support_model} are support now' + arch_model = eval(arch_type)(copy_config) + return arch_model diff --git a/benchmark/PaddleOCR_DBNet/models/backbone/__init__.py b/benchmark/PaddleOCR_DBNet/models/backbone/__init__.py new file mode 100644 index 0000000..740c8d5 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/backbone/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:54 +# @Author : zhoujun + +from .resnet import * + +__all__ = ['build_backbone'] + +support_backbone = [ + 'resnet18', 'deformable_resnet18', 'deformable_resnet50', 'resnet50', + 'resnet34', 'resnet101', 'resnet152' +] + + +def build_backbone(backbone_name, **kwargs): + assert backbone_name in support_backbone, f'all support backbone is {support_backbone}' + backbone = eval(backbone_name)(**kwargs) + return backbone diff --git a/benchmark/PaddleOCR_DBNet/models/backbone/resnet.py b/benchmark/PaddleOCR_DBNet/models/backbone/resnet.py new file mode 100644 index 0000000..9b30b38 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/backbone/resnet.py @@ -0,0 +1,375 @@ +import math +import paddle +from paddle import nn + +BatchNorm2d = nn.BatchNorm2D + +__all__ = [ + 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'deformable_resnet18', 'deformable_resnet50', 'resnet152' +] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def constant_init(module, constant, bias=0): + module.weight = paddle.create_parameter( + shape=module.weight.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(constant)) + if hasattr(module, 'bias'): + module.bias = paddle.create_parameter( + shape=module.bias.shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(bias)) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): + super(BasicBlock, self).__init__() + self.with_dcn = dcn is not None + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes, momentum=0.1) + self.relu = nn.ReLU() + self.with_modulated_dcn = False + if not self.with_dcn: + self.conv2 = nn.Conv2D( + planes, planes, kernel_size=3, padding=1, bias_attr=False) + else: + from paddle.version.ops import DeformConv2D + deformable_groups = dcn.get('deformable_groups', 1) + offset_channels = 18 + self.conv2_offset = nn.Conv2D( + planes, + deformable_groups * offset_channels, + kernel_size=3, + padding=1) + self.conv2 = DeformConv2D( + planes, planes, kernel_size=3, padding=1, bias_attr=False) + self.bn2 = BatchNorm2d(planes, momentum=0.1) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # out = self.conv2(out) + if not self.with_dcn: + out = self.conv2(out) + else: + offset = self.conv2_offset(out) + out = self.conv2(out, offset) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Layer): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): + super(Bottleneck, self).__init__() + self.with_dcn = dcn is not None + self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=1, bias_attr=False) + self.bn1 = BatchNorm2d(planes, momentum=0.1) + self.with_modulated_dcn = False + if not self.with_dcn: + self.conv2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + else: + deformable_groups = dcn.get('deformable_groups', 1) + from paddle.vision.ops import DeformConv2D + offset_channels = 18 + self.conv2_offset = nn.Conv2D( + planes, + deformable_groups * offset_channels, + stride=stride, + kernel_size=3, + padding=1) + self.conv2 = DeformConv2D( + planes, + planes, + kernel_size=3, + padding=1, + stride=stride, + bias_attr=False) + self.bn2 = BatchNorm2d(planes, momentum=0.1) + self.conv3 = nn.Conv2D( + planes, planes * 4, kernel_size=1, bias_attr=False) + self.bn3 = BatchNorm2d(planes * 4, momentum=0.1) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + self.dcn = dcn + self.with_dcn = dcn is not None + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # out = self.conv2(out) + if not self.with_dcn: + out = self.conv2(out) + else: + offset = self.conv2_offset(out) + out = self.conv2(out, offset) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Layer): + def __init__(self, block, layers, in_channels=3, dcn=None): + self.dcn = dcn + self.inplanes = 64 + super(ResNet, self).__init__() + self.out_channels = [] + self.conv1 = nn.Conv2D( + in_channels, + 64, + kernel_size=7, + stride=2, + padding=3, + bias_attr=False) + self.bn1 = BatchNorm2d(64, momentum=0.1) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dcn=dcn) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dcn=dcn) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dcn=dcn) + + if self.dcn is not None: + for m in self.modules(): + if isinstance(m, Bottleneck) or isinstance(m, BasicBlock): + if hasattr(m, 'conv2_offset'): + constant_init(m.conv2_offset, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dcn=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias_attr=False), + BatchNorm2d( + planes * block.expansion, momentum=0.1), ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dcn=dcn)) + self.out_channels.append(planes * block.expansion) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x2 = self.layer1(x) + x3 = self.layer2(x2) + x4 = self.layer3(x3) + x5 = self.layer4(x4) + + return x2, x3, x4, x5 + + +def load_torch_params(paddle_model, torch_patams): + paddle_params = paddle_model.state_dict() + + fc_names = ['classifier'] + for key, torch_value in torch_patams.items(): + if 'num_batches_tracked' in key: + continue + key = key.replace("running_var", "_variance").replace( + "running_mean", "_mean").replace("module.", "") + torch_value = torch_value.detach().cpu().numpy() + if key in paddle_params: + flag = [i in key for i in fc_names] + if any(flag) and "weight" in key: # ignore bias + new_shape = [1, 0] + list(range(2, torch_value.ndim)) + print( + f"name: {key}, ori shape: {torch_value.shape}, new shape: {torch_value.transpose(new_shape).shape}" + ) + torch_value = torch_value.transpose(new_shape) + paddle_params[key] = torch_value + else: + print(f'{key} not in paddle') + paddle_model.set_state_dict(paddle_params) + + +def load_models(model, model_name): + import torch.utils.model_zoo as model_zoo + torch_patams = model_zoo.load_url(model_urls[model_name]) + load_torch_params(model, torch_patams) + + +def resnet18(pretrained=True, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + assert kwargs.get( + 'in_channels', + 3) == 3, 'in_channels must be 3 whem pretrained is True' + print('load from imagenet') + load_models(model, 'resnet18') + return model + + +def deformable_resnet18(pretrained=True, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet( + BasicBlock, [2, 2, 2, 2], dcn=dict(deformable_groups=1), **kwargs) + if pretrained: + assert kwargs.get( + 'in_channels', + 3) == 3, 'in_channels must be 3 whem pretrained is True' + print('load from imagenet') + model.load_state_dict( + model_zoo.load_url(model_urls['resnet18']), strict=False) + return model + + +def resnet34(pretrained=True, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + assert kwargs.get( + 'in_channels', + 3) == 3, 'in_channels must be 3 whem pretrained is True' + model.load_state_dict( + model_zoo.load_url(model_urls['resnet34']), strict=False) + return model + + +def resnet50(pretrained=True, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + assert kwargs.get( + 'in_channels', + 3) == 3, 'in_channels must be 3 whem pretrained is True' + load_models(model, 'resnet50') + return model + + +def deformable_resnet50(pretrained=True, **kwargs): + """Constructs a ResNet-50 model with deformable conv. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet( + Bottleneck, [3, 4, 6, 3], dcn=dict(deformable_groups=1), **kwargs) + if pretrained: + assert kwargs.get( + 'in_channels', + 3) == 3, 'in_channels must be 3 whem pretrained is True' + model.load_state_dict( + model_zoo.load_url(model_urls['resnet50']), strict=False) + return model + + +def resnet101(pretrained=True, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + assert kwargs.get( + 'in_channels', + 3) == 3, 'in_channels must be 3 whem pretrained is True' + model.load_state_dict( + model_zoo.load_url(model_urls['resnet101']), strict=False) + return model + + +def resnet152(pretrained=True, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + assert kwargs.get( + 'in_channels', + 3) == 3, 'in_channels must be 3 whem pretrained is True' + model.load_state_dict( + model_zoo.load_url(model_urls['resnet152']), strict=False) + return model + + +if __name__ == '__main__': + + x = paddle.zeros([2, 3, 640, 640]) + net = resnet50(pretrained=True) + y = net(x) + for u in y: + print(u.shape) + + print(net.out_channels) diff --git a/benchmark/PaddleOCR_DBNet/models/basic.py b/benchmark/PaddleOCR_DBNet/models/basic.py new file mode 100644 index 0000000..f661878 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/basic.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/6 11:19 +# @Author : zhoujun +from paddle import nn + + +class ConvBnRelu(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + inplace=True): + super().__init__() + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias_attr=bias, + padding_mode=padding_mode) + self.bn = nn.BatchNorm2D(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x diff --git a/benchmark/PaddleOCR_DBNet/models/head/DBHead.py b/benchmark/PaddleOCR_DBNet/models/head/DBHead.py new file mode 100644 index 0000000..29277ce --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/head/DBHead.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/4 14:54 +# @Author : zhoujun +import paddle +from paddle import nn, ParamAttr + + +class DBHead(nn.Layer): + def __init__(self, in_channels, out_channels, k=50): + super().__init__() + self.k = k + self.binarize = nn.Sequential( + nn.Conv2D( + in_channels, + in_channels // 4, + 3, + padding=1, + weight_attr=ParamAttr( + initializer=nn.initializer.KaimingNormal())), + nn.BatchNorm2D( + in_channels // 4, + weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)), + bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))), + nn.ReLU(), + nn.Conv2DTranspose( + in_channels // 4, + in_channels // 4, + 2, + 2, + weight_attr=ParamAttr( + initializer=nn.initializer.KaimingNormal())), + nn.BatchNorm2D( + in_channels // 4, + weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)), + bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))), + nn.ReLU(), + nn.Conv2DTranspose( + in_channels // 4, + 1, + 2, + 2, + weight_attr=nn.initializer.KaimingNormal()), + nn.Sigmoid()) + + self.thresh = self._init_thresh(in_channels) + + def forward(self, x): + shrink_maps = self.binarize(x) + threshold_maps = self.thresh(x) + if self.training: + binary_maps = self.step_function(shrink_maps, threshold_maps) + y = paddle.concat( + (shrink_maps, threshold_maps, binary_maps), axis=1) + else: + y = paddle.concat((shrink_maps, threshold_maps), axis=1) + return y + + def _init_thresh(self, + inner_channels, + serial=False, + smooth=False, + bias=False): + in_channels = inner_channels + if serial: + in_channels += 1 + self.thresh = nn.Sequential( + nn.Conv2D( + in_channels, + inner_channels // 4, + 3, + padding=1, + bias_attr=bias, + weight_attr=ParamAttr( + initializer=nn.initializer.KaimingNormal())), + nn.BatchNorm2D( + inner_channels // 4, + weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)), + bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))), + nn.ReLU(), + self._init_upsample( + inner_channels // 4, + inner_channels // 4, + smooth=smooth, + bias=bias), + nn.BatchNorm2D( + inner_channels // 4, + weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)), + bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))), + nn.ReLU(), + self._init_upsample( + inner_channels // 4, 1, smooth=smooth, bias=bias), + nn.Sigmoid()) + return self.thresh + + def _init_upsample(self, + in_channels, + out_channels, + smooth=False, + bias=False): + if smooth: + inter_out_channels = out_channels + if out_channels == 1: + inter_out_channels = in_channels + module_list = [ + nn.Upsample( + scale_factor=2, mode='nearest'), nn.Conv2D( + in_channels, + inter_out_channels, + 3, + 1, + 1, + bias_attr=bias, + weight_attr=ParamAttr( + initializer=nn.initializer.KaimingNormal())) + ] + if out_channels == 1: + module_list.append( + nn.Conv2D( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=1, + bias_attr=True, + weight_attr=ParamAttr( + initializer=nn.initializer.KaimingNormal()))) + return nn.Sequential(module_list) + else: + return nn.Conv2DTranspose( + in_channels, + out_channels, + 2, + 2, + weight_attr=ParamAttr( + initializer=nn.initializer.KaimingNormal())) + + def step_function(self, x, y): + return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y))) diff --git a/benchmark/PaddleOCR_DBNet/models/head/__init__.py b/benchmark/PaddleOCR_DBNet/models/head/__init__.py new file mode 100644 index 0000000..5610c69 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/head/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/6/5 11:35 +# @Author : zhoujun +from .DBHead import DBHead + +__all__ = ['build_head'] +support_head = ['DBHead'] + + +def build_head(head_name, **kwargs): + assert head_name in support_head, f'all support head is {support_head}' + head = eval(head_name)(**kwargs) + return head \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/models/losses/DB_loss.py b/benchmark/PaddleOCR_DBNet/models/losses/DB_loss.py new file mode 100644 index 0000000..74d240c --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/losses/DB_loss.py @@ -0,0 +1,49 @@ +import paddle +from models.losses.basic_loss import BalanceCrossEntropyLoss, MaskL1Loss, DiceLoss + + +class DBLoss(paddle.nn.Layer): + def __init__(self, + alpha=1.0, + beta=10, + ohem_ratio=3, + reduction='mean', + eps=1e-06): + """ + Implement PSE Loss. + :param alpha: binary_map loss 前面的系数 + :param beta: threshold_map loss 前面的系数 + :param ohem_ratio: OHEM的比例 + :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和 + """ + super().__init__() + assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']" + self.alpha = alpha + self.beta = beta + self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio) + self.dice_loss = DiceLoss(eps=eps) + self.l1_loss = MaskL1Loss(eps=eps) + self.ohem_ratio = ohem_ratio + self.reduction = reduction + + def forward(self, pred, batch): + shrink_maps = pred[:, 0, :, :] + threshold_maps = pred[:, 1, :, :] + binary_maps = pred[:, 2, :, :] + loss_shrink_maps = self.bce_loss(shrink_maps, batch['shrink_map'], + batch['shrink_mask']) + loss_threshold_maps = self.l1_loss( + threshold_maps, batch['threshold_map'], batch['threshold_mask']) + metrics = dict( + loss_shrink_maps=loss_shrink_maps, + loss_threshold_maps=loss_threshold_maps) + if pred.shape[1] > 2: + loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'], + batch['shrink_mask']) + metrics['loss_binary_maps'] = loss_binary_maps + loss_all = (self.alpha * loss_shrink_maps + self.beta * + loss_threshold_maps + loss_binary_maps) + metrics['loss'] = loss_all + else: + metrics['loss'] = loss_shrink_maps + return metrics diff --git a/benchmark/PaddleOCR_DBNet/models/losses/__init__.py b/benchmark/PaddleOCR_DBNet/models/losses/__init__.py new file mode 100644 index 0000000..9dc0f10 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/losses/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/6/5 11:36 +# @Author : zhoujun +import copy +from .DB_loss import DBLoss + +__all__ = ['build_loss'] +support_loss = ['DBLoss'] + + +def build_loss(config): + copy_config = copy.deepcopy(config) + loss_type = copy_config.pop('type') + assert loss_type in support_loss, f'all support loss is {support_loss}' + criterion = eval(loss_type)(**copy_config) + return criterion diff --git a/benchmark/PaddleOCR_DBNet/models/losses/basic_loss.py b/benchmark/PaddleOCR_DBNet/models/losses/basic_loss.py new file mode 100644 index 0000000..8e68cb1 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/losses/basic_loss.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/4 14:39 +# @Author : zhoujun +import paddle +import paddle.nn as nn + + +class BalanceCrossEntropyLoss(nn.Layer): + ''' + Balanced cross entropy loss. + Shape: + - Input: :math:`(N, 1, H, W)` + - GT: :math:`(N, 1, H, W)`, same shape as the input + - Mask: :math:`(N, H, W)`, same spatial shape as the input + - Output: scalar. + + ''' + + def __init__(self, negative_ratio=3.0, eps=1e-6): + super(BalanceCrossEntropyLoss, self).__init__() + self.negative_ratio = negative_ratio + self.eps = eps + + def forward(self, + pred: paddle.Tensor, + gt: paddle.Tensor, + mask: paddle.Tensor, + return_origin=False): + ''' + Args: + pred: shape :math:`(N, 1, H, W)`, the prediction of network + gt: shape :math:`(N, 1, H, W)`, the target + mask: shape :math:`(N, H, W)`, the mask indicates positive regions + ''' + positive = (gt * mask) + negative = ((1 - gt) * mask) + positive_count = int(positive.sum()) + negative_count = min( + int(negative.sum()), int(positive_count * self.negative_ratio)) + loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none') + positive_loss = loss * positive + negative_loss = loss * negative + negative_loss, _ = negative_loss.reshape([-1]).topk(negative_count) + + balance_loss = (positive_loss.sum() + negative_loss.sum()) / ( + positive_count + negative_count + self.eps) + + if return_origin: + return balance_loss, loss + return balance_loss + + +class DiceLoss(nn.Layer): + ''' + Loss function from https://arxiv.org/abs/1707.03237, + where iou computation is introduced heatmap manner to measure the + diversity bwtween tow heatmaps. + ''' + + def __init__(self, eps=1e-6): + super(DiceLoss, self).__init__() + self.eps = eps + + def forward(self, pred: paddle.Tensor, gt, mask, weights=None): + ''' + pred: one or two heatmaps of shape (N, 1, H, W), + the losses of tow heatmaps are added together. + gt: (N, 1, H, W) + mask: (N, H, W) + ''' + return self._compute(pred, gt, mask, weights) + + def _compute(self, pred, gt, mask, weights): + if len(pred.shape) == 4: + pred = pred[:, 0, :, :] + gt = gt[:, 0, :, :] + assert pred.shape == gt.shape + assert pred.shape == mask.shape + if weights is not None: + assert weights.shape == mask.shape + mask = weights * mask + intersection = (pred * gt * mask).sum() + + union = (pred * mask).sum() + (gt * mask).sum() + self.eps + loss = 1 - 2.0 * intersection / union + assert loss <= 1 + return loss + + +class MaskL1Loss(nn.Layer): + def __init__(self, eps=1e-6): + super(MaskL1Loss, self).__init__() + self.eps = eps + + def forward(self, pred: paddle.Tensor, gt, mask): + loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps) + return loss diff --git a/benchmark/PaddleOCR_DBNet/models/model.py b/benchmark/PaddleOCR_DBNet/models/model.py new file mode 100644 index 0000000..ee24ff5 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/model.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:57 +# @Author : zhoujun +from addict import Dict +from paddle import nn +import paddle.nn.functional as F + +from models.backbone import build_backbone +from models.neck import build_neck +from models.head import build_head + + +class Model(nn.Layer): + def __init__(self, model_config: dict): + """ + PANnet + :param model_config: 模型配置 + """ + super().__init__() + model_config = Dict(model_config) + backbone_type = model_config.backbone.pop('type') + neck_type = model_config.neck.pop('type') + head_type = model_config.head.pop('type') + self.backbone = build_backbone(backbone_type, **model_config.backbone) + self.neck = build_neck( + neck_type, + in_channels=self.backbone.out_channels, + **model_config.neck) + self.head = build_head( + head_type, in_channels=self.neck.out_channels, **model_config.head) + self.name = f'{backbone_type}_{neck_type}_{head_type}' + + def forward(self, x): + _, _, H, W = x.shape + backbone_out = self.backbone(x) + neck_out = self.neck(backbone_out) + y = self.head(neck_out) + y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=True) + return y diff --git a/benchmark/PaddleOCR_DBNet/models/neck/FPN.py b/benchmark/PaddleOCR_DBNet/models/neck/FPN.py new file mode 100644 index 0000000..53a3fa4 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/neck/FPN.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/9/13 10:29 +# @Author : zhoujun +import paddle +import paddle.nn.functional as F +from paddle import nn + +from models.basic import ConvBnRelu + + +class FPN(nn.Layer): + def __init__(self, in_channels, inner_channels=256, **kwargs): + """ + :param in_channels: 基础网络输出的维度 + :param kwargs: + """ + super().__init__() + inplace = True + self.conv_out = inner_channels + inner_channels = inner_channels // 4 + # reduce layers + self.reduce_conv_c2 = ConvBnRelu( + in_channels[0], inner_channels, kernel_size=1, inplace=inplace) + self.reduce_conv_c3 = ConvBnRelu( + in_channels[1], inner_channels, kernel_size=1, inplace=inplace) + self.reduce_conv_c4 = ConvBnRelu( + in_channels[2], inner_channels, kernel_size=1, inplace=inplace) + self.reduce_conv_c5 = ConvBnRelu( + in_channels[3], inner_channels, kernel_size=1, inplace=inplace) + # Smooth layers + self.smooth_p4 = ConvBnRelu( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + inplace=inplace) + self.smooth_p3 = ConvBnRelu( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + inplace=inplace) + self.smooth_p2 = ConvBnRelu( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + inplace=inplace) + + self.conv = nn.Sequential( + nn.Conv2D( + self.conv_out, + self.conv_out, + kernel_size=3, + padding=1, + stride=1), + nn.BatchNorm2D(self.conv_out), + nn.ReLU()) + self.out_channels = self.conv_out + + def forward(self, x): + c2, c3, c4, c5 = x + # Top-down + p5 = self.reduce_conv_c5(c5) + p4 = self._upsample_add(p5, self.reduce_conv_c4(c4)) + p4 = self.smooth_p4(p4) + p3 = self._upsample_add(p4, self.reduce_conv_c3(c3)) + p3 = self.smooth_p3(p3) + p2 = self._upsample_add(p3, self.reduce_conv_c2(c2)) + p2 = self.smooth_p2(p2) + + x = self._upsample_cat(p2, p3, p4, p5) + x = self.conv(x) + return x + + def _upsample_add(self, x, y): + return F.interpolate(x, size=y.shape[2:]) + y + + def _upsample_cat(self, p2, p3, p4, p5): + h, w = p2.shape[2:] + p3 = F.interpolate(p3, size=(h, w)) + p4 = F.interpolate(p4, size=(h, w)) + p5 = F.interpolate(p5, size=(h, w)) + return paddle.concat([p2, p3, p4, p5], axis=1) diff --git a/benchmark/PaddleOCR_DBNet/models/neck/__init__.py b/benchmark/PaddleOCR_DBNet/models/neck/__init__.py new file mode 100644 index 0000000..7655341 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/models/neck/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/6/5 11:34 +# @Author : zhoujun +from .FPN import FPN + +__all__ = ['build_neck'] +support_neck = ['FPN'] + + +def build_neck(neck_name, **kwargs): + assert neck_name in support_neck, f'all support neck is {support_neck}' + neck = eval(neck_name)(**kwargs) + return neck diff --git a/benchmark/PaddleOCR_DBNet/multi_gpu_train.sh b/benchmark/PaddleOCR_DBNet/multi_gpu_train.sh new file mode 100644 index 0000000..b49a73f --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/multi_gpu_train.sh @@ -0,0 +1,2 @@ +# export NCCL_P2P_DISABLE=1 +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m paddle.distributed.launch tools/train.py --config_file "config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml" \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/post_processing/__init__.py b/benchmark/PaddleOCR_DBNet/post_processing/__init__.py new file mode 100644 index 0000000..2f8e432 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/post_processing/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/5 15:17 +# @Author : zhoujun + +from .seg_detector_representer import SegDetectorRepresenter + + +def get_post_processing(config): + try: + cls = eval(config['type'])(**config['args']) + return cls + except: + return None \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/post_processing/seg_detector_representer.py b/benchmark/PaddleOCR_DBNet/post_processing/seg_detector_representer.py new file mode 100644 index 0000000..f1273dc --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/post_processing/seg_detector_representer.py @@ -0,0 +1,192 @@ +import cv2 +import numpy as np +import pyclipper +import paddle +from shapely.geometry import Polygon + + +class SegDetectorRepresenter(): + def __init__(self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=1.5): + self.min_size = 3 + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + + def __call__(self, batch, pred, is_output_polygon=False): + ''' + batch: (image, polygons, ignore_tags + batch: a dict produced by dataloaders. + image: tensor of shape (N, C, H, W). + polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. + shape: the original shape of images. + filename: the original filenames of images. + pred: + binary: text region segmentation map, with shape (N, H, W) + thresh: [if exists] thresh hold prediction with shape (N, H, W) + thresh_binary: [if exists] binarized with threshhold, (N, H, W) + ''' + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = pred[:, 0, :, :] + segmentation = self.binarize(pred) + boxes_batch = [] + scores_batch = [] + for batch_index in range(pred.shape[0]): + height, width = batch['shape'][batch_index] + if is_output_polygon: + boxes, scores = self.polygons_from_bitmap( + pred[batch_index], segmentation[batch_index], width, height) + else: + boxes, scores = self.boxes_from_bitmap( + pred[batch_index], segmentation[batch_index], width, height) + boxes_batch.append(boxes) + scores_batch.append(scores) + return boxes_batch, scores_batch + + def binarize(self, pred): + return pred > self.thresh + + def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (H, W), + whose values are binarized as {0, 1} + ''' + + assert len(_bitmap.shape) == 2 + bitmap = _bitmap # The first channel + height, width = bitmap.shape + boxes = [] + scores = [] + + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), + cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours[:self.max_candidates]: + epsilon = 0.005 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + points = approx.reshape((-1, 2)) + if points.shape[0] < 4: + continue + # _, sside = self.get_mini_boxes(contour) + # if sside < self.min_size: + # continue + score = self.box_score_fast(pred, contour.squeeze(1)) + if self.box_thresh > score: + continue + + if points.shape[0] > 2: + box = self.unclip(points, unclip_ratio=self.unclip_ratio) + if len(box) > 1: + continue + else: + continue + box = box.reshape(-1, 2) + _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) + if sside < self.min_size + 2: + continue + + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box) + scores.append(score) + return boxes, scores + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (H, W), + whose values are binarized as {0, 1} + ''' + + assert len(_bitmap.shape) == 2 + bitmap = _bitmap # The first channel + height, width = bitmap.shape + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), + cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + num_contours = min(len(contours), self.max_candidates) + boxes = np.zeros((num_contours, 4, 2), dtype=np.int16) + scores = np.zeros((num_contours, ), dtype=np.float32) + + for index in range(num_contours): + contour = contours[index].squeeze(1) + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + score = self.box_score_fast(pred, contour) + if self.box_thresh > score: + continue + + box = self.unclip( + points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes[index, :, :] = box.astype(np.int16) + scores[index] = score + return boxes, scores + + def unclip(self, box, unclip_ratio=1.5): + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] diff --git a/benchmark/PaddleOCR_DBNet/predict.sh b/benchmark/PaddleOCR_DBNet/predict.sh new file mode 100644 index 0000000..37ab148 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/predict.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python tools/predict.py --model_path model_best.pth --input_folder ./input --output_folder ./output --thre 0.7 --polygon --show --save_result \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/requirement.txt b/benchmark/PaddleOCR_DBNet/requirement.txt new file mode 100644 index 0000000..191819f --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/requirement.txt @@ -0,0 +1,13 @@ +anyconfig +future +imgaug +matplotlib +numpy +opencv-python +Polygon3 +pyclipper +PyYAML +scikit-image +Shapely +tqdm +addict \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/singlel_gpu_train.sh b/benchmark/PaddleOCR_DBNet/singlel_gpu_train.sh new file mode 100644 index 0000000..f8b9f0e --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/singlel_gpu_train.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python3 tools/train.py --config_file "config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml" \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/test/README.MD b/benchmark/PaddleOCR_DBNet/test/README.MD new file mode 100644 index 0000000..b43c6e9 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/test/README.MD @@ -0,0 +1,8 @@ +Place the images that you want to detect here. You better named them as such: +img_10.jpg +img_11.jpg +img_{img_id}.jpg + +For predicting single images, you can change the `img_path` in the `/tools/predict.py` to your image number. + +The result will be saved in the output_folder(default is test/output) you give in predict.sh \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/test_tipc/benchmark_train.sh b/benchmark/PaddleOCR_DBNet/test_tipc/benchmark_train.sh new file mode 100644 index 0000000..d94dac2 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/test_tipc/benchmark_train.sh @@ -0,0 +1,287 @@ +#!/bin/bash +source test_tipc/common_func.sh + +# run benchmark sh +# Usage: +# bash run_benchmark_train.sh config.txt params +# or +# bash run_benchmark_train.sh config.txt + +function func_parser_params(){ + strs=$1 + IFS="=" + array=(${strs}) + tmp=${array[1]} + echo ${tmp} +} + +function set_dynamic_epoch(){ + string=$1 + num=$2 + _str=${string:1:6} + IFS="C" + arr=(${_str}) + M=${arr[0]} + P=${arr[1]} + ep=`expr $num \* $M \* $P` + echo $ep +} + +function func_sed_params(){ + filename=$1 + line=$2 + param_value=$3 + params=`sed -n "${line}p" $filename` + IFS=":" + array=(${params}) + key=${array[0]} + value=${array[1]} + + new_params="${key}:${param_value}" + IFS=";" + cmd="sed -i '${line}s/.*/${new_params}/' '${filename}'" + eval $cmd +} + +function set_gpu_id(){ + string=$1 + _str=${string:1:6} + IFS="C" + arr=(${_str}) + M=${arr[0]} + P=${arr[1]} + gn=`expr $P - 1` + gpu_num=`expr $gn / $M` + seq=`seq -s "," 0 $gpu_num` + echo $seq +} + +function get_repo_name(){ + IFS=";" + cur_dir=$(pwd) + IFS="/" + arr=(${cur_dir}) + echo ${arr[-1]} +} + +FILENAME=$1 +# copy FILENAME as new +new_filename="./test_tipc/benchmark_train.txt" +cmd=`yes|cp $FILENAME $new_filename` +FILENAME=$new_filename +# MODE must be one of ['benchmark_train'] +MODE=$2 +PARAMS=$3 + +to_static="" +# parse "to_static" options and modify trainer into "to_static_trainer" +if [[ $PARAMS =~ "dynamicTostatic" ]] ;then + to_static="d2sT_" + sed -i 's/trainer:norm_train/trainer:to_static_train/g' $FILENAME + # clear PARAM contents + if [ $PARAMS = "to_static" ] ;then + PARAMS="" + fi +fi +# bash test_tipc/benchmark_train.sh test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt benchmark_train dynamic_bs8_fp32_DP_N1C8 +# bash test_tipc/benchmark_train.sh test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt benchmark_train dynamicTostatic_bs8_fp32_DP_N1C8 +# bash test_tipc/benchmark_train.sh test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt benchmark_train dynamic_bs8_null_DP_N1C1 +IFS=$'\n' +# parser params from train_benchmark.txt +dataline=`cat $FILENAME` +# parser params +IFS=$'\n' +lines=(${dataline}) +model_name=$(func_parser_value "${lines[1]}") +python_name=$(func_parser_value "${lines[2]}") + +# set env +python=${python_name} +export str_tmp=$(echo `pip list|grep paddlepaddle-gpu|awk -F ' ' '{print $2}'`) +export frame_version=${str_tmp%%.post*} +export frame_commit=$(echo `${python} -c "import paddle;print(paddle.version.commit)"`) + +# 获取benchmark_params所在的行数 +line_num=`grep -n -w "train_benchmark_params" $FILENAME | cut -d ":" -f 1` +# for train log parser +batch_size=$(func_parser_value "${lines[line_num]}") +line_num=`expr $line_num + 1` +fp_items=$(func_parser_value "${lines[line_num]}") +line_num=`expr $line_num + 1` +epoch=$(func_parser_value "${lines[line_num]}") + +line_num=`expr $line_num + 1` +profile_option_key=$(func_parser_key "${lines[line_num]}") +profile_option_params=$(func_parser_value "${lines[line_num]}") +profile_option="${profile_option_key}:${profile_option_params}" + +line_num=`expr $line_num + 1` +flags_value=$(func_parser_value "${lines[line_num]}") +# set flags +IFS=";" +flags_list=(${flags_value}) +for _flag in ${flags_list[*]}; do + cmd="export ${_flag}" + eval $cmd +done + +# set log_name +repo_name=$(get_repo_name ) +SAVE_LOG=${BENCHMARK_LOG_DIR:-$(pwd)} # */benchmark_log +mkdir -p "${SAVE_LOG}/benchmark_log/" +status_log="${SAVE_LOG}/benchmark_log/results.log" + +# The number of lines in which train params can be replaced. +line_python=3 +line_gpuid=4 +line_precision=6 +line_epoch=7 +line_batchsize=9 +line_profile=13 +line_eval_py=24 +line_export_py=30 + +func_sed_params "$FILENAME" "${line_eval_py}" "null" +func_sed_params "$FILENAME" "${line_export_py}" "null" +func_sed_params "$FILENAME" "${line_python}" "$python" + +# if params +if [ ! -n "$PARAMS" ] ;then + # PARAMS input is not a word. + IFS="|" + batch_size_list=(${batch_size}) + fp_items_list=(${fp_items}) + device_num_list=(N1C4) + run_mode="DP" +elif [[ ${PARAMS} = "dynamicTostatic" ]];then + IFS="|" + model_type=$PARAMS + batch_size_list=(${batch_size}) + fp_items_list=(${fp_items}) + device_num_list=(N1C4) + run_mode="DP" +else + # parser params from input: modeltype_bs${bs_item}_${fp_item}_${run_mode}_${device_num} + IFS="_" + params_list=(${PARAMS}) + model_type=${params_list[0]} + batch_size=${params_list[1]} + batch_size=`echo ${batch_size} | tr -cd "[0-9]" ` + precision=${params_list[2]} + run_mode=${params_list[3]} + device_num=${params_list[4]} + IFS=";" + + if [ ${precision} = "fp16" ];then + precision="amp" + fi + + epoch=$(set_dynamic_epoch $device_num $epoch) + fp_items_list=($precision) + batch_size_list=($batch_size) + device_num_list=($device_num) +fi + +IFS="|" +for batch_size in ${batch_size_list[*]}; do + for train_precision in ${fp_items_list[*]}; do + for device_num in ${device_num_list[*]}; do + # sed batchsize and precision + if [ ${train_precision} = "amp" ];then + precision="fp16" + else + precision="fp32" + fi + + func_sed_params "$FILENAME" "${line_precision}" "$train_precision" + func_sed_params "$FILENAME" "${line_batchsize}" "$MODE=$batch_size" + func_sed_params "$FILENAME" "${line_epoch}" "$MODE=$epoch" + gpu_id=$(set_gpu_id $device_num) + + if [ ${#gpu_id} -le 1 ];then + log_path="$SAVE_LOG/profiling_log" + mkdir -p $log_path + log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}profiling" + func_sed_params "$FILENAME" "${line_gpuid}" "0" # sed used gpu_id + # set profile_option params + tmp=`sed -i "${line_profile}s/.*/${profile_option}/" "${FILENAME}"` + + # run test_train_inference_python.sh + cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} benchmark_train > ${log_path}/${log_name} 2>&1 " + echo $cmd + eval $cmd + eval "cat ${log_path}/${log_name}" + + # without profile + log_path="$SAVE_LOG/train_log" + speed_log_path="$SAVE_LOG/index" + mkdir -p $log_path + mkdir -p $speed_log_path + log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}log" + speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}speed" + func_sed_params "$FILENAME" "${line_profile}" "null" # sed profile_id as null + cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} benchmark_train > ${log_path}/${log_name} 2>&1 " + echo $cmd + job_bt=`date '+%Y%m%d%H%M%S'` + eval $cmd + job_et=`date '+%Y%m%d%H%M%S'` + export model_run_time=$((${job_et}-${job_bt})) + eval "cat ${log_path}/${log_name}" + + # parser log + _model_name="${model_name}_bs${batch_size}_${precision}_${run_mode}" + cmd="${python} ${BENCHMARK_ROOT}/scripts/analysis.py --filename ${log_path}/${log_name} \ + --speed_log_file '${speed_log_path}/${speed_log_name}' \ + --model_name ${_model_name} \ + --base_batch_size ${batch_size} \ + --run_mode ${run_mode} \ + --fp_item ${precision} \ + --keyword ips: \ + --skip_steps 2 \ + --device_num ${device_num} \ + --speed_unit samples/s \ + --convergence_key loss: " + echo $cmd + eval $cmd + last_status=${PIPESTATUS[0]} + status_check $last_status "${cmd}" "${status_log}" + else + IFS=";" + unset_env=`unset CUDA_VISIBLE_DEVICES` + log_path="$SAVE_LOG/train_log" + speed_log_path="$SAVE_LOG/index" + mkdir -p $log_path + mkdir -p $speed_log_path + log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}log" + speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}speed" + func_sed_params "$FILENAME" "${line_gpuid}" "$gpu_id" # sed used gpu_id + func_sed_params "$FILENAME" "${line_profile}" "null" # sed --profile_option as null + cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} benchmark_train > ${log_path}/${log_name} 2>&1 " + echo $cmd + job_bt=`date '+%Y%m%d%H%M%S'` + eval $cmd + job_et=`date '+%Y%m%d%H%M%S'` + export model_run_time=$((${job_et}-${job_bt})) + eval "cat ${log_path}/${log_name}" + # parser log + _model_name="${model_name}_bs${batch_size}_${precision}_${run_mode}" + + cmd="${python} ${BENCHMARK_ROOT}/scripts/analysis.py --filename ${log_path}/${log_name} \ + --speed_log_file '${speed_log_path}/${speed_log_name}' \ + --model_name ${_model_name} \ + --base_batch_size ${batch_size} \ + --run_mode ${run_mode} \ + --fp_item ${precision} \ + --keyword ips: \ + --skip_steps 2 \ + --device_num ${device_num} \ + --speed_unit images/s \ + --convergence_key loss: " + echo $cmd + eval $cmd + last_status=${PIPESTATUS[0]} + status_check $last_status "${cmd}" "${status_log}" + fi + done + done +done diff --git a/benchmark/PaddleOCR_DBNet/test_tipc/common_func.sh b/benchmark/PaddleOCR_DBNet/test_tipc/common_func.sh new file mode 100644 index 0000000..c123d3c --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/test_tipc/common_func.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +function func_parser_key(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[0]} + echo ${tmp} +} + +function func_parser_value(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[1]} + echo ${tmp} +} + +function func_set_params(){ + key=$1 + value=$2 + if [ ${key}x = "null"x ];then + echo " " + elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then + echo " " + else + echo "${key}=${value}" + fi +} + +function func_parser_params(){ + strs=$1 + MODE=$2 + IFS=":" + array=(${strs}) + key=${array[0]} + tmp=${array[1]} + IFS="|" + res="" + for _params in ${tmp[*]}; do + IFS="=" + array=(${_params}) + mode=${array[0]} + value=${array[1]} + if [[ ${mode} = ${MODE} ]]; then + IFS="|" + #echo $(func_set_params "${mode}" "${value}") + echo $value + break + fi + IFS="|" + done + echo ${res} +} + +function status_check(){ + last_status=$1 # the exit code + run_command=$2 + run_log=$3 + model_name=$4 + log_path=$5 + if [ $last_status -eq 0 ]; then + echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log} + else + echo -e "\033[33m Run failed with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log} + fi +} \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/test_tipc/configs/det_res50_db/train_infer_python.txt b/benchmark/PaddleOCR_DBNet/test_tipc/configs/det_res50_db/train_infer_python.txt new file mode 100644 index 0000000..7dc3da3 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/test_tipc/configs/det_res50_db/train_infer_python.txt @@ -0,0 +1,61 @@ +===========================train_params=========================== +model_name:det_res50_db +python:python +gpu_list:0|0,1 +trainer.use_gpu:True|True +amp:null +trainer.epochs:lite_train_lite_infer=1|whole_train_whole_infer=300 +trainer.output_dir:./output/ +dataset.train.loader.batch_size:lite_train_lite_infer=8|whole_train_lite_infer=8 +trainer.finetune_checkpoint:null +train_model_name:checkpoint/model_latest.pth +train_infer_img_dir:imgs/paper/db.jpg +null:null +## +trainer:norm_train +norm_train:tools/train.py --config_file config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml -o trainer.log_iter=1 trainer.enable_eval=False dataset.train.loader.shuffle=false arch.backbone.pretrained=False +quant_export:null +fpgm_export:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +trainer.output_dir:./output/ +trainer.resume_checkpoint: +norm_export:tools/export_model.py --config_file config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/det_r50_vd_db_v2.0_train/best_accuracy +infer_export:tools/export_model.py --config_file config/icdar2015_resnet50_FPN_DBhead_polyLR.yaml -o +infer_quant:False +inference:tools/infer.py +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--batch_size:1 +--use_tensorrt:False +--precision:fp32 +--model_dir: +--img_path:imgs/paper/db.jpg +--save_log_path:null +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] +===========================train_benchmark_params========================== +batch_size:8 +fp_items:fp32|fp16 +epoch:2 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 +===========================to_static_train_benchmark_params=========================== +to_static_train:trainer.to_static=true diff --git a/benchmark/PaddleOCR_DBNet/test_tipc/prepare.sh b/benchmark/PaddleOCR_DBNet/test_tipc/prepare.sh new file mode 100644 index 0000000..cd8f56f --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/test_tipc/prepare.sh @@ -0,0 +1,54 @@ +#!/bin/bash +source test_tipc/common_func.sh + +FILENAME=$1 + +# MODE be one of ['lite_train_lite_infer' 'lite_train_whole_infer' 'whole_train_whole_infer', +# 'whole_infer', 'klquant_whole_infer', +# 'cpp_infer', 'serving_infer'] + +MODE=$2 + +dataline=$(cat ${FILENAME}) + +# parser params +IFS=$'\n' +lines=(${dataline}) + +# The training params +model_name=$(func_parser_value "${lines[1]}") + +trainer_list=$(func_parser_value "${lines[14]}") + +if [ ${MODE} = "lite_train_lite_infer" ];then + python_name_list=$(func_parser_value "${lines[2]}") + array=(${python_name_list}) + python_name=${array[0]} + ${python_name} -m pip install -r requirement.txt + if [[ ${model_name} =~ "det_res50_db" ]];then + wget -nc https://paddle-wheel.bj.bcebos.com/benchmark/resnet50-19c8e357.pth -O /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth + + # 下载数据集并解压 + rm -rf datasets + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/benchmark_train/datasets.tar + tar xf datasets.tar + fi +elif [ ${MODE} = "benchmark_train" ];then + python_name_list=$(func_parser_value "${lines[2]}") + array=(${python_name_list}) + python_name=${array[0]} + ${python_name} -m pip install -r requirement.txt + if [[ ${model_name} =~ "det_res50_db" ]];then + wget -nc https://paddle-wheel.bj.bcebos.com/benchmark/resnet50-19c8e357.pth -O /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth + + # 下载数据集并解压 + rm -rf datasets + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/benchmark_train/datasets.tar + tar xf datasets.tar + # expand gt.txt 2 times + # cd ./train_data/icdar2015/text_localization + # for i in `seq 2`;do cp train_icdar2015_label.txt dup$i.txt;done + # cat dup* > train_icdar2015_label.txt && rm -rf dup* + # cd ../../../ + fi +fi \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/test_tipc/test_train_inference_python.sh b/benchmark/PaddleOCR_DBNet/test_tipc/test_train_inference_python.sh new file mode 100644 index 0000000..a54591a --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/test_tipc/test_train_inference_python.sh @@ -0,0 +1,343 @@ +#!/bin/bash +source test_tipc/common_func.sh + +FILENAME=$1 +# MODE be one of ['lite_train_lite_infer' 'lite_train_whole_infer' 'whole_train_whole_infer', 'whole_infer'] +MODE=$2 + +dataline=$(awk 'NR>=1{print}' $FILENAME) + +# parser params +IFS=$'\n' +lines=(${dataline}) + +# The training params +model_name=$(func_parser_value "${lines[1]}") +python=$(func_parser_value "${lines[2]}") +gpu_list=$(func_parser_value "${lines[3]}") +train_use_gpu_key=$(func_parser_key "${lines[4]}") +train_use_gpu_value=$(func_parser_value "${lines[4]}") +autocast_list=$(func_parser_value "${lines[5]}") +autocast_key=$(func_parser_key "${lines[5]}") +epoch_key=$(func_parser_key "${lines[6]}") +epoch_num=$(func_parser_params "${lines[6]}" "${MODE}") +save_model_key=$(func_parser_key "${lines[7]}") +train_batch_key=$(func_parser_key "${lines[8]}") +train_batch_value=$(func_parser_params "${lines[8]}" "${MODE}") +pretrain_model_key=$(func_parser_key "${lines[9]}") +pretrain_model_value=$(func_parser_value "${lines[9]}") +train_model_name=$(func_parser_value "${lines[10]}") +train_infer_img_dir=$(func_parser_value "${lines[11]}") +train_param_key1=$(func_parser_key "${lines[12]}") +train_param_value1=$(func_parser_value "${lines[12]}") + +trainer_list=$(func_parser_value "${lines[14]}") +trainer_norm=$(func_parser_key "${lines[15]}") +norm_trainer=$(func_parser_value "${lines[15]}") +pact_key=$(func_parser_key "${lines[16]}") +pact_trainer=$(func_parser_value "${lines[16]}") +fpgm_key=$(func_parser_key "${lines[17]}") +fpgm_trainer=$(func_parser_value "${lines[17]}") +distill_key=$(func_parser_key "${lines[18]}") +distill_trainer=$(func_parser_value "${lines[18]}") +trainer_key1=$(func_parser_key "${lines[19]}") +trainer_value1=$(func_parser_value "${lines[19]}") +trainer_key2=$(func_parser_key "${lines[20]}") +trainer_value2=$(func_parser_value "${lines[20]}") + +eval_py=$(func_parser_value "${lines[23]}") +eval_key1=$(func_parser_key "${lines[24]}") +eval_value1=$(func_parser_value "${lines[24]}") + +save_infer_key=$(func_parser_key "${lines[27]}") +export_weight=$(func_parser_key "${lines[28]}") +norm_export=$(func_parser_value "${lines[29]}") +pact_export=$(func_parser_value "${lines[30]}") +fpgm_export=$(func_parser_value "${lines[31]}") +distill_export=$(func_parser_value "${lines[32]}") +export_key1=$(func_parser_key "${lines[33]}") +export_value1=$(func_parser_value "${lines[33]}") +export_key2=$(func_parser_key "${lines[34]}") +export_value2=$(func_parser_value "${lines[34]}") +inference_dir=$(func_parser_value "${lines[35]}") + +# parser inference model +infer_model_dir_list=$(func_parser_value "${lines[36]}") +infer_export_list=$(func_parser_value "${lines[37]}") +infer_is_quant=$(func_parser_value "${lines[38]}") +# parser inference +inference_py=$(func_parser_value "${lines[39]}") +use_gpu_key=$(func_parser_key "${lines[40]}") +use_gpu_list=$(func_parser_value "${lines[40]}") +use_mkldnn_key=$(func_parser_key "${lines[41]}") +use_mkldnn_list=$(func_parser_value "${lines[41]}") +cpu_threads_key=$(func_parser_key "${lines[42]}") +cpu_threads_list=$(func_parser_value "${lines[42]}") +batch_size_key=$(func_parser_key "${lines[43]}") +batch_size_list=$(func_parser_value "${lines[43]}") +use_trt_key=$(func_parser_key "${lines[44]}") +use_trt_list=$(func_parser_value "${lines[44]}") +precision_key=$(func_parser_key "${lines[45]}") +precision_list=$(func_parser_value "${lines[45]}") +infer_model_key=$(func_parser_key "${lines[46]}") +image_dir_key=$(func_parser_key "${lines[47]}") +infer_img_dir=$(func_parser_value "${lines[47]}") +save_log_key=$(func_parser_key "${lines[48]}") +benchmark_key=$(func_parser_key "${lines[49]}") +benchmark_value=$(func_parser_value "${lines[49]}") +infer_key1=$(func_parser_key "${lines[50]}") +infer_value1=$(func_parser_value "${lines[50]}") + +LOG_PATH="./test_tipc/output/${model_name}/${MODE}" +mkdir -p ${LOG_PATH} +status_log="${LOG_PATH}/results_python.log" + +line_num=`grep -n -w "to_static_train_benchmark_params" $FILENAME | cut -d ":" -f 1` +to_static_key=$(func_parser_key "${lines[line_num]}") +to_static_trainer=$(func_parser_value "${lines[line_num]}") + +function func_inference(){ + IFS='|' + _python=$1 + _script=$2 + _model_dir=$3 + _log_path=$4 + _img_dir=$5 + _flag_quant=$6 + _gpu=$7 + # inference + for use_gpu in ${use_gpu_list[*]}; do + if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then + for use_mkldnn in ${use_mkldnn_list[*]}; do + # if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then + # continue + # fi + for threads in ${cpu_threads_list[*]}; do + for batch_size in ${batch_size_list[*]}; do + for precision in ${precision_list[*]}; do + if [ ${use_mkldnn} = "False" ] && [ ${precision} = "fp16" ]; then + continue + fi # skip when enable fp16 but disable mkldnn + if [ ${_flag_quant} = "True" ] && [ ${precision} != "int8" ]; then + continue + fi # skip when quant model inference but precision is not int8 + set_precision=$(func_set_params "${precision_key}" "${precision}") + + _save_log_path="${_log_path}/python_infer_cpu_gpus_${_gpu}_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") + set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") + set_mkldnn=$(func_set_params "${use_mkldnn_key}" "${use_mkldnn}") + set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}") + set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}") + set_infer_params0=$(func_set_params "${save_log_key}" "${save_log_value}") + set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}") + command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_params0} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" + done + done + done + done + elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then + for use_trt in ${use_trt_list[*]}; do + for precision in ${precision_list[*]}; do + if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then + continue + fi + if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then + continue + fi + if [[ ${use_trt} = "False" && ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then + continue + fi + for batch_size in ${batch_size_list[*]}; do + _save_log_path="${_log_path}/python_infer_gpu_gpus_${_gpu}_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log" + set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") + set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") + set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") + set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}") + set_precision=$(func_set_params "${precision_key}" "${precision}") + set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}") + set_infer_params0=$(func_set_params "${save_log_key}" "${save_log_value}") + set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}") + command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} ${set_infer_params0} > ${_save_log_path} 2>&1 " + eval $command + last_status=${PIPESTATUS[0]} + eval "cat ${_save_log_path}" + status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}" + + done + done + done + else + echo "Does not support hardware other than CPU and GPU Currently!" + fi + done +} + +if [ ${MODE} = "whole_infer" ]; then + GPUID=$3 + if [ ${#GPUID} -le 0 ];then + env=" " + else + env="export CUDA_VISIBLE_DEVICES=${GPUID}" + fi + # set CUDA_VISIBLE_DEVICES + eval $env + export Count=0 + gpu=0 + IFS="|" + infer_run_exports=(${infer_export_list}) + infer_quant_flag=(${infer_is_quant}) + for infer_model in ${infer_model_dir_list[*]}; do + # run export + if [ ${infer_run_exports[Count]} != "null" ];then + save_infer_dir="${infer_model}" + set_export_weight=$(func_set_params "${export_weight}" "${infer_model}") + set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}") + export_log_path="${LOG_PATH}_export_${Count}.log" + export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 " + echo ${infer_run_exports[Count]} + echo $export_cmd + eval $export_cmd + status_export=$? + status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}" + else + save_infer_dir=${infer_model} + fi + #run inference + is_quant=${infer_quant_flag[Count]} + func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant} "${gpu}" + Count=$(($Count + 1)) + done +else + IFS="|" + export Count=0 + USE_GPU_KEY=(${train_use_gpu_value}) + for gpu in ${gpu_list[*]}; do + train_use_gpu=${USE_GPU_KEY[Count]} + Count=$(($Count + 1)) + ips="" + if [ ${gpu} = "-1" ];then + env="" + elif [ ${#gpu} -le 1 ];then + env="export CUDA_VISIBLE_DEVICES=${gpu}" + elif [ ${#gpu} -le 15 ];then + IFS="," + array=(${gpu}) + env="export CUDA_VISIBLE_DEVICES=${array[0]}" + IFS="|" + else + IFS=";" + array=(${gpu}) + ips=${array[0]} + gpu=${array[1]} + IFS="|" + env=" " + fi + for autocast in ${autocast_list[*]}; do + if [ ${autocast} = "amp" ]; then + set_amp_config="amp.scale_loss=1024.0 amp.use_dynamic_loss_scaling=True amp.amp_level=O2" + else + set_amp_config="amp=None" + fi + for trainer in ${trainer_list[*]}; do + flag_quant=False + if [ ${trainer} = ${pact_key} ]; then + run_train=${pact_trainer} + run_export=${pact_export} + flag_quant=True + elif [ ${trainer} = "${fpgm_key}" ]; then + run_train=${fpgm_trainer} + run_export=${fpgm_export} + elif [ ${trainer} = "${distill_key}" ]; then + run_train=${distill_trainer} + run_export=${distill_export} + elif [ ${trainer} = "${to_static_key}" ]; then + run_train="${norm_trainer} ${to_static_trainer}" + run_export=${norm_export} + elif [[ ${trainer} = ${trainer_key2} ]]; then + run_train=${trainer_value2} + run_export=${export_value2} + else + run_train=${norm_trainer} + run_export=${norm_export} + fi + + if [ ${run_train} = "null" ]; then + continue + fi + + set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}") + set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}") + set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}") + set_train_params1=$(func_set_params "${train_param_key1}" "${train_param_value1}") + set_use_gpu=$(func_set_params "${train_use_gpu_key}" "${train_use_gpu}") + # if length of ips >= 15, then it is seen as multi-machine + # 15 is the min length of ips info for multi-machine: 0.0.0.0,0.0.0.0 + if [ ${#ips} -le 15 ];then + save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}" + nodes=1 + else + IFS="," + ips_array=(${ips}) + IFS="|" + nodes=${#ips_array[@]} + save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}" + fi + + + set_save_model=$(func_set_params "${save_model_key}" "${save_log}") + if [ ${#gpu} -le 2 ];then # train with cpu or single gpu + cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_amp_config} ${set_train_params1}" + elif [ ${#ips} -le 15 ];then # train with multi-gpu + cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_amp_config} ${set_train_params1}" + else # train with multi-machine + cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_batchsize} ${set_amp_config} ${set_train_params1}" + fi + # run train + eval $cmd + eval "cat ${save_log}/train.log >> ${save_log}.log" + status_check $? "${cmd}" "${status_log}" "${model_name}" "${save_log}.log" + + set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}") + + # run eval + if [ ${eval_py} != "null" ]; then + eval ${env} + set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}") + eval_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_eval.log" + eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1} > ${eval_log_path} 2>&1 " + eval $eval_cmd + status_check $? "${eval_cmd}" "${status_log}" "${model_name}" "${eval_log_path}" + fi + # run export model + if [ ${run_export} != "null" ]; then + # run export model + save_infer_path="${save_log}" + export_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_export.log" + set_export_weight=$(func_set_params "${export_weight}" "${save_log}/${train_model_name}") + set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}") + export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 " + eval $export_cmd + status_check $? "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}" + + #run inference + eval $env + save_infer_path="${save_log}" + if [[ ${inference_dir} != "null" ]] && [[ ${inference_dir} != '##' ]]; then + infer_model_dir="${save_infer_path}/${inference_dir}" + else + infer_model_dir=${save_infer_path} + fi + func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}" "${gpu}" + + eval "unset CUDA_VISIBLE_DEVICES" + fi + done # done with: for trainer in ${trainer_list[*]}; do + done # done with: for autocast in ${autocast_list[*]}; do + done # done with: for gpu in ${gpu_list[*]}; do +fi # end if [ ${MODE} = "infer" ]; then \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/tools/__init__.py b/benchmark/PaddleOCR_DBNet/tools/__init__.py new file mode 100644 index 0000000..7cbf835 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/tools/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/8 13:14 +# @Author : zhoujun \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/tools/eval.py b/benchmark/PaddleOCR_DBNet/tools/eval.py new file mode 100644 index 0000000..fe514dd --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/tools/eval.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# @Time : 2018/6/11 15:54 +# @Author : zhoujun +import os +import sys +import pathlib +__dir__ = pathlib.Path(os.path.abspath(__file__)) +sys.path.append(str(__dir__)) +sys.path.append(str(__dir__.parent.parent)) + +import argparse +import time +import paddle +from tqdm.auto import tqdm + + +class EVAL(): + def __init__(self, model_path, gpu_id=0): + from models import build_model + from data_loader import get_dataloader + from post_processing import get_post_processing + from utils import get_metric + self.gpu_id = gpu_id + if self.gpu_id is not None and isinstance( + self.gpu_id, int) and paddle.device.is_compiled_with_cuda(): + paddle.device.set_device("gpu:{}".format(self.gpu_id)) + else: + paddle.device.set_device("cpu") + checkpoint = paddle.load(model_path) + config = checkpoint['config'] + config['arch']['backbone']['pretrained'] = False + + self.validate_loader = get_dataloader(config['dataset']['validate'], + config['distributed']) + + self.model = build_model(config['arch']) + self.model.set_state_dict(checkpoint['state_dict']) + + self.post_process = get_post_processing(config['post_processing']) + self.metric_cls = get_metric(config['metric']) + + def eval(self): + self.model.eval() + raw_metrics = [] + total_frame = 0.0 + total_time = 0.0 + for i, batch in tqdm( + enumerate(self.validate_loader), + total=len(self.validate_loader), + desc='test model'): + with paddle.no_grad(): + start = time.time() + preds = self.model(batch['img']) + boxes, scores = self.post_process( + batch, + preds, + is_output_polygon=self.metric_cls.is_output_polygon) + total_frame += batch['img'].shape[0] + total_time += time.time() - start + raw_metric = self.metric_cls.validate_measure(batch, + (boxes, scores)) + raw_metrics.append(raw_metric) + metrics = self.metric_cls.gather_measure(raw_metrics) + print('FPS:{}'.format(total_frame / total_time)) + return { + 'recall': metrics['recall'].avg, + 'precision': metrics['precision'].avg, + 'fmeasure': metrics['fmeasure'].avg + } + + +def init_args(): + parser = argparse.ArgumentParser(description='DBNet.paddle') + parser.add_argument( + '--model_path', + required=False, + default='output/DBNet_resnet18_FPN_DBHead/checkpoint/1.pth', + type=str) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = init_args() + eval = EVAL(args.model_path) + result = eval.eval() + print(result) diff --git a/benchmark/PaddleOCR_DBNet/tools/export_model.py b/benchmark/PaddleOCR_DBNet/tools/export_model.py new file mode 100644 index 0000000..59a318a --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/tools/export_model.py @@ -0,0 +1,57 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) + +import argparse + +import paddle +from paddle.jit import to_static + +from models import build_model +from utils import Config, ArgsParser + + +def init_args(): + parser = ArgsParser() + args = parser.parse_args() + return args + + +def load_checkpoint(model, checkpoint_path): + """ + load checkpoints + :param checkpoint_path: Checkpoint path to be loaded + """ + checkpoint = paddle.load(checkpoint_path) + model.set_state_dict(checkpoint['state_dict']) + print('load checkpoint from {}'.format(checkpoint_path)) + + +def main(config): + model = build_model(config['arch']) + load_checkpoint(model, config['trainer']['resume_checkpoint']) + model.eval() + + save_path = config["trainer"]["output_dir"] + save_path = os.path.join(save_path, "inference") + infer_shape = [3, -1, -1] + model = to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + infer_shape, dtype="float32") + ]) + + paddle.jit.save(model, save_path) + print("inference model is saved to {}".format(save_path)) + + +if __name__ == "__main__": + args = init_args() + assert os.path.exists(args.config_file) + config = Config(args.config_file) + config.merge_dict(args.opt) + main(config.cfg) diff --git a/benchmark/PaddleOCR_DBNet/tools/infer.py b/benchmark/PaddleOCR_DBNet/tools/infer.py new file mode 100644 index 0000000..24e919c --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/tools/infer.py @@ -0,0 +1,298 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import pathlib +__dir__ = pathlib.Path(os.path.abspath(__file__)) +sys.path.append(str(__dir__)) +sys.path.append(str(__dir__.parent.parent)) + +import cv2 +import paddle +from paddle import inference +import numpy as np +from PIL import Image + +from paddle.vision import transforms +from tools.predict import resize_image +from post_processing import get_post_processing +from utils.util import draw_bbox, save_result + + +class InferenceEngine(object): + """InferenceEngine + + Inference engina class which contains preprocess, run, postprocess + """ + + def __init__(self, args): + """ + Args: + args: Parameters generated using argparser. + Returns: None + """ + super().__init__() + self.args = args + + # init inference engine + self.predictor, self.config, self.input_tensor, self.output_tensor = self.load_predictor( + os.path.join(args.model_dir, "inference.pdmodel"), + os.path.join(args.model_dir, "inference.pdiparams")) + + # build transforms + self.transforms = transforms.Compose([ + transforms.ToTensor(), transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # wamrup + if self.args.warmup > 0: + for idx in range(args.warmup): + print(idx) + x = np.random.rand(1, 3, self.args.crop_size, + self.args.crop_size).astype("float32") + self.input_tensor.copy_from_cpu(x) + self.predictor.run() + self.output_tensor.copy_to_cpu() + + self.post_process = get_post_processing({ + 'type': 'SegDetectorRepresenter', + 'args': { + 'thresh': 0.3, + 'box_thresh': 0.7, + 'max_candidates': 1000, + 'unclip_ratio': 1.5 + } + }) + + def load_predictor(self, model_file_path, params_file_path): + """load_predictor + initialize the inference engine + Args: + model_file_path: inference model path (*.pdmodel) + model_file_path: inference parmaeter path (*.pdiparams) + Return: + predictor: Predictor created using Paddle Inference. + config: Configuration of the predictor. + input_tensor: Input tensor of the predictor. + output_tensor: Output tensor of the predictor. + """ + args = self.args + config = inference.Config(model_file_path, params_file_path) + if args.use_gpu: + config.enable_use_gpu(1000, 0) + if args.use_tensorrt: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=precision, + max_batch_size=args.max_batch_size, + min_subgraph_size=args. + min_subgraph_size, # skip the minmum trt subgraph + use_calib_mode=False) + + # collect shape + trt_shape_f = os.path.join(model_dir, "_trt_dynamic_shape.txt") + + if not os.path.exists(trt_shape_f): + config.collect_shape_range_info(trt_shape_f) + logger.info( + f"collect dynamic shape info into : {trt_shape_f}") + try: + config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, + True) + except Exception as E: + logger.info(E) + logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!") + else: + config.disable_gpu() + # The thread num should not be greater than the number of cores in the CPU. + if args.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + if args.precision == "fp16": + config.enable_mkldnn_bfloat16() + if hasattr(args, "cpu_threads"): + config.set_cpu_math_library_num_threads(args.cpu_threads) + else: + # default cpu threads as 10 + config.set_cpu_math_library_num_threads(10) + + # enable memory optim + config.enable_memory_optim() + config.disable_glog_info() + + config.switch_use_feed_fetch_ops(False) + config.switch_ir_optim(True) + + # create predictor + predictor = inference.create_predictor(config) + + # get input and output tensor property + input_names = predictor.get_input_names() + input_tensor = predictor.get_input_handle(input_names[0]) + + output_names = predictor.get_output_names() + output_tensor = predictor.get_output_handle(output_names[0]) + + return predictor, config, input_tensor, output_tensor + + def preprocess(self, img_path, short_size): + """preprocess + Preprocess to the input. + Args: + img_path: Image path. + Returns: Input data after preprocess. + """ + img = cv2.imread(img_path, 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + h, w = img.shape[:2] + img = resize_image(img, short_size) + img = self.transforms(img) + img = np.expand_dims(img, axis=0) + shape_info = {'shape': [(h, w)]} + return img, shape_info + + def postprocess(self, x, shape_info, is_output_polygon): + """postprocess + Postprocess to the inference engine output. + Args: + x: Inference engine output. + Returns: Output data after argmax. + """ + box_list, score_list = self.post_process( + shape_info, x, is_output_polygon=is_output_polygon) + box_list, score_list = box_list[0], score_list[0] + if len(box_list) > 0: + if is_output_polygon: + idx = [x.sum() > 0 for x in box_list] + box_list = [box_list[i] for i, v in enumerate(idx) if v] + score_list = [score_list[i] for i, v in enumerate(idx) if v] + else: + idx = box_list.reshape(box_list.shape[0], -1).sum( + axis=1) > 0 # 去掉全为0的框 + box_list, score_list = box_list[idx], score_list[idx] + else: + box_list, score_list = [], [] + return box_list, score_list + + def run(self, x): + """run + Inference process using inference engine. + Args: + x: Input data after preprocess. + Returns: Inference engine output + """ + self.input_tensor.copy_from_cpu(x) + self.predictor.run() + output = self.output_tensor.copy_to_cpu() + return output + + +def get_args(add_help=True): + """ + parse args + """ + import argparse + + def str2bool(v): + return v.lower() in ("true", "t", "1") + + parser = argparse.ArgumentParser( + description="PaddlePaddle Classification Training", add_help=add_help) + + parser.add_argument("--model_dir", default=None, help="inference model dir") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument( + "--short_size", default=1024, type=int, help="short size") + parser.add_argument("--img_path", default="./images/demo.jpg") + + parser.add_argument( + "--benchmark", default=False, type=str2bool, help="benchmark") + parser.add_argument("--warmup", default=0, type=int, help="warmup iter") + parser.add_argument( + '--polygon', action='store_true', help='output polygon or box') + + parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--precision", type=str, default="fp32") + parser.add_argument("--gpu_mem", type=int, default=500) + parser.add_argument("--gpu_id", type=int, default=0) + parser.add_argument("--enable_mkldnn", type=str2bool, default=False) + parser.add_argument("--cpu_threads", type=int, default=10) + + args = parser.parse_args() + return args + + +def main(args): + """ + Main inference function. + Args: + args: Parameters generated using argparser. + Returns: + class_id: Class index of the input. + prob: : Probability of the input. + """ + inference_engine = InferenceEngine(args) + + # init benchmark + if args.benchmark: + import auto_log + autolog = auto_log.AutoLogger( + model_name="db", + batch_size=args.batch_size, + inference_config=inference_engine.config, + gpu_ids="auto" if args.use_gpu else None) + + # enable benchmark + if args.benchmark: + autolog.times.start() + + # preprocess + img, shape_info = inference_engine.preprocess(args.img_path, + args.short_size) + + if args.benchmark: + autolog.times.stamp() + + output = inference_engine.run(img) + + if args.benchmark: + autolog.times.stamp() + + # postprocess + box_list, score_list = inference_engine.postprocess(output, shape_info, + args.polygon) + + if args.benchmark: + autolog.times.stamp() + autolog.times.end(stamp=True) + autolog.report() + + img = draw_bbox(cv2.imread(args.img_path)[:, :, ::-1], box_list) + # 保存结果到路径 + os.makedirs('output', exist_ok=True) + img_path = pathlib.Path(args.img_path) + output_path = os.path.join('output', img_path.stem + '_infer_result.jpg') + cv2.imwrite(output_path, img[:, :, ::-1]) + save_result( + output_path.replace('_infer_result.jpg', '.txt'), box_list, score_list, + args.polygon) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/benchmark/PaddleOCR_DBNet/tools/predict.py b/benchmark/PaddleOCR_DBNet/tools/predict.py new file mode 100644 index 0000000..51beffd --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/tools/predict.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/24 12:06 +# @Author : zhoujun + +import os +import sys +import pathlib +__dir__ = pathlib.Path(os.path.abspath(__file__)) +sys.path.append(str(__dir__)) +sys.path.append(str(__dir__.parent.parent)) + +import time +import cv2 +import paddle + +from data_loader import get_transforms +from models import build_model +from post_processing import get_post_processing + + +def resize_image(img, short_size): + height, width, _ = img.shape + if height < width: + new_height = short_size + new_width = new_height / height * width + else: + new_width = short_size + new_height = new_width / width * height + new_height = int(round(new_height / 32) * 32) + new_width = int(round(new_width / 32) * 32) + resized_img = cv2.resize(img, (new_width, new_height)) + return resized_img + + +class PaddleModel: + def __init__(self, model_path, post_p_thre=0.7, gpu_id=None): + ''' + 初始化模型 + :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) + :param gpu_id: 在哪一块gpu上运行 + ''' + self.gpu_id = gpu_id + + if self.gpu_id is not None and isinstance( + self.gpu_id, int) and paddle.device.is_compiled_with_cuda(): + paddle.device.set_device("gpu:{}".format(self.gpu_id)) + else: + paddle.device.set_device("cpu") + checkpoint = paddle.load(model_path) + + config = checkpoint['config'] + config['arch']['backbone']['pretrained'] = False + self.model = build_model(config['arch']) + self.post_process = get_post_processing(config['post_processing']) + self.post_process.box_thresh = post_p_thre + self.img_mode = config['dataset']['train']['dataset']['args'][ + 'img_mode'] + self.model.set_state_dict(checkpoint['state_dict']) + self.model.eval() + + self.transform = [] + for t in config['dataset']['train']['dataset']['args']['transforms']: + if t['type'] in ['ToTensor', 'Normalize']: + self.transform.append(t) + self.transform = get_transforms(self.transform) + + def predict(self, + img_path: str, + is_output_polygon=False, + short_size: int=1024): + ''' + 对传入的图像进行预测,支持图像地址,opecv 读取图片,偏慢 + :param img_path: 图像地址 + :param is_numpy: + :return: + ''' + assert os.path.exists(img_path), 'file is not exists' + img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0) + if self.img_mode == 'RGB': + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + h, w = img.shape[:2] + img = resize_image(img, short_size) + # 将图片由(w,h)变为(1,img_channel,h,w) + tensor = self.transform(img) + tensor = tensor.unsqueeze_(0) + + batch = {'shape': [(h, w)]} + with paddle.no_grad(): + start = time.time() + preds = self.model(tensor) + box_list, score_list = self.post_process( + batch, preds, is_output_polygon=is_output_polygon) + box_list, score_list = box_list[0], score_list[0] + if len(box_list) > 0: + if is_output_polygon: + idx = [x.sum() > 0 for x in box_list] + box_list = [box_list[i] for i, v in enumerate(idx) if v] + score_list = [score_list[i] for i, v in enumerate(idx) if v] + else: + idx = box_list.reshape(box_list.shape[0], -1).sum( + axis=1) > 0 # 去掉全为0的框 + box_list, score_list = box_list[idx], score_list[idx] + else: + box_list, score_list = [], [] + t = time.time() - start + return preds[0, 0, :, :].detach().cpu().numpy(), box_list, score_list, t + + +def save_depoly(net, input, save_path): + input_spec = [ + paddle.static.InputSpec( + shape=[None, 3, None, None], dtype="float32") + ] + net = paddle.jit.to_static(net, input_spec=input_spec) + + # save static model for inference directly + paddle.jit.save(net, save_path) + + +def init_args(): + import argparse + parser = argparse.ArgumentParser(description='DBNet.paddle') + parser.add_argument('--model_path', default=r'model_best.pth', type=str) + parser.add_argument( + '--input_folder', + default='./test/input', + type=str, + help='img path for predict') + parser.add_argument( + '--output_folder', + default='./test/output', + type=str, + help='img path for output') + parser.add_argument('--gpu', default=0, type=int, help='gpu for inference') + parser.add_argument( + '--thre', default=0.3, type=float, help='the thresh of post_processing') + parser.add_argument( + '--polygon', action='store_true', help='output polygon or box') + parser.add_argument('--show', action='store_true', help='show result') + parser.add_argument( + '--save_result', + action='store_true', + help='save box and score to txt file') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + import pathlib + from tqdm import tqdm + import matplotlib.pyplot as plt + from utils.util import show_img, draw_bbox, save_result, get_image_file_list + + args = init_args() + print(args) + # 初始化网络 + model = PaddleModel(args.model_path, post_p_thre=args.thre, gpu_id=args.gpu) + img_folder = pathlib.Path(args.input_folder) + for img_path in tqdm(get_image_file_list(args.input_folder)): + preds, boxes_list, score_list, t = model.predict( + img_path, is_output_polygon=args.polygon) + img = draw_bbox(cv2.imread(img_path)[:, :, ::-1], boxes_list) + if args.show: + show_img(preds) + show_img(img, title=os.path.basename(img_path)) + plt.show() + # 保存结果到路径 + os.makedirs(args.output_folder, exist_ok=True) + img_path = pathlib.Path(img_path) + output_path = os.path.join(args.output_folder, + img_path.stem + '_result.jpg') + pred_path = os.path.join(args.output_folder, + img_path.stem + '_pred.jpg') + cv2.imwrite(output_path, img[:, :, ::-1]) + cv2.imwrite(pred_path, preds * 255) + save_result( + output_path.replace('_result.jpg', '.txt'), boxes_list, score_list, + args.polygon) diff --git a/benchmark/PaddleOCR_DBNet/tools/train.py b/benchmark/PaddleOCR_DBNet/tools/train.py new file mode 100644 index 0000000..403d618 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/tools/train.py @@ -0,0 +1,61 @@ +import os +import sys +import pathlib +__dir__ = pathlib.Path(os.path.abspath(__file__)) +sys.path.append(str(__dir__)) +sys.path.append(str(__dir__.parent.parent)) + +import paddle +import paddle.distributed as dist +from utils import Config, ArgsParser + + +def init_args(): + parser = ArgsParser() + args = parser.parse_args() + return args + + +def main(config, profiler_options): + from models import build_model, build_loss + from data_loader import get_dataloader + from trainer import Trainer + from post_processing import get_post_processing + from utils import get_metric + if paddle.device.cuda.device_count() > 1: + dist.init_parallel_env() + config['distributed'] = True + else: + config['distributed'] = False + train_loader = get_dataloader(config['dataset']['train'], + config['distributed']) + assert train_loader is not None + if 'validate' in config['dataset']: + validate_loader = get_dataloader(config['dataset']['validate'], False) + else: + validate_loader = None + criterion = build_loss(config['loss']) + config['arch']['backbone']['in_channels'] = 3 if config['dataset']['train'][ + 'dataset']['args']['img_mode'] != 'GRAY' else 1 + model = build_model(config['arch']) + # set @to_static for benchmark, skip this by default. + post_p = get_post_processing(config['post_processing']) + metric = get_metric(config['metric']) + trainer = Trainer( + config=config, + model=model, + criterion=criterion, + train_loader=train_loader, + post_process=post_p, + metric_cls=metric, + validate_loader=validate_loader, + profiler_options=profiler_options) + trainer.train() + + +if __name__ == '__main__': + args = init_args() + assert os.path.exists(args.config_file) + config = Config(args.config_file) + config.merge_dict(args.opt) + main(config.cfg, args.profiler_options) diff --git a/benchmark/PaddleOCR_DBNet/trainer/__init__.py b/benchmark/PaddleOCR_DBNet/trainer/__init__.py new file mode 100644 index 0000000..76c7392 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/trainer/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:58 +# @Author : zhoujun +from .trainer import Trainer \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/trainer/trainer.py b/benchmark/PaddleOCR_DBNet/trainer/trainer.py new file mode 100644 index 0000000..34b259f --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/trainer/trainer.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:58 +# @Author : zhoujun +import time + +import paddle +from tqdm import tqdm + +from base import BaseTrainer +from utils import runningScore, cal_text_score, Polynomial, profiler + + +class Trainer(BaseTrainer): + def __init__(self, + config, + model, + criterion, + train_loader, + validate_loader, + metric_cls, + post_process=None, + profiler_options=None): + super(Trainer, self).__init__(config, model, criterion, train_loader, + validate_loader, metric_cls, post_process) + self.profiler_options = profiler_options + self.enable_eval = config['trainer'].get('enable_eval', True) + + def _train_epoch(self, epoch): + self.model.train() + total_samples = 0 + train_reader_cost = 0.0 + train_batch_cost = 0.0 + reader_start = time.time() + epoch_start = time.time() + train_loss = 0. + running_metric_text = runningScore(2) + + for i, batch in enumerate(self.train_loader): + profiler.add_profiler_step(self.profiler_options) + if i >= self.train_loader_len: + break + self.global_step += 1 + lr = self.optimizer.get_lr() + + cur_batch_size = batch['img'].shape[0] + + train_reader_cost += time.time() - reader_start + if self.amp: + with paddle.amp.auto_cast( + enable='gpu' in paddle.device.get_device(), + custom_white_list=self.amp.get('custom_white_list', []), + custom_black_list=self.amp.get('custom_black_list', []), + level=self.amp.get('level', 'O2')): + preds = self.model(batch['img']) + loss_dict = self.criterion(preds.astype(paddle.float32), batch) + scaled_loss = self.amp['scaler'].scale(loss_dict['loss']) + scaled_loss.backward() + self.amp['scaler'].minimize(self.optimizer, scaled_loss) + else: + preds = self.model(batch['img']) + loss_dict = self.criterion(preds, batch) + # backward + loss_dict['loss'].backward() + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.clear_grad() + + train_batch_time = time.time() - reader_start + train_batch_cost += train_batch_time + total_samples += cur_batch_size + + # acc iou + score_shrink_map = cal_text_score( + preds[:, 0, :, :], + batch['shrink_map'], + batch['shrink_mask'], + running_metric_text, + thred=self.config['post_processing']['args']['thresh']) + + # loss 和 acc 记录到日志 + loss_str = 'loss: {:.4f}, '.format(loss_dict['loss'].item()) + for idx, (key, value) in enumerate(loss_dict.items()): + loss_dict[key] = value.item() + if key == 'loss': + continue + loss_str += '{}: {:.4f}'.format(key, loss_dict[key]) + if idx < len(loss_dict) - 1: + loss_str += ', ' + + train_loss += loss_dict['loss'] + acc = score_shrink_map['Mean Acc'] + iou_shrink_map = score_shrink_map['Mean IoU'] + + if self.global_step % self.log_iter == 0: + self.logger_info( + '[{}/{}], [{}/{}], global_step: {}, ips: {:.1f} samples/sec, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}'. + format(epoch, self.epochs, i + 1, self.train_loader_len, + self.global_step, total_samples / train_batch_cost, + train_reader_cost / self.log_iter, train_batch_cost / + self.log_iter, total_samples / self.log_iter, acc, + iou_shrink_map, loss_str, lr, train_batch_cost)) + total_samples = 0 + train_reader_cost = 0.0 + train_batch_cost = 0.0 + + if self.visualdl_enable and paddle.distributed.get_rank() == 0: + # write tensorboard + for key, value in loss_dict.items(): + self.writer.add_scalar('TRAIN/LOSS/{}'.format(key), value, + self.global_step) + self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc, + self.global_step) + self.writer.add_scalar('TRAIN/ACC_IOU/iou_shrink_map', + iou_shrink_map, self.global_step) + self.writer.add_scalar('TRAIN/lr', lr, self.global_step) + reader_start = time.time() + return { + 'train_loss': train_loss / self.train_loader_len, + 'lr': lr, + 'time': time.time() - epoch_start, + 'epoch': epoch + } + + def _eval(self, epoch): + self.model.eval() + raw_metrics = [] + total_frame = 0.0 + total_time = 0.0 + for i, batch in tqdm( + enumerate(self.validate_loader), + total=len(self.validate_loader), + desc='test model'): + with paddle.no_grad(): + start = time.time() + if self.amp: + with paddle.amp.auto_cast( + enable='gpu' in paddle.device.get_device(), + custom_white_list=self.amp.get('custom_white_list', + []), + custom_black_list=self.amp.get('custom_black_list', + []), + level=self.amp.get('level', 'O2')): + preds = self.model(batch['img']) + preds = preds.astype(paddle.float32) + else: + preds = self.model(batch['img']) + boxes, scores = self.post_process( + batch, + preds, + is_output_polygon=self.metric_cls.is_output_polygon) + total_frame += batch['img'].shape[0] + total_time += time.time() - start + raw_metric = self.metric_cls.validate_measure(batch, + (boxes, scores)) + raw_metrics.append(raw_metric) + metrics = self.metric_cls.gather_measure(raw_metrics) + self.logger_info('FPS:{}'.format(total_frame / total_time)) + return metrics['recall'].avg, metrics['precision'].avg, metrics[ + 'fmeasure'].avg + + def _on_epoch_finish(self): + self.logger_info('[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'. + format(self.epoch_result['epoch'], self.epochs, self. + epoch_result['train_loss'], self.epoch_result[ + 'time'], self.epoch_result['lr'])) + net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir) + net_save_path_best = '{}/model_best.pth'.format(self.checkpoint_dir) + + if paddle.distributed.get_rank() == 0: + self._save_checkpoint(self.epoch_result['epoch'], net_save_path) + save_best = False + if self.validate_loader is not None and self.metric_cls is not None and self.enable_eval: # 使用f1作为最优模型指标 + recall, precision, hmean = self._eval(self.epoch_result[ + 'epoch']) + + if self.visualdl_enable: + self.writer.add_scalar('EVAL/recall', recall, + self.global_step) + self.writer.add_scalar('EVAL/precision', precision, + self.global_step) + self.writer.add_scalar('EVAL/hmean', hmean, + self.global_step) + self.logger_info( + 'test: recall: {:.6f}, precision: {:.6f}, hmean: {:.6f}'. + format(recall, precision, hmean)) + + if hmean >= self.metrics['hmean']: + save_best = True + self.metrics['train_loss'] = self.epoch_result['train_loss'] + self.metrics['hmean'] = hmean + self.metrics['precision'] = precision + self.metrics['recall'] = recall + self.metrics['best_model_epoch'] = self.epoch_result[ + 'epoch'] + else: + if self.epoch_result['train_loss'] <= self.metrics[ + 'train_loss']: + save_best = True + self.metrics['train_loss'] = self.epoch_result['train_loss'] + self.metrics['best_model_epoch'] = self.epoch_result[ + 'epoch'] + best_str = 'current best, ' + for k, v in self.metrics.items(): + best_str += '{}: {:.6f}, '.format(k, v) + self.logger_info(best_str) + if save_best: + import shutil + shutil.copy(net_save_path, net_save_path_best) + self.logger_info("Saving current best: {}".format( + net_save_path_best)) + else: + self.logger_info("Saving checkpoint: {}".format(net_save_path)) + + def _on_train_finish(self): + if self.enable_eval: + for k, v in self.metrics.items(): + self.logger_info('{}:{}'.format(k, v)) + self.logger_info('finish train') + + def _initialize_scheduler(self): + if self.config['lr_scheduler']['type'] == 'Polynomial': + self.config['lr_scheduler']['args']['epochs'] = self.config[ + 'trainer']['epochs'] + self.config['lr_scheduler']['args']['step_each_epoch'] = len( + self.train_loader) + self.lr_scheduler = Polynomial( + **self.config['lr_scheduler']['args'])() + else: + self.lr_scheduler = self._initialize('lr_scheduler', + paddle.optimizer.lr) diff --git a/benchmark/PaddleOCR_DBNet/utils/__init__.py b/benchmark/PaddleOCR_DBNet/utils/__init__.py new file mode 100644 index 0000000..194e0b8 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:58 +# @Author : zhoujun +from .util import * +from .metrics import * +from .schedulers import * +from .cal_recall.script import cal_recall_precison_f1 +from .ocr_metric import get_metric diff --git a/benchmark/PaddleOCR_DBNet/utils/cal_recall/__init__.py b/benchmark/PaddleOCR_DBNet/utils/cal_recall/__init__.py new file mode 100644 index 0000000..0db38a8 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/cal_recall/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +# @Time : 1/16/19 6:40 AM +# @Author : zhoujun +from .script import cal_recall_precison_f1 +__all__ = ['cal_recall_precison_f1'] diff --git a/benchmark/PaddleOCR_DBNet/utils/cal_recall/rrc_evaluation_funcs.py b/benchmark/PaddleOCR_DBNet/utils/cal_recall/rrc_evaluation_funcs.py new file mode 100644 index 0000000..4e12ee6 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/cal_recall/rrc_evaluation_funcs.py @@ -0,0 +1,479 @@ +#!/usr/bin/env python2 +#encoding: UTF-8 +import json +import sys +sys.path.append('./') +import zipfile +import re +import sys +import os +import codecs +import traceback +import numpy as np +from utils import order_points_clockwise + + +def print_help(): + sys.stdout.write( + 'Usage: python %s.py -g= -s= [-o= -p=]' + % sys.argv[0]) + sys.exit(2) + + +def load_zip_file_keys(file, fileNameRegExp=''): + """ + Returns an array with the entries of the ZIP file that match with the regular expression. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + """ + try: + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: + raise Exception('Error loading the ZIP archive.') + + pairs = [] + + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append(keyName) + + return pairs + + +def load_zip_file(file, fileNameRegExp='', allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + try: + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: + raise Exception('Error loading the ZIP archive') + + pairs = [] + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append([keyName, archive.read(name)]) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' % name) + + return dict(pairs) + + +def load_folder_file(file, fileNameRegExp='', allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + pairs = [] + for name in os.listdir(file): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append([keyName, open(os.path.join(file, name)).read()]) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' % name) + + return dict(pairs) + + +def decode_utf8(raw): + """ + Returns a Unicode object on success, or None on failure + """ + try: + raw = codecs.decode(raw, 'utf-8', 'replace') + #extracts BOM if exists + raw = raw.encode('utf8') + if raw.startswith(codecs.BOM_UTF8): + raw = raw.replace(codecs.BOM_UTF8, '', 1) + return raw.decode('utf-8') + except: + return None + + +def validate_lines_in_file(fileName, + file_contents, + CRLF=True, + LTRB=True, + withTranscription=False, + withConfidence=False, + imWidth=0, + imHeight=0): + """ + This function validates that all lines of the file calling the Line validation function for each line + """ + utf8File = decode_utf8(file_contents) + if (utf8File is None): + raise Exception("The file %s is not UTF-8" % fileName) + + lines = utf8File.split("\r\n" if CRLF else "\n") + for line in lines: + line = line.replace("\r", "").replace("\n", "") + if (line != ""): + try: + validate_tl_line(line, LTRB, withTranscription, withConfidence, + imWidth, imHeight) + except Exception as e: + raise Exception( + ("Line in sample not valid. Sample: %s Line: %s Error: %s" % + (fileName, line, str(e))).encode('utf-8', 'replace')) + + +def validate_tl_line(line, + LTRB=True, + withTranscription=True, + withConfidence=True, + imWidth=0, + imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + """ + get_tl_line_values(line, LTRB, withTranscription, withConfidence, imWidth, + imHeight) + + +def get_tl_line_values(line, + LTRB=True, + withTranscription=False, + withConfidence=False, + imWidth=0, + imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + Returns values from a textline. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = "" + points = [] + + numPoints = 4 + + if LTRB: + + numPoints = 4 + + if withTranscription and withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + if m == None: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + raise Exception( + "Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription" + ) + elif withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: + raise Exception( + "Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence" + ) + elif withTranscription: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$', + line) + if m == None: + raise Exception( + "Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription" + ) + else: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$', + line) + if m == None: + raise Exception( + "Format incorrect. Should be: xmin,ymin,xmax,ymax") + + xmin = int(m.group(1)) + ymin = int(m.group(2)) + xmax = int(m.group(3)) + ymax = int(m.group(4)) + if (xmax < xmin): + raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." % (xmax)) + if (ymax < ymin): + raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." % + (ymax)) + + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(xmin, ymin, imWidth, imHeight) + validate_point_inside_bounds(xmax, ymax, imWidth, imHeight) + + else: + + numPoints = 8 + + if withTranscription and withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + if m == None: + raise Exception( + "Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription" + ) + elif withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: + raise Exception( + "Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence" + ) + elif withTranscription: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$', + line) + if m == None: + raise Exception( + "Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription" + ) + else: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$', + line) + if m == None: + raise Exception( + "Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") + + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + + points = order_points_clockwise(np.array(points).reshape(-1, + 2)).reshape(-1) + validate_clockwise_points(points) + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(points[0], points[1], imWidth, + imHeight) + validate_point_inside_bounds(points[2], points[3], imWidth, + imHeight) + validate_point_inside_bounds(points[4], points[5], imWidth, + imHeight) + validate_point_inside_bounds(points[6], points[7], imWidth, + imHeight) + + if withConfidence: + try: + confidence = float(m.group(numPoints + 1)) + except ValueError: + raise Exception("Confidence value must be a float") + + if withTranscription: + posTranscription = numPoints + (2 if withConfidence else 1) + transcription = m.group(posTranscription) + m2 = re.match(r'^\s*\"(.*)\"\s*$', transcription) + if m2 != None: #Transcription with double quotes, we extract the value and replace escaped characters + transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", + "\"") + + return points, confidence, transcription + + +def validate_point_inside_bounds(x, y, imWidth, imHeight): + if (x < 0 or x > imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" % + (xmin, imWidth, imHeight)) + if (y < 0 or y > imHeight): + raise Exception( + "Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" + % (ymin, imWidth, imHeight)) + + +def validate_clockwise_points(points): + """ + Validates that the points that the 4 points that dlimite a polygon are in clockwise order. + """ + + if len(points) != 8: + raise Exception("Points list not valid." + str(len(points))) + + point = [[int(points[0]), int(points[1])], + [int(points[2]), int(points[3])], + [int(points[4]), int(points[5])], + [int(points[6]), int(points[7])]] + edge = [(point[1][0] - point[0][0]) * (point[1][1] + point[0][1]), + (point[2][0] - point[1][0]) * (point[2][1] + point[1][1]), + (point[3][0] - point[2][0]) * (point[3][1] + point[2][1]), + (point[0][0] - point[3][0]) * (point[0][1] + point[3][1])] + + summatory = edge[0] + edge[1] + edge[2] + edge[3] + if summatory > 0: + raise Exception( + "Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards." + ) + + +def get_tl_line_values_from_file_contents(content, + CRLF=True, + LTRB=True, + withTranscription=False, + withConfidence=False, + imWidth=0, + imHeight=0, + sort_by_confidences=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid line formats: + xmin,ymin,xmax,ymax,[confidence],[transcription] + x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + lines = content.split("\r\n" if CRLF else "\n") + for line in lines: + line = line.replace("\r", "").replace("\n", "") + if (line != ""): + points, confidence, transcription = get_tl_line_values( + line, LTRB, withTranscription, withConfidence, imWidth, + imHeight) + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList) > 0 and sort_by_confidences: + import numpy as np + sorted_ind = np.argsort(-np.array(confidencesList)) + confidencesList = [confidencesList[i] for i in sorted_ind] + pointsList = [pointsList[i] for i in sorted_ind] + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList, confidencesList, transcriptionsList + + +def main_evaluation(p, + default_evaluation_params_fn, + validate_data_fn, + evaluate_method_fn, + show_result=True, + per_sample=True): + """ + This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. + Params: + p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results + """ + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p[ + 'p'][1:-1])) + + resDict = { + 'calculated': True, + 'Message': '', + 'method': '{}', + 'per_sample': '{}' + } + try: + # validate_data_fn(p['g'], p['s'], evalParams) + evalData = evaluate_method_fn(p['g'], p['s'], evalParams) + resDict.update(evalData) + + except Exception as e: + traceback.print_exc() + resDict['Message'] = str(e) + resDict['calculated'] = False + + if 'o' in p: + if not os.path.exists(p['o']): + os.makedirs(p['o']) + + resultsOutputname = p['o'] + '/results.zip' + outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) + + del resDict['per_sample'] + if 'output_items' in resDict.keys(): + del resDict['output_items'] + + outZip.writestr('method.json', json.dumps(resDict)) + + if not resDict['calculated']: + if show_result: + sys.stderr.write('Error!\n' + resDict['Message'] + '\n\n') + if 'o' in p: + outZip.close() + return resDict + + if 'o' in p: + if per_sample == True: + for k, v in evalData['per_sample'].iteritems(): + outZip.writestr(k + '.json', json.dumps(v)) + + if 'output_items' in evalData.keys(): + for k, v in evalData['output_items'].iteritems(): + outZip.writestr(k, v) + + outZip.close() + + if show_result: + sys.stdout.write("Calculated!") + sys.stdout.write(json.dumps(resDict['method'])) + + return resDict + + +def main_validation(default_evaluation_params_fn, validate_data_fn): + """ + This process validates a method + Params: + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + """ + try: + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update(p['p'] if isinstance(p['p'], dict) else + json.loads(p['p'][1:-1])) + + validate_data_fn(p['g'], p['s'], evalParams) + print('SUCCESS') + sys.exit(0) + except Exception as e: + print(str(e)) + sys.exit(101) diff --git a/benchmark/PaddleOCR_DBNet/utils/cal_recall/script.py b/benchmark/PaddleOCR_DBNet/utils/cal_recall/script.py new file mode 100644 index 0000000..3b2f391 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/cal_recall/script.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from collections import namedtuple +from . import rrc_evaluation_funcs +import Polygon as plg +import numpy as np + + +def default_evaluation_params(): + """ + default_evaluation_params: Default parameters to use for the validation and evaluation. + """ + return { + 'IOU_CONSTRAINT': 0.5, + 'AREA_PRECISION_CONSTRAINT': 0.5, + 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', + 'LTRB': + False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) + 'CRLF': False, # Lines are delimited by Windows CRLF format + 'CONFIDENCES': + False, # Detections must include confidence value. AP will be calculated + 'PER_SAMPLE_RESULTS': + True # Generate per sample results and produce data for visualization + } + + +def validate_data(gtFilePath, submFilePath, evaluationParams): + """ + Method validate_data: validates that all files in the results folder are correct (have the correct name contents). + Validates also that there are no missing files in the folder. + If some error detected, the method raises the error + """ + gt = rrc_evaluation_funcs.load_folder_file( + gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + + subm = rrc_evaluation_funcs.load_folder_file( + submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + # Validate format of GroundTruth + for k in gt: + rrc_evaluation_funcs.validate_lines_in_file( + k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True) + + # Validate format of results + for k in subm: + if (k in gt) == False: + raise Exception("The sample %s not present in GT" % k) + + rrc_evaluation_funcs.validate_lines_in_file( + k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'], + False, evaluationParams['CONFIDENCES']) + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + + def polygon_from_points(points): + """ + Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 + """ + resBoxes = np.empty([1, 8], dtype='int32') + resBoxes[0, 0] = int(points[0]) + resBoxes[0, 4] = int(points[1]) + resBoxes[0, 1] = int(points[2]) + resBoxes[0, 5] = int(points[3]) + resBoxes[0, 2] = int(points[4]) + resBoxes[0, 6] = int(points[5]) + resBoxes[0, 3] = int(points[6]) + resBoxes[0, 7] = int(points[7]) + pointMat = resBoxes[0].reshape([2, 4]).T + return plg.Polygon(pointMat) + + def rectangle_to_polygon(rect): + resBoxes = np.empty([1, 8], dtype='int32') + resBoxes[0, 0] = int(rect.xmin) + resBoxes[0, 4] = int(rect.ymax) + resBoxes[0, 1] = int(rect.xmin) + resBoxes[0, 5] = int(rect.ymin) + resBoxes[0, 2] = int(rect.xmax) + resBoxes[0, 6] = int(rect.ymin) + resBoxes[0, 3] = int(rect.xmax) + resBoxes[0, 7] = int(rect.ymax) + + pointMat = resBoxes[0].reshape([2, 4]).T + + return plg.Polygon(pointMat) + + def rectangle_to_points(rect): + points = [ + int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), + int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin) + ] + return points + + def get_union(pD, pG): + areaA = pD.area() + areaB = pG.area() + return areaA + areaB - get_intersection(pD, pG) + + def get_intersection_over_union(pD, pG): + try: + return get_intersection(pD, pG) / get_union(pD, pG) + except: + return 0 + + def get_intersection(pD, pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + gt = rrc_evaluation_funcs.load_folder_file( + gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_folder_file( + submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + arrGlobalConfidences = [] + arrGlobalMatches = [] + + for resFile in gt: + + gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile]) + recall = 0 + precision = 0 + hmean = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + arrSampleConfidences = [] + arrSampleMatch = [] + sampleAP = 0 + + evaluationLog = "" + + pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents( + gtFile, evaluationParams['CRLF'], evaluationParams['LTRB'], True, + False) + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + dontCare = transcription == "###" + if evaluationParams['LTRB']: + gtRect = Rectangle(*points) + gtPol = rectangle_to_polygon(gtRect) + else: + gtPol = polygon_from_points(points) + gtPols.append(gtPol) + gtPolPoints.append(points) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + ( + " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" + if len(gtDontCarePolsNum) > 0 else "\n") + + if resFile in subm: + + detFile = subm[ + resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile]) + + pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents( + detFile, evaluationParams['CRLF'], evaluationParams['LTRB'], + False, evaluationParams['CONFIDENCES']) + for n in range(len(pointsList)): + points = pointsList[n] + + if evaluationParams['LTRB']: + detRect = Rectangle(*points) + detPol = rectangle_to_polygon(detRect) + else: + detPol = polygon_from_points(points) + detPols.append(detPol) + detPolPoints.append(points) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = detPol.area() + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > + evaluationParams['AREA_PRECISION_CONSTRAINT']): + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + ( + " (" + str(len(detDontCarePolsNum)) + " don't care)\n" + if len(detDontCarePolsNum) > 0 else "\n") + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, + pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > evaluationParams[ + 'IOU_CONSTRAINT']: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += "Match GT #" + str( + gtNum) + " with Det #" + str(detNum) + "\n" + + if evaluationParams['CONFIDENCES']: + for detNum in range(len(detPols)): + if detNum not in detDontCarePolsNum: + # we exclude the don't care detections + match = detNum in detMatchedNums + + arrSampleConfidences.append(confidencesList[detNum]) + arrSampleMatch.append(match) + + arrGlobalConfidences.append(confidencesList[detNum]) + arrGlobalMatches.append(match) + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare > 0 else float(1) + sampleAP = precision + else: + recall = float(detMatched) / numGtCare + precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare + if evaluationParams['CONFIDENCES'] and evaluationParams[ + 'PER_SAMPLE_RESULTS']: + sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, + numGtCare) + + hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / ( + precision + recall) + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + if evaluationParams['PER_SAMPLE_RESULTS']: + perSampleMetrics[resFile] = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'AP': sampleAP, + 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtDontCare': gtDontCarePolsNum, + 'detDontCare': detDontCarePolsNum, + 'evaluationParams': evaluationParams, + 'evaluationLog': evaluationLog + } + + # Compute MAP and MAR + AP = 0 + if evaluationParams['CONFIDENCES']: + AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) + + methodRecall = 0 if numGlobalCareGt == 0 else float( + matchedSum) / numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float( + matchedSum) / numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = { + 'precision': methodPrecision, + 'recall': methodRecall, + 'hmean': methodHmean, + 'AP': AP + } + + resDict = { + 'calculated': True, + 'Message': '', + 'method': methodMetrics, + 'per_sample': perSampleMetrics + } + + return resDict + + +def cal_recall_precison_f1(gt_path, result_path, show_result=False): + p = {'g': gt_path, 's': result_path} + result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, + validate_data, + evaluate_method, show_result) + return result['method'] diff --git a/benchmark/PaddleOCR_DBNet/utils/compute_mean_std.py b/benchmark/PaddleOCR_DBNet/utils/compute_mean_std.py new file mode 100644 index 0000000..5d0ab5c --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/compute_mean_std.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/7 14:46 +# @Author : zhoujun + +import numpy as np +import cv2 +import os +import random +from tqdm import tqdm +# calculate means and std +train_txt_path = './train_val_list.txt' + +CNum = 10000 # 挑选多少图片进行计算 + +img_h, img_w = 640, 640 +imgs = np.zeros([img_w, img_h, 3, 1]) +means, stdevs = [], [] + +with open(train_txt_path, 'r') as f: + lines = f.readlines() + random.shuffle(lines) # shuffle , 随机挑选图片 + + for i in tqdm(range(CNum)): + img_path = lines[i].split('\t')[0] + + img = cv2.imread(img_path) + img = cv2.resize(img, (img_h, img_w)) + img = img[:, :, :, np.newaxis] + + imgs = np.concatenate((imgs, img), axis=3) +# print(i) + +imgs = imgs.astype(np.float32) / 255. + +for i in tqdm(range(3)): + pixels = imgs[:, :, i, :].ravel() # 拉成一行 + means.append(np.mean(pixels)) + stdevs.append(np.std(pixels)) + +# cv2 读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转 +means.reverse() # BGR --> RGB +stdevs.reverse() + +print("normMean = {}".format(means)) +print("normStd = {}".format(stdevs)) +print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs)) \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/utils/make_trainfile.py b/benchmark/PaddleOCR_DBNet/utils/make_trainfile.py new file mode 100644 index 0000000..9b7ae70 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/make_trainfile.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/24 12:06 +# @Author : zhoujun +import os +import glob +import pathlib + +data_path = r'test' +# data_path/img 存放图片 +# data_path/gt 存放标签文件 + +f_w = open(os.path.join(data_path, 'test.txt'), 'w', encoding='utf8') +for img_path in glob.glob(data_path + '/img/*.jpg', recursive=True): + d = pathlib.Path(img_path) + label_path = os.path.join(data_path, 'gt', ('gt_' + str(d.stem) + '.txt')) + if os.path.exists(img_path) and os.path.exists(label_path): + print(img_path, label_path) + else: + print('不存在', img_path, label_path) + f_w.write('{}\t{}\n'.format(img_path, label_path)) +f_w.close() \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/utils/metrics.py b/benchmark/PaddleOCR_DBNet/utils/metrics.py new file mode 100644 index 0000000..e9c54b8 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/metrics.py @@ -0,0 +1,58 @@ +# Adapted from score written by wkentaro +# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py + +import numpy as np + + +class runningScore(object): + def __init__(self, n_classes): + self.n_classes = n_classes + self.confusion_matrix = np.zeros((n_classes, n_classes)) + + def _fast_hist(self, label_true, label_pred, n_class): + mask = (label_true >= 0) & (label_true < n_class) + + if np.sum((label_pred[mask] < 0)) > 0: + print(label_pred[label_pred < 0]) + hist = np.bincount( + n_class * label_true[mask].astype(int) + label_pred[mask], + minlength=n_class**2).reshape(n_class, n_class) + return hist + + def update(self, label_trues, label_preds): + # print label_trues.dtype, label_preds.dtype + for lt, lp in zip(label_trues, label_preds): + try: + self.confusion_matrix += self._fast_hist(lt.flatten(), + lp.flatten(), + self.n_classes) + except: + pass + + def get_scores(self): + """Returns accuracy score evaluation result. + - overall accuracy + - mean accuracy + - mean IU + - fwavacc + """ + hist = self.confusion_matrix + acc = np.diag(hist).sum() / (hist.sum() + 0.0001) + acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001) + acc_cls = np.nanmean(acc_cls) + iu = np.diag(hist) / ( + hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001) + mean_iu = np.nanmean(iu) + freq = hist.sum(axis=1) / (hist.sum() + 0.0001) + fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() + cls_iu = dict(zip(range(self.n_classes), iu)) + + return { + 'Overall Acc': acc, + 'Mean Acc': acc_cls, + 'FreqW Acc': fwavacc, + 'Mean IoU': mean_iu, + }, cls_iu + + def reset(self): + self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/__init__.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/__init__.py new file mode 100644 index 0000000..3e7c51c --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/5 15:36 +# @Author : zhoujun +from .icdar2015 import QuadMetric + + +def get_metric(config): + try: + if 'args' not in config: + args = {} + else: + args = config['args'] + if isinstance(args, dict): + cls = eval(config['type'])(**args) + else: + cls = eval(config['type'])(args) + return cls + except: + return None \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/__init__.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/__init__.py new file mode 100644 index 0000000..375ae55 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/5 15:36 +# @Author : zhoujun + +from .quad_metric import QuadMetric \ No newline at end of file diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/__init__.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/deteval.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/deteval.py new file mode 100644 index 0000000..c5dcfc4 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/deteval.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import math +from collections import namedtuple +import numpy as np +from shapely.geometry import Polygon + + +class DetectionDetEvalEvaluator(object): + def __init__(self, + area_recall_constraint=0.8, + area_precision_constraint=0.4, + ev_param_ind_center_diff_thr=1, + mtype_oo_o=1.0, + mtype_om_o=0.8, + mtype_om_m=1.0): + + self.area_recall_constraint = area_recall_constraint + self.area_precision_constraint = area_precision_constraint + self.ev_param_ind_center_diff_thr = ev_param_ind_center_diff_thr + self.mtype_oo_o = mtype_oo_o + self.mtype_om_o = mtype_om_o + self.mtype_om_m = mtype_om_m + + def evaluate_image(self, gt, pred): + def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def one_to_one_match(row, col): + cont = 0 + for j in range(len(recallMat[0])): + if recallMat[row, + j] >= self.area_recall_constraint and precisionMat[ + row, j] >= self.area_precision_constraint: + cont = cont + 1 + if (cont != 1): + return False + cont = 0 + for i in range(len(recallMat)): + if recallMat[ + i, col] >= self.area_recall_constraint and precisionMat[ + i, col] >= self.area_precision_constraint: + cont = cont + 1 + if (cont != 1): + return False + + if recallMat[row, + col] >= self.area_recall_constraint and precisionMat[ + row, col] >= self.area_precision_constraint: + return True + return False + + def num_overlaps_gt(gtNum): + cont = 0 + for detNum in range(len(detRects)): + if detNum not in detDontCareRectsNum: + if recallMat[gtNum, detNum] > 0: + cont = cont + 1 + return cont + + def num_overlaps_det(detNum): + cont = 0 + for gtNum in range(len(recallMat)): + if gtNum not in gtDontCareRectsNum: + if recallMat[gtNum, detNum] > 0: + cont = cont + 1 + return cont + + def is_single_overlap(row, col): + if num_overlaps_gt(row) == 1 and num_overlaps_det(col) == 1: + return True + else: + return False + + def one_to_many_match(gtNum): + many_sum = 0 + detRects = [] + for detNum in range(len(recallMat[0])): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and detNum not in detDontCareRectsNum: + if precisionMat[gtNum, + detNum] >= self.area_precision_constraint: + many_sum += recallMat[gtNum, detNum] + detRects.append(detNum) + if round(many_sum, 4) >= self.area_recall_constraint: + return True, detRects + else: + return False, [] + + def many_to_one_match(detNum): + many_sum = 0 + gtRects = [] + for gtNum in range(len(recallMat)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCareRectsNum: + if recallMat[gtNum, detNum] >= self.area_recall_constraint: + many_sum += precisionMat[gtNum, detNum] + gtRects.append(gtNum) + if round(many_sum, 4) >= self.area_precision_constraint: + return True, gtRects + else: + return False, [] + + def center_distance(r1, r2): + return ((np.mean(r1, axis=0) - np.mean(r2, axis=0))**2).sum()**0.5 + + def diag(r): + r = np.array(r) + return ((r[:, 0].max() - r[:, 0].min())**2 + + (r[:, 1].max() - r[:, 1].min())**2)**0.5 + + perSampleMetrics = {} + + recall = 0 + precision = 0 + hmean = 0 + recallAccum = 0. + precisionAccum = 0. + gtRects = [] + detRects = [] + gtPolPoints = [] + detPolPoints = [] + gtDontCareRectsNum = [ + ] #Array of Ground Truth Rectangles' keys marked as don't Care + detDontCareRectsNum = [ + ] #Array of Detected Rectangles' matched with a don't Care GT + pairs = [] + evaluationLog = "" + + recallMat = np.empty([1, 1]) + precisionMat = np.empty([1, 1]) + + for n in range(len(gt)): + points = gt[n]['points'] + # transcription = gt[n]['text'] + dontCare = gt[n]['ignore'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + gtRects.append(points) + gtPolPoints.append(points) + if dontCare: + gtDontCareRectsNum.append(len(gtRects) - 1) + + evaluationLog += "GT rectangles: " + str(len(gtRects)) + ( + " (" + str(len(gtDontCareRectsNum)) + " don't care)\n" + if len(gtDontCareRectsNum) > 0 else "\n") + + for n in range(len(pred)): + points = pred[n]['points'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + detRect = points + detRects.append(detRect) + detPolPoints.append(points) + if len(gtDontCareRectsNum) > 0: + for dontCareRectNum in gtDontCareRectsNum: + dontCareRect = gtRects[dontCareRectNum] + intersected_area = get_intersection(dontCareRect, detRect) + rdDimensions = Polygon(detRect).area + if (rdDimensions == 0): + precision = 0 + else: + precision = intersected_area / rdDimensions + if (precision > self.area_precision_constraint): + detDontCareRectsNum.append(len(detRects) - 1) + break + + evaluationLog += "DET rectangles: " + str(len(detRects)) + ( + " (" + str(len(detDontCareRectsNum)) + " don't care)\n" + if len(detDontCareRectsNum) > 0 else "\n") + + if len(gtRects) == 0: + recall = 1 + precision = 0 if len(detRects) > 0 else 1 + + if len(detRects) > 0: + #Calculate recall and precision matrixs + outputShape = [len(gtRects), len(detRects)] + recallMat = np.empty(outputShape) + precisionMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtRects), np.int8) + detRectMat = np.zeros(len(detRects), np.int8) + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + rG = gtRects[gtNum] + rD = detRects[detNum] + intersected_area = get_intersection(rG, rD) + rgDimensions = Polygon(rG).area + rdDimensions = Polygon(rD).area + recallMat[ + gtNum, + detNum] = 0 if rgDimensions == 0 else intersected_area / rgDimensions + precisionMat[ + gtNum, + detNum] = 0 if rdDimensions == 0 else intersected_area / rdDimensions + + # Find one-to-one matches + evaluationLog += "Find one-to-one matches\n" + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCareRectsNum and detNum not in detDontCareRectsNum: + match = one_to_one_match(gtNum, detNum) + if match is True: + #in deteval we have to make other validation before mark as one-to-one + if is_single_overlap(gtNum, detNum) is True: + rG = gtRects[gtNum] + rD = detRects[detNum] + normDist = center_distance(rG, rD) + normDist /= diag(rG) + diag(rD) + normDist *= 2.0 + if normDist < self.ev_param_ind_center_diff_thr: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + recallAccum += self.mtype_oo_o + precisionAccum += self.mtype_oo_o + pairs.append({ + 'gt': gtNum, + 'det': detNum, + 'type': 'OO' + }) + evaluationLog += "Match GT #" + str( + gtNum) + " with Det #" + str( + detNum) + "\n" + else: + evaluationLog += "Match Discarded GT #" + str( + gtNum) + " with Det #" + str( + detNum) + " normDist: " + str( + normDist) + " \n" + else: + evaluationLog += "Match Discarded GT #" + str( + gtNum) + " with Det #" + str( + detNum) + " not single overlap\n" + # Find one-to-many matches + evaluationLog += "Find one-to-many matches\n" + for gtNum in range(len(gtRects)): + if gtNum not in gtDontCareRectsNum: + match, matchesDet = one_to_many_match(gtNum) + if match is True: + evaluationLog += "num_overlaps_gt=" + str( + num_overlaps_gt(gtNum)) + #in deteval we have to make other validation before mark as one-to-one + if num_overlaps_gt(gtNum) >= 2: + gtRectMat[gtNum] = 1 + recallAccum += (self.mtype_oo_o + if len(matchesDet) == 1 else + self.mtype_om_o) + precisionAccum += (self.mtype_oo_o + if len(matchesDet) == 1 else + self.mtype_om_o * + len(matchesDet)) + pairs.append({ + 'gt': gtNum, + 'det': matchesDet, + 'type': 'OO' if len(matchesDet) == 1 else 'OM' + }) + for detNum in matchesDet: + detRectMat[detNum] = 1 + evaluationLog += "Match GT #" + str( + gtNum) + " with Det #" + str(matchesDet) + "\n" + else: + evaluationLog += "Match Discarded GT #" + str( + gtNum) + " with Det #" + str( + matchesDet) + " not single overlap\n" + + # Find many-to-one matches + evaluationLog += "Find many-to-one matches\n" + for detNum in range(len(detRects)): + if detNum not in detDontCareRectsNum: + match, matchesGt = many_to_one_match(detNum) + if match is True: + #in deteval we have to make other validation before mark as one-to-one + if num_overlaps_det(detNum) >= 2: + detRectMat[detNum] = 1 + recallAccum += (self.mtype_oo_o + if len(matchesGt) == 1 else + self.mtype_om_m * len(matchesGt)) + precisionAccum += (self.mtype_oo_o + if len(matchesGt) == 1 else + self.mtype_om_m) + pairs.append({ + 'gt': matchesGt, + 'det': detNum, + 'type': 'OO' if len(matchesGt) == 1 else 'MO' + }) + for gtNum in matchesGt: + gtRectMat[gtNum] = 1 + evaluationLog += "Match GT #" + str( + matchesGt) + " with Det #" + str(detNum) + "\n" + else: + evaluationLog += "Match Discarded GT #" + str( + matchesGt) + " with Det #" + str( + detNum) + " not single overlap\n" + + numGtCare = (len(gtRects) - len(gtDontCareRectsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if len(detRects) > 0 else float(1) + else: + recall = float(recallAccum) / numGtCare + precision = float(0) if ( + len(detRects) - len(detDontCareRectsNum) + ) == 0 else float(precisionAccum) / ( + len(detRects) - len(detDontCareRectsNum)) + hmean = 0 if (precision + recall + ) == 0 else 2.0 * precision * recall / ( + precision + recall) + + numGtCare = len(gtRects) - len(gtDontCareRectsNum) + numDetCare = len(detRects) - len(detDontCareRectsNum) + + perSampleMetrics = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'recallMat': [] if len(detRects) > 100 else recallMat.tolist(), + 'precisionMat': [] + if len(detRects) > 100 else precisionMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtCare': numGtCare, + 'detCare': numDetCare, + 'gtDontCare': gtDontCareRectsNum, + 'detDontCare': detDontCareRectsNum, + 'recallAccum': recallAccum, + 'precisionAccum': precisionAccum, + 'evaluationLog': evaluationLog + } + + return perSampleMetrics + + def combine_results(self, results): + numGt = 0 + numDet = 0 + methodRecallSum = 0 + methodPrecisionSum = 0 + + for result in results: + numGt += result['gtCare'] + numDet += result['detCare'] + methodRecallSum += result['recallAccum'] + methodPrecisionSum += result['precisionAccum'] + + methodRecall = 0 if numGt == 0 else methodRecallSum / numGt + methodPrecision = 0 if numDet == 0 else methodPrecisionSum / numDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = { + 'precision': methodPrecision, + 'recall': methodRecall, + 'hmean': methodHmean + } + + return methodMetrics + + +if __name__ == '__main__': + evaluator = DetectionDetEvalEvaluator() + gts = [[{ + 'points': [(0, 0), (1, 0), (1, 1), (0, 1)], + 'text': 1234, + 'ignore': False, + }, { + 'points': [(2, 2), (3, 2), (3, 3), (2, 3)], + 'text': 5678, + 'ignore': True, + }]] + preds = [[{ + 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)], + 'text': 123, + 'ignore': False, + }]] + results = [] + for gt, pred in zip(gts, preds): + results.append(evaluator.evaluate_image(gt, pred)) + metrics = evaluator.combine_results(results) + print(metrics) diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/icdar2013.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/icdar2013.py new file mode 100644 index 0000000..7e8c86a --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/icdar2013.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import math +from collections import namedtuple +import numpy as np +from shapely.geometry import Polygon + + +class DetectionICDAR2013Evaluator(object): + def __init__(self, + area_recall_constraint=0.8, + area_precision_constraint=0.4, + ev_param_ind_center_diff_thr=1, + mtype_oo_o=1.0, + mtype_om_o=0.8, + mtype_om_m=1.0): + + self.area_recall_constraint = area_recall_constraint + self.area_precision_constraint = area_precision_constraint + self.ev_param_ind_center_diff_thr = ev_param_ind_center_diff_thr + self.mtype_oo_o = mtype_oo_o + self.mtype_om_o = mtype_om_o + self.mtype_om_m = mtype_om_m + + def evaluate_image(self, gt, pred): + def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def one_to_one_match(row, col): + cont = 0 + for j in range(len(recallMat[0])): + if recallMat[row, + j] >= self.area_recall_constraint and precisionMat[ + row, j] >= self.area_precision_constraint: + cont = cont + 1 + if (cont != 1): + return False + cont = 0 + for i in range(len(recallMat)): + if recallMat[ + i, col] >= self.area_recall_constraint and precisionMat[ + i, col] >= self.area_precision_constraint: + cont = cont + 1 + if (cont != 1): + return False + + if recallMat[row, + col] >= self.area_recall_constraint and precisionMat[ + row, col] >= self.area_precision_constraint: + return True + return False + + def one_to_many_match(gtNum): + many_sum = 0 + detRects = [] + for detNum in range(len(recallMat[0])): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and detNum not in detDontCareRectsNum: + if precisionMat[gtNum, + detNum] >= self.area_precision_constraint: + many_sum += recallMat[gtNum, detNum] + detRects.append(detNum) + if round(many_sum, 4) >= self.area_recall_constraint: + return True, detRects + else: + return False, [] + + def many_to_one_match(detNum): + many_sum = 0 + gtRects = [] + for gtNum in range(len(recallMat)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCareRectsNum: + if recallMat[gtNum, detNum] >= self.area_recall_constraint: + many_sum += precisionMat[gtNum, detNum] + gtRects.append(gtNum) + if round(many_sum, 4) >= self.area_precision_constraint: + return True, gtRects + else: + return False, [] + + def center_distance(r1, r2): + return ((np.mean(r1, axis=0) - np.mean(r2, axis=0))**2).sum()**0.5 + + def diag(r): + r = np.array(r) + return ((r[:, 0].max() - r[:, 0].min())**2 + + (r[:, 1].max() - r[:, 1].min())**2)**0.5 + + perSampleMetrics = {} + + recall = 0 + precision = 0 + hmean = 0 + recallAccum = 0. + precisionAccum = 0. + gtRects = [] + detRects = [] + gtPolPoints = [] + detPolPoints = [] + gtDontCareRectsNum = [ + ] #Array of Ground Truth Rectangles' keys marked as don't Care + detDontCareRectsNum = [ + ] #Array of Detected Rectangles' matched with a don't Care GT + pairs = [] + evaluationLog = "" + + recallMat = np.empty([1, 1]) + precisionMat = np.empty([1, 1]) + + for n in range(len(gt)): + points = gt[n]['points'] + # transcription = gt[n]['text'] + dontCare = gt[n]['ignore'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + gtRects.append(points) + gtPolPoints.append(points) + if dontCare: + gtDontCareRectsNum.append(len(gtRects) - 1) + + evaluationLog += "GT rectangles: " + str(len(gtRects)) + ( + " (" + str(len(gtDontCareRectsNum)) + " don't care)\n" + if len(gtDontCareRectsNum) > 0 else "\n") + + for n in range(len(pred)): + points = pred[n]['points'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + detRect = points + detRects.append(detRect) + detPolPoints.append(points) + if len(gtDontCareRectsNum) > 0: + for dontCareRectNum in gtDontCareRectsNum: + dontCareRect = gtRects[dontCareRectNum] + intersected_area = get_intersection(dontCareRect, detRect) + rdDimensions = Polygon(detRect).area + if (rdDimensions == 0): + precision = 0 + else: + precision = intersected_area / rdDimensions + if (precision > self.area_precision_constraint): + detDontCareRectsNum.append(len(detRects) - 1) + break + + evaluationLog += "DET rectangles: " + str(len(detRects)) + ( + " (" + str(len(detDontCareRectsNum)) + " don't care)\n" + if len(detDontCareRectsNum) > 0 else "\n") + + if len(gtRects) == 0: + recall = 1 + precision = 0 if len(detRects) > 0 else 1 + + if len(detRects) > 0: + #Calculate recall and precision matrixs + outputShape = [len(gtRects), len(detRects)] + recallMat = np.empty(outputShape) + precisionMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtRects), np.int8) + detRectMat = np.zeros(len(detRects), np.int8) + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + rG = gtRects[gtNum] + rD = detRects[detNum] + intersected_area = get_intersection(rG, rD) + rgDimensions = Polygon(rG).area + rdDimensions = Polygon(rD).area + recallMat[ + gtNum, + detNum] = 0 if rgDimensions == 0 else intersected_area / rgDimensions + precisionMat[ + gtNum, + detNum] = 0 if rdDimensions == 0 else intersected_area / rdDimensions + + # Find one-to-one matches + evaluationLog += "Find one-to-one matches\n" + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCareRectsNum and detNum not in detDontCareRectsNum: + match = one_to_one_match(gtNum, detNum) + if match is True: + #in deteval we have to make other validation before mark as one-to-one + rG = gtRects[gtNum] + rD = detRects[detNum] + normDist = center_distance(rG, rD) + normDist /= diag(rG) + diag(rD) + normDist *= 2.0 + if normDist < self.ev_param_ind_center_diff_thr: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + recallAccum += self.mtype_oo_o + precisionAccum += self.mtype_oo_o + pairs.append({ + 'gt': gtNum, + 'det': detNum, + 'type': 'OO' + }) + evaluationLog += "Match GT #" + str( + gtNum) + " with Det #" + str(detNum) + "\n" + else: + evaluationLog += "Match Discarded GT #" + str( + gtNum) + " with Det #" + str( + detNum) + " normDist: " + str( + normDist) + " \n" + # Find one-to-many matches + evaluationLog += "Find one-to-many matches\n" + for gtNum in range(len(gtRects)): + if gtNum not in gtDontCareRectsNum: + match, matchesDet = one_to_many_match(gtNum) + if match is True: + evaluationLog += "num_overlaps_gt=" + str( + num_overlaps_gt(gtNum)) + gtRectMat[gtNum] = 1 + recallAccum += (self.mtype_oo_o if len(matchesDet) == 1 + else self.mtype_om_o) + precisionAccum += (self.mtype_oo_o + if len(matchesDet) == 1 else + self.mtype_om_o * len(matchesDet)) + pairs.append({ + 'gt': gtNum, + 'det': matchesDet, + 'type': 'OO' if len(matchesDet) == 1 else 'OM' + }) + for detNum in matchesDet: + detRectMat[detNum] = 1 + evaluationLog += "Match GT #" + str( + gtNum) + " with Det #" + str(matchesDet) + "\n" + + # Find many-to-one matches + evaluationLog += "Find many-to-one matches\n" + for detNum in range(len(detRects)): + if detNum not in detDontCareRectsNum: + match, matchesGt = many_to_one_match(detNum) + if match is True: + detRectMat[detNum] = 1 + recallAccum += (self.mtype_oo_o if len(matchesGt) == 1 + else self.mtype_om_m * len(matchesGt)) + precisionAccum += (self.mtype_oo_o + if len(matchesGt) == 1 else + self.mtype_om_m) + pairs.append({ + 'gt': matchesGt, + 'det': detNum, + 'type': 'OO' if len(matchesGt) == 1 else 'MO' + }) + for gtNum in matchesGt: + gtRectMat[gtNum] = 1 + evaluationLog += "Match GT #" + str( + matchesGt) + " with Det #" + str(detNum) + "\n" + + numGtCare = (len(gtRects) - len(gtDontCareRectsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if len(detRects) > 0 else float(1) + else: + recall = float(recallAccum) / numGtCare + precision = float(0) if ( + len(detRects) - len(detDontCareRectsNum) + ) == 0 else float(precisionAccum) / ( + len(detRects) - len(detDontCareRectsNum)) + hmean = 0 if (precision + recall + ) == 0 else 2.0 * precision * recall / ( + precision + recall) + + numGtCare = len(gtRects) - len(gtDontCareRectsNum) + numDetCare = len(detRects) - len(detDontCareRectsNum) + + perSampleMetrics = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'recallMat': [] if len(detRects) > 100 else recallMat.tolist(), + 'precisionMat': [] + if len(detRects) > 100 else precisionMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtCare': numGtCare, + 'detCare': numDetCare, + 'gtDontCare': gtDontCareRectsNum, + 'detDontCare': detDontCareRectsNum, + 'recallAccum': recallAccum, + 'precisionAccum': precisionAccum, + 'evaluationLog': evaluationLog + } + + return perSampleMetrics + + def combine_results(self, results): + numGt = 0 + numDet = 0 + methodRecallSum = 0 + methodPrecisionSum = 0 + + for result in results: + numGt += result['gtCare'] + numDet += result['detCare'] + methodRecallSum += result['recallAccum'] + methodPrecisionSum += result['precisionAccum'] + + methodRecall = 0 if numGt == 0 else methodRecallSum / numGt + methodPrecision = 0 if numDet == 0 else methodPrecisionSum / numDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = { + 'precision': methodPrecision, + 'recall': methodRecall, + 'hmean': methodHmean + } + + return methodMetrics + + +if __name__ == '__main__': + evaluator = DetectionICDAR2013Evaluator() + gts = [[{ + 'points': [(0, 0), (1, 0), (1, 1), (0, 1)], + 'text': 1234, + 'ignore': False, + }, { + 'points': [(2, 2), (3, 2), (3, 3), (2, 3)], + 'text': 5678, + 'ignore': True, + }]] + preds = [[{ + 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)], + 'text': 123, + 'ignore': False, + }]] + results = [] + for gt, pred in zip(gts, preds): + results.append(evaluator.evaluate_image(gt, pred)) + metrics = evaluator.combine_results(results) + print(metrics) diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/iou.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/iou.py new file mode 100644 index 0000000..5f9533b --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/iou.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from collections import namedtuple +import numpy as np +from shapely.geometry import Polygon +import cv2 + + +def iou_rotate(box_a, box_b, method='union'): + rect_a = cv2.minAreaRect(box_a) + rect_b = cv2.minAreaRect(box_b) + r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b) + if r1[0] == 0: + return 0 + else: + inter_area = cv2.contourArea(r1[1]) + area_a = cv2.contourArea(box_a) + area_b = cv2.contourArea(box_b) + union_area = area_a + area_b - inter_area + if union_area == 0 or inter_area == 0: + return 0 + if method == 'union': + iou = inter_area / union_area + elif method == 'intersection': + iou = inter_area / min(area_a, area_b) + else: + raise NotImplementedError + return iou + + +class DetectionIoUEvaluator(object): + def __init__(self, + is_output_polygon=False, + iou_constraint=0.5, + area_precision_constraint=0.5): + self.is_output_polygon = is_output_polygon + self.iou_constraint = iou_constraint + self.area_precision_constraint = area_precision_constraint + + def evaluate_image(self, gt, pred): + def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + arrGlobalConfidences = [] + arrGlobalMatches = [] + + recall = 0 + precision = 0 + hmean = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + arrSampleConfidences = [] + arrSampleMatch = [] + + evaluationLog = "" + + for n in range(len(gt)): + points = gt[n]['points'] + # transcription = gt[n]['text'] + dontCare = gt[n]['ignore'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + gtPol = points + gtPols.append(gtPol) + gtPolPoints.append(points) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + ( + " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" + if len(gtDontCarePolsNum) > 0 else "\n") + + for n in range(len(pred)): + points = pred[n]['points'] + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + detPol = points + detPols.append(detPol) + detPolPoints.append(points) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = Polygon(detPol).area + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > self.area_precision_constraint): + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + ( + " (" + str(len(detDontCarePolsNum)) + " don't care)\n" + if len(detDontCarePolsNum) > 0 else "\n") + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + if self.is_output_polygon: + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, + pG) + else: + # gtPols = np.float32(gtPols) + # detPols = np.float32(detPols) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = np.float32(gtPols[gtNum]) + pD = np.float32(detPols[detNum]) + iouMat[gtNum, detNum] = iou_rotate(pD, pG) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > self.iou_constraint: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += "Match GT #" + \ + str(gtNum) + " with Det #" + str(detNum) + "\n" + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare > 0 else float(1) + else: + recall = float(detMatched) / numGtCare + precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare + + hmean = 0 if (precision + recall) == 0 else 2.0 * \ + precision * recall / (precision + recall) + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + perSampleMetrics = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtCare': numGtCare, + 'detCare': numDetCare, + 'gtDontCare': gtDontCarePolsNum, + 'detDontCare': detDontCarePolsNum, + 'detMatched': detMatched, + 'evaluationLog': evaluationLog + } + + return perSampleMetrics + + def combine_results(self, results): + numGlobalCareGt = 0 + numGlobalCareDet = 0 + matchedSum = 0 + for result in results: + numGlobalCareGt += result['gtCare'] + numGlobalCareDet += result['detCare'] + matchedSum += result['detMatched'] + + methodRecall = 0 if numGlobalCareGt == 0 else float( + matchedSum) / numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float( + matchedSum) / numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ + methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = { + 'precision': methodPrecision, + 'recall': methodRecall, + 'hmean': methodHmean + } + + return methodMetrics + + +if __name__ == '__main__': + evaluator = DetectionIoUEvaluator() + preds = [[{ + 'points': [(0.1, 0.1), (0.5, 0), (0.5, 1), (0, 1)], + 'text': 1234, + 'ignore': False, + }, { + 'points': [(0.5, 0.1), (1, 0), (1, 1), (0.5, 1)], + 'text': 5678, + 'ignore': False, + }]] + gts = [[{ + 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)], + 'text': 123, + 'ignore': False, + }]] + results = [] + for gt, pred in zip(gts, preds): + results.append(evaluator.evaluate_image(gt, pred)) + metrics = evaluator.combine_results(results) + print(metrics) diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/mtwi2018.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/mtwi2018.py new file mode 100644 index 0000000..8e319aa --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/detection/mtwi2018.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import math +from collections import namedtuple +import numpy as np +from shapely.geometry import Polygon + + +class DetectionMTWI2018Evaluator(object): + def __init__( + self, + area_recall_constraint=0.7, + area_precision_constraint=0.7, + ev_param_ind_center_diff_thr=1, ): + + self.area_recall_constraint = area_recall_constraint + self.area_precision_constraint = area_precision_constraint + self.ev_param_ind_center_diff_thr = ev_param_ind_center_diff_thr + + def evaluate_image(self, gt, pred): + def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def one_to_one_match(row, col): + cont = 0 + for j in range(len(recallMat[0])): + if recallMat[row, + j] >= self.area_recall_constraint and precisionMat[ + row, j] >= self.area_precision_constraint: + cont = cont + 1 + if (cont != 1): + return False + cont = 0 + for i in range(len(recallMat)): + if recallMat[ + i, col] >= self.area_recall_constraint and precisionMat[ + i, col] >= self.area_precision_constraint: + cont = cont + 1 + if (cont != 1): + return False + + if recallMat[row, + col] >= self.area_recall_constraint and precisionMat[ + row, col] >= self.area_precision_constraint: + return True + return False + + def one_to_many_match(gtNum): + many_sum = 0 + detRects = [] + for detNum in range(len(recallMat[0])): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and detNum not in detDontCareRectsNum: + if precisionMat[gtNum, + detNum] >= self.area_precision_constraint: + many_sum += recallMat[gtNum, detNum] + detRects.append(detNum) + if round(many_sum, 4) >= self.area_recall_constraint: + return True, detRects + else: + return False, [] + + def many_to_one_match(detNum): + many_sum = 0 + gtRects = [] + for gtNum in range(len(recallMat)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCareRectsNum: + if recallMat[gtNum, detNum] >= self.area_recall_constraint: + many_sum += precisionMat[gtNum, detNum] + gtRects.append(gtNum) + if round(many_sum, 4) >= self.area_precision_constraint: + return True, gtRects + else: + return False, [] + + def center_distance(r1, r2): + return ((np.mean(r1, axis=0) - np.mean(r2, axis=0))**2).sum()**0.5 + + def diag(r): + r = np.array(r) + return ((r[:, 0].max() - r[:, 0].min())**2 + + (r[:, 1].max() - r[:, 1].min())**2)**0.5 + + perSampleMetrics = {} + + recall = 0 + precision = 0 + hmean = 0 + recallAccum = 0. + precisionAccum = 0. + gtRects = [] + detRects = [] + gtPolPoints = [] + detPolPoints = [] + gtDontCareRectsNum = [ + ] #Array of Ground Truth Rectangles' keys marked as don't Care + detDontCareRectsNum = [ + ] #Array of Detected Rectangles' matched with a don't Care GT + pairs = [] + evaluationLog = "" + + recallMat = np.empty([1, 1]) + precisionMat = np.empty([1, 1]) + + for n in range(len(gt)): + points = gt[n]['points'] + # transcription = gt[n]['text'] + dontCare = gt[n]['ignore'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + gtRects.append(points) + gtPolPoints.append(points) + if dontCare: + gtDontCareRectsNum.append(len(gtRects) - 1) + + evaluationLog += "GT rectangles: " + str(len(gtRects)) + ( + " (" + str(len(gtDontCareRectsNum)) + " don't care)\n" + if len(gtDontCareRectsNum) > 0 else "\n") + + for n in range(len(pred)): + points = pred[n]['points'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + detRect = points + detRects.append(detRect) + detPolPoints.append(points) + if len(gtDontCareRectsNum) > 0: + for dontCareRectNum in gtDontCareRectsNum: + dontCareRect = gtRects[dontCareRectNum] + intersected_area = get_intersection(dontCareRect, detRect) + rdDimensions = Polygon(detRect).area + if (rdDimensions == 0): + precision = 0 + else: + precision = intersected_area / rdDimensions + if (precision > 0.5): + detDontCareRectsNum.append(len(detRects) - 1) + break + + evaluationLog += "DET rectangles: " + str(len(detRects)) + ( + " (" + str(len(detDontCareRectsNum)) + " don't care)\n" + if len(detDontCareRectsNum) > 0 else "\n") + + if len(gtRects) == 0: + recall = 1 + precision = 0 if len(detRects) > 0 else 1 + + if len(detRects) > 0: + #Calculate recall and precision matrixs + outputShape = [len(gtRects), len(detRects)] + recallMat = np.empty(outputShape) + precisionMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtRects), np.int8) + detRectMat = np.zeros(len(detRects), np.int8) + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + rG = gtRects[gtNum] + rD = detRects[detNum] + intersected_area = get_intersection(rG, rD) + rgDimensions = Polygon(rG).area + rdDimensions = Polygon(rD).area + recallMat[ + gtNum, + detNum] = 0 if rgDimensions == 0 else intersected_area / rgDimensions + precisionMat[ + gtNum, + detNum] = 0 if rdDimensions == 0 else intersected_area / rdDimensions + + # Find one-to-one matches + evaluationLog += "Find one-to-one matches\n" + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCareRectsNum and detNum not in detDontCareRectsNum: + match = one_to_one_match(gtNum, detNum) + if match is True: + #in deteval we have to make other validation before mark as one-to-one + rG = gtRects[gtNum] + rD = detRects[detNum] + normDist = center_distance(rG, rD) + normDist /= diag(rG) + diag(rD) + normDist *= 2.0 + if normDist < self.ev_param_ind_center_diff_thr: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + recallAccum += 1.0 + precisionAccum += 1.0 + pairs.append({ + 'gt': gtNum, + 'det': detNum, + 'type': 'OO' + }) + evaluationLog += "Match GT #" + str( + gtNum) + " with Det #" + str(detNum) + "\n" + else: + evaluationLog += "Match Discarded GT #" + str( + gtNum) + " with Det #" + str( + detNum) + " normDist: " + str( + normDist) + " \n" + # Find one-to-many matches + evaluationLog += "Find one-to-many matches\n" + for gtNum in range(len(gtRects)): + if gtNum not in gtDontCareRectsNum: + match, matchesDet = one_to_many_match(gtNum) + if match is True: + gtRectMat[gtNum] = 1 + recallAccum += 1.0 + precisionAccum += len(matchesDet) / ( + 1 + math.log(len(matchesDet))) + pairs.append({ + 'gt': gtNum, + 'det': matchesDet, + 'type': 'OO' if len(matchesDet) == 1 else 'OM' + }) + for detNum in matchesDet: + detRectMat[detNum] = 1 + evaluationLog += "Match GT #" + str( + gtNum) + " with Det #" + str(matchesDet) + "\n" + + # Find many-to-one matches + evaluationLog += "Find many-to-one matches\n" + for detNum in range(len(detRects)): + if detNum not in detDontCareRectsNum: + match, matchesGt = many_to_one_match(detNum) + if match is True: + detRectMat[detNum] = 1 + recallAccum += len(matchesGt) / ( + 1 + math.log(len(matchesGt))) + precisionAccum += 1.0 + pairs.append({ + 'gt': matchesGt, + 'det': detNum, + 'type': 'OO' if len(matchesGt) == 1 else 'MO' + }) + for gtNum in matchesGt: + gtRectMat[gtNum] = 1 + evaluationLog += "Match GT #" + str( + matchesGt) + " with Det #" + str(detNum) + "\n" + + numGtCare = (len(gtRects) - len(gtDontCareRectsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if len(detRects) > 0 else float(1) + else: + recall = float(recallAccum) / numGtCare + precision = float(0) if ( + len(detRects) - len(detDontCareRectsNum) + ) == 0 else float(precisionAccum) / ( + len(detRects) - len(detDontCareRectsNum)) + hmean = 0 if (precision + recall + ) == 0 else 2.0 * precision * recall / ( + precision + recall) + + numGtCare = len(gtRects) - len(gtDontCareRectsNum) + numDetCare = len(detRects) - len(detDontCareRectsNum) + + perSampleMetrics = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'recallMat': [] if len(detRects) > 100 else recallMat.tolist(), + 'precisionMat': [] + if len(detRects) > 100 else precisionMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtCare': numGtCare, + 'detCare': numDetCare, + 'gtDontCare': gtDontCareRectsNum, + 'detDontCare': detDontCareRectsNum, + 'recallAccum': recallAccum, + 'precisionAccum': precisionAccum, + 'evaluationLog': evaluationLog + } + + return perSampleMetrics + + def combine_results(self, results): + numGt = 0 + numDet = 0 + methodRecallSum = 0 + methodPrecisionSum = 0 + + for result in results: + numGt += result['gtCare'] + numDet += result['detCare'] + methodRecallSum += result['recallAccum'] + methodPrecisionSum += result['precisionAccum'] + + methodRecall = 0 if numGt == 0 else methodRecallSum / numGt + methodPrecision = 0 if numDet == 0 else methodPrecisionSum / numDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = { + 'precision': methodPrecision, + 'recall': methodRecall, + 'hmean': methodHmean + } + + return methodMetrics + + +if __name__ == '__main__': + evaluator = DetectionICDAR2013Evaluator() + gts = [[{ + 'points': [(0, 0), (1, 0), (1, 1), (0, 1)], + 'text': 1234, + 'ignore': False, + }, { + 'points': [(2, 2), (3, 2), (3, 3), (2, 3)], + 'text': 5678, + 'ignore': True, + }]] + preds = [[{ + 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)], + 'text': 123, + 'ignore': False, + }]] + results = [] + for gt, pred in zip(gts, preds): + results.append(evaluator.evaluate_image(gt, pred)) + metrics = evaluator.combine_results(results) + print(metrics) diff --git a/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/quad_metric.py b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/quad_metric.py new file mode 100644 index 0000000..e7e403a --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/ocr_metric/icdar2015/quad_metric.py @@ -0,0 +1,98 @@ +import numpy as np + +from .detection.iou import DetectionIoUEvaluator + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + return self + + +class QuadMetric(): + def __init__(self, is_output_polygon=False): + self.is_output_polygon = is_output_polygon + self.evaluator = DetectionIoUEvaluator( + is_output_polygon=is_output_polygon) + + def measure(self, batch, output, box_thresh=0.6): + ''' + batch: (image, polygons, ignore_tags + batch: a dict produced by dataloaders. + image: tensor of shape (N, C, H, W). + polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. + shape: the original shape of images. + filename: the original filenames of images. + output: (polygons, ...) + ''' + results = [] + gt_polyons_batch = batch['text_polys'] + ignore_tags_batch = batch['ignore_tags'] + pred_polygons_batch = np.array(output[0]) + pred_scores_batch = np.array(output[1]) + for polygons, pred_polygons, pred_scores, ignore_tags in zip( + gt_polyons_batch, pred_polygons_batch, pred_scores_batch, + ignore_tags_batch): + gt = [ + dict( + points=np.int64(polygons[i]), ignore=ignore_tags[i]) + for i in range(len(polygons)) + ] + if self.is_output_polygon: + pred = [ + dict(points=pred_polygons[i]) + for i in range(len(pred_polygons)) + ] + else: + pred = [] + # print(pred_polygons.shape) + for i in range(pred_polygons.shape[0]): + if pred_scores[i] >= box_thresh: + # print(pred_polygons[i,:,:].tolist()) + pred.append( + dict(points=pred_polygons[i, :, :].astype(np.int))) + # pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])] + results.append(self.evaluator.evaluate_image(gt, pred)) + return results + + def validate_measure(self, batch, output, box_thresh=0.6): + return self.measure(batch, output, box_thresh) + + def evaluate_measure(self, batch, output): + return self.measure(batch, output), np.linspace( + 0, batch['image'].shape[0]).tolist() + + def gather_measure(self, raw_metrics): + raw_metrics = [ + image_metrics + for batch_metrics in raw_metrics for image_metrics in batch_metrics + ] + + result = self.evaluator.combine_results(raw_metrics) + + precision = AverageMeter() + recall = AverageMeter() + fmeasure = AverageMeter() + + precision.update(result['precision'], n=len(raw_metrics)) + recall.update(result['recall'], n=len(raw_metrics)) + fmeasure_score = 2 * precision.val * recall.val / ( + precision.val + recall.val + 1e-8) + fmeasure.update(fmeasure_score) + + return {'precision': precision, 'recall': recall, 'fmeasure': fmeasure} diff --git a/benchmark/PaddleOCR_DBNet/utils/profiler.py b/benchmark/PaddleOCR_DBNet/utils/profiler.py new file mode 100644 index 0000000..e64afd6 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/profiler.py @@ -0,0 +1,110 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import paddle + +# A global variable to record the number of calling times for profiler +# functions. It is used to specify the tracing range of training steps. +_profiler_step_id = 0 + +# A global variable to avoid parsing from string every time. +_profiler_options = None + + +class ProfilerOptions(object): + ''' + Use a string to initialize a ProfilerOptions. + The string should be in the format: "key1=value1;key2=value;key3=value3". + For example: + "profile_path=model.profile" + "batch_range=[50, 60]; profile_path=model.profile" + "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" + ProfilerOptions supports following key-value pair: + batch_range - a integer list, e.g. [100, 110]. + state - a string, the optional values are 'CPU', 'GPU' or 'All'. + sorted_key - a string, the optional values are 'calls', 'total', + 'max', 'min' or 'ave. + tracer_option - a string, the optional values are 'Default', 'OpDetail', + 'AllOpDetail'. + profile_path - a string, the path to save the serialized profile data, + which can be used to generate a timeline. + exit_on_finished - a boolean. + ''' + + def __init__(self, options_str): + assert isinstance(options_str, str) + + self._options = { + 'batch_range': [10, 20], + 'state': 'All', + 'sorted_key': 'total', + 'tracer_option': 'Default', + 'profile_path': '/tmp/profile', + 'exit_on_finished': True + } + self._parse_from_string(options_str) + + def _parse_from_string(self, options_str): + for kv in options_str.replace(' ', '').split(';'): + key, value = kv.split('=') + if key == 'batch_range': + value_list = value.replace('[', '').replace(']', '').split(',') + value_list = list(map(int, value_list)) + if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ + 1] > value_list[0]: + self._options[key] = value_list + elif key == 'exit_on_finished': + self._options[key] = value.lower() in ("yes", "true", "t", "1") + elif key in [ + 'state', 'sorted_key', 'tracer_option', 'profile_path' + ]: + self._options[key] = value + + def __getitem__(self, name): + if self._options.get(name, None) is None: + raise ValueError( + "ProfilerOptions does not have an option named %s." % name) + return self._options[name] + + +def add_profiler_step(options_str=None): + ''' + Enable the operator-level timing using PaddlePaddle's profiler. + The profiler uses a independent variable to count the profiler steps. + One call of this function is treated as a profiler step. + + Args: + profiler_options - a string to initialize the ProfilerOptions. + Default is None, and the profiler is disabled. + ''' + if options_str is None: + return + + global _profiler_step_id + global _profiler_options + + if _profiler_options is None: + _profiler_options = ProfilerOptions(options_str) + + if _profiler_step_id == _profiler_options['batch_range'][0]: + paddle.utils.profiler.start_profiler(_profiler_options['state'], + _profiler_options['tracer_option']) + elif _profiler_step_id == _profiler_options['batch_range'][1]: + paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], + _profiler_options['profile_path']) + if _profiler_options['exit_on_finished']: + sys.exit(0) + + _profiler_step_id += 1 diff --git a/benchmark/PaddleOCR_DBNet/utils/schedulers.py b/benchmark/PaddleOCR_DBNet/utils/schedulers.py new file mode 100644 index 0000000..1b6fb7d --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/schedulers.py @@ -0,0 +1,64 @@ +from paddle.optimizer import lr +import logging +__all__ = ['Polynomial'] + + +class Polynomial(object): + """ + Polynomial learning rate decay + Args: + learning_rate (float): The initial learning rate. It is a python float number. + epochs(int): The decay epoch size. It determines the decay cycle, when by_epoch is set to true, it will change to epochs=epochs*step_each_epoch. + step_each_epoch: all steps in each epoch. + end_lr(float, optional): The minimum final learning rate. Default: 0.0001. + power(float, optional): Power of polynomial. Default: 1.0. + warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0, , when by_epoch is set to true, it will change to warmup_epoch=warmup_epoch*step_each_epoch. + warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + by_epoch: Whether the set parameter is based on epoch or iter, when set to true,, epochs and warmup_epoch will be automatically multiplied by step_each_epoch. Default: True + """ + + def __init__(self, + learning_rate, + epochs, + step_each_epoch, + end_lr=0.0, + power=1.0, + warmup_epoch=0, + warmup_start_lr=0.0, + last_epoch=-1, + by_epoch=True, + **kwargs): + super().__init__() + if warmup_epoch >= epochs: + msg = f"When using warm up, the value of \"epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}." + logging.warning(msg) + warmup_epoch = epochs + self.learning_rate = learning_rate + self.epochs = epochs + self.end_lr = end_lr + self.power = power + self.last_epoch = last_epoch + self.warmup_epoch = warmup_epoch + self.warmup_start_lr = warmup_start_lr + + if by_epoch: + self.epochs *= step_each_epoch + self.warmup_epoch = int(self.warmup_epoch * step_each_epoch) + + def __call__(self): + learning_rate = lr.PolynomialDecay( + learning_rate=self.learning_rate, + decay_steps=self.epochs, + end_lr=self.end_lr, + power=self.power, + last_epoch=self. + last_epoch) if self.epochs > 0 else self.learning_rate + if self.warmup_epoch > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_epoch, + start_lr=self.warmup_start_lr, + end_lr=self.learning_rate, + last_epoch=self.last_epoch) + return learning_rate diff --git a/benchmark/PaddleOCR_DBNet/utils/util.py b/benchmark/PaddleOCR_DBNet/utils/util.py new file mode 100644 index 0000000..39bae76 --- /dev/null +++ b/benchmark/PaddleOCR_DBNet/utils/util.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/8/23 21:59 +# @Author : zhoujun +import json +import pathlib +import time +import os +import glob +import cv2 +import yaml +from typing import Mapping +import matplotlib.pyplot as plt +import numpy as np + +from argparse import ArgumentParser, RawDescriptionHelpFormatter + + +def _check_image_file(path): + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'} + return any([path.lower().endswith(e) for e in img_end]) + + +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'} + if os.path.isfile(img_file) and _check_image_file(img_file): + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + file_path = os.path.join(img_file, single_file) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) + return imgs_lists + + +def setup_logger(log_file_path: str=None): + import logging + logging._warn_preinit_stderr = 0 + logger = logging.getLogger('DBNet.paddle') + formatter = logging.Formatter( + '%(asctime)s %(name)s %(levelname)s: %(message)s') + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + if log_file_path is not None: + file_handle = logging.FileHandler(log_file_path) + file_handle.setFormatter(formatter) + logger.addHandler(file_handle) + logger.setLevel(logging.DEBUG) + return logger + + +# --exeTime +def exe_time(func): + def newFunc(*args, **args2): + t0 = time.time() + back = func(*args, **args2) + print("{} cost {:.3f}s".format(func.__name__, time.time() - t0)) + return back + + return newFunc + + +def load(file_path: str): + file_path = pathlib.Path(file_path) + func_dict = {'.txt': _load_txt, '.json': _load_json, '.list': _load_txt} + assert file_path.suffix in func_dict + return func_dict[file_path.suffix](file_path) + + +def _load_txt(file_path: str): + with open(file_path, 'r', encoding='utf8') as f: + content = [ + x.strip().strip('\ufeff').strip('\xef\xbb\xbf') + for x in f.readlines() + ] + return content + + +def _load_json(file_path: str): + with open(file_path, 'r', encoding='utf8') as f: + content = json.load(f) + return content + + +def save(data, file_path): + file_path = pathlib.Path(file_path) + func_dict = {'.txt': _save_txt, '.json': _save_json} + assert file_path.suffix in func_dict + return func_dict[file_path.suffix](data, file_path) + + +def _save_txt(data, file_path): + """ + 将一个list的数组写入txt文件里 + :param data: + :param file_path: + :return: + """ + if not isinstance(data, list): + data = [data] + with open(file_path, mode='w', encoding='utf8') as f: + f.write('\n'.join(data)) + + +def _save_json(data, file_path): + with open(file_path, 'w', encoding='utf-8') as json_file: + json.dump(data, json_file, ensure_ascii=False, indent=4) + + +def show_img(imgs: np.ndarray, title='img'): + color = (len(imgs.shape) == 3 and imgs.shape[-1] == 3) + imgs = np.expand_dims(imgs, axis=0) + for i, img in enumerate(imgs): + plt.figure() + plt.title('{}_{}'.format(title, i)) + plt.imshow(img, cmap=None if color else 'gray') + plt.show() + + +def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2): + if isinstance(img_path, str): + img_path = cv2.imread(img_path) + # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB) + img_path = img_path.copy() + for point in result: + point = point.astype(int) + cv2.polylines(img_path, [point], True, color, thickness) + return img_path + + +def cal_text_score(texts, + gt_texts, + training_masks, + running_metric_text, + thred=0.5): + training_masks = training_masks.numpy() + pred_text = texts.numpy() * training_masks + pred_text[pred_text <= thred] = 0 + pred_text[pred_text > thred] = 1 + pred_text = pred_text.astype(np.int32) + gt_text = gt_texts.numpy() * training_masks + gt_text = gt_text.astype(np.int32) + running_metric_text.update(gt_text, pred_text) + score_text, _ = running_metric_text.get_scores() + return score_text + + +def order_points_clockwise(pts): + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + diff = np.diff(pts, axis=1) + rect[1] = pts[np.argmin(diff)] + rect[3] = pts[np.argmax(diff)] + return rect + + +def order_points_clockwise_list(pts): + pts = pts.tolist() + pts.sort(key=lambda x: (x[1], x[0])) + pts[:2] = sorted(pts[:2], key=lambda x: x[0]) + pts[2:] = sorted(pts[2:], key=lambda x: -x[0]) + pts = np.array(pts) + return pts + + +def get_datalist(train_data_path): + """ + 获取训练和验证的数据list + :param train_data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ + :return: + """ + train_data = [] + for p in train_data_path: + with open(p, 'r', encoding='utf-8') as f: + for line in f.readlines(): + line = line.strip('\n').replace('.jpg ', '.jpg\t').split('\t') + if len(line) > 1: + img_path = pathlib.Path(line[0].strip(' ')) + label_path = pathlib.Path(line[1].strip(' ')) + if img_path.exists() and img_path.stat( + ).st_size > 0 and label_path.exists() and label_path.stat( + ).st_size > 0: + train_data.append((str(img_path), str(label_path))) + return train_data + + +def save_result(result_path, box_list, score_list, is_output_polygon): + if is_output_polygon: + with open(result_path, 'wt') as res: + for i, box in enumerate(box_list): + box = box.reshape(-1).tolist() + result = ",".join([str(int(x)) for x in box]) + score = score_list[i] + res.write(result + ',' + str(score) + "\n") + else: + with open(result_path, 'wt') as res: + for i, box in enumerate(box_list): + score = score_list[i] + box = box.reshape(-1).tolist() + result = ",".join([str(int(x)) for x in box]) + res.write(result + ',' + str(score) + "\n") + + +def expand_polygon(polygon): + """ + 对只有一个字符的框进行扩充 + """ + (x, y), (w, h), angle = cv2.minAreaRect(np.float32(polygon)) + if angle < -45: + w, h = h, w + angle += 90 + new_w = w + h + box = ((x, y), (new_w, h), angle) + points = cv2.boxPoints(box) + return order_points_clockwise(points) + + +def _merge_dict(config, merge_dct): + """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of + updating only top-level keys, dict_merge recurses down into dicts nested + to an arbitrary depth, updating keys. The ``merge_dct`` is merged into + ``dct``. + Args: + config: dict onto which the merge is executed + merge_dct: dct merged into config + Returns: dct + """ + for key, value in merge_dct.items(): + sub_keys = key.split('.') + key = sub_keys[0] + if key in config and len(sub_keys) > 1: + _merge_dict(config[key], {'.'.join(sub_keys[1:]): value}) + elif key in config and isinstance(config[key], dict) and isinstance( + value, Mapping): + _merge_dict(config[key], value) + else: + config[key] = value + return config + + +def print_dict(cfg, print_func=print, delimiter=0): + """ + Recursively visualize a dict and + indenting acrrording by the relationship of keys. + """ + for k, v in sorted(cfg.items()): + if isinstance(v, dict): + print_func("{}{} : ".format(delimiter * " ", str(k))) + print_dict(v, print_func, delimiter + 4) + elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): + print_func("{}{} : ".format(delimiter * " ", str(k))) + for value in v: + print_dict(value, print_func, delimiter + 4) + else: + print_func("{}{} : {}".format(delimiter * " ", k, v)) + + +class Config(object): + def __init__(self, config_path, BASE_KEY='base'): + self.BASE_KEY = BASE_KEY + self.cfg = self._load_config_with_base(config_path) + + def _load_config_with_base(self, file_path): + """ + Load config from file. + Args: + file_path (str): Path of the config file to be loaded. + Returns: global config + """ + _, ext = os.path.splitext(file_path) + assert ext in ['.yml', '.yaml'], "only support yaml files for now" + + with open(file_path) as f: + file_cfg = yaml.load(f, Loader=yaml.Loader) + + # NOTE: cfgs outside have higher priority than cfgs in _BASE_ + if self.BASE_KEY in file_cfg: + all_base_cfg = dict() + base_ymls = list(file_cfg[self.BASE_KEY]) + for base_yml in base_ymls: + with open(base_yml) as f: + base_cfg = self._load_config_with_base(base_yml) + all_base_cfg = _merge_dict(all_base_cfg, base_cfg) + + del file_cfg[self.BASE_KEY] + file_cfg = _merge_dict(all_base_cfg, file_cfg) + file_cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0] + return file_cfg + + def merge_dict(self, args): + self.cfg = _merge_dict(self.cfg, args) + + def print_cfg(self, print_func=print): + """ + Recursively visualize a dict and + indenting acrrording by the relationship of keys. + """ + print_func('----------- Config -----------') + print_dict(self.cfg, print_func) + print_func('---------------------------------------------') + + def save(self, p): + with open(p, 'w') as f: + yaml.dump( + dict(self.cfg), f, default_flow_style=False, sort_keys=False) + + +class ArgsParser(ArgumentParser): + def __init__(self): + super(ArgsParser, self).__init__( + formatter_class=RawDescriptionHelpFormatter) + self.add_argument( + "-c", "--config_file", help="configuration file to use") + self.add_argument( + "-o", "--opt", nargs='*', help="set configuration options") + self.add_argument( + '-p', + '--profiler_options', + type=str, + default=None, + help='The option of profiler, which should be in format ' \ + '\"key1=value1;key2=value2;key3=value3\".' + ) + + def parse_args(self, argv=None): + args = super(ArgsParser, self).parse_args(argv) + assert args.config_file is not None, \ + "Please specify --config_file=configure_file_path." + args.opt = self._parse_opt(args.opt) + return args + + def _parse_opt(self, opts): + config = {} + if not opts: + return config + for s in opts: + s = s.strip() + k, v = s.split('=', 1) + if '.' not in k: + config[k] = yaml.load(v, Loader=yaml.Loader) + else: + keys = k.split('.') + if keys[0] not in config: + config[keys[0]] = {} + cur = config[keys[0]] + for idx, key in enumerate(keys[1:]): + if idx == len(keys) - 2: + cur[key] = yaml.load(v, Loader=yaml.Loader) + else: + cur[key] = {} + cur = cur[key] + return config + + +if __name__ == '__main__': + img = np.zeros((1, 3, 640, 640)) + show_img(img[0][0]) + plt.show() diff --git a/benchmark/analysis.py b/benchmark/analysis.py new file mode 100644 index 0000000..7322f00 --- /dev/null +++ b/benchmark/analysis.py @@ -0,0 +1,346 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import json +import os +import re +import traceback + + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--filename", type=str, help="The name of log which need to analysis.") + parser.add_argument( + "--log_with_profiler", + type=str, + help="The path of train log with profiler") + parser.add_argument( + "--profiler_path", type=str, help="The path of profiler timeline log.") + parser.add_argument( + "--keyword", type=str, help="Keyword to specify analysis data") + parser.add_argument( + "--separator", + type=str, + default=None, + help="Separator of different field in log") + parser.add_argument( + '--position', type=int, default=None, help='The position of data field') + parser.add_argument( + '--range', + type=str, + default="", + help='The range of data field to intercept') + parser.add_argument( + '--base_batch_size', type=int, help='base_batch size on gpu') + parser.add_argument( + '--skip_steps', + type=int, + default=0, + help='The number of steps to be skipped') + parser.add_argument( + '--model_mode', + type=int, + default=-1, + help='Analysis mode, default value is -1') + parser.add_argument('--ips_unit', type=str, default=None, help='IPS unit') + parser.add_argument( + '--model_name', + type=str, + default=0, + help='training model_name, transformer_base') + parser.add_argument( + '--mission_name', type=str, default=0, help='training mission name') + parser.add_argument( + '--direction_id', type=int, default=0, help='training direction_id') + parser.add_argument( + '--run_mode', + type=str, + default="sp", + help='multi process or single process') + parser.add_argument( + '--index', + type=int, + default=1, + help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}') + parser.add_argument( + '--gpu_num', type=int, default=1, help='nums of training gpus') + args = parser.parse_args() + args.separator = None if args.separator == "None" else args.separator + return args + + +def _is_number(num): + pattern = re.compile(r'^[-+]?[-0-9]\d*\.\d*|[-+]?\.?[0-9]\d*$') + result = pattern.match(num) + if result: + return True + else: + return False + + +class TimeAnalyzer(object): + def __init__(self, + filename, + keyword=None, + separator=None, + position=None, + range="-1"): + if filename is None: + raise Exception("Please specify the filename!") + + if keyword is None: + raise Exception("Please specify the keyword!") + + self.filename = filename + self.keyword = keyword + self.separator = separator + self.position = position + self.range = range + self.records = None + self._distil() + + def _distil(self): + self.records = [] + with open(self.filename, "r") as f_object: + lines = f_object.readlines() + for line in lines: + if self.keyword not in line: + continue + try: + result = None + + # Distil the string from a line. + line = line.strip() + line_words = line.split( + self.separator) if self.separator else line.split() + if args.position: + result = line_words[self.position] + else: + # Distil the string following the keyword. + for i in range(len(line_words) - 1): + if line_words[i] == self.keyword: + result = line_words[i + 1] + break + + # Distil the result from the picked string. + if not self.range: + result = result[0:] + elif _is_number(self.range): + result = result[0:int(self.range)] + else: + result = result[int(self.range.split(":")[0]):int( + self.range.split(":")[1])] + self.records.append(float(result)) + except Exception as exc: + print("line is: {}; separator={}; position={}".format( + line, self.separator, self.position)) + + print("Extract {} records: separator={}; position={}".format( + len(self.records), self.separator, self.position)) + + def _get_fps(self, + mode, + batch_size, + gpu_num, + avg_of_records, + run_mode, + unit=None): + if mode == -1 and run_mode == 'sp': + assert unit, "Please set the unit when mode is -1." + fps = gpu_num * avg_of_records + elif mode == -1 and run_mode == 'mp': + assert unit, "Please set the unit when mode is -1." + fps = gpu_num * avg_of_records #temporarily, not used now + print("------------this is mp") + elif mode == 0: + # s/step -> samples/s + fps = (batch_size * gpu_num) / avg_of_records + unit = "samples/s" + elif mode == 1: + # steps/s -> steps/s + fps = avg_of_records + unit = "steps/s" + elif mode == 2: + # s/step -> steps/s + fps = 1 / avg_of_records + unit = "steps/s" + elif mode == 3: + # steps/s -> samples/s + fps = batch_size * gpu_num * avg_of_records + unit = "samples/s" + elif mode == 4: + # s/epoch -> s/epoch + fps = avg_of_records + unit = "s/epoch" + else: + ValueError("Unsupported analysis mode.") + + return fps, unit + + def analysis(self, + batch_size, + gpu_num=1, + skip_steps=0, + mode=-1, + run_mode='sp', + unit=None): + if batch_size <= 0: + print("base_batch_size should larger than 0.") + return 0, '' + + if len( + self.records + ) <= skip_steps: # to address the condition which item of log equals to skip_steps + print("no records") + return 0, '' + + sum_of_records = 0 + sum_of_records_skipped = 0 + skip_min = self.records[skip_steps] + skip_max = self.records[skip_steps] + + count = len(self.records) + for i in range(count): + sum_of_records += self.records[i] + if i >= skip_steps: + sum_of_records_skipped += self.records[i] + if self.records[i] < skip_min: + skip_min = self.records[i] + if self.records[i] > skip_max: + skip_max = self.records[i] + + avg_of_records = sum_of_records / float(count) + avg_of_records_skipped = sum_of_records_skipped / float(count - + skip_steps) + + fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records, + run_mode, unit) + fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num, + avg_of_records_skipped, run_mode, unit) + if mode == -1: + print("average ips of %d steps, skip 0 step:" % count) + print("\tAvg: %.3f %s" % (avg_of_records, fps_unit)) + print("\tFPS: %.3f %s" % (fps, fps_unit)) + if skip_steps > 0: + print("average ips of %d steps, skip %d steps:" % + (count, skip_steps)) + print("\tAvg: %.3f %s" % (avg_of_records_skipped, fps_unit)) + print("\tMin: %.3f %s" % (skip_min, fps_unit)) + print("\tMax: %.3f %s" % (skip_max, fps_unit)) + print("\tFPS: %.3f %s" % (fps_skipped, fps_unit)) + elif mode == 1 or mode == 3: + print("average latency of %d steps, skip 0 step:" % count) + print("\tAvg: %.3f steps/s" % avg_of_records) + print("\tFPS: %.3f %s" % (fps, fps_unit)) + if skip_steps > 0: + print("average latency of %d steps, skip %d steps:" % + (count, skip_steps)) + print("\tAvg: %.3f steps/s" % avg_of_records_skipped) + print("\tMin: %.3f steps/s" % skip_min) + print("\tMax: %.3f steps/s" % skip_max) + print("\tFPS: %.3f %s" % (fps_skipped, fps_unit)) + elif mode == 0 or mode == 2: + print("average latency of %d steps, skip 0 step:" % count) + print("\tAvg: %.3f s/step" % avg_of_records) + print("\tFPS: %.3f %s" % (fps, fps_unit)) + if skip_steps > 0: + print("average latency of %d steps, skip %d steps:" % + (count, skip_steps)) + print("\tAvg: %.3f s/step" % avg_of_records_skipped) + print("\tMin: %.3f s/step" % skip_min) + print("\tMax: %.3f s/step" % skip_max) + print("\tFPS: %.3f %s" % (fps_skipped, fps_unit)) + + return round(fps_skipped, 3), fps_unit + + +if __name__ == "__main__": + args = parse_args() + run_info = dict() + run_info["log_file"] = args.filename + run_info["model_name"] = args.model_name + run_info["mission_name"] = args.mission_name + run_info["direction_id"] = args.direction_id + run_info["run_mode"] = args.run_mode + run_info["index"] = args.index + run_info["gpu_num"] = args.gpu_num + run_info["FINAL_RESULT"] = 0 + run_info["JOB_FAIL_FLAG"] = 0 + + try: + if args.index == 1: + if args.gpu_num == 1: + run_info["log_with_profiler"] = args.log_with_profiler + run_info["profiler_path"] = args.profiler_path + analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator, + args.position, args.range) + run_info["FINAL_RESULT"], run_info["UNIT"] = analyzer.analysis( + batch_size=args.base_batch_size, + gpu_num=args.gpu_num, + skip_steps=args.skip_steps, + mode=args.model_mode, + run_mode=args.run_mode, + unit=args.ips_unit) + try: + if int(os.getenv('job_fail_flag')) == 1 or int(run_info[ + "FINAL_RESULT"]) == 0: + run_info["JOB_FAIL_FLAG"] = 1 + except: + pass + elif args.index == 3: + run_info["FINAL_RESULT"] = {} + records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead', + None, 3, '').records + records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead', + None, 5).records + records_ct_total = TimeAnalyzer(args.filename, 'Computation time', + None, 3, '').records + records_gm_total = TimeAnalyzer(args.filename, + 'GpuMemcpy Calls', + None, 4, '').records + records_gm_ratio = TimeAnalyzer(args.filename, + 'GpuMemcpy Calls', + None, 6).records + records_gmas_total = TimeAnalyzer(args.filename, + 'GpuMemcpyAsync Calls', + None, 4, '').records + records_gms_total = TimeAnalyzer(args.filename, + 'GpuMemcpySync Calls', + None, 4, '').records + run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[ + 0] if records_fo_total else 0 + run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[ + 0] if records_fo_ratio else 0 + run_info["FINAL_RESULT"][ + "ComputationTime_Total"] = records_ct_total[ + 0] if records_ct_total else 0 + run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[ + 0] if records_gm_total else 0 + run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[ + 0] if records_gm_ratio else 0 + run_info["FINAL_RESULT"][ + "GpuMemcpyAsync_Total"] = records_gmas_total[ + 0] if records_gmas_total else 0 + run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[ + 0] if records_gms_total else 0 + else: + print("Not support!") + except Exception: + traceback.print_exc() + print("{}".format(json.dumps(run_info)) + ) # it's required, for the log file path insert to the database diff --git a/benchmark/readme.md b/benchmark/readme.md new file mode 100644 index 0000000..d90d214 --- /dev/null +++ b/benchmark/readme.md @@ -0,0 +1,30 @@ + +# PaddleOCR DB/EAST/PSE 算法训练benchmark测试 + +PaddleOCR/benchmark目录下的文件用于获取并分析训练日志。 +训练采用icdar2015数据集,包括1000张训练图像和500张测试图像。模型配置采用resnet18_vd作为backbone,分别训练batch_size=8和batch_size=16的情况。 + +## 运行训练benchmark + +benchmark/run_det.sh 中包含了三个过程: +- 安装依赖 +- 下载数据 +- 执行训练 +- 日志分析获取IPS + +在执行训练部分,会执行单机单卡(默认0号卡)单机多卡训练,并分别执行batch_size=8和batch_size=16的情况。所以执行完后,每种模型会得到4个日志文件。 + +run_det.sh 执行方式如下: + +``` +# cd PaddleOCR/ +bash benchmark/run_det.sh +``` + +以DB为例,将得到四个日志文件,如下: +``` +det_res18_db_v2.0_sp_bs16_fp32_1 +det_res18_db_v2.0_sp_bs8_fp32_1 +det_res18_db_v2.0_mp_bs16_fp32_1 +det_res18_db_v2.0_mp_bs8_fp32_1 +``` diff --git a/benchmark/run_benchmark_det.sh b/benchmark/run_benchmark_det.sh new file mode 100644 index 0000000..9f5b46c --- /dev/null +++ b/benchmark/run_benchmark_det.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# 运行示例:CUDA_VISIBLE_DEVICES=0 bash run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode} +# 参数说明 +function _set_params(){ + run_mode=${1:-"sp"} # 单卡sp|多卡mp + batch_size=${2:-"64"} + fp_item=${3:-"fp32"} # fp32|fp16 + max_epoch=${4:-"10"} # 可选,如果需要修改代码提前中断 + model_item=${5:-"model_item"} + run_log_path=${TRAIN_LOG_DIR:-$(pwd)} # TRAIN_LOG_DIR 后续QA设置该参数 +# 日志解析所需参数 + base_batch_size=${batch_size} + mission_name="OCR" + direction_id="0" + ips_unit="images/sec" + skip_steps=2 # 解析日志,有些模型前几个step耗时长,需要跳过 (必填) + keyword="ips:" # 解析日志,筛选出数据所在行的关键字 (必填) + index="1" + model_name=${model_item}_bs${batch_size}_${fp_item} # model_item 用于yml文件名匹配,model_name 用于数据入库前端展示 +# 以下不用修改 + device=${CUDA_VISIBLE_DEVICES//,/ } + arr=(${device}) + num_gpu_devices=${#arr[*]} + log_file=${run_log_path}/${model_item}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices} +} +function _train(){ + echo "Train on ${num_gpu_devices} GPUs" + echo "current CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES, gpus=$num_gpu_devices, batch_size=$batch_size" + + train_cmd="-c configs/det/${model_item}.yml -o Train.loader.batch_size_per_card=${batch_size} Global.epoch_num=${max_epoch} Global.eval_batch_step=[0,20000] Global.print_batch_step=2" + case ${run_mode} in + sp) + train_cmd="python tools/train.py "${train_cmd}"" + ;; + mp) + rm -rf ./mylog + train_cmd="python -m paddle.distributed.launch --log_dir=./mylog --gpus=$CUDA_VISIBLE_DEVICES tools/train.py ${train_cmd}" + ;; + *) echo "choose run_mode(sp or mp)"; exit 1; + esac +# 以下不用修改 + echo ${train_cmd} + timeout 15m ${train_cmd} > ${log_file} 2>&1 + if [ $? -ne 0 ];then + echo -e "${model_name}, FAIL" + export job_fail_flag=1 + else + echo -e "${model_name}, SUCCESS" + export job_fail_flag=0 + fi + + if [ $run_mode = "mp" -a -d mylog ]; then + rm ${log_file} + cp mylog/workerlog.0 ${log_file} + fi +} + +source ${BENCHMARK_ROOT}/scripts/run_model.sh # 在该脚本中会对符合benchmark规范的log使用analysis.py 脚本进行性能数据解析;该脚本在连调时可从benchmark repo中下载https://github.com/PaddlePaddle/benchmark/blob/master/scripts/run_model.sh;如果不联调只想要产出训练log可以注掉本行,提交时需打开 +_set_params $@ +#_train # 如果只想产出训练log,不解析,可取消注释 +_run # 该函数在run_model.sh中,执行时会调用_train; 如果不联调只想要产出训练log可以注掉本行,提交时需打开 + diff --git a/benchmark/run_det.sh b/benchmark/run_det.sh new file mode 100644 index 0000000..981510c --- /dev/null +++ b/benchmark/run_det.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# 提供可稳定复现性能的脚本,默认在标准docker环境内py37执行: paddlepaddle/paddle:latest-gpu-cuda10.1-cudnn7 paddle=2.1.2 py=37 +# 执行目录: ./PaddleOCR +# 1 安装该模型需要的依赖 (如需开启优化策略请注明) +log_path=${LOG_PATH_INDEX_DIR:-$(pwd)} +python -m pip install -r requirements.txt +# 2 拷贝该模型需要数据、预训练模型 +wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar && cd train_data && tar xf icdar2015.tar && cd ../ +wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams +wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams +wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams +# 3 批量运行(如不方便批量,1,2需放到单个模型中) + +model_mode_list=(det_res18_db_v2.0 det_r50_vd_east det_r50_vd_pse) +fp_item_list=(fp32) +for model_mode in ${model_mode_list[@]}; do + for fp_item in ${fp_item_list[@]}; do + if [ ${model_mode} == "det_r50_vd_east" ]; then + bs_list=(16) + else + bs_list=(8 16) + fi + for bs_item in ${bs_list[@]}; do + echo "index is speed, 1gpus, begin, ${model_name}" + run_mode=sp + log_name=ocr_${model_mode}_bs${bs_item}_${fp_item}_${run_mode} + CUDA_VISIBLE_DEVICES=0 bash benchmark/run_benchmark_det.sh ${run_mode} ${bs_item} ${fp_item} 1 ${model_mode} | tee ${log_path}/${log_name}_speed_1gpus 2>&1 # (5min) + sleep 60 + echo "index is speed, 8gpus, run_mode is multi_process, begin, ${model_name}" + run_mode=mp + log_name=ocr_${model_mode}_bs${bs_item}_${fp_item}_${run_mode} + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash benchmark/run_benchmark_det.sh ${run_mode} ${bs_item} ${fp_item} 2 ${model_mode} | tee ${log_path}/${log_name}_speed_8gpus8p 2>&1 + sleep 60 + done + done +done + + + diff --git a/configs/det/SAST.yml b/configs/det/SAST.yml new file mode 100644 index 0000000..1512285 --- /dev/null +++ b/configs/det/SAST.yml @@ -0,0 +1,119 @@ +Global: + debug: false + use_gpu: false + epoch_num: 300 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/SAST/ + save_epoch_step: 2 + eval_batch_step: + - 40000 + - 50000 + cal_metric_during_train: false + pretrained_model: ./pretrain/detection + checkpoints: + save_inference_dir: ./inference/SAST + use_visualdl: True + infer_img: null + save_res_path: ./output/sast.txt +Architecture: + model_type: det + algorithm: SAST + Transform: + Backbone: + name: ResNet_SAST + layers: 50 + Neck: + name: SASTFPN + with_cab: true + Head: + name: SASTHead +Loss: + name: SASTLoss +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.001 + regularizer: + name: L2 + factor: 0 +PostProcess: + name: SASTPostProcess + score_thresh: 0.5 + sample_pts_num: 2 + nms_thresh: 0.5 + expand_scale: 0.2 + shrink_ratio_of_width: 0.15 +Metric: + name: DetMetric + main_indicator: hmean +Train: + dataset: + name: SimpleDataSet + data_dir: ./train/vietnamese/train_images + label_file_list: [./train/train_label.txt] + ratio_list: [1] + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - DetLabelEncode: null + - SASTProcessTrain: + image_shape: + - 512 + - 512 + min_crop_side_ratio: 0.3 + min_crop_size: 24 + min_text_size: 4 + max_text_size: 512 + - KeepKeys: + keep_keys: + - image + - score_map + - border_map + - training_mask + - tvo_map + - tco_map + loader: + shuffle: false + drop_last: false + batch_size_per_card: 8 + num_workers: 1 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train/vietnamese/test_image + label_file_list: + - ./train/test_label.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - DetLabelEncode: null + - DetResizeForTest: + resize_long: 1536 + - NormalizeImage: + scale: 1./255. + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + order: hwc + - ToCHWImage: null + - KeepKeys: + keep_keys: + - image + - shape + - polys + - ignore_tags + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 1 diff --git a/dat.jpg b/dat.jpg new file mode 100644 index 0000000..31f5f52 Binary files /dev/null and b/dat.jpg differ diff --git a/deploy/Jetson/images/00057937.jpg b/deploy/Jetson/images/00057937.jpg new file mode 100644 index 0000000..a35896e Binary files /dev/null and b/deploy/Jetson/images/00057937.jpg differ diff --git a/deploy/Jetson/images/det_res_french_0.jpg b/deploy/Jetson/images/det_res_french_0.jpg new file mode 100644 index 0000000..5f0e488 Binary files /dev/null and b/deploy/Jetson/images/det_res_french_0.jpg differ diff --git a/deploy/Jetson/readme.md b/deploy/Jetson/readme.md new file mode 100644 index 0000000..14c88c3 --- /dev/null +++ b/deploy/Jetson/readme.md @@ -0,0 +1,84 @@ +English | [简体中文](readme_ch.md) + +# Jetson Deployment for PaddleOCR + +This section introduces the deployment of PaddleOCR on Jetson NX, TX2, nano, AGX and other series of hardware. + + +## 1. Prepare Environment + +You need to prepare a Jetson development hardware. If you need TensorRT, you need to prepare the TensorRT environment. It is recommended to use TensorRT version 7.1.3; + +1. Install PaddlePaddle in Jetson + +The PaddlePaddle download [link](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python) +Please select the appropriate installation package for your Jetpack version, cuda version, and trt version. Here, we download paddlepaddle_gpu-2.3.0rc0-cp36-cp36m-linux_aarch64.whl. + +Install PaddlePaddle: +```shell +pip3 install -U paddlepaddle_gpu-2.3.0rc0-cp36-cp36m-linux_aarch64.whl +``` + + +2. Download PaddleOCR code and install dependencies + +Clone the PaddleOCR code: +``` +git clone https://github.com/PaddlePaddle/PaddleOCR +``` + +and install dependencies: +``` +cd PaddleOCR +pip3 install -r requirements.txt +``` + +*Note: Jetson hardware CPU is poor, dependency installation is slow, please wait patiently* + +## 2. Perform prediction + +Obtain the PPOCR model from the [document](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_en/ppocr_introduction_en.md#6-model-zoo) model library. The following takes the PP-OCRv3 model as an example to introduce the use of the PPOCR model on Jetson: + +Download and unzip the PP-OCRv3 models. +``` +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar +tar xf ch_PP-OCRv3_det_infer.tar +tar xf ch_PP-OCRv3_rec_infer.tar +``` + +The text detection inference: +``` +cd PaddleOCR +python3 tools/infer/predict_det.py --det_model_dir=./inference/ch_PP-OCRv2_det_infer/ --image_dir=./doc/imgs/french_0.jpg --use_gpu=True +``` + +After executing the command, the predicted information will be printed out in the terminal, and the visualization results will be saved in the `./inference_results/` directory. +![](./images/det_res_french_0.jpg) + + +The text recognition inference: +``` +python3 tools/infer/predict_det.py --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --image_dir=./doc/imgs_words/en/word_2.png --use_gpu=True --rec_image_shape="3,48,320" +``` + +After executing the command, the predicted information will be printed on the terminal, and the output is as follows: +``` +[2022/04/28 15:41:45] root INFO: Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.98084533) +``` + +The text detection and text recognition inference: + +``` +python3 tools/infer/predict_system.py --det_model_dir=./inference/ch_PP-OCRv2_det_infer/ --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --image_dir=./doc/imgs/00057937.jpg --use_gpu=True --rec_image_shape="3,48,320" +``` + +After executing the command, the predicted information will be printed out in the terminal, and the visualization results will be saved in the `./inference_results/` directory. +![](./images/00057937.jpg) + +To enable TRT prediction, you only need to set `--use_tensorrt=True` on the basis of the above command: +``` +python3 tools/infer/predict_system.py --det_model_dir=./inference/ch_PP-OCRv2_det_infer/ --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --image_dir=./doc/imgs/ --rec_image_shape="3,48,320" --use_gpu=True --use_tensorrt=True +``` + +For more ppocr model predictions, please refer to[document](../../doc/doc_en/models_list_en.md) diff --git a/deploy/Jetson/readme_ch.md b/deploy/Jetson/readme_ch.md new file mode 100644 index 0000000..7b0a344 --- /dev/null +++ b/deploy/Jetson/readme_ch.md @@ -0,0 +1,86 @@ +[English](readme.md) | 简体中文 + +# Jetson部署PaddleOCR模型 + +本节介绍PaddleOCR在Jetson NX、TX2、nano、AGX等系列硬件的部署。 + + +## 1. 环境准备 + +需要准备一台Jetson开发板,如果需要TensorRT预测,需准备好TensorRT环境,建议使用7.1.3版本的TensorRT; + +1. Jetson安装PaddlePaddle + +PaddlePaddle下载[链接](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python) +请选择适合的您Jetpack版本、cuda版本、trt版本的安装包。 + +安装命令: +```shell +# 安装paddle,以paddlepaddle_gpu-2.3.0rc0-cp36-cp36m-linux_aarch64.whl 为例 +pip3 install -U paddlepaddle_gpu-2.3.0rc0-cp36-cp36m-linux_aarch64.whl +``` + + +2. 下载PaddleOCR代码并安装依赖 + +首先 clone PaddleOCR 代码: +``` +git clone https://github.com/PaddlePaddle/PaddleOCR +``` + +然后,安装依赖: +``` +cd PaddleOCR +pip3 install -r requirements.txt +``` + +*注:jetson硬件CPU较差,依赖安装较慢,请耐心等待* + + +## 2. 执行预测 + +从[文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/ppocr_introduction.md#6-%E6%A8%A1%E5%9E%8B%E5%BA%93) 模型库中获取PPOCR模型,下面以PP-OCRv3模型为例,介绍在PPOCR模型在jetson上的使用方式: + +下载并解压PP-OCRv3模型 +``` +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar +tar xf ch_PP-OCRv3_det_infer.tar +tar xf ch_PP-OCRv3_rec_infer.tar +``` + +执行文本检测预测: +``` +cd PaddleOCR +python3 tools/infer/predict_det.py --det_model_dir=./inference/ch_PP-OCRv2_det_infer/ --image_dir=./doc/imgs/french_0.jpg --use_gpu=True +``` + +执行命令后在终端会打印出预测的信息,并在 `./inference_results/` 下保存可视化结果。 +![](./images/det_res_french_0.jpg) + + +执行文本识别预测: +``` +python3 tools/infer/predict_det.py --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --image_dir=./doc/imgs_words/en/word_2.png --use_gpu=True --rec_image_shape="3,48,320" +``` + +执行命令后在终端会打印出预测的信息,输出如下: +``` +[2022/04/28 15:41:45] root INFO: Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.98084533) +``` + +执行文本检测+文本识别串联预测: + +``` +python3 tools/infer/predict_system.py --det_model_dir=./inference/ch_PP-OCRv2_det_infer/ --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --image_dir=./doc/imgs/ --use_gpu=True --rec_image_shape="3,48,320" +``` + +执行命令后在终端会打印出预测的信息,并在 `./inference_results/` 下保存可视化结果。 +![](./images/00057937.jpg) + +开启TRT预测只需要在以上命令基础上设置`--use_tensorrt=True`即可: +``` +python3 tools/infer/predict_system.py --det_model_dir=./inference/ch_PP-OCRv2_det_infer/ --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --image_dir=./doc/imgs/00057937.jpg --use_gpu=True --use_tensorrt=True --rec_image_shape="3,48,320" +``` + +更多ppocr模型预测请参考[文档](../../doc/doc_ch/models_list.md) diff --git a/deploy/README.md b/deploy/README.md new file mode 100644 index 0000000..0cfb793 --- /dev/null +++ b/deploy/README.md @@ -0,0 +1,31 @@ +English | [简体中文](README_ch.md) + +# PP-OCR Deployment + +- [Paddle Deployment Introduction](#1) +- [PP-OCR Deployment](#2) + + +## Paddle Deployment Introduction + +Paddle provides a variety of deployment schemes to meet the deployment requirements of different scenarios. Please choose according to the actual situation: + +
+ +
+ + + +## PP-OCR Deployment + +PP-OCR has supported muti deployment schemes. Click the link to get the specific tutorial. + +- [Python Inference](../doc/doc_en/inference_ppocr_en.md) +- [C++ Inference](./cpp_infer/readme.md) +- [Serving (Python/C++)](./pdserving/README.md) +- [Paddle-Lite (ARM CPU/OpenCL ARM GPU)](./lite/readme.md) +- [Paddle.js](./paddlejs/README.md) +- [Jetson Inference](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/deploy/Jetson/readme.md) +- [Paddle2ONNX](./paddle2onnx/readme.md) + +If you need the deployment tutorial of academic algorithm models other than PP-OCR, please directly enter the main page of corresponding algorithms, [entrance](../doc/doc_en/algorithm_overview_en.md)。 diff --git a/deploy/README_ch.md b/deploy/README_ch.md new file mode 100644 index 0000000..1773aed --- /dev/null +++ b/deploy/README_ch.md @@ -0,0 +1,31 @@ +[English](README.md) | 简体中文 + +# PP-OCR 模型推理部署 + +- [Paddle 推理部署方式简介](#1) +- [PP-OCR 推理部署](#2) + + +## Paddle 推理部署方式简介 + +飞桨提供多种部署方案,以满足不同场景的部署需求,请根据实际情况进行选择: + +
+ +
+ + + +## PP-OCR 推理部署 + +PP-OCR模型已打通多种场景部署方案,点击链接获取具体的使用教程。 + +- [Python 推理](../doc/doc_ch/inference_ppocr.md) +- [C++ 推理](./cpp_infer/readme_ch.md) +- [Serving 服务化部署(Python/C++)](./pdserving/README_CN.md) +- [Paddle-Lite 端侧部署(ARM CPU/OpenCL ARM GPU)](./lite/readme_ch.md) +- [Paddle.js 部署](./paddlejs/README_ch.md) +- [Jetson 推理](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/deploy/Jetson/readme_ch.md) +- [Paddle2ONNX 推理](./paddle2onnx/readme_ch.md) + +需要PP-OCR以外的学术算法模型的推理部署,请直接进入相应算法主页面,[入口](../doc/doc_ch/algorithm_overview.md)。 \ No newline at end of file diff --git a/deploy/android_demo/.gitignore b/deploy/android_demo/.gitignore new file mode 100644 index 0000000..93dcb29 --- /dev/null +++ b/deploy/android_demo/.gitignore @@ -0,0 +1,9 @@ +*.iml +.gradle +/local.properties +/.idea/* +.DS_Store +/build +/captures +.externalNativeBuild + diff --git a/deploy/android_demo/README.md b/deploy/android_demo/README.md new file mode 100644 index 0000000..ba615fb --- /dev/null +++ b/deploy/android_demo/README.md @@ -0,0 +1,118 @@ +- [Android Demo](#android-demo) + - [1. 简介](#1-简介) + - [2. 近期更新](#2-近期更新) + - [3. 快速使用](#3-快速使用) + - [3.1 环境准备](#31-环境准备) + - [3.2 导入项目](#32-导入项目) + - [3.3 运行demo](#33-运行demo) + - [3.4 运行模式](#34-运行模式) + - [3.5 设置](#35-设置) + - [4 更多支持](#4-更多支持) + +# Android Demo + +## 1. 简介 +此为PaddleOCR的Android Demo,目前支持文本检测,文本方向分类器和文本识别模型的使用。使用 [PaddleLite v2.10](https://github.com/PaddlePaddle/Paddle-Lite/tree/release/v2.10) 进行开发。 + +## 2. 近期更新 +* 2022.02.27 + * 预测库更新到PaddleLite v2.10 + * 支持6种运行模式: + * 检测+分类+识别 + * 检测+识别 + * 分类+识别 + * 检测 + * 识别 + * 分类 + +## 3. 快速使用 + +### 3.1 环境准备 +1. 在本地环境安装好 Android Studio 工具,详细安装方法请见[Android Stuido 官网](https://developer.android.com/studio)。 +2. 准备一部 Android 手机,并开启 USB 调试模式。开启方法: `手机设置 -> 查找开发者选项 -> 打开开发者选项和 USB 调试模式` + +**注意**:如果您的 Android Studio 尚未配置 NDK ,请根据 Android Studio 用户指南中的[安装及配置 NDK 和 CMake ](https://developer.android.com/studio/projects/install-ndk)内容,预先配置好 NDK 。您可以选择最新的 NDK 版本,或者使用 Paddle Lite 预测库版本一样的 NDK + +### 3.2 导入项目 + +点击 File->New->Import Project..., 然后跟着Android Studio的引导导入 +导入完成后呈现如下界面 +![](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/imgs/import_demo.jpg) + +### 3.3 运行demo +将手机连接上电脑后,点击Android Studio工具栏中的运行按钮即可运行demo。在此过程中,手机会弹出"允许从 USB 安装软件权限"的弹窗,点击允许即可。 + +软件安转到手机上后会在手机主屏最后一页看到如下app +
+ +
+ +点击app图标即可启动app,启动后app主页如下 + +
+ +
+ +app主页中有四个按钮,一个下拉列表和一个菜单按钮,他们的功能分别为 + +* 运行模型:按照已选择的模式,运行对应的模型组合 +* 拍照识别:唤起手机相机拍照并获取拍照的图像,拍照完成后需要点击运行模型进行识别 +* 选取图片:唤起手机相册拍照选择图像,选择完成后需要点击运行模型进行识别 +* 清空绘图:清空当前显示图像上绘制的文本框,以便进行下一次识别(每次识别使用的图像都是当前显示的图像) +* 下拉列表:进行运行模式的选择,目前包含6种运行模式,默认模式为**检测+分类+识别**详细说明见下一节。 +* 菜单按钮:点击后会进入菜单界面,进行模型和内置图像有关设置 + +点击运行模型后,会按照所选择的模式运行对应的模型,**检测+分类+识别**模式下运行的模型结果如下所示: + + + +模型运行完成后,模型和运行状态显示区`STATUS`字段显示了当前模型的运行状态,这里显示为`run model successed`表明模型运行成功。 + +模型的运行结果显示在运行结果显示区,显示格式为 +```text +序号:Det:(x1,y1)(x2,y2)(x3,y3)(x4,y4) Rec: 识别文本,识别置信度 Cls:分类类别,分类分时 +``` + +### 3.4 运行模式 + +PaddleOCR demo共提供了6种运行模式,如下图 +
+ +
+ +每种模式的运行结果如下表所示 + +| 检测+分类+识别 | 检测+识别 | 分类+识别 | +|------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------| +| | | | + + +| 检测 | 识别 | 分类 | +|----------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------| +| | | | + +### 3.5 设置 + +设置界面如下 + +
+ +
+ +在设置界面可以进行如下几项设定: +1. 普通设置 + * Enable custom settings: 选中状态下才能更改设置 + * Model Path: 所运行的模型地址,使用默认值就好 + * Label Path: 识别模型的字典 + * Image Path: 进行识别的内置图像名 +2. 模型运行态设置,此项设置更改后返回主界面时,会自动重新加载模型 + * CPU Thread Num: 模型运行使用的CPU核心数量 + * CPU Power Mode: 模型运行模式,大小核设定 +3. 输入设置 + * det long size: DB模型预处理时图像的长边长度,超过此长度resize到该值,短边进行等比例缩放,小于此长度不进行处理。 +4. 输出设置 + * Score Threshold: DB模型后处理box的阈值,低于此阈值的box进行过滤,不显示。 + +## 4 更多支持 +1. 实时识别,更新预测库可参考 https://github.com/PaddlePaddle/Paddle-Lite-Demo/tree/develop/ocr/android/app/cxx/ppocr_demo +2. 更多Paddle-Lite相关问题可前往[Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite) ,获得更多开发支持 diff --git a/deploy/android_demo/app/.gitignore b/deploy/android_demo/app/.gitignore new file mode 100644 index 0000000..796b96d --- /dev/null +++ b/deploy/android_demo/app/.gitignore @@ -0,0 +1 @@ +/build diff --git a/deploy/android_demo/app/build.gradle b/deploy/android_demo/app/build.gradle new file mode 100644 index 0000000..2607f32 --- /dev/null +++ b/deploy/android_demo/app/build.gradle @@ -0,0 +1,93 @@ +import java.security.MessageDigest + +apply plugin: 'com.android.application' + +android { + compileSdkVersion 29 + defaultConfig { + applicationId "com.baidu.paddle.lite.demo.ocr" + minSdkVersion 23 + targetSdkVersion 29 + versionCode 2 + versionName "2.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + externalNativeBuild { + cmake { + cppFlags "-std=c++11 -frtti -fexceptions -Wno-format" + arguments '-DANDROID_PLATFORM=android-23', '-DANDROID_STL=c++_shared' ,"-DANDROID_ARM_NEON=TRUE" + } + } + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + externalNativeBuild { + cmake { + path "src/main/cpp/CMakeLists.txt" + version "3.10.2" + } + } +} + +dependencies { + implementation fileTree(include: ['*.jar'], dir: 'libs') + implementation 'androidx.appcompat:appcompat:1.1.0' + implementation 'androidx.constraintlayout:constraintlayout:1.1.3' + testImplementation 'junit:junit:4.12' + androidTestImplementation 'com.android.support.test:runner:1.0.2' + androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2' +} + +def archives = [ + [ + 'src' : 'https://paddleocr.bj.bcebos.com/libs/paddle_lite_libs_v2_10.tar.gz', + 'dest': 'PaddleLite' + ], + [ + 'src' : 'https://paddlelite-demo.bj.bcebos.com/libs/android/opencv-4.2.0-android-sdk.tar.gz', + 'dest': 'OpenCV' + ], + [ + 'src' : 'https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2.tar.gz', + 'dest' : 'src/main/assets/models' + ], + [ + 'src' : 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_dict.tar.gz', + 'dest' : 'src/main/assets/labels' + ] +] + +task downloadAndExtractArchives(type: DefaultTask) { + doFirst { + println "Downloading and extracting archives including libs and models" + } + doLast { + // Prepare cache folder for archives + String cachePath = "cache" + if (!file("${cachePath}").exists()) { + mkdir "${cachePath}" + } + archives.eachWithIndex { archive, index -> + MessageDigest messageDigest = MessageDigest.getInstance('MD5') + messageDigest.update(archive.src.bytes) + String cacheName = new BigInteger(1, messageDigest.digest()).toString(32) + // Download the target archive if not exists + boolean copyFiles = !file("${archive.dest}").exists() + if (!file("${cachePath}/${cacheName}.tar.gz").exists()) { + ant.get(src: archive.src, dest: file("${cachePath}/${cacheName}.tar.gz")) + copyFiles = true; // force to copy files from the latest archive files + } + // Extract the target archive if its dest path does not exists + if (copyFiles) { + copy { + from tarTree("${cachePath}/${cacheName}.tar.gz") + into "${archive.dest}" + } + } + } + } +} +preBuild.dependsOn downloadAndExtractArchives \ No newline at end of file diff --git a/deploy/android_demo/app/proguard-rules.pro b/deploy/android_demo/app/proguard-rules.pro new file mode 100644 index 0000000..f1b4245 --- /dev/null +++ b/deploy/android_demo/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/deploy/android_demo/app/src/androidTest/java/com/baidu/paddle/lite/demo/ocr/ExampleInstrumentedTest.java b/deploy/android_demo/app/src/androidTest/java/com/baidu/paddle/lite/demo/ocr/ExampleInstrumentedTest.java new file mode 100644 index 0000000..77b179d --- /dev/null +++ b/deploy/android_demo/app/src/androidTest/java/com/baidu/paddle/lite/demo/ocr/ExampleInstrumentedTest.java @@ -0,0 +1,26 @@ +package com.baidu.paddle.lite.demo.ocr; + +import android.content.Context; +import android.support.test.InstrumentationRegistry; +import android.support.test.runner.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.junit.Assert.*; + +/** + * Instrumented test, which will execute on an Android device. + * + * @see Testing documentation + */ +@RunWith(AndroidJUnit4.class) +public class ExampleInstrumentedTest { + @Test + public void useAppContext() { + // Context of the app under test. + Context appContext = InstrumentationRegistry.getTargetContext(); + + assertEquals("com.baidu.paddle.lite.demo", appContext.getPackageName()); + } +} diff --git a/deploy/android_demo/app/src/main/AndroidManifest.xml b/deploy/android_demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000..133f357 --- /dev/null +++ b/deploy/android_demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/assets/images/det_0.jpg b/deploy/android_demo/app/src/main/assets/images/det_0.jpg new file mode 100644 index 0000000..8517e12 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/det_0.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/det_180.jpg b/deploy/android_demo/app/src/main/assets/images/det_180.jpg new file mode 100644 index 0000000..b1bb8a4 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/det_180.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/det_270.jpg b/deploy/android_demo/app/src/main/assets/images/det_270.jpg new file mode 100644 index 0000000..5687390 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/det_270.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/det_90.jpg b/deploy/android_demo/app/src/main/assets/images/det_90.jpg new file mode 100644 index 0000000..49e949a Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/det_90.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/rec_0.jpg b/deploy/android_demo/app/src/main/assets/images/rec_0.jpg new file mode 100644 index 0000000..2c34cd3 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/rec_0.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/rec_0_180.jpg b/deploy/android_demo/app/src/main/assets/images/rec_0_180.jpg new file mode 100644 index 0000000..02bc3b9 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/rec_0_180.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/rec_1.jpg b/deploy/android_demo/app/src/main/assets/images/rec_1.jpg new file mode 100644 index 0000000..22031ba Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/rec_1.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/rec_1_180.jpg b/deploy/android_demo/app/src/main/assets/images/rec_1_180.jpg new file mode 100644 index 0000000..d745530 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/rec_1_180.jpg differ diff --git a/deploy/android_demo/app/src/main/cpp/CMakeLists.txt b/deploy/android_demo/app/src/main/cpp/CMakeLists.txt new file mode 100644 index 0000000..742786a --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/CMakeLists.txt @@ -0,0 +1,117 @@ +# For more information about using CMake with Android Studio, read the +# documentation: https://d.android.com/studio/projects/add-native-code.html + +# Sets the minimum version of CMake required to build the native library. + +cmake_minimum_required(VERSION 3.4.1) + +# Creates and names a library, sets it as either STATIC or SHARED, and provides +# the relative paths to its source code. You can define multiple libraries, and +# CMake builds them for you. Gradle automatically packages shared libraries with +# your APK. + +set(PaddleLite_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../PaddleLite") +include_directories(${PaddleLite_DIR}/cxx/include) + +set(OpenCV_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../OpenCV/sdk/native/jni") +message(STATUS "opencv dir: ${OpenCV_DIR}") +find_package(OpenCV REQUIRED) +message(STATUS "OpenCV libraries: ${OpenCV_LIBS}") +include_directories(${OpenCV_INCLUDE_DIRS}) +aux_source_directory(. SOURCES) +set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os" + ) +set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS} -fvisibility=hidden -fvisibility-inlines-hidden -fdata-sections -ffunction-sections" + ) +set(CMAKE_SHARED_LINKER_FLAGS + "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--gc-sections -Wl,-z,nocopyreloc") + +add_library( + # Sets the name of the library. + Native + # Sets the library as a shared library. + SHARED + # Provides a relative path to your source file(s). + ${SOURCES}) + +find_library( + # Sets the name of the path variable. + log-lib + # Specifies the name of the NDK library that you want CMake to locate. + log) + +add_library( + # Sets the name of the library. + paddle_light_api_shared + # Sets the library as a shared library. + SHARED + # Provides a relative path to your source file(s). + IMPORTED) + +set_target_properties( + # Specifies the target library. + paddle_light_api_shared + # Specifies the parameter you want to define. + PROPERTIES + IMPORTED_LOCATION + ${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libpaddle_light_api_shared.so + # Provides the path to the library you want to import. +) + + +# Specifies libraries CMake should link to your target library. You can link +# multiple libraries, such as libraries you define in this build script, +# prebuilt third-party libraries, or system libraries. + +target_link_libraries( + # Specifies the target library. + Native + paddle_light_api_shared + ${OpenCV_LIBS} + GLESv2 + EGL + jnigraphics + ${log-lib} +) + +add_custom_command( + TARGET Native + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy + ${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libc++_shared.so + ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libc++_shared.so) + +add_custom_command( + TARGET Native + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy + ${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libpaddle_light_api_shared.so + ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libpaddle_light_api_shared.so) + +add_custom_command( + TARGET Native + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy + ${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libhiai.so + ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libhiai.so) + +add_custom_command( + TARGET Native + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy + ${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libhiai_ir.so + ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libhiai_ir.so) + +add_custom_command( + TARGET Native + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy + ${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libhiai_ir_build.so + ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libhiai_ir_build.so) \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/common.h b/deploy/android_demo/app/src/main/cpp/common.h new file mode 100644 index 0000000..fc47407 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/common.h @@ -0,0 +1,37 @@ +// +// Created by fu on 4/25/18. +// + +#pragma once +#import +#import + +#ifdef __ANDROID__ + +#include + +#define LOG_TAG "OCR_NDK" + +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) +#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__) +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) +#else +#include +#define LOGI(format, ...) \ + fprintf(stdout, "[" LOG_TAG "]" format "\n", ##__VA_ARGS__) +#define LOGW(format, ...) \ + fprintf(stdout, "[" LOG_TAG "]" format "\n", ##__VA_ARGS__) +#define LOGE(format, ...) \ + fprintf(stderr, "[" LOG_TAG "]Error: " format "\n", ##__VA_ARGS__) +#endif + +enum RETURN_CODE { RETURN_OK = 0 }; + +enum NET_TYPE { NET_OCR = 900100, NET_OCR_INTERNAL = 991008 }; + +template inline T product(const std::vector &vec) { + if (vec.empty()) { + return 0; + } + return std::accumulate(vec.begin(), vec.end(), 1, std::multiplies()); +} diff --git a/deploy/android_demo/app/src/main/cpp/native.cpp b/deploy/android_demo/app/src/main/cpp/native.cpp new file mode 100644 index 0000000..4961e5e --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/native.cpp @@ -0,0 +1,119 @@ +// +// Created by fujiayi on 2020/7/5. +// + +#include "native.h" +#include "ocr_ppredictor.h" +#include +#include +#include + +static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode); + +extern "C" JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( + JNIEnv *env, jobject thiz, jstring j_det_model_path, + jstring j_rec_model_path, jstring j_cls_model_path, jint j_use_opencl, jint j_thread_num, + jstring j_cpu_mode) { + std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path); + std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path); + std::string cls_model_path = jstring_to_cpp_string(env, j_cls_model_path); + int thread_num = j_thread_num; + std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode); + ppredictor::OCR_Config conf; + conf.use_opencl = j_use_opencl; + conf.thread_num = thread_num; + conf.mode = str_to_cpu_mode(cpu_mode); + ppredictor::OCR_PPredictor *orc_predictor = + new ppredictor::OCR_PPredictor{conf}; + orc_predictor->init_from_file(det_model_path, rec_model_path, cls_model_path); + return reinterpret_cast(orc_predictor); +} + +/** + * "LITE_POWER_HIGH" convert to paddle::lite_api::LITE_POWER_HIGH + * @param cpu_mode + * @return + */ +static paddle::lite_api::PowerMode +str_to_cpu_mode(const std::string &cpu_mode) { + static std::map cpu_mode_map{ + {"LITE_POWER_HIGH", paddle::lite_api::LITE_POWER_HIGH}, + {"LITE_POWER_LOW", paddle::lite_api::LITE_POWER_HIGH}, + {"LITE_POWER_FULL", paddle::lite_api::LITE_POWER_FULL}, + {"LITE_POWER_NO_BIND", paddle::lite_api::LITE_POWER_NO_BIND}, + {"LITE_POWER_RAND_HIGH", paddle::lite_api::LITE_POWER_RAND_HIGH}, + {"LITE_POWER_RAND_LOW", paddle::lite_api::LITE_POWER_RAND_LOW}}; + std::string upper_key; + std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(), + ::toupper); + auto index = cpu_mode_map.find(upper_key.c_str()); + if (index == cpu_mode_map.end()) { + LOGE("cpu_mode not found %s", upper_key.c_str()); + return paddle::lite_api::LITE_POWER_HIGH; + } else { + return index->second; + } +} + +extern "C" JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( + JNIEnv *env, jobject thiz, jlong java_pointer, jobject original_image,jint j_max_size_len, jint j_run_det, jint j_run_cls, jint j_run_rec) { + LOGI("begin to run native forward"); + if (java_pointer == 0) { + LOGE("JAVA pointer is NULL"); + return cpp_array_to_jfloatarray(env, nullptr, 0); + } + + cv::Mat origin = bitmap_to_cv_mat(env, original_image); + if (origin.size == 0) { + LOGE("origin bitmap cannot convert to CV Mat"); + return cpp_array_to_jfloatarray(env, nullptr, 0); + } + + int max_size_len = j_max_size_len; + int run_det = j_run_det; + int run_cls = j_run_cls; + int run_rec = j_run_rec; + + ppredictor::OCR_PPredictor *ppredictor = + (ppredictor::OCR_PPredictor *)java_pointer; + std::vector dims_arr; + std::vector results = + ppredictor->infer_ocr(origin, max_size_len, run_det, run_cls, run_rec); + LOGI("infer_ocr finished with boxes %ld", results.size()); + + // 这里将std::vector 序列化成 + // float数组,传输到java层再反序列化 + std::vector float_arr; + for (const ppredictor::OCRPredictResult &r : results) { + float_arr.push_back(r.points.size()); + float_arr.push_back(r.word_index.size()); + float_arr.push_back(r.score); + // add det point + for (const std::vector &point : r.points) { + float_arr.push_back(point.at(0)); + float_arr.push_back(point.at(1)); + } + // add rec word idx + for (int index : r.word_index) { + float_arr.push_back(index); + } + // add cls result + float_arr.push_back(r.cls_label); + float_arr.push_back(r.cls_score); + } + return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size()); +} + +extern "C" JNIEXPORT void JNICALL +Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release( + JNIEnv *env, jobject thiz, jlong java_pointer) { + if (java_pointer == 0) { + LOGE("JAVA pointer is NULL"); + return; + } + ppredictor::OCR_PPredictor *ppredictor = + (ppredictor::OCR_PPredictor *)java_pointer; + delete ppredictor; +} diff --git a/deploy/android_demo/app/src/main/cpp/native.h b/deploy/android_demo/app/src/main/cpp/native.h new file mode 100644 index 0000000..9b8e4e4 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/native.h @@ -0,0 +1,137 @@ +// +// Created by fujiayi on 2020/7/5. +// + +#pragma once + +#include "common.h" +#include +#include +#include +#include +#include + +inline std::string jstring_to_cpp_string(JNIEnv *env, jstring jstr) { + // In java, a unicode char will be encoded using 2 bytes (utf16). + // so jstring will contain characters utf16. std::string in c++ is + // essentially a string of bytes, not characters, so if we want to + // pass jstring from JNI to c++, we have convert utf16 to bytes. + if (!jstr) { + return ""; + } + const jclass stringClass = env->GetObjectClass(jstr); + const jmethodID getBytes = + env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B"); + const jbyteArray stringJbytes = (jbyteArray)env->CallObjectMethod( + jstr, getBytes, env->NewStringUTF("UTF-8")); + + size_t length = (size_t)env->GetArrayLength(stringJbytes); + jbyte *pBytes = env->GetByteArrayElements(stringJbytes, NULL); + + std::string ret = std::string(reinterpret_cast(pBytes), length); + env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT); + + env->DeleteLocalRef(stringJbytes); + env->DeleteLocalRef(stringClass); + return ret; +} + +inline jstring cpp_string_to_jstring(JNIEnv *env, std::string str) { + auto *data = str.c_str(); + jclass strClass = env->FindClass("java/lang/String"); + jmethodID strClassInitMethodID = + env->GetMethodID(strClass, "", "([BLjava/lang/String;)V"); + + jbyteArray bytes = env->NewByteArray(strlen(data)); + env->SetByteArrayRegion(bytes, 0, strlen(data), + reinterpret_cast(data)); + + jstring encoding = env->NewStringUTF("UTF-8"); + jstring res = (jstring)( + env->NewObject(strClass, strClassInitMethodID, bytes, encoding)); + + env->DeleteLocalRef(strClass); + env->DeleteLocalRef(encoding); + env->DeleteLocalRef(bytes); + + return res; +} + +inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env, const float *buf, + int64_t len) { + if (len == 0) { + return env->NewFloatArray(0); + } + jfloatArray result = env->NewFloatArray(len); + env->SetFloatArrayRegion(result, 0, len, buf); + return result; +} + +inline jintArray cpp_array_to_jintarray(JNIEnv *env, const int *buf, + int64_t len) { + jintArray result = env->NewIntArray(len); + env->SetIntArrayRegion(result, 0, len, buf); + return result; +} + +inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env, const int8_t *buf, + int64_t len) { + jbyteArray result = env->NewByteArray(len); + env->SetByteArrayRegion(result, 0, len, buf); + return result; +} + +inline jlongArray int64_vector_to_jlongarray(JNIEnv *env, + const std::vector &vec) { + jlongArray result = env->NewLongArray(vec.size()); + jlong *buf = new jlong[vec.size()]; + for (size_t i = 0; i < vec.size(); ++i) { + buf[i] = (jlong)vec[i]; + } + env->SetLongArrayRegion(result, 0, vec.size(), buf); + delete[] buf; + return result; +} + +inline std::vector jlongarray_to_int64_vector(JNIEnv *env, + jlongArray data) { + int data_size = env->GetArrayLength(data); + jlong *data_ptr = env->GetLongArrayElements(data, nullptr); + std::vector data_vec(data_ptr, data_ptr + data_size); + env->ReleaseLongArrayElements(data, data_ptr, 0); + return data_vec; +} + +inline std::vector jfloatarray_to_float_vector(JNIEnv *env, + jfloatArray data) { + int data_size = env->GetArrayLength(data); + jfloat *data_ptr = env->GetFloatArrayElements(data, nullptr); + std::vector data_vec(data_ptr, data_ptr + data_size); + env->ReleaseFloatArrayElements(data, data_ptr, 0); + return data_vec; +} + +inline cv::Mat bitmap_to_cv_mat(JNIEnv *env, jobject bitmap) { + AndroidBitmapInfo info; + int result = AndroidBitmap_getInfo(env, bitmap, &info); + if (result != ANDROID_BITMAP_RESULT_SUCCESS) { + LOGE("AndroidBitmap_getInfo failed, result: %d", result); + return cv::Mat{}; + } + if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) { + LOGE("Bitmap format is not RGBA_8888 !"); + return cv::Mat{}; + } + unsigned char *srcData = NULL; + AndroidBitmap_lockPixels(env, bitmap, (void **)&srcData); + cv::Mat mat = cv::Mat::zeros(info.height, info.width, CV_8UC4); + memcpy(mat.data, srcData, info.height * info.width * 4); + AndroidBitmap_unlockPixels(env, bitmap); + cv::cvtColor(mat, mat, cv::COLOR_RGBA2BGR); + /** + if (!cv::imwrite("/sdcard/1/copy.jpg", mat)){ + LOGE("Write image failed " ); + } + */ + return mat; +} diff --git a/deploy/android_demo/app/src/main/cpp/ocr_clipper.cpp b/deploy/android_demo/app/src/main/cpp/ocr_clipper.cpp new file mode 100644 index 0000000..4a531fc --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_clipper.cpp @@ -0,0 +1,4380 @@ +/******************************************************************************* +* * +* Author : Angus Johnson * +* Version : 6.4.2 * +* Date : 27 February 2017 * +* Website : http://www.angusj.com * +* Copyright : Angus Johnson 2010-2017 * +* * +* License: * +* Use, modification & distribution is subject to Boost Software License Ver 1. * +* http://www.boost.org/LICENSE_1_0.txt * +* * +* Attributions: * +* The code in this library is an extension of Bala Vatti's clipping algorithm: * +* "A generic solution to polygon clipping" * +* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. * +* http://portal.acm.org/citation.cfm?id=129906 * +* * +* Computer graphics and geometric modeling: implementation and algorithms * +* By Max K. Agoston * +* Springer; 1 edition (January 4, 2005) * +* http://books.google.com/books?q=vatti+clipping+agoston * +* * +* See also: * +* "Polygon Offsetting by Computing Winding Numbers" * +* Paper no. DETC2005-85513 pp. 565-575 * +* ASME 2005 International Design Engineering Technical Conferences * +* and Computers and Information in Engineering Conference (IDETC/CIE2005) * +* September 24-28, 2005 , Long Beach, California, USA * +* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf * +* * +*******************************************************************************/ + +/******************************************************************************* +* * +* This is a translation of the Delphi Clipper library and the naming style * +* used has retained a Delphi flavour. * +* * +*******************************************************************************/ + +#include "ocr_clipper.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ClipperLib { + +static double const pi = 3.141592653589793238; +static double const two_pi = pi * 2; +static double const def_arc_tolerance = 0.25; + +enum Direction { dRightToLeft, dLeftToRight }; + +static int const Unassigned = -1; // edge not currently 'owning' a solution +static int const Skip = -2; // edge that would otherwise close a path + +#define HORIZONTAL (-1.0E+40) +#define TOLERANCE (1.0e-20) +#define NEAR_ZERO(val) (((val) > -TOLERANCE) && ((val) < TOLERANCE)) + +struct TEdge { + IntPoint Bot; + IntPoint Curr; // current (updated for every new scanbeam) + IntPoint Top; + double Dx; + PolyType PolyTyp; + EdgeSide Side; // side only refers to current side of solution poly + int WindDelta; // 1 or -1 depending on winding direction + int WindCnt; + int WindCnt2; // winding count of the opposite polytype + int OutIdx; + TEdge *Next; + TEdge *Prev; + TEdge *NextInLML; + TEdge *NextInAEL; + TEdge *PrevInAEL; + TEdge *NextInSEL; + TEdge *PrevInSEL; +}; + +struct IntersectNode { + TEdge *Edge1; + TEdge *Edge2; + IntPoint Pt; +}; + +struct LocalMinimum { + cInt Y; + TEdge *LeftBound; + TEdge *RightBound; +}; + +struct OutPt; + +// OutRec: contains a path in the clipping solution. Edges in the AEL will +// carry a pointer to an OutRec when they are part of the clipping solution. +struct OutRec { + int Idx; + bool IsHole; + bool IsOpen; + OutRec *FirstLeft; // see comments in clipper.pas + PolyNode *PolyNd; + OutPt *Pts; + OutPt *BottomPt; +}; + +struct OutPt { + int Idx; + IntPoint Pt; + OutPt *Next; + OutPt *Prev; +}; + +struct Join { + OutPt *OutPt1; + OutPt *OutPt2; + IntPoint OffPt; +}; + +struct LocMinSorter { + inline bool operator()(const LocalMinimum &locMin1, + const LocalMinimum &locMin2) { + return locMin2.Y < locMin1.Y; + } +}; + +//------------------------------------------------------------------------------ +//------------------------------------------------------------------------------ + +inline cInt Round(double val) { + if ((val < 0)) + return static_cast(val - 0.5); + else + return static_cast(val + 0.5); +} +//------------------------------------------------------------------------------ + +inline cInt Abs(cInt val) { return val < 0 ? -val : val; } + +//------------------------------------------------------------------------------ +// PolyTree methods ... +//------------------------------------------------------------------------------ + +void PolyTree::Clear() { + for (PolyNodes::size_type i = 0; i < AllNodes.size(); ++i) + delete AllNodes[i]; + AllNodes.resize(0); + Childs.resize(0); +} +//------------------------------------------------------------------------------ + +PolyNode *PolyTree::GetFirst() const { + if (!Childs.empty()) + return Childs[0]; + else + return 0; +} +//------------------------------------------------------------------------------ + +int PolyTree::Total() const { + int result = (int)AllNodes.size(); + // with negative offsets, ignore the hidden outer polygon ... + if (result > 0 && Childs[0] != AllNodes[0]) + result--; + return result; +} + +//------------------------------------------------------------------------------ +// PolyNode methods ... +//------------------------------------------------------------------------------ + +PolyNode::PolyNode() : Parent(0), Index(0), m_IsOpen(false) {} +//------------------------------------------------------------------------------ + +int PolyNode::ChildCount() const { return (int)Childs.size(); } +//------------------------------------------------------------------------------ + +void PolyNode::AddChild(PolyNode &child) { + unsigned cnt = (unsigned)Childs.size(); + Childs.push_back(&child); + child.Parent = this; + child.Index = cnt; +} +//------------------------------------------------------------------------------ + +PolyNode *PolyNode::GetNext() const { + if (!Childs.empty()) + return Childs[0]; + else + return GetNextSiblingUp(); +} +//------------------------------------------------------------------------------ + +PolyNode *PolyNode::GetNextSiblingUp() const { + if (!Parent) // protects against PolyTree.GetNextSiblingUp() + return 0; + else if (Index == Parent->Childs.size() - 1) + return Parent->GetNextSiblingUp(); + else + return Parent->Childs[Index + 1]; +} +//------------------------------------------------------------------------------ + +bool PolyNode::IsHole() const { + bool result = true; + PolyNode *node = Parent; + while (node) { + result = !result; + node = node->Parent; + } + return result; +} +//------------------------------------------------------------------------------ + +bool PolyNode::IsOpen() const { return m_IsOpen; } +//------------------------------------------------------------------------------ + +#ifndef use_int32 + +//------------------------------------------------------------------------------ +// Int128 class (enables safe math on signed 64bit integers) +// eg Int128 val1((long64)9223372036854775807); //ie 2^63 -1 +// Int128 val2((long64)9223372036854775807); +// Int128 val3 = val1 * val2; +// val3.AsString => "85070591730234615847396907784232501249" (8.5e+37) +//------------------------------------------------------------------------------ + +class Int128 { +public: + ulong64 lo; + long64 hi; + + Int128(long64 _lo = 0) { + lo = (ulong64)_lo; + if (_lo < 0) + hi = -1; + else + hi = 0; + } + + Int128(const Int128 &val) : lo(val.lo), hi(val.hi) {} + + Int128(const long64 &_hi, const ulong64 &_lo) : lo(_lo), hi(_hi) {} + + Int128 &operator=(const long64 &val) { + lo = (ulong64)val; + if (val < 0) + hi = -1; + else + hi = 0; + return *this; + } + + bool operator==(const Int128 &val) const { + return (hi == val.hi && lo == val.lo); + } + + bool operator!=(const Int128 &val) const { return !(*this == val); } + + bool operator>(const Int128 &val) const { + if (hi != val.hi) + return hi > val.hi; + else + return lo > val.lo; + } + + bool operator<(const Int128 &val) const { + if (hi != val.hi) + return hi < val.hi; + else + return lo < val.lo; + } + + bool operator>=(const Int128 &val) const { return !(*this < val); } + + bool operator<=(const Int128 &val) const { return !(*this > val); } + + Int128 &operator+=(const Int128 &rhs) { + hi += rhs.hi; + lo += rhs.lo; + if (lo < rhs.lo) + hi++; + return *this; + } + + Int128 operator+(const Int128 &rhs) const { + Int128 result(*this); + result += rhs; + return result; + } + + Int128 &operator-=(const Int128 &rhs) { + *this += -rhs; + return *this; + } + + Int128 operator-(const Int128 &rhs) const { + Int128 result(*this); + result -= rhs; + return result; + } + + Int128 operator-() const // unary negation + { + if (lo == 0) + return Int128(-hi, 0); + else + return Int128(~hi, ~lo + 1); + } + + operator double() const { + const double shift64 = 18446744073709551616.0; // 2^64 + if (hi < 0) { + if (lo == 0) + return (double)hi * shift64; + else + return -(double)(~lo + ~hi * shift64); + } else + return (double)(lo + hi * shift64); + } +}; +//------------------------------------------------------------------------------ + +Int128 Int128Mul(long64 lhs, long64 rhs) { + bool negate = (lhs < 0) != (rhs < 0); + + if (lhs < 0) + lhs = -lhs; + ulong64 int1Hi = ulong64(lhs) >> 32; + ulong64 int1Lo = ulong64(lhs & 0xFFFFFFFF); + + if (rhs < 0) + rhs = -rhs; + ulong64 int2Hi = ulong64(rhs) >> 32; + ulong64 int2Lo = ulong64(rhs & 0xFFFFFFFF); + + // nb: see comments in clipper.pas + ulong64 a = int1Hi * int2Hi; + ulong64 b = int1Lo * int2Lo; + ulong64 c = int1Hi * int2Lo + int1Lo * int2Hi; + + Int128 tmp; + tmp.hi = long64(a + (c >> 32)); + tmp.lo = long64(c << 32); + tmp.lo += long64(b); + if (tmp.lo < b) + tmp.hi++; + if (negate) + tmp = -tmp; + return tmp; +}; +#endif + +//------------------------------------------------------------------------------ +// Miscellaneous global functions +//------------------------------------------------------------------------------ + +bool Orientation(const Path &poly) { return Area(poly) >= 0; } +//------------------------------------------------------------------------------ + +double Area(const Path &poly) { + int size = (int)poly.size(); + if (size < 3) + return 0; + + double a = 0; + for (int i = 0, j = size - 1; i < size; ++i) { + a += ((double)poly[j].X + poly[i].X) * ((double)poly[j].Y - poly[i].Y); + j = i; + } + return -a * 0.5; +} +//------------------------------------------------------------------------------ + +double Area(const OutPt *op) { + const OutPt *startOp = op; + if (!op) + return 0; + double a = 0; + do { + a += (double)(op->Prev->Pt.X + op->Pt.X) * + (double)(op->Prev->Pt.Y - op->Pt.Y); + op = op->Next; + } while (op != startOp); + return a * 0.5; +} +//------------------------------------------------------------------------------ + +double Area(const OutRec &outRec) { return Area(outRec.Pts); } +//------------------------------------------------------------------------------ + +bool PointIsVertex(const IntPoint &Pt, OutPt *pp) { + OutPt *pp2 = pp; + do { + if (pp2->Pt == Pt) + return true; + pp2 = pp2->Next; + } while (pp2 != pp); + return false; +} +//------------------------------------------------------------------------------ + +// See "The Point in Polygon Problem for Arbitrary Polygons" by Hormann & +// Agathos +// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.88.5498&rep=rep1&type=pdf +int PointInPolygon(const IntPoint &pt, const Path &path) { + // returns 0 if false, +1 if true, -1 if pt ON polygon boundary + int result = 0; + size_t cnt = path.size(); + if (cnt < 3) + return 0; + IntPoint ip = path[0]; + for (size_t i = 1; i <= cnt; ++i) { + IntPoint ipNext = (i == cnt ? path[0] : path[i]); + if (ipNext.Y == pt.Y) { + if ((ipNext.X == pt.X) || + (ip.Y == pt.Y && ((ipNext.X > pt.X) == (ip.X < pt.X)))) + return -1; + } + if ((ip.Y < pt.Y) != (ipNext.Y < pt.Y)) { + if (ip.X >= pt.X) { + if (ipNext.X > pt.X) + result = 1 - result; + else { + double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) - + (double)(ipNext.X - pt.X) * (ip.Y - pt.Y); + if (!d) + return -1; + if ((d > 0) == (ipNext.Y > ip.Y)) + result = 1 - result; + } + } else { + if (ipNext.X > pt.X) { + double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) - + (double)(ipNext.X - pt.X) * (ip.Y - pt.Y); + if (!d) + return -1; + if ((d > 0) == (ipNext.Y > ip.Y)) + result = 1 - result; + } + } + } + ip = ipNext; + } + return result; +} +//------------------------------------------------------------------------------ + +int PointInPolygon(const IntPoint &pt, OutPt *op) { + // returns 0 if false, +1 if true, -1 if pt ON polygon boundary + int result = 0; + OutPt *startOp = op; + for (;;) { + if (op->Next->Pt.Y == pt.Y) { + if ((op->Next->Pt.X == pt.X) || + (op->Pt.Y == pt.Y && ((op->Next->Pt.X > pt.X) == (op->Pt.X < pt.X)))) + return -1; + } + if ((op->Pt.Y < pt.Y) != (op->Next->Pt.Y < pt.Y)) { + if (op->Pt.X >= pt.X) { + if (op->Next->Pt.X > pt.X) + result = 1 - result; + else { + double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) - + (double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y); + if (!d) + return -1; + if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y)) + result = 1 - result; + } + } else { + if (op->Next->Pt.X > pt.X) { + double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) - + (double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y); + if (!d) + return -1; + if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y)) + result = 1 - result; + } + } + } + op = op->Next; + if (startOp == op) + break; + } + return result; +} +//------------------------------------------------------------------------------ + +bool Poly2ContainsPoly1(OutPt *OutPt1, OutPt *OutPt2) { + OutPt *op = OutPt1; + do { + // nb: PointInPolygon returns 0 if false, +1 if true, -1 if pt on polygon + int res = PointInPolygon(op->Pt, OutPt2); + if (res >= 0) + return res > 0; + op = op->Next; + } while (op != OutPt1); + return true; +} +//---------------------------------------------------------------------- + +bool SlopesEqual(const TEdge &e1, const TEdge &e2, bool UseFullInt64Range) { +#ifndef use_int32 + if (UseFullInt64Range) + return Int128Mul(e1.Top.Y - e1.Bot.Y, e2.Top.X - e2.Bot.X) == + Int128Mul(e1.Top.X - e1.Bot.X, e2.Top.Y - e2.Bot.Y); + else +#endif + return (e1.Top.Y - e1.Bot.Y) * (e2.Top.X - e2.Bot.X) == + (e1.Top.X - e1.Bot.X) * (e2.Top.Y - e2.Bot.Y); +} +//------------------------------------------------------------------------------ + +bool SlopesEqual(const IntPoint pt1, const IntPoint pt2, const IntPoint pt3, + bool UseFullInt64Range) { +#ifndef use_int32 + if (UseFullInt64Range) + return Int128Mul(pt1.Y - pt2.Y, pt2.X - pt3.X) == + Int128Mul(pt1.X - pt2.X, pt2.Y - pt3.Y); + else +#endif + return (pt1.Y - pt2.Y) * (pt2.X - pt3.X) == + (pt1.X - pt2.X) * (pt2.Y - pt3.Y); +} +//------------------------------------------------------------------------------ + +bool SlopesEqual(const IntPoint pt1, const IntPoint pt2, const IntPoint pt3, + const IntPoint pt4, bool UseFullInt64Range) { +#ifndef use_int32 + if (UseFullInt64Range) + return Int128Mul(pt1.Y - pt2.Y, pt3.X - pt4.X) == + Int128Mul(pt1.X - pt2.X, pt3.Y - pt4.Y); + else +#endif + return (pt1.Y - pt2.Y) * (pt3.X - pt4.X) == + (pt1.X - pt2.X) * (pt3.Y - pt4.Y); +} +//------------------------------------------------------------------------------ + +inline bool IsHorizontal(TEdge &e) { return e.Dx == HORIZONTAL; } +//------------------------------------------------------------------------------ + +inline double GetDx(const IntPoint pt1, const IntPoint pt2) { + return (pt1.Y == pt2.Y) ? HORIZONTAL + : (double)(pt2.X - pt1.X) / (pt2.Y - pt1.Y); +} +//--------------------------------------------------------------------------- + +inline void SetDx(TEdge &e) { + cInt dy = (e.Top.Y - e.Bot.Y); + if (dy == 0) + e.Dx = HORIZONTAL; + else + e.Dx = (double)(e.Top.X - e.Bot.X) / dy; +} +//--------------------------------------------------------------------------- + +inline void SwapSides(TEdge &Edge1, TEdge &Edge2) { + EdgeSide Side = Edge1.Side; + Edge1.Side = Edge2.Side; + Edge2.Side = Side; +} +//------------------------------------------------------------------------------ + +inline void SwapPolyIndexes(TEdge &Edge1, TEdge &Edge2) { + int OutIdx = Edge1.OutIdx; + Edge1.OutIdx = Edge2.OutIdx; + Edge2.OutIdx = OutIdx; +} +//------------------------------------------------------------------------------ + +inline cInt TopX(TEdge &edge, const cInt currentY) { + return (currentY == edge.Top.Y) + ? edge.Top.X + : edge.Bot.X + Round(edge.Dx * (currentY - edge.Bot.Y)); +} +//------------------------------------------------------------------------------ + +void IntersectPoint(TEdge &Edge1, TEdge &Edge2, IntPoint &ip) { +#ifdef use_xyz + ip.Z = 0; +#endif + + double b1, b2; + if (Edge1.Dx == Edge2.Dx) { + ip.Y = Edge1.Curr.Y; + ip.X = TopX(Edge1, ip.Y); + return; + } else if (Edge1.Dx == 0) { + ip.X = Edge1.Bot.X; + if (IsHorizontal(Edge2)) + ip.Y = Edge2.Bot.Y; + else { + b2 = Edge2.Bot.Y - (Edge2.Bot.X / Edge2.Dx); + ip.Y = Round(ip.X / Edge2.Dx + b2); + } + } else if (Edge2.Dx == 0) { + ip.X = Edge2.Bot.X; + if (IsHorizontal(Edge1)) + ip.Y = Edge1.Bot.Y; + else { + b1 = Edge1.Bot.Y - (Edge1.Bot.X / Edge1.Dx); + ip.Y = Round(ip.X / Edge1.Dx + b1); + } + } else { + b1 = Edge1.Bot.X - Edge1.Bot.Y * Edge1.Dx; + b2 = Edge2.Bot.X - Edge2.Bot.Y * Edge2.Dx; + double q = (b2 - b1) / (Edge1.Dx - Edge2.Dx); + ip.Y = Round(q); + if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx)) + ip.X = Round(Edge1.Dx * q + b1); + else + ip.X = Round(Edge2.Dx * q + b2); + } + + if (ip.Y < Edge1.Top.Y || ip.Y < Edge2.Top.Y) { + if (Edge1.Top.Y > Edge2.Top.Y) + ip.Y = Edge1.Top.Y; + else + ip.Y = Edge2.Top.Y; + if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx)) + ip.X = TopX(Edge1, ip.Y); + else + ip.X = TopX(Edge2, ip.Y); + } + // finally, don't allow 'ip' to be BELOW curr.Y (ie bottom of scanbeam) ... + if (ip.Y > Edge1.Curr.Y) { + ip.Y = Edge1.Curr.Y; + // use the more vertical edge to derive X ... + if (std::fabs(Edge1.Dx) > std::fabs(Edge2.Dx)) + ip.X = TopX(Edge2, ip.Y); + else + ip.X = TopX(Edge1, ip.Y); + } +} +//------------------------------------------------------------------------------ + +void ReversePolyPtLinks(OutPt *pp) { + if (!pp) + return; + OutPt *pp1, *pp2; + pp1 = pp; + do { + pp2 = pp1->Next; + pp1->Next = pp1->Prev; + pp1->Prev = pp2; + pp1 = pp2; + } while (pp1 != pp); +} +//------------------------------------------------------------------------------ + +void DisposeOutPts(OutPt *&pp) { + if (pp == 0) + return; + pp->Prev->Next = 0; + while (pp) { + OutPt *tmpPp = pp; + pp = pp->Next; + delete tmpPp; + } +} +//------------------------------------------------------------------------------ + +inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) { + std::memset(e, 0, sizeof(TEdge)); + e->Next = eNext; + e->Prev = ePrev; + e->Curr = Pt; + e->OutIdx = Unassigned; +} +//------------------------------------------------------------------------------ + +void InitEdge2(TEdge &e, PolyType Pt) { + if (e.Curr.Y >= e.Next->Curr.Y) { + e.Bot = e.Curr; + e.Top = e.Next->Curr; + } else { + e.Top = e.Curr; + e.Bot = e.Next->Curr; + } + SetDx(e); + e.PolyTyp = Pt; +} +//------------------------------------------------------------------------------ + +TEdge *RemoveEdge(TEdge *e) { + // removes e from double_linked_list (but without removing from memory) + e->Prev->Next = e->Next; + e->Next->Prev = e->Prev; + TEdge *result = e->Next; + e->Prev = 0; // flag as removed (see ClipperBase.Clear) + return result; +} +//------------------------------------------------------------------------------ + +inline void ReverseHorizontal(TEdge &e) { + // swap horizontal edges' Top and Bottom x's so they follow the natural + // progression of the bounds - ie so their xbots will align with the + // adjoining lower edge. [Helpful in the ProcessHorizontal() method.] + std::swap(e.Top.X, e.Bot.X); +#ifdef use_xyz + std::swap(e.Top.Z, e.Bot.Z); +#endif +} +//------------------------------------------------------------------------------ + +void SwapPoints(IntPoint &pt1, IntPoint &pt2) { + IntPoint tmp = pt1; + pt1 = pt2; + pt2 = tmp; +} +//------------------------------------------------------------------------------ + +bool GetOverlapSegment(IntPoint pt1a, IntPoint pt1b, IntPoint pt2a, + IntPoint pt2b, IntPoint &pt1, IntPoint &pt2) { + // precondition: segments are Collinear. + if (Abs(pt1a.X - pt1b.X) > Abs(pt1a.Y - pt1b.Y)) { + if (pt1a.X > pt1b.X) + SwapPoints(pt1a, pt1b); + if (pt2a.X > pt2b.X) + SwapPoints(pt2a, pt2b); + if (pt1a.X > pt2a.X) + pt1 = pt1a; + else + pt1 = pt2a; + if (pt1b.X < pt2b.X) + pt2 = pt1b; + else + pt2 = pt2b; + return pt1.X < pt2.X; + } else { + if (pt1a.Y < pt1b.Y) + SwapPoints(pt1a, pt1b); + if (pt2a.Y < pt2b.Y) + SwapPoints(pt2a, pt2b); + if (pt1a.Y < pt2a.Y) + pt1 = pt1a; + else + pt1 = pt2a; + if (pt1b.Y > pt2b.Y) + pt2 = pt1b; + else + pt2 = pt2b; + return pt1.Y > pt2.Y; + } +} +//------------------------------------------------------------------------------ + +bool FirstIsBottomPt(const OutPt *btmPt1, const OutPt *btmPt2) { + OutPt *p = btmPt1->Prev; + while ((p->Pt == btmPt1->Pt) && (p != btmPt1)) + p = p->Prev; + double dx1p = std::fabs(GetDx(btmPt1->Pt, p->Pt)); + p = btmPt1->Next; + while ((p->Pt == btmPt1->Pt) && (p != btmPt1)) + p = p->Next; + double dx1n = std::fabs(GetDx(btmPt1->Pt, p->Pt)); + + p = btmPt2->Prev; + while ((p->Pt == btmPt2->Pt) && (p != btmPt2)) + p = p->Prev; + double dx2p = std::fabs(GetDx(btmPt2->Pt, p->Pt)); + p = btmPt2->Next; + while ((p->Pt == btmPt2->Pt) && (p != btmPt2)) + p = p->Next; + double dx2n = std::fabs(GetDx(btmPt2->Pt, p->Pt)); + + if (std::max(dx1p, dx1n) == std::max(dx2p, dx2n) && + std::min(dx1p, dx1n) == std::min(dx2p, dx2n)) + return Area(btmPt1) > 0; // if otherwise identical use orientation + else + return (dx1p >= dx2p && dx1p >= dx2n) || (dx1n >= dx2p && dx1n >= dx2n); +} +//------------------------------------------------------------------------------ + +OutPt *GetBottomPt(OutPt *pp) { + OutPt *dups = 0; + OutPt *p = pp->Next; + while (p != pp) { + if (p->Pt.Y > pp->Pt.Y) { + pp = p; + dups = 0; + } else if (p->Pt.Y == pp->Pt.Y && p->Pt.X <= pp->Pt.X) { + if (p->Pt.X < pp->Pt.X) { + dups = 0; + pp = p; + } else { + if (p->Next != pp && p->Prev != pp) + dups = p; + } + } + p = p->Next; + } + if (dups) { + // there appears to be at least 2 vertices at BottomPt so ... + while (dups != p) { + if (!FirstIsBottomPt(p, dups)) + pp = dups; + dups = dups->Next; + while (dups->Pt != pp->Pt) + dups = dups->Next; + } + } + return pp; +} +//------------------------------------------------------------------------------ + +bool Pt2IsBetweenPt1AndPt3(const IntPoint pt1, const IntPoint pt2, + const IntPoint pt3) { + if ((pt1 == pt3) || (pt1 == pt2) || (pt3 == pt2)) + return false; + else if (pt1.X != pt3.X) + return (pt2.X > pt1.X) == (pt2.X < pt3.X); + else + return (pt2.Y > pt1.Y) == (pt2.Y < pt3.Y); +} +//------------------------------------------------------------------------------ + +bool HorzSegmentsOverlap(cInt seg1a, cInt seg1b, cInt seg2a, cInt seg2b) { + if (seg1a > seg1b) + std::swap(seg1a, seg1b); + if (seg2a > seg2b) + std::swap(seg2a, seg2b); + return (seg1a < seg2b) && (seg2a < seg1b); +} + +//------------------------------------------------------------------------------ +// ClipperBase class methods ... +//------------------------------------------------------------------------------ + +ClipperBase::ClipperBase() // constructor +{ + m_CurrentLM = m_MinimaList.begin(); // begin() == end() here + m_UseFullRange = false; +} +//------------------------------------------------------------------------------ + +ClipperBase::~ClipperBase() // destructor +{ + Clear(); +} +//------------------------------------------------------------------------------ + +void RangeTest(const IntPoint &Pt, bool &useFullRange) { + if (useFullRange) { + if (Pt.X > hiRange || Pt.Y > hiRange || -Pt.X > hiRange || -Pt.Y > hiRange) + throw clipperException("Coordinate outside allowed range"); + } else if (Pt.X > loRange || Pt.Y > loRange || -Pt.X > loRange || + -Pt.Y > loRange) { + useFullRange = true; + RangeTest(Pt, useFullRange); + } +} +//------------------------------------------------------------------------------ + +TEdge *FindNextLocMin(TEdge *E) { + for (;;) { + while (E->Bot != E->Prev->Bot || E->Curr == E->Top) + E = E->Next; + if (!IsHorizontal(*E) && !IsHorizontal(*E->Prev)) + break; + while (IsHorizontal(*E->Prev)) + E = E->Prev; + TEdge *E2 = E; + while (IsHorizontal(*E)) + E = E->Next; + if (E->Top.Y == E->Prev->Bot.Y) + continue; // ie just an intermediate horz. + if (E2->Prev->Bot.X < E->Bot.X) + E = E2; + break; + } + return E; +} +//------------------------------------------------------------------------------ + +TEdge *ClipperBase::ProcessBound(TEdge *E, bool NextIsForward) { + TEdge *Result = E; + TEdge *Horz = 0; + + if (E->OutIdx == Skip) { + // if edges still remain in the current bound beyond the skip edge then + // create another LocMin and call ProcessBound once more + if (NextIsForward) { + while (E->Top.Y == E->Next->Bot.Y) + E = E->Next; + // don't include top horizontals when parsing a bound a second time, + // they will be contained in the opposite bound ... + while (E != Result && IsHorizontal(*E)) + E = E->Prev; + } else { + while (E->Top.Y == E->Prev->Bot.Y) + E = E->Prev; + while (E != Result && IsHorizontal(*E)) + E = E->Next; + } + + if (E == Result) { + if (NextIsForward) + Result = E->Next; + else + Result = E->Prev; + } else { + // there are more edges in the bound beyond result starting with E + if (NextIsForward) + E = Result->Next; + else + E = Result->Prev; + MinimaList::value_type locMin; + locMin.Y = E->Bot.Y; + locMin.LeftBound = 0; + locMin.RightBound = E; + E->WindDelta = 0; + Result = ProcessBound(E, NextIsForward); + m_MinimaList.push_back(locMin); + } + return Result; + } + + TEdge *EStart; + + if (IsHorizontal(*E)) { + // We need to be careful with open paths because this may not be a + // true local minima (ie E may be following a skip edge). + // Also, consecutive horz. edges may start heading left before going right. + if (NextIsForward) + EStart = E->Prev; + else + EStart = E->Next; + if (IsHorizontal(*EStart)) // ie an adjoining horizontal skip edge + { + if (EStart->Bot.X != E->Bot.X && EStart->Top.X != E->Bot.X) + ReverseHorizontal(*E); + } else if (EStart->Bot.X != E->Bot.X) + ReverseHorizontal(*E); + } + + EStart = E; + if (NextIsForward) { + while (Result->Top.Y == Result->Next->Bot.Y && Result->Next->OutIdx != Skip) + Result = Result->Next; + if (IsHorizontal(*Result) && Result->Next->OutIdx != Skip) { + // nb: at the top of a bound, horizontals are added to the bound + // only when the preceding edge attaches to the horizontal's left vertex + // unless a Skip edge is encountered when that becomes the top divide + Horz = Result; + while (IsHorizontal(*Horz->Prev)) + Horz = Horz->Prev; + if (Horz->Prev->Top.X > Result->Next->Top.X) + Result = Horz->Prev; + } + while (E != Result) { + E->NextInLML = E->Next; + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Prev->Top.X) + ReverseHorizontal(*E); + E = E->Next; + } + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Prev->Top.X) + ReverseHorizontal(*E); + Result = Result->Next; // move to the edge just beyond current bound + } else { + while (Result->Top.Y == Result->Prev->Bot.Y && Result->Prev->OutIdx != Skip) + Result = Result->Prev; + if (IsHorizontal(*Result) && Result->Prev->OutIdx != Skip) { + Horz = Result; + while (IsHorizontal(*Horz->Next)) + Horz = Horz->Next; + if (Horz->Next->Top.X == Result->Prev->Top.X || + Horz->Next->Top.X > Result->Prev->Top.X) + Result = Horz->Next; + } + + while (E != Result) { + E->NextInLML = E->Prev; + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X) + ReverseHorizontal(*E); + E = E->Prev; + } + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X) + ReverseHorizontal(*E); + Result = Result->Prev; // move to the edge just beyond current bound + } + + return Result; +} +//------------------------------------------------------------------------------ + +bool ClipperBase::AddPath(const Path &pg, PolyType PolyTyp, bool Closed) { +#ifdef use_lines + if (!Closed && PolyTyp == ptClip) + throw clipperException("AddPath: Open paths must be subject."); +#else + if (!Closed) + throw clipperException("AddPath: Open paths have been disabled."); +#endif + + int highI = (int)pg.size() - 1; + if (Closed) + while (highI > 0 && (pg[highI] == pg[0])) + --highI; + while (highI > 0 && (pg[highI] == pg[highI - 1])) + --highI; + if ((Closed && highI < 2) || (!Closed && highI < 1)) + return false; + + // create a new edge array ... + TEdge *edges = new TEdge[highI + 1]; + + bool IsFlat = true; + // 1. Basic (first) edge initialization ... + try { + edges[1].Curr = pg[1]; + RangeTest(pg[0], m_UseFullRange); + RangeTest(pg[highI], m_UseFullRange); + InitEdge(&edges[0], &edges[1], &edges[highI], pg[0]); + InitEdge(&edges[highI], &edges[0], &edges[highI - 1], pg[highI]); + for (int i = highI - 1; i >= 1; --i) { + RangeTest(pg[i], m_UseFullRange); + InitEdge(&edges[i], &edges[i + 1], &edges[i - 1], pg[i]); + } + } catch (...) { + delete[] edges; + throw; // range test fails + } + TEdge *eStart = &edges[0]; + + // 2. Remove duplicate vertices, and (when closed) collinear edges ... + TEdge *E = eStart, *eLoopStop = eStart; + for (;;) { + // nb: allows matching start and end points when not Closed ... + if (E->Curr == E->Next->Curr && (Closed || E->Next != eStart)) { + if (E == E->Next) + break; + if (E == eStart) + eStart = E->Next; + E = RemoveEdge(E); + eLoopStop = E; + continue; + } + if (E->Prev == E->Next) + break; // only two vertices + else if (Closed && SlopesEqual(E->Prev->Curr, E->Curr, E->Next->Curr, + m_UseFullRange) && + (!m_PreserveCollinear || + !Pt2IsBetweenPt1AndPt3(E->Prev->Curr, E->Curr, E->Next->Curr))) { + // Collinear edges are allowed for open paths but in closed paths + // the default is to merge adjacent collinear edges into a single edge. + // However, if the PreserveCollinear property is enabled, only overlapping + // collinear edges (ie spikes) will be removed from closed paths. + if (E == eStart) + eStart = E->Next; + E = RemoveEdge(E); + E = E->Prev; + eLoopStop = E; + continue; + } + E = E->Next; + if ((E == eLoopStop) || (!Closed && E->Next == eStart)) + break; + } + + if ((!Closed && (E == E->Next)) || (Closed && (E->Prev == E->Next))) { + delete[] edges; + return false; + } + + if (!Closed) { + m_HasOpenPaths = true; + eStart->Prev->OutIdx = Skip; + } + + // 3. Do second stage of edge initialization ... + E = eStart; + do { + InitEdge2(*E, PolyTyp); + E = E->Next; + if (IsFlat && E->Curr.Y != eStart->Curr.Y) + IsFlat = false; + } while (E != eStart); + + // 4. Finally, add edge bounds to LocalMinima list ... + + // Totally flat paths must be handled differently when adding them + // to LocalMinima list to avoid endless loops etc ... + if (IsFlat) { + if (Closed) { + delete[] edges; + return false; + } + E->Prev->OutIdx = Skip; + MinimaList::value_type locMin; + locMin.Y = E->Bot.Y; + locMin.LeftBound = 0; + locMin.RightBound = E; + locMin.RightBound->Side = esRight; + locMin.RightBound->WindDelta = 0; + for (;;) { + if (E->Bot.X != E->Prev->Top.X) + ReverseHorizontal(*E); + if (E->Next->OutIdx == Skip) + break; + E->NextInLML = E->Next; + E = E->Next; + } + m_MinimaList.push_back(locMin); + m_edges.push_back(edges); + return true; + } + + m_edges.push_back(edges); + bool leftBoundIsForward; + TEdge *EMin = 0; + + // workaround to avoid an endless loop in the while loop below when + // open paths have matching start and end points ... + if (E->Prev->Bot == E->Prev->Top) + E = E->Next; + + for (;;) { + E = FindNextLocMin(E); + if (E == EMin) + break; + else if (!EMin) + EMin = E; + + // E and E.Prev now share a local minima (left aligned if horizontal). + // Compare their slopes to find which starts which bound ... + MinimaList::value_type locMin; + locMin.Y = E->Bot.Y; + if (E->Dx < E->Prev->Dx) { + locMin.LeftBound = E->Prev; + locMin.RightBound = E; + leftBoundIsForward = false; // Q.nextInLML = Q.prev + } else { + locMin.LeftBound = E; + locMin.RightBound = E->Prev; + leftBoundIsForward = true; // Q.nextInLML = Q.next + } + + if (!Closed) + locMin.LeftBound->WindDelta = 0; + else if (locMin.LeftBound->Next == locMin.RightBound) + locMin.LeftBound->WindDelta = -1; + else + locMin.LeftBound->WindDelta = 1; + locMin.RightBound->WindDelta = -locMin.LeftBound->WindDelta; + + E = ProcessBound(locMin.LeftBound, leftBoundIsForward); + if (E->OutIdx == Skip) + E = ProcessBound(E, leftBoundIsForward); + + TEdge *E2 = ProcessBound(locMin.RightBound, !leftBoundIsForward); + if (E2->OutIdx == Skip) + E2 = ProcessBound(E2, !leftBoundIsForward); + + if (locMin.LeftBound->OutIdx == Skip) + locMin.LeftBound = 0; + else if (locMin.RightBound->OutIdx == Skip) + locMin.RightBound = 0; + m_MinimaList.push_back(locMin); + if (!leftBoundIsForward) + E = E2; + } + return true; +} +//------------------------------------------------------------------------------ + +bool ClipperBase::AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed) { + bool result = false; + for (Paths::size_type i = 0; i < ppg.size(); ++i) + if (AddPath(ppg[i], PolyTyp, Closed)) + result = true; + return result; +} +//------------------------------------------------------------------------------ + +void ClipperBase::Clear() { + DisposeLocalMinimaList(); + for (EdgeList::size_type i = 0; i < m_edges.size(); ++i) { + TEdge *edges = m_edges[i]; + delete[] edges; + } + m_edges.clear(); + m_UseFullRange = false; + m_HasOpenPaths = false; +} +//------------------------------------------------------------------------------ + +void ClipperBase::Reset() { + m_CurrentLM = m_MinimaList.begin(); + if (m_CurrentLM == m_MinimaList.end()) + return; // ie nothing to process + std::sort(m_MinimaList.begin(), m_MinimaList.end(), LocMinSorter()); + + m_Scanbeam = ScanbeamList(); // clears/resets priority_queue + // reset all edges ... + for (MinimaList::iterator lm = m_MinimaList.begin(); lm != m_MinimaList.end(); + ++lm) { + InsertScanbeam(lm->Y); + TEdge *e = lm->LeftBound; + if (e) { + e->Curr = e->Bot; + e->Side = esLeft; + e->OutIdx = Unassigned; + } + + e = lm->RightBound; + if (e) { + e->Curr = e->Bot; + e->Side = esRight; + e->OutIdx = Unassigned; + } + } + m_ActiveEdges = 0; + m_CurrentLM = m_MinimaList.begin(); +} +//------------------------------------------------------------------------------ + +void ClipperBase::DisposeLocalMinimaList() { + m_MinimaList.clear(); + m_CurrentLM = m_MinimaList.begin(); +} +//------------------------------------------------------------------------------ + +bool ClipperBase::PopLocalMinima(cInt Y, const LocalMinimum *&locMin) { + if (m_CurrentLM == m_MinimaList.end() || (*m_CurrentLM).Y != Y) + return false; + locMin = &(*m_CurrentLM); + ++m_CurrentLM; + return true; +} +//------------------------------------------------------------------------------ + +IntRect ClipperBase::GetBounds() { + IntRect result; + MinimaList::iterator lm = m_MinimaList.begin(); + if (lm == m_MinimaList.end()) { + result.left = result.top = result.right = result.bottom = 0; + return result; + } + result.left = lm->LeftBound->Bot.X; + result.top = lm->LeftBound->Bot.Y; + result.right = lm->LeftBound->Bot.X; + result.bottom = lm->LeftBound->Bot.Y; + while (lm != m_MinimaList.end()) { + // todo - needs fixing for open paths + result.bottom = std::max(result.bottom, lm->LeftBound->Bot.Y); + TEdge *e = lm->LeftBound; + for (;;) { + TEdge *bottomE = e; + while (e->NextInLML) { + if (e->Bot.X < result.left) + result.left = e->Bot.X; + if (e->Bot.X > result.right) + result.right = e->Bot.X; + e = e->NextInLML; + } + result.left = std::min(result.left, e->Bot.X); + result.right = std::max(result.right, e->Bot.X); + result.left = std::min(result.left, e->Top.X); + result.right = std::max(result.right, e->Top.X); + result.top = std::min(result.top, e->Top.Y); + if (bottomE == lm->LeftBound) + e = lm->RightBound; + else + break; + } + ++lm; + } + return result; +} +//------------------------------------------------------------------------------ + +void ClipperBase::InsertScanbeam(const cInt Y) { m_Scanbeam.push(Y); } +//------------------------------------------------------------------------------ + +bool ClipperBase::PopScanbeam(cInt &Y) { + if (m_Scanbeam.empty()) + return false; + Y = m_Scanbeam.top(); + m_Scanbeam.pop(); + while (!m_Scanbeam.empty() && Y == m_Scanbeam.top()) { + m_Scanbeam.pop(); + } // Pop duplicates. + return true; +} +//------------------------------------------------------------------------------ + +void ClipperBase::DisposeAllOutRecs() { + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + DisposeOutRec(i); + m_PolyOuts.clear(); +} +//------------------------------------------------------------------------------ + +void ClipperBase::DisposeOutRec(PolyOutList::size_type index) { + OutRec *outRec = m_PolyOuts[index]; + if (outRec->Pts) + DisposeOutPts(outRec->Pts); + delete outRec; + m_PolyOuts[index] = 0; +} +//------------------------------------------------------------------------------ + +void ClipperBase::DeleteFromAEL(TEdge *e) { + TEdge *AelPrev = e->PrevInAEL; + TEdge *AelNext = e->NextInAEL; + if (!AelPrev && !AelNext && (e != m_ActiveEdges)) + return; // already deleted + if (AelPrev) + AelPrev->NextInAEL = AelNext; + else + m_ActiveEdges = AelNext; + if (AelNext) + AelNext->PrevInAEL = AelPrev; + e->NextInAEL = 0; + e->PrevInAEL = 0; +} +//------------------------------------------------------------------------------ + +OutRec *ClipperBase::CreateOutRec() { + OutRec *result = new OutRec; + result->IsHole = false; + result->IsOpen = false; + result->FirstLeft = 0; + result->Pts = 0; + result->BottomPt = 0; + result->PolyNd = 0; + m_PolyOuts.push_back(result); + result->Idx = (int)m_PolyOuts.size() - 1; + return result; +} +//------------------------------------------------------------------------------ + +void ClipperBase::SwapPositionsInAEL(TEdge *Edge1, TEdge *Edge2) { + // check that one or other edge hasn't already been removed from AEL ... + if (Edge1->NextInAEL == Edge1->PrevInAEL || + Edge2->NextInAEL == Edge2->PrevInAEL) + return; + + if (Edge1->NextInAEL == Edge2) { + TEdge *Next = Edge2->NextInAEL; + if (Next) + Next->PrevInAEL = Edge1; + TEdge *Prev = Edge1->PrevInAEL; + if (Prev) + Prev->NextInAEL = Edge2; + Edge2->PrevInAEL = Prev; + Edge2->NextInAEL = Edge1; + Edge1->PrevInAEL = Edge2; + Edge1->NextInAEL = Next; + } else if (Edge2->NextInAEL == Edge1) { + TEdge *Next = Edge1->NextInAEL; + if (Next) + Next->PrevInAEL = Edge2; + TEdge *Prev = Edge2->PrevInAEL; + if (Prev) + Prev->NextInAEL = Edge1; + Edge1->PrevInAEL = Prev; + Edge1->NextInAEL = Edge2; + Edge2->PrevInAEL = Edge1; + Edge2->NextInAEL = Next; + } else { + TEdge *Next = Edge1->NextInAEL; + TEdge *Prev = Edge1->PrevInAEL; + Edge1->NextInAEL = Edge2->NextInAEL; + if (Edge1->NextInAEL) + Edge1->NextInAEL->PrevInAEL = Edge1; + Edge1->PrevInAEL = Edge2->PrevInAEL; + if (Edge1->PrevInAEL) + Edge1->PrevInAEL->NextInAEL = Edge1; + Edge2->NextInAEL = Next; + if (Edge2->NextInAEL) + Edge2->NextInAEL->PrevInAEL = Edge2; + Edge2->PrevInAEL = Prev; + if (Edge2->PrevInAEL) + Edge2->PrevInAEL->NextInAEL = Edge2; + } + + if (!Edge1->PrevInAEL) + m_ActiveEdges = Edge1; + else if (!Edge2->PrevInAEL) + m_ActiveEdges = Edge2; +} +//------------------------------------------------------------------------------ + +void ClipperBase::UpdateEdgeIntoAEL(TEdge *&e) { + if (!e->NextInLML) + throw clipperException("UpdateEdgeIntoAEL: invalid call"); + + e->NextInLML->OutIdx = e->OutIdx; + TEdge *AelPrev = e->PrevInAEL; + TEdge *AelNext = e->NextInAEL; + if (AelPrev) + AelPrev->NextInAEL = e->NextInLML; + else + m_ActiveEdges = e->NextInLML; + if (AelNext) + AelNext->PrevInAEL = e->NextInLML; + e->NextInLML->Side = e->Side; + e->NextInLML->WindDelta = e->WindDelta; + e->NextInLML->WindCnt = e->WindCnt; + e->NextInLML->WindCnt2 = e->WindCnt2; + e = e->NextInLML; + e->Curr = e->Bot; + e->PrevInAEL = AelPrev; + e->NextInAEL = AelNext; + if (!IsHorizontal(*e)) + InsertScanbeam(e->Top.Y); +} +//------------------------------------------------------------------------------ + +bool ClipperBase::LocalMinimaPending() { + return (m_CurrentLM != m_MinimaList.end()); +} + +//------------------------------------------------------------------------------ +// TClipper methods ... +//------------------------------------------------------------------------------ + +Clipper::Clipper(int initOptions) + : ClipperBase() // constructor +{ + m_ExecuteLocked = false; + m_UseFullRange = false; + m_ReverseOutput = ((initOptions & ioReverseSolution) != 0); + m_StrictSimple = ((initOptions & ioStrictlySimple) != 0); + m_PreserveCollinear = ((initOptions & ioPreserveCollinear) != 0); + m_HasOpenPaths = false; +#ifdef use_xyz + m_ZFill = 0; +#endif +} +//------------------------------------------------------------------------------ + +#ifdef use_xyz +void Clipper::ZFillFunction(ZFillCallback zFillFunc) { m_ZFill = zFillFunc; } +//------------------------------------------------------------------------------ +#endif + +bool Clipper::Execute(ClipType clipType, Paths &solution, + PolyFillType fillType) { + return Execute(clipType, solution, fillType, fillType); +} +//------------------------------------------------------------------------------ + +bool Clipper::Execute(ClipType clipType, PolyTree &polytree, + PolyFillType fillType) { + return Execute(clipType, polytree, fillType, fillType); +} +//------------------------------------------------------------------------------ + +bool Clipper::Execute(ClipType clipType, Paths &solution, + PolyFillType subjFillType, PolyFillType clipFillType) { + if (m_ExecuteLocked) + return false; + if (m_HasOpenPaths) + throw clipperException( + "Error: PolyTree struct is needed for open path clipping."); + m_ExecuteLocked = true; + solution.resize(0); + m_SubjFillType = subjFillType; + m_ClipFillType = clipFillType; + m_ClipType = clipType; + m_UsingPolyTree = false; + bool succeeded = ExecuteInternal(); + if (succeeded) + BuildResult(solution); + DisposeAllOutRecs(); + m_ExecuteLocked = false; + return succeeded; +} +//------------------------------------------------------------------------------ + +bool Clipper::Execute(ClipType clipType, PolyTree &polytree, + PolyFillType subjFillType, PolyFillType clipFillType) { + if (m_ExecuteLocked) + return false; + m_ExecuteLocked = true; + m_SubjFillType = subjFillType; + m_ClipFillType = clipFillType; + m_ClipType = clipType; + m_UsingPolyTree = true; + bool succeeded = ExecuteInternal(); + if (succeeded) + BuildResult2(polytree); + DisposeAllOutRecs(); + m_ExecuteLocked = false; + return succeeded; +} +//------------------------------------------------------------------------------ + +void Clipper::FixHoleLinkage(OutRec &outrec) { + // skip OutRecs that (a) contain outermost polygons or + //(b) already have the correct owner/child linkage ... + if (!outrec.FirstLeft || + (outrec.IsHole != outrec.FirstLeft->IsHole && outrec.FirstLeft->Pts)) + return; + + OutRec *orfl = outrec.FirstLeft; + while (orfl && ((orfl->IsHole == outrec.IsHole) || !orfl->Pts)) + orfl = orfl->FirstLeft; + outrec.FirstLeft = orfl; +} +//------------------------------------------------------------------------------ + +bool Clipper::ExecuteInternal() { + bool succeeded = true; + try { + Reset(); + m_Maxima = MaximaList(); + m_SortedEdges = 0; + + succeeded = true; + cInt botY, topY; + if (!PopScanbeam(botY)) + return false; + InsertLocalMinimaIntoAEL(botY); + while (PopScanbeam(topY) || LocalMinimaPending()) { + ProcessHorizontals(); + ClearGhostJoins(); + if (!ProcessIntersections(topY)) { + succeeded = false; + break; + } + ProcessEdgesAtTopOfScanbeam(topY); + botY = topY; + InsertLocalMinimaIntoAEL(botY); + } + } catch (...) { + succeeded = false; + } + + if (succeeded) { + // fix orientations ... + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) { + OutRec *outRec = m_PolyOuts[i]; + if (!outRec->Pts || outRec->IsOpen) + continue; + if ((outRec->IsHole ^ m_ReverseOutput) == (Area(*outRec) > 0)) + ReversePolyPtLinks(outRec->Pts); + } + + if (!m_Joins.empty()) + JoinCommonEdges(); + + // unfortunately FixupOutPolygon() must be done after JoinCommonEdges() + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) { + OutRec *outRec = m_PolyOuts[i]; + if (!outRec->Pts) + continue; + if (outRec->IsOpen) + FixupOutPolyline(*outRec); + else + FixupOutPolygon(*outRec); + } + + if (m_StrictSimple) + DoSimplePolygons(); + } + + ClearJoins(); + ClearGhostJoins(); + return succeeded; +} +//------------------------------------------------------------------------------ + +void Clipper::SetWindingCount(TEdge &edge) { + TEdge *e = edge.PrevInAEL; + // find the edge of the same polytype that immediately preceeds 'edge' in AEL + while (e && ((e->PolyTyp != edge.PolyTyp) || (e->WindDelta == 0))) + e = e->PrevInAEL; + if (!e) { + if (edge.WindDelta == 0) { + PolyFillType pft = + (edge.PolyTyp == ptSubject ? m_SubjFillType : m_ClipFillType); + edge.WindCnt = (pft == pftNegative ? -1 : 1); + } else + edge.WindCnt = edge.WindDelta; + edge.WindCnt2 = 0; + e = m_ActiveEdges; // ie get ready to calc WindCnt2 + } else if (edge.WindDelta == 0 && m_ClipType != ctUnion) { + edge.WindCnt = 1; + edge.WindCnt2 = e->WindCnt2; + e = e->NextInAEL; // ie get ready to calc WindCnt2 + } else if (IsEvenOddFillType(edge)) { + // EvenOdd filling ... + if (edge.WindDelta == 0) { + // are we inside a subj polygon ... + bool Inside = true; + TEdge *e2 = e->PrevInAEL; + while (e2) { + if (e2->PolyTyp == e->PolyTyp && e2->WindDelta != 0) + Inside = !Inside; + e2 = e2->PrevInAEL; + } + edge.WindCnt = (Inside ? 0 : 1); + } else { + edge.WindCnt = edge.WindDelta; + } + edge.WindCnt2 = e->WindCnt2; + e = e->NextInAEL; // ie get ready to calc WindCnt2 + } else { + // nonZero, Positive or Negative filling ... + if (e->WindCnt * e->WindDelta < 0) { + // prev edge is 'decreasing' WindCount (WC) toward zero + // so we're outside the previous polygon ... + if (Abs(e->WindCnt) > 1) { + // outside prev poly but still inside another. + // when reversing direction of prev poly use the same WC + if (e->WindDelta * edge.WindDelta < 0) + edge.WindCnt = e->WindCnt; + // otherwise continue to 'decrease' WC ... + else + edge.WindCnt = e->WindCnt + edge.WindDelta; + } else + // now outside all polys of same polytype so set own WC ... + edge.WindCnt = (edge.WindDelta == 0 ? 1 : edge.WindDelta); + } else { + // prev edge is 'increasing' WindCount (WC) away from zero + // so we're inside the previous polygon ... + if (edge.WindDelta == 0) + edge.WindCnt = (e->WindCnt < 0 ? e->WindCnt - 1 : e->WindCnt + 1); + // if wind direction is reversing prev then use same WC + else if (e->WindDelta * edge.WindDelta < 0) + edge.WindCnt = e->WindCnt; + // otherwise add to WC ... + else + edge.WindCnt = e->WindCnt + edge.WindDelta; + } + edge.WindCnt2 = e->WindCnt2; + e = e->NextInAEL; // ie get ready to calc WindCnt2 + } + + // update WindCnt2 ... + if (IsEvenOddAltFillType(edge)) { + // EvenOdd filling ... + while (e != &edge) { + if (e->WindDelta != 0) + edge.WindCnt2 = (edge.WindCnt2 == 0 ? 1 : 0); + e = e->NextInAEL; + } + } else { + // nonZero, Positive or Negative filling ... + while (e != &edge) { + edge.WindCnt2 += e->WindDelta; + e = e->NextInAEL; + } + } +} +//------------------------------------------------------------------------------ + +bool Clipper::IsEvenOddFillType(const TEdge &edge) const { + if (edge.PolyTyp == ptSubject) + return m_SubjFillType == pftEvenOdd; + else + return m_ClipFillType == pftEvenOdd; +} +//------------------------------------------------------------------------------ + +bool Clipper::IsEvenOddAltFillType(const TEdge &edge) const { + if (edge.PolyTyp == ptSubject) + return m_ClipFillType == pftEvenOdd; + else + return m_SubjFillType == pftEvenOdd; +} +//------------------------------------------------------------------------------ + +bool Clipper::IsContributing(const TEdge &edge) const { + PolyFillType pft, pft2; + if (edge.PolyTyp == ptSubject) { + pft = m_SubjFillType; + pft2 = m_ClipFillType; + } else { + pft = m_ClipFillType; + pft2 = m_SubjFillType; + } + + switch (pft) { + case pftEvenOdd: + // return false if a subj line has been flagged as inside a subj polygon + if (edge.WindDelta == 0 && edge.WindCnt != 1) + return false; + break; + case pftNonZero: + if (Abs(edge.WindCnt) != 1) + return false; + break; + case pftPositive: + if (edge.WindCnt != 1) + return false; + break; + default: // pftNegative + if (edge.WindCnt != -1) + return false; + } + + switch (m_ClipType) { + case ctIntersection: + switch (pft2) { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 != 0); + case pftPositive: + return (edge.WindCnt2 > 0); + default: + return (edge.WindCnt2 < 0); + } + break; + case ctUnion: + switch (pft2) { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 == 0); + case pftPositive: + return (edge.WindCnt2 <= 0); + default: + return (edge.WindCnt2 >= 0); + } + break; + case ctDifference: + if (edge.PolyTyp == ptSubject) + switch (pft2) { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 == 0); + case pftPositive: + return (edge.WindCnt2 <= 0); + default: + return (edge.WindCnt2 >= 0); + } + else + switch (pft2) { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 != 0); + case pftPositive: + return (edge.WindCnt2 > 0); + default: + return (edge.WindCnt2 < 0); + } + break; + case ctXor: + if (edge.WindDelta == 0) // XOr always contributing unless open + switch (pft2) { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 == 0); + case pftPositive: + return (edge.WindCnt2 <= 0); + default: + return (edge.WindCnt2 >= 0); + } + else + return true; + break; + default: + return true; + } +} +//------------------------------------------------------------------------------ + +OutPt *Clipper::AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt) { + OutPt *result; + TEdge *e, *prevE; + if (IsHorizontal(*e2) || (e1->Dx > e2->Dx)) { + result = AddOutPt(e1, Pt); + e2->OutIdx = e1->OutIdx; + e1->Side = esLeft; + e2->Side = esRight; + e = e1; + if (e->PrevInAEL == e2) + prevE = e2->PrevInAEL; + else + prevE = e->PrevInAEL; + } else { + result = AddOutPt(e2, Pt); + e1->OutIdx = e2->OutIdx; + e1->Side = esRight; + e2->Side = esLeft; + e = e2; + if (e->PrevInAEL == e1) + prevE = e1->PrevInAEL; + else + prevE = e->PrevInAEL; + } + + if (prevE && prevE->OutIdx >= 0 && prevE->Top.Y < Pt.Y && e->Top.Y < Pt.Y) { + cInt xPrev = TopX(*prevE, Pt.Y); + cInt xE = TopX(*e, Pt.Y); + if (xPrev == xE && (e->WindDelta != 0) && (prevE->WindDelta != 0) && + SlopesEqual(IntPoint(xPrev, Pt.Y), prevE->Top, IntPoint(xE, Pt.Y), + e->Top, m_UseFullRange)) { + OutPt *outPt = AddOutPt(prevE, Pt); + AddJoin(result, outPt, e->Top); + } + } + return result; +} +//------------------------------------------------------------------------------ + +void Clipper::AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt) { + AddOutPt(e1, Pt); + if (e2->WindDelta == 0) + AddOutPt(e2, Pt); + if (e1->OutIdx == e2->OutIdx) { + e1->OutIdx = Unassigned; + e2->OutIdx = Unassigned; + } else if (e1->OutIdx < e2->OutIdx) + AppendPolygon(e1, e2); + else + AppendPolygon(e2, e1); +} +//------------------------------------------------------------------------------ + +void Clipper::AddEdgeToSEL(TEdge *edge) { + // SEL pointers in PEdge are reused to build a list of horizontal edges. + // However, we don't need to worry about order with horizontal edge + // processing. + if (!m_SortedEdges) { + m_SortedEdges = edge; + edge->PrevInSEL = 0; + edge->NextInSEL = 0; + } else { + edge->NextInSEL = m_SortedEdges; + edge->PrevInSEL = 0; + m_SortedEdges->PrevInSEL = edge; + m_SortedEdges = edge; + } +} +//------------------------------------------------------------------------------ + +bool Clipper::PopEdgeFromSEL(TEdge *&edge) { + if (!m_SortedEdges) + return false; + edge = m_SortedEdges; + DeleteFromSEL(m_SortedEdges); + return true; +} +//------------------------------------------------------------------------------ + +void Clipper::CopyAELToSEL() { + TEdge *e = m_ActiveEdges; + m_SortedEdges = e; + while (e) { + e->PrevInSEL = e->PrevInAEL; + e->NextInSEL = e->NextInAEL; + e = e->NextInAEL; + } +} +//------------------------------------------------------------------------------ + +void Clipper::AddJoin(OutPt *op1, OutPt *op2, const IntPoint OffPt) { + Join *j = new Join; + j->OutPt1 = op1; + j->OutPt2 = op2; + j->OffPt = OffPt; + m_Joins.push_back(j); +} +//------------------------------------------------------------------------------ + +void Clipper::ClearJoins() { + for (JoinList::size_type i = 0; i < m_Joins.size(); i++) + delete m_Joins[i]; + m_Joins.resize(0); +} +//------------------------------------------------------------------------------ + +void Clipper::ClearGhostJoins() { + for (JoinList::size_type i = 0; i < m_GhostJoins.size(); i++) + delete m_GhostJoins[i]; + m_GhostJoins.resize(0); +} +//------------------------------------------------------------------------------ + +void Clipper::AddGhostJoin(OutPt *op, const IntPoint OffPt) { + Join *j = new Join; + j->OutPt1 = op; + j->OutPt2 = 0; + j->OffPt = OffPt; + m_GhostJoins.push_back(j); +} +//------------------------------------------------------------------------------ + +void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) { + const LocalMinimum *lm; + while (PopLocalMinima(botY, lm)) { + TEdge *lb = lm->LeftBound; + TEdge *rb = lm->RightBound; + + OutPt *Op1 = 0; + if (!lb) { + // nb: don't insert LB into either AEL or SEL + InsertEdgeIntoAEL(rb, 0); + SetWindingCount(*rb); + if (IsContributing(*rb)) + Op1 = AddOutPt(rb, rb->Bot); + } else if (!rb) { + InsertEdgeIntoAEL(lb, 0); + SetWindingCount(*lb); + if (IsContributing(*lb)) + Op1 = AddOutPt(lb, lb->Bot); + InsertScanbeam(lb->Top.Y); + } else { + InsertEdgeIntoAEL(lb, 0); + InsertEdgeIntoAEL(rb, lb); + SetWindingCount(*lb); + rb->WindCnt = lb->WindCnt; + rb->WindCnt2 = lb->WindCnt2; + if (IsContributing(*lb)) + Op1 = AddLocalMinPoly(lb, rb, lb->Bot); + InsertScanbeam(lb->Top.Y); + } + + if (rb) { + if (IsHorizontal(*rb)) { + AddEdgeToSEL(rb); + if (rb->NextInLML) + InsertScanbeam(rb->NextInLML->Top.Y); + } else + InsertScanbeam(rb->Top.Y); + } + + if (!lb || !rb) + continue; + + // if any output polygons share an edge, they'll need joining later ... + if (Op1 && IsHorizontal(*rb) && m_GhostJoins.size() > 0 && + (rb->WindDelta != 0)) { + for (JoinList::size_type i = 0; i < m_GhostJoins.size(); ++i) { + Join *jr = m_GhostJoins[i]; + // if the horizontal Rb and a 'ghost' horizontal overlap, then convert + // the 'ghost' join to a real join ready for later ... + if (HorzSegmentsOverlap(jr->OutPt1->Pt.X, jr->OffPt.X, rb->Bot.X, + rb->Top.X)) + AddJoin(jr->OutPt1, Op1, jr->OffPt); + } + } + + if (lb->OutIdx >= 0 && lb->PrevInAEL && + lb->PrevInAEL->Curr.X == lb->Bot.X && lb->PrevInAEL->OutIdx >= 0 && + SlopesEqual(lb->PrevInAEL->Bot, lb->PrevInAEL->Top, lb->Curr, lb->Top, + m_UseFullRange) && + (lb->WindDelta != 0) && (lb->PrevInAEL->WindDelta != 0)) { + OutPt *Op2 = AddOutPt(lb->PrevInAEL, lb->Bot); + AddJoin(Op1, Op2, lb->Top); + } + + if (lb->NextInAEL != rb) { + + if (rb->OutIdx >= 0 && rb->PrevInAEL->OutIdx >= 0 && + SlopesEqual(rb->PrevInAEL->Curr, rb->PrevInAEL->Top, rb->Curr, + rb->Top, m_UseFullRange) && + (rb->WindDelta != 0) && (rb->PrevInAEL->WindDelta != 0)) { + OutPt *Op2 = AddOutPt(rb->PrevInAEL, rb->Bot); + AddJoin(Op1, Op2, rb->Top); + } + + TEdge *e = lb->NextInAEL; + if (e) { + while (e != rb) { + // nb: For calculating winding counts etc, IntersectEdges() assumes + // that param1 will be to the Right of param2 ABOVE the intersection + // ... + IntersectEdges(rb, e, lb->Curr); // order important here + e = e->NextInAEL; + } + } + } + } +} +//------------------------------------------------------------------------------ + +void Clipper::DeleteFromSEL(TEdge *e) { + TEdge *SelPrev = e->PrevInSEL; + TEdge *SelNext = e->NextInSEL; + if (!SelPrev && !SelNext && (e != m_SortedEdges)) + return; // already deleted + if (SelPrev) + SelPrev->NextInSEL = SelNext; + else + m_SortedEdges = SelNext; + if (SelNext) + SelNext->PrevInSEL = SelPrev; + e->NextInSEL = 0; + e->PrevInSEL = 0; +} +//------------------------------------------------------------------------------ + +#ifdef use_xyz +void Clipper::SetZ(IntPoint &pt, TEdge &e1, TEdge &e2) { + if (pt.Z != 0 || !m_ZFill) + return; + else if (pt == e1.Bot) + pt.Z = e1.Bot.Z; + else if (pt == e1.Top) + pt.Z = e1.Top.Z; + else if (pt == e2.Bot) + pt.Z = e2.Bot.Z; + else if (pt == e2.Top) + pt.Z = e2.Top.Z; + else + (*m_ZFill)(e1.Bot, e1.Top, e2.Bot, e2.Top, pt); +} +//------------------------------------------------------------------------------ +#endif + +void Clipper::IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &Pt) { + bool e1Contributing = (e1->OutIdx >= 0); + bool e2Contributing = (e2->OutIdx >= 0); + +#ifdef use_xyz + SetZ(Pt, *e1, *e2); +#endif + +#ifdef use_lines + // if either edge is on an OPEN path ... + if (e1->WindDelta == 0 || e2->WindDelta == 0) { + // ignore subject-subject open path intersections UNLESS they + // are both open paths, AND they are both 'contributing maximas' ... + if (e1->WindDelta == 0 && e2->WindDelta == 0) + return; + + // if intersecting a subj line with a subj poly ... + else if (e1->PolyTyp == e2->PolyTyp && e1->WindDelta != e2->WindDelta && + m_ClipType == ctUnion) { + if (e1->WindDelta == 0) { + if (e2Contributing) { + AddOutPt(e1, Pt); + if (e1Contributing) + e1->OutIdx = Unassigned; + } + } else { + if (e1Contributing) { + AddOutPt(e2, Pt); + if (e2Contributing) + e2->OutIdx = Unassigned; + } + } + } else if (e1->PolyTyp != e2->PolyTyp) { + // toggle subj open path OutIdx on/off when Abs(clip.WndCnt) == 1 ... + if ((e1->WindDelta == 0) && abs(e2->WindCnt) == 1 && + (m_ClipType != ctUnion || e2->WindCnt2 == 0)) { + AddOutPt(e1, Pt); + if (e1Contributing) + e1->OutIdx = Unassigned; + } else if ((e2->WindDelta == 0) && (abs(e1->WindCnt) == 1) && + (m_ClipType != ctUnion || e1->WindCnt2 == 0)) { + AddOutPt(e2, Pt); + if (e2Contributing) + e2->OutIdx = Unassigned; + } + } + return; + } +#endif + + // update winding counts... + // assumes that e1 will be to the Right of e2 ABOVE the intersection + if (e1->PolyTyp == e2->PolyTyp) { + if (IsEvenOddFillType(*e1)) { + int oldE1WindCnt = e1->WindCnt; + e1->WindCnt = e2->WindCnt; + e2->WindCnt = oldE1WindCnt; + } else { + if (e1->WindCnt + e2->WindDelta == 0) + e1->WindCnt = -e1->WindCnt; + else + e1->WindCnt += e2->WindDelta; + if (e2->WindCnt - e1->WindDelta == 0) + e2->WindCnt = -e2->WindCnt; + else + e2->WindCnt -= e1->WindDelta; + } + } else { + if (!IsEvenOddFillType(*e2)) + e1->WindCnt2 += e2->WindDelta; + else + e1->WindCnt2 = (e1->WindCnt2 == 0) ? 1 : 0; + if (!IsEvenOddFillType(*e1)) + e2->WindCnt2 -= e1->WindDelta; + else + e2->WindCnt2 = (e2->WindCnt2 == 0) ? 1 : 0; + } + + PolyFillType e1FillType, e2FillType, e1FillType2, e2FillType2; + if (e1->PolyTyp == ptSubject) { + e1FillType = m_SubjFillType; + e1FillType2 = m_ClipFillType; + } else { + e1FillType = m_ClipFillType; + e1FillType2 = m_SubjFillType; + } + if (e2->PolyTyp == ptSubject) { + e2FillType = m_SubjFillType; + e2FillType2 = m_ClipFillType; + } else { + e2FillType = m_ClipFillType; + e2FillType2 = m_SubjFillType; + } + + cInt e1Wc, e2Wc; + switch (e1FillType) { + case pftPositive: + e1Wc = e1->WindCnt; + break; + case pftNegative: + e1Wc = -e1->WindCnt; + break; + default: + e1Wc = Abs(e1->WindCnt); + } + switch (e2FillType) { + case pftPositive: + e2Wc = e2->WindCnt; + break; + case pftNegative: + e2Wc = -e2->WindCnt; + break; + default: + e2Wc = Abs(e2->WindCnt); + } + + if (e1Contributing && e2Contributing) { + if ((e1Wc != 0 && e1Wc != 1) || (e2Wc != 0 && e2Wc != 1) || + (e1->PolyTyp != e2->PolyTyp && m_ClipType != ctXor)) { + AddLocalMaxPoly(e1, e2, Pt); + } else { + AddOutPt(e1, Pt); + AddOutPt(e2, Pt); + SwapSides(*e1, *e2); + SwapPolyIndexes(*e1, *e2); + } + } else if (e1Contributing) { + if (e2Wc == 0 || e2Wc == 1) { + AddOutPt(e1, Pt); + SwapSides(*e1, *e2); + SwapPolyIndexes(*e1, *e2); + } + } else if (e2Contributing) { + if (e1Wc == 0 || e1Wc == 1) { + AddOutPt(e2, Pt); + SwapSides(*e1, *e2); + SwapPolyIndexes(*e1, *e2); + } + } else if ((e1Wc == 0 || e1Wc == 1) && (e2Wc == 0 || e2Wc == 1)) { + // neither edge is currently contributing ... + + cInt e1Wc2, e2Wc2; + switch (e1FillType2) { + case pftPositive: + e1Wc2 = e1->WindCnt2; + break; + case pftNegative: + e1Wc2 = -e1->WindCnt2; + break; + default: + e1Wc2 = Abs(e1->WindCnt2); + } + switch (e2FillType2) { + case pftPositive: + e2Wc2 = e2->WindCnt2; + break; + case pftNegative: + e2Wc2 = -e2->WindCnt2; + break; + default: + e2Wc2 = Abs(e2->WindCnt2); + } + + if (e1->PolyTyp != e2->PolyTyp) { + AddLocalMinPoly(e1, e2, Pt); + } else if (e1Wc == 1 && e2Wc == 1) + switch (m_ClipType) { + case ctIntersection: + if (e1Wc2 > 0 && e2Wc2 > 0) + AddLocalMinPoly(e1, e2, Pt); + break; + case ctUnion: + if (e1Wc2 <= 0 && e2Wc2 <= 0) + AddLocalMinPoly(e1, e2, Pt); + break; + case ctDifference: + if (((e1->PolyTyp == ptClip) && (e1Wc2 > 0) && (e2Wc2 > 0)) || + ((e1->PolyTyp == ptSubject) && (e1Wc2 <= 0) && (e2Wc2 <= 0))) + AddLocalMinPoly(e1, e2, Pt); + break; + case ctXor: + AddLocalMinPoly(e1, e2, Pt); + } + else + SwapSides(*e1, *e2); + } +} +//------------------------------------------------------------------------------ + +void Clipper::SetHoleState(TEdge *e, OutRec *outrec) { + TEdge *e2 = e->PrevInAEL; + TEdge *eTmp = 0; + while (e2) { + if (e2->OutIdx >= 0 && e2->WindDelta != 0) { + if (!eTmp) + eTmp = e2; + else if (eTmp->OutIdx == e2->OutIdx) + eTmp = 0; + } + e2 = e2->PrevInAEL; + } + if (!eTmp) { + outrec->FirstLeft = 0; + outrec->IsHole = false; + } else { + outrec->FirstLeft = m_PolyOuts[eTmp->OutIdx]; + outrec->IsHole = !outrec->FirstLeft->IsHole; + } +} +//------------------------------------------------------------------------------ + +OutRec *GetLowermostRec(OutRec *outRec1, OutRec *outRec2) { + // work out which polygon fragment has the correct hole state ... + if (!outRec1->BottomPt) + outRec1->BottomPt = GetBottomPt(outRec1->Pts); + if (!outRec2->BottomPt) + outRec2->BottomPt = GetBottomPt(outRec2->Pts); + OutPt *OutPt1 = outRec1->BottomPt; + OutPt *OutPt2 = outRec2->BottomPt; + if (OutPt1->Pt.Y > OutPt2->Pt.Y) + return outRec1; + else if (OutPt1->Pt.Y < OutPt2->Pt.Y) + return outRec2; + else if (OutPt1->Pt.X < OutPt2->Pt.X) + return outRec1; + else if (OutPt1->Pt.X > OutPt2->Pt.X) + return outRec2; + else if (OutPt1->Next == OutPt1) + return outRec2; + else if (OutPt2->Next == OutPt2) + return outRec1; + else if (FirstIsBottomPt(OutPt1, OutPt2)) + return outRec1; + else + return outRec2; +} +//------------------------------------------------------------------------------ + +bool OutRec1RightOfOutRec2(OutRec *outRec1, OutRec *outRec2) { + do { + outRec1 = outRec1->FirstLeft; + if (outRec1 == outRec2) + return true; + } while (outRec1); + return false; +} +//------------------------------------------------------------------------------ + +OutRec *Clipper::GetOutRec(int Idx) { + OutRec *outrec = m_PolyOuts[Idx]; + while (outrec != m_PolyOuts[outrec->Idx]) + outrec = m_PolyOuts[outrec->Idx]; + return outrec; +} +//------------------------------------------------------------------------------ + +void Clipper::AppendPolygon(TEdge *e1, TEdge *e2) { + // get the start and ends of both output polygons ... + OutRec *outRec1 = m_PolyOuts[e1->OutIdx]; + OutRec *outRec2 = m_PolyOuts[e2->OutIdx]; + + OutRec *holeStateRec; + if (OutRec1RightOfOutRec2(outRec1, outRec2)) + holeStateRec = outRec2; + else if (OutRec1RightOfOutRec2(outRec2, outRec1)) + holeStateRec = outRec1; + else + holeStateRec = GetLowermostRec(outRec1, outRec2); + + // get the start and ends of both output polygons and + // join e2 poly onto e1 poly and delete pointers to e2 ... + + OutPt *p1_lft = outRec1->Pts; + OutPt *p1_rt = p1_lft->Prev; + OutPt *p2_lft = outRec2->Pts; + OutPt *p2_rt = p2_lft->Prev; + + // join e2 poly onto e1 poly and delete pointers to e2 ... + if (e1->Side == esLeft) { + if (e2->Side == esLeft) { + // z y x a b c + ReversePolyPtLinks(p2_lft); + p2_lft->Next = p1_lft; + p1_lft->Prev = p2_lft; + p1_rt->Next = p2_rt; + p2_rt->Prev = p1_rt; + outRec1->Pts = p2_rt; + } else { + // x y z a b c + p2_rt->Next = p1_lft; + p1_lft->Prev = p2_rt; + p2_lft->Prev = p1_rt; + p1_rt->Next = p2_lft; + outRec1->Pts = p2_lft; + } + } else { + if (e2->Side == esRight) { + // a b c z y x + ReversePolyPtLinks(p2_lft); + p1_rt->Next = p2_rt; + p2_rt->Prev = p1_rt; + p2_lft->Next = p1_lft; + p1_lft->Prev = p2_lft; + } else { + // a b c x y z + p1_rt->Next = p2_lft; + p2_lft->Prev = p1_rt; + p1_lft->Prev = p2_rt; + p2_rt->Next = p1_lft; + } + } + + outRec1->BottomPt = 0; + if (holeStateRec == outRec2) { + if (outRec2->FirstLeft != outRec1) + outRec1->FirstLeft = outRec2->FirstLeft; + outRec1->IsHole = outRec2->IsHole; + } + outRec2->Pts = 0; + outRec2->BottomPt = 0; + outRec2->FirstLeft = outRec1; + + int OKIdx = e1->OutIdx; + int ObsoleteIdx = e2->OutIdx; + + e1->OutIdx = + Unassigned; // nb: safe because we only get here via AddLocalMaxPoly + e2->OutIdx = Unassigned; + + TEdge *e = m_ActiveEdges; + while (e) { + if (e->OutIdx == ObsoleteIdx) { + e->OutIdx = OKIdx; + e->Side = e1->Side; + break; + } + e = e->NextInAEL; + } + + outRec2->Idx = outRec1->Idx; +} +//------------------------------------------------------------------------------ + +OutPt *Clipper::AddOutPt(TEdge *e, const IntPoint &pt) { + if (e->OutIdx < 0) { + OutRec *outRec = CreateOutRec(); + outRec->IsOpen = (e->WindDelta == 0); + OutPt *newOp = new OutPt; + outRec->Pts = newOp; + newOp->Idx = outRec->Idx; + newOp->Pt = pt; + newOp->Next = newOp; + newOp->Prev = newOp; + if (!outRec->IsOpen) + SetHoleState(e, outRec); + e->OutIdx = outRec->Idx; + return newOp; + } else { + OutRec *outRec = m_PolyOuts[e->OutIdx]; + // OutRec.Pts is the 'Left-most' point & OutRec.Pts.Prev is the 'Right-most' + OutPt *op = outRec->Pts; + + bool ToFront = (e->Side == esLeft); + if (ToFront && (pt == op->Pt)) + return op; + else if (!ToFront && (pt == op->Prev->Pt)) + return op->Prev; + + OutPt *newOp = new OutPt; + newOp->Idx = outRec->Idx; + newOp->Pt = pt; + newOp->Next = op; + newOp->Prev = op->Prev; + newOp->Prev->Next = newOp; + op->Prev = newOp; + if (ToFront) + outRec->Pts = newOp; + return newOp; + } +} +//------------------------------------------------------------------------------ + +OutPt *Clipper::GetLastOutPt(TEdge *e) { + OutRec *outRec = m_PolyOuts[e->OutIdx]; + if (e->Side == esLeft) + return outRec->Pts; + else + return outRec->Pts->Prev; +} +//------------------------------------------------------------------------------ + +void Clipper::ProcessHorizontals() { + TEdge *horzEdge; + while (PopEdgeFromSEL(horzEdge)) + ProcessHorizontal(horzEdge); +} +//------------------------------------------------------------------------------ + +inline bool IsMinima(TEdge *e) { + return e && (e->Prev->NextInLML != e) && (e->Next->NextInLML != e); +} +//------------------------------------------------------------------------------ + +inline bool IsMaxima(TEdge *e, const cInt Y) { + return e && e->Top.Y == Y && !e->NextInLML; +} +//------------------------------------------------------------------------------ + +inline bool IsIntermediate(TEdge *e, const cInt Y) { + return e->Top.Y == Y && e->NextInLML; +} +//------------------------------------------------------------------------------ + +TEdge *GetMaximaPair(TEdge *e) { + if ((e->Next->Top == e->Top) && !e->Next->NextInLML) + return e->Next; + else if ((e->Prev->Top == e->Top) && !e->Prev->NextInLML) + return e->Prev; + else + return 0; +} +//------------------------------------------------------------------------------ + +TEdge *GetMaximaPairEx(TEdge *e) { + // as GetMaximaPair() but returns 0 if MaxPair isn't in AEL (unless it's + // horizontal) + TEdge *result = GetMaximaPair(e); + if (result && + (result->OutIdx == Skip || + (result->NextInAEL == result->PrevInAEL && !IsHorizontal(*result)))) + return 0; + return result; +} +//------------------------------------------------------------------------------ + +void Clipper::SwapPositionsInSEL(TEdge *Edge1, TEdge *Edge2) { + if (!(Edge1->NextInSEL) && !(Edge1->PrevInSEL)) + return; + if (!(Edge2->NextInSEL) && !(Edge2->PrevInSEL)) + return; + + if (Edge1->NextInSEL == Edge2) { + TEdge *Next = Edge2->NextInSEL; + if (Next) + Next->PrevInSEL = Edge1; + TEdge *Prev = Edge1->PrevInSEL; + if (Prev) + Prev->NextInSEL = Edge2; + Edge2->PrevInSEL = Prev; + Edge2->NextInSEL = Edge1; + Edge1->PrevInSEL = Edge2; + Edge1->NextInSEL = Next; + } else if (Edge2->NextInSEL == Edge1) { + TEdge *Next = Edge1->NextInSEL; + if (Next) + Next->PrevInSEL = Edge2; + TEdge *Prev = Edge2->PrevInSEL; + if (Prev) + Prev->NextInSEL = Edge1; + Edge1->PrevInSEL = Prev; + Edge1->NextInSEL = Edge2; + Edge2->PrevInSEL = Edge1; + Edge2->NextInSEL = Next; + } else { + TEdge *Next = Edge1->NextInSEL; + TEdge *Prev = Edge1->PrevInSEL; + Edge1->NextInSEL = Edge2->NextInSEL; + if (Edge1->NextInSEL) + Edge1->NextInSEL->PrevInSEL = Edge1; + Edge1->PrevInSEL = Edge2->PrevInSEL; + if (Edge1->PrevInSEL) + Edge1->PrevInSEL->NextInSEL = Edge1; + Edge2->NextInSEL = Next; + if (Edge2->NextInSEL) + Edge2->NextInSEL->PrevInSEL = Edge2; + Edge2->PrevInSEL = Prev; + if (Edge2->PrevInSEL) + Edge2->PrevInSEL->NextInSEL = Edge2; + } + + if (!Edge1->PrevInSEL) + m_SortedEdges = Edge1; + else if (!Edge2->PrevInSEL) + m_SortedEdges = Edge2; +} +//------------------------------------------------------------------------------ + +TEdge *GetNextInAEL(TEdge *e, Direction dir) { + return dir == dLeftToRight ? e->NextInAEL : e->PrevInAEL; +} +//------------------------------------------------------------------------------ + +void GetHorzDirection(TEdge &HorzEdge, Direction &Dir, cInt &Left, + cInt &Right) { + if (HorzEdge.Bot.X < HorzEdge.Top.X) { + Left = HorzEdge.Bot.X; + Right = HorzEdge.Top.X; + Dir = dLeftToRight; + } else { + Left = HorzEdge.Top.X; + Right = HorzEdge.Bot.X; + Dir = dRightToLeft; + } +} +//------------------------------------------------------------------------ + +/******************************************************************************* +* Notes: Horizontal edges (HEs) at scanline intersections (ie at the Top or * +* Bottom of a scanbeam) are processed as if layered. The order in which HEs * +* are processed doesn't matter. HEs intersect with other HE Bot.Xs only [#] * +* (or they could intersect with Top.Xs only, ie EITHER Bot.Xs OR Top.Xs), * +* and with other non-horizontal edges [*]. Once these intersections are * +* processed, intermediate HEs then 'promote' the Edge above (NextInLML) into * +* the AEL. These 'promoted' edges may in turn intersect [%] with other HEs. * +*******************************************************************************/ + +void Clipper::ProcessHorizontal(TEdge *horzEdge) { + Direction dir; + cInt horzLeft, horzRight; + bool IsOpen = (horzEdge->WindDelta == 0); + + GetHorzDirection(*horzEdge, dir, horzLeft, horzRight); + + TEdge *eLastHorz = horzEdge, *eMaxPair = 0; + while (eLastHorz->NextInLML && IsHorizontal(*eLastHorz->NextInLML)) + eLastHorz = eLastHorz->NextInLML; + if (!eLastHorz->NextInLML) + eMaxPair = GetMaximaPair(eLastHorz); + + MaximaList::const_iterator maxIt; + MaximaList::const_reverse_iterator maxRit; + if (m_Maxima.size() > 0) { + // get the first maxima in range (X) ... + if (dir == dLeftToRight) { + maxIt = m_Maxima.begin(); + while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X) + maxIt++; + if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X) + maxIt = m_Maxima.end(); + } else { + maxRit = m_Maxima.rbegin(); + while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X) + maxRit++; + if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X) + maxRit = m_Maxima.rend(); + } + } + + OutPt *op1 = 0; + + for (;;) // loop through consec. horizontal edges + { + + bool IsLastHorz = (horzEdge == eLastHorz); + TEdge *e = GetNextInAEL(horzEdge, dir); + while (e) { + + // this code block inserts extra coords into horizontal edges (in output + // polygons) whereever maxima touch these horizontal edges. This helps + //'simplifying' polygons (ie if the Simplify property is set). + if (m_Maxima.size() > 0) { + if (dir == dLeftToRight) { + while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) { + if (horzEdge->OutIdx >= 0 && !IsOpen) + AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y)); + maxIt++; + } + } else { + while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) { + if (horzEdge->OutIdx >= 0 && !IsOpen) + AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y)); + maxRit++; + } + } + }; + + if ((dir == dLeftToRight && e->Curr.X > horzRight) || + (dir == dRightToLeft && e->Curr.X < horzLeft)) + break; + + // Also break if we've got to the end of an intermediate horizontal edge + // ... + // nb: Smaller Dx's are to the right of larger Dx's ABOVE the horizontal. + if (e->Curr.X == horzEdge->Top.X && horzEdge->NextInLML && + e->Dx < horzEdge->NextInLML->Dx) + break; + + if (horzEdge->OutIdx >= 0 && !IsOpen) // note: may be done multiple times + { +#ifdef use_xyz + if (dir == dLeftToRight) + SetZ(e->Curr, *horzEdge, *e); + else + SetZ(e->Curr, *e, *horzEdge); +#endif + op1 = AddOutPt(horzEdge, e->Curr); + TEdge *eNextHorz = m_SortedEdges; + while (eNextHorz) { + if (eNextHorz->OutIdx >= 0 && + HorzSegmentsOverlap(horzEdge->Bot.X, horzEdge->Top.X, + eNextHorz->Bot.X, eNextHorz->Top.X)) { + OutPt *op2 = GetLastOutPt(eNextHorz); + AddJoin(op2, op1, eNextHorz->Top); + } + eNextHorz = eNextHorz->NextInSEL; + } + AddGhostJoin(op1, horzEdge->Bot); + } + + // OK, so far we're still in range of the horizontal Edge but make sure + // we're at the last of consec. horizontals when matching with eMaxPair + if (e == eMaxPair && IsLastHorz) { + if (horzEdge->OutIdx >= 0) + AddLocalMaxPoly(horzEdge, eMaxPair, horzEdge->Top); + DeleteFromAEL(horzEdge); + DeleteFromAEL(eMaxPair); + return; + } + + if (dir == dLeftToRight) { + IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y); + IntersectEdges(horzEdge, e, Pt); + } else { + IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y); + IntersectEdges(e, horzEdge, Pt); + } + TEdge *eNext = GetNextInAEL(e, dir); + SwapPositionsInAEL(horzEdge, e); + e = eNext; + } // end while(e) + + // Break out of loop if HorzEdge.NextInLML is not also horizontal ... + if (!horzEdge->NextInLML || !IsHorizontal(*horzEdge->NextInLML)) + break; + + UpdateEdgeIntoAEL(horzEdge); + if (horzEdge->OutIdx >= 0) + AddOutPt(horzEdge, horzEdge->Bot); + GetHorzDirection(*horzEdge, dir, horzLeft, horzRight); + + } // end for (;;) + + if (horzEdge->OutIdx >= 0 && !op1) { + op1 = GetLastOutPt(horzEdge); + TEdge *eNextHorz = m_SortedEdges; + while (eNextHorz) { + if (eNextHorz->OutIdx >= 0 && + HorzSegmentsOverlap(horzEdge->Bot.X, horzEdge->Top.X, + eNextHorz->Bot.X, eNextHorz->Top.X)) { + OutPt *op2 = GetLastOutPt(eNextHorz); + AddJoin(op2, op1, eNextHorz->Top); + } + eNextHorz = eNextHorz->NextInSEL; + } + AddGhostJoin(op1, horzEdge->Top); + } + + if (horzEdge->NextInLML) { + if (horzEdge->OutIdx >= 0) { + op1 = AddOutPt(horzEdge, horzEdge->Top); + UpdateEdgeIntoAEL(horzEdge); + if (horzEdge->WindDelta == 0) + return; + // nb: HorzEdge is no longer horizontal here + TEdge *ePrev = horzEdge->PrevInAEL; + TEdge *eNext = horzEdge->NextInAEL; + if (ePrev && ePrev->Curr.X == horzEdge->Bot.X && + ePrev->Curr.Y == horzEdge->Bot.Y && ePrev->WindDelta != 0 && + (ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y && + SlopesEqual(*horzEdge, *ePrev, m_UseFullRange))) { + OutPt *op2 = AddOutPt(ePrev, horzEdge->Bot); + AddJoin(op1, op2, horzEdge->Top); + } else if (eNext && eNext->Curr.X == horzEdge->Bot.X && + eNext->Curr.Y == horzEdge->Bot.Y && eNext->WindDelta != 0 && + eNext->OutIdx >= 0 && eNext->Curr.Y > eNext->Top.Y && + SlopesEqual(*horzEdge, *eNext, m_UseFullRange)) { + OutPt *op2 = AddOutPt(eNext, horzEdge->Bot); + AddJoin(op1, op2, horzEdge->Top); + } + } else + UpdateEdgeIntoAEL(horzEdge); + } else { + if (horzEdge->OutIdx >= 0) + AddOutPt(horzEdge, horzEdge->Top); + DeleteFromAEL(horzEdge); + } +} +//------------------------------------------------------------------------------ + +bool Clipper::ProcessIntersections(const cInt topY) { + if (!m_ActiveEdges) + return true; + try { + BuildIntersectList(topY); + size_t IlSize = m_IntersectList.size(); + if (IlSize == 0) + return true; + if (IlSize == 1 || FixupIntersectionOrder()) + ProcessIntersectList(); + else + return false; + } catch (...) { + m_SortedEdges = 0; + DisposeIntersectNodes(); + throw clipperException("ProcessIntersections error"); + } + m_SortedEdges = 0; + return true; +} +//------------------------------------------------------------------------------ + +void Clipper::DisposeIntersectNodes() { + for (size_t i = 0; i < m_IntersectList.size(); ++i) + delete m_IntersectList[i]; + m_IntersectList.clear(); +} +//------------------------------------------------------------------------------ + +void Clipper::BuildIntersectList(const cInt topY) { + if (!m_ActiveEdges) + return; + + // prepare for sorting ... + TEdge *e = m_ActiveEdges; + m_SortedEdges = e; + while (e) { + e->PrevInSEL = e->PrevInAEL; + e->NextInSEL = e->NextInAEL; + e->Curr.X = TopX(*e, topY); + e = e->NextInAEL; + } + + // bubblesort ... + bool isModified; + do { + isModified = false; + e = m_SortedEdges; + while (e->NextInSEL) { + TEdge *eNext = e->NextInSEL; + IntPoint Pt; + if (e->Curr.X > eNext->Curr.X) { + IntersectPoint(*e, *eNext, Pt); + if (Pt.Y < topY) + Pt = IntPoint(TopX(*e, topY), topY); + IntersectNode *newNode = new IntersectNode; + newNode->Edge1 = e; + newNode->Edge2 = eNext; + newNode->Pt = Pt; + m_IntersectList.push_back(newNode); + + SwapPositionsInSEL(e, eNext); + isModified = true; + } else + e = eNext; + } + if (e->PrevInSEL) + e->PrevInSEL->NextInSEL = 0; + else + break; + } while (isModified); + m_SortedEdges = 0; // important +} +//------------------------------------------------------------------------------ + +void Clipper::ProcessIntersectList() { + for (size_t i = 0; i < m_IntersectList.size(); ++i) { + IntersectNode *iNode = m_IntersectList[i]; + { + IntersectEdges(iNode->Edge1, iNode->Edge2, iNode->Pt); + SwapPositionsInAEL(iNode->Edge1, iNode->Edge2); + } + delete iNode; + } + m_IntersectList.clear(); +} +//------------------------------------------------------------------------------ + +bool IntersectListSort(IntersectNode *node1, IntersectNode *node2) { + return node2->Pt.Y < node1->Pt.Y; +} +//------------------------------------------------------------------------------ + +inline bool EdgesAdjacent(const IntersectNode &inode) { + return (inode.Edge1->NextInSEL == inode.Edge2) || + (inode.Edge1->PrevInSEL == inode.Edge2); +} +//------------------------------------------------------------------------------ + +bool Clipper::FixupIntersectionOrder() { + // pre-condition: intersections are sorted Bottom-most first. + // Now it's crucial that intersections are made only between adjacent edges, + // so to ensure this the order of intersections may need adjusting ... + CopyAELToSEL(); + std::sort(m_IntersectList.begin(), m_IntersectList.end(), IntersectListSort); + size_t cnt = m_IntersectList.size(); + for (size_t i = 0; i < cnt; ++i) { + if (!EdgesAdjacent(*m_IntersectList[i])) { + size_t j = i + 1; + while (j < cnt && !EdgesAdjacent(*m_IntersectList[j])) + j++; + if (j == cnt) + return false; + std::swap(m_IntersectList[i], m_IntersectList[j]); + } + SwapPositionsInSEL(m_IntersectList[i]->Edge1, m_IntersectList[i]->Edge2); + } + return true; +} +//------------------------------------------------------------------------------ + +void Clipper::DoMaxima(TEdge *e) { + TEdge *eMaxPair = GetMaximaPairEx(e); + if (!eMaxPair) { + if (e->OutIdx >= 0) + AddOutPt(e, e->Top); + DeleteFromAEL(e); + return; + } + + TEdge *eNext = e->NextInAEL; + while (eNext && eNext != eMaxPair) { + IntersectEdges(e, eNext, e->Top); + SwapPositionsInAEL(e, eNext); + eNext = e->NextInAEL; + } + + if (e->OutIdx == Unassigned && eMaxPair->OutIdx == Unassigned) { + DeleteFromAEL(e); + DeleteFromAEL(eMaxPair); + } else if (e->OutIdx >= 0 && eMaxPair->OutIdx >= 0) { + if (e->OutIdx >= 0) + AddLocalMaxPoly(e, eMaxPair, e->Top); + DeleteFromAEL(e); + DeleteFromAEL(eMaxPair); + } +#ifdef use_lines + else if (e->WindDelta == 0) { + if (e->OutIdx >= 0) { + AddOutPt(e, e->Top); + e->OutIdx = Unassigned; + } + DeleteFromAEL(e); + + if (eMaxPair->OutIdx >= 0) { + AddOutPt(eMaxPair, e->Top); + eMaxPair->OutIdx = Unassigned; + } + DeleteFromAEL(eMaxPair); + } +#endif + else + throw clipperException("DoMaxima error"); +} +//------------------------------------------------------------------------------ + +void Clipper::ProcessEdgesAtTopOfScanbeam(const cInt topY) { + TEdge *e = m_ActiveEdges; + while (e) { + // 1. process maxima, treating them as if they're 'bent' horizontal edges, + // but exclude maxima with horizontal edges. nb: e can't be a horizontal. + bool IsMaximaEdge = IsMaxima(e, topY); + + if (IsMaximaEdge) { + TEdge *eMaxPair = GetMaximaPairEx(e); + IsMaximaEdge = (!eMaxPair || !IsHorizontal(*eMaxPair)); + } + + if (IsMaximaEdge) { + if (m_StrictSimple) + m_Maxima.push_back(e->Top.X); + TEdge *ePrev = e->PrevInAEL; + DoMaxima(e); + if (!ePrev) + e = m_ActiveEdges; + else + e = ePrev->NextInAEL; + } else { + // 2. promote horizontal edges, otherwise update Curr.X and Curr.Y ... + if (IsIntermediate(e, topY) && IsHorizontal(*e->NextInLML)) { + UpdateEdgeIntoAEL(e); + if (e->OutIdx >= 0) + AddOutPt(e, e->Bot); + AddEdgeToSEL(e); + } else { + e->Curr.X = TopX(*e, topY); + e->Curr.Y = topY; +#ifdef use_xyz + e->Curr.Z = + topY == e->Top.Y ? e->Top.Z : (topY == e->Bot.Y ? e->Bot.Z : 0); +#endif + } + + // When StrictlySimple and 'e' is being touched by another edge, then + // make sure both edges have a vertex here ... + if (m_StrictSimple) { + TEdge *ePrev = e->PrevInAEL; + if ((e->OutIdx >= 0) && (e->WindDelta != 0) && ePrev && + (ePrev->OutIdx >= 0) && (ePrev->Curr.X == e->Curr.X) && + (ePrev->WindDelta != 0)) { + IntPoint pt = e->Curr; +#ifdef use_xyz + SetZ(pt, *ePrev, *e); +#endif + OutPt *op = AddOutPt(ePrev, pt); + OutPt *op2 = AddOutPt(e, pt); + AddJoin(op, op2, pt); // StrictlySimple (type-3) join + } + } + + e = e->NextInAEL; + } + } + + // 3. Process horizontals at the Top of the scanbeam ... + m_Maxima.sort(); + ProcessHorizontals(); + m_Maxima.clear(); + + // 4. Promote intermediate vertices ... + e = m_ActiveEdges; + while (e) { + if (IsIntermediate(e, topY)) { + OutPt *op = 0; + if (e->OutIdx >= 0) + op = AddOutPt(e, e->Top); + UpdateEdgeIntoAEL(e); + + // if output polygons share an edge, they'll need joining later ... + TEdge *ePrev = e->PrevInAEL; + TEdge *eNext = e->NextInAEL; + if (ePrev && ePrev->Curr.X == e->Bot.X && ePrev->Curr.Y == e->Bot.Y && + op && ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y && + SlopesEqual(e->Curr, e->Top, ePrev->Curr, ePrev->Top, + m_UseFullRange) && + (e->WindDelta != 0) && (ePrev->WindDelta != 0)) { + OutPt *op2 = AddOutPt(ePrev, e->Bot); + AddJoin(op, op2, e->Top); + } else if (eNext && eNext->Curr.X == e->Bot.X && + eNext->Curr.Y == e->Bot.Y && op && eNext->OutIdx >= 0 && + eNext->Curr.Y > eNext->Top.Y && + SlopesEqual(e->Curr, e->Top, eNext->Curr, eNext->Top, + m_UseFullRange) && + (e->WindDelta != 0) && (eNext->WindDelta != 0)) { + OutPt *op2 = AddOutPt(eNext, e->Bot); + AddJoin(op, op2, e->Top); + } + } + e = e->NextInAEL; + } +} +//------------------------------------------------------------------------------ + +void Clipper::FixupOutPolyline(OutRec &outrec) { + OutPt *pp = outrec.Pts; + OutPt *lastPP = pp->Prev; + while (pp != lastPP) { + pp = pp->Next; + if (pp->Pt == pp->Prev->Pt) { + if (pp == lastPP) + lastPP = pp->Prev; + OutPt *tmpPP = pp->Prev; + tmpPP->Next = pp->Next; + pp->Next->Prev = tmpPP; + delete pp; + pp = tmpPP; + } + } + + if (pp == pp->Prev) { + DisposeOutPts(pp); + outrec.Pts = 0; + return; + } +} +//------------------------------------------------------------------------------ + +void Clipper::FixupOutPolygon(OutRec &outrec) { + // FixupOutPolygon() - removes duplicate points and simplifies consecutive + // parallel edges by removing the middle vertex. + OutPt *lastOK = 0; + outrec.BottomPt = 0; + OutPt *pp = outrec.Pts; + bool preserveCol = m_PreserveCollinear || m_StrictSimple; + + for (;;) { + if (pp->Prev == pp || pp->Prev == pp->Next) { + DisposeOutPts(pp); + outrec.Pts = 0; + return; + } + + // test for duplicate points and collinear edges ... + if ((pp->Pt == pp->Next->Pt) || (pp->Pt == pp->Prev->Pt) || + (SlopesEqual(pp->Prev->Pt, pp->Pt, pp->Next->Pt, m_UseFullRange) && + (!preserveCol || + !Pt2IsBetweenPt1AndPt3(pp->Prev->Pt, pp->Pt, pp->Next->Pt)))) { + lastOK = 0; + OutPt *tmp = pp; + pp->Prev->Next = pp->Next; + pp->Next->Prev = pp->Prev; + pp = pp->Prev; + delete tmp; + } else if (pp == lastOK) + break; + else { + if (!lastOK) + lastOK = pp; + pp = pp->Next; + } + } + outrec.Pts = pp; +} +//------------------------------------------------------------------------------ + +int PointCount(OutPt *Pts) { + if (!Pts) + return 0; + int result = 0; + OutPt *p = Pts; + do { + result++; + p = p->Next; + } while (p != Pts); + return result; +} +//------------------------------------------------------------------------------ + +void Clipper::BuildResult(Paths &polys) { + polys.reserve(m_PolyOuts.size()); + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) { + if (!m_PolyOuts[i]->Pts) + continue; + Path pg; + OutPt *p = m_PolyOuts[i]->Pts->Prev; + int cnt = PointCount(p); + if (cnt < 2) + continue; + pg.reserve(cnt); + for (int i = 0; i < cnt; ++i) { + pg.push_back(p->Pt); + p = p->Prev; + } + polys.push_back(pg); + } +} +//------------------------------------------------------------------------------ + +void Clipper::BuildResult2(PolyTree &polytree) { + polytree.Clear(); + polytree.AllNodes.reserve(m_PolyOuts.size()); + // add each output polygon/contour to polytree ... + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++) { + OutRec *outRec = m_PolyOuts[i]; + int cnt = PointCount(outRec->Pts); + if ((outRec->IsOpen && cnt < 2) || (!outRec->IsOpen && cnt < 3)) + continue; + FixHoleLinkage(*outRec); + PolyNode *pn = new PolyNode(); + // nb: polytree takes ownership of all the PolyNodes + polytree.AllNodes.push_back(pn); + outRec->PolyNd = pn; + pn->Parent = 0; + pn->Index = 0; + pn->Contour.reserve(cnt); + OutPt *op = outRec->Pts->Prev; + for (int j = 0; j < cnt; j++) { + pn->Contour.push_back(op->Pt); + op = op->Prev; + } + } + + // fixup PolyNode links etc ... + polytree.Childs.reserve(m_PolyOuts.size()); + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++) { + OutRec *outRec = m_PolyOuts[i]; + if (!outRec->PolyNd) + continue; + if (outRec->IsOpen) { + outRec->PolyNd->m_IsOpen = true; + polytree.AddChild(*outRec->PolyNd); + } else if (outRec->FirstLeft && outRec->FirstLeft->PolyNd) + outRec->FirstLeft->PolyNd->AddChild(*outRec->PolyNd); + else + polytree.AddChild(*outRec->PolyNd); + } +} +//------------------------------------------------------------------------------ + +void SwapIntersectNodes(IntersectNode &int1, IntersectNode &int2) { + // just swap the contents (because fIntersectNodes is a single-linked-list) + IntersectNode inode = int1; // gets a copy of Int1 + int1.Edge1 = int2.Edge1; + int1.Edge2 = int2.Edge2; + int1.Pt = int2.Pt; + int2.Edge1 = inode.Edge1; + int2.Edge2 = inode.Edge2; + int2.Pt = inode.Pt; +} +//------------------------------------------------------------------------------ + +inline bool E2InsertsBeforeE1(TEdge &e1, TEdge &e2) { + if (e2.Curr.X == e1.Curr.X) { + if (e2.Top.Y > e1.Top.Y) + return e2.Top.X < TopX(e1, e2.Top.Y); + else + return e1.Top.X > TopX(e2, e1.Top.Y); + } else + return e2.Curr.X < e1.Curr.X; +} +//------------------------------------------------------------------------------ + +bool GetOverlap(const cInt a1, const cInt a2, const cInt b1, const cInt b2, + cInt &Left, cInt &Right) { + if (a1 < a2) { + if (b1 < b2) { + Left = std::max(a1, b1); + Right = std::min(a2, b2); + } else { + Left = std::max(a1, b2); + Right = std::min(a2, b1); + } + } else { + if (b1 < b2) { + Left = std::max(a2, b1); + Right = std::min(a1, b2); + } else { + Left = std::max(a2, b2); + Right = std::min(a1, b1); + } + } + return Left < Right; +} +//------------------------------------------------------------------------------ + +inline void UpdateOutPtIdxs(OutRec &outrec) { + OutPt *op = outrec.Pts; + do { + op->Idx = outrec.Idx; + op = op->Prev; + } while (op != outrec.Pts); +} +//------------------------------------------------------------------------------ + +void Clipper::InsertEdgeIntoAEL(TEdge *edge, TEdge *startEdge) { + if (!m_ActiveEdges) { + edge->PrevInAEL = 0; + edge->NextInAEL = 0; + m_ActiveEdges = edge; + } else if (!startEdge && E2InsertsBeforeE1(*m_ActiveEdges, *edge)) { + edge->PrevInAEL = 0; + edge->NextInAEL = m_ActiveEdges; + m_ActiveEdges->PrevInAEL = edge; + m_ActiveEdges = edge; + } else { + if (!startEdge) + startEdge = m_ActiveEdges; + while (startEdge->NextInAEL && + !E2InsertsBeforeE1(*startEdge->NextInAEL, *edge)) + startEdge = startEdge->NextInAEL; + edge->NextInAEL = startEdge->NextInAEL; + if (startEdge->NextInAEL) + startEdge->NextInAEL->PrevInAEL = edge; + edge->PrevInAEL = startEdge; + startEdge->NextInAEL = edge; + } +} +//---------------------------------------------------------------------- + +OutPt *DupOutPt(OutPt *outPt, bool InsertAfter) { + OutPt *result = new OutPt; + result->Pt = outPt->Pt; + result->Idx = outPt->Idx; + if (InsertAfter) { + result->Next = outPt->Next; + result->Prev = outPt; + outPt->Next->Prev = result; + outPt->Next = result; + } else { + result->Prev = outPt->Prev; + result->Next = outPt; + outPt->Prev->Next = result; + outPt->Prev = result; + } + return result; +} +//------------------------------------------------------------------------------ + +bool JoinHorz(OutPt *op1, OutPt *op1b, OutPt *op2, OutPt *op2b, + const IntPoint Pt, bool DiscardLeft) { + Direction Dir1 = (op1->Pt.X > op1b->Pt.X ? dRightToLeft : dLeftToRight); + Direction Dir2 = (op2->Pt.X > op2b->Pt.X ? dRightToLeft : dLeftToRight); + if (Dir1 == Dir2) + return false; + + // When DiscardLeft, we want Op1b to be on the Left of Op1, otherwise we + // want Op1b to be on the Right. (And likewise with Op2 and Op2b.) + // So, to facilitate this while inserting Op1b and Op2b ... + // when DiscardLeft, make sure we're AT or RIGHT of Pt before adding Op1b, + // otherwise make sure we're AT or LEFT of Pt. (Likewise with Op2b.) + if (Dir1 == dLeftToRight) { + while (op1->Next->Pt.X <= Pt.X && op1->Next->Pt.X >= op1->Pt.X && + op1->Next->Pt.Y == Pt.Y) + op1 = op1->Next; + if (DiscardLeft && (op1->Pt.X != Pt.X)) + op1 = op1->Next; + op1b = DupOutPt(op1, !DiscardLeft); + if (op1b->Pt != Pt) { + op1 = op1b; + op1->Pt = Pt; + op1b = DupOutPt(op1, !DiscardLeft); + } + } else { + while (op1->Next->Pt.X >= Pt.X && op1->Next->Pt.X <= op1->Pt.X && + op1->Next->Pt.Y == Pt.Y) + op1 = op1->Next; + if (!DiscardLeft && (op1->Pt.X != Pt.X)) + op1 = op1->Next; + op1b = DupOutPt(op1, DiscardLeft); + if (op1b->Pt != Pt) { + op1 = op1b; + op1->Pt = Pt; + op1b = DupOutPt(op1, DiscardLeft); + } + } + + if (Dir2 == dLeftToRight) { + while (op2->Next->Pt.X <= Pt.X && op2->Next->Pt.X >= op2->Pt.X && + op2->Next->Pt.Y == Pt.Y) + op2 = op2->Next; + if (DiscardLeft && (op2->Pt.X != Pt.X)) + op2 = op2->Next; + op2b = DupOutPt(op2, !DiscardLeft); + if (op2b->Pt != Pt) { + op2 = op2b; + op2->Pt = Pt; + op2b = DupOutPt(op2, !DiscardLeft); + }; + } else { + while (op2->Next->Pt.X >= Pt.X && op2->Next->Pt.X <= op2->Pt.X && + op2->Next->Pt.Y == Pt.Y) + op2 = op2->Next; + if (!DiscardLeft && (op2->Pt.X != Pt.X)) + op2 = op2->Next; + op2b = DupOutPt(op2, DiscardLeft); + if (op2b->Pt != Pt) { + op2 = op2b; + op2->Pt = Pt; + op2b = DupOutPt(op2, DiscardLeft); + }; + }; + + if ((Dir1 == dLeftToRight) == DiscardLeft) { + op1->Prev = op2; + op2->Next = op1; + op1b->Next = op2b; + op2b->Prev = op1b; + } else { + op1->Next = op2; + op2->Prev = op1; + op1b->Prev = op2b; + op2b->Next = op1b; + } + return true; +} +//------------------------------------------------------------------------------ + +bool Clipper::JoinPoints(Join *j, OutRec *outRec1, OutRec *outRec2) { + OutPt *op1 = j->OutPt1, *op1b; + OutPt *op2 = j->OutPt2, *op2b; + + // There are 3 kinds of joins for output polygons ... + // 1. Horizontal joins where Join.OutPt1 & Join.OutPt2 are vertices anywhere + // along (horizontal) collinear edges (& Join.OffPt is on the same + // horizontal). + // 2. Non-horizontal joins where Join.OutPt1 & Join.OutPt2 are at the same + // location at the Bottom of the overlapping segment (& Join.OffPt is above). + // 3. StrictSimple joins where edges touch but are not collinear and where + // Join.OutPt1, Join.OutPt2 & Join.OffPt all share the same point. + bool isHorizontal = (j->OutPt1->Pt.Y == j->OffPt.Y); + + if (isHorizontal && (j->OffPt == j->OutPt1->Pt) && + (j->OffPt == j->OutPt2->Pt)) { + // Strictly Simple join ... + if (outRec1 != outRec2) + return false; + op1b = j->OutPt1->Next; + while (op1b != op1 && (op1b->Pt == j->OffPt)) + op1b = op1b->Next; + bool reverse1 = (op1b->Pt.Y > j->OffPt.Y); + op2b = j->OutPt2->Next; + while (op2b != op2 && (op2b->Pt == j->OffPt)) + op2b = op2b->Next; + bool reverse2 = (op2b->Pt.Y > j->OffPt.Y); + if (reverse1 == reverse2) + return false; + if (reverse1) { + op1b = DupOutPt(op1, false); + op2b = DupOutPt(op2, true); + op1->Prev = op2; + op2->Next = op1; + op1b->Next = op2b; + op2b->Prev = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } else { + op1b = DupOutPt(op1, true); + op2b = DupOutPt(op2, false); + op1->Next = op2; + op2->Prev = op1; + op1b->Prev = op2b; + op2b->Next = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } + } else if (isHorizontal) { + // treat horizontal joins differently to non-horizontal joins since with + // them we're not yet sure where the overlapping is. OutPt1.Pt & OutPt2.Pt + // may be anywhere along the horizontal edge. + op1b = op1; + while (op1->Prev->Pt.Y == op1->Pt.Y && op1->Prev != op1b && + op1->Prev != op2) + op1 = op1->Prev; + while (op1b->Next->Pt.Y == op1b->Pt.Y && op1b->Next != op1 && + op1b->Next != op2) + op1b = op1b->Next; + if (op1b->Next == op1 || op1b->Next == op2) + return false; // a flat 'polygon' + + op2b = op2; + while (op2->Prev->Pt.Y == op2->Pt.Y && op2->Prev != op2b && + op2->Prev != op1b) + op2 = op2->Prev; + while (op2b->Next->Pt.Y == op2b->Pt.Y && op2b->Next != op2 && + op2b->Next != op1) + op2b = op2b->Next; + if (op2b->Next == op2 || op2b->Next == op1) + return false; // a flat 'polygon' + + cInt Left, Right; + // Op1 --> Op1b & Op2 --> Op2b are the extremites of the horizontal edges + if (!GetOverlap(op1->Pt.X, op1b->Pt.X, op2->Pt.X, op2b->Pt.X, Left, Right)) + return false; + + // DiscardLeftSide: when overlapping edges are joined, a spike will created + // which needs to be cleaned up. However, we don't want Op1 or Op2 caught up + // on the discard Side as either may still be needed for other joins ... + IntPoint Pt; + bool DiscardLeftSide; + if (op1->Pt.X >= Left && op1->Pt.X <= Right) { + Pt = op1->Pt; + DiscardLeftSide = (op1->Pt.X > op1b->Pt.X); + } else if (op2->Pt.X >= Left && op2->Pt.X <= Right) { + Pt = op2->Pt; + DiscardLeftSide = (op2->Pt.X > op2b->Pt.X); + } else if (op1b->Pt.X >= Left && op1b->Pt.X <= Right) { + Pt = op1b->Pt; + DiscardLeftSide = op1b->Pt.X > op1->Pt.X; + } else { + Pt = op2b->Pt; + DiscardLeftSide = (op2b->Pt.X > op2->Pt.X); + } + j->OutPt1 = op1; + j->OutPt2 = op2; + return JoinHorz(op1, op1b, op2, op2b, Pt, DiscardLeftSide); + } else { + // nb: For non-horizontal joins ... + // 1. Jr.OutPt1.Pt.Y == Jr.OutPt2.Pt.Y + // 2. Jr.OutPt1.Pt > Jr.OffPt.Y + + // make sure the polygons are correctly oriented ... + op1b = op1->Next; + while ((op1b->Pt == op1->Pt) && (op1b != op1)) + op1b = op1b->Next; + bool Reverse1 = ((op1b->Pt.Y > op1->Pt.Y) || + !SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange)); + if (Reverse1) { + op1b = op1->Prev; + while ((op1b->Pt == op1->Pt) && (op1b != op1)) + op1b = op1b->Prev; + if ((op1b->Pt.Y > op1->Pt.Y) || + !SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange)) + return false; + }; + op2b = op2->Next; + while ((op2b->Pt == op2->Pt) && (op2b != op2)) + op2b = op2b->Next; + bool Reverse2 = ((op2b->Pt.Y > op2->Pt.Y) || + !SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange)); + if (Reverse2) { + op2b = op2->Prev; + while ((op2b->Pt == op2->Pt) && (op2b != op2)) + op2b = op2b->Prev; + if ((op2b->Pt.Y > op2->Pt.Y) || + !SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange)) + return false; + } + + if ((op1b == op1) || (op2b == op2) || (op1b == op2b) || + ((outRec1 == outRec2) && (Reverse1 == Reverse2))) + return false; + + if (Reverse1) { + op1b = DupOutPt(op1, false); + op2b = DupOutPt(op2, true); + op1->Prev = op2; + op2->Next = op1; + op1b->Next = op2b; + op2b->Prev = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } else { + op1b = DupOutPt(op1, true); + op2b = DupOutPt(op2, false); + op1->Next = op2; + op2->Prev = op1; + op1b->Prev = op2b; + op2b->Next = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } + } +} +//---------------------------------------------------------------------- + +static OutRec *ParseFirstLeft(OutRec *FirstLeft) { + while (FirstLeft && !FirstLeft->Pts) + FirstLeft = FirstLeft->FirstLeft; + return FirstLeft; +} +//------------------------------------------------------------------------------ + +void Clipper::FixupFirstLefts1(OutRec *OldOutRec, OutRec *NewOutRec) { + // tests if NewOutRec contains the polygon before reassigning FirstLeft + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) { + OutRec *outRec = m_PolyOuts[i]; + OutRec *firstLeft = ParseFirstLeft(outRec->FirstLeft); + if (outRec->Pts && firstLeft == OldOutRec) { + if (Poly2ContainsPoly1(outRec->Pts, NewOutRec->Pts)) + outRec->FirstLeft = NewOutRec; + } + } +} +//---------------------------------------------------------------------- + +void Clipper::FixupFirstLefts2(OutRec *InnerOutRec, OutRec *OuterOutRec) { + // A polygon has split into two such that one is now the inner of the other. + // It's possible that these polygons now wrap around other polygons, so check + // every polygon that's also contained by OuterOutRec's FirstLeft container + //(including 0) to see if they've become inner to the new inner polygon ... + OutRec *orfl = OuterOutRec->FirstLeft; + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) { + OutRec *outRec = m_PolyOuts[i]; + + if (!outRec->Pts || outRec == OuterOutRec || outRec == InnerOutRec) + continue; + OutRec *firstLeft = ParseFirstLeft(outRec->FirstLeft); + if (firstLeft != orfl && firstLeft != InnerOutRec && + firstLeft != OuterOutRec) + continue; + if (Poly2ContainsPoly1(outRec->Pts, InnerOutRec->Pts)) + outRec->FirstLeft = InnerOutRec; + else if (Poly2ContainsPoly1(outRec->Pts, OuterOutRec->Pts)) + outRec->FirstLeft = OuterOutRec; + else if (outRec->FirstLeft == InnerOutRec || + outRec->FirstLeft == OuterOutRec) + outRec->FirstLeft = orfl; + } +} +//---------------------------------------------------------------------- +void Clipper::FixupFirstLefts3(OutRec *OldOutRec, OutRec *NewOutRec) { + // reassigns FirstLeft WITHOUT testing if NewOutRec contains the polygon + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) { + OutRec *outRec = m_PolyOuts[i]; + OutRec *firstLeft = ParseFirstLeft(outRec->FirstLeft); + if (outRec->Pts && firstLeft == OldOutRec) + outRec->FirstLeft = NewOutRec; + } +} +//---------------------------------------------------------------------- + +void Clipper::JoinCommonEdges() { + for (JoinList::size_type i = 0; i < m_Joins.size(); i++) { + Join *join = m_Joins[i]; + + OutRec *outRec1 = GetOutRec(join->OutPt1->Idx); + OutRec *outRec2 = GetOutRec(join->OutPt2->Idx); + + if (!outRec1->Pts || !outRec2->Pts) + continue; + if (outRec1->IsOpen || outRec2->IsOpen) + continue; + + // get the polygon fragment with the correct hole state (FirstLeft) + // before calling JoinPoints() ... + OutRec *holeStateRec; + if (outRec1 == outRec2) + holeStateRec = outRec1; + else if (OutRec1RightOfOutRec2(outRec1, outRec2)) + holeStateRec = outRec2; + else if (OutRec1RightOfOutRec2(outRec2, outRec1)) + holeStateRec = outRec1; + else + holeStateRec = GetLowermostRec(outRec1, outRec2); + + if (!JoinPoints(join, outRec1, outRec2)) + continue; + + if (outRec1 == outRec2) { + // instead of joining two polygons, we've just created a new one by + // splitting one polygon into two. + outRec1->Pts = join->OutPt1; + outRec1->BottomPt = 0; + outRec2 = CreateOutRec(); + outRec2->Pts = join->OutPt2; + + // update all OutRec2.Pts Idx's ... + UpdateOutPtIdxs(*outRec2); + + if (Poly2ContainsPoly1(outRec2->Pts, outRec1->Pts)) { + // outRec1 contains outRec2 ... + outRec2->IsHole = !outRec1->IsHole; + outRec2->FirstLeft = outRec1; + + if (m_UsingPolyTree) + FixupFirstLefts2(outRec2, outRec1); + + if ((outRec2->IsHole ^ m_ReverseOutput) == (Area(*outRec2) > 0)) + ReversePolyPtLinks(outRec2->Pts); + + } else if (Poly2ContainsPoly1(outRec1->Pts, outRec2->Pts)) { + // outRec2 contains outRec1 ... + outRec2->IsHole = outRec1->IsHole; + outRec1->IsHole = !outRec2->IsHole; + outRec2->FirstLeft = outRec1->FirstLeft; + outRec1->FirstLeft = outRec2; + + if (m_UsingPolyTree) + FixupFirstLefts2(outRec1, outRec2); + + if ((outRec1->IsHole ^ m_ReverseOutput) == (Area(*outRec1) > 0)) + ReversePolyPtLinks(outRec1->Pts); + } else { + // the 2 polygons are completely separate ... + outRec2->IsHole = outRec1->IsHole; + outRec2->FirstLeft = outRec1->FirstLeft; + + // fixup FirstLeft pointers that may need reassigning to OutRec2 + if (m_UsingPolyTree) + FixupFirstLefts1(outRec1, outRec2); + } + + } else { + // joined 2 polygons together ... + + outRec2->Pts = 0; + outRec2->BottomPt = 0; + outRec2->Idx = outRec1->Idx; + + outRec1->IsHole = holeStateRec->IsHole; + if (holeStateRec == outRec2) + outRec1->FirstLeft = outRec2->FirstLeft; + outRec2->FirstLeft = outRec1; + + if (m_UsingPolyTree) + FixupFirstLefts3(outRec2, outRec1); + } + } +} + +//------------------------------------------------------------------------------ +// ClipperOffset support functions ... +//------------------------------------------------------------------------------ + +DoublePoint GetUnitNormal(const IntPoint &pt1, const IntPoint &pt2) { + if (pt2.X == pt1.X && pt2.Y == pt1.Y) + return DoublePoint(0, 0); + + double Dx = (double)(pt2.X - pt1.X); + double dy = (double)(pt2.Y - pt1.Y); + double f = 1 * 1.0 / std::sqrt(Dx * Dx + dy * dy); + Dx *= f; + dy *= f; + return DoublePoint(dy, -Dx); +} + +//------------------------------------------------------------------------------ +// ClipperOffset class +//------------------------------------------------------------------------------ + +ClipperOffset::ClipperOffset(double miterLimit, double arcTolerance) { + this->MiterLimit = miterLimit; + this->ArcTolerance = arcTolerance; + m_lowest.X = -1; +} +//------------------------------------------------------------------------------ + +ClipperOffset::~ClipperOffset() { Clear(); } +//------------------------------------------------------------------------------ + +void ClipperOffset::Clear() { + for (int i = 0; i < m_polyNodes.ChildCount(); ++i) + delete m_polyNodes.Childs[i]; + m_polyNodes.Childs.clear(); + m_lowest.X = -1; +} +//------------------------------------------------------------------------------ + +void ClipperOffset::AddPath(const Path &path, JoinType joinType, + EndType endType) { + int highI = (int)path.size() - 1; + if (highI < 0) + return; + PolyNode *newNode = new PolyNode(); + newNode->m_jointype = joinType; + newNode->m_endtype = endType; + + // strip duplicate points from path and also get index to the lowest point ... + if (endType == etClosedLine || endType == etClosedPolygon) + while (highI > 0 && path[0] == path[highI]) + highI--; + newNode->Contour.reserve(highI + 1); + newNode->Contour.push_back(path[0]); + int j = 0, k = 0; + for (int i = 1; i <= highI; i++) + if (newNode->Contour[j] != path[i]) { + j++; + newNode->Contour.push_back(path[i]); + if (path[i].Y > newNode->Contour[k].Y || + (path[i].Y == newNode->Contour[k].Y && + path[i].X < newNode->Contour[k].X)) + k = j; + } + if (endType == etClosedPolygon && j < 2) { + delete newNode; + return; + } + m_polyNodes.AddChild(*newNode); + + // if this path's lowest pt is lower than all the others then update m_lowest + if (endType != etClosedPolygon) + return; + if (m_lowest.X < 0) + m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k); + else { + IntPoint ip = m_polyNodes.Childs[(int)m_lowest.X]->Contour[(int)m_lowest.Y]; + if (newNode->Contour[k].Y > ip.Y || + (newNode->Contour[k].Y == ip.Y && newNode->Contour[k].X < ip.X)) + m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k); + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::AddPaths(const Paths &paths, JoinType joinType, + EndType endType) { + for (Paths::size_type i = 0; i < paths.size(); ++i) + AddPath(paths[i], joinType, endType); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::FixOrientations() { + // fixup orientations of all closed paths if the orientation of the + // closed path with the lowermost vertex is wrong ... + if (m_lowest.X >= 0 && + !Orientation(m_polyNodes.Childs[(int)m_lowest.X]->Contour)) { + for (int i = 0; i < m_polyNodes.ChildCount(); ++i) { + PolyNode &node = *m_polyNodes.Childs[i]; + if (node.m_endtype == etClosedPolygon || + (node.m_endtype == etClosedLine && Orientation(node.Contour))) + ReversePath(node.Contour); + } + } else { + for (int i = 0; i < m_polyNodes.ChildCount(); ++i) { + PolyNode &node = *m_polyNodes.Childs[i]; + if (node.m_endtype == etClosedLine && !Orientation(node.Contour)) + ReversePath(node.Contour); + } + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::Execute(Paths &solution, double delta) { + solution.clear(); + FixOrientations(); + DoOffset(delta); + + // now clean up 'corners' ... + Clipper clpr; + clpr.AddPaths(m_destPolys, ptSubject, true); + if (delta > 0) { + clpr.Execute(ctUnion, solution, pftPositive, pftPositive); + } else { + IntRect r = clpr.GetBounds(); + Path outer(4); + outer[0] = IntPoint(r.left - 10, r.bottom + 10); + outer[1] = IntPoint(r.right + 10, r.bottom + 10); + outer[2] = IntPoint(r.right + 10, r.top - 10); + outer[3] = IntPoint(r.left - 10, r.top - 10); + + clpr.AddPath(outer, ptSubject, true); + clpr.ReverseSolution(true); + clpr.Execute(ctUnion, solution, pftNegative, pftNegative); + if (solution.size() > 0) + solution.erase(solution.begin()); + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::Execute(PolyTree &solution, double delta) { + solution.Clear(); + FixOrientations(); + DoOffset(delta); + + // now clean up 'corners' ... + Clipper clpr; + clpr.AddPaths(m_destPolys, ptSubject, true); + if (delta > 0) { + clpr.Execute(ctUnion, solution, pftPositive, pftPositive); + } else { + IntRect r = clpr.GetBounds(); + Path outer(4); + outer[0] = IntPoint(r.left - 10, r.bottom + 10); + outer[1] = IntPoint(r.right + 10, r.bottom + 10); + outer[2] = IntPoint(r.right + 10, r.top - 10); + outer[3] = IntPoint(r.left - 10, r.top - 10); + + clpr.AddPath(outer, ptSubject, true); + clpr.ReverseSolution(true); + clpr.Execute(ctUnion, solution, pftNegative, pftNegative); + // remove the outer PolyNode rectangle ... + if (solution.ChildCount() == 1 && solution.Childs[0]->ChildCount() > 0) { + PolyNode *outerNode = solution.Childs[0]; + solution.Childs.reserve(outerNode->ChildCount()); + solution.Childs[0] = outerNode->Childs[0]; + solution.Childs[0]->Parent = outerNode->Parent; + for (int i = 1; i < outerNode->ChildCount(); ++i) + solution.AddChild(*outerNode->Childs[i]); + } else + solution.Clear(); + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoOffset(double delta) { + m_destPolys.clear(); + m_delta = delta; + + // if Zero offset, just copy any CLOSED polygons to m_p and return ... + if (NEAR_ZERO(delta)) { + m_destPolys.reserve(m_polyNodes.ChildCount()); + for (int i = 0; i < m_polyNodes.ChildCount(); i++) { + PolyNode &node = *m_polyNodes.Childs[i]; + if (node.m_endtype == etClosedPolygon) + m_destPolys.push_back(node.Contour); + } + return; + } + + // see offset_triginometry3.svg in the documentation folder ... + if (MiterLimit > 2) + m_miterLim = 2 / (MiterLimit * MiterLimit); + else + m_miterLim = 0.5; + + double y; + if (ArcTolerance <= 0.0) + y = def_arc_tolerance; + else if (ArcTolerance > std::fabs(delta) * def_arc_tolerance) + y = std::fabs(delta) * def_arc_tolerance; + else + y = ArcTolerance; + // see offset_triginometry2.svg in the documentation folder ... + double steps = pi / std::acos(1 - y / std::fabs(delta)); + if (steps > std::fabs(delta) * pi) + steps = std::fabs(delta) * pi; // ie excessive precision check + m_sin = std::sin(two_pi / steps); + m_cos = std::cos(two_pi / steps); + m_StepsPerRad = steps / two_pi; + if (delta < 0.0) + m_sin = -m_sin; + + m_destPolys.reserve(m_polyNodes.ChildCount() * 2); + for (int i = 0; i < m_polyNodes.ChildCount(); i++) { + PolyNode &node = *m_polyNodes.Childs[i]; + m_srcPoly = node.Contour; + + int len = (int)m_srcPoly.size(); + if (len == 0 || + (delta <= 0 && (len < 3 || node.m_endtype != etClosedPolygon))) + continue; + + m_destPoly.clear(); + if (len == 1) { + if (node.m_jointype == jtRound) { + double X = 1.0, Y = 0.0; + for (cInt j = 1; j <= steps; j++) { + m_destPoly.push_back(IntPoint(Round(m_srcPoly[0].X + X * delta), + Round(m_srcPoly[0].Y + Y * delta))); + double X2 = X; + X = X * m_cos - m_sin * Y; + Y = X2 * m_sin + Y * m_cos; + } + } else { + double X = -1.0, Y = -1.0; + for (int j = 0; j < 4; ++j) { + m_destPoly.push_back(IntPoint(Round(m_srcPoly[0].X + X * delta), + Round(m_srcPoly[0].Y + Y * delta))); + if (X < 0) + X = 1; + else if (Y < 0) + Y = 1; + else + X = -1; + } + } + m_destPolys.push_back(m_destPoly); + continue; + } + // build m_normals ... + m_normals.clear(); + m_normals.reserve(len); + for (int j = 0; j < len - 1; ++j) + m_normals.push_back(GetUnitNormal(m_srcPoly[j], m_srcPoly[j + 1])); + if (node.m_endtype == etClosedLine || node.m_endtype == etClosedPolygon) + m_normals.push_back(GetUnitNormal(m_srcPoly[len - 1], m_srcPoly[0])); + else + m_normals.push_back(DoublePoint(m_normals[len - 2])); + + if (node.m_endtype == etClosedPolygon) { + int k = len - 1; + for (int j = 0; j < len; ++j) + OffsetPoint(j, k, node.m_jointype); + m_destPolys.push_back(m_destPoly); + } else if (node.m_endtype == etClosedLine) { + int k = len - 1; + for (int j = 0; j < len; ++j) + OffsetPoint(j, k, node.m_jointype); + m_destPolys.push_back(m_destPoly); + m_destPoly.clear(); + // re-build m_normals ... + DoublePoint n = m_normals[len - 1]; + for (int j = len - 1; j > 0; j--) + m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y); + m_normals[0] = DoublePoint(-n.X, -n.Y); + k = 0; + for (int j = len - 1; j >= 0; j--) + OffsetPoint(j, k, node.m_jointype); + m_destPolys.push_back(m_destPoly); + } else { + int k = 0; + for (int j = 1; j < len - 1; ++j) + OffsetPoint(j, k, node.m_jointype); + + IntPoint pt1; + if (node.m_endtype == etOpenButt) { + int j = len - 1; + pt1 = IntPoint((cInt)Round(m_srcPoly[j].X + m_normals[j].X * delta), + (cInt)Round(m_srcPoly[j].Y + m_normals[j].Y * delta)); + m_destPoly.push_back(pt1); + pt1 = IntPoint((cInt)Round(m_srcPoly[j].X - m_normals[j].X * delta), + (cInt)Round(m_srcPoly[j].Y - m_normals[j].Y * delta)); + m_destPoly.push_back(pt1); + } else { + int j = len - 1; + k = len - 2; + m_sinA = 0; + m_normals[j] = DoublePoint(-m_normals[j].X, -m_normals[j].Y); + if (node.m_endtype == etOpenSquare) + DoSquare(j, k); + else + DoRound(j, k); + } + + // re-build m_normals ... + for (int j = len - 1; j > 0; j--) + m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y); + m_normals[0] = DoublePoint(-m_normals[1].X, -m_normals[1].Y); + + k = len - 1; + for (int j = k - 1; j > 0; --j) + OffsetPoint(j, k, node.m_jointype); + + if (node.m_endtype == etOpenButt) { + pt1 = IntPoint((cInt)Round(m_srcPoly[0].X - m_normals[0].X * delta), + (cInt)Round(m_srcPoly[0].Y - m_normals[0].Y * delta)); + m_destPoly.push_back(pt1); + pt1 = IntPoint((cInt)Round(m_srcPoly[0].X + m_normals[0].X * delta), + (cInt)Round(m_srcPoly[0].Y + m_normals[0].Y * delta)); + m_destPoly.push_back(pt1); + } else { + k = 1; + m_sinA = 0; + if (node.m_endtype == etOpenSquare) + DoSquare(0, 1); + else + DoRound(0, 1); + } + m_destPolys.push_back(m_destPoly); + } + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::OffsetPoint(int j, int &k, JoinType jointype) { + // cross product ... + m_sinA = (m_normals[k].X * m_normals[j].Y - m_normals[j].X * m_normals[k].Y); + if (std::fabs(m_sinA * m_delta) < 1.0) { + // dot product ... + double cosA = + (m_normals[k].X * m_normals[j].X + m_normals[j].Y * m_normals[k].Y); + if (cosA > 0) // angle => 0 degrees + { + m_destPoly.push_back( + IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta))); + return; + } + // else angle => 180 degrees + } else if (m_sinA > 1.0) + m_sinA = 1.0; + else if (m_sinA < -1.0) + m_sinA = -1.0; + + if (m_sinA * m_delta < 0) { + m_destPoly.push_back( + IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta))); + m_destPoly.push_back(m_srcPoly[j]); + m_destPoly.push_back( + IntPoint(Round(m_srcPoly[j].X + m_normals[j].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta))); + } else + switch (jointype) { + case jtMiter: { + double r = 1 + (m_normals[j].X * m_normals[k].X + + m_normals[j].Y * m_normals[k].Y); + if (r >= m_miterLim) + DoMiter(j, k, r); + else + DoSquare(j, k); + break; + } + case jtSquare: + DoSquare(j, k); + break; + case jtRound: + DoRound(j, k); + break; + } + k = j; +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoSquare(int j, int k) { + double dx = std::tan(std::atan2(m_sinA, m_normals[k].X * m_normals[j].X + + m_normals[k].Y * m_normals[j].Y) / + 4); + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + m_delta * (m_normals[k].X - m_normals[k].Y * dx)), + Round(m_srcPoly[j].Y + + m_delta * (m_normals[k].Y + m_normals[k].X * dx)))); + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + m_delta * (m_normals[j].X + m_normals[j].Y * dx)), + Round(m_srcPoly[j].Y + + m_delta * (m_normals[j].Y - m_normals[j].X * dx)))); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoMiter(int j, int k, double r) { + double q = m_delta / r; + m_destPoly.push_back( + IntPoint(Round(m_srcPoly[j].X + (m_normals[k].X + m_normals[j].X) * q), + Round(m_srcPoly[j].Y + (m_normals[k].Y + m_normals[j].Y) * q))); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoRound(int j, int k) { + double a = std::atan2(m_sinA, m_normals[k].X * m_normals[j].X + + m_normals[k].Y * m_normals[j].Y); + int steps = std::max((int)Round(m_StepsPerRad * std::fabs(a)), 1); + + double X = m_normals[k].X, Y = m_normals[k].Y, X2; + for (int i = 0; i < steps; ++i) { + m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + X * m_delta), + Round(m_srcPoly[j].Y + Y * m_delta))); + X2 = X; + X = X * m_cos - m_sin * Y; + Y = X2 * m_sin + Y * m_cos; + } + m_destPoly.push_back( + IntPoint(Round(m_srcPoly[j].X + m_normals[j].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta))); +} + +//------------------------------------------------------------------------------ +// Miscellaneous public functions +//------------------------------------------------------------------------------ + +void Clipper::DoSimplePolygons() { + PolyOutList::size_type i = 0; + while (i < m_PolyOuts.size()) { + OutRec *outrec = m_PolyOuts[i++]; + OutPt *op = outrec->Pts; + if (!op || outrec->IsOpen) + continue; + do // for each Pt in Polygon until duplicate found do ... + { + OutPt *op2 = op->Next; + while (op2 != outrec->Pts) { + if ((op->Pt == op2->Pt) && op2->Next != op && op2->Prev != op) { + // split the polygon into two ... + OutPt *op3 = op->Prev; + OutPt *op4 = op2->Prev; + op->Prev = op4; + op4->Next = op; + op2->Prev = op3; + op3->Next = op2; + + outrec->Pts = op; + OutRec *outrec2 = CreateOutRec(); + outrec2->Pts = op2; + UpdateOutPtIdxs(*outrec2); + if (Poly2ContainsPoly1(outrec2->Pts, outrec->Pts)) { + // OutRec2 is contained by OutRec1 ... + outrec2->IsHole = !outrec->IsHole; + outrec2->FirstLeft = outrec; + if (m_UsingPolyTree) + FixupFirstLefts2(outrec2, outrec); + } else if (Poly2ContainsPoly1(outrec->Pts, outrec2->Pts)) { + // OutRec1 is contained by OutRec2 ... + outrec2->IsHole = outrec->IsHole; + outrec->IsHole = !outrec2->IsHole; + outrec2->FirstLeft = outrec->FirstLeft; + outrec->FirstLeft = outrec2; + if (m_UsingPolyTree) + FixupFirstLefts2(outrec, outrec2); + } else { + // the 2 polygons are separate ... + outrec2->IsHole = outrec->IsHole; + outrec2->FirstLeft = outrec->FirstLeft; + if (m_UsingPolyTree) + FixupFirstLefts1(outrec, outrec2); + } + op2 = op; // ie get ready for the Next iteration + } + op2 = op2->Next; + } + op = op->Next; + } while (op != outrec->Pts); + } +} +//------------------------------------------------------------------------------ + +void ReversePath(Path &p) { std::reverse(p.begin(), p.end()); } +//------------------------------------------------------------------------------ + +void ReversePaths(Paths &p) { + for (Paths::size_type i = 0; i < p.size(); ++i) + ReversePath(p[i]); +} +//------------------------------------------------------------------------------ + +void SimplifyPolygon(const Path &in_poly, Paths &out_polys, + PolyFillType fillType) { + Clipper c; + c.StrictlySimple(true); + c.AddPath(in_poly, ptSubject, true); + c.Execute(ctUnion, out_polys, fillType, fillType); +} +//------------------------------------------------------------------------------ + +void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, + PolyFillType fillType) { + Clipper c; + c.StrictlySimple(true); + c.AddPaths(in_polys, ptSubject, true); + c.Execute(ctUnion, out_polys, fillType, fillType); +} +//------------------------------------------------------------------------------ + +void SimplifyPolygons(Paths &polys, PolyFillType fillType) { + SimplifyPolygons(polys, polys, fillType); +} +//------------------------------------------------------------------------------ + +inline double DistanceSqrd(const IntPoint &pt1, const IntPoint &pt2) { + double Dx = ((double)pt1.X - pt2.X); + double dy = ((double)pt1.Y - pt2.Y); + return (Dx * Dx + dy * dy); +} +//------------------------------------------------------------------------------ + +double DistanceFromLineSqrd(const IntPoint &pt, const IntPoint &ln1, + const IntPoint &ln2) { + // The equation of a line in general form (Ax + By + C = 0) + // given 2 points (x�,y�) & (x�,y�) is ... + //(y� - y�)x + (x� - x�)y + (y� - y�)x� - (x� - x�)y� = 0 + // A = (y� - y�); B = (x� - x�); C = (y� - y�)x� - (x� - x�)y� + // perpendicular distance of point (x�,y�) = (Ax� + By� + C)/Sqrt(A� + B�) + // see http://en.wikipedia.org/wiki/Perpendicular_distance + double A = double(ln1.Y - ln2.Y); + double B = double(ln2.X - ln1.X); + double C = A * ln1.X + B * ln1.Y; + C = A * pt.X + B * pt.Y - C; + return (C * C) / (A * A + B * B); +} +//--------------------------------------------------------------------------- + +bool SlopesNearCollinear(const IntPoint &pt1, const IntPoint &pt2, + const IntPoint &pt3, double distSqrd) { + // this function is more accurate when the point that's geometrically + // between the other 2 points is the one that's tested for distance. + // ie makes it more likely to pick up 'spikes' ... + if (Abs(pt1.X - pt2.X) > Abs(pt1.Y - pt2.Y)) { + if ((pt1.X > pt2.X) == (pt1.X < pt3.X)) + return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd; + else if ((pt2.X > pt1.X) == (pt2.X < pt3.X)) + return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd; + else + return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd; + } else { + if ((pt1.Y > pt2.Y) == (pt1.Y < pt3.Y)) + return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd; + else if ((pt2.Y > pt1.Y) == (pt2.Y < pt3.Y)) + return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd; + else + return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd; + } +} +//------------------------------------------------------------------------------ + +bool PointsAreClose(IntPoint pt1, IntPoint pt2, double distSqrd) { + double Dx = (double)pt1.X - pt2.X; + double dy = (double)pt1.Y - pt2.Y; + return ((Dx * Dx) + (dy * dy) <= distSqrd); +} +//------------------------------------------------------------------------------ + +OutPt *ExcludeOp(OutPt *op) { + OutPt *result = op->Prev; + result->Next = op->Next; + op->Next->Prev = result; + result->Idx = 0; + return result; +} +//------------------------------------------------------------------------------ + +void CleanPolygon(const Path &in_poly, Path &out_poly, double distance) { + // distance = proximity in units/pixels below which vertices + // will be stripped. Default ~= sqrt(2). + + size_t size = in_poly.size(); + + if (size == 0) { + out_poly.clear(); + return; + } + + OutPt *outPts = new OutPt[size]; + for (size_t i = 0; i < size; ++i) { + outPts[i].Pt = in_poly[i]; + outPts[i].Next = &outPts[(i + 1) % size]; + outPts[i].Next->Prev = &outPts[i]; + outPts[i].Idx = 0; + } + + double distSqrd = distance * distance; + OutPt *op = &outPts[0]; + while (op->Idx == 0 && op->Next != op->Prev) { + if (PointsAreClose(op->Pt, op->Prev->Pt, distSqrd)) { + op = ExcludeOp(op); + size--; + } else if (PointsAreClose(op->Prev->Pt, op->Next->Pt, distSqrd)) { + ExcludeOp(op->Next); + op = ExcludeOp(op); + size -= 2; + } else if (SlopesNearCollinear(op->Prev->Pt, op->Pt, op->Next->Pt, + distSqrd)) { + op = ExcludeOp(op); + size--; + } else { + op->Idx = 1; + op = op->Next; + } + } + + if (size < 3) + size = 0; + out_poly.resize(size); + for (size_t i = 0; i < size; ++i) { + out_poly[i] = op->Pt; + op = op->Next; + } + delete[] outPts; +} +//------------------------------------------------------------------------------ + +void CleanPolygon(Path &poly, double distance) { + CleanPolygon(poly, poly, distance); +} +//------------------------------------------------------------------------------ + +void CleanPolygons(const Paths &in_polys, Paths &out_polys, double distance) { + out_polys.resize(in_polys.size()); + for (Paths::size_type i = 0; i < in_polys.size(); ++i) + CleanPolygon(in_polys[i], out_polys[i], distance); +} +//------------------------------------------------------------------------------ + +void CleanPolygons(Paths &polys, double distance) { + CleanPolygons(polys, polys, distance); +} +//------------------------------------------------------------------------------ + +void Minkowski(const Path &poly, const Path &path, Paths &solution, bool isSum, + bool isClosed) { + int delta = (isClosed ? 1 : 0); + size_t polyCnt = poly.size(); + size_t pathCnt = path.size(); + Paths pp; + pp.reserve(pathCnt); + if (isSum) + for (size_t i = 0; i < pathCnt; ++i) { + Path p; + p.reserve(polyCnt); + for (size_t j = 0; j < poly.size(); ++j) + p.push_back(IntPoint(path[i].X + poly[j].X, path[i].Y + poly[j].Y)); + pp.push_back(p); + } + else + for (size_t i = 0; i < pathCnt; ++i) { + Path p; + p.reserve(polyCnt); + for (size_t j = 0; j < poly.size(); ++j) + p.push_back(IntPoint(path[i].X - poly[j].X, path[i].Y - poly[j].Y)); + pp.push_back(p); + } + + solution.clear(); + solution.reserve((pathCnt + delta) * (polyCnt + 1)); + for (size_t i = 0; i < pathCnt - 1 + delta; ++i) + for (size_t j = 0; j < polyCnt; ++j) { + Path quad; + quad.reserve(4); + quad.push_back(pp[i % pathCnt][j % polyCnt]); + quad.push_back(pp[(i + 1) % pathCnt][j % polyCnt]); + quad.push_back(pp[(i + 1) % pathCnt][(j + 1) % polyCnt]); + quad.push_back(pp[i % pathCnt][(j + 1) % polyCnt]); + if (!Orientation(quad)) + ReversePath(quad); + solution.push_back(quad); + } +} +//------------------------------------------------------------------------------ + +void MinkowskiSum(const Path &pattern, const Path &path, Paths &solution, + bool pathIsClosed) { + Minkowski(pattern, path, solution, true, pathIsClosed); + Clipper c; + c.AddPaths(solution, ptSubject, true); + c.Execute(ctUnion, solution, pftNonZero, pftNonZero); +} +//------------------------------------------------------------------------------ + +void TranslatePath(const Path &input, Path &output, const IntPoint delta) { + // precondition: input != output + output.resize(input.size()); + for (size_t i = 0; i < input.size(); ++i) + output[i] = IntPoint(input[i].X + delta.X, input[i].Y + delta.Y); +} +//------------------------------------------------------------------------------ + +void MinkowskiSum(const Path &pattern, const Paths &paths, Paths &solution, + bool pathIsClosed) { + Clipper c; + for (size_t i = 0; i < paths.size(); ++i) { + Paths tmp; + Minkowski(pattern, paths[i], tmp, true, pathIsClosed); + c.AddPaths(tmp, ptSubject, true); + if (pathIsClosed) { + Path tmp2; + TranslatePath(paths[i], tmp2, pattern[0]); + c.AddPath(tmp2, ptClip, true); + } + } + c.Execute(ctUnion, solution, pftNonZero, pftNonZero); +} +//------------------------------------------------------------------------------ + +void MinkowskiDiff(const Path &poly1, const Path &poly2, Paths &solution) { + Minkowski(poly1, poly2, solution, false, true); + Clipper c; + c.AddPaths(solution, ptSubject, true); + c.Execute(ctUnion, solution, pftNonZero, pftNonZero); +} +//------------------------------------------------------------------------------ + +enum NodeType { ntAny, ntOpen, ntClosed }; + +void AddPolyNodeToPaths(const PolyNode &polynode, NodeType nodetype, + Paths &paths) { + bool match = true; + if (nodetype == ntClosed) + match = !polynode.IsOpen(); + else if (nodetype == ntOpen) + return; + + if (!polynode.Contour.empty() && match) + paths.push_back(polynode.Contour); + for (int i = 0; i < polynode.ChildCount(); ++i) + AddPolyNodeToPaths(*polynode.Childs[i], nodetype, paths); +} +//------------------------------------------------------------------------------ + +void PolyTreeToPaths(const PolyTree &polytree, Paths &paths) { + paths.resize(0); + paths.reserve(polytree.Total()); + AddPolyNodeToPaths(polytree, ntAny, paths); +} +//------------------------------------------------------------------------------ + +void ClosedPathsFromPolyTree(const PolyTree &polytree, Paths &paths) { + paths.resize(0); + paths.reserve(polytree.Total()); + AddPolyNodeToPaths(polytree, ntClosed, paths); +} +//------------------------------------------------------------------------------ + +void OpenPathsFromPolyTree(PolyTree &polytree, Paths &paths) { + paths.resize(0); + paths.reserve(polytree.Total()); + // Open paths are top level only, so ... + for (int i = 0; i < polytree.ChildCount(); ++i) + if (polytree.Childs[i]->IsOpen()) + paths.push_back(polytree.Childs[i]->Contour); +} +//------------------------------------------------------------------------------ + +std::ostream &operator<<(std::ostream &s, const IntPoint &p) { + s << "(" << p.X << "," << p.Y << ")"; + return s; +} +//------------------------------------------------------------------------------ + +std::ostream &operator<<(std::ostream &s, const Path &p) { + if (p.empty()) + return s; + Path::size_type last = p.size() - 1; + for (Path::size_type i = 0; i < last; i++) + s << "(" << p[i].X << "," << p[i].Y << "), "; + s << "(" << p[last].X << "," << p[last].Y << ")\n"; + return s; +} +//------------------------------------------------------------------------------ + +std::ostream &operator<<(std::ostream &s, const Paths &p) { + for (Paths::size_type i = 0; i < p.size(); i++) + s << p[i]; + s << "\n"; + return s; +} +//------------------------------------------------------------------------------ + +} // ClipperLib namespace diff --git a/deploy/android_demo/app/src/main/cpp/ocr_clipper.hpp b/deploy/android_demo/app/src/main/cpp/ocr_clipper.hpp new file mode 100644 index 0000000..60af2bb --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_clipper.hpp @@ -0,0 +1,544 @@ +/******************************************************************************* +* * +* Author : Angus Johnson * +* Version : 6.4.2 * +* Date : 27 February 2017 * +* Website : http://www.angusj.com * +* Copyright : Angus Johnson 2010-2017 * +* * +* License: * +* Use, modification & distribution is subject to Boost Software License Ver 1. * +* http://www.boost.org/LICENSE_1_0.txt * +* * +* Attributions: * +* The code in this library is an extension of Bala Vatti's clipping algorithm: * +* "A generic solution to polygon clipping" * +* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. * +* http://portal.acm.org/citation.cfm?id=129906 * +* * +* Computer graphics and geometric modeling: implementation and algorithms * +* By Max K. Agoston * +* Springer; 1 edition (January 4, 2005) * +* http://books.google.com/books?q=vatti+clipping+agoston * +* * +* See also: * +* "Polygon Offsetting by Computing Winding Numbers" * +* Paper no. DETC2005-85513 pp. 565-575 * +* ASME 2005 International Design Engineering Technical Conferences * +* and Computers and Information in Engineering Conference (IDETC/CIE2005) * +* September 24-28, 2005 , Long Beach, California, USA * +* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf * +* * +*******************************************************************************/ + +#ifndef clipper_hpp +#define clipper_hpp + +#define CLIPPER_VERSION "6.4.2" + +// use_int32: When enabled 32bit ints are used instead of 64bit ints. This +// improve performance but coordinate values are limited to the range +/- 46340 +//#define use_int32 + +// use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance. +//#define use_xyz + +// use_lines: Enables line clipping. Adds a very minor cost to performance. +#define use_lines + +// use_deprecated: Enables temporary support for the obsolete functions +//#define use_deprecated + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ClipperLib { + +enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor }; +enum PolyType { ptSubject, ptClip }; +// By far the most widely used winding rules for polygon filling are +// EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32) +// Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL) +// see http://glprogramming.com/red/chapter11.html +enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative }; + +#ifdef use_int32 +typedef int cInt; +static cInt const loRange = 0x7FFF; +static cInt const hiRange = 0x7FFF; +#else +typedef signed long long cInt; +static cInt const loRange = 0x3FFFFFFF; +static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL; +typedef signed long long long64; // used by Int128 class +typedef unsigned long long ulong64; + +#endif + +struct IntPoint { + cInt X; + cInt Y; +#ifdef use_xyz + cInt Z; + IntPoint(cInt x = 0, cInt y = 0, cInt z = 0) : X(x), Y(y), Z(z){}; +#else + + IntPoint(cInt x = 0, cInt y = 0) : X(x), Y(y){}; +#endif + + friend inline bool operator==(const IntPoint &a, const IntPoint &b) { + return a.X == b.X && a.Y == b.Y; + } + + friend inline bool operator!=(const IntPoint &a, const IntPoint &b) { + return a.X != b.X || a.Y != b.Y; + } +}; +//------------------------------------------------------------------------------ + +typedef std::vector Path; +typedef std::vector Paths; + +inline Path &operator<<(Path &poly, const IntPoint &p) { + poly.push_back(p); + return poly; +} + +inline Paths &operator<<(Paths &polys, const Path &p) { + polys.push_back(p); + return polys; +} + +std::ostream &operator<<(std::ostream &s, const IntPoint &p); + +std::ostream &operator<<(std::ostream &s, const Path &p); + +std::ostream &operator<<(std::ostream &s, const Paths &p); + +struct DoublePoint { + double X; + double Y; + + DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {} + + DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {} +}; +//------------------------------------------------------------------------------ + +#ifdef use_xyz +typedef void (*ZFillCallback)(IntPoint &e1bot, IntPoint &e1top, IntPoint &e2bot, + IntPoint &e2top, IntPoint &pt); +#endif + +enum InitOptions { + ioReverseSolution = 1, + ioStrictlySimple = 2, + ioPreserveCollinear = 4 +}; +enum JoinType { jtSquare, jtRound, jtMiter }; +enum EndType { + etClosedPolygon, + etClosedLine, + etOpenButt, + etOpenSquare, + etOpenRound +}; + +class PolyNode; + +typedef std::vector PolyNodes; + +class PolyNode { +public: + PolyNode(); + + virtual ~PolyNode(){}; + Path Contour; + PolyNodes Childs; + PolyNode *Parent; + + PolyNode *GetNext() const; + + bool IsHole() const; + + bool IsOpen() const; + + int ChildCount() const; + +private: + // PolyNode& operator =(PolyNode& other); + unsigned Index; // node index in Parent.Childs + bool m_IsOpen; + JoinType m_jointype; + EndType m_endtype; + + PolyNode *GetNextSiblingUp() const; + + void AddChild(PolyNode &child); + + friend class Clipper; // to access Index + friend class ClipperOffset; +}; + +class PolyTree : public PolyNode { +public: + ~PolyTree() { Clear(); }; + + PolyNode *GetFirst() const; + + void Clear(); + + int Total() const; + +private: + // PolyTree& operator =(PolyTree& other); + PolyNodes AllNodes; + + friend class Clipper; // to access AllNodes +}; + +bool Orientation(const Path &poly); + +double Area(const Path &poly); + +int PointInPolygon(const IntPoint &pt, const Path &path); + +void SimplifyPolygon(const Path &in_poly, Paths &out_polys, + PolyFillType fillType = pftEvenOdd); + +void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, + PolyFillType fillType = pftEvenOdd); + +void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd); + +void CleanPolygon(const Path &in_poly, Path &out_poly, double distance = 1.415); + +void CleanPolygon(Path &poly, double distance = 1.415); + +void CleanPolygons(const Paths &in_polys, Paths &out_polys, + double distance = 1.415); + +void CleanPolygons(Paths &polys, double distance = 1.415); + +void MinkowskiSum(const Path &pattern, const Path &path, Paths &solution, + bool pathIsClosed); + +void MinkowskiSum(const Path &pattern, const Paths &paths, Paths &solution, + bool pathIsClosed); + +void MinkowskiDiff(const Path &poly1, const Path &poly2, Paths &solution); + +void PolyTreeToPaths(const PolyTree &polytree, Paths &paths); + +void ClosedPathsFromPolyTree(const PolyTree &polytree, Paths &paths); + +void OpenPathsFromPolyTree(PolyTree &polytree, Paths &paths); + +void ReversePath(Path &p); + +void ReversePaths(Paths &p); + +struct IntRect { + cInt left; + cInt top; + cInt right; + cInt bottom; +}; + +// enums that are used internally ... +enum EdgeSide { esLeft = 1, esRight = 2 }; + +// forward declarations (for stuff used internally) ... +struct TEdge; +struct IntersectNode; +struct LocalMinimum; +struct OutPt; +struct OutRec; +struct Join; + +typedef std::vector PolyOutList; +typedef std::vector EdgeList; +typedef std::vector JoinList; +typedef std::vector IntersectList; + +//------------------------------------------------------------------------------ + +// ClipperBase is the ancestor to the Clipper class. It should not be +// instantiated directly. This class simply abstracts the conversion of sets of +// polygon coordinates into edge objects that are stored in a LocalMinima list. +class ClipperBase { +public: + ClipperBase(); + + virtual ~ClipperBase(); + + virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed); + + bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed); + + virtual void Clear(); + + IntRect GetBounds(); + + bool PreserveCollinear() { return m_PreserveCollinear; }; + + void PreserveCollinear(bool value) { m_PreserveCollinear = value; }; + +protected: + void DisposeLocalMinimaList(); + + TEdge *AddBoundsToLML(TEdge *e, bool IsClosed); + + virtual void Reset(); + + TEdge *ProcessBound(TEdge *E, bool IsClockwise); + + void InsertScanbeam(const cInt Y); + + bool PopScanbeam(cInt &Y); + + bool LocalMinimaPending(); + + bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin); + + OutRec *CreateOutRec(); + + void DisposeAllOutRecs(); + + void DisposeOutRec(PolyOutList::size_type index); + + void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2); + + void DeleteFromAEL(TEdge *e); + + void UpdateEdgeIntoAEL(TEdge *&e); + + typedef std::vector MinimaList; + MinimaList::iterator m_CurrentLM; + MinimaList m_MinimaList; + + bool m_UseFullRange; + EdgeList m_edges; + bool m_PreserveCollinear; + bool m_HasOpenPaths; + PolyOutList m_PolyOuts; + TEdge *m_ActiveEdges; + + typedef std::priority_queue ScanbeamList; + ScanbeamList m_Scanbeam; +}; +//------------------------------------------------------------------------------ + +class Clipper : public virtual ClipperBase { +public: + Clipper(int initOptions = 0); + + bool Execute(ClipType clipType, Paths &solution, + PolyFillType fillType = pftEvenOdd); + + bool Execute(ClipType clipType, Paths &solution, PolyFillType subjFillType, + PolyFillType clipFillType); + + bool Execute(ClipType clipType, PolyTree &polytree, + PolyFillType fillType = pftEvenOdd); + + bool Execute(ClipType clipType, PolyTree &polytree, PolyFillType subjFillType, + PolyFillType clipFillType); + + bool ReverseSolution() { return m_ReverseOutput; }; + + void ReverseSolution(bool value) { m_ReverseOutput = value; }; + + bool StrictlySimple() { return m_StrictSimple; }; + + void StrictlySimple(bool value) { m_StrictSimple = value; }; +// set the callback function for z value filling on intersections (otherwise Z +// is 0) +#ifdef use_xyz + void ZFillFunction(ZFillCallback zFillFunc); +#endif +protected: + virtual bool ExecuteInternal(); + +private: + JoinList m_Joins; + JoinList m_GhostJoins; + IntersectList m_IntersectList; + ClipType m_ClipType; + typedef std::list MaximaList; + MaximaList m_Maxima; + TEdge *m_SortedEdges; + bool m_ExecuteLocked; + PolyFillType m_ClipFillType; + PolyFillType m_SubjFillType; + bool m_ReverseOutput; + bool m_UsingPolyTree; + bool m_StrictSimple; +#ifdef use_xyz + ZFillCallback m_ZFill; // custom callback +#endif + + void SetWindingCount(TEdge &edge); + + bool IsEvenOddFillType(const TEdge &edge) const; + + bool IsEvenOddAltFillType(const TEdge &edge) const; + + void InsertLocalMinimaIntoAEL(const cInt botY); + + void InsertEdgeIntoAEL(TEdge *edge, TEdge *startEdge); + + void AddEdgeToSEL(TEdge *edge); + + bool PopEdgeFromSEL(TEdge *&edge); + + void CopyAELToSEL(); + + void DeleteFromSEL(TEdge *e); + + void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2); + + bool IsContributing(const TEdge &edge) const; + + bool IsTopHorz(const cInt XPos); + + void DoMaxima(TEdge *e); + + void ProcessHorizontals(); + + void ProcessHorizontal(TEdge *horzEdge); + + void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt); + + OutPt *AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt); + + OutRec *GetOutRec(int idx); + + void AppendPolygon(TEdge *e1, TEdge *e2); + + void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt); + + OutPt *AddOutPt(TEdge *e, const IntPoint &pt); + + OutPt *GetLastOutPt(TEdge *e); + + bool ProcessIntersections(const cInt topY); + + void BuildIntersectList(const cInt topY); + + void ProcessIntersectList(); + + void ProcessEdgesAtTopOfScanbeam(const cInt topY); + + void BuildResult(Paths &polys); + + void BuildResult2(PolyTree &polytree); + + void SetHoleState(TEdge *e, OutRec *outrec); + + void DisposeIntersectNodes(); + + bool FixupIntersectionOrder(); + + void FixupOutPolygon(OutRec &outrec); + + void FixupOutPolyline(OutRec &outrec); + + bool IsHole(TEdge *e); + + bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl); + + void FixHoleLinkage(OutRec &outrec); + + void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt); + + void ClearJoins(); + + void ClearGhostJoins(); + + void AddGhostJoin(OutPt *op, const IntPoint offPt); + + bool JoinPoints(Join *j, OutRec *outRec1, OutRec *outRec2); + + void JoinCommonEdges(); + + void DoSimplePolygons(); + + void FixupFirstLefts1(OutRec *OldOutRec, OutRec *NewOutRec); + + void FixupFirstLefts2(OutRec *InnerOutRec, OutRec *OuterOutRec); + + void FixupFirstLefts3(OutRec *OldOutRec, OutRec *NewOutRec); + +#ifdef use_xyz + void SetZ(IntPoint &pt, TEdge &e1, TEdge &e2); +#endif +}; +//------------------------------------------------------------------------------ + +class ClipperOffset { +public: + ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25); + + ~ClipperOffset(); + + void AddPath(const Path &path, JoinType joinType, EndType endType); + + void AddPaths(const Paths &paths, JoinType joinType, EndType endType); + + void Execute(Paths &solution, double delta); + + void Execute(PolyTree &solution, double delta); + + void Clear(); + + double MiterLimit; + double ArcTolerance; + +private: + Paths m_destPolys; + Path m_srcPoly; + Path m_destPoly; + std::vector m_normals; + double m_delta, m_sinA, m_sin, m_cos; + double m_miterLim, m_StepsPerRad; + IntPoint m_lowest; + PolyNode m_polyNodes; + + void FixOrientations(); + + void DoOffset(double delta); + + void OffsetPoint(int j, int &k, JoinType jointype); + + void DoSquare(int j, int k); + + void DoMiter(int j, int k, double r); + + void DoRound(int j, int k); +}; +//------------------------------------------------------------------------------ + +class clipperException : public std::exception { +public: + clipperException(const char *description) : m_descr(description) {} + + virtual ~clipperException() throw() {} + + virtual const char *what() const throw() { return m_descr.c_str(); } + +private: + std::string m_descr; +}; +//------------------------------------------------------------------------------ + +} // ClipperLib namespace + +#endif // clipper_hpp diff --git a/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp new file mode 100644 index 0000000..e7de9b0 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ocr_cls_process.h" +#include +#include +#include +#include +#include +#include + +const std::vector CLS_IMAGE_SHAPE = {3, 48, 192}; + +cv::Mat cls_resize_img(const cv::Mat &img) { + int imgC = CLS_IMAGE_SHAPE[0]; + int imgW = CLS_IMAGE_SHAPE[2]; + int imgH = CLS_IMAGE_SHAPE[1]; + + float ratio = float(img.cols) / float(img.rows); + int resize_w = 0; + if (ceilf(imgH * ratio) > imgW) + resize_w = imgW; + else + resize_w = int(ceilf(imgH * ratio)); + + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_CUBIC); + + if (resize_w < imgW) { + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(imgW - resize_w), + cv::BORDER_CONSTANT, {0, 0, 0}); + } + return resize_img; +} \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_cls_process.h b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.h new file mode 100644 index 0000000..1c30ee1 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.h @@ -0,0 +1,23 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "common.h" +#include +#include + +extern const std::vector CLS_IMAGE_SHAPE; + +cv::Mat cls_resize_img(const cv::Mat &img); \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.cpp b/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.cpp new file mode 100644 index 0000000..44c34a2 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.cpp @@ -0,0 +1,142 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ocr_crnn_process.h" +#include +#include +#include +#include +#include +#include + +const std::string CHARACTER_TYPE = "ch"; +const int MAX_DICT_LENGTH = 6624; +const std::vector REC_IMAGE_SHAPE = {3, 32, 320}; + +static cv::Mat crnn_resize_norm_img(cv::Mat img, float wh_ratio) { + int imgC = REC_IMAGE_SHAPE[0]; + int imgW = REC_IMAGE_SHAPE[2]; + int imgH = REC_IMAGE_SHAPE[1]; + + if (CHARACTER_TYPE == "ch") + imgW = int(32 * wh_ratio); + + float ratio = float(img.cols) / float(img.rows); + int resize_w = 0; + if (ceilf(imgH * ratio) > imgW) + resize_w = imgW; + else + resize_w = int(ceilf(imgH * ratio)); + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_CUBIC); + + resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f); + + for (int h = 0; h < resize_img.rows; h++) { + for (int w = 0; w < resize_img.cols; w++) { + resize_img.at(h, w)[0] = + (resize_img.at(h, w)[0] - 0.5) * 2; + resize_img.at(h, w)[1] = + (resize_img.at(h, w)[1] - 0.5) * 2; + resize_img.at(h, w)[2] = + (resize_img.at(h, w)[2] - 0.5) * 2; + } + } + + cv::Mat dist; + cv::copyMakeBorder(resize_img, dist, 0, 0, 0, int(imgW - resize_w), + cv::BORDER_CONSTANT, {0, 0, 0}); + + return dist; +} + +cv::Mat crnn_resize_img(const cv::Mat &img, float wh_ratio) { + int imgC = REC_IMAGE_SHAPE[0]; + int imgW = REC_IMAGE_SHAPE[2]; + int imgH = REC_IMAGE_SHAPE[1]; + + if (CHARACTER_TYPE == "ch") { + imgW = int(32 * wh_ratio); + } + + float ratio = float(img.cols) / float(img.rows); + int resize_w = 0; + if (ceilf(imgH * ratio) > imgW) + resize_w = imgW; + else + resize_w = int(ceilf(imgH * ratio)); + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(resize_w, imgH)); + return resize_img; +} + +cv::Mat get_rotate_crop_image(const cv::Mat &srcimage, + const std::vector> &box) { + + std::vector> points = box; + + int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]}; + int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]}; + int left = int(*std::min_element(x_collect, x_collect + 4)); + int right = int(*std::max_element(x_collect, x_collect + 4)); + int top = int(*std::min_element(y_collect, y_collect + 4)); + int bottom = int(*std::max_element(y_collect, y_collect + 4)); + + cv::Mat img_crop; + srcimage(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop); + + for (int i = 0; i < points.size(); i++) { + points[i][0] -= left; + points[i][1] -= top; + } + + int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) + + pow(points[0][1] - points[1][1], 2))); + int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) + + pow(points[0][1] - points[3][1], 2))); + + cv::Point2f pts_std[4]; + pts_std[0] = cv::Point2f(0., 0.); + pts_std[1] = cv::Point2f(img_crop_width, 0.); + pts_std[2] = cv::Point2f(img_crop_width, img_crop_height); + pts_std[3] = cv::Point2f(0.f, img_crop_height); + + cv::Point2f pointsf[4]; + pointsf[0] = cv::Point2f(points[0][0], points[0][1]); + pointsf[1] = cv::Point2f(points[1][0], points[1][1]); + pointsf[2] = cv::Point2f(points[2][0], points[2][1]); + pointsf[3] = cv::Point2f(points[3][0], points[3][1]); + + cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std); + + cv::Mat dst_img; + cv::warpPerspective(img_crop, dst_img, M, + cv::Size(img_crop_width, img_crop_height), + cv::BORDER_REPLICATE); + + if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) { + /* + cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth()); + cv::transpose(dst_img, srcCopy); + cv::flip(srcCopy, srcCopy, 0); + return srcCopy; + */ + cv::transpose(dst_img, dst_img); + cv::flip(dst_img, dst_img, 0); + return dst_img; + } else { + return dst_img; + } +} diff --git a/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.h b/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.h new file mode 100644 index 0000000..0346afe --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_crnn_process.h @@ -0,0 +1,20 @@ +// +// Created by fujiayi on 2020/7/3. +// +#pragma once + +#include "common.h" +#include +#include + +extern const std::vector REC_IMAGE_SHAPE; + +cv::Mat get_rotate_crop_image(const cv::Mat &srcimage, + const std::vector> &box); + +cv::Mat crnn_resize_img(const cv::Mat &img, float wh_ratio); + +template +inline size_t argmax(ForwardIterator first, ForwardIterator last) { + return std::distance(first, std::max_element(first, last)); +} \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_db_post_process.cpp b/deploy/android_demo/app/src/main/cpp/ocr_db_post_process.cpp new file mode 100644 index 0000000..9816ea4 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_db_post_process.cpp @@ -0,0 +1,342 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ocr_clipper.hpp" +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include + +static void getcontourarea(float **box, float unclip_ratio, float &distance) { + int pts_num = 4; + float area = 0.0f; + float dist = 0.0f; + for (int i = 0; i < pts_num; i++) { + area += box[i][0] * box[(i + 1) % pts_num][1] - + box[i][1] * box[(i + 1) % pts_num][0]; + dist += sqrtf((box[i][0] - box[(i + 1) % pts_num][0]) * + (box[i][0] - box[(i + 1) % pts_num][0]) + + (box[i][1] - box[(i + 1) % pts_num][1]) * + (box[i][1] - box[(i + 1) % pts_num][1])); + } + area = fabs(float(area / 2.0)); + + distance = area * unclip_ratio / dist; +} + +static cv::RotatedRect unclip(float **box) { + float unclip_ratio = 2.0; + float distance = 1.0; + + getcontourarea(box, unclip_ratio, distance); + + ClipperLib::ClipperOffset offset; + ClipperLib::Path p; + p << ClipperLib::IntPoint(int(box[0][0]), int(box[0][1])) + << ClipperLib::IntPoint(int(box[1][0]), int(box[1][1])) + << ClipperLib::IntPoint(int(box[2][0]), int(box[2][1])) + << ClipperLib::IntPoint(int(box[3][0]), int(box[3][1])); + offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon); + + ClipperLib::Paths soln; + offset.Execute(soln, distance); + std::vector points; + + for (int j = 0; j < soln.size(); j++) { + for (int i = 0; i < soln[soln.size() - 1].size(); i++) { + points.emplace_back(soln[j][i].X, soln[j][i].Y); + } + } + cv::RotatedRect res = cv::minAreaRect(points); + + return res; +} + +static float **Mat2Vec(cv::Mat mat) { + auto **array = new float *[mat.rows]; + for (int i = 0; i < mat.rows; ++i) { + array[i] = new float[mat.cols]; + } + for (int i = 0; i < mat.rows; ++i) { + for (int j = 0; j < mat.cols; ++j) { + array[i][j] = mat.at(i, j); + } + } + + return array; +} + +static void quickSort(float **s, int l, int r) { + if (l < r) { + int i = l, j = r; + float x = s[l][0]; + float *xp = s[l]; + while (i < j) { + while (i < j && s[j][0] >= x) { + j--; + } + if (i < j) { + std::swap(s[i++], s[j]); + } + while (i < j && s[i][0] < x) { + i++; + } + if (i < j) { + std::swap(s[j--], s[i]); + } + } + s[i] = xp; + quickSort(s, l, i - 1); + quickSort(s, i + 1, r); + } +} + +static void quickSort_vector(std::vector> &box, int l, int r, + int axis) { + if (l < r) { + int i = l, j = r; + int x = box[l][axis]; + std::vector xp(box[l]); + while (i < j) { + while (i < j && box[j][axis] >= x) { + j--; + } + if (i < j) { + std::swap(box[i++], box[j]); + } + while (i < j && box[i][axis] < x) { + i++; + } + if (i < j) { + std::swap(box[j--], box[i]); + } + } + box[i] = xp; + quickSort_vector(box, l, i - 1, axis); + quickSort_vector(box, i + 1, r, axis); + } +} + +static std::vector> +order_points_clockwise(std::vector> pts) { + std::vector> box = pts; + quickSort_vector(box, 0, int(box.size() - 1), 0); + std::vector> leftmost = {box[0], box[1]}; + std::vector> rightmost = {box[2], box[3]}; + + if (leftmost[0][1] > leftmost[1][1]) { + std::swap(leftmost[0], leftmost[1]); + } + + if (rightmost[0][1] > rightmost[1][1]) { + std::swap(rightmost[0], rightmost[1]); + } + + std::vector> rect = {leftmost[0], rightmost[0], rightmost[1], + leftmost[1]}; + return rect; +} + +static float **get_mini_boxes(cv::RotatedRect box, float &ssid) { + ssid = box.size.width >= box.size.height ? box.size.height : box.size.width; + + cv::Mat points; + cv::boxPoints(box, points); + // sorted box points + auto array = Mat2Vec(points); + quickSort(array, 0, 3); + + float *idx1 = array[0], *idx2 = array[1], *idx3 = array[2], *idx4 = array[3]; + if (array[3][1] <= array[2][1]) { + idx2 = array[3]; + idx3 = array[2]; + } else { + idx2 = array[2]; + idx3 = array[3]; + } + if (array[1][1] <= array[0][1]) { + idx1 = array[1]; + idx4 = array[0]; + } else { + idx1 = array[0]; + idx4 = array[1]; + } + + array[0] = idx1; + array[1] = idx2; + array[2] = idx3; + array[3] = idx4; + + return array; +} + +template T clamp(T x, T min, T max) { + if (x > max) { + return max; + } + if (x < min) { + return min; + } + return x; +} + +static float clampf(float x, float min, float max) { + if (x > max) + return max; + if (x < min) + return min; + return x; +} + +float box_score_fast(float **box_array, cv::Mat pred) { + auto array = box_array; + int width = pred.cols; + int height = pred.rows; + + float box_x[4] = {array[0][0], array[1][0], array[2][0], array[3][0]}; + float box_y[4] = {array[0][1], array[1][1], array[2][1], array[3][1]}; + + int xmin = clamp(int(std::floorf(*(std::min_element(box_x, box_x + 4)))), 0, + width - 1); + int xmax = clamp(int(std::ceilf(*(std::max_element(box_x, box_x + 4)))), 0, + width - 1); + int ymin = clamp(int(std::floorf(*(std::min_element(box_y, box_y + 4)))), 0, + height - 1); + int ymax = clamp(int(std::ceilf(*(std::max_element(box_y, box_y + 4)))), 0, + height - 1); + + cv::Mat mask; + mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1); + + cv::Point root_point[4]; + root_point[0] = cv::Point(int(array[0][0]) - xmin, int(array[0][1]) - ymin); + root_point[1] = cv::Point(int(array[1][0]) - xmin, int(array[1][1]) - ymin); + root_point[2] = cv::Point(int(array[2][0]) - xmin, int(array[2][1]) - ymin); + root_point[3] = cv::Point(int(array[3][0]) - xmin, int(array[3][1]) - ymin); + const cv::Point *ppt[1] = {root_point}; + int npt[] = {4}; + cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1)); + + cv::Mat croppedImg; + pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1)) + .copyTo(croppedImg); + + auto score = cv::mean(croppedImg, mask)[0]; + return score; +} + +std::vector>> +boxes_from_bitmap(const cv::Mat &pred, const cv::Mat &bitmap) { + const int min_size = 3; + const int max_candidates = 1000; + const float box_thresh = 0.5; + + int width = bitmap.cols; + int height = bitmap.rows; + + std::vector> contours; + std::vector hierarchy; + + cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST, + cv::CHAIN_APPROX_SIMPLE); + + int num_contours = + contours.size() >= max_candidates ? max_candidates : contours.size(); + + std::vector>> boxes; + + for (int _i = 0; _i < num_contours; _i++) { + float ssid; + cv::RotatedRect box = cv::minAreaRect(contours[_i]); + auto array = get_mini_boxes(box, ssid); + + auto box_for_unclip = array; + // end get_mini_box + + if (ssid < min_size) { + continue; + } + + float score; + score = box_score_fast(array, pred); + // end box_score_fast + if (score < box_thresh) { + continue; + } + + // start for unclip + cv::RotatedRect points = unclip(box_for_unclip); + // end for unclip + + cv::RotatedRect clipbox = points; + auto cliparray = get_mini_boxes(clipbox, ssid); + + if (ssid < min_size + 2) + continue; + + int dest_width = pred.cols; + int dest_height = pred.rows; + std::vector> intcliparray; + + for (int num_pt = 0; num_pt < 4; num_pt++) { + std::vector a{int(clampf(roundf(cliparray[num_pt][0] / float(width) * + float(dest_width)), + 0, float(dest_width))), + int(clampf(roundf(cliparray[num_pt][1] / + float(height) * float(dest_height)), + 0, float(dest_height)))}; + intcliparray.emplace_back(std::move(a)); + } + boxes.emplace_back(std::move(intcliparray)); + + } // end for + return boxes; +} + +int _max(int a, int b) { return a >= b ? a : b; } + +int _min(int a, int b) { return a >= b ? b : a; } + +std::vector>> +filter_tag_det_res(const std::vector>> &o_boxes, + float ratio_h, float ratio_w, const cv::Mat &srcimg) { + int oriimg_h = srcimg.rows; + int oriimg_w = srcimg.cols; + std::vector>> boxes{o_boxes}; + std::vector>> root_points; + for (int n = 0; n < boxes.size(); n++) { + boxes[n] = order_points_clockwise(boxes[n]); + for (int m = 0; m < boxes[0].size(); m++) { + boxes[n][m][0] /= ratio_w; + boxes[n][m][1] /= ratio_h; + + boxes[n][m][0] = int(_min(_max(boxes[n][m][0], 0), oriimg_w - 1)); + boxes[n][m][1] = int(_min(_max(boxes[n][m][1], 0), oriimg_h - 1)); + } + } + + for (int n = 0; n < boxes.size(); n++) { + int rect_width, rect_height; + rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) + + pow(boxes[n][0][1] - boxes[n][1][1], 2))); + rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) + + pow(boxes[n][0][1] - boxes[n][3][1], 2))); + if (rect_width <= 10 || rect_height <= 10) + continue; + root_points.push_back(boxes[n]); + } + return root_points; +} \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_db_post_process.h b/deploy/android_demo/app/src/main/cpp/ocr_db_post_process.h new file mode 100644 index 0000000..327da36 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_db_post_process.h @@ -0,0 +1,13 @@ +// +// Created by fujiayi on 2020/7/2. +// +#pragma once +#include +#include + +std::vector>> +boxes_from_bitmap(const cv::Mat &pred, const cv::Mat &bitmap); + +std::vector>> +filter_tag_det_res(const std::vector>> &o_boxes, + float ratio_h, float ratio_w, const cv::Mat &srcimg); \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp new file mode 100644 index 0000000..1bd989c --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp @@ -0,0 +1,350 @@ +// +// Created by fujiayi on 2020/7/1. +// + +#include "ocr_ppredictor.h" +#include "common.h" +#include "ocr_cls_process.h" +#include "ocr_crnn_process.h" +#include "ocr_db_post_process.h" +#include "preprocess.h" + +namespace ppredictor { + +OCR_PPredictor::OCR_PPredictor(const OCR_Config &config) : _config(config) {} + +int OCR_PPredictor::init(const std::string &det_model_content, + const std::string &rec_model_content, + const std::string &cls_model_content) { + _det_predictor = std::unique_ptr( + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR, _config.mode}); + _det_predictor->init_nb(det_model_content); + + _rec_predictor = std::unique_ptr( + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + _rec_predictor->init_nb(rec_model_content); + + _cls_predictor = std::unique_ptr( + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + _cls_predictor->init_nb(cls_model_content); + return RETURN_OK; +} + +int OCR_PPredictor::init_from_file(const std::string &det_model_path, + const std::string &rec_model_path, + const std::string &cls_model_path) { + _det_predictor = std::unique_ptr( + new PPredictor{_config.use_opencl, _config.thread_num, NET_OCR, _config.mode}); + _det_predictor->init_from_file(det_model_path); + + + _rec_predictor = std::unique_ptr( + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + _rec_predictor->init_from_file(rec_model_path); + + _cls_predictor = std::unique_ptr( + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + _cls_predictor->init_from_file(cls_model_path); + return RETURN_OK; +} +/** + * for debug use, show result of First Step + * @param filter_boxes + * @param boxes + * @param srcimg + */ +static void +visual_img(const std::vector>> &filter_boxes, + const std::vector>> &boxes, + const cv::Mat &srcimg) { + // visualization + cv::Point rook_points[filter_boxes.size()][4]; + for (int n = 0; n < filter_boxes.size(); n++) { + for (int m = 0; m < filter_boxes[0].size(); m++) { + rook_points[n][m] = + cv::Point(int(filter_boxes[n][m][0]), int(filter_boxes[n][m][1])); + } + } + + cv::Mat img_vis; + srcimg.copyTo(img_vis); + for (int n = 0; n < boxes.size(); n++) { + const cv::Point *ppt[1] = {rook_points[n]}; + int npt[] = {4}; + cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0); + } + // 调试用,自行替换需要修改的路径 + cv::imwrite("/sdcard/1/vis.png", img_vis); +} + +std::vector +OCR_PPredictor::infer_ocr(cv::Mat &origin,int max_size_len, int run_det, int run_cls, int run_rec) { + LOGI("ocr cpp start *****************"); + LOGI("ocr cpp det: %d, cls: %d, rec: %d", run_det, run_cls, run_rec); + std::vector ocr_results; + if(run_det){ + infer_det(origin, max_size_len, ocr_results); + } + if(run_rec){ + if(ocr_results.size()==0){ + OCRPredictResult res; + ocr_results.emplace_back(std::move(res)); + } + for(int i = 0; i < ocr_results.size();i++) { + infer_rec(origin, run_cls, ocr_results[i]); + } + }else if(run_cls){ + ClsPredictResult cls_res = infer_cls(origin); + OCRPredictResult res; + res.cls_score = cls_res.cls_score; + res.cls_label = cls_res.cls_label; + ocr_results.push_back(res); + } + + LOGI("ocr cpp end *****************"); + return ocr_results; +} + +cv::Mat DetResizeImg(const cv::Mat img, int max_size_len, + std::vector &ratio_hw) { + int w = img.cols; + int h = img.rows; + + float ratio = 1.f; + int max_wh = w >= h ? w : h; + if (max_wh > max_size_len) { + if (h > w) { + ratio = static_cast(max_size_len) / static_cast(h); + } else { + ratio = static_cast(max_size_len) / static_cast(w); + } + } + + int resize_h = static_cast(float(h) * ratio); + int resize_w = static_cast(float(w) * ratio); + if (resize_h % 32 == 0) + resize_h = resize_h; + else if (resize_h / 32 < 1 + 1e-5) + resize_h = 32; + else + resize_h = (resize_h / 32 - 1) * 32; + + if (resize_w % 32 == 0) + resize_w = resize_w; + else if (resize_w / 32 < 1 + 1e-5) + resize_w = 32; + else + resize_w = (resize_w / 32 - 1) * 32; + + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); + + ratio_hw.push_back(static_cast(resize_h) / static_cast(h)); + ratio_hw.push_back(static_cast(resize_w) / static_cast(w)); + return resize_img; +} + +void OCR_PPredictor::infer_det(cv::Mat &origin, int max_size_len, std::vector &ocr_results) { + std::vector mean = {0.485f, 0.456f, 0.406f}; + std::vector scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + + PredictorInput input = _det_predictor->get_first_input(); + + std::vector ratio_hw; + cv::Mat input_image = DetResizeImg(origin, max_size_len, ratio_hw); + input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); + const float *dimg = reinterpret_cast(input_image.data); + int input_size = input_image.rows * input_image.cols; + + input.set_dims({1, 3, input_image.rows, input_image.cols}); + + neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, + scale); + LOGI("ocr cpp det shape %d,%d", input_image.rows,input_image.cols); + std::vector results = _det_predictor->infer(); + PredictorOutput &res = results.at(0); + std::vector>> filtered_box = calc_filtered_boxes( + res.get_float_data(), res.get_size(), input_image.rows, input_image.cols, origin); + LOGI("ocr cpp det Filter_box size %ld", filtered_box.size()); + + for(int i = 0;i mean = {0.5f, 0.5f, 0.5f}; + std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; + std::vector dims = {1, 3, 0, 0}; + + PredictorInput input = _rec_predictor->get_first_input(); + + const std::vector> &box = ocr_result.points; + cv::Mat crop_img; + if(box.size()>0){ + crop_img = get_rotate_crop_image(origin_img, box); + } + else{ + crop_img = origin_img; + } + + if(run_cls){ + ClsPredictResult cls_res = infer_cls(crop_img); + crop_img = cls_res.img; + ocr_result.cls_score = cls_res.cls_score; + ocr_result.cls_label = cls_res.cls_label; + } + + + float wh_ratio = float(crop_img.cols) / float(crop_img.rows); + cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio); + input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); + const float *dimg = reinterpret_cast(input_image.data); + int input_size = input_image.rows * input_image.cols; + + dims[2] = input_image.rows; + dims[3] = input_image.cols; + input.set_dims(dims); + + neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, + scale); + + std::vector results = _rec_predictor->infer(); + const float *predict_batch = results.at(0).get_float_data(); + const std::vector predict_shape = results.at(0).get_shape(); + + // ctc decode + int argmax_idx; + int last_index = 0; + float score = 0.f; + int count = 0; + float max_value = 0.0f; + + for (int n = 0; n < predict_shape[1]; n++) { + argmax_idx = int(argmax(&predict_batch[n * predict_shape[2]], + &predict_batch[(n + 1) * predict_shape[2]])); + max_value = + float(*std::max_element(&predict_batch[n * predict_shape[2]], + &predict_batch[(n + 1) * predict_shape[2]])); + if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { + score += max_value; + count += 1; + ocr_result.word_index.push_back(argmax_idx); + } + last_index = argmax_idx; + } + score /= count; + ocr_result.score = score; + LOGI("ocr cpp rec word size %ld", count); +} + +ClsPredictResult OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { + std::vector mean = {0.5f, 0.5f, 0.5f}; + std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; + std::vector dims = {1, 3, 0, 0}; + + PredictorInput input = _cls_predictor->get_first_input(); + + cv::Mat input_image = cls_resize_img(img); + input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); + const float *dimg = reinterpret_cast(input_image.data); + int input_size = input_image.rows * input_image.cols; + + dims[2] = input_image.rows; + dims[3] = input_image.cols; + input.set_dims(dims); + + neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, + scale); + + std::vector results = _cls_predictor->infer(); + + const float *scores = results.at(0).get_float_data(); + float score = 0; + int label = 0; + for (int64_t i = 0; i < results.at(0).get_size(); i++) { + LOGI("ocr cpp cls output scores [%f]", scores[i]); + if (scores[i] > score) { + score = scores[i]; + label = i; + } + } + cv::Mat srcimg; + img.copyTo(srcimg); + if (label % 2 == 1 && score > thresh) { + cv::rotate(srcimg, srcimg, 1); + } + ClsPredictResult res; + res.cls_label = label; + res.cls_score = score; + res.img = srcimg; + LOGI("ocr cpp cls word cls %ld, %f", label, score); + return res; +} + +std::vector>> +OCR_PPredictor::calc_filtered_boxes(const float *pred, int pred_size, + int output_height, int output_width, + const cv::Mat &origin) { + const double threshold = 0.3; + const double maxvalue = 1; + + cv::Mat pred_map = cv::Mat::zeros(output_height, output_width, CV_32F); + memcpy(pred_map.data, pred, pred_size * sizeof(float)); + cv::Mat cbuf_map; + pred_map.convertTo(cbuf_map, CV_8UC1); + + cv::Mat bit_map; + cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); + + std::vector>> boxes = + boxes_from_bitmap(pred_map, bit_map); + float ratio_h = output_height * 1.0f / origin.rows; + float ratio_w = output_width * 1.0f / origin.cols; + std::vector>> filter_boxes = + filter_tag_det_res(boxes, ratio_h, ratio_w, origin); + return filter_boxes; +} + +std::vector +OCR_PPredictor::postprocess_rec_word_index(const PredictorOutput &res) { + const int *rec_idx = res.get_int_data(); + const std::vector> rec_idx_lod = res.get_lod(); + + std::vector pred_idx; + for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1] * 2); n += 2) { + pred_idx.emplace_back(rec_idx[n]); + } + return pred_idx; +} + +float OCR_PPredictor::postprocess_rec_score(const PredictorOutput &res) { + const float *predict_batch = res.get_float_data(); + const std::vector predict_shape = res.get_shape(); + const std::vector> predict_lod = res.get_lod(); + int blank = predict_shape[1]; + float score = 0.f; + int count = 0; + for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) { + int argmax_idx = argmax(predict_batch + n * predict_shape[1], + predict_batch + (n + 1) * predict_shape[1]); + float max_value = predict_batch[n * predict_shape[1] + argmax_idx]; + if (blank - 1 - argmax_idx > 1e-5) { + score += max_value; + count += 1; + } + } + if (count == 0) { + LOGE("calc score count 0"); + } else { + score /= count; + } + LOGI("calc score: %f", score); + return score; +} + +NET_TYPE OCR_PPredictor::get_net_flag() const { return NET_OCR; } +} \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h new file mode 100644 index 0000000..f0bff93 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h @@ -0,0 +1,130 @@ +// +// Created by fujiayi on 2020/7/1. +// + +#pragma once + +#include "ppredictor.h" +#include +#include +#include + +namespace ppredictor { + +/** + * Config + */ +struct OCR_Config { + int use_opencl = 0; + int thread_num = 4; // Thread num + paddle::lite_api::PowerMode mode = + paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode +}; + +/** + * PolyGone Result + */ +struct OCRPredictResult { + std::vector word_index; + std::vector> points; + float score; + float cls_score; + int cls_label=-1; +}; + +struct ClsPredictResult { + float cls_score; + int cls_label=-1; + cv::Mat img; +}; +/** + * OCR there are 2 models + * 1. First model(det),select polygones to show where are the texts + * 2. crop from the origin images, use these polygones to infer + */ +class OCR_PPredictor : public PPredictor_Interface { +public: + OCR_PPredictor(const OCR_Config &config); + + virtual ~OCR_PPredictor() {} + + /** + * 初始化二个模型的Predictor + * @param det_model_content + * @param rec_model_content + * @return + */ + int init(const std::string &det_model_content, + const std::string &rec_model_content, + const std::string &cls_model_content); + int init_from_file(const std::string &det_model_path, + const std::string &rec_model_path, + const std::string &cls_model_path); + /** + * Return OCR result + * @param dims + * @param input_data + * @param input_len + * @param net_flag + * @param origin + * @return + */ + virtual std::vector + infer_ocr(cv::Mat &origin, int max_size_len, int run_det, int run_cls, int run_rec); + + virtual NET_TYPE get_net_flag() const; + +private: + /** + * calcul Polygone from the result image of first model + * @param pred + * @param output_height + * @param output_width + * @param origin + * @return + */ + std::vector>> + calc_filtered_boxes(const float *pred, int pred_size, int output_height, + int output_width, const cv::Mat &origin); + + void + infer_det(cv::Mat &origin, int max_side_len, std::vector& ocr_results); + /** + * infer for rec model + * + * @param boxes + * @param origin + * @return + */ + void + infer_rec(const cv::Mat &origin, int run_cls, OCRPredictResult& ocr_result); + + /** + * infer for cls model + * + * @param boxes + * @param origin + * @return + */ + ClsPredictResult infer_cls(const cv::Mat &origin, float thresh = 0.9); + + /** + * Postprocess or sencod model to extract text + * @param res + * @return + */ + std::vector postprocess_rec_word_index(const PredictorOutput &res); + + /** + * calculate confidence of second model text result + * @param res + * @return + */ + float postprocess_rec_score(const PredictorOutput &res); + + std::unique_ptr _det_predictor; + std::unique_ptr _rec_predictor; + std::unique_ptr _cls_predictor; + OCR_Config _config; +}; +} diff --git a/deploy/android_demo/app/src/main/cpp/ppredictor.cpp b/deploy/android_demo/app/src/main/cpp/ppredictor.cpp new file mode 100644 index 0000000..a40fe5e --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ppredictor.cpp @@ -0,0 +1,95 @@ +#include "ppredictor.h" +#include "common.h" + +namespace ppredictor { +PPredictor::PPredictor(int use_opencl, int thread_num, int net_flag, + paddle::lite_api::PowerMode mode) + : _use_opencl(use_opencl), _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {} + +int PPredictor::init_nb(const std::string &model_content) { + paddle::lite_api::MobileConfig config; + config.set_model_from_buffer(model_content); + return _init(config); +} + +int PPredictor::init_from_file(const std::string &model_content) { + paddle::lite_api::MobileConfig config; + config.set_model_from_file(model_content); + return _init(config); +} + +template int PPredictor::_init(ConfigT &config) { + bool is_opencl_backend_valid = paddle::lite_api::IsOpenCLBackendValid(/*check_fp16_valid = false*/); + if (is_opencl_backend_valid) { + if (_use_opencl != 0) { + // Make sure you have write permission of the binary path. + // We strongly recommend each model has a unique binary name. + const std::string bin_path = "/data/local/tmp/"; + const std::string bin_name = "lite_opencl_kernel.bin"; + config.set_opencl_binary_path_name(bin_path, bin_name); + + // opencl tune option + // CL_TUNE_NONE: 0 + // CL_TUNE_RAPID: 1 + // CL_TUNE_NORMAL: 2 + // CL_TUNE_EXHAUSTIVE: 3 + const std::string tuned_path = "/data/local/tmp/"; + const std::string tuned_name = "lite_opencl_tuned.bin"; + config.set_opencl_tune(paddle::lite_api::CL_TUNE_NORMAL, tuned_path, tuned_name); + + // opencl precision option + // CL_PRECISION_AUTO: 0, first fp16 if valid, default + // CL_PRECISION_FP32: 1, force fp32 + // CL_PRECISION_FP16: 2, force fp16 + config.set_opencl_precision(paddle::lite_api::CL_PRECISION_FP32); + LOGI("ocr cpp device: running on gpu."); + } + } else { + LOGI("ocr cpp device: running on cpu."); + // you can give backup cpu nb model instead + // config.set_model_from_file(cpu_nb_model_dir); + } + config.set_threads(_thread_num); + config.set_power_mode(_mode); + _predictor = paddle::lite_api::CreatePaddlePredictor(config); + LOGI("ocr cpp paddle instance created"); + return RETURN_OK; +} + +PredictorInput PPredictor::get_input(int index) { + PredictorInput input{_predictor->GetInput(index), index, _net_flag}; + _is_input_get = true; + return input; +} + +std::vector PPredictor::get_inputs(int num) { + std::vector results; + for (int i = 0; i < num; i++) { + results.emplace_back(get_input(i)); + } + return results; +} + +PredictorInput PPredictor::get_first_input() { return get_input(0); } + +std::vector PPredictor::infer() { + LOGI("ocr cpp infer Run start %d", _net_flag); + std::vector results; + if (!_is_input_get) { + return results; + } + _predictor->Run(); + LOGI("ocr cpp infer Run end"); + + for (int i = 0; i < _predictor->GetOutputNames().size(); i++) { + std::unique_ptr output_tensor = + _predictor->GetOutput(i); + LOGI("ocr cpp output tensor[%d] size %ld", i, product(output_tensor->shape())); + PredictorOutput result{std::move(output_tensor), i, _net_flag}; + results.emplace_back(std::move(result)); + } + return results; +} + +NET_TYPE PPredictor::get_net_flag() const { return (NET_TYPE)_net_flag; } +} \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ppredictor.h b/deploy/android_demo/app/src/main/cpp/ppredictor.h new file mode 100644 index 0000000..4025076 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ppredictor.h @@ -0,0 +1,64 @@ +#pragma once + +#include "paddle_api.h" +#include "predictor_input.h" +#include "predictor_output.h" + +namespace ppredictor { + +/** + * PaddleLite Preditor Common Interface + */ +class PPredictor_Interface { +public: + virtual ~PPredictor_Interface() {} + + virtual NET_TYPE get_net_flag() const = 0; +}; + +/** + * Common Predictor + */ +class PPredictor : public PPredictor_Interface { +public: + PPredictor( + int use_opencl, int thread_num, int net_flag = 0, + paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH); + + virtual ~PPredictor() {} + + /** + * init paddlitelite opt model,nb format ,or use ini_paddle + * @param model_content + * @return 0 + */ + virtual int init_nb(const std::string &model_content); + + virtual int init_from_file(const std::string &model_content); + + std::vector infer(); + + std::shared_ptr get_predictor() { + return _predictor; + } + + virtual std::vector get_inputs(int num); + + virtual PredictorInput get_input(int index); + + virtual PredictorInput get_first_input(); + + virtual NET_TYPE get_net_flag() const; + +protected: + template int _init(ConfigT &config); + +private: + int _use_opencl; + int _thread_num; + paddle::lite_api::PowerMode _mode; + std::shared_ptr _predictor; + bool _is_input_get = false; + int _net_flag; +}; +} diff --git a/deploy/android_demo/app/src/main/cpp/predictor_input.cpp b/deploy/android_demo/app/src/main/cpp/predictor_input.cpp new file mode 100644 index 0000000..f0b4bf8 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/predictor_input.cpp @@ -0,0 +1,28 @@ +#include "predictor_input.h" + +namespace ppredictor { + +void PredictorInput::set_dims(std::vector dims) { + // yolov3 + if (_net_flag == 101 && _index == 1) { + _tensor->Resize({1, 2}); + _tensor->mutable_data()[0] = (int)dims.at(2); + _tensor->mutable_data()[1] = (int)dims.at(3); + } else { + _tensor->Resize(dims); + } + _is_dims_set = true; +} + +float *PredictorInput::get_mutable_float_data() { + if (!_is_dims_set) { + LOGE("PredictorInput::set_dims is not called"); + } + return _tensor->mutable_data(); +} + +void PredictorInput::set_data(const float *input_data, int input_float_len) { + float *input_raw_data = get_mutable_float_data(); + memcpy(input_raw_data, input_data, input_float_len * sizeof(float)); +} +} \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/predictor_input.h b/deploy/android_demo/app/src/main/cpp/predictor_input.h new file mode 100644 index 0000000..f3fd6cf --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/predictor_input.h @@ -0,0 +1,26 @@ +#pragma once + +#include "common.h" +#include +#include + +namespace ppredictor { +class PredictorInput { +public: + PredictorInput(std::unique_ptr &&tensor, int index, + int net_flag) + : _tensor(std::move(tensor)), _index(index), _net_flag(net_flag) {} + + void set_dims(std::vector dims); + + float *get_mutable_float_data(); + + void set_data(const float *input_data, int input_float_len); + +private: + std::unique_ptr _tensor; + bool _is_dims_set = false; + int _index; + int _net_flag; +}; +} diff --git a/deploy/android_demo/app/src/main/cpp/predictor_output.cpp b/deploy/android_demo/app/src/main/cpp/predictor_output.cpp new file mode 100644 index 0000000..e9cfdbc --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/predictor_output.cpp @@ -0,0 +1,26 @@ +#include "predictor_output.h" +namespace ppredictor { +const float *PredictorOutput::get_float_data() const { + return _tensor->data(); +} + +const int *PredictorOutput::get_int_data() const { + return _tensor->data(); +} + +const std::vector> PredictorOutput::get_lod() const { + return _tensor->lod(); +} + +int64_t PredictorOutput::get_size() const { + if (_net_flag == NET_OCR) { + return _tensor->shape().at(2) * _tensor->shape().at(3); + } else { + return product(_tensor->shape()); + } +} + +const std::vector PredictorOutput::get_shape() const { + return _tensor->shape(); +} +} \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/predictor_output.h b/deploy/android_demo/app/src/main/cpp/predictor_output.h new file mode 100644 index 0000000..8e8c9ba --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/predictor_output.h @@ -0,0 +1,31 @@ +#pragma once + +#include "common.h" +#include +#include + +namespace ppredictor { +class PredictorOutput { +public: + PredictorOutput() {} + PredictorOutput(std::unique_ptr &&tensor, + int index, int net_flag) + : _tensor(std::move(tensor)), _index(index), _net_flag(net_flag) {} + + const float *get_float_data() const; + const int *get_int_data() const; + int64_t get_size() const; + const std::vector> get_lod() const; + const std::vector get_shape() const; + + std::vector data; // return float, or use data_int + std::vector data_int; // several layers return int ,or use data + std::vector shape; // PaddleLite output shape + std::vector> lod; // PaddleLite output lod + +private: + std::unique_ptr _tensor; + int _index; + int _net_flag; +}; +} diff --git a/deploy/android_demo/app/src/main/cpp/preprocess.cpp b/deploy/android_demo/app/src/main/cpp/preprocess.cpp new file mode 100644 index 0000000..e99b2cd --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/preprocess.cpp @@ -0,0 +1,82 @@ +#include "preprocess.h" +#include + +cv::Mat bitmap_to_cv_mat(JNIEnv *env, jobject bitmap) { + AndroidBitmapInfo info; + int result = AndroidBitmap_getInfo(env, bitmap, &info); + if (result != ANDROID_BITMAP_RESULT_SUCCESS) { + LOGE("AndroidBitmap_getInfo failed, result: %d", result); + return cv::Mat{}; + } + if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) { + LOGE("Bitmap format is not RGBA_8888 !"); + return cv::Mat{}; + } + unsigned char *srcData = NULL; + AndroidBitmap_lockPixels(env, bitmap, (void **)&srcData); + cv::Mat mat = cv::Mat::zeros(info.height, info.width, CV_8UC4); + memcpy(mat.data, srcData, info.height * info.width * 4); + AndroidBitmap_unlockPixels(env, bitmap); + cv::cvtColor(mat, mat, cv::COLOR_RGBA2BGR); + /** + if (!cv::imwrite("/sdcard/1/copy.jpg", mat)){ + LOGE("Write image failed " ); + } + */ + + return mat; +} + +cv::Mat resize_img(const cv::Mat &img, int height, int width) { + if (img.rows == height && img.cols == width) { + return img; + } + cv::Mat new_img; + cv::resize(img, new_img, cv::Size(height, width)); + return new_img; +} + +// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up +void neon_mean_scale(const float *din, float *dout, int size, + const std::vector &mean, + const std::vector &scale) { + if (mean.size() != 3 || scale.size() != 3) { + LOGE("[ERROR] mean or scale size must equal to 3"); + return; + } + + float32x4_t vmean0 = vdupq_n_f32(mean[0]); + float32x4_t vmean1 = vdupq_n_f32(mean[1]); + float32x4_t vmean2 = vdupq_n_f32(mean[2]); + float32x4_t vscale0 = vdupq_n_f32(scale[0]); + float32x4_t vscale1 = vdupq_n_f32(scale[1]); + float32x4_t vscale2 = vdupq_n_f32(scale[2]); + + float *dout_c0 = dout; + float *dout_c1 = dout + size; + float *dout_c2 = dout + size * 2; + + int i = 0; + for (; i < size - 3; i += 4) { + float32x4x3_t vin3 = vld3q_f32(din); + float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0); + float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1); + float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2); + float32x4_t vs0 = vmulq_f32(vsub0, vscale0); + float32x4_t vs1 = vmulq_f32(vsub1, vscale1); + float32x4_t vs2 = vmulq_f32(vsub2, vscale2); + vst1q_f32(dout_c0, vs0); + vst1q_f32(dout_c1, vs1); + vst1q_f32(dout_c2, vs2); + + din += 12; + dout_c0 += 4; + dout_c1 += 4; + dout_c2 += 4; + } + for (; i < size; i++) { + *(dout_c0++) = (*(din++) - mean[0]) * scale[0]; + *(dout_c1++) = (*(din++) - mean[1]) * scale[1]; + *(dout_c2++) = (*(din++) - mean[2]) * scale[2]; + } +} \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/preprocess.h b/deploy/android_demo/app/src/main/cpp/preprocess.h new file mode 100644 index 0000000..7909152 --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/preprocess.h @@ -0,0 +1,12 @@ +#pragma once + +#include "common.h" +#include +#include +cv::Mat bitmap_to_cv_mat(JNIEnv *env, jobject bitmap); + +cv::Mat resize_img(const cv::Mat &img, int height, int width); + +void neon_mean_scale(const float *din, float *dout, int size, + const std::vector &mean, + const std::vector &scale); diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/AppCompatPreferenceActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/AppCompatPreferenceActivity.java new file mode 100644 index 0000000..49af0af --- /dev/null +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/AppCompatPreferenceActivity.java @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidu.paddle.lite.demo.ocr; + +import android.content.res.Configuration; +import android.os.Bundle; +import android.preference.PreferenceActivity; +import android.view.MenuInflater; +import android.view.View; +import android.view.ViewGroup; + +import androidx.annotation.LayoutRes; +import androidx.annotation.Nullable; +import androidx.appcompat.app.ActionBar; +import androidx.appcompat.app.AppCompatDelegate; +import androidx.appcompat.widget.Toolbar; + +/** + * A {@link PreferenceActivity} which implements and proxies the necessary calls + * to be used with AppCompat. + *

+ * This technique can be used with an {@link android.app.Activity} class, not just + * {@link PreferenceActivity}. + */ +public abstract class AppCompatPreferenceActivity extends PreferenceActivity { + private AppCompatDelegate mDelegate; + + @Override + protected void onCreate(Bundle savedInstanceState) { + getDelegate().installViewFactory(); + getDelegate().onCreate(savedInstanceState); + super.onCreate(savedInstanceState); + } + + @Override + protected void onPostCreate(Bundle savedInstanceState) { + super.onPostCreate(savedInstanceState); + getDelegate().onPostCreate(savedInstanceState); + } + + public ActionBar getSupportActionBar() { + return getDelegate().getSupportActionBar(); + } + + public void setSupportActionBar(@Nullable Toolbar toolbar) { + getDelegate().setSupportActionBar(toolbar); + } + + @Override + public MenuInflater getMenuInflater() { + return getDelegate().getMenuInflater(); + } + + @Override + public void setContentView(@LayoutRes int layoutResID) { + getDelegate().setContentView(layoutResID); + } + + @Override + public void setContentView(View view) { + getDelegate().setContentView(view); + } + + @Override + public void setContentView(View view, ViewGroup.LayoutParams params) { + getDelegate().setContentView(view, params); + } + + @Override + public void addContentView(View view, ViewGroup.LayoutParams params) { + getDelegate().addContentView(view, params); + } + + @Override + protected void onPostResume() { + super.onPostResume(); + getDelegate().onPostResume(); + } + + @Override + protected void onTitleChanged(CharSequence title, int color) { + super.onTitleChanged(title, color); + getDelegate().setTitle(title); + } + + @Override + public void onConfigurationChanged(Configuration newConfig) { + super.onConfigurationChanged(newConfig); + getDelegate().onConfigurationChanged(newConfig); + } + + @Override + protected void onStop() { + super.onStop(); + getDelegate().onStop(); + } + + @Override + protected void onDestroy() { + super.onDestroy(); + getDelegate().onDestroy(); + } + + public void invalidateOptionsMenu() { + getDelegate().invalidateOptionsMenu(); + } + + private AppCompatDelegate getDelegate() { + if (mDelegate == null) { + mDelegate = AppCompatDelegate.create(this, null); + } + return mDelegate; + } +} diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java new file mode 100644 index 0000000..f932718 --- /dev/null +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java @@ -0,0 +1,520 @@ +package com.baidu.paddle.lite.demo.ocr; + +import android.Manifest; +import android.app.ProgressDialog; +import android.content.ContentResolver; +import android.content.Context; +import android.content.Intent; +import android.content.SharedPreferences; +import android.content.pm.PackageManager; +import android.database.Cursor; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.drawable.BitmapDrawable; +import android.media.ExifInterface; +import android.content.res.AssetManager; +import android.media.FaceDetector; +import android.net.Uri; +import android.os.Bundle; +import android.os.Environment; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Message; +import android.preference.PreferenceManager; +import android.provider.MediaStore; +import android.text.method.ScrollingMovementMethod; +import android.util.Log; +import android.view.Menu; +import android.view.MenuInflater; +import android.view.MenuItem; +import android.view.View; +import android.widget.CheckBox; +import android.widget.ImageView; +import android.widget.Spinner; +import android.widget.TextView; +import android.widget.Toast; + +import androidx.annotation.NonNull; +import androidx.appcompat.app.AppCompatActivity; +import androidx.core.app.ActivityCompat; +import androidx.core.content.ContextCompat; +import androidx.core.content.FileProvider; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.text.SimpleDateFormat; +import java.util.Date; + +public class MainActivity extends AppCompatActivity { + private static final String TAG = MainActivity.class.getSimpleName(); + public static final int OPEN_GALLERY_REQUEST_CODE = 0; + public static final int TAKE_PHOTO_REQUEST_CODE = 1; + + public static final int REQUEST_LOAD_MODEL = 0; + public static final int REQUEST_RUN_MODEL = 1; + public static final int RESPONSE_LOAD_MODEL_SUCCESSED = 0; + public static final int RESPONSE_LOAD_MODEL_FAILED = 1; + public static final int RESPONSE_RUN_MODEL_SUCCESSED = 2; + public static final int RESPONSE_RUN_MODEL_FAILED = 3; + + protected ProgressDialog pbLoadModel = null; + protected ProgressDialog pbRunModel = null; + + protected Handler receiver = null; // Receive messages from worker thread + protected Handler sender = null; // Send command to worker thread + protected HandlerThread worker = null; // Worker thread to load&run model + + // UI components of object detection + protected TextView tvInputSetting; + protected TextView tvStatus; + protected ImageView ivInputImage; + protected TextView tvOutputResult; + protected TextView tvInferenceTime; + protected CheckBox cbOpencl; + protected Spinner spRunMode; + + // Model settings of ocr + protected String modelPath = ""; + protected String labelPath = ""; + protected String imagePath = ""; + protected int cpuThreadNum = 1; + protected String cpuPowerMode = ""; + protected int detLongSize = 960; + protected float scoreThreshold = 0.1f; + private String currentPhotoPath; + private AssetManager assetManager = null; + + protected Predictor predictor = new Predictor(); + + private Bitmap cur_predict_image = null; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + // Clear all setting items to avoid app crashing due to the incorrect settings + SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this); + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.clear(); + editor.apply(); + + // Setup the UI components + tvInputSetting = findViewById(R.id.tv_input_setting); + cbOpencl = findViewById(R.id.cb_opencl); + tvStatus = findViewById(R.id.tv_model_img_status); + ivInputImage = findViewById(R.id.iv_input_image); + tvInferenceTime = findViewById(R.id.tv_inference_time); + tvOutputResult = findViewById(R.id.tv_output_result); + spRunMode = findViewById(R.id.sp_run_mode); + tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance()); + tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance()); + + // Prepare the worker thread for mode loading and inference + receiver = new Handler() { + @Override + public void handleMessage(Message msg) { + switch (msg.what) { + case RESPONSE_LOAD_MODEL_SUCCESSED: + if (pbLoadModel != null && pbLoadModel.isShowing()) { + pbLoadModel.dismiss(); + } + onLoadModelSuccessed(); + break; + case RESPONSE_LOAD_MODEL_FAILED: + if (pbLoadModel != null && pbLoadModel.isShowing()) { + pbLoadModel.dismiss(); + } + Toast.makeText(MainActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show(); + onLoadModelFailed(); + break; + case RESPONSE_RUN_MODEL_SUCCESSED: + if (pbRunModel != null && pbRunModel.isShowing()) { + pbRunModel.dismiss(); + } + onRunModelSuccessed(); + break; + case RESPONSE_RUN_MODEL_FAILED: + if (pbRunModel != null && pbRunModel.isShowing()) { + pbRunModel.dismiss(); + } + Toast.makeText(MainActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show(); + onRunModelFailed(); + break; + default: + break; + } + } + }; + + worker = new HandlerThread("Predictor Worker"); + worker.start(); + sender = new Handler(worker.getLooper()) { + public void handleMessage(Message msg) { + switch (msg.what) { + case REQUEST_LOAD_MODEL: + // Load model and reload test image + if (onLoadModel()) { + receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_SUCCESSED); + } else { + receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_FAILED); + } + break; + case REQUEST_RUN_MODEL: + // Run model if model is loaded + if (onRunModel()) { + receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_SUCCESSED); + } else { + receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_FAILED); + } + break; + default: + break; + } + } + }; + } + + @Override + protected void onResume() { + super.onResume(); + SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this); + boolean settingsChanged = false; + boolean model_settingsChanged = false; + String model_path = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY), + getString(R.string.MODEL_PATH_DEFAULT)); + String label_path = sharedPreferences.getString(getString(R.string.LABEL_PATH_KEY), + getString(R.string.LABEL_PATH_DEFAULT)); + String image_path = sharedPreferences.getString(getString(R.string.IMAGE_PATH_KEY), + getString(R.string.IMAGE_PATH_DEFAULT)); + model_settingsChanged |= !model_path.equalsIgnoreCase(modelPath); + settingsChanged |= !label_path.equalsIgnoreCase(labelPath); + settingsChanged |= !image_path.equalsIgnoreCase(imagePath); + int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY), + getString(R.string.CPU_THREAD_NUM_DEFAULT))); + model_settingsChanged |= cpu_thread_num != cpuThreadNum; + String cpu_power_mode = + sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY), + getString(R.string.CPU_POWER_MODE_DEFAULT)); + model_settingsChanged |= !cpu_power_mode.equalsIgnoreCase(cpuPowerMode); + + int det_long_size = Integer.parseInt(sharedPreferences.getString(getString(R.string.DET_LONG_SIZE_KEY), + getString(R.string.DET_LONG_SIZE_DEFAULT))); + settingsChanged |= det_long_size != detLongSize; + float score_threshold = + Float.parseFloat(sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY), + getString(R.string.SCORE_THRESHOLD_DEFAULT))); + settingsChanged |= scoreThreshold != score_threshold; + if (settingsChanged) { + labelPath = label_path; + imagePath = image_path; + detLongSize = det_long_size; + scoreThreshold = score_threshold; + set_img(); + } + if (model_settingsChanged) { + modelPath = model_path; + cpuThreadNum = cpu_thread_num; + cpuPowerMode = cpu_power_mode; + // Update UI + tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode); + tvInputSetting.scrollTo(0, 0); + // Reload model if configure has been changed + loadModel(); + } + } + + public void loadModel() { + pbLoadModel = ProgressDialog.show(this, "", "loading model...", false, false); + sender.sendEmptyMessage(REQUEST_LOAD_MODEL); + } + + public void runModel() { + pbRunModel = ProgressDialog.show(this, "", "running model...", false, false); + sender.sendEmptyMessage(REQUEST_RUN_MODEL); + } + + public boolean onLoadModel() { + if (predictor.isLoaded()) { + predictor.releaseModel(); + } + return predictor.init(MainActivity.this, modelPath, labelPath, cbOpencl.isChecked() ? 1 : 0, cpuThreadNum, + cpuPowerMode, + detLongSize, scoreThreshold); + } + + public boolean onRunModel() { + String run_mode = spRunMode.getSelectedItem().toString(); + int run_det = run_mode.contains("检测") ? 1 : 0; + int run_cls = run_mode.contains("分类") ? 1 : 0; + int run_rec = run_mode.contains("识别") ? 1 : 0; + return predictor.isLoaded() && predictor.runModel(run_det, run_cls, run_rec); + } + + public void onLoadModelSuccessed() { + // Load test image from path and run model + tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode); + tvInputSetting.scrollTo(0, 0); + tvStatus.setText("STATUS: load model successed"); + + } + + public void onLoadModelFailed() { + tvStatus.setText("STATUS: load model failed"); + } + + public void onRunModelSuccessed() { + tvStatus.setText("STATUS: run model successed"); + // Obtain results and update UI + tvInferenceTime.setText("Inference time: " + predictor.inferenceTime() + " ms"); + Bitmap outputImage = predictor.outputImage(); + if (outputImage != null) { + ivInputImage.setImageBitmap(outputImage); + } + tvOutputResult.setText(predictor.outputResult()); + tvOutputResult.scrollTo(0, 0); + } + + public void onRunModelFailed() { + tvStatus.setText("STATUS: run model failed"); + } + + public void set_img() { + // Load test image from path and run model + try { + assetManager = getAssets(); + InputStream in = assetManager.open(imagePath); + Bitmap bmp = BitmapFactory.decodeStream(in); + cur_predict_image = bmp; + ivInputImage.setImageBitmap(bmp); + } catch (IOException e) { + Toast.makeText(MainActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show(); + e.printStackTrace(); + } + } + + public void onSettingsClicked() { + startActivity(new Intent(MainActivity.this, SettingsActivity.class)); + } + + @Override + public boolean onCreateOptionsMenu(Menu menu) { + MenuInflater inflater = getMenuInflater(); + inflater.inflate(R.menu.menu_action_options, menu); + return true; + } + + public boolean onPrepareOptionsMenu(Menu menu) { + boolean isLoaded = predictor.isLoaded(); + return super.onPrepareOptionsMenu(menu); + } + + @Override + public boolean onOptionsItemSelected(MenuItem item) { + switch (item.getItemId()) { + case android.R.id.home: + finish(); + break; + case R.id.settings: + if (requestAllPermissions()) { + // Make sure we have SDCard r&w permissions to load model from SDCard + onSettingsClicked(); + } + break; + } + return super.onOptionsItemSelected(item); + } + + @Override + public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, + @NonNull int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + if (grantResults[0] != PackageManager.PERMISSION_GRANTED || grantResults[1] != PackageManager.PERMISSION_GRANTED) { + Toast.makeText(this, "Permission Denied", Toast.LENGTH_SHORT).show(); + } + } + + private boolean requestAllPermissions() { + if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE) + != PackageManager.PERMISSION_GRANTED || ContextCompat.checkSelfPermission(this, + Manifest.permission.CAMERA) + != PackageManager.PERMISSION_GRANTED) { + ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.WRITE_EXTERNAL_STORAGE, + Manifest.permission.CAMERA}, + 0); + return false; + } + return true; + } + + private void openGallery() { + Intent intent = new Intent(Intent.ACTION_PICK, null); + intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*"); + startActivityForResult(intent, OPEN_GALLERY_REQUEST_CODE); + } + + private void takePhoto() { + Intent takePictureIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE); + // Ensure that there's a camera activity to handle the intent + if (takePictureIntent.resolveActivity(getPackageManager()) != null) { + // Create the File where the photo should go + File photoFile = null; + try { + photoFile = createImageFile(); + } catch (IOException ex) { + Log.e("MainActitity", ex.getMessage(), ex); + Toast.makeText(MainActivity.this, + "Create Camera temp file failed: " + ex.getMessage(), Toast.LENGTH_SHORT).show(); + } + // Continue only if the File was successfully created + if (photoFile != null) { + Log.i(TAG, "FILEPATH " + getExternalFilesDir("Pictures").getAbsolutePath()); + Uri photoURI = FileProvider.getUriForFile(this, + "com.baidu.paddle.lite.demo.ocr.fileprovider", + photoFile); + currentPhotoPath = photoFile.getAbsolutePath(); + takePictureIntent.putExtra(MediaStore.EXTRA_OUTPUT, photoURI); + startActivityForResult(takePictureIntent, TAKE_PHOTO_REQUEST_CODE); + Log.i(TAG, "startActivityForResult finished"); + } + } + + } + + private File createImageFile() throws IOException { + // Create an image file name + String timeStamp = new SimpleDateFormat("yyyyMMdd_HHmmss").format(new Date()); + String imageFileName = "JPEG_" + timeStamp + "_"; + File storageDir = getExternalFilesDir(Environment.DIRECTORY_PICTURES); + File image = File.createTempFile( + imageFileName, /* prefix */ + ".bmp", /* suffix */ + storageDir /* directory */ + ); + + return image; + } + + @Override + protected void onActivityResult(int requestCode, int resultCode, Intent data) { + super.onActivityResult(requestCode, resultCode, data); + if (resultCode == RESULT_OK) { + switch (requestCode) { + case OPEN_GALLERY_REQUEST_CODE: + if (data == null) { + break; + } + try { + ContentResolver resolver = getContentResolver(); + Uri uri = data.getData(); + Bitmap image = MediaStore.Images.Media.getBitmap(resolver, uri); + String[] proj = {MediaStore.Images.Media.DATA}; + Cursor cursor = managedQuery(uri, proj, null, null, null); + cursor.moveToFirst(); + if (image != null) { + cur_predict_image = image; + ivInputImage.setImageBitmap(image); + } + } catch (IOException e) { + Log.e(TAG, e.toString()); + } + break; + case TAKE_PHOTO_REQUEST_CODE: + if (currentPhotoPath != null) { + ExifInterface exif = null; + try { + exif = new ExifInterface(currentPhotoPath); + } catch (IOException e) { + e.printStackTrace(); + } + int orientation = exif.getAttributeInt(ExifInterface.TAG_ORIENTATION, + ExifInterface.ORIENTATION_UNDEFINED); + Log.i(TAG, "rotation " + orientation); + Bitmap image = BitmapFactory.decodeFile(currentPhotoPath); + image = Utils.rotateBitmap(image, orientation); + if (image != null) { + cur_predict_image = image; + ivInputImage.setImageBitmap(image); + } + } else { + Log.e(TAG, "currentPhotoPath is null"); + } + break; + default: + break; + } + } + } + + public void btn_reset_img_click(View view) { + ivInputImage.setImageBitmap(cur_predict_image); + } + + public void cb_opencl_click(View view) { + tvStatus.setText("STATUS: load model ......"); + loadModel(); + } + + public void btn_run_model_click(View view) { + Bitmap image = ((BitmapDrawable) ivInputImage.getDrawable()).getBitmap(); + if (image == null) { + tvStatus.setText("STATUS: image is not exists"); + } else if (!predictor.isLoaded()) { + tvStatus.setText("STATUS: model is not loaded"); + } else { + tvStatus.setText("STATUS: run model ...... "); + predictor.setInputImage(image); + runModel(); + } + } + + public void btn_choice_img_click(View view) { + if (requestAllPermissions()) { + openGallery(); + } + } + + public void btn_take_photo_click(View view) { + if (requestAllPermissions()) { + takePhoto(); + } + } + + @Override + protected void onDestroy() { + if (predictor != null) { + predictor.releaseModel(); + } + worker.quit(); + super.onDestroy(); + } + + public int get_run_mode() { + String run_mode = spRunMode.getSelectedItem().toString(); + int mode; + switch (run_mode) { + case "检测+分类+识别": + mode = 1; + break; + case "检测+识别": + mode = 2; + break; + case "识别+分类": + mode = 3; + break; + case "检测": + mode = 4; + break; + case "识别": + mode = 5; + break; + case "分类": + mode = 6; + break; + default: + mode = 1; + } + return mode; + } +} diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java new file mode 100644 index 0000000..41fa183 --- /dev/null +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java @@ -0,0 +1,105 @@ +package com.baidu.paddle.lite.demo.ocr; + +import android.graphics.Bitmap; +import android.util.Log; + +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicBoolean; + +public class OCRPredictorNative { + + private static final AtomicBoolean isSOLoaded = new AtomicBoolean(); + + public static void loadLibrary() throws RuntimeException { + if (!isSOLoaded.get() && isSOLoaded.compareAndSet(false, true)) { + try { + System.loadLibrary("Native"); + } catch (Throwable e) { + RuntimeException exception = new RuntimeException( + "Load libNative.so failed, please check it exists in apk file.", e); + throw exception; + } + } + } + + private Config config; + + private long nativePointer = 0; + + public OCRPredictorNative(Config config) { + this.config = config; + loadLibrary(); + nativePointer = init(config.detModelFilename, config.recModelFilename, config.clsModelFilename, config.useOpencl, + config.cpuThreadNum, config.cpuPower); + Log.i("OCRPredictorNative", "load success " + nativePointer); + + } + + + public ArrayList runImage(Bitmap originalImage, int max_size_len, int run_det, int run_cls, int run_rec) { + Log.i("OCRPredictorNative", "begin to run image "); + float[] rawResults = forward(nativePointer, originalImage, max_size_len, run_det, run_cls, run_rec); + ArrayList results = postprocess(rawResults); + return results; + } + + public static class Config { + public int useOpencl; + public int cpuThreadNum; + public String cpuPower; + public String detModelFilename; + public String recModelFilename; + public String clsModelFilename; + + } + + public void destory() { + if (nativePointer != 0) { + release(nativePointer); + nativePointer = 0; + } + } + + protected native long init(String detModelPath, String recModelPath, String clsModelPath, int useOpencl, int threadNum, String cpuMode); + + protected native float[] forward(long pointer, Bitmap originalImage,int max_size_len, int run_det, int run_cls, int run_rec); + + protected native void release(long pointer); + + private ArrayList postprocess(float[] raw) { + ArrayList results = new ArrayList(); + int begin = 0; + + while (begin < raw.length) { + int point_num = Math.round(raw[begin]); + int word_num = Math.round(raw[begin + 1]); + OcrResultModel res = parse(raw, begin + 2, point_num, word_num); + begin += 2 + 1 + point_num * 2 + word_num + 2; + results.add(res); + } + + return results; + } + + private OcrResultModel parse(float[] raw, int begin, int pointNum, int wordNum) { + int current = begin; + OcrResultModel res = new OcrResultModel(); + res.setConfidence(raw[current]); + current++; + for (int i = 0; i < pointNum; i++) { + res.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1])); + } + current += (pointNum * 2); + for (int i = 0; i < wordNum; i++) { + int index = Math.round(raw[current + i]); + res.addWordIndex(index); + } + current += wordNum; + res.setClsIdx(raw[current]); + res.setClsConfidence(raw[current + 1]); + Log.i("OCRPredictorNative", "word finished " + wordNum); + return res; + } + + +} diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java new file mode 100644 index 0000000..1bccbc7 --- /dev/null +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java @@ -0,0 +1,79 @@ +package com.baidu.paddle.lite.demo.ocr; + +import android.graphics.Point; + +import java.util.ArrayList; +import java.util.List; + +public class OcrResultModel { + private List points; + private List wordIndex; + private String label; + private float confidence; + private float cls_idx; + private String cls_label; + private float cls_confidence; + + public OcrResultModel() { + super(); + points = new ArrayList<>(); + wordIndex = new ArrayList<>(); + } + + public void addPoints(int x, int y) { + Point point = new Point(x, y); + points.add(point); + } + + public void addWordIndex(int index) { + wordIndex.add(index); + } + + public List getPoints() { + return points; + } + + public List getWordIndex() { + return wordIndex; + } + + public String getLabel() { + return label; + } + + public void setLabel(String label) { + this.label = label; + } + + public float getConfidence() { + return confidence; + } + + public void setConfidence(float confidence) { + this.confidence = confidence; + } + + public float getClsIdx() { + return cls_idx; + } + + public void setClsIdx(float idx) { + this.cls_idx = idx; + } + + public String getClsLabel() { + return cls_label; + } + + public void setClsLabel(String label) { + this.cls_label = label; + } + + public float getClsConfidence() { + return cls_confidence; + } + + public void setClsConfidence(float confidence) { + this.cls_confidence = confidence; + } +} diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java new file mode 100644 index 0000000..ab31216 --- /dev/null +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java @@ -0,0 +1,278 @@ +package com.baidu.paddle.lite.demo.ocr; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.Path; +import android.graphics.Point; +import android.util.Log; + +import java.io.File; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Vector; + +import static android.graphics.Color.*; + +public class Predictor { + private static final String TAG = Predictor.class.getSimpleName(); + public boolean isLoaded = false; + public int warmupIterNum = 1; + public int inferIterNum = 1; + public int cpuThreadNum = 4; + public String cpuPowerMode = "LITE_POWER_HIGH"; + public String modelPath = ""; + public String modelName = ""; + protected OCRPredictorNative paddlePredictor = null; + protected float inferenceTime = 0; + // Only for object detection + protected Vector wordLabels = new Vector(); + protected int detLongSize = 960; + protected float scoreThreshold = 0.1f; + protected Bitmap inputImage = null; + protected Bitmap outputImage = null; + protected volatile String outputResult = ""; + protected float postprocessTime = 0; + + + public Predictor() { + } + + public boolean init(Context appCtx, String modelPath, String labelPath, int useOpencl, int cpuThreadNum, String cpuPowerMode) { + isLoaded = loadModel(appCtx, modelPath, useOpencl, cpuThreadNum, cpuPowerMode); + if (!isLoaded) { + return false; + } + isLoaded = loadLabel(appCtx, labelPath); + return isLoaded; + } + + + public boolean init(Context appCtx, String modelPath, String labelPath, int useOpencl, int cpuThreadNum, String cpuPowerMode, + int detLongSize, float scoreThreshold) { + boolean isLoaded = init(appCtx, modelPath, labelPath, useOpencl, cpuThreadNum, cpuPowerMode); + if (!isLoaded) { + return false; + } + this.detLongSize = detLongSize; + this.scoreThreshold = scoreThreshold; + return true; + } + + protected boolean loadModel(Context appCtx, String modelPath, int useOpencl, int cpuThreadNum, String cpuPowerMode) { + // Release model if exists + releaseModel(); + + // Load model + if (modelPath.isEmpty()) { + return false; + } + String realPath = modelPath; + if (!modelPath.substring(0, 1).equals("/")) { + // Read model files from custom path if the first character of mode path is '/' + // otherwise copy model to cache from assets + realPath = appCtx.getCacheDir() + "/" + modelPath; + Utils.copyDirectoryFromAssets(appCtx, modelPath, realPath); + } + if (realPath.isEmpty()) { + return false; + } + + OCRPredictorNative.Config config = new OCRPredictorNative.Config(); + config.useOpencl = useOpencl; + config.cpuThreadNum = cpuThreadNum; + config.cpuPower = cpuPowerMode; + config.detModelFilename = realPath + File.separator + "det_db.nb"; + config.recModelFilename = realPath + File.separator + "rec_crnn.nb"; + config.clsModelFilename = realPath + File.separator + "cls.nb"; + Log.i("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename + ";" + config.clsModelFilename); + paddlePredictor = new OCRPredictorNative(config); + + this.cpuThreadNum = cpuThreadNum; + this.cpuPowerMode = cpuPowerMode; + this.modelPath = realPath; + this.modelName = realPath.substring(realPath.lastIndexOf("/") + 1); + return true; + } + + public void releaseModel() { + if (paddlePredictor != null) { + paddlePredictor.destory(); + paddlePredictor = null; + } + isLoaded = false; + cpuThreadNum = 1; + cpuPowerMode = "LITE_POWER_HIGH"; + modelPath = ""; + modelName = ""; + } + + protected boolean loadLabel(Context appCtx, String labelPath) { + wordLabels.clear(); + wordLabels.add("black"); + // Load word labels from file + try { + InputStream assetsInputStream = appCtx.getAssets().open(labelPath); + int available = assetsInputStream.available(); + byte[] lines = new byte[available]; + assetsInputStream.read(lines); + assetsInputStream.close(); + String words = new String(lines); + String[] contents = words.split("\n"); + for (String content : contents) { + wordLabels.add(content); + } + wordLabels.add(" "); + Log.i(TAG, "Word label size: " + wordLabels.size()); + } catch (Exception e) { + Log.e(TAG, e.getMessage()); + return false; + } + return true; + } + + + public boolean runModel(int run_det, int run_cls, int run_rec) { + if (inputImage == null || !isLoaded()) { + return false; + } + + // Warm up + for (int i = 0; i < warmupIterNum; i++) { + paddlePredictor.runImage(inputImage, detLongSize, run_det, run_cls, run_rec); + } + warmupIterNum = 0; // do not need warm + // Run inference + Date start = new Date(); + ArrayList results = paddlePredictor.runImage(inputImage, detLongSize, run_det, run_cls, run_rec); + Date end = new Date(); + inferenceTime = (end.getTime() - start.getTime()) / (float) inferIterNum; + + results = postprocess(results); + Log.i(TAG, "[stat] Inference Time: " + inferenceTime + " ;Box Size " + results.size()); + drawResults(results); + + return true; + } + + public boolean isLoaded() { + return paddlePredictor != null && isLoaded; + } + + public String modelPath() { + return modelPath; + } + + public String modelName() { + return modelName; + } + + public int cpuThreadNum() { + return cpuThreadNum; + } + + public String cpuPowerMode() { + return cpuPowerMode; + } + + public float inferenceTime() { + return inferenceTime; + } + + public Bitmap inputImage() { + return inputImage; + } + + public Bitmap outputImage() { + return outputImage; + } + + public String outputResult() { + return outputResult; + } + + public float postprocessTime() { + return postprocessTime; + } + + + public void setInputImage(Bitmap image) { + if (image == null) { + return; + } + this.inputImage = image.copy(Bitmap.Config.ARGB_8888, true); + } + + private ArrayList postprocess(ArrayList results) { + for (OcrResultModel r : results) { + StringBuffer word = new StringBuffer(); + for (int index : r.getWordIndex()) { + if (index >= 0 && index < wordLabels.size()) { + word.append(wordLabels.get(index)); + } else { + Log.e(TAG, "Word index is not in label list:" + index); + word.append("×"); + } + } + r.setLabel(word.toString()); + r.setClsLabel(r.getClsIdx() == 1 ? "180" : "0"); + } + return results; + } + + private void drawResults(ArrayList results) { + StringBuffer outputResultSb = new StringBuffer(""); + for (int i = 0; i < results.size(); i++) { + OcrResultModel result = results.get(i); + StringBuilder sb = new StringBuilder(""); + if(result.getPoints().size()>0){ + sb.append("Det: "); + for (Point p : result.getPoints()) { + sb.append("(").append(p.x).append(",").append(p.y).append(") "); + } + } + if(result.getLabel().length() > 0){ + sb.append("\n Rec: ").append(result.getLabel()); + sb.append(",").append(result.getConfidence()); + } + if(result.getClsIdx()!=-1){ + sb.append(" Cls: ").append(result.getClsLabel()); + sb.append(",").append(result.getClsConfidence()); + } + Log.i(TAG, sb.toString()); // show LOG in Logcat panel + outputResultSb.append(i + 1).append(": ").append(sb.toString()).append("\n"); + } + outputResult = outputResultSb.toString(); + outputImage = inputImage; + Canvas canvas = new Canvas(outputImage); + Paint paintFillAlpha = new Paint(); + paintFillAlpha.setStyle(Paint.Style.FILL); + paintFillAlpha.setColor(Color.parseColor("#3B85F5")); + paintFillAlpha.setAlpha(50); + + Paint paint = new Paint(); + paint.setColor(Color.parseColor("#3B85F5")); + paint.setStrokeWidth(5); + paint.setStyle(Paint.Style.STROKE); + + for (OcrResultModel result : results) { + Path path = new Path(); + List points = result.getPoints(); + if(points.size()==0){ + continue; + } + path.moveTo(points.get(0).x, points.get(0).y); + for (int i = points.size() - 1; i >= 0; i--) { + Point p = points.get(i); + path.lineTo(p.x, p.y); + } + canvas.drawPath(path, paint); + canvas.drawPath(path, paintFillAlpha); + } + } + +} diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/SettingsActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/SettingsActivity.java new file mode 100644 index 0000000..477cd5d --- /dev/null +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/SettingsActivity.java @@ -0,0 +1,172 @@ +package com.baidu.paddle.lite.demo.ocr; + +import android.content.SharedPreferences; +import android.os.Bundle; +import android.preference.CheckBoxPreference; +import android.preference.EditTextPreference; +import android.preference.ListPreference; + +import androidx.appcompat.app.ActionBar; + +import java.util.ArrayList; +import java.util.List; + + +public class SettingsActivity extends AppCompatPreferenceActivity implements SharedPreferences.OnSharedPreferenceChangeListener { + ListPreference lpChoosePreInstalledModel = null; + CheckBoxPreference cbEnableCustomSettings = null; + EditTextPreference etModelPath = null; + EditTextPreference etLabelPath = null; + ListPreference etImagePath = null; + ListPreference lpCPUThreadNum = null; + ListPreference lpCPUPowerMode = null; + EditTextPreference etDetLongSize = null; + EditTextPreference etScoreThreshold = null; + + List preInstalledModelPaths = null; + List preInstalledLabelPaths = null; + List preInstalledImagePaths = null; + List preInstalledDetLongSizes = null; + List preInstalledCPUThreadNums = null; + List preInstalledCPUPowerModes = null; + List preInstalledInputColorFormats = null; + List preInstalledInputMeans = null; + List preInstalledInputStds = null; + List preInstalledScoreThresholds = null; + + @Override + public void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + addPreferencesFromResource(R.xml.settings); + ActionBar supportActionBar = getSupportActionBar(); + if (supportActionBar != null) { + supportActionBar.setDisplayHomeAsUpEnabled(true); + } + + // Initialized pre-installed models + preInstalledModelPaths = new ArrayList(); + preInstalledLabelPaths = new ArrayList(); + preInstalledImagePaths = new ArrayList(); + preInstalledDetLongSizes = new ArrayList(); + preInstalledCPUThreadNums = new ArrayList(); + preInstalledCPUPowerModes = new ArrayList(); + preInstalledInputColorFormats = new ArrayList(); + preInstalledInputMeans = new ArrayList(); + preInstalledInputStds = new ArrayList(); + preInstalledScoreThresholds = new ArrayList(); + // Add ssd_mobilenet_v1_pascalvoc_for_cpu + preInstalledModelPaths.add(getString(R.string.MODEL_PATH_DEFAULT)); + preInstalledLabelPaths.add(getString(R.string.LABEL_PATH_DEFAULT)); + preInstalledImagePaths.add(getString(R.string.IMAGE_PATH_DEFAULT)); + preInstalledCPUThreadNums.add(getString(R.string.CPU_THREAD_NUM_DEFAULT)); + preInstalledCPUPowerModes.add(getString(R.string.CPU_POWER_MODE_DEFAULT)); + preInstalledDetLongSizes.add(getString(R.string.DET_LONG_SIZE_DEFAULT)); + preInstalledScoreThresholds.add(getString(R.string.SCORE_THRESHOLD_DEFAULT)); + + // Setup UI components + lpChoosePreInstalledModel = + (ListPreference) findPreference(getString(R.string.CHOOSE_PRE_INSTALLED_MODEL_KEY)); + String[] preInstalledModelNames = new String[preInstalledModelPaths.size()]; + for (int i = 0; i < preInstalledModelPaths.size(); i++) { + preInstalledModelNames[i] = + preInstalledModelPaths.get(i).substring(preInstalledModelPaths.get(i).lastIndexOf("/") + 1); + } + lpChoosePreInstalledModel.setEntries(preInstalledModelNames); + lpChoosePreInstalledModel.setEntryValues(preInstalledModelPaths.toArray(new String[preInstalledModelPaths.size()])); + cbEnableCustomSettings = + (CheckBoxPreference) findPreference(getString(R.string.ENABLE_CUSTOM_SETTINGS_KEY)); + etModelPath = (EditTextPreference) findPreference(getString(R.string.MODEL_PATH_KEY)); + etModelPath.setTitle("Model Path (SDCard: " + Utils.getSDCardDirectory() + ")"); + etLabelPath = (EditTextPreference) findPreference(getString(R.string.LABEL_PATH_KEY)); + etImagePath = (ListPreference) findPreference(getString(R.string.IMAGE_PATH_KEY)); + lpCPUThreadNum = + (ListPreference) findPreference(getString(R.string.CPU_THREAD_NUM_KEY)); + lpCPUPowerMode = + (ListPreference) findPreference(getString(R.string.CPU_POWER_MODE_KEY)); + etDetLongSize = (EditTextPreference) findPreference(getString(R.string.DET_LONG_SIZE_KEY)); + etScoreThreshold = (EditTextPreference) findPreference(getString(R.string.SCORE_THRESHOLD_KEY)); + } + + private void reloadPreferenceAndUpdateUI() { + SharedPreferences sharedPreferences = getPreferenceScreen().getSharedPreferences(); + boolean enableCustomSettings = + sharedPreferences.getBoolean(getString(R.string.ENABLE_CUSTOM_SETTINGS_KEY), false); + String modelPath = sharedPreferences.getString(getString(R.string.CHOOSE_PRE_INSTALLED_MODEL_KEY), + getString(R.string.MODEL_PATH_DEFAULT)); + int modelIdx = lpChoosePreInstalledModel.findIndexOfValue(modelPath); + if (modelIdx >= 0 && modelIdx < preInstalledModelPaths.size()) { + if (!enableCustomSettings) { + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.putString(getString(R.string.MODEL_PATH_KEY), preInstalledModelPaths.get(modelIdx)); + editor.putString(getString(R.string.LABEL_PATH_KEY), preInstalledLabelPaths.get(modelIdx)); + editor.putString(getString(R.string.IMAGE_PATH_KEY), preInstalledImagePaths.get(modelIdx)); + editor.putString(getString(R.string.CPU_THREAD_NUM_KEY), preInstalledCPUThreadNums.get(modelIdx)); + editor.putString(getString(R.string.CPU_POWER_MODE_KEY), preInstalledCPUPowerModes.get(modelIdx)); + editor.putString(getString(R.string.DET_LONG_SIZE_KEY), preInstalledDetLongSizes.get(modelIdx)); + editor.putString(getString(R.string.SCORE_THRESHOLD_KEY), + preInstalledScoreThresholds.get(modelIdx)); + editor.apply(); + } + lpChoosePreInstalledModel.setSummary(modelPath); + } + cbEnableCustomSettings.setChecked(enableCustomSettings); + etModelPath.setEnabled(enableCustomSettings); + etLabelPath.setEnabled(enableCustomSettings); + etImagePath.setEnabled(enableCustomSettings); + lpCPUThreadNum.setEnabled(enableCustomSettings); + lpCPUPowerMode.setEnabled(enableCustomSettings); + etDetLongSize.setEnabled(enableCustomSettings); + etScoreThreshold.setEnabled(enableCustomSettings); + modelPath = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY), + getString(R.string.MODEL_PATH_DEFAULT)); + String labelPath = sharedPreferences.getString(getString(R.string.LABEL_PATH_KEY), + getString(R.string.LABEL_PATH_DEFAULT)); + String imagePath = sharedPreferences.getString(getString(R.string.IMAGE_PATH_KEY), + getString(R.string.IMAGE_PATH_DEFAULT)); + String cpuThreadNum = sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY), + getString(R.string.CPU_THREAD_NUM_DEFAULT)); + String cpuPowerMode = sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY), + getString(R.string.CPU_POWER_MODE_DEFAULT)); + String detLongSize = sharedPreferences.getString(getString(R.string.DET_LONG_SIZE_KEY), + getString(R.string.DET_LONG_SIZE_DEFAULT)); + String scoreThreshold = sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY), + getString(R.string.SCORE_THRESHOLD_DEFAULT)); + etModelPath.setSummary(modelPath); + etModelPath.setText(modelPath); + etLabelPath.setSummary(labelPath); + etLabelPath.setText(labelPath); + etImagePath.setSummary(imagePath); + etImagePath.setValue(imagePath); + lpCPUThreadNum.setValue(cpuThreadNum); + lpCPUThreadNum.setSummary(cpuThreadNum); + lpCPUPowerMode.setValue(cpuPowerMode); + lpCPUPowerMode.setSummary(cpuPowerMode); + etDetLongSize.setSummary(detLongSize); + etDetLongSize.setText(detLongSize); + etScoreThreshold.setText(scoreThreshold); + etScoreThreshold.setSummary(scoreThreshold); + } + + @Override + protected void onResume() { + super.onResume(); + getPreferenceScreen().getSharedPreferences().registerOnSharedPreferenceChangeListener(this); + reloadPreferenceAndUpdateUI(); + } + + @Override + protected void onPause() { + super.onPause(); + getPreferenceScreen().getSharedPreferences().unregisterOnSharedPreferenceChangeListener(this); + } + + @Override + public void onSharedPreferenceChanged(SharedPreferences sharedPreferences, String key) { + if (key.equals(getString(R.string.CHOOSE_PRE_INSTALLED_MODEL_KEY))) { + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.putBoolean(getString(R.string.ENABLE_CUSTOM_SETTINGS_KEY), false); + editor.commit(); + } + reloadPreferenceAndUpdateUI(); + } +} diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Utils.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Utils.java new file mode 100644 index 0000000..ef46805 --- /dev/null +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Utils.java @@ -0,0 +1,159 @@ +package com.baidu.paddle.lite.demo.ocr; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Matrix; +import android.media.ExifInterface; +import android.os.Environment; + +import java.io.*; + +public class Utils { + private static final String TAG = Utils.class.getSimpleName(); + + public static void copyFileFromAssets(Context appCtx, String srcPath, String dstPath) { + if (srcPath.isEmpty() || dstPath.isEmpty()) { + return; + } + InputStream is = null; + OutputStream os = null; + try { + is = new BufferedInputStream(appCtx.getAssets().open(srcPath)); + os = new BufferedOutputStream(new FileOutputStream(new File(dstPath))); + byte[] buffer = new byte[1024]; + int length = 0; + while ((length = is.read(buffer)) != -1) { + os.write(buffer, 0, length); + } + } catch (FileNotFoundException e) { + e.printStackTrace(); + } catch (IOException e) { + e.printStackTrace(); + } finally { + try { + os.close(); + is.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + + public static void copyDirectoryFromAssets(Context appCtx, String srcDir, String dstDir) { + if (srcDir.isEmpty() || dstDir.isEmpty()) { + return; + } + try { + if (!new File(dstDir).exists()) { + new File(dstDir).mkdirs(); + } + for (String fileName : appCtx.getAssets().list(srcDir)) { + String srcSubPath = srcDir + File.separator + fileName; + String dstSubPath = dstDir + File.separator + fileName; + if (new File(srcSubPath).isDirectory()) { + copyDirectoryFromAssets(appCtx, srcSubPath, dstSubPath); + } else { + copyFileFromAssets(appCtx, srcSubPath, dstSubPath); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + public static float[] parseFloatsFromString(String string, String delimiter) { + String[] pieces = string.trim().toLowerCase().split(delimiter); + float[] floats = new float[pieces.length]; + for (int i = 0; i < pieces.length; i++) { + floats[i] = Float.parseFloat(pieces[i].trim()); + } + return floats; + } + + public static long[] parseLongsFromString(String string, String delimiter) { + String[] pieces = string.trim().toLowerCase().split(delimiter); + long[] longs = new long[pieces.length]; + for (int i = 0; i < pieces.length; i++) { + longs[i] = Long.parseLong(pieces[i].trim()); + } + return longs; + } + + public static String getSDCardDirectory() { + return Environment.getExternalStorageDirectory().getAbsolutePath(); + } + + public static boolean isSupportedNPU() { + return false; + // String hardware = android.os.Build.HARDWARE; + // return hardware.equalsIgnoreCase("kirin810") || hardware.equalsIgnoreCase("kirin990"); + } + + public static Bitmap resizeWithStep(Bitmap bitmap, int maxLength, int step) { + int width = bitmap.getWidth(); + int height = bitmap.getHeight(); + int maxWH = Math.max(width, height); + float ratio = 1; + int newWidth = width; + int newHeight = height; + if (maxWH > maxLength) { + ratio = maxLength * 1.0f / maxWH; + newWidth = (int) Math.floor(ratio * width); + newHeight = (int) Math.floor(ratio * height); + } + + newWidth = newWidth - newWidth % step; + if (newWidth == 0) { + newWidth = step; + } + newHeight = newHeight - newHeight % step; + if (newHeight == 0) { + newHeight = step; + } + return Bitmap.createScaledBitmap(bitmap, newWidth, newHeight, true); + } + + public static Bitmap rotateBitmap(Bitmap bitmap, int orientation) { + + Matrix matrix = new Matrix(); + switch (orientation) { + case ExifInterface.ORIENTATION_NORMAL: + return bitmap; + case ExifInterface.ORIENTATION_FLIP_HORIZONTAL: + matrix.setScale(-1, 1); + break; + case ExifInterface.ORIENTATION_ROTATE_180: + matrix.setRotate(180); + break; + case ExifInterface.ORIENTATION_FLIP_VERTICAL: + matrix.setRotate(180); + matrix.postScale(-1, 1); + break; + case ExifInterface.ORIENTATION_TRANSPOSE: + matrix.setRotate(90); + matrix.postScale(-1, 1); + break; + case ExifInterface.ORIENTATION_ROTATE_90: + matrix.setRotate(90); + break; + case ExifInterface.ORIENTATION_TRANSVERSE: + matrix.setRotate(-90); + matrix.postScale(-1, 1); + break; + case ExifInterface.ORIENTATION_ROTATE_270: + matrix.setRotate(-90); + break; + default: + return bitmap; + } + try { + Bitmap bmRotated = Bitmap.createBitmap(bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true); + bitmap.recycle(); + return bmRotated; + } + catch (OutOfMemoryError e) { + e.printStackTrace(); + return null; + } + } +} diff --git a/deploy/android_demo/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/deploy/android_demo/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 0000000..1f6bb29 --- /dev/null +++ b/deploy/android_demo/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/deploy/android_demo/app/src/main/res/drawable/ic_launcher_background.xml b/deploy/android_demo/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 0000000..0d025f9 --- /dev/null +++ b/deploy/android_demo/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/deploy/android_demo/app/src/main/res/layout/activity_main.xml b/deploy/android_demo/app/src/main/res/layout/activity_main.xml new file mode 100644 index 0000000..e90c99a --- /dev/null +++ b/deploy/android_demo/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,180 @@ + + + + + + + + + + +