Skip to content

Commit

Permalink
Update converter to sync with upstream tensorflow
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed Apr 8, 2022
1 parent c23eaf6 commit 930116a
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 63 deletions.
8 changes: 4 additions & 4 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,16 @@ cc_library(
"tf_to_tfl_flatbuffer.h",
],
deps = [
":lce_tfl_passes",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@org_tensorflow//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_export",
"@org_tensorflow//tensorflow/compiler/mlir/lite/metrics:error_collector",
"@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:error_util",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_freeze_variables",
"@org_tensorflow//tensorflow/stream_executor/lib",
],
)
Expand All @@ -488,11 +492,7 @@ cc_library(
srcs = ["python/common.cc"],
hdrs = ["python/common.h"],
deps = [
":lce_tfl_passes",
":tf_to_tfl_flatbuffer",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"@org_tensorflow//tensorflow/compiler/mlir/lite/python:tf_tfl_flatbuffer_helpers",
"@org_tensorflow//tensorflow/core:ops",
"@pybind11",
],
Expand Down
25 changes: 7 additions & 18 deletions larq_compute_engine/mlir/python/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@ limitations under the License.

#include <exception>

#include "larq_compute_engine/mlir/tf_tfl_passes.h"
#include "larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h"
#include "larq_compute_engine/mlir/transforms/passes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/status.h"

Expand Down Expand Up @@ -77,8 +72,10 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) {
pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const int num_inputs, const bool should_quantize,
const bool mark_as_post_training_quant) {
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
llvm::Optional<tensorflow::Session*> session, const int num_inputs,
const bool should_quantize, const bool mark_as_post_training_quant) {
mlir::TFL::QuantizationSpecs quant_specs;
if (should_quantize) {
// Normally we'd only set `inference_type` to QINT8 when there are
Expand Down Expand Up @@ -118,18 +115,10 @@ pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
}
}

mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit);
tensorflow::SetCrashReproducer(pm);

tensorflow::AddTFToLCETFLConversionPasses(quant_specs, &pm, target);

// Convert back to outlined while format for export back to flatbuffer.
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());

std::string result;
auto status = ConvertTFExecutorToFlatbuffer(
module->get(), /*export_to_mlir=*/false, &result, &pm);
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module->get(), /*export_to_mlir=*/false, target, quant_specs,
saved_model_tags, saved_model_dir, session, &result);

if (!status.ok()) {
throw std::runtime_error("Could not translate to flatbuffer.");
Expand Down
7 changes: 5 additions & 2 deletions larq_compute_engine/mlir/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "mlir/Pass/Pass.h"
#include "pybind11/pybind11.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/public/session.h"

namespace tensorflow {

Expand All @@ -13,7 +14,9 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs);
pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const int num_inputs, const bool should_quantize,
const bool mark_as_post_training_quant);
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
llvm::Optional<tensorflow::Session*> session, const int num_inputs,
const bool should_quantize, const bool mark_as_post_training_quant);

} // namespace tensorflow
4 changes: 3 additions & 1 deletion larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer(

return ConvertMLIRModuleToTFLiteFlatBuffer(
&module.ValueOrDie(), context, target, default_ranges,
input_arrays.size(), should_quantize,
/*saved_model_tags=*/{},
/*saved_model_dir=*/"", /*session=*/llvm::None, input_arrays.size(),
should_quantize,
/*mark_as_post_training_quant=*/false);
}

Expand Down
7 changes: 4 additions & 3 deletions larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer(

auto target = GetLCETarget(target_str);

if (exported_names.size() != 1) {
throw std::runtime_error("Only a single exported name is supported.");
if (exported_names.empty()) {
throw std::runtime_error("Need at least one exported name.");
}

tensorflow::GraphImportConfig specs;
Expand Down Expand Up @@ -84,7 +84,8 @@ pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer(
}

return ConvertMLIRModuleToTFLiteFlatBuffer(
&module.ValueOrDie(), context, target, default_ranges, num_inputs,
&module.ValueOrDie(), context, target, default_ranges, tags,
saved_model_dir, bundle ? bundle->GetSession() : nullptr, num_inputs,
/*should_quantize=*/true,
/*mark_as_post_training_quant=*/true);
}
Expand Down
54 changes: 40 additions & 14 deletions larq_compute_engine/mlir/tf_tfl_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,33 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager) {
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePrepareQuantizePass(quant_specs));
pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLCEQuantizePass());
if (quant_specs.default_ranges.first.hasValue() ||
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateLCEQuantizePass());
}
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLCEQuantizePass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
} // namespace

