Skip to content

Commit

Permalink
[onert] Introduce CheckpointLoader to load checkpoint file
Browse files Browse the repository at this point in the history
This commit introduces CheckpointLoader class. It loads the checkpoint
file data and updates the tensor data and training information.

ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun committed Aug 28, 2024
1 parent 523d27b commit 84551f6
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 1 deletion.
3 changes: 2 additions & 1 deletion runtime/onert/api/nnfw/src/nnfw_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "loader/ModelLoader.h"
#include "loader/TFLiteLoader.h"
#include "loader/TrainInfoLoader.h"
#include "loader/train/CheckpointLoader.h"
#include "exporter/CircleExporter.h"
#include "exporter/train/CheckpointExporter.h"
#include "json/json.h"
Expand Down Expand Up @@ -1718,7 +1719,7 @@ NNFW_STATUS nnfw_session::train_import_checkpoint(const char *path)

try
{
// TODO Implement importing checkpoint
onert::loader::train::loadCheckpoint(path, _train_info, _execution);
}
catch (const std::exception &e)
{
Expand Down
53 changes: 53 additions & 0 deletions runtime/onert/core/include/loader/train/CheckpointLoader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_LOADER_TRAIN_CHECKPOINT_LOADER_H__
#define __ONERT_LOADER_TRAIN_CHECKPOINT_LOADER_H__

#include <string>
#include <memory>

namespace onert
{
namespace exec
{
class Execution;
} // namespace exec
namespace ir
{
namespace train
{
class TrainingInfo;
} // namespace train
} // namespace ir
} // namespace onert

namespace onert
{
namespace loader
{
namespace train
{

void loadCheckpoint(const std::string &filename,
const std::unique_ptr<ir::train::TrainingInfo> &train_info,
const std::unique_ptr<exec::Execution> &exec);

} // namespace train
} // namespace loader
} // namespace onert

#endif // __ONERT_LOADER_TRAIN_CHECKPOINT_LOADER_H__
108 changes: 108 additions & 0 deletions runtime/onert/core/src/loader/train/CheckpointLoader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "loader/train/CheckpointLoader.h"

#include "exec/Execution.h"
#include "ir/train/Checkpoint.h"
#include "ir/train/TrainingInfo.h"

#include <fstream>
#include <filesystem>

namespace
{

using namespace onert;
using namespace ir;
using namespace train;
using namespace checkpoint;
using namespace exec;

class CheckpointLoader final
{
public:
CheckpointLoader(const std::string &filename)
{
if (filename.empty() || !std::filesystem::exists(filename))
throw std::runtime_error{"Invalid checkpoint file"};

_file.open(filename.c_str(), std::ios::binary | std::ios::in);
if (!_file.good())
throw std::runtime_error{"Failed to open checkpoint file"};

_file.seekg(0, std::ios::end);
const unsigned long filesize = _file.tellg();
_file.seekg(0, std::ios::beg);

if (filesize < sizeof(_header))
throw std::runtime_error{"Invalid checkpoint file data"};

memset(reinterpret_cast<char *>(&_header), 0, sizeof(_header));
_file.read(reinterpret_cast<char *>(&_header), sizeof(_header));
if (_file.fail())
throw std::runtime_error{"Failed to load header data"};

if (_header.magic != checkpoint::MAGIC_NUMBER)
throw std::runtime_error{"Invalid MAGIC NUMBER"};

if (_header.schema != checkpoint::SCHEMA_VERSION)
throw std::runtime_error{"Invalid SCHEMA VERSION"};

if ((filesize - _header.other_offset) != sizeof(_footer))
throw std::runtime_error{"Invalid checkpoint file footer data"};

memset(reinterpret_cast<char *>(&_footer), 0, sizeof(_footer));
_file.seekg(_header.other_offset, std::ios::beg);
_file.read(reinterpret_cast<char *>(&_footer), sizeof(_footer));
}

~CheckpointLoader()
{
if (_file.is_open())
_file.close();
}

private:
std::ifstream _file;
checkpoint::Header _header;
checkpoint::Footer _footer;
};

} // namespace

namespace onert
{
namespace loader
{
namespace train
{

void loadCheckpoint(const std::string &filename,
const std::unique_ptr<ir::train::TrainingInfo> &train_info,
const std::unique_ptr<onert::exec::Execution> &exec)
{
CheckpointLoader loader(filename);

// TODO Load tensor data
UNUSED_RELEASE(exec);
// TODO Update step in train_info
UNUSED_RELEASE(train_info);
}

} // namespace train
} // namespace loader
} // namespace onert

0 comments on commit 84551f6

Please sign in to comment.