Skip to content

Commit

Permalink
[record-minmax] Introduce HDF5Iterator (#14261)
Browse files Browse the repository at this point in the history
This introduces an iterator for HDF5 format.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Oct 28, 2024
1 parent 094bd4a commit 043e9e2
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 0 deletions.
54 changes: 54 additions & 0 deletions compiler/record-minmax/include/HDF5Iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __RECORD_MINMAX_HDF5_ITERATOR_H__
#define __RECORD_MINMAX_HDF5_ITERATOR_H__

#include "DataBuffer.h"
#include "DataSetIterator.h"

#include <luci/IR/Module.h>
#include <luci/IR/CircleNodes.h>
#include <dio_hdf5/HDF5Importer.h>

#include <string>
#include <vector>

namespace record_minmax
{

class HDF5Iterator final : public DataSetIterator
{
public:
HDF5Iterator(const std::string &file_path, luci::Module *module);

bool hasNext() const override;

std::vector<DataBuffer> next() override;

bool check_type_shape() const override;

private:
dio::hdf5::HDF5Importer _importer;
std::vector<const luci::CircleInput *> _input_nodes;
bool _is_raw_data = false;
uint32_t _curr_idx = 0;
uint32_t _num_data = 0;
};

} // namespace record_minmax

#endif // __RECORD_MINMAX_HDF5_ITERATOR_H__
100 changes: 100 additions & 0 deletions compiler/record-minmax/src/HDF5Iterator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "HDF5Iterator.h"
#include "DataBuffer.h"
#include "Utils.h"

#include <luci/IR/Module.h>

#include <vector>
#include <string>

namespace record_minmax
{

HDF5Iterator::HDF5Iterator(const std::string &file_path, luci::Module *module)
: _importer(file_path)
{
try
{
_importer.importGroup("value");

_is_raw_data = _importer.isRawData();

_num_data = _importer.numData();
}
catch (const H5::Exception &e)
{
H5::Exception::printErrorStack();
throw std::runtime_error("HDF5 error occurred during initialization.");
}

auto input_nodes = loco::input_nodes(module->graph());
for (auto input_node : input_nodes)
{
const auto cnode = loco::must_cast<const luci::CircleInput *>(input_node);
_input_nodes.emplace_back(cnode);
}
}

bool HDF5Iterator::hasNext() const { return _curr_idx < _num_data; }

std::vector<DataBuffer> HDF5Iterator::next()
{
std::vector<DataBuffer> res;

try
{
for (int32_t input_idx = 0; input_idx < _importer.numInputs(_curr_idx); input_idx++)
{
DataBuffer buf;

const auto input_node = _input_nodes.at(input_idx);
const auto input_size = getTensorSize(input_node);
buf.data.resize(input_size);

if (check_type_shape())
{
_importer.readTensor(_curr_idx, input_idx, &buf.dtype, &buf.shape, buf.data.data(),
input_size);
}
else
{
_importer.readTensor(_curr_idx, input_idx, buf.data.data(), input_size);
}

res.emplace_back(buf);
}
}
catch (const H5::Exception &e)
{
H5::Exception::printErrorStack();
throw std::runtime_error("HDF5 error occurred during iteration.");
}

_curr_idx++; // move to the next index

return res;
}

bool HDF5Iterator::check_type_shape() const
{
// If it's raw data, we don't need to check type and shape
return not _is_raw_data;
}

} // namespace record_minmax

0 comments on commit 043e9e2

Please sign in to comment.