Skip to content

Commit

Permalink
update flops_mem_count
Browse files Browse the repository at this point in the history
  • Loading branch information
luxincn committed Sep 10, 2024
1 parent 1031119 commit 62eb648
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2910,6 +2910,106 @@ pnnx::ModelInfo Graph::flops_mem_count()
}
m.memory_access += in + out + in * out;
}
else if (op->type == "nn.MultiheadAttention")
{
int in_size = op->inputs.size();

if (std::find(op->nputnames.begin(), op->inputnames.end(), "attn_mask") != op->inputnames.end())
{
in_size -= 1;
}

int q_l, k_s, v_s;
bool batch_first = op->params.find("batch_first") != op->params.end() && op->params.at("batch_first").b;

if (in_size == 3)
{
q_l = op->inputs[0]->shape[batch_first ? 1 : 0];
k_s = op->inputs[1]->shape[batch_first ? 1 : 0];
v_s = op->inputs[2]->shape[batch_first ? 1 : 0];
}
else if (in_size == 2)
{
q_l = op->inputs[0]->shape[batch_first ? 1 : 0];
k_s = op->inputs[1]->shape[batch_first ? 1 : 0];
v_s = k_s;
}
else
{
q_l = op->inputs[0]->shape[batch_first ? 1 : 0];
k_s = q_l;
v_s = q_l;
}

int num_heads = op->params.at("num_heads").i;
int embed_dim = op->params.at("embed_dim").i;
int Kdim = op->params.at("kdim").i;
int vdim = op->params.at("vdim").i;

long long linear1 = q_l * embed_dim * embed_dim + k_s * embed_dim * Kdim + v_s * embed_dim * vdim;
long long attention = q_l * k_s * embed_dim + 2 * q_l * k_s * num_heads + q_l * v_s * embed_dim;
long long linerar2 = q_l* embed_dim * embed_dim;
m.flops += linear1 + attention + linerar2;

long long weights = embed_dim * embed_dim + embed_dim * Kdim + embed_dim * vdim + num_heads * vdim * embed_dim;
long long in = q_l * embed_dim + k_s * Kdim + v_s * vdim;
long long attention_m = q_l * embed_dim + k_s * Kdim + 2 * q_l * k_s + v_s * vdim;
long long out = q_l * embed_dim;
m.memory_access += weights + in + attention_m + out;
}
else if (op->type == "nn.MaxPool2d")
{
int num_o = op->params.at("return_indices").b ? 2 : 1;
int batch_size, in_c, in_h, in_w, out_h, out_w;
if (op->inputs[0]->shape.size() == 4)
{
batch_size = op->inputs[0]->shape[0];
in_c = op->inputs[0]->shape[1];
in_h = op->inputs[0]->shape[2];
in_w = op->inputs[0]->shape[3];
out_h = op->outputs[0]->shape[2];
out_w = op->outputs[0]->shape[3];
}
else if (op->inputs[0]->shape.size() == 3)
{
batch_size = 1;
in_c = op->inputs[0]->shape[0];
in_h = op->inputs[0]->shape[1];
in_w = op->inputs[0]->shape[2];
out_h = op->outputs[0]->shape[1];
out_w = op->outputs[0]->shape[2];
}
m.memory_access += batch_size * in_c * ( in_h * in_w + out_h * out_w * num_o )
}
else if (op->type == "nn.AvgPool2d")
{
int batch_size, in_c, in_h, in_w, out_h, out_w, k_h, k_w, kernel_add, kernel_avg;
if (op->inputs[0]->shape.size() == 4)
{
batch_size = op->inputs[0]->shape[0];
in_c = op->inputs[0]->shape[1];
in_h = op->inputs[0]->shape[2];
in_w = op->inputs[0]->shape[3];
out_h = op->outputs[0]->shape[2];
out_w = op->outputs[0]->shape[3];
}
else if (op->inputs[0]->shape.size() == 3)
{
batch_size = 1;
in_c = op->inputs[0]->shape[0];
in_h = op->inputs[0]->shape[1];
in_w = op->inputs[0]->shape[2];
out_h = op->outputs[0]->shape[1];
out_w = op->outputs[0]->shape[2];
}
k_h = op->params.at("kernel_size").ai[0];
k_w = op->params.at("kernel_size").ai[1];

kernel_add = k_h * k_w - 1;
kernel_avg = 1;
m.flops += ( kernel_add + kernel_avg ) * ( out_h * out_w ) * in_c;
m.memory_access += batch_size * in_c * ( in_h * in_w + out_h * out_w )
}
else
{
}
Expand Down

0 comments on commit 62eb648

Please sign in to comment.