diff --git a/configs/cls/cls_mv3.yml b/configs/cls/cls_mv3.yml index 0c46ff5602..8016f91c6a 100644 --- a/configs/cls/cls_mv3.yml +++ b/configs/cls/cls_mv3.yml @@ -23,10 +23,12 @@ Architecture: name: MobileNetV3 scale: 0.35 model_name: small + data_format: NHWC Neck: Head: name: ClsHead class_dim: 2 + data_format: NHWC Loss: name: ClsLoss diff --git a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml index 252d159977..7123dd1c10 100644 --- a/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml +++ b/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml @@ -35,13 +35,16 @@ Architecture: scale: 0.5 model_name: large disable_se: true + data_format: NHWC Neck: name: RSEFPN out_channels: 96 shortcut: True + data_format: NHWC Head: name: DBHead k: 50 + data_format: NHWC Student2: pretrained: model_type: det @@ -52,13 +55,16 @@ Architecture: scale: 0.5 model_name: large disable_se: true + data_format: NHWC Neck: name: RSEFPN out_channels: 96 shortcut: True + data_format: NHWC Head: name: DBHead k: 50 + data_format: NHWC Teacher: freeze_params: true return_all_feats: false @@ -68,13 +74,16 @@ Architecture: name: ResNet_vd in_channels: 3 layers: 50 + data_format: NHWC Neck: name: LKPAN out_channels: 256 + data_format: NHWC Head: name: DBHead kernel_list: [7,2,2] k: 50 + data_format: NHWC Loss: name: CombinedLoss diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index 8f5685ec2a..4455730728 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -25,12 +25,15 @@ Architecture: name: MobileNetV3 scale: 0.5 model_name: large + data_format: NHWC Neck: name: DBFPN out_channels: 256 + data_format: NHWC Head: name: DBHead k: 50 + data_format: NHWC Loss: name: DBLoss @@ -64,7 +67,7 @@ Metric: Train: dataset: name: SimpleDataSet - data_dir: ./train_data/icdar2015/text_localization/ + data_dir: ./ label_file_list: - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt ratio_list: [1.0] @@ -107,7 +110,7 @@ Train: Eval: dataset: name: SimpleDataSet - data_dir: ./train_data/icdar2015/text_localization/ + data_dir: ./ label_file_list: - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt transforms: diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml index fd15873fbf..39ca94f06b 100644 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml @@ -44,6 +44,7 @@ Architecture: last_conv_stride: [1, 2] last_pool_type: avg last_pool_kernel_size: [2, 2] + data_format: 'NHWC' Head: name: MultiHead head_list: @@ -59,6 +60,7 @@ Architecture: - SARHead: enc_dim: 512 max_text_length: *max_text_length + data_format: 'NHWC' Loss: name: MultiLoss diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml old mode 100644 new mode 100755 index 3b82ef857f..31fac2e585 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml @@ -54,6 +54,7 @@ Architecture: last_conv_stride: [1, 2] last_pool_type: avg last_pool_kernel_size: [2, 2] + data_format: 'NHWC' Head: name: MultiHead head_list: @@ -69,6 +70,7 @@ Architecture: - SARHead: enc_dim: 512 max_text_length: *max_text_length + data_format: 'NHWC' Student: pretrained: freeze_params: false @@ -82,6 +84,7 @@ Architecture: last_conv_stride: [1, 2] last_pool_type: avg last_pool_kernel_size: [2, 2] + data_format: 'NHWC' Head: name: MultiHead head_list: @@ -97,6 +100,7 @@ Architecture: - SARHead: enc_dim: 512 max_text_length: *max_text_length + data_format: 'NHWC' Loss: name: CombinedLoss loss_config_list: diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index 98db44b691..46c088163f 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -35,7 +35,13 @@ def make_divisible(v, divisor=8, min_value=None): class MobileNetV3(nn.Layer): def __init__( - self, in_channels=3, model_name="large", scale=0.5, disable_se=False, **kwargs + self, + in_channels=3, + model_name="large", + scale=0.5, + disable_se=False, + data_format="NCHW", + **kwargs, ): """ the MobilenetV3 backbone network for detection module. @@ -46,6 +52,7 @@ def __init__( self.disable_se = disable_se + self.nchw = data_format == "NCHW" if model_name == "large": cfg = [ # k, exp, c, se, nl, s, @@ -102,6 +109,7 @@ def __init__( groups=1, if_act=True, act="hardswish", + data_format=data_format, ) self.stages = [] @@ -125,6 +133,7 @@ def __init__( stride=s, use_se=se, act=nl, + data_format=data_format, ) ) inplanes = make_divisible(scale * c) @@ -139,6 +148,7 @@ def __init__( groups=1, if_act=True, act="hardswish", + data_format=data_format, ) ) self.stages.append(nn.Sequential(*block_list)) @@ -147,6 +157,8 @@ def __init__( self.add_sublayer(sublayer=stage, name="stage{}".format(i)) def forward(self, x): + if not self.nchw: + x = x.transpose([0, 2, 3, 1]) x = self.conv(x) out_list = [] for stage in self.stages: @@ -166,6 +178,7 @@ def __init__( groups=1, if_act=True, act=None, + data_format="NCHW", ): super(ConvBNLayer, self).__init__() self.if_act = if_act @@ -178,9 +191,12 @@ def __init__( padding=padding, groups=groups, bias_attr=False, + data_format=data_format, ) - self.bn = nn.BatchNorm(num_channels=out_channels, act=None) + self.bn = nn.BatchNorm( + num_channels=out_channels, act=None, data_layout=data_format + ) def forward(self, x): x = self.conv(x) @@ -210,6 +226,7 @@ def __init__( stride, use_se, act=None, + data_format="NCHW", ): super(ResidualUnit, self).__init__() self.if_shortcut = stride == 1 and in_channels == out_channels @@ -223,6 +240,7 @@ def __init__( padding=0, if_act=True, act=act, + data_format=data_format, ) self.bottleneck_conv = ConvBNLayer( in_channels=mid_channels, @@ -233,9 +251,10 @@ def __init__( groups=mid_channels, if_act=True, act=act, + data_format=data_format, ) if self.if_se: - self.mid_se = SEModule(mid_channels) + self.mid_se = SEModule(mid_channels, data_format=data_format) self.linear_conv = ConvBNLayer( in_channels=mid_channels, out_channels=out_channels, @@ -244,6 +263,7 @@ def __init__( padding=0, if_act=False, act=None, + data_format=data_format, ) def forward(self, inputs): @@ -258,15 +278,16 @@ def forward(self, inputs): class SEModule(nn.Layer): - def __init__(self, in_channels, reduction=4): + def __init__(self, in_channels, reduction=4, data_format="NCHW"): super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.avg_pool = nn.AdaptiveAvgPool2D(1, data_format=data_format) self.conv1 = nn.Conv2D( in_channels=in_channels, out_channels=in_channels // reduction, kernel_size=1, stride=1, padding=0, + data_format=data_format, ) self.conv2 = nn.Conv2D( in_channels=in_channels // reduction, @@ -274,6 +295,7 @@ def __init__(self, in_channels, reduction=4): kernel_size=1, stride=1, padding=0, + data_format=data_format, ) def forward(self, inputs): diff --git a/ppocr/modeling/backbones/det_resnet_vd.py b/ppocr/modeling/backbones/det_resnet_vd.py index 070ba3c978..4113c11678 100644 --- a/ppocr/modeling/backbones/det_resnet_vd.py +++ b/ppocr/modeling/backbones/det_resnet_vd.py @@ -45,6 +45,7 @@ def __init__( skip_quant=False, dcn_bias_regularizer=L2Decay(0.0), dcn_bias_lr_scale=2.0, + data_format="NCHW", ): super(DeformableConvV2, self).__init__() self.offset_channel = 2 * kernel_size**2 * groups @@ -70,6 +71,7 @@ def __init__( deformable_groups=groups, weight_attr=weight_attr, bias_attr=dcn_bias_attr, + data_format=data_format, ) if lr_scale == 1 and regularizer is None: @@ -88,6 +90,7 @@ def __init__( padding=(kernel_size - 1) // 2, weight_attr=ParamAttr(initializer=Constant(0.0)), bias_attr=offset_bias_attr, + data_format=data_format, ) if skip_quant: self.conv_offset.skip_quant = True @@ -116,12 +119,13 @@ def __init__( is_vd_mode=False, act=None, is_dcn=False, + data_format="NCHW", ): super(ConvBNLayer, self).__init__() self.is_vd_mode = is_vd_mode self._pool2d_avg = nn.AvgPool2D( - kernel_size=2, stride=2, padding=0, ceil_mode=True + kernel_size=2, stride=2, padding=0, ceil_mode=True, data_format=data_format ) if not is_dcn: self._conv = nn.Conv2D( @@ -132,6 +136,7 @@ def __init__( padding=(kernel_size - 1) // 2, groups=groups, bias_attr=False, + data_format=data_format, ) else: self._conv = DeformableConvV2( @@ -142,8 +147,9 @@ def __init__( padding=(kernel_size - 1) // 2, groups=dcn_groups, # groups, bias_attr=False, + data_format=data_format, ) - self._batch_norm = nn.BatchNorm(out_channels, act=act) + self._batch_norm = nn.BatchNorm(out_channels, act=act, data_layout=data_format) def forward(self, inputs): if self.is_vd_mode: @@ -162,6 +168,7 @@ def __init__( shortcut=True, if_first=False, is_dcn=False, + data_format="NCHW", ): super(BottleneckBlock, self).__init__() @@ -170,6 +177,7 @@ def __init__( out_channels=out_channels, kernel_size=1, act="relu", + data_format=data_format, ) self.conv1 = ConvBNLayer( in_channels=out_channels, @@ -179,12 +187,14 @@ def __init__( act="relu", is_dcn=is_dcn, dcn_groups=2, + data_format=data_format, ) self.conv2 = ConvBNLayer( in_channels=out_channels, out_channels=out_channels * 4, kernel_size=1, act=None, + data_format=data_format, ) if not shortcut: @@ -194,6 +204,7 @@ def __init__( kernel_size=1, stride=1, is_vd_mode=False if if_first else True, + data_format=data_format, ) self.shortcut = shortcut @@ -220,6 +231,7 @@ def __init__( stride, shortcut=True, if_first=False, + data_format="NCHW", ): super(BasicBlock, self).__init__() self.stride = stride @@ -229,9 +241,14 @@ def __init__( kernel_size=3, stride=stride, act="relu", + data_format=data_format, ) self.conv1 = ConvBNLayer( - in_channels=out_channels, out_channels=out_channels, kernel_size=3, act=None + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + data_format=data_format, ) if not shortcut: @@ -241,6 +258,7 @@ def __init__( kernel_size=1, stride=1, is_vd_mode=False if if_first else True, + data_format=data_format, ) self.shortcut = shortcut @@ -260,7 +278,13 @@ def forward(self, inputs): class ResNet_vd(nn.Layer): def __init__( - self, in_channels=3, layers=50, dcn_stage=None, out_indices=None, **kwargs + self, + in_channels=3, + layers=50, + dcn_stage=None, + out_indices=None, + data_format="NCHW", + **kwargs, ): super(ResNet_vd, self).__init__() @@ -296,14 +320,27 @@ def __init__( kernel_size=3, stride=2, act="relu", + data_format=data_format, ) self.conv1_2 = ConvBNLayer( - in_channels=32, out_channels=32, kernel_size=3, stride=1, act="relu" + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + act="relu", + data_format=data_format, ) self.conv1_3 = ConvBNLayer( - in_channels=32, out_channels=64, kernel_size=3, stride=1, act="relu" + in_channels=32, + out_channels=64, + kernel_size=3, + stride=1, + act="relu", + data_format=data_format, + ) + self.pool2d_max = nn.MaxPool2D( + kernel_size=3, stride=2, padding=1, data_format=data_format ) - self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) self.stages = [] self.out_channels = [] @@ -326,6 +363,7 @@ def __init__( shortcut=shortcut, if_first=block == i == 0, is_dcn=is_dcn, + data_format=data_format, ), ) shortcut = True @@ -348,6 +386,7 @@ def __init__( stride=2 if i == 0 and block != 0 else 1, shortcut=shortcut, if_first=block == i == 0, + data_format=data_format, ), ) shortcut = True @@ -357,6 +396,8 @@ def __init__( self.stages.append(nn.Sequential(*block_list)) def forward(self, inputs): + if not self.nchw: + inputs = inputs.transpose([0, 2, 3, 1]) # NCHW -> NHWC y = self.conv1_1(inputs) y = self.conv1_2(y) y = self.conv1_3(y) diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py index 00ee5a3da0..348bb0331d 100644 --- a/ppocr/modeling/backbones/rec_mobilenet_v3.py +++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py @@ -32,10 +32,12 @@ def __init__( large_stride=None, small_stride=None, disable_se=False, + data_format="NCHW", **kwargs, ): super(MobileNetV3, self).__init__() self.disable_se = disable_se + self.nchw = data_format == "NCHW" if small_stride is None: small_stride = [2, 2, 2, 2] if large_stride is None: @@ -113,6 +115,7 @@ def __init__( groups=1, if_act=True, act="hardswish", + data_format=data_format, ) i = 0 block_list = [] @@ -128,6 +131,7 @@ def __init__( stride=s, use_se=se, act=nl, + data_format=data_format, ) ) inplanes = make_divisible(scale * c) @@ -143,12 +147,17 @@ def __init__( groups=1, if_act=True, act="hardswish", + data_format=data_format, ) - self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.pool = nn.MaxPool2D( + kernel_size=2, stride=2, padding=0, data_format=data_format + ) self.out_channels = make_divisible(scale * cls_ch_squeeze) def forward(self, x): + if not self.nchw: + x = x.transpose([0, 2, 3, 1]) # NCHW -> NHWC x = self.conv1(x) x = self.blocks(x) x = self.conv2(x) diff --git a/ppocr/modeling/backbones/rec_mv1_enhance.py b/ppocr/modeling/backbones/rec_mv1_enhance.py index f20fa4cb1d..ab6934f395 100644 --- a/ppocr/modeling/backbones/rec_mv1_enhance.py +++ b/ppocr/modeling/backbones/rec_mv1_enhance.py @@ -42,6 +42,7 @@ def __init__( channels=None, num_groups=1, act="hard_swish", + data_format="NCHW", ): super(ConvBNLayer, self).__init__() @@ -54,6 +55,7 @@ def __init__( groups=num_groups, weight_attr=ParamAttr(initializer=KaimingNormal()), bias_attr=False, + data_format=data_format, ) self._batch_norm = BatchNorm( @@ -61,6 +63,7 @@ def __init__( act=act, param_attr=ParamAttr(regularizer=L2Decay(0.0)), bias_attr=ParamAttr(regularizer=L2Decay(0.0)), + data_layout=data_format, ) def forward(self, inputs): @@ -81,6 +84,7 @@ def __init__( dw_size=3, padding=1, use_se=False, + data_format="NCHW", ): super(DepthwiseSeparable, self).__init__() self.use_se = use_se @@ -91,15 +95,17 @@ def __init__( stride=stride, padding=padding, num_groups=int(num_groups * scale), + data_format=data_format, ) if use_se: - self._se = SEModule(int(num_filters1 * scale)) + self._se = SEModule(int(num_filters1 * scale), data_format=data_format) self._pointwise_conv = ConvBNLayer( num_channels=int(num_filters1 * scale), filter_size=1, num_filters=int(num_filters2 * scale), stride=1, padding=0, + data_format=data_format, ) def forward(self, inputs): @@ -118,6 +124,7 @@ def __init__( last_conv_stride=1, last_pool_type="max", last_pool_kernel_size=[3, 2], + data_format="NCHW", **kwargs, ): super().__init__() @@ -131,6 +138,7 @@ def __init__( num_filters=int(32 * scale), stride=2, padding=1, + data_format=data_format, ) conv2_1 = DepthwiseSeparable( @@ -140,6 +148,7 @@ def __init__( num_groups=32, stride=1, scale=scale, + data_format=data_format, ) self.block_list.append(conv2_1) @@ -150,6 +159,7 @@ def __init__( num_groups=64, stride=1, scale=scale, + data_format=data_format, ) self.block_list.append(conv2_2) @@ -160,6 +170,7 @@ def __init__( num_groups=128, stride=1, scale=scale, + data_format=data_format, ) self.block_list.append(conv3_1) @@ -170,6 +181,7 @@ def __init__( num_groups=128, stride=(2, 1), scale=scale, + data_format=data_format, ) self.block_list.append(conv3_2) @@ -180,6 +192,7 @@ def __init__( num_groups=256, stride=1, scale=scale, + data_format=data_format, ) self.block_list.append(conv4_1) @@ -190,6 +203,7 @@ def __init__( num_groups=256, stride=(2, 1), scale=scale, + data_format=data_format, ) self.block_list.append(conv4_2) @@ -204,6 +218,7 @@ def __init__( padding=2, scale=scale, use_se=False, + data_format=data_format, ) self.block_list.append(conv5) @@ -217,6 +232,7 @@ def __init__( padding=2, scale=scale, use_se=True, + data_format=data_format, ) self.block_list.append(conv5_6) @@ -230,6 +246,7 @@ def __init__( padding=2, use_se=True, scale=scale, + data_format=data_format, ) self.block_list.append(conv6) @@ -239,12 +256,17 @@ def __init__( kernel_size=last_pool_kernel_size, stride=last_pool_kernel_size, padding=0, + data_format=data_format, ) else: - self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.pool = nn.MaxPool2D( + kernel_size=2, stride=2, padding=0, data_format=data_format + ) self.out_channels = int(1024 * scale) def forward(self, inputs): + if self.data_format == "NHWC": + inputs = paddle.tensor.transpose(inputs, [0, 2, 3, 1]) y = self.conv1(inputs) y = self.block_list(y) y = self.pool(y) @@ -252,9 +274,9 @@ def forward(self, inputs): class SEModule(nn.Layer): - def __init__(self, channel, reduction=4): + def __init__(self, channel, reduction=4, data_format="NCHW"): super(SEModule, self).__init__() - self.avg_pool = AdaptiveAvgPool2D(1) + self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format) self.conv1 = Conv2D( in_channels=channel, out_channels=channel // reduction, @@ -263,6 +285,7 @@ def __init__(self, channel, reduction=4): padding=0, weight_attr=ParamAttr(), bias_attr=ParamAttr(), + data_format=data_format, ) self.conv2 = Conv2D( in_channels=channel // reduction, @@ -272,6 +295,7 @@ def __init__(self, channel, reduction=4): padding=0, weight_attr=ParamAttr(), bias_attr=ParamAttr(), + data_format=data_format, ) def forward(self, inputs): diff --git a/ppocr/modeling/backbones/rec_svtrnet.py b/ppocr/modeling/backbones/rec_svtrnet.py index 427c87b324..d5bb6e507d 100644 --- a/ppocr/modeling/backbones/rec_svtrnet.py +++ b/ppocr/modeling/backbones/rec_svtrnet.py @@ -51,6 +51,7 @@ def __init__( bias_attr=False, groups=1, act=nn.GELU, + data_format="NCHW", ): super().__init__() self.conv = nn.Conv2D( @@ -62,8 +63,9 @@ def __init__( groups=groups, weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), bias_attr=bias_attr, + data_format=data_format, ) - self.norm = nn.BatchNorm2D(out_channels) + self.norm = nn.BatchNorm2D(out_channels, data_format=data_format) self.act = act() def forward(self, inputs): diff --git a/ppocr/modeling/heads/cls_head.py b/ppocr/modeling/heads/cls_head.py index 867e960182..381f0f3e6b 100644 --- a/ppocr/modeling/heads/cls_head.py +++ b/ppocr/modeling/heads/cls_head.py @@ -31,9 +31,10 @@ class ClsHead(nn.Layer): params(dict): super parameters for build Class network """ - def __init__(self, in_channels, class_dim, **kwargs): + def __init__(self, in_channels, class_dim, data_format="NCHW", **kwargs): super(ClsHead, self).__init__() self.pool = nn.AdaptiveAvgPool2D(1) + self.nchw = data_format == "NCHW" stdv = 1.0 / math.sqrt(in_channels * 1.0) self.fc = nn.Linear( in_channels, @@ -46,7 +47,7 @@ def __init__(self, in_channels, class_dim, **kwargs): def forward(self, x, targets=None): x = self.pool(x) - x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) + x = paddle.reshape(x, shape=[x.shape[0], x.shape[1 if self.nchw else 3]]) x = self.fc(x) if not self.training: x = F.softmax(x, axis=1) diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py index 8f41a25b01..dc11c4dc11 100644 --- a/ppocr/modeling/heads/det_db_head.py +++ b/ppocr/modeling/heads/det_db_head.py @@ -32,7 +32,9 @@ def get_bias_attr(k): class Head(nn.Layer): - def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs): + def __init__( + self, in_channels, kernel_list=[3, 2, 2], data_format="NCHw", **kwargs + ): super(Head, self).__init__() self.conv1 = nn.Conv2D( @@ -42,12 +44,14 @@ def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs): padding=int(kernel_list[0] // 2), weight_attr=ParamAttr(), bias_attr=False, + data_format=data_format, ) self.conv_bn1 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1e-4)), act="relu", + data_layout=data_format, ) self.conv2 = nn.Conv2DTranspose( @@ -57,12 +61,14 @@ def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs): stride=2, weight_attr=ParamAttr(initializer=paddle.nn.initializer.KaimingUniform()), bias_attr=get_bias_attr(in_channels // 4), + data_format=data_format, ) self.conv_bn2 = nn.BatchNorm( num_channels=in_channels // 4, param_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1.0)), bias_attr=ParamAttr(initializer=paddle.nn.initializer.Constant(value=1e-4)), act="relu", + data_layout=data_format, ) self.conv3 = nn.Conv2DTranspose( in_channels=in_channels // 4, @@ -71,6 +77,7 @@ def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs): stride=2, weight_attr=ParamAttr(initializer=paddle.nn.initializer.KaimingUniform()), bias_attr=get_bias_attr(in_channels // 4), + data_format=data_format, ) def forward(self, x, return_f=False): @@ -95,11 +102,12 @@ class DBHead(nn.Layer): params(dict): super parameters for build DB network """ - def __init__(self, in_channels, k=50, **kwargs): + def __init__(self, in_channels, k=50, data_format="NCHW", **kwargs): super(DBHead, self).__init__() self.k = k - self.binarize = Head(in_channels, **kwargs) - self.thresh = Head(in_channels, **kwargs) + self.binarize = Head(in_channels, data_format=data_format, **kwargs) + self.thresh = Head(in_channels, data_format=data_format, **kwargs) + self.data_format = data_format def step_function(self, x, y): return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y))) @@ -107,11 +115,18 @@ def step_function(self, x, y): def forward(self, x, targets=None): shrink_maps = self.binarize(x) if not self.training: + if "NHWC" == self.data_format: + shrink_maps = paddle.tensor.transpose(shrink_maps, [0, 3, 1, 2]) return {"maps": shrink_maps} threshold_maps = self.thresh(x) binary_maps = self.step_function(shrink_maps, threshold_maps) - y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1) + y = paddle.concat( + [shrink_maps, threshold_maps, binary_maps], + axis=1 if "NCHW" == self.data_format else 3, + ) + if "NHWC" == self.data_format: + y = paddle.tensor.transpose(y, [0, 3, 1, 2]) return {"maps": y} diff --git a/ppocr/modeling/heads/rec_multi_head.py b/ppocr/modeling/heads/rec_multi_head.py index be4461567b..ca4f3826e9 100644 --- a/ppocr/modeling/heads/rec_multi_head.py +++ b/ppocr/modeling/heads/rec_multi_head.py @@ -65,14 +65,17 @@ def forward(self, x): class MultiHead(nn.Layer): - def __init__(self, in_channels, out_channels_list, **kwargs): + def __init__(self, in_channels, out_channels_list, data_format="NCHW", **kwargs): super().__init__() self.head_list = kwargs.pop("head_list") self.use_pool = kwargs.get("use_pool", False) self.use_pos = kwargs.get("use_pos", False) self.in_channels = in_channels + self.nchw = data_format == "NCHW" if self.use_pool: - self.pool = nn.AvgPool2D(kernel_size=[3, 2], stride=[3, 2], padding=0) + self.pool = nn.AvgPool2D( + kernel_size=[3, 2], stride=[3, 2], padding=0, data_format=data_format + ) self.gtc_head = "sar" assert len(self.head_list) >= 2 for idx, head_name in enumerate(self.head_list): @@ -113,18 +116,21 @@ def __init__(self, in_channels, out_channels_list, **kwargs): ) elif name == "CTCHead": # ctc neck - self.encoder_reshape = Im2Seq(in_channels) + self.encoder_reshape = Im2Seq(in_channels, data_format=data_format) neck_args = self.head_list[idx][name]["Neck"] encoder_type = neck_args.pop("name") self.ctc_encoder = SequenceEncoder( - in_channels=in_channels, encoder_type=encoder_type, **neck_args + in_channels=in_channels, + encoder_type=encoder_type, + data_format=data_format, + **neck_args, ) # ctc head head_args = self.head_list[idx][name]["Head"] self.ctc_head = eval(name)( in_channels=self.ctc_encoder.out_channels, out_channels=out_channels_list["CTCLabelDecode"], - **head_args, + data_format=data_format**head_args, ) else: raise NotImplementedError( @@ -144,6 +150,8 @@ def forward(self, x, targets=None): # eval mode if not self.training: return ctc_out + if not self.nchw: + x = x.transpose([0, 3, 1, 2]) if self.gtc_head == "sar": sar_out = self.sar_head(x, targets[1:]) head_out["sar"] = sar_out diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py old mode 100644 new mode 100755 index 9c646a1d67..c5d44282b5 --- a/ppocr/modeling/heads/rec_sar_head.py +++ b/ppocr/modeling/heads/rec_sar_head.py @@ -70,7 +70,7 @@ def __init__( kwargs = dict( input_size=d_model, hidden_size=d_enc, - num_layers=2, + num_layers=1, time_major=False, dropout=enc_drop_rnn, direction=direction, @@ -197,7 +197,7 @@ def __init__( kwargs = dict( input_size=encoder_rnn_out_size, hidden_size=encoder_rnn_out_size, - num_layers=2, + num_layers=1, time_major=False, dropout=dec_drop_rnn, direction=direction, diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py index 5c1674b434..48c7ccd263 100644 --- a/ppocr/modeling/necks/db_fpn.py +++ b/ppocr/modeling/necks/db_fpn.py @@ -42,6 +42,7 @@ def __init__( groups=None, if_act=True, act="relu", + data_format="NCHW", **kwargs, ): super(DSConv, self).__init__() @@ -57,9 +58,12 @@ def __init__( padding=padding, groups=groups, bias_attr=False, + data_format=data_format, ) - self.bn1 = nn.BatchNorm(num_channels=in_channels, act=None) + self.bn1 = nn.BatchNorm( + num_channels=in_channels, act=None, data_layout=data_format + ) self.conv2 = nn.Conv2D( in_channels=in_channels, @@ -67,9 +71,12 @@ def __init__( kernel_size=1, stride=1, bias_attr=False, + data_format=data_format, ) - self.bn2 = nn.BatchNorm(num_channels=int(in_channels * 4), act=None) + self.bn2 = nn.BatchNorm( + num_channels=int(in_channels * 4), act=None, data_layout=data_format + ) self.conv3 = nn.Conv2D( in_channels=int(in_channels * 4), @@ -77,6 +84,7 @@ def __init__( kernel_size=1, stride=1, bias_attr=False, + data_format=data_format, ) self._c = [in_channels, out_channels] if in_channels != out_channels: @@ -86,6 +94,7 @@ def __init__( kernel_size=1, stride=1, bias_attr=False, + data_format=data_format, ) def forward(self, inputs): @@ -114,11 +123,14 @@ def forward(self, inputs): class DBFPN(nn.Layer): - def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): + def __init__( + self, in_channels, out_channels, use_asf=False, data_format="NCHW", **kwargs + ): super(DBFPN, self).__init__() self.out_channels = out_channels self.use_asf = use_asf weight_attr = paddle.nn.initializer.KaimingUniform() + self.data_format = data_format self.in2_conv = nn.Conv2D( in_channels=in_channels[0], @@ -126,6 +138,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) self.in3_conv = nn.Conv2D( in_channels=in_channels[1], @@ -133,6 +146,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) self.in4_conv = nn.Conv2D( in_channels=in_channels[2], @@ -140,6 +154,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) self.in5_conv = nn.Conv2D( in_channels=in_channels[3], @@ -147,6 +162,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) self.p5_conv = nn.Conv2D( in_channels=self.out_channels, @@ -155,6 +171,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): padding=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) self.p4_conv = nn.Conv2D( in_channels=self.out_channels, @@ -163,6 +180,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): padding=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) self.p3_conv = nn.Conv2D( in_channels=self.out_channels, @@ -171,6 +189,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): padding=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) self.p2_conv = nn.Conv2D( in_channels=self.out_channels, @@ -179,6 +198,7 @@ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs): padding=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) if self.use_asf is True: @@ -193,24 +213,56 @@ def forward(self, x): in2 = self.in2_conv(c2) out4 = in4 + F.upsample( - in5, scale_factor=2, mode="nearest", align_mode=1 + in5, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, ) # 1/16 out3 = in3 + F.upsample( - out4, scale_factor=2, mode="nearest", align_mode=1 + out4, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, ) # 1/8 out2 = in2 + F.upsample( - out3, scale_factor=2, mode="nearest", align_mode=1 + out3, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, ) # 1/4 p5 = self.p5_conv(in5) p4 = self.p4_conv(out4) p3 = self.p3_conv(out3) p2 = self.p2_conv(out2) - p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) - p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) - p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) + p5 = F.upsample( + p5, + scale_factor=8, + mode="nearest", + align_mode=1, + data_format=self.data_format, + ) + p4 = F.upsample( + p4, + scale_factor=4, + mode="nearest", + align_mode=1, + data_format=self.data_format, + ) + p3 = F.upsample( + p3, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, + ) - fuse = paddle.concat([p5, p4, p3, p2], axis=1) + fuse = paddle.concat( + [p5, p4, p3, p2], axis=1 if "NCHW" == self.data_format else 3 + ) if self.use_asf is True: fuse = self.asf(fuse, [p5, p4, p3, p2]) @@ -219,7 +271,9 @@ def forward(self, x): class RSELayer(nn.Layer): - def __init__(self, in_channels, out_channels, kernel_size, shortcut=True): + def __init__( + self, in_channels, out_channels, kernel_size, shortcut=True, data_format="NCHW" + ): super(RSELayer, self).__init__() weight_attr = paddle.nn.initializer.KaimingUniform() self.out_channels = out_channels @@ -230,8 +284,9 @@ def __init__(self, in_channels, out_channels, kernel_size, shortcut=True): padding=int(kernel_size // 2), weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) - self.se_block = SEModule(self.out_channels) + self.se_block = SEModule(self.out_channels, data_format=data_format) self.shortcut = shortcut def forward(self, ins): @@ -244,26 +299,48 @@ def forward(self, ins): class RSEFPN(nn.Layer): - def __init__(self, in_channels, out_channels, shortcut=True, **kwargs): + def __init__( + self, in_channels, out_channels, shortcut=True, data_format="NCHW", **kwargs + ): super(RSEFPN, self).__init__() self.out_channels = out_channels + self.nchw = data_format == "NCHW" + self.data_format = data_format self.ins_conv = nn.LayerList() self.inp_conv = nn.LayerList() self.intracl = False if "intracl" in kwargs.keys() and kwargs["intracl"] is True: self.intracl = kwargs["intracl"] - self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl1 = IntraCLBlock( + self.out_channels // 4, reduce_factor=2, data_format=data_format + ) + self.incl2 = IntraCLBlock( + self.out_channels // 4, reduce_factor=2, data_format=data_format + ) + self.incl3 = IntraCLBlock( + self.out_channels // 4, reduce_factor=2, data_format=data_format + ) + self.incl4 = IntraCLBlock( + self.out_channels // 4, reduce_factor=2, data_format=data_format + ) for i in range(len(in_channels)): self.ins_conv.append( - RSELayer(in_channels[i], out_channels, kernel_size=1, shortcut=shortcut) + RSELayer( + in_channels[i], + out_channels, + kernel_size=1, + shortcut=shortcut, + data_format=data_format, + ) ) self.inp_conv.append( RSELayer( - out_channels, out_channels // 4, kernel_size=3, shortcut=shortcut + out_channels, + out_channels // 4, + kernel_size=3, + shortcut=shortcut, + data_format=data_format, ) ) @@ -276,13 +353,25 @@ def forward(self, x): in2 = self.ins_conv[0](c2) out4 = in4 + F.upsample( - in5, scale_factor=2, mode="nearest", align_mode=1 + in5, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, ) # 1/16 out3 = in3 + F.upsample( - out4, scale_factor=2, mode="nearest", align_mode=1 + out4, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, ) # 1/8 out2 = in2 + F.upsample( - out3, scale_factor=2, mode="nearest", align_mode=1 + out3, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, ) # 1/4 p5 = self.inp_conv[3](in5) @@ -296,18 +385,40 @@ def forward(self, x): p3 = self.incl2(p3) p2 = self.incl1(p2) - p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) - p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) - p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) + p5 = F.upsample( + p5, + scale_factor=8, + mode="nearest", + align_mode=1, + data_format=self.data_format, + ) + p4 = F.upsample( + p4, + scale_factor=4, + mode="nearest", + align_mode=1, + data_format=self.data_format, + ) + p3 = F.upsample( + p3, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, + ) - fuse = paddle.concat([p5, p4, p3, p2], axis=1) + fuse = paddle.concat([p5, p4, p3, p2], axis=1 if self.nchw else 3) return fuse class LKPAN(nn.Layer): - def __init__(self, in_channels, out_channels, mode="large", **kwargs): + def __init__( + self, in_channels, out_channels, mode="large", data_format="NCHW", **kwargs + ): super(LKPAN, self).__init__() self.out_channels = out_channels + self.nchw = data_format == "NCHW" + self.data_format = data_format weight_attr = paddle.nn.initializer.KaimingUniform() self.ins_conv = nn.LayerList() @@ -335,6 +446,7 @@ def __init__(self, in_channels, out_channels, mode="large", **kwargs): kernel_size=1, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) ) @@ -346,6 +458,7 @@ def __init__(self, in_channels, out_channels, mode="large", **kwargs): padding=4, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) ) @@ -359,6 +472,7 @@ def __init__(self, in_channels, out_channels, mode="large", **kwargs): stride=2, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) ) self.pan_lat_conv.append( @@ -369,16 +483,25 @@ def __init__(self, in_channels, out_channels, mode="large", **kwargs): padding=4, weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False, + data_format=data_format, ) ) self.intracl = False if "intracl" in kwargs.keys() and kwargs["intracl"] is True: self.intracl = kwargs["intracl"] - self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) - self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2) + self.incl1 = IntraCLBlock( + self.out_channels // 4, reduce_factor=2, data_format=data_format + ) + self.incl2 = IntraCLBlock( + self.out_channels // 4, reduce_factor=2, data_format=data_format + ) + self.incl3 = IntraCLBlock( + self.out_channels // 4, reduce_factor=2, data_format=data_format + ) + self.incl4 = IntraCLBlock( + self.out_channels // 4, reduce_factor=2, data_format=data_format + ) def forward(self, x): c2, c3, c4, c5 = x @@ -389,13 +512,25 @@ def forward(self, x): in2 = self.ins_conv[0](c2) out4 = in4 + F.upsample( - in5, scale_factor=2, mode="nearest", align_mode=1 + in5, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, ) # 1/16 out3 = in3 + F.upsample( - out4, scale_factor=2, mode="nearest", align_mode=1 + out4, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, ) # 1/8 out2 = in2 + F.upsample( - out3, scale_factor=2, mode="nearest", align_mode=1 + out3, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, ) # 1/4 f5 = self.inp_conv[3](in5) @@ -418,11 +553,29 @@ def forward(self, x): p3 = self.incl2(p3) p2 = self.incl1(p2) - p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) - p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) - p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) + p5 = F.upsample( + p5, + scale_factor=8, + mode="nearest", + align_mode=1, + data_format=self.data_format, + ) + p4 = F.upsample( + p4, + scale_factor=4, + mode="nearest", + align_mode=1, + data_format=self.data_format, + ) + p3 = F.upsample( + p3, + scale_factor=2, + mode="nearest", + align_mode=1, + data_format=self.data_format, + ) - fuse = paddle.concat([p5, p4, p3, p2], axis=1) + fuse = paddle.concat([p5, p4, p3, p2], axis=1 if self.nchw else 3) return fuse diff --git a/ppocr/modeling/necks/intracl.py b/ppocr/modeling/necks/intracl.py index 2c4809cb12..10a60a2e42 100644 --- a/ppocr/modeling/necks/intracl.py +++ b/ppocr/modeling/necks/intracl.py @@ -5,16 +5,26 @@ class IntraCLBlock(nn.Layer): - def __init__(self, in_channels=96, reduce_factor=4): + def __init__(self, in_channels=96, reduce_factor=4, data_format="NCHW"): super(IntraCLBlock, self).__init__() self.channels = in_channels self.rf = reduce_factor weight_attr = paddle.nn.initializer.KaimingUniform() self.conv1x1_reduce_channel = nn.Conv2D( - self.channels, self.channels // self.rf, kernel_size=1, stride=1, padding=0 + self.channels, + self.channels // self.rf, + kernel_size=1, + stride=1, + padding=0, + data_format=data_format, ) self.conv1x1_return_channel = nn.Conv2D( - self.channels // self.rf, self.channels, kernel_size=1, stride=1, padding=0 + self.channels // self.rf, + self.channels, + kernel_size=1, + stride=1, + padding=0, + data_format=data_format, ) self.v_layer_7x1 = nn.Conv2D( @@ -23,6 +33,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), + data_format=data_format, ) self.v_layer_5x1 = nn.Conv2D( self.channels // self.rf, @@ -30,6 +41,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(5, 1), stride=(1, 1), padding=(2, 0), + data_format=data_format, ) self.v_layer_3x1 = nn.Conv2D( self.channels // self.rf, @@ -37,6 +49,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), + data_format=data_format, ) self.q_layer_1x7 = nn.Conv2D( @@ -45,6 +58,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), + data_format=data_format, ) self.q_layer_1x5 = nn.Conv2D( self.channels // self.rf, @@ -52,6 +66,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(1, 5), stride=(1, 1), padding=(0, 2), + data_format=data_format, ) self.q_layer_1x3 = nn.Conv2D( self.channels // self.rf, @@ -59,6 +74,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), + data_format=data_format, ) # base @@ -68,6 +84,7 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), + data_format=data_format, ) self.c_layer_5x5 = nn.Conv2D( self.channels // self.rf, @@ -82,9 +99,10 @@ def __init__(self, in_channels=96, reduce_factor=4): kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), + data_format=data_format, ) - self.bn = nn.BatchNorm2D(self.channels) + self.bn = nn.BatchNorm2D(self.channels, data_format=data_format) self.relu = nn.ReLU() def forward(self, x): diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py index fa7b8a1f1a..9727568072 100644 --- a/ppocr/modeling/necks/rnn.py +++ b/ppocr/modeling/necks/rnn.py @@ -30,15 +30,21 @@ class Im2Seq(nn.Layer): - def __init__(self, in_channels, **kwargs): + def __init__(self, in_channels, data_format="NCHW", **kwargs): super().__init__() self.out_channels = in_channels + self.nchw = data_format == "NCHW" def forward(self, x): - B, C, H, W = x.shape - assert H == 1 - x = x.squeeze(axis=2) - x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + if self.nchw: + B, C, H, W = x.shape + assert H == 1 + x = x.squeeze(axis=2) + x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + else: + B, H, W, C = x.shape + assert H == 1 + x = x.squeeze(axis=1) return x @@ -152,19 +158,26 @@ def __init__( drop_path=0.0, kernel_size=[3, 3], qk_scale=None, + data_format="NCHW", ): super(EncoderWithSVTR, self).__init__() self.depth = depth self.use_guide = use_guide + self.nchw = data_format == "NCHW" self.conv1 = ConvBNLayer( in_channels, in_channels // 8, kernel_size=kernel_size, padding=[kernel_size[0] // 2, kernel_size[1] // 2], act=nn.Swish, + data_format=data_format, ) self.conv2 = ConvBNLayer( - in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish + in_channels // 8, + hidden_dims, + kernel_size=1, + act=nn.Swish, + data_format=data_format, ) self.svtr_block = nn.LayerList( @@ -189,7 +202,13 @@ def __init__( ] ) self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6) - self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act=nn.Swish) + self.conv3 = ConvBNLayer( + hidden_dims, + in_channels, + kernel_size=1, + act=nn.Swish, + data_format=data_format, + ) # last conv-nxn, the input is concat of input tensor and conv3 output tensor self.conv4 = ConvBNLayer( 2 * in_channels, @@ -197,9 +216,12 @@ def __init__( kernel_size=kernel_size, padding=[kernel_size[0] // 2, kernel_size[1] // 2], act=nn.Swish, + data_format=data_format, ) - self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act=nn.Swish) + self.conv1x1 = ConvBNLayer( + in_channels // 8, dims, kernel_size=1, act=nn.Swish, data_format=data_format + ) self.out_channels = dims self.apply(self._init_weights) @@ -225,23 +247,33 @@ def forward(self, x): z = self.conv1(z) z = self.conv2(z) # SVTR global block - B, C, H, W = z.shape - z = z.flatten(2).transpose([0, 2, 1]) + if self.nchw: + B, C, H, W = z.shape + z = z.flatten(2).transpose([0, 2, 1]) + else: + B, H, W, C = z.shape + z = z.flatten(start_axis=1, stop_axis=2) + for blk in self.svtr_block: z = blk(z) z = self.norm(z) # last stage - z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2]) + if self.nchw: + z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2]) + else: + z = z.reshape([0, H, W, C]) z = self.conv3(z) - z = paddle.concat((h, z), axis=1) + z = paddle.concat((h, z), axis=1 if self.nchw else 3) z = self.conv1x1(self.conv4(z)) return z class SequenceEncoder(nn.Layer): - def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): + def __init__( + self, in_channels, encoder_type, hidden_size=48, data_format="NCHW", **kwargs + ): super(SequenceEncoder, self).__init__() - self.encoder_reshape = Im2Seq(in_channels) + self.encoder_reshape = Im2Seq(in_channels, data_format=data_format) self.out_channels = self.encoder_reshape.out_channels self.encoder_type = encoder_type if encoder_type == "reshape": @@ -259,15 +291,20 @@ def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): ) if encoder_type == "svtr": self.encoder = support_encoder_dict[encoder_type]( - self.encoder_reshape.out_channels, **kwargs + self.encoder_reshape.out_channels, data_format=data_format, **kwargs ) elif encoder_type == "cascadernn": self.encoder = support_encoder_dict[encoder_type]( - self.encoder_reshape.out_channels, hidden_size, **kwargs + self.encoder_reshape.out_channels, + hidden_size, + data_format=data_format, + **kwargs, ) else: self.encoder = support_encoder_dict[encoder_type]( - self.encoder_reshape.out_channels, hidden_size + self.encoder_reshape.out_channels, + hidden_size, + data_format=data_format, ) self.out_channels = self.encoder.out_channels self.only_reshape = False