diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp index 6678b0dc328..f6e380d7872 100644 --- a/compiler/luci/export/src/CircleExporterUtils.cpp +++ b/compiler/luci/export/src/CircleExporterUtils.cpp @@ -293,6 +293,12 @@ void set_tensor_index(loco::Node *node, const CircleTensorIndex &tensor_id) node->annot(std::make_unique(tensor_id)); } +void clear_tensor_index(loco::Node *node) +{ + if (node->annot() != nullptr) + node->annot(nullptr); +} + CircleTensorIndex get_tensor_index(loco::Node *node) { assert(node->annot() != nullptr); diff --git a/compiler/luci/export/src/CircleExporterUtils.h b/compiler/luci/export/src/CircleExporterUtils.h index 4a4c54a695a..83b040753dc 100644 --- a/compiler/luci/export/src/CircleExporterUtils.h +++ b/compiler/luci/export/src/CircleExporterUtils.h @@ -57,6 +57,7 @@ circle::Padding getOpPadding(const luci::Padding pad); using CircleTensorIndex = int32_t; void set_tensor_index(loco::Node *node, const CircleTensorIndex &tensor_id); +void clear_tensor_index(loco::Node *node); CircleTensorIndex get_tensor_index(loco::Node *node); } // namespace luci diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp index 133b8f4c4ba..57ae160bd54 100644 --- a/compiler/luci/export/src/CircleTensorExporter.cpp +++ b/compiler/luci/export/src/CircleTensorExporter.cpp @@ -697,4 +697,14 @@ void exportOpDefinedTensors(loco::Graph *g, FlatBufferBuilder &builder, Serializ } } +void clearExportInfo(loco::Graph *g) +{ + auto nodes = g->nodes(); + for (uint32_t n = 0; n < nodes->size(); ++n) + { + auto node = loco::must_cast(nodes->at(n)); + clear_tensor_index(node); + } +} + } // namespace luci diff --git a/compiler/luci/export/src/CircleTensorExporter.h b/compiler/luci/export/src/CircleTensorExporter.h index f9d6107b4d6..8c2a1eb2113 100644 --- a/compiler/luci/export/src/CircleTensorExporter.h +++ b/compiler/luci/export/src/CircleTensorExporter.h @@ -39,6 +39,11 @@ void prepareModelData(flatbuffers::FlatBufferBuilder &builder, SerializedModelDa void exportOpDefinedTensors(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md, SerializedGraphData &gd); +/** + * @brief clear temporary export information annotated to graph nodes + */ +void clearExportInfo(loco::Graph *g); + } // namespace luci #endif // __CIRCLE_TENSOR_EXPORTER_H__