Skip to content

Commit

Permalink
[onert] Remove shape conversion function duplication (#13813)
Browse files Browse the repository at this point in the history
This commit remove duplicated shape conversion functions and change parameter to use permute type.

ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh <[email protected]>
Co-authored-by: Jiyoung Giuliana Yun <[email protected]>
  • Loading branch information
hseok-oh and jyoungyun authored Aug 29, 2024
1 parent d37c688 commit 7278225
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 120 deletions.
8 changes: 7 additions & 1 deletion runtime/onert/core/include/ir/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,13 @@ struct Shape
inline bool operator==(const Shape &lhs, const Shape &rhs) { return lhs.dims() == rhs.dims(); }
inline bool operator!=(const Shape &lhs, const Shape &rhs) { return lhs.dims() != rhs.dims(); }

Shape permuteShape(const Shape &shape, Layout frontend_layout, Layout backend_layout);
/**
* @brief Converts shape when its rank is 4
*
* @return Return a shape based on permutation type.
* If rank is not 4, input shape is returned without conversion.
*/
Shape convertShape(const Shape &shape, const PermuteType &type);

/**
* @brief Find out if tha rank in this shape is "maybe" unspecified.
Expand Down
9 changes: 3 additions & 6 deletions runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

#include "PermuteLayer.h"

#include "../../../exec/ShapeConverter.h"

#include <ruy/context.h> // from @ruy

namespace onert
Expand Down Expand Up @@ -203,14 +201,14 @@ void PermuteLayer::run()
{
auto dst_tensor = _dst_tensors.at(i);
auto src_tensor = _src_tensors.at(i);
auto permute_type = _permute_types.at(i);
if (src_tensor->is_dynamic() || dst_tensor->is_dynamic())
{
// getting output shape
auto src_shape = src_tensor->getShape();

// set output shape and output buffer
ir::Shape new_shape =
exec::convertShape(src_shape, src_tensor->layout(), dst_tensor->layout());
ir::Shape new_shape = ir::convertShape(src_shape, permute_type);

try
{
Expand All @@ -227,8 +225,7 @@ void PermuteLayer::run()
throw;
}
}
assert(exec::convertShape(src_tensor->getShape(), src_tensor->layout(), dst_tensor->layout()) ==
dst_tensor->getShape());
assert(ir::convertShape(src_tensor->getShape(), permute_type) == dst_tensor->getShape());
}
assert(_src_tensors.size() == _dst_tensors.size());
assert(_src_tensors.size() == _src_tensors_offsets.size());
Expand Down
2 changes: 0 additions & 2 deletions runtime/onert/core/src/exec/ExecutorBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

#include "ExecutorBase.h"

#include "ShapeConverter.h"

#include "util/ConfigSource.h"
#include <misc/polymorphic_downcast.h>

Expand Down
60 changes: 0 additions & 60 deletions runtime/onert/core/src/exec/ShapeConverter.cc

This file was deleted.

39 changes: 0 additions & 39 deletions runtime/onert/core/src/exec/ShapeConverter.h

This file was deleted.

24 changes: 12 additions & 12 deletions runtime/onert/core/src/ir/Shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,29 @@ uint64_t Shape::num_elements() const
std::multiplies<uint64_t>());
}

Shape permuteShape(const Shape &shape, Layout from, Layout to)
Shape convertShape(const Shape &shape, const PermuteType &type)
{
assert(shape.rank() <= Shape::kMaxRank);
Shape ret{shape};
if (from == to)
return ret;
if (shape.rank() < 4)

if (type == ir::PermuteType::COPY || shape.rank() < 4)
return ret;

// Permutation changing layout beyond 4-D is not supported yet
assert(shape.rank() <= 4);
if (from == Layout::NHWC && to == Layout::NCHW)

if (type == ir::PermuteType::NHWC_TO_NCHW)
{
ret.dim(1) = shape.dim(3);
ret.dim(2) = shape.dim(1);
ret.dim(3) = shape.dim(2);
return ret;
}
else if (from == Layout::NCHW && to == Layout::NHWC)
{
ret.dim(1) = shape.dim(2);
ret.dim(2) = shape.dim(3);
ret.dim(3) = shape.dim(1);
}
// Other cases(either `from` or `to` is UNKNOWN), just return the original shape

assert(type == ir::PermuteType::NCHW_TO_NHWC);
ret.dim(1) = shape.dim(2);
ret.dim(2) = shape.dim(3);
ret.dim(3) = shape.dim(1);
return ret;
}

Expand Down

0 comments on commit 7278225

Please sign in to comment.