Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF 2.8: Update converter to sync with upstream tensorflow #723

Merged
merged 2 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,16 @@ cc_library(
"tf_to_tfl_flatbuffer.h",
],
deps = [
":lce_tfl_passes",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@org_tensorflow//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_export",
"@org_tensorflow//tensorflow/compiler/mlir/lite/metrics:error_collector",
"@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:error_util",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_freeze_variables",
"@org_tensorflow//tensorflow/stream_executor/lib",
],
)
Expand All @@ -488,11 +492,7 @@ cc_library(
srcs = ["python/common.cc"],
hdrs = ["python/common.h"],
deps = [
":lce_tfl_passes",
":tf_to_tfl_flatbuffer",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"@org_tensorflow//tensorflow/compiler/mlir/lite/python:tf_tfl_flatbuffer_helpers",
"@org_tensorflow//tensorflow/core:ops",
"@pybind11",
],
Expand Down
25 changes: 7 additions & 18 deletions larq_compute_engine/mlir/python/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@ limitations under the License.

#include <exception>

#include "larq_compute_engine/mlir/tf_tfl_passes.h"
#include "larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h"
#include "larq_compute_engine/mlir/transforms/passes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/status.h"

Expand Down Expand Up @@ -77,8 +72,10 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) {
pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const int num_inputs, const bool should_quantize,
const bool mark_as_post_training_quant) {
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
llvm::Optional<tensorflow::Session*> session, const int num_inputs,
const bool should_quantize, const bool mark_as_post_training_quant) {
mlir::TFL::QuantizationSpecs quant_specs;
if (should_quantize) {
// Normally we'd only set `inference_type` to QINT8 when there are
Expand Down Expand Up @@ -118,18 +115,10 @@ pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
}
}

mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit);
tensorflow::SetCrashReproducer(pm);

tensorflow::AddTFToLCETFLConversionPasses(quant_specs, &pm, target);

// Convert back to outlined while format for export back to flatbuffer.
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());

Comment on lines -121 to -129
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this block has moved into ConvertTFExecutorToFlatbuffer, and AddTFToLCETFLConversionPasses is now split in two.
However tensorflow::SetCrashReproducer(pm);, CreateWhileOutlinePass and CreateRuntimeVerifyPass were not moved there, is that intentional? I see that CreateRuntimeVerifyPass is one of the passes inside TFToLCETFL so that's probably fine (although it's no longer the final pass).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However CreateWhileOutlinePass and CreateRuntimeVerifyPass were not moved there, is that intentional

These were moved to tf_tfl_passes.cc following tensorflow/tensorflow@6347e46

However tensorflow::SetCrashReproducer(pm); were not moved there, is that intentional

This was done following tensorflow/tensorflow@a68046e so I assume that this functionality is now handled somewhere else internally.

std::string result;
auto status = ConvertTFExecutorToFlatbuffer(
module->get(), /*export_to_mlir=*/false, &result, &pm);
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module->get(), /*export_to_mlir=*/false, target, quant_specs,
saved_model_tags, saved_model_dir, session, &result);

if (!status.ok()) {
throw std::runtime_error("Could not translate to flatbuffer.");
Expand Down
7 changes: 5 additions & 2 deletions larq_compute_engine/mlir/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "mlir/Pass/Pass.h"
#include "pybind11/pybind11.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/public/session.h"

namespace tensorflow {

Expand All @@ -13,7 +14,9 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs);
pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const int num_inputs, const bool should_quantize,
const bool mark_as_post_training_quant);
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
llvm::Optional<tensorflow::Session*> session, const int num_inputs,
const bool should_quantize, const bool mark_as_post_training_quant);

} // namespace tensorflow
4 changes: 3 additions & 1 deletion larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(

return ConvertMLIRModuleToTFLiteFlatBuffer(
&module.ValueOrDie(), context, target, default_ranges,
input_arrays.size(), should_quantize,
/*saved_model_tags=*/{},
/*saved_model_dir=*/"", /*session=*/llvm::None, input_arrays.size(),
should_quantize,
/*mark_as_post_training_quant=*/false);
}

Expand Down
7 changes: 4 additions & 3 deletions larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer(

auto target = GetLCETarget(target_str);

if (exported_names.size() != 1) {
throw std::runtime_error("Only a single exported name is supported.");
if (exported_names.empty()) {
throw std::runtime_error("Need at least one exported name.");
}

tensorflow::GraphImportConfig specs;
Expand Down Expand Up @@ -84,7 +84,8 @@ pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer(
}

return ConvertMLIRModuleToTFLiteFlatBuffer(
&module.ValueOrDie(), context, target, default_ranges, num_inputs,
&module.ValueOrDie(), context, target, default_ranges, tags,
saved_model_dir, bundle ? bundle->GetSession() : nullptr, num_inputs,
/*should_quantize=*/true,
/*mark_as_post_training_quant=*/true);
}
Expand Down
54 changes: 40 additions & 14 deletions larq_compute_engine/mlir/tf_tfl_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,33 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager) {
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePrepareQuantizePass(quant_specs));
pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLCEQuantizePass());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR threw a nice error message that recommended changing our passes to be nested passes. Looks like the overall module structure changed, which required this update.

