Skip to content

Commit

Permalink
[record-minmax] Introduce DirectoryIterator (#14262)
Browse files Browse the repository at this point in the history
This introduces an iterator for directory format.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Oct 28, 2024
1 parent b443c68 commit 094bd4a
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 0 deletions.
53 changes: 53 additions & 0 deletions compiler/record-minmax/include/DirectoryIterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __RECORD_MINMAX_DIRECTORY_ITERATOR_H__
#define __RECORD_MINMAX_DIRECTORY_ITERATOR_H__

#include "DataBuffer.h"
#include "DataSetIterator.h"

#include <luci/IR/Module.h>
#include <luci/IR/CircleNodes.h>

#include <string>
#include <vector>
#include <dirent.h>

namespace record_minmax
{

class DirectoryIterator final : public DataSetIterator
{
public:
DirectoryIterator(const std::string &dir_path, luci::Module *module);

bool hasNext() const override;

std::vector<DataBuffer> next() override;

bool check_type_shape() const override;

private:
std::vector<dirent *> _entries;
uint32_t _curr_idx = 0;
std::string _dir_path;
std::vector<const luci::CircleInput *> _input_nodes;
};

} // namespace record_minmax

#endif // __RECORD_MINMAX_DIRECTORY_ITERATOR_H__
3 changes: 3 additions & 0 deletions compiler/record-minmax/include/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ uint32_t numElements(const luci::CircleNode *node);
// Return the node's output tensor size in bytes
size_t getTensorSize(const luci::CircleNode *node);

// Read data from file into buffer with specified size in bytes
void readDataFromFile(const std::string &filename, std::vector<char> &data, size_t data_size);

} // namespace record_minmax

#endif // __RECORD_MINMAX_UTILS_H__
102 changes: 102 additions & 0 deletions compiler/record-minmax/src/DirectoryIterator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "DirectoryIterator.h"
#include "DataBuffer.h"
#include "Utils.h"

#include <luci/IR/Module.h>

#include <vector>
#include <string>
#include <cstring> // For memcpy

#include <dirent.h>

namespace record_minmax
{

DirectoryIterator::DirectoryIterator(const std::string &dir_path, luci::Module *module)
: _dir_path(dir_path)
{
auto dir = opendir(dir_path.c_str());
if (not dir)
throw std::runtime_error("Cannot open directory. Please check \"" + _dir_path +
"\" is a directory.\n");

dirent *entry = nullptr;
while ((entry = readdir(dir)))
{
if (entry->d_type != DT_REG)
continue;

_entries.emplace_back(entry);
}

auto input_nodes = loco::input_nodes(module->graph());
for (auto input_node : input_nodes)
{
const auto cnode = loco::must_cast<const luci::CircleInput *>(input_node);
_input_nodes.emplace_back(cnode);
}
}

bool DirectoryIterator::hasNext() const { return _curr_idx < _entries.size(); }

std::vector<DataBuffer> DirectoryIterator::next()
{
auto entry = _entries.at(_curr_idx++);
assert(entry); // FIX_ME_UNLESS

// Get total input size
uint32_t total_input_size = 0;
for (auto input : _input_nodes)
{
const auto *input_node = loco::must_cast<const luci::CircleInput *>(input);
total_input_size += getTensorSize(input_node);
}

const std::string filename = entry->d_name;

// Read data from file to buffer
// Assumption: For a multi-input model, the binary file should have inputs concatenated in the
// same order with the input index.
std::vector<char> input_data(total_input_size);
readDataFromFile(_dir_path + "/" + filename, input_data, total_input_size);

std::vector<DataBuffer> res;

uint32_t offset = 0;
for (auto input_node : _input_nodes)
{
DataBuffer buf;

const auto input_size = getTensorSize(input_node);

buf.data.resize(input_size);
memcpy(buf.data.data(), input_data.data() + offset, input_size);

offset += input_size;

res.emplace_back(buf);
}

return res;
}

bool DirectoryIterator::check_type_shape() const { return false; }

} // namespace record_minmax
17 changes: 17 additions & 0 deletions compiler/record-minmax/src/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include <luci/IR/CircleNodes.h>
#include <luci/IR/DataTypeHelper.h>

#include <vector>
#include <string>
#include <fstream>

namespace record_minmax
{

Expand Down Expand Up @@ -46,4 +50,17 @@ size_t getTensorSize(const luci::CircleNode *node)
return numElements(node) * elem_size;
}

void readDataFromFile(const std::string &filename, std::vector<char> &data, size_t data_size)
{
assert(data.size() == data_size); // FIX_CALLER_UNLESS

std::ifstream fs(filename, std::ifstream::binary);
if (fs.fail())
throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
if (fs.read(data.data(), data_size).fail())
throw std::runtime_error("Failed to read data from file \"" + filename + "\".\n");
if (fs.peek() != EOF)
throw std::runtime_error("Input tensor size mismatches with \"" + filename + "\".\n");
}

} // namespace record_minmax

0 comments on commit 094bd4a

Please sign in to comment.