diff --git a/larq_compute_engine/mlir/BUILD b/larq_compute_engine/mlir/BUILD index 57060244..cd1e4f6b 100644 --- a/larq_compute_engine/mlir/BUILD +++ b/larq_compute_engine/mlir/BUILD @@ -473,12 +473,16 @@ cc_library( "tf_to_tfl_flatbuffer.h", ], deps = [ + ":lce_tfl_passes", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@org_tensorflow//tensorflow/compiler/mlir:op_or_arg_name_mapper", "@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_export", + "@org_tensorflow//tensorflow/compiler/mlir/lite/metrics:error_collector", + "@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_config", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:error_util", + "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_freeze_variables", "@org_tensorflow//tensorflow/stream_executor/lib", ], ) @@ -488,11 +492,7 @@ cc_library( srcs = ["python/common.cc"], hdrs = ["python/common.h"], deps = [ - ":lce_tfl_passes", ":tf_to_tfl_flatbuffer", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "@org_tensorflow//tensorflow/compiler/mlir/lite/python:tf_tfl_flatbuffer_helpers", "@org_tensorflow//tensorflow/core:ops", "@pybind11", ], diff --git a/larq_compute_engine/mlir/python/common.cc b/larq_compute_engine/mlir/python/common.cc index 0934e652..9f3c2548 100644 --- a/larq_compute_engine/mlir/python/common.cc +++ b/larq_compute_engine/mlir/python/common.cc @@ -18,14 +18,9 @@ limitations under the License. #include -#include "larq_compute_engine/mlir/tf_tfl_passes.h" #include "larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h" -#include "larq_compute_engine/mlir/transforms/passes.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/Pass/Pass.h" #include "pybind11/pybind11.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" @@ -77,8 +72,10 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) { pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer( mlir::OwningModuleRef* module, mlir::MLIRContext& context, const LCETarget target, const pybind11::object& default_ranges, - const int num_inputs, const bool should_quantize, - const bool mark_as_post_training_quant) { + const std::unordered_set& saved_model_tags, + llvm::StringRef saved_model_dir, + llvm::Optional session, const int num_inputs, + const bool should_quantize, const bool mark_as_post_training_quant) { mlir::TFL::QuantizationSpecs quant_specs; if (should_quantize) { // Normally we'd only set `inference_type` to QINT8 when there are @@ -118,18 +115,10 @@ pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer( } } - mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit); - tensorflow::SetCrashReproducer(pm); - - tensorflow::AddTFToLCETFLConversionPasses(quant_specs, &pm, target); - - // Convert back to outlined while format for export back to flatbuffer. - pm.addPass(mlir::TFL::CreateWhileOutlinePass()); - pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); - std::string result; - auto status = ConvertTFExecutorToFlatbuffer( - module->get(), /*export_to_mlir=*/false, &result, &pm); + auto status = ConvertTFExecutorToTFLOrFlatbuffer( + module->get(), /*export_to_mlir=*/false, target, quant_specs, + saved_model_tags, saved_model_dir, session, &result); if (!status.ok()) { throw std::runtime_error("Could not translate to flatbuffer."); diff --git a/larq_compute_engine/mlir/python/common.h b/larq_compute_engine/mlir/python/common.h index 19403d14..72e66119 100644 --- a/larq_compute_engine/mlir/python/common.h +++ b/larq_compute_engine/mlir/python/common.h @@ -3,6 +3,7 @@ #include "mlir/Pass/Pass.h" #include "pybind11/pybind11.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/public/session.h" namespace tensorflow { @@ -13,7 +14,9 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs); pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer( mlir::OwningModuleRef* module, mlir::MLIRContext& context, const LCETarget target, const pybind11::object& default_ranges, - const int num_inputs, const bool should_quantize, - const bool mark_as_post_training_quant); + const std::unordered_set& saved_model_tags, + llvm::StringRef saved_model_dir, + llvm::Optional session, const int num_inputs, + const bool should_quantize, const bool mark_as_post_training_quant); } // namespace tensorflow diff --git a/larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc b/larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc index ec2427ca..d630aabe 100644 --- a/larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc +++ b/larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc @@ -64,7 +64,9 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer( return ConvertMLIRModuleToTFLiteFlatBuffer( &module.ValueOrDie(), context, target, default_ranges, - input_arrays.size(), should_quantize, + /*saved_model_tags=*/{}, + /*saved_model_dir=*/"", /*session=*/llvm::None, input_arrays.size(), + should_quantize, /*mark_as_post_training_quant=*/false); } diff --git a/larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc b/larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc index 0e36eb19..2247bd3c 100644 --- a/larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc +++ b/larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc @@ -42,8 +42,8 @@ pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer( auto target = GetLCETarget(target_str); - if (exported_names.size() != 1) { - throw std::runtime_error("Only a single exported name is supported."); + if (exported_names.empty()) { + throw std::runtime_error("Need at least one exported name."); } tensorflow::GraphImportConfig specs; @@ -84,7 +84,8 @@ pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer( } return ConvertMLIRModuleToTFLiteFlatBuffer( - &module.ValueOrDie(), context, target, default_ranges, num_inputs, + &module.ValueOrDie(), context, target, default_ranges, tags, + saved_model_dir, bundle ? bundle->GetSession() : nullptr, num_inputs, /*should_quantize=*/true, /*mark_as_post_training_quant=*/true); } diff --git a/larq_compute_engine/mlir/tf_tfl_passes.cc b/larq_compute_engine/mlir/tf_tfl_passes.cc index 841a791c..20d62264 100644 --- a/larq_compute_engine/mlir/tf_tfl_passes.cc +++ b/larq_compute_engine/mlir/tf_tfl_passes.cc @@ -27,7 +27,7 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, mlir::OpPassManager* pass_manager) { pass_manager->addNestedPass( mlir::TFL::CreatePrepareQuantizePass(quant_specs)); - pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass()); + pass_manager->addNestedPass(mlir::TFL::CreateLCEQuantizePass()); if (quant_specs.default_ranges.first.hasValue() || quant_specs.default_ranges.second.hasValue()) { pass_manager->addNestedPass( @@ -35,22 +35,25 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, quant_specs.default_ranges.first.getValueOr(0.0), quant_specs.default_ranges.second.getValueOr(0.0), quant_specs.IsSignedInferenceType())); - pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass()); + pass_manager->addNestedPass( + mlir::TFL::CreateLCEQuantizePass()); } pass_manager->addNestedPass(mlir::TFL::CreateQuantizePass()); bool emit_quant_adaptor_ops = quant_specs.inference_type != quant_specs.inference_input_type; pass_manager->addNestedPass( mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); - pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass()); + pass_manager->addNestedPass(mlir::TFL::CreateLCEQuantizePass()); pass_manager->addNestedPass( mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); } } // namespace -void AddTFToLCETFLConversionPasses( - const mlir::TFL::QuantizationSpecs& quant_specs, - mlir::OpPassManager* pass_manager, const LCETarget target) { +// This is the early part of the conversion in isolation. This enables a caller +// to inject more information in the middle of the conversion before resuming +// it. +void AddPreVariableFreezingTFToLCETFLConversionPasses( + mlir::OpPassManager* pass_manager) { // This pass wraps all the tf.FakeQuant ops in a custom op so they are not // folded before being converted to tfl.quantize and tfl.dequantize ops. auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps(); @@ -81,7 +84,15 @@ void AddTFToLCETFLConversionPasses( // during which resources dont get frozen in the python layer. pass_manager->addNestedPass( mlir::TFDevice::CreateDecomposeResourceOpsPass()); +} +// This is the later part of the conversion in isolation. This enables a caller +// to resume the conversion after injecting more information in the middle of +// it. +void AddPostVariableFreezingTFToLCETFLConversionPasses( + llvm::StringRef saved_model_dir, + const mlir::TFL::QuantizationSpecs& quant_specs, + mlir::OpPassManager* pass_manager, const LCETarget target) { // Note: // We need to fuse composite ops before LowerStaticTensorList pass. // The tensorflow list is not supported right now by that pass. @@ -102,7 +113,7 @@ void AddTFToLCETFLConversionPasses( // Set the batch size of the function input to 1 and let shape inference // propagate this in the next pass. - pass_manager->addPass(mlir::CreateSetBatchSizePass()); + pass_manager->addNestedPass(mlir::CreateSetBatchSizePass()); // Add a shape inference pass to optimize away the unnecessary casts. pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); @@ -120,7 +131,7 @@ void AddTFToLCETFLConversionPasses( // Remove passthrough ops early so constant folding can happen before // LCE ops are injected - pass_manager->addPass(mlir::TFL::CreateOpRemovalPass()); + pass_manager->addNestedPass(mlir::TFL::CreateOpRemovalPass()); // The following pass used to be just after createSymbolDCEPass but we move it // before createCanonicalizerPass because without it, the tf.Sign op is not @@ -130,6 +141,13 @@ void AddTFToLCETFLConversionPasses( // constant ops. pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); + if (!saved_model_dir.empty()) { + // This pass 'freezes' tf saved model asset ops and inlines as string values + // in a format of the tf constant op. + pass_manager->addPass( + mlir::tf_saved_model::CreateFreezeAssetsPass(saved_model_dir.str())); + } + // Reduce operands of TFL::While without changing the outcome. // It needs to stay here because: // 1. WhileOps are in TFL dialect. @@ -156,7 +174,8 @@ void AddTFToLCETFLConversionPasses( mlir::TF::CreateLayoutOptimizationPipeline(pass_manager->nest(), layout_optimization_options); // Inject Larq Compute Engine Ops - pass_manager->addPass(mlir::TFL::CreatePrepareLCEPass(target)); + pass_manager->addNestedPass( + mlir::TFL::CreatePrepareLCEPass(target)); // Prepare for TFLite dialect, rerun canonicalization, and then legalize to // the TFLite dialect. pass_manager->addNestedPass( @@ -172,6 +191,10 @@ void AddTFToLCETFLConversionPasses( // control flow ops (IfOp, CaseOp). pass_manager->addPass(mlir::createInlinerPass()); + // This pass removes the asset file dependencies in hash table use cases. + pass_manager->addNestedPass( + mlir::TF::CreateInitTextFileToImportPass(saved_model_dir.str())); + // This pass removes the asset file dependencies in hash table use cases. pass_manager->addNestedPass( mlir::TF::CreateInitTextFileToImportPass()); @@ -181,11 +204,14 @@ void AddTFToLCETFLConversionPasses( pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass()); - pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target)); + pass_manager->addNestedPass( + mlir::TFL::CreateOptimizeLCEPass(target)); pass_manager->addNestedPass( mlir::TFL::CreateOptimizePass(true)); - pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target)); - pass_manager->addPass(mlir::TFL::CreateBitpackWeightsLCEPass()); + pass_manager->addNestedPass( + mlir::TFL::CreateOptimizeLCEPass(target)); + pass_manager->addNestedPass( + mlir::TFL::CreateBitpackWeightsLCEPass()); // This pass operates on TensorFlow ops but is triggered after legalization // so that it can target constants introduced once TensorFlow Identity ops @@ -198,7 +224,7 @@ void AddTFToLCETFLConversionPasses( pass_manager->addNestedPass(mlir::createCanonicalizerPass()); pass_manager->addNestedPass(mlir::createCSEPass()); - pass_manager->addPass(mlir::TFL::CreateFusePaddingPass()); + pass_manager->addNestedPass(mlir::TFL::CreateFusePaddingPass()); // Run quantization after all the floating point model conversion is // completed. @@ -226,7 +252,7 @@ void AddTFToLCETFLConversionPasses( pass_manager->addNestedPass( mlir::TFL::CreateRuntimeVerifyPass()); - pass_manager->addPass(mlir::TFL::CreateLegalizeLCEPass()); + pass_manager->addNestedPass(mlir::TFL::CreateLegalizeLCEPass()); } } // namespace tensorflow diff --git a/larq_compute_engine/mlir/tf_tfl_passes.h b/larq_compute_engine/mlir/tf_tfl_passes.h index 3b9dccb8..f5d3d923 100644 --- a/larq_compute_engine/mlir/tf_tfl_passes.h +++ b/larq_compute_engine/mlir/tf_tfl_passes.h @@ -9,8 +9,11 @@ namespace tensorflow { -// Add the TF to TFLite passes into a pass_manager. -void AddTFToLCETFLConversionPasses( +void AddPreVariableFreezingTFToLCETFLConversionPasses( + mlir::OpPassManager* pass_manager); + +void AddPostVariableFreezingTFToLCETFLConversionPasses( + llvm::StringRef saved_model_dir, const mlir::TFL::QuantizationSpecs& quant_specs, mlir::OpPassManager* pass_manager, const LCETarget target = LCETarget::ARM); diff --git a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc index c28d682e..e5a4a952 100644 --- a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc +++ b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc @@ -1,10 +1,14 @@ #include "larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h" +#include "larq_compute_engine/mlir/tf_tfl_passes.h" +#include "larq_compute_engine/mlir/transforms/passes.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -27,11 +31,13 @@ mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) { : mlir::WalkResult::advance(); }); if (result.wasInterrupted()) { - module.emitError( - "The graph has Control Flow V1 ops. TFLite converter doesn't support " - "Control Flow V1 ops. Consider using Control Flow V2 ops instead. See " - "https://www.tensorflow.org/api_docs/python/tf/compat/v1/" - "enable_control_flow_v2."); + mlir::TFL::AttachErrorCode( + module.emitError( + "The graph has Control Flow V1 ops. TFLite converter doesn't " + "support Control Flow V1 ops. Consider using Control Flow V2 ops " + "instead. See https://www.tensorflow.org/api_docs/python/tf/compat/" + "v1/enable_control_flow_v2."), + tflite::metrics::ConverterErrorData::ERROR_UNSUPPORTED_CONTROL_FLOW_V1); return mlir::failure(); } return mlir::success(); @@ -49,10 +55,12 @@ class TruncateOpOrArgLocNameMapper : public OpOrArgLocNameMapper { }; } // namespace - -Status ConvertTFExecutorToFlatbuffer(mlir::ModuleOp module, bool export_to_mlir, - std::string* result, - mlir::PassManager* pass_manager) { +Status ConvertTFExecutorToTFLOrFlatbuffer( + mlir::ModuleOp module, bool export_to_mlir, const LCETarget target, + mlir::TFL::QuantizationSpecs quant_specs, + const std::unordered_set& saved_model_tags, + llvm::StringRef saved_model_dir, + llvm::Optional session, std::string* result) { // Explicitly disable dumping Op details on failures. module.getContext()->printOpOnDiagnostic(false); @@ -70,25 +78,64 @@ Status ConvertTFExecutorToFlatbuffer(mlir::ModuleOp module, bool export_to_mlir, mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), /*propagate=*/true); + if (failed(IsValidGraph(module))) { + return statusHandler.ConsumeStatus(); + } + + mlir::PassManager pass_manager(module.getContext()); + mlir::applyPassManagerCLOptions(pass_manager); + pass_manager.addInstrumentation( + std::make_unique( + pass_manager.getContext())); - if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) { + tensorflow::AddPreVariableFreezingTFToLCETFLConversionPasses(&pass_manager); + if (failed(pass_manager.run(module))) { return statusHandler.ConsumeStatus(); } + // Freeze variables if a session is provided. + if (session.hasValue()) { + mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext()); + if (failed(mlir::tf_saved_model::FreezeVariables(module, + session.getValue()))) { + auto status = statusHandler.ConsumeStatus(); + mlir::TFL::ErrorCollector* collector = + mlir::TFL::ErrorCollector::GetErrorCollector(); + if (!collector->CollectedErrors().empty()) { + return errors::InvalidArgument("Variable constant folding has failed."); + } + return status; + } + } + pass_manager.clear(); + tensorflow::AddPostVariableFreezingTFToLCETFLConversionPasses( + saved_model_dir, quant_specs, &pass_manager, target); + if (failed(pass_manager.run(module))) { + auto status = statusHandler.ConsumeStatus(); + mlir::TFL::ErrorCollector* collector = + mlir::TFL::ErrorCollector::GetErrorCollector(); + for (const auto& error_data : collector->CollectedErrors()) { + if (error_data.subcomponent() == "FreezeGlobalTensorsPass") { + return errors::InvalidArgument("Variable constant folding is failed."); + } + } + return status; + } + if (export_to_mlir) { llvm::raw_string_ostream os(*result); module.print(os); - return Status::OK(); + return statusHandler.ConsumeStatus(); } - // This is the only modification compared to the upstream tensorflow file - // TODO: This is no longer the case, these files have diverged since TF2.6 + // Write MLIR TFLite dialect into FlatBuffer TruncateOpOrArgLocNameMapper op_or_arg_name_mapper; toco::TocoFlags toco_flags; toco_flags.set_force_select_tf_ops(false); toco_flags.set_allow_custom_ops(true); tflite::FlatbufferExportOptions options; options.toco_flags = toco_flags; + options.saved_model_tags = saved_model_tags; options.op_or_arg_name_mapper = &op_or_arg_name_mapper; if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) { return statusHandler.ConsumeStatus(); diff --git a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h index 3f8d5416..f1aa84d1 100644 --- a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h +++ b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h @@ -1,18 +1,24 @@ #ifndef LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_ #define LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_ +#include + +#include "larq_compute_engine/mlir/transforms/passes.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/PassManager.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/core/public/session.h" #include "tensorflow/stream_executor/lib/statusor.h" - namespace tensorflow { // This is a fork of ConvertTFExecutorToTFLOrFlatbuffer to enable custom // OpOrArgLocNameMapper -// https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h#L55-L69 -Status ConvertTFExecutorToFlatbuffer(mlir::ModuleOp module, bool export_to_mlir, - std::string* result, - mlir::PassManager* pass_manager); +// https://github.com/tensorflow/tensorflow/blob/v2.8.0/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h#L60-L78 +Status ConvertTFExecutorToTFLOrFlatbuffer( + mlir::ModuleOp module, bool export_to_mlir, const LCETarget target, + mlir::TFL::QuantizationSpecs quant_specs, + const std::unordered_set& saved_model_tags, + llvm::StringRef saved_model_dir, + llvm::Optional session, std::string* result); } // namespace tensorflow #endif // LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_