Skip to content

Commit

Permalink
apply code-format changes
Browse files Browse the repository at this point in the history
  • Loading branch information
luxincn authored and github-actions[bot] committed Sep 6, 2024
1 parent 8172355 commit 1031119
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
13 changes: 6 additions & 7 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2867,9 +2867,9 @@ pnnx::ModelInfo Graph::flops_mem_count()
pnnx::ModelInfo m;
for (const Operator* op : ops)
{
if(op->type == "nn.Conv2d")
if (op->type == "nn.Conv2d")
{
if(op->inputs[0]->type != 0)
if (op->inputs[0]->type != 0)
{
int ci = op->inputs[0]->shape[1];
int kw = op->params.at("kernel_size").ai[0];
Expand All @@ -2881,26 +2881,26 @@ pnnx::ModelInfo Graph::flops_mem_count()
int wi = op->inputs[0]->shape[2];
int hi = op->inputs[0]->shape[3];
int g = op->params.at("groups").i;
if(bias == 1)
if (bias == 1)
{
m.flops += 2 * ci * kw * kh * co * w * h;
}
else
{
m.flops += (2 * ci * kw * kh -1) * co * w * h;
m.flops += (2 * ci * kw * kh - 1) * co * w * h;
}
int input_m = wi * hi * ci;
int output_m = w * h * co;
int weights_m = kw * kh * ci * co;
m.memory_access += input_m + output_m + weights_m;
}
}
else if(op->type == "nn.Linear")
else if (op->type == "nn.Linear")
{
int in = op->params.at("in_features").i;
int out = op->params.at("out_features").i;
int bias = op->params.at("bias").b ? 1 : 0;
if(bias == 1)
if (bias == 1)
{
m.flops += 2 * in * out;
}
Expand All @@ -2912,7 +2912,6 @@ pnnx::ModelInfo Graph::flops_mem_count()
}
else
{

}
}

Expand Down
3 changes: 2 additions & 1 deletion tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class OnnxAttributeProxy;

namespace pnnx {

struct ModelInfo {
struct ModelInfo
{
ModelInfo()
: flops(0), memory_access(0)
{
Expand Down
1 change: 0 additions & 1 deletion tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ int main(int argc, char** argv)
pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath, fp16);
}


// pnnx::Graph pnnx_graph2;

// pnnx_graph2.load("pnnx.param", "pnnx.bin");
Expand Down

0 comments on commit 1031119

Please sign in to comment.