Skip to content

Commit

Permalink
DRAFT CFe to support huge 2G model
Browse files Browse the repository at this point in the history
on-going draft to support huge size > 2G model.

Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark committed Sep 2, 2024
1 parent d96ddd3 commit c1802c2
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 40 deletions.
1 change: 0 additions & 1 deletion compiler/fme-apply/driver/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <foder/FileLoader.h>
#include <luci/CircleExporter.h>
#include <luci/CircleFileExpContract.h>
#include <luci/Importer.h>
#include <luci/ImporterEx.h>
#include <luci/Service/Validate.h>

Expand Down
1 change: 0 additions & 1 deletion compiler/fme-detect/driver/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <foder/FileLoader.h>
#include <luci/CircleExporter.h>
#include <luci/CircleFileExpContract.h>
#include <luci/Importer.h>
#include <luci/ImporterEx.h>
#include <luci/Service/Validate.h>

Expand Down
104 changes: 101 additions & 3 deletions compiler/luci/export/src/CircleExporterImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,31 @@ void CircleExporterImpl::exportModule(Module *module)
// prepare model data
prepareModelData(_builder, md);

exportModuleData(module, md);
// if source is extended buffer mode, force export to use extended buffer
md._ext_buffer = module->ext_buffer();

if (!exportModuleData(module, md) && md._require_ext_buffer)
{
assert(md._ext_buffer == false);

// do some cleanups for re-run
_builder.Clear();
for (size_t g = 0; g < module->size(); ++g)
{
auto graph = module->graph(g);
clearExportInfo(graph);
}
prepareModelData(_builder, md);

// run again with ext_buffer mode
md._ext_buffer = true;
exportModuleData(module, md);
}

finalizeWithExtendedBuffer(md);
}

void CircleExporterImpl::exportModuleData(Module *module, SerializedModelData &md)
bool CircleExporterImpl::exportModuleData(Module *module, SerializedModelData &md)
{
std::vector<flatbuffers::Offset<circle::SubGraph>> subgraph_vec;

Expand Down Expand Up @@ -208,20 +229,97 @@ void CircleExporterImpl::exportModuleData(Module *module, SerializedModelData &m
// create array of buffers
auto buffers = _builder.CreateVector(md._buffers);

// check current total size exceeds limit
if (check_size_limit(_builder, 0))
{
md._require_ext_buffer = true;
return false;
}

// This version is taken from comment in fbs
constexpr uint32_t version = 0;

// Model
auto model_offset = CreateModel(_builder, version, operator_codes, subgraphs, description,
buffers, 0 /* metadata_buffer */, metadata);
FinishModelBuffer(_builder, model_offset);

return true;
}

void CircleExporterImpl::finalizeWithExtendedBuffer(SerializedModelData &md)
{
_ext_buffer = md._ext_buffer;
if (!_ext_buffer)
return;

_fb_data_with_ext.clear();

auto align16 = [](size_t &v) {
while (v % 16 != 0)
v++;
};

// get total memory for flatbuffer + all buffer_data
size_t result_size = _builder.GetSize();
align16(result_size);
for (auto &it : md._buffer_data_map)
{
SerializedModelData::BufferData &buffer_data = it.second;
result_size += buffer_data.size();
align16(result_size);
}
align16(result_size);
result_size += 16; // for safety

std::string result;
const char *buff_ptr = reinterpret_cast<const char *>(_builder.GetBufferPointer());

auto padalign16 = [](std::string &str) {
while (str.size() % 16 != 0)
str += '\0';
};

result.reserve(result_size);
result.append(buff_ptr, _builder.GetSize());

auto mutable_model = circle::GetMutableModel(result.data());
auto mutable_buffers = mutable_model->mutable_buffers();

// pad to be 16 bytes aligned
padalign16(result);
for (auto &it : md._buffer_data_map)
{
int32_t buffer_index = it.first;
SerializedModelData::BufferData &buffer_data = it.second;
uint64_t offset = result.size();
uint64_t size = buffer_data.size();

circle::Buffer *mutable_buffer = mutable_buffers->GetMutableObject(buffer_index);
mutable_buffer->mutate_offset(offset);
mutable_buffer->mutate_size(size);

result.append(buffer_data.begin(), buffer_data.end());
padalign16(result);
}
padalign16(result);

// use final result
_fb_data_with_ext = result;
}

const char *CircleExporterImpl::getBufferPointer() const
{
if (_ext_buffer)
return reinterpret_cast<const char *>(_fb_data_with_ext.data());
return reinterpret_cast<const char *>(_builder.GetBufferPointer());
}

size_t CircleExporterImpl::getBufferSize() const { return _builder.GetSize(); }
size_t CircleExporterImpl::getBufferSize() const
{
if (_ext_buffer)
return _fb_data_with_ext.size();
return _builder.GetSize();
}

} // namespace luci
9 changes: 8 additions & 1 deletion compiler/luci/export/src/CircleExporterImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,17 @@ class CircleExporterImpl
/**
* @brief implementation that writes Module into internal buffer
*/
void exportModuleData(Module *module, SerializedModelData &md);
bool exportModuleData(Module *module, SerializedModelData &md);

