-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF 2.8: Update converter to sync with upstream tensorflow #723
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,30 +27,33 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, | |
mlir::OpPassManager* pass_manager) { | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreatePrepareQuantizePass(quant_specs)); | ||
pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLCEQuantizePass()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MLIR threw a nice error message that recommended changing our passes to be nested passes. Looks like the overall module structure changed, which required this update. |
||
if (quant_specs.default_ranges.first.hasValue() || | ||
quant_specs.default_ranges.second.hasValue()) { | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreateDefaultQuantParamsPass( | ||
quant_specs.default_ranges.first.getValueOr(0.0), | ||
quant_specs.default_ranges.second.getValueOr(0.0), | ||
quant_specs.IsSignedInferenceType())); | ||
pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreateLCEQuantizePass()); | ||
} | ||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateQuantizePass()); | ||
bool emit_quant_adaptor_ops = | ||
quant_specs.inference_type != quant_specs.inference_input_type; | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); | ||
pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLCEQuantizePass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); | ||
} | ||
} // namespace | ||
|
||
void AddTFToLCETFLConversionPasses( | ||
const mlir::TFL::QuantizationSpecs& quant_specs, | ||
mlir::OpPassManager* pass_manager, const LCETarget target) { | ||
// This is the early part of the conversion in isolation. This enables a caller | ||
// to inject more information in the middle of the conversion before resuming | ||
// it. | ||
void AddPreVariableFreezingTFToLCETFLConversionPasses( | ||
mlir::OpPassManager* pass_manager) { | ||
// This pass wraps all the tf.FakeQuant ops in a custom op so they are not | ||
// folded before being converted to tfl.quantize and tfl.dequantize ops. | ||
auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps(); | ||
|
@@ -81,7 +84,15 @@ void AddTFToLCETFLConversionPasses( | |
// during which resources dont get frozen in the python layer. | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFDevice::CreateDecomposeResourceOpsPass()); | ||
} | ||
|
||
// This is the later part of the conversion in isolation. This enables a caller | ||
// to resume the conversion after injecting more information in the middle of | ||
// it. | ||
void AddPostVariableFreezingTFToLCETFLConversionPasses( | ||
llvm::StringRef saved_model_dir, | ||
const mlir::TFL::QuantizationSpecs& quant_specs, | ||
mlir::OpPassManager* pass_manager, const LCETarget target) { | ||
// Note: | ||
// We need to fuse composite ops before LowerStaticTensorList pass. | ||
// The tensorflow list is not supported right now by that pass. | ||
|
@@ -102,7 +113,7 @@ void AddTFToLCETFLConversionPasses( | |
|
||
// Set the batch size of the function input to 1 and let shape inference | ||
// propagate this in the next pass. | ||
pass_manager->addPass(mlir::CreateSetBatchSizePass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::CreateSetBatchSizePass()); | ||
|
||
// Add a shape inference pass to optimize away the unnecessary casts. | ||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); | ||
|
@@ -120,7 +131,7 @@ void AddTFToLCETFLConversionPasses( | |
|
||
// Remove passthrough ops early so constant folding can happen before | ||
// LCE ops are injected | ||
pass_manager->addPass(mlir::TFL::CreateOpRemovalPass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateOpRemovalPass()); | ||
|
||
// The following pass used to be just after createSymbolDCEPass but we move it | ||
// before createCanonicalizerPass because without it, the tf.Sign op is not | ||
|
@@ -130,6 +141,13 @@ void AddTFToLCETFLConversionPasses( | |
// constant ops. | ||
pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); | ||
|
||
if (!saved_model_dir.empty()) { | ||
// This pass 'freezes' tf saved model asset ops and inlines as string values | ||
// in a format of the tf constant op. | ||
pass_manager->addPass( | ||
mlir::tf_saved_model::CreateFreezeAssetsPass(saved_model_dir.str())); | ||
} | ||
|
||
// Reduce operands of TFL::While without changing the outcome. | ||
// It needs to stay here because: | ||
// 1. WhileOps are in TFL dialect. | ||
|
@@ -156,7 +174,8 @@ void AddTFToLCETFLConversionPasses( | |
mlir::TF::CreateLayoutOptimizationPipeline(pass_manager->nest<mlir::FuncOp>(), | ||
layout_optimization_options); | ||
// Inject Larq Compute Engine Ops | ||
pass_manager->addPass(mlir::TFL::CreatePrepareLCEPass(target)); | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreatePrepareLCEPass(target)); | ||
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to | ||
// the TFLite dialect. | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
|
@@ -172,6 +191,10 @@ void AddTFToLCETFLConversionPasses( | |
// control flow ops (IfOp, CaseOp). | ||
pass_manager->addPass(mlir::createInlinerPass()); | ||
|
||
// This pass removes the asset file dependencies in hash table use cases. | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TF::CreateInitTextFileToImportPass(saved_model_dir.str())); | ||
|
||
// This pass removes the asset file dependencies in hash table use cases. | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TF::CreateInitTextFileToImportPass()); | ||
|
@@ -181,11 +204,14 @@ void AddTFToLCETFLConversionPasses( | |
pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass()); | ||
pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass()); | ||
pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass()); | ||
pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target)); | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreateOptimizeLCEPass(target)); | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreateOptimizePass(true)); | ||
pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target)); | ||
pass_manager->addPass(mlir::TFL::CreateBitpackWeightsLCEPass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreateOptimizeLCEPass(target)); | ||
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreateBitpackWeightsLCEPass()); | ||
|
||
// This pass operates on TensorFlow ops but is triggered after legalization | ||
// so that it can target constants introduced once TensorFlow Identity ops | ||
|
@@ -198,7 +224,7 @@ void AddTFToLCETFLConversionPasses( | |
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); | ||
|
||
pass_manager->addPass(mlir::TFL::CreateFusePaddingPass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateFusePaddingPass()); | ||
|
||
// Run quantization after all the floating point model conversion is | ||
// completed. | ||
|
@@ -226,7 +252,7 @@ void AddTFToLCETFLConversionPasses( | |
pass_manager->addNestedPass<mlir::FuncOp>( | ||
mlir::TFL::CreateRuntimeVerifyPass()); | ||
|
||
pass_manager->addPass(mlir::TFL::CreateLegalizeLCEPass()); | ||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLegalizeLCEPass()); | ||
} | ||
|
||
} // namespace tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that this block has moved into
ConvertTFExecutorToFlatbuffer
, andAddTFToLCETFLConversionPasses
is now split in two.However
tensorflow::SetCrashReproducer(pm);
,CreateWhileOutlinePass
andCreateRuntimeVerifyPass
were not moved there, is that intentional? I see thatCreateRuntimeVerifyPass
is one of the passes insideTFToLCETFL
so that's probably fine (although it's no longer the final pass).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were moved to
tf_tfl_passes.cc
following tensorflow/tensorflow@6347e46This was done following tensorflow/tensorflow@a68046e so I assume that this functionality is now handled somewhere else internally.