diff --git a/compiler/record-minmax/include/HDF5Iterator.h b/compiler/record-minmax/include/HDF5Iterator.h new file mode 100644 index 00000000000..a810aaa3f26 --- /dev/null +++ b/compiler/record-minmax/include/HDF5Iterator.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __RECORD_MINMAX_HDF5_ITERATOR_H__ +#define __RECORD_MINMAX_HDF5_ITERATOR_H__ + +#include "DataBuffer.h" +#include "DataSetIterator.h" + +#include +#include +#include + +#include +#include + +namespace record_minmax +{ + +class HDF5Iterator final : public DataSetIterator +{ +public: + HDF5Iterator(const std::string &file_path, luci::Module *module); + + bool hasNext() const override; + + std::vector next() override; + + bool check_type_shape() const override; + +private: + dio::hdf5::HDF5Importer _importer; + std::vector _input_nodes; + bool _is_raw_data = false; + uint32_t _curr_idx = 0; + uint32_t _num_data = 0; +}; + +} // namespace record_minmax + +#endif // __RECORD_MINMAX_HDF5_ITERATOR_H__ diff --git a/compiler/record-minmax/src/HDF5Iterator.cpp b/compiler/record-minmax/src/HDF5Iterator.cpp new file mode 100644 index 00000000000..aaadc674b63 --- /dev/null +++ b/compiler/record-minmax/src/HDF5Iterator.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "HDF5Iterator.h" +#include "DataBuffer.h" +#include "Utils.h" + +#include + +#include +#include + +namespace record_minmax +{ + +HDF5Iterator::HDF5Iterator(const std::string &file_path, luci::Module *module) + : _importer(file_path) +{ + try + { + _importer.importGroup("value"); + + _is_raw_data = _importer.isRawData(); + + _num_data = _importer.numData(); + } + catch (const H5::Exception &e) + { + H5::Exception::printErrorStack(); + throw std::runtime_error("HDF5 error occurred during initialization."); + } + + auto input_nodes = loco::input_nodes(module->graph()); + for (auto input_node : input_nodes) + { + const auto cnode = loco::must_cast(input_node); + _input_nodes.emplace_back(cnode); + } +} + +bool HDF5Iterator::hasNext() const { return _curr_idx < _num_data; } + +std::vector HDF5Iterator::next() +{ + std::vector res; + + try + { + for (int32_t input_idx = 0; input_idx < _importer.numInputs(_curr_idx); input_idx++) + { + DataBuffer buf; + + const auto input_node = _input_nodes.at(input_idx); + const auto input_size = getTensorSize(input_node); + buf.data.resize(input_size); + + if (check_type_shape()) + { + _importer.readTensor(_curr_idx, input_idx, &buf.dtype, &buf.shape, buf.data.data(), + input_size); + } + else + { + _importer.readTensor(_curr_idx, input_idx, buf.data.data(), input_size); + } + + res.emplace_back(buf); + } + } + catch (const H5::Exception &e) + { + H5::Exception::printErrorStack(); + throw std::runtime_error("HDF5 error occurred during iteration."); + } + + _curr_idx++; // move to the next index + + return res; +} + +bool HDF5Iterator::check_type_shape() const +{ + // If it's raw data, we don't need to check type and shape + return not _is_raw_data; +} + +} // namespace record_minmax