Skip to content

Commit

Permalink
[luci/service] Add sinf namespace for sinf::Algorithm::visit()
Browse files Browse the repository at this point in the history
This PR explicitly marks 'sinf namespace' for each node's sinf::Algorithm::visit() function.

ONE-DCO-1.0-Signed-off-by: seunghui youn <[email protected]>
  • Loading branch information
zetwhite committed Aug 28, 2024
1 parent c5319b6 commit f845601
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 12 deletions.
11 changes: 8 additions & 3 deletions compiler/luci/service/src/Nodes/CircleAdd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,20 @@ luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleAdd *node)
return cloned;
}

loco::TensorShape sinf::Algorithm::visit(const luci::CircleAdd *node)
namespace sinf
{

loco::TensorShape Algorithm::visit(const luci::CircleAdd *node)
{
const auto x = loco::must_cast<luci::CircleNode *>(node->x());
const auto y = loco::must_cast<luci::CircleNode *>(node->y());

const auto x_shape = sinf::circle_shape(x);
const auto y_shape = sinf::circle_shape(y);
const auto x_shape = circle_shape(x);
const auto y_shape = circle_shape(y);

return broadcast_shape(x_shape, y_shape);
}

} // namespace sinf

} // namespace luci
11 changes: 8 additions & 3 deletions compiler/luci/service/src/Nodes/CircleBatchMatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,18 @@ luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleBatchMatMul *no
return cloned;
}

namespace sinf
{

// BatchMatMulV2 supports broadcasting in the batch dimensions(BatchMatMul doesn't)
// TODO Distinguish BatchMatMul and BatchMatMulV2
loco::TensorShape sinf::Algorithm::visit(const luci::CircleBatchMatMul *node)
loco::TensorShape Algorithm::visit(const luci::CircleBatchMatMul *node)
{
const auto x = loco::must_cast<CircleNode *>(node->x());
const auto y = loco::must_cast<CircleNode *>(node->y());

const auto x_shape = sinf::circle_shape(x);
const auto y_shape = sinf::circle_shape(y);
const auto x_shape = circle_shape(x);
const auto y_shape = circle_shape(y);

uint32_t x_rank = x_shape.rank();
uint32_t y_rank = y_shape.rank();
Expand Down Expand Up @@ -144,4 +147,6 @@ loco::TensorShape sinf::Algorithm::visit(const luci::CircleBatchMatMul *node)
return output_shape;
}

} // namespace sinf

} // namespace luci
7 changes: 6 additions & 1 deletion compiler/luci/service/src/Nodes/CircleConcatenation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleConcatenation *
return cloned;
}

loco::TensorShape sinf::Algorithm::visit(const luci::CircleConcatenation *node)
namespace sinf
{

loco::TensorShape Algorithm::visit(const luci::CircleConcatenation *node)
{
// TODO Support when CircleConcatenation has 0 input
assert(node->numValues() > 0);
Expand Down Expand Up @@ -101,4 +104,6 @@ loco::TensorShape sinf::Algorithm::visit(const luci::CircleConcatenation *node)
return output_shape;
}

} // namespace sinf

} // namespace luci
11 changes: 8 additions & 3 deletions compiler/luci/service/src/Nodes/CircleDiv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,22 @@ luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleDiv *node)
return cloned;
}

loco::TensorShape sinf::Algorithm::visit(const luci::CircleDiv *node)
namespace sinf
{

loco::TensorShape Algorithm::visit(const luci::CircleDiv *node)
{
const auto x = loco::must_cast<luci::CircleNode *>(node->x());
const auto y = loco::must_cast<luci::CircleNode *>(node->y());

const auto x_shape = sinf::circle_shape(x);
const auto y_shape = sinf::circle_shape(y);
const auto x_shape = circle_shape(x);
const auto y_shape = circle_shape(y);

auto output_shape = broadcast_shape(x_shape, y_shape);

return output_shape;
}

} // namespace sinf

} // namespace luci
6 changes: 5 additions & 1 deletion compiler/luci/service/src/Nodes/CircleIfOut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,18 @@ CircleIfOutGraphs get_out_graphs(const luci::CircleIfOut *node)

namespace luci
{
namespace sinf
{

loco::TensorShape sinf::Algorithm::visit(const luci::CircleIfOut *node)
loco::TensorShape Algorithm::visit(const luci::CircleIfOut *node)
{
auto graphs = get_out_graphs(node);
assert(*graphs.then_graph_output->shape() == *graphs.else_graph_output->shape());
return *graphs.then_graph_output->shape();
}

} // namespace sinf

loco::DataType tinf::Algorithm::visit(const luci::CircleIfOut *node)
{
auto graphs = get_out_graphs(node);
Expand Down
7 changes: 6 additions & 1 deletion compiler/luci/service/src/Nodes/CircleSoftmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,16 @@ luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSoftmax *node)
return cloned;
}

loco::TensorShape sinf::Algorithm::visit(const luci::CircleSoftmax *node)
namespace sinf
{

loco::TensorShape Algorithm::visit(const luci::CircleSoftmax *node)
{
const auto logits = loco::must_cast<luci::CircleNode *>(node->logits());

return sinf::circle_shape(logits);
}

} // namespace sinf

} // namespace luci

0 comments on commit f845601

Please sign in to comment.