Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add data_format selection support to ocr #13328

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/cls/cls_mv3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ Architecture:
name: MobileNetV3
scale: 0.35
model_name: small
data_format: NHWC
Neck:
Head:
name: ClsHead
class_dim: 2
data_format: NHWC

Loss:
name: ClsLoss
Expand Down
9 changes: 9 additions & 0 deletions configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ Architecture:
scale: 0.5
model_name: large
disable_se: true
data_format: NHWC
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
data_format: NHWC
Head:
name: DBHead
k: 50
data_format: NHWC
Student2:
pretrained:
model_type: det
Expand All @@ -52,13 +55,16 @@ Architecture:
scale: 0.5
model_name: large
disable_se: true
data_format: NHWC
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
data_format: NHWC
Head:
name: DBHead
k: 50
data_format: NHWC
Teacher:
freeze_params: true
return_all_feats: false
Expand All @@ -68,13 +74,16 @@ Architecture:
name: ResNet_vd
in_channels: 3
layers: 50
data_format: NHWC
Neck:
name: LKPAN
out_channels: 256
data_format: NHWC
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
data_format: NHWC

Loss:
name: CombinedLoss
Expand Down
7 changes: 5 additions & 2 deletions configs/det/det_mv3_db.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ Architecture:
name: MobileNetV3
scale: 0.5
model_name: large
data_format: NHWC
Neck:
name: DBFPN
out_channels: 256
data_format: NHWC
Head:
name: DBHead
k: 50
data_format: NHWC

Loss:
name: DBLoss
Expand Down Expand Up @@ -64,7 +67,7 @@ Metric:
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
data_dir: ./
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
Expand Down Expand Up @@ -107,7 +110,7 @@ Train:
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
data_dir: ./
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
Expand Down
2 changes: 2 additions & 0 deletions configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Architecture:
last_conv_stride: [1, 2]
last_pool_type: avg
last_pool_kernel_size: [2, 2]
data_format: 'NHWC'
Head:
name: MultiHead
head_list:
Expand All @@ -59,6 +60,7 @@ Architecture:
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
data_format: 'NHWC'

Loss:
name: MultiLoss
Expand Down
4 changes: 4 additions & 0 deletions configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Architecture:
last_conv_stride: [1, 2]
last_pool_type: avg
last_pool_kernel_size: [2, 2]
data_format: 'NHWC'
Head:
name: MultiHead
head_list:
Expand All @@ -69,6 +70,7 @@ Architecture:
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
data_format: 'NHWC'
Student:
pretrained:
freeze_params: false
Expand All @@ -82,6 +84,7 @@ Architecture:
last_conv_stride: [1, 2]
last_pool_type: avg
last_pool_kernel_size: [2, 2]
data_format: 'NHWC'
Head:
name: MultiHead
head_list:
Expand All @@ -97,6 +100,7 @@ Architecture:
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
data_format: 'NHWC'
Loss:
name: CombinedLoss
loss_config_list:
Expand Down
32 changes: 27 additions & 5 deletions ppocr/modeling/backbones/det_mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def make_divisible(v, divisor=8, min_value=None):

class MobileNetV3(nn.Layer):
def __init__(
self, in_channels=3, model_name="large", scale=0.5, disable_se=False, **kwargs
self,
in_channels=3,
model_name="large",
scale=0.5,
disable_se=False,
data_format="NCHW",
**kwargs,
):
"""
the MobilenetV3 backbone network for detection module.
Expand All @@ -46,6 +52,7 @@ def __init__(

self.disable_se = disable_se

self.nchw = data_format == "NCHW"
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
Expand Down Expand Up @@ -102,6 +109,7 @@ def __init__(
groups=1,
if_act=True,
act="hardswish",
data_format=data_format,
)

self.stages = []
Expand All @@ -125,6 +133,7 @@ def __init__(
stride=s,
use_se=se,
act=nl,
data_format=data_format,
)
)
inplanes = make_divisible(scale * c)
Expand All @@ -139,6 +148,7 @@ def __init__(
groups=1,
if_act=True,
act="hardswish",
data_format=data_format,
)
)
self.stages.append(nn.Sequential(*block_list))
Expand All @@ -147,6 +157,8 @@ def __init__(
self.add_sublayer(sublayer=stage, name="stage{}".format(i))

def forward(self, x):
if not self.nchw:
x = x.transpose([0, 2, 3, 1])
x = self.conv(x)
out_list = []
for stage in self.stages:
Expand All @@ -166,6 +178,7 @@ def __init__(
groups=1,
if_act=True,
act=None,
data_format="NCHW",
):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
Expand All @@ -178,9 +191,12 @@ def __init__(
padding=padding,
groups=groups,
bias_attr=False,
data_format=data_format,
)

self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
self.bn = nn.BatchNorm(
num_channels=out_channels, act=None, data_layout=data_format
)

def forward(self, x):
x = self.conv(x)
Expand Down Expand Up @@ -210,6 +226,7 @@ def __init__(
stride,
use_se,
act=None,
data_format="NCHW",
):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
Expand All @@ -223,6 +240,7 @@ def __init__(
padding=0,
if_act=True,
act=act,
data_format=data_format,
)
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
Expand All @@ -233,9 +251,10 @@ def __init__(
groups=mid_channels,
if_act=True,
act=act,
data_format=data_format,
)
if self.if_se:
self.mid_se = SEModule(mid_channels)
self.mid_se = SEModule(mid_channels, data_format=data_format)
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
Expand All @@ -244,6 +263,7 @@ def __init__(
padding=0,
if_act=False,
act=None,
data_format=data_format,
)

def forward(self, inputs):
Expand All @@ -258,22 +278,24 @@ def forward(self, inputs):


class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4):
def __init__(self, in_channels, reduction=4, data_format="NCHW"):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.avg_pool = nn.AdaptiveAvgPool2D(1, data_format=data_format)
self.conv1 = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
padding=0,
data_format=data_format,
)
self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
padding=0,
data_format=data_format,
)

def forward(self, inputs):
Expand Down
Loading