if (quant_specs.default_ranges.first.hasValue() ||
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateLCEQuantizePass());
}
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLCEQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
} // namespace

void AddTFToLCETFLConversionPasses(
const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager, const LCETarget target) {
// This is the early part of the conversion in isolation. This enables a caller
// to inject more information in the middle of the conversion before resuming
// it.
void AddPreVariableFreezingTFToLCETFLConversionPasses(
mlir::OpPassManager* pass_manager) {
// This pass wraps all the tf.FakeQuant ops in a custom op so they are not
// folded before being converted to tfl.quantize and tfl.dequantize ops.
auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps();
Expand Down Expand Up @@ -81,7 +84,15 @@ void AddTFToLCETFLConversionPasses(
// during which resources dont get frozen in the python layer.
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFDevice::CreateDecomposeResourceOpsPass());
}

// This is the later part of the conversion in isolation. This enables a caller
// to resume the conversion after injecting more information in the middle of
// it.
void AddPostVariableFreezingTFToLCETFLConversionPasses(
llvm::StringRef saved_model_dir,
const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager, const LCETarget target) {
// Note:
// We need to fuse composite ops before LowerStaticTensorList pass.
// The tensorflow list is not supported right now by that pass.
Expand All @@ -102,7 +113,7 @@ void AddTFToLCETFLConversionPasses(

// Set the batch size of the function input to 1 and let shape inference
// propagate this in the next pass.
pass_manager->addPass(mlir::CreateSetBatchSizePass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::CreateSetBatchSizePass());

// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
Expand All @@ -120,7 +131,7 @@ void AddTFToLCETFLConversionPasses(

// Remove passthrough ops early so constant folding can happen before
// LCE ops are injected
pass_manager->addPass(mlir::TFL::CreateOpRemovalPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateOpRemovalPass());

// The following pass used to be just after createSymbolDCEPass but we move it
// before createCanonicalizerPass because without it, the tf.Sign op is not
Expand All @@ -130,6 +141,13 @@ void AddTFToLCETFLConversionPasses(
// constant ops.
pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());

if (!saved_model_dir.empty()) {
// This pass 'freezes' tf saved model asset ops and inlines as string values
// in a format of the tf constant op.
pass_manager->addPass(
mlir::tf_saved_model::CreateFreezeAssetsPass(saved_model_dir.str()));
}

// Reduce operands of TFL::While without changing the outcome.
// It needs to stay here because:
// 1. WhileOps are in TFL dialect.
Expand All @@ -156,7 +174,8 @@ void AddTFToLCETFLConversionPasses(
mlir::TF::CreateLayoutOptimizationPipeline(pass_manager->nest<mlir::FuncOp>(),
layout_optimization_options);
// Inject Larq Compute Engine Ops
pass_manager->addPass(mlir::TFL::CreatePrepareLCEPass(target));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePrepareLCEPass(target));
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect.
pass_manager->addNestedPass<mlir::FuncOp>(
Expand All @@ -172,6 +191,10 @@ void AddTFToLCETFLConversionPasses(
// control flow ops (IfOp, CaseOp).
pass_manager->addPass(mlir::createInlinerPass());

// This pass removes the asset file dependencies in hash table use cases.
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TF::CreateInitTextFileToImportPass(saved_model_dir.str()));

// This pass removes the asset file dependencies in hash table use cases.
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TF::CreateInitTextFileToImportPass());
Expand All @@ -181,11 +204,14 @@ void AddTFToLCETFLConversionPasses(
pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass());
pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateOptimizeLCEPass(target));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateOptimizePass(true));
pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target));
pass_manager->addPass(mlir::TFL::CreateBitpackWeightsLCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateOptimizeLCEPass(target));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateBitpackWeightsLCEPass());

// This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops
Expand All @@ -198,7 +224,7 @@ void AddTFToLCETFLConversionPasses(
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());

pass_manager->addPass(mlir::TFL::CreateFusePaddingPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateFusePaddingPass());

// Run quantization after all the floating point model conversion is
// completed.
Expand Down Expand Up @@ -226,7 +252,7 @@ void AddTFToLCETFLConversionPasses(
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateRuntimeVerifyPass());

pass_manager->addPass(mlir::TFL::CreateLegalizeLCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLegalizeLCEPass());
}

} // namespace tensorflow
7 changes: 5 additions & 2 deletions larq_compute_engine/mlir/tf_tfl_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@

namespace tensorflow {

// Add the TF to TFLite passes into a pass_manager.
void AddTFToLCETFLConversionPasses(
void AddPreVariableFreezingTFToLCETFLConversionPasses(
mlir::OpPassManager* pass_manager);

void AddPostVariableFreezingTFToLCETFLConversionPasses(
llvm::StringRef saved_model_dir,
const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager, const LCETarget target = LCETarget::ARM);

Expand Down
Loading