Skip to content

Commit

Permalink
[onert] Support RoPE in circle loader (#14209)
Browse files Browse the repository at this point in the history
This commit supports RoPE op in circle loader

ONE-DCO-1.0-Signed-off-by: youngsik kim <[email protected]>
  • Loading branch information
ys44kim authored Oct 15, 2024
1 parent ef7418b commit 21496b0
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions runtime/onert/core/src/loader/CircleLoader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class CircleLoader final : public loader::BaseLoader<LoaderDomain>
void loadBCQFullyConnected(const Operator *op, ir::Graph &subg);
void loadBCQGather(const Operator *op, ir::Graph &subg);
void loadRmsNorm(const Operator *op, ir::Graph &subg);
void loadRoPE(const Operator *op, ir::Graph &subg);

public:
using BaseLoader::BaseLoader;
Expand Down Expand Up @@ -99,6 +100,20 @@ class CircleLoader final : public loader::BaseLoader<LoaderDomain>
return BaseLoader::tensorTypeToDataType(type);
}

ir::operation::RoPE::RoPEMode convertRoPEMode(const circle::RoPEMode mode)
{
switch (mode)
{
case circle::RoPEMode::RoPEMode_GPT_NEOX:
return ir::operation::RoPE::RoPEMode::GPT_NEOX;
case circle::RoPEMode::RoPEMode_GPT_J:
return ir::operation::RoPE::RoPEMode::GPT_J;
default:
throw std::runtime_error(std::string("Unsupported RoPE mode: ") +
std::to_string(static_cast<int>(mode)));
}
}

private:
std::unique_ptr<ir::Graph> loadSubgraph(const circle::SubGraph *circle_subg) override
{
Expand Down Expand Up @@ -154,6 +169,9 @@ class CircleLoader final : public loader::BaseLoader<LoaderDomain>
case circle::BuiltinOperator::BuiltinOperator_RMS_NORM:
loadRmsNorm(op, subg);
return;
case circle::BuiltinOperator::BuiltinOperator_ROPE:
loadRoPE(op, subg);
return;
default:
BaseLoader::loadOperation(op, subg);
return;
Expand Down Expand Up @@ -246,6 +264,22 @@ void CircleLoader::loadRmsNorm(const Operator *op, ir::Graph &subg)
subg.addOperation(std::move(new_op));
}

void CircleLoader::loadRoPE(const Operator *op, ir::Graph &subg)
{
ir::OperandIndexSequence inputs;
ir::OperandIndexSequence outputs;

loadOperationIO(op, inputs, outputs);

ir::operation::RoPE::Param param;
const auto *options = op->builtin_options_as_RoPEOptions();

param.mode = convertRoPEMode(options->mode());

std::unique_ptr<ir::Operation> new_op(new ir::operation::RoPE(inputs, outputs, param));
subg.addOperation(std::move(new_op));
}

} // namespace

std::unique_ptr<ir::Model> loadCircleModel(const std::string &filename)
Expand Down

0 comments on commit 21496b0

Please sign in to comment.