diff --git a/runtime/onert/core/include/ir/IOperation.h b/runtime/onert/core/include/ir/IOperation.h index be0dd939da6..4cf4bbeb9ad 100644 --- a/runtime/onert/core/include/ir/IOperation.h +++ b/runtime/onert/core/include/ir/IOperation.h @@ -38,6 +38,7 @@ struct IOperation virtual std::string name() const { return std::string{toString(opcode())}; } virtual OpCode opcode() const = 0; + virtual void replaceInput(size_t pos, const OperandIndex &index) = 0; virtual void replaceInputs(const OperandIndex &from, const OperandIndex &to) = 0; virtual void replaceOutputs(const OperandIndex &from, const OperandIndex &to) = 0; virtual const OperandIndexSequence &getInputs() const = 0; diff --git a/runtime/onert/core/include/ir/OperandIndexSequence.h b/runtime/onert/core/include/ir/OperandIndexSequence.h index 66d00761ba9..018e1e30740 100644 --- a/runtime/onert/core/include/ir/OperandIndexSequence.h +++ b/runtime/onert/core/include/ir/OperandIndexSequence.h @@ -50,6 +50,7 @@ class OperandIndexSequence const OperandIndex &at(IOIndex set_index) const { return _vec.at(set_index.value()); } const OperandIndex &at(uint32_t index) const { return _vec.at(index); } bool contains(const OperandIndex &index) const; + void replace(size_t pos, const OperandIndex &index); void replace(const OperandIndex &from, const OperandIndex &to); OperandIndexSequence operator|(ir::Remove filter) const { diff --git a/runtime/onert/core/include/ir/Operation.h b/runtime/onert/core/include/ir/Operation.h index 06ab29ecb19..5e5584c0c8a 100644 --- a/runtime/onert/core/include/ir/Operation.h +++ b/runtime/onert/core/include/ir/Operation.h @@ -48,6 +48,7 @@ class Operation : virtual public IOperation virtual ~Operation(); public: + void replaceInput(size_t pos, const OperandIndex &index) override; void replaceInputs(const OperandIndex &from, const OperandIndex &to) override; void replaceOutputs(const OperandIndex &from, const OperandIndex &to) override; OperandIndexSequence &getInputs() { return _inputs; } diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.cc b/runtime/onert/core/src/ir/OperandIndexSequence.cc index a15b6d0d69f..96a08f8ce59 100644 --- a/runtime/onert/core/src/ir/OperandIndexSequence.cc +++ b/runtime/onert/core/src/ir/OperandIndexSequence.cc @@ -17,6 +17,7 @@ #include "ir/OperandIndexSequence.h" #include +#include #include namespace onert @@ -50,6 +51,12 @@ bool OperandIndexSequence::contains(const OperandIndex &index) const return std::find(_vec.begin(), _vec.end(), index) != _vec.end(); } +void OperandIndexSequence::replace(size_t pos, const OperandIndex &index) +{ + assert(pos < _vec.size() && "OperandIndexSequence: Out of range"); + _vec.at(pos) = index; +} + void OperandIndexSequence::replace(const OperandIndex &from, const OperandIndex &to) { std::replace(_vec.begin(), _vec.end(), from, to); diff --git a/runtime/onert/core/src/ir/Operation.cc b/runtime/onert/core/src/ir/Operation.cc index 64792525dd7..1079738e04c 100644 --- a/runtime/onert/core/src/ir/Operation.cc +++ b/runtime/onert/core/src/ir/Operation.cc @@ -52,6 +52,8 @@ void Operation::setOutputs(const OperandIndexSequence &indexes) _outputs = indexes; } +void Operation::replaceInput(size_t pos, const OperandIndex &index) { _inputs.replace(pos, index); } + void Operation::replaceInputs(const OperandIndex &from, const OperandIndex &to) { _inputs.replace(from, to);