diff --git a/compiler/luci/lang/include/luci/IR/Module.h b/compiler/luci/lang/include/luci/IR/Module.h index 75cf67905e7..dd4df3340be 100644 --- a/compiler/luci/lang/include/luci/IR/Module.h +++ b/compiler/luci/lang/include/luci/IR/Module.h @@ -51,6 +51,13 @@ class Module final */ loco::Graph *graph(void) const; + /** + * @brief remove graph at index + * + * @note graph(0) is interpreted as a main graph and cannot be deleted + */ + void removeGraphByIndex(size_t idx); + /** * @brief provide graph with an index * diff --git a/compiler/luci/lang/src/Module.cpp b/compiler/luci/lang/src/Module.cpp index 80ef61910f2..16c5a8277d8 100644 --- a/compiler/luci/lang/src/Module.cpp +++ b/compiler/luci/lang/src/Module.cpp @@ -35,6 +35,14 @@ loco::Graph *Module::graph(void) const return graph.get(); } +void Module::removeGraphByIndex(size_t idx) +{ + if (idx >= _graphs.size() or idx == 0) + throw std::invalid_argument("Module: Invalid graph index to be deleted"); + + _graphs.erase(_graphs.begin() + idx); +} + loco::Graph *Module::graph(size_t idx) const { auto &graph = _graphs.at(idx); diff --git a/compiler/luci/lang/src/Module.test.cpp b/compiler/luci/lang/src/Module.test.cpp index a5973e52dad..16c93250eb3 100644 --- a/compiler/luci/lang/src/Module.test.cpp +++ b/compiler/luci/lang/src/Module.test.cpp @@ -37,6 +37,33 @@ TEST(ModuleTest, add) ASSERT_EQ(g_ptr, m->graph(0)); } +TEST(ModuleTest, remove) +{ + auto m = luci::make_module(); + auto g1 = loco::make_graph(); + auto g2 = loco::make_graph(); + auto g3 = loco::make_graph(); + auto g1_ptr = g1.get(); + auto g2_ptr = g2.get(); + auto g3_ptr = g3.get(); + + m->add(std::move(g1)); + m->add(std::move(g2)); + m->add(std::move(g3)); + + ASSERT_EQ(3, m->size()); + ASSERT_EQ(g1_ptr, m->graph()); + ASSERT_EQ(g1_ptr, m->graph(0)); + ASSERT_EQ(g2_ptr, m->graph(1)); + ASSERT_EQ(g3_ptr, m->graph(2)); + + // Let's delete graph at second position + m->removeGraphByIndex(1); + ASSERT_EQ(2, m->size()); + ASSERT_EQ(g1_ptr, m->graph(0)); + ASSERT_EQ(g3_ptr, m->graph(1)); +} + TEST(ModuleTest, add_more) { auto m = luci::make_module(); @@ -65,6 +92,13 @@ TEST(ModuleTest, add_nullptr_NEG) EXPECT_THROW(m->add(nullptr), std::invalid_argument); } +TEST(ModuleTest, remove_index_overflow_NEG) +{ + auto m = luci::make_module(); + + EXPECT_THROW(m->removeGraphByIndex(10), std::invalid_argument); +} + TEST(ModuleTest, graph_index_overflow_NEG) { auto m = luci::make_module();