void AddTFToLCETFLConversionPasses(
const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager, const LCETarget target) {
// This is the early part of the conversion in isolation. This enables a caller
// to inject more information in the middle of the conversion before resuming
// it.
void AddPreVariableFreezingTFToLCETFLConversionPasses(
mlir::OpPassManager* pass_manager) {
// This pass wraps all the tf.FakeQuant ops in a custom op so they are not
// folded before being converted to tfl.quantize and tfl.dequantize ops.
auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps();
Expand Down Expand Up @@ -81,7 +84,15 @@ void AddTFToLCETFLConversionPasses(
// during which resources dont get frozen in the python layer.
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFDevice::CreateDecomposeResourceOpsPass());
}

// This is the later part of the conversion in isolation. This enables a caller
// to resume the conversion after injecting more information in the middle of
// it.
void AddPostVariableFreezingTFToLCETFLConversionPasses(
llvm::StringRef saved_model_dir,
const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager, const LCETarget target) {
// Note:
// We need to fuse composite ops before LowerStaticTensorList pass.
// The tensorflow list is not supported right now by that pass.
Expand All @@ -102,7 +113,7 @@ void AddTFToLCETFLConversionPasses(

// Set the batch size of the function input to 1 and let shape inference
// propagate this in the next pass.
pass_manager->addPass(mlir::CreateSetBatchSizePass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::CreateSetBatchSizePass());

// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
Expand All @@ -120,7 +131,7 @@ void AddTFToLCETFLConversionPasses(

// Remove passthrough ops early so constant folding can happen before
// LCE ops are injected
pass_manager->addPass(mlir::TFL::CreateOpRemovalPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateOpRemovalPass());

// The following pass used to be just after createSymbolDCEPass but we move it
// before createCanonicalizerPass because without it, the tf.Sign op is not
Expand All @@ -130,6 +141,13 @@ void AddTFToLCETFLConversionPasses(
// constant ops.
pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());

if (!saved_model_dir.empty()) {
// This pass 'freezes' tf saved model asset ops and inlines as string values
// in a format of the tf constant op.
pass_manager->addPass(
mlir::tf_saved_model::CreateFreezeAssetsPass(saved_model_dir.str()));
}

// Reduce operands of TFL::While without changing the outcome.
// It needs to stay here because:
// 1. WhileOps are in TFL dialect.
Expand All @@ -156,7 +174,8 @@ void AddTFToLCETFLConversionPasses(
mlir::TF::CreateLayoutOptimizationPipeline(pass_manager->nest<mlir::FuncOp>(),
layout_optimization_options);
// Inject Larq Compute Engine Ops
pass_manager->addPass(mlir::TFL::CreatePrepareLCEPass(target));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreatePrepareLCEPass(target));
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect.
pass_manager->addNestedPass<mlir::FuncOp>(
Expand All @@ -172,6 +191,10 @@ void AddTFToLCETFLConversionPasses(
// control flow ops (IfOp, CaseOp).
pass_manager->addPass(mlir::createInlinerPass());

// This pass removes the asset file dependencies in hash table use cases.
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TF::CreateInitTextFileToImportPass(saved_model_dir.str()));

// This pass removes the asset file dependencies in hash table use cases.
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TF::CreateInitTextFileToImportPass());
Expand All @@ -181,11 +204,14 @@ void AddTFToLCETFLConversionPasses(
pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass());
pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateOptimizeLCEPass(target));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateOptimizePass(true));
pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target));
pass_manager->addPass(mlir::TFL::CreateBitpackWeightsLCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateOptimizeLCEPass(target));
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateBitpackWeightsLCEPass());

// This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops
Expand All @@ -198,7 +224,7 @@ void AddTFToLCETFLConversionPasses(
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());

pass_manager->addPass(mlir::TFL::CreateFusePaddingPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateFusePaddingPass());

// Run quantization after all the floating point model conversion is
// completed.
Expand Down Expand Up @@ -226,7 +252,7 @@ void AddTFToLCETFLConversionPasses(
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateRuntimeVerifyPass());

pass_manager->addPass(mlir::TFL::CreateLegalizeLCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLegalizeLCEPass());
}

} // namespace tensorflow
7 changes: 5 additions & 2 deletions larq_compute_engine/mlir/tf_tfl_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@

namespace tensorflow {

// Add the TF to TFLite passes into a pass_manager.
void AddTFToLCETFLConversionPasses(
void AddPreVariableFreezingTFToLCETFLConversionPasses(
mlir::OpPassManager* pass_manager);

void AddPostVariableFreezingTFToLCETFLConversionPasses(
llvm::StringRef saved_model_dir,
const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager, const LCETarget target = LCETarget::ARM);

Expand Down
Loading

0 comments on commit 930116a

Please sign in to comment.