/**
* @brief finalizes file stream with extended buffer from internal buffer
*/
void finalizeWithExtendedBuffer(SerializedModelData &md);

private:
flatbuffers::FlatBufferBuilder _builder;
bool _ext_buffer = false;
std::string _fb_data_with_ext;
};

} // namespace luci
Expand Down
9 changes: 9 additions & 0 deletions compiler/luci/export/src/CircleExporterUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

#include <mio/circle/schema_generated.h>

// limitation of current flatbuffers file size
inline constexpr unsigned int FLATBUFFERS_SIZE_MAX = 2147483648;

namespace luci
{

Expand Down Expand Up @@ -60,6 +63,12 @@ void set_tensor_index(loco::Node *node, const CircleTensorIndex &tensor_id);
void clear_tensor_index(loco::Node *node);
CircleTensorIndex get_tensor_index(loco::Node *node);

// check if Flatbuffer builder can no longer hold the given amount of the data
inline bool check_size_limit(const flatbuffers::FlatBufferBuilder &fb, const uint64_t data_size)
{
return data_size > FLATBUFFERS_SIZE_MAX - fb.GetSize();
}

} // namespace luci

#endif // __CIRCLE_EXPORTER_UTILS_H__
29 changes: 26 additions & 3 deletions compiler/luci/export/src/CircleTensorExporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "CircleTensorExporter.h"
#include "CircleExporterUtils.h"

#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
Expand Down Expand Up @@ -346,7 +347,7 @@ flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder,

template <loco::DataType DT>
flatbuffers::Offset<circle::Buffer>
encodeOpBufferByDType(FlatBufferBuilder &builder, SerializedModelData &, luci::CircleConst *c)
encodeOpBufferByDType(FlatBufferBuilder &builder, SerializedModelData &md, luci::CircleConst *c)
{
using NativeType = typename loco::DataTypeImpl<DT>::Type;

Expand All @@ -358,6 +359,26 @@ encodeOpBufferByDType(FlatBufferBuilder &builder, SerializedModelData &, luci::C
raw_data.push_back(c->at<DT>(i));
}
const size_t raw_size = size * sizeof(NativeType);

if (md._ext_buffer)
{
// TODO optimize this if this operation takes long or much memory
SerializedModelData::BufferData buffer_data;
buffer_data.resize(raw_size);
std::memcpy(buffer_data.data(), raw_data.data(), raw_size);

int32_t buffer_index = md._buffers.size();
md._buffer_data_map.emplace(buffer_index, buffer_data);

// create fake indicator buffer
return circle::CreateBuffer(builder, 0 /* data */, 1 /* offset */, 1 /* size */);
}
if (check_size_limit(builder, raw_size))
{
md._require_ext_buffer = true;
return md._empty_buffer;
}

auto array_offset = builder.CreateVector(reinterpret_cast<uint8_t *>(raw_data.data()), raw_size);
return CreateBuffer(builder, array_offset);
}
Expand Down Expand Up @@ -658,14 +679,16 @@ namespace luci

void prepareModelData(FlatBufferBuilder &builder, SerializedModelData &md)
{
md.clear();

// add one empty buffer
// note: this follows TFLite
// note: there's a comment in tflite fbs file
// - Note the 0th entry of this array must be an empty buffer (sentinel).
// - This is a convention so that tensors without a buffer can provide 0 as
// - their buffer.
auto buffer = encodeOpBuffer(builder);
md._buffers.push_back(buffer);
md._empty_buffer = encodeOpBuffer(builder);
md._buffers.push_back(md._empty_buffer);
}

void exportOpDefinedTensors(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &md,
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/export/src/SerializedData.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct SerializedModelData final

std::unordered_map<OpCode, uint32_t> _operator_codes;
std::vector<flatbuffers::Offset<circle::Buffer>> _buffers;
flatbuffers::Offset<circle::Buffer> _empty_buffer;
CircleExportMetadata _metadata;

// This is used for removing buffers with same values
Expand Down
5 changes: 3 additions & 2 deletions compiler/luci/import/include/luci/Importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class Importer final
// DO NOTHING
}

public:
// TODO move to private
private:
std::unique_ptr<Module> importModule(const circle::Model *model) const;

public:
std::unique_ptr<Module> importModule(const uint8_t *data, size_t size);

private:
Expand Down
23 changes: 10 additions & 13 deletions compiler/luci/import/src/Importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace
{

void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &reader,
loco::Graph *graph)
loco::Graph *graph, bool &ext_buffer)
{
LOGGER(l);

Expand Down Expand Up @@ -242,6 +242,8 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
auto dtype = luci::luci_datatype(tensor->type());
graph_output->dtype(dtype);
}

ext_buffer = gb_context.ext_buffer();
}

class ValidateCollector final : public loco::ErrorListener
Expand Down Expand Up @@ -277,17 +279,8 @@ std::unique_ptr<Module> Importer::importModule(const circle::Model *model) const
}

CircleReader reader;
if (_file_data && _file_size)
{
if (!reader.parse(model, _file_data, _file_size))
return nullptr;
}
else
{
// TODO remove this
if (!reader.parse(model))
return nullptr;
}
if (!reader.parse(model, _file_data, _file_size))
return nullptr;

for (uint32_t g = 0; g < reader.num_subgraph(); ++g)
{
Expand All @@ -299,7 +292,8 @@ std::unique_ptr<Module> Importer::importModule(const circle::Model *model) const
graph->name(reader.name());

// Convert circle::Model to loco::Graph
convert_graph(*source_ptr, reader, graph.get());
bool graph_ext_buffer = false;
convert_graph(*source_ptr, reader, graph.get(), graph_ext_buffer);

LOGGER(l);
VERBOSE(l, 3) << "--- graph dump begin -------------------------------------------";
Expand All @@ -310,6 +304,9 @@ std::unique_ptr<Module> Importer::importModule(const circle::Model *model) const
assert(loco::valid(graph.get(), std::make_unique<ValidateCollector>()));

module->add(std::move(graph));

if (graph_ext_buffer)
module->ext_buffer(true);
}

post_import_graph(module.get(), reader);
Expand Down
23 changes: 14 additions & 9 deletions compiler/luci/import/src/ImporterEx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,27 @@ std::unique_ptr<Module> ImporterEx::importVerifyModule(const std::string &input_
return nullptr;
}

flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(model_data.data()), model_data.size()};
auto data_data = reinterpret_cast<uint8_t *>(model_data.data());
auto data_size = model_data.size();

flatbuffers::Verifier verifier{data_data, data_size};
if (!circle::VerifyModelBuffer(verifier))
{
std::cerr << "ERROR: Invalid input file '" << input_path << "'" << std::endl;
return nullptr;
}

const circle::Model *circle_model = circle::GetModel(model_data.data());
if (circle_model == nullptr)
{
std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
return nullptr;
}
Importer importer(_source);
return importer.importModule(data_data, data_size);
}

std::unique_ptr<Module> ImporterEx::importModule(std::vector<char> &model_data) const
{
auto data_data = reinterpret_cast<uint8_t *>(model_data.data());
auto data_size = model_data.size();

Importer importer;
return importer.importModule(circle_model);
Importer importer(_source);
return importer.importModule(data_data, data_size);
}

std::unique_ptr<Module> ImporterEx::importModule(std::vector<char> &model_data) const
Expand Down
Loading

0 comments on commit c1802c2

Please sign in to comment.