Skip to content

Commit

Permalink
added multiple_reshape_input_consumers test + refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mbencer committed Oct 30, 2024
1 parent 333b5e1 commit 89271b8
Showing 1 changed file with 73 additions and 32 deletions.
105 changes: 73 additions & 32 deletions runtime/onert/backend/cpu/SharedMemoryOperands.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,25 @@
using namespace onert::backend::cpu;
using namespace onert::ir;

// Add node other than Reshape/ExpandDims/Squeeze.
// It is used for cases where Reshape input/output is not input/output on the whole model.
namespace
{
void addNotOptimizedNode(Graph *graph, const OperandIndex &input, const OperandIndex &output)
{
graph->addOperation(std::make_unique<operation::Permute>(input, output));
}
} // namespace

TEST(SharedMemoryOperands, no_shared_memory_graph)
{
auto graph = std::make_unique<Graph>();
TypeInfo data_type{DataType::FLOAT32};
const auto perm_input = graph->addOperand({4}, data_type);
const auto perm_output = graph->addOperand({4}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(perm_input, perm_output));
graph->addInput(perm_input);
graph->addOutput(perm_output);
const auto not_optim_in = graph->addOperand({4}, data_type);
const auto not_optim_out = graph->addOperand({4}, data_type);
addNotOptimizedNode(graph.get(), not_optim_in, not_optim_out);
graph->addInput(not_optim_in);
graph->addOutput(not_optim_out);
graph->verify();

const auto indexes_map = findSharedMemoryOperandIndexes(*graph);
Expand All @@ -48,9 +58,9 @@ TEST(SharedMemoryOperands, single_reshape_graph)
{
auto graph = std::make_unique<Graph>();
TypeInfo data_type{DataType::FLOAT32};
const auto perm_input = graph->addOperand({4}, data_type);
const auto not_optim_in = graph->addOperand({4}, data_type);
const auto reshape_input = graph->addOperand({4}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(perm_input, reshape_input));
addNotOptimizedNode(graph.get(), not_optim_in, reshape_input);
const auto reshape_output = graph->addOperand({2, 2}, data_type);
operation::Reshape::Param shape;
shape.new_shape = {2, 2};
Expand All @@ -59,10 +69,10 @@ TEST(SharedMemoryOperands, single_reshape_graph)
graph->addOperation(
std::make_unique<operation::Reshape>(OperandIndexSequence{reshape_input, reshape_shape},
OperandIndexSequence{reshape_output}, shape));
const auto perm2_output = graph->addOperand({2, 2}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(reshape_output, perm2_output));
graph->addInput(perm_input);
graph->addOutput(perm2_output);
const auto not_optim_out_2 = graph->addOperand({2, 2}, data_type);
addNotOptimizedNode(graph.get(), reshape_output, not_optim_out_2);
graph->addInput(not_optim_in);
graph->addOutput(not_optim_out_2);
graph->verify();

const auto indexes_map = findSharedMemoryOperandIndexes(*graph);
Expand All @@ -76,9 +86,9 @@ TEST(SharedMemoryOperands, double_reshape_graph)
{
auto graph = std::make_unique<Graph>();
TypeInfo data_type{DataType::FLOAT32};
const auto perm_input = graph->addOperand({4}, data_type);
const auto not_optim_in = graph->addOperand({4}, data_type);
const auto reshape1_input = graph->addOperand({4}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(perm_input, reshape1_input));
addNotOptimizedNode(graph.get(), not_optim_in, reshape1_input);
const auto reshape1_output = graph->addOperand({2, 2}, data_type);
operation::Reshape::Param shape;
shape.new_shape = {2, 2};
Expand All @@ -91,10 +101,10 @@ TEST(SharedMemoryOperands, double_reshape_graph)
graph->addOperation(
std::make_unique<operation::Reshape>(OperandIndexSequence{reshape1_output, reshape_shape},
OperandIndexSequence{reshape2_output}, shape));
const auto perm2_output = graph->addOperand({2, 2}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(reshape2_output, perm2_output));
graph->addInput(perm_input);
graph->addOutput(perm2_output);
const auto not_optim_out_2 = graph->addOperand({2, 2}, data_type);
addNotOptimizedNode(graph.get(), reshape2_output, not_optim_out_2);
graph->addInput(not_optim_in);
graph->addOutput(not_optim_out_2);
graph->verify();

const auto indexes_map = findSharedMemoryOperandIndexes(*graph);
Expand All @@ -112,9 +122,9 @@ TEST(SharedMemoryOperands, dyn_output_reshape_graph)
{
auto graph = std::make_unique<Graph>();
TypeInfo data_type{DataType::FLOAT32};
const auto perm_input = graph->addOperand({4}, data_type);
const auto not_optim_in = graph->addOperand({4}, data_type);
const auto reshape_input = graph->addOperand({4}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(perm_input, reshape_input));
addNotOptimizedNode(graph.get(), not_optim_in, reshape_input);
const auto reshape_output = graph->addOperand({}, data_type);
graph->operands().at(reshape_output).info().setDynamic();
operation::Reshape::Param shape;
Expand All @@ -123,10 +133,10 @@ TEST(SharedMemoryOperands, dyn_output_reshape_graph)
graph->addOperation(
std::make_unique<operation::Reshape>(OperandIndexSequence{reshape_input, reshape_shape},
OperandIndexSequence{reshape_output}, shape));
const auto perm2_output = graph->addOperand({}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(reshape_output, perm2_output));
graph->addInput(perm_input);
graph->addOutput(perm2_output);
const auto not_optim_out_2 = graph->addOperand({}, data_type);
addNotOptimizedNode(graph.get(), reshape_output, not_optim_out_2);
graph->addInput(not_optim_in);
graph->addOutput(not_optim_out_2);
graph->verify();

const auto indexes_map = findSharedMemoryOperandIndexes(*graph);
Expand All @@ -147,10 +157,10 @@ TEST(SharedMemoryOperands, model_input_reshape_graph)
graph->addOperation(
std::make_unique<operation::Reshape>(OperandIndexSequence{reshape_input, reshape_shape},
OperandIndexSequence{reshape_output}, shape));
const auto perm_output = graph->addOperand({2, 2}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(reshape_output, perm_output));
const auto not_optim_out = graph->addOperand({2, 2}, data_type);
addNotOptimizedNode(graph.get(), reshape_output, not_optim_out);
graph->addInput(reshape_input);
graph->addOutput(perm_output);
graph->addOutput(not_optim_out);
graph->verify();

const auto indexes_map = findSharedMemoryOperandIndexes(*graph);
Expand All @@ -162,19 +172,50 @@ TEST(SharedMemoryOperands, single_squeeze_graph)
{
auto graph = std::make_unique<Graph>();
TypeInfo data_type{DataType::FLOAT32};
const auto perm_input = graph->addOperand({4, 1}, data_type);
const auto not_optim_in = graph->addOperand({4, 1}, data_type);
const auto squeeze_input = graph->addOperand({4, 1}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(perm_input, squeeze_input));
addNotOptimizedNode(graph.get(), not_optim_in, squeeze_input);
const auto squeeze_output = graph->addOperand({4}, data_type);
operation::Squeeze::Param axes;
axes.dims[0] = 1;
axes.ndim = 1;
graph->addOperation(std::make_unique<operation::Squeeze>(
OperandIndexSequence{squeeze_input}, OperandIndexSequence{squeeze_output}, axes));
const auto perm2_output = graph->addOperand({4}, data_type);
graph->addOperation(std::make_unique<operation::Permute>(squeeze_output, perm2_output));
graph->addInput(perm_input);
graph->addOutput(perm2_output);
const auto not_optim_out_2 = graph->addOperand({4}, data_type);
addNotOptimizedNode(graph.get(), squeeze_output, not_optim_out_2);
graph->addInput(not_optim_in);
graph->addOutput(not_optim_out_2);
graph->verify();

const auto indexes_map = findSharedMemoryOperandIndexes(*graph);

ASSERT_EQ(indexes_map.size(), 1);
EXPECT_EQ(indexes_map.begin()->first, 2);
EXPECT_EQ(indexes_map.begin()->second, 1);
}

TEST(SharedMemoryOperands, multiple_reshape_input_consumers)
{
auto graph = std::make_unique<Graph>();
TypeInfo data_type{DataType::FLOAT32};
const auto not_optim_in = graph->addOperand({4}, data_type);
const auto reshape_input = graph->addOperand({4}, data_type);
addNotOptimizedNode(graph.get(), not_optim_in, reshape_input);
const auto reshape_output = graph->addOperand({2, 2}, data_type);
operation::Reshape::Param shape;
shape.new_shape = {2, 2};
TypeInfo shape_type{DataType::INT32};
const auto reshape_shape = graph->addOperand({2}, shape_type);
graph->addOperation(
std::make_unique<operation::Reshape>(OperandIndexSequence{reshape_input, reshape_shape},
OperandIndexSequence{reshape_output}, shape));
const auto not_optim_out_2 = graph->addOperand({2, 2}, data_type);
addNotOptimizedNode(graph.get(), reshape_output, not_optim_out_2);
const auto not_optim_out_3 = graph->addOperand({4}, data_type);
addNotOptimizedNode(graph.get(), reshape_input, not_optim_out_3);
graph->addInput(not_optim_in);
graph->addOutput(not_optim_out_2);
graph->addOutput(not_optim_out_3);
graph->verify();

const auto indexes_map = findSharedMemoryOperandIndexes(*graph);
Expand Down

0 comments on commit 89271b8

Please sign in to comment.