Skip to content

Commit

Permalink
[onert] Introduce ExtraTensorIndex (#13605)
Browse files Browse the repository at this point in the history
This PR introduces ExtraTensorIndex.
ExtraTensorIndex will be used to identify extra tensor in tensor registry.

ONE-DCO-1.0-Signed-off-by: seunghui youn <[email protected]>
  • Loading branch information
zetwhite authored Aug 8, 2024
1 parent 0a104d3 commit 036cb5d
Showing 1 changed file with 85 additions and 0 deletions.
85 changes: 85 additions & 0 deletions runtime/onert/backend/train/ExtraTensorIndex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_BACKEND_TRAIN_EXTRA_TENSOR_INDEX_H__
#define __ONERT_BACKEND_TRAIN_EXTRA_TENSOR_INDEX_H__

#include <ir/Index.h>

#include <cassert>

namespace onert
{
namespace backend
{
namespace train
{

class ExtraTensorIndex
{
public:
ExtraTensorIndex(const ir::OperationIndex &op_index, uint32_t sub_index)
: _op_index{op_index}, _sub_index{sub_index}
{
assert(op_index.valid());
}

public:
const ir::OperationIndex &op_index() const { return _op_index; }
uint32_t sub_index() const { return _sub_index; }

bool operator==(const ExtraTensorIndex &other) const
{
return _op_index == other.op_index() && _sub_index == other.sub_index();
}
bool operator!=(const ExtraTensorIndex &other) const { return !(*this == other); }

private:
ir::OperationIndex _op_index;
uint32_t _sub_index;
};

inline std::ostream &operator<<(std::ostream &o, const ExtraTensorIndex &i)
{
o << i.op_index() << "-" << i.sub_index();
return o;
}

} // namespace train
} // namespace backend
} // namespace onert

namespace std
{

template <> struct hash<onert::backend::train::ExtraTensorIndex>
{
size_t operator()(const onert::backend::train::ExtraTensorIndex &index) const noexcept
{
const auto op_index = index.op_index();
const auto sub_index = index.sub_index();

assert(sizeof(op_index) <= sizeof(uint32_t));
static_assert(sizeof(size_t) >= sizeof(uint32_t),
"ExtraTensorIndex's hash creation error, size_t size is less than uint32_t");

return (static_cast<size_t>(op_index.value())) << 16 | static_cast<size_t>(sub_index);
}
};

} // namespace std

#endif // __ONERT_BACKEND_TRAIN_EXTRA_TENSOR_INDEX_H__

0 comments on commit 036cb5d

Please sign in to comment.