From 1001baddfc95ee211ff03d3df6e9b8c15042f7d5 Mon Sep 17 00:00:00 2001 From: Xiaoyu Date: Thu, 23 Jan 2025 22:12:57 +0000 Subject: [PATCH 1/5] Add generate_config_file option to cli --- olive/cli/auto_opt.py | 5 +++++ olive/cli/base.py | 18 ++++++++++++++++++ olive/cli/capture_onnx.py | 5 +++++ olive/cli/finetune.py | 5 +++++ olive/cli/generate_adapter.py | 5 +++++ olive/cli/quantize.py | 5 +++++ olive/cli/session_params_tuning.py | 5 +++++ 7 files changed, 48 insertions(+) diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index 2c358fdc3..bc623e80c 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -14,9 +14,11 @@ add_input_model_options, add_logging_options, add_remote_options, + add_save_config_file_options, add_search_options, add_shared_cache_options, get_input_model_config, + save_config_file, update_accelerator_options, update_remote_options, update_search_options, @@ -171,6 +173,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_shared_cache_options(sub_parser) add_logging_options(sub_parser) + add_save_config_file_options(sub_parser) sub_parser.set_defaults(func=AutoOptCommand) def run(self): @@ -178,6 +181,8 @@ def run(self): with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: run_config = self.get_run_config(tempdir) + if self.args.generate_config_file: + save_config_file(run_config) olive_run(run_config) def get_run_config(self, tempdir) -> Dict: diff --git a/olive/cli/base.py b/olive/cli/base.py index aec4cfd70..ed023e15c 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -247,6 +247,16 @@ def add_logging_options(sub_parser: ArgumentParser): return sub_parser +def add_save_config_file_options(sub_parser: ArgumentParser): + """Add save config file options to the sub_parser.""" + sub_parser.add_argument( + "--generate_config_file", + action="store_true", + help="Generate a config file for the command.", + ) + return sub_parser + + def add_remote_options(sub_parser: ArgumentParser): """Add remote options to the sub_parser.""" remote_group = sub_parser @@ -354,6 +364,14 @@ def add_input_model_options( return model_group +def save_config_file(config: Dict): + """Save the config file.""" + config_file_path = Path(config["output_dir"]) / "config.json" + with open(config_file_path, "w") as f: + json.dump(config, f, indent=4) + print(f"Config file saved at {config_file_path}") + + def output_path_type(path: str) -> str: """Resolve the output path and mkdir if it doesn't exist.""" path = Path(path).resolve() diff --git a/olive/cli/capture_onnx.py b/olive/cli/capture_onnx.py index 062ca78ba..eea06b406 100644 --- a/olive/cli/capture_onnx.py +++ b/olive/cli/capture_onnx.py @@ -12,8 +12,10 @@ add_input_model_options, add_logging_options, add_remote_options, + add_save_config_file_options, add_shared_cache_options, get_input_model_config, + save_config_file, update_remote_options, update_shared_cache_options, ) @@ -142,6 +144,7 @@ def register_subcommand(parser: ArgumentParser): # remote options add_remote_options(sub_parser) add_logging_options(sub_parser) + add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) sub_parser.set_defaults(func=CaptureOnnxGraphCommand) @@ -150,6 +153,8 @@ def run(self): with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: run_config = self.get_run_config(tempdir) + if self.args.generate_config_file: + save_config_file(run_config) olive_run(run_config) def get_run_config(self, tempdir: str) -> Dict: diff --git a/olive/cli/finetune.py b/olive/cli/finetune.py index b43162d29..21791ba18 100644 --- a/olive/cli/finetune.py +++ b/olive/cli/finetune.py @@ -13,8 +13,10 @@ add_input_model_options, add_logging_options, add_remote_options, + add_save_config_file_options, add_shared_cache_options, get_input_model_config, + save_config_file, update_dataset_options, update_remote_options, update_shared_cache_options, @@ -76,6 +78,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_shared_cache_options(sub_parser) add_logging_options(sub_parser) + add_save_config_file_options(sub_parser) sub_parser.set_defaults(func=FineTuneCommand) def run(self): @@ -83,6 +86,8 @@ def run(self): with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: run_config = self.get_run_config(tempdir) + if self.args.generate_config_file: + save_config_file(run_config) olive_run(run_config) def parse_training_args(self) -> Dict: diff --git a/olive/cli/generate_adapter.py b/olive/cli/generate_adapter.py index e68b2f591..c318c5bfd 100644 --- a/olive/cli/generate_adapter.py +++ b/olive/cli/generate_adapter.py @@ -12,8 +12,10 @@ add_input_model_options, add_logging_options, add_remote_options, + add_save_config_file_options, add_shared_cache_options, get_input_model_config, + save_config_file, update_remote_options, update_shared_cache_options, ) @@ -40,6 +42,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_logging_options(sub_parser) + add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) sub_parser.set_defaults(func=GenerateAdapterCommand) @@ -48,6 +51,8 @@ def run(self): with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: run_config = self.get_run_config(tempdir) + if self.args.generate_config_file: + save_config_file(run_config) olive_run(run_config) def get_run_config(self, tempdir: str) -> Dict: diff --git a/olive/cli/quantize.py b/olive/cli/quantize.py index ca6b8ac62..8012ded8e 100644 --- a/olive/cli/quantize.py +++ b/olive/cli/quantize.py @@ -17,7 +17,9 @@ add_input_model_options, add_logging_options, add_remote_options, + add_save_config_file_options, add_shared_cache_options, + save_config_file, update_dataset_options, update_input_model_options, update_shared_cache_options, @@ -74,6 +76,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_shared_cache_options(sub_parser) add_logging_options(sub_parser) + add_save_config_file_options(sub_parser) sub_parser.set_defaults(func=QuantizeCommand) def _get_run_config(self, tempdir: str) -> Dict[str, Any]: @@ -147,6 +150,8 @@ def run(self): with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: run_config = self._get_run_config(tempdir) + if self.args.generate_config_file: + save_config_file(run_config) olive_run(run_config) diff --git a/olive/cli/session_params_tuning.py b/olive/cli/session_params_tuning.py index 53c605cbe..7466ff6d9 100644 --- a/olive/cli/session_params_tuning.py +++ b/olive/cli/session_params_tuning.py @@ -15,9 +15,11 @@ add_input_model_options, add_logging_options, add_remote_options, + add_save_config_file_options, add_shared_cache_options, get_input_model_config, is_remote_run, + save_config_file, update_accelerator_options, update_remote_options, update_shared_cache_options, @@ -100,6 +102,7 @@ def register_subcommand(parser: ArgumentParser): add_accelerator_options(sub_parser, single_provider=False) add_remote_options(sub_parser) add_logging_options(sub_parser) + add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) sub_parser.set_defaults(func=SessionParamsTuningCommand) @@ -146,6 +149,8 @@ def run(self): with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: run_config = self.get_run_config(tempdir) + if self.args.generate_config_file: + save_config_file(run_config) output = olive_run(run_config) if is_remote_run(self.args): From a9cd6b7e078061ab7038343532bb6d9579e09c1a Mon Sep 17 00:00:00 2001 From: Xiaoyu Date: Thu, 23 Jan 2025 23:52:27 +0000 Subject: [PATCH 2/5] simplify logic --- olive/cli/auto_opt.py | 10 +------- olive/cli/base.py | 27 ++++++++++++++------ olive/cli/capture_onnx.py | 10 +------- olive/cli/finetune.py | 10 +------- olive/cli/generate_adapter.py | 10 +------- olive/cli/quantize.py | 10 +------- olive/cli/session_params_tuning.py | 40 ++++++++++++------------------ 7 files changed, 40 insertions(+), 77 deletions(-) diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index bc623e80c..9809f4fd3 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -import tempfile from argparse import ArgumentParser from collections import OrderedDict from copy import deepcopy @@ -18,7 +17,6 @@ add_search_options, add_shared_cache_options, get_input_model_config, - save_config_file, update_accelerator_options, update_remote_options, update_search_options, @@ -177,13 +175,7 @@ def register_subcommand(parser: ArgumentParser): sub_parser.set_defaults(func=AutoOptCommand) def run(self): - from olive.workflows import run as olive_run - - with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: - run_config = self.get_run_config(tempdir) - if self.args.generate_config_file: - save_config_file(run_config) - olive_run(run_config) + self._run_workflow() def get_run_config(self, tempdir) -> Dict: config = deepcopy(TEMPLATE) diff --git a/olive/cli/base.py b/olive/cli/base.py index ed023e15c..fac9a0e67 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -30,6 +30,25 @@ def __init__(self, parser: ArgumentParser, args: Namespace, unknown_args: Option if unknown_args and not self.allow_unknown_args: parser.error(f"Unknown arguments: {unknown_args}") + def _run_workflow(self): + import tempfile + + from olive.workflows import run as olive_run + + with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: + run_config = self._get_run_config(tempdir) + if self.args.generate_config_file: + self._save_config_file(run_config) + return olive_run(run_config) + + @staticmethod + def _save_config_file(config: Dict): + """Save the config file.""" + config_file_path = Path(config["output_dir"]) / "config.json" + with open(config_file_path, "w") as f: + json.dump(config, f, indent=4) + print(f"Config file saved at {config_file_path}") + @staticmethod @abstractmethod def register_subcommand(parser: ArgumentParser): @@ -364,14 +383,6 @@ def add_input_model_options( return model_group -def save_config_file(config: Dict): - """Save the config file.""" - config_file_path = Path(config["output_dir"]) / "config.json" - with open(config_file_path, "w") as f: - json.dump(config, f, indent=4) - print(f"Config file saved at {config_file_path}") - - def output_path_type(path: str) -> str: """Resolve the output path and mkdir if it doesn't exist.""" path = Path(path).resolve() diff --git a/olive/cli/capture_onnx.py b/olive/cli/capture_onnx.py index eea06b406..43c7f0802 100644 --- a/olive/cli/capture_onnx.py +++ b/olive/cli/capture_onnx.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -import tempfile from argparse import ArgumentParser from copy import deepcopy from typing import Dict @@ -15,7 +14,6 @@ add_save_config_file_options, add_shared_cache_options, get_input_model_config, - save_config_file, update_remote_options, update_shared_cache_options, ) @@ -149,13 +147,7 @@ def register_subcommand(parser: ArgumentParser): sub_parser.set_defaults(func=CaptureOnnxGraphCommand) def run(self): - from olive.workflows import run as olive_run - - with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: - run_config = self.get_run_config(tempdir) - if self.args.generate_config_file: - save_config_file(run_config) - olive_run(run_config) + self._run_workflow() def get_run_config(self, tempdir: str) -> Dict: config = deepcopy(TEMPLATE) diff --git a/olive/cli/finetune.py b/olive/cli/finetune.py index 21791ba18..421610a33 100644 --- a/olive/cli/finetune.py +++ b/olive/cli/finetune.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -import tempfile from argparse import ArgumentParser from copy import deepcopy from typing import ClassVar, Dict @@ -16,7 +15,6 @@ add_save_config_file_options, add_shared_cache_options, get_input_model_config, - save_config_file, update_dataset_options, update_remote_options, update_shared_cache_options, @@ -82,13 +80,7 @@ def register_subcommand(parser: ArgumentParser): sub_parser.set_defaults(func=FineTuneCommand) def run(self): - from olive.workflows import run as olive_run - - with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: - run_config = self.get_run_config(tempdir) - if self.args.generate_config_file: - save_config_file(run_config) - olive_run(run_config) + self._run_workflow() def parse_training_args(self) -> Dict: if not self.unknown_args: diff --git a/olive/cli/generate_adapter.py b/olive/cli/generate_adapter.py index c318c5bfd..8ff750e43 100644 --- a/olive/cli/generate_adapter.py +++ b/olive/cli/generate_adapter.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -import tempfile from argparse import ArgumentParser from copy import deepcopy from typing import Dict @@ -15,7 +14,6 @@ add_save_config_file_options, add_shared_cache_options, get_input_model_config, - save_config_file, update_remote_options, update_shared_cache_options, ) @@ -47,13 +45,7 @@ def register_subcommand(parser: ArgumentParser): sub_parser.set_defaults(func=GenerateAdapterCommand) def run(self): - from olive.workflows import run as olive_run - - with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: - run_config = self.get_run_config(tempdir) - if self.args.generate_config_file: - save_config_file(run_config) - olive_run(run_config) + self._run_workflow() def get_run_config(self, tempdir: str) -> Dict: input_model_config = get_input_model_config(self.args) diff --git a/olive/cli/quantize.py b/olive/cli/quantize.py index 8012ded8e..ff823c09e 100644 --- a/olive/cli/quantize.py +++ b/olive/cli/quantize.py @@ -6,7 +6,6 @@ # ruff: noqa: T201 # ruff: noqa: RUF012 -import tempfile from argparse import ArgumentParser from copy import deepcopy from typing import Any, Dict @@ -19,7 +18,6 @@ add_remote_options, add_save_config_file_options, add_shared_cache_options, - save_config_file, update_dataset_options, update_input_model_options, update_shared_cache_options, @@ -143,16 +141,10 @@ def _get_run_config(self, tempdir: str) -> Dict[str, Any]: return config def run(self): - from olive.workflows import run as olive_run - if ("gptq" in self.args.algorithm) and (not self.args.data_name): raise ValueError("data_name is required to use gptq.") - with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: - run_config = self._get_run_config(tempdir) - if self.args.generate_config_file: - save_config_file(run_config) - olive_run(run_config) + self._run_workflow() TEMPLATE = { diff --git a/olive/cli/session_params_tuning.py b/olive/cli/session_params_tuning.py index 7466ff6d9..4d5b12342 100644 --- a/olive/cli/session_params_tuning.py +++ b/olive/cli/session_params_tuning.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import json -import tempfile from argparse import ArgumentParser from copy import deepcopy from pathlib import Path @@ -19,7 +18,6 @@ add_shared_cache_options, get_input_model_config, is_remote_run, - save_config_file, update_accelerator_options, update_remote_options, update_shared_cache_options, @@ -145,28 +143,22 @@ def get_run_config(self, tempdir) -> Dict: return config def run(self): - from olive.workflows import run as olive_run - - with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: - run_config = self.get_run_config(tempdir) - if self.args.generate_config_file: - save_config_file(run_config) - output = olive_run(run_config) - - if is_remote_run(self.args): - return - - output_path = Path(self.args.output_path).resolve() - for key, value in output.items(): - if len(value.nodes) < 1: - print(f"Tuning for {key} failed. Please set the log_level to 1 for more detailed logs.") - continue - - infer_setting_output_path = output_path / f"{key}.json" - infer_settings = value.get_model_inference_config(value.get_output_model_id()) - with infer_setting_output_path.open("w") as f: - json.dump(infer_settings, f, indent=4) - print(f"Inference session parameters are saved to {output_path}.") + output = self._run_workflow() + + if is_remote_run(self.args): + return + + output_path = Path(self.args.output_path).resolve() + for key, value in output.items(): + if len(value.nodes) < 1: + print(f"Tuning for {key} failed. Please set the log_level to 1 for more detailed logs.") + continue + + infer_setting_output_path = output_path / f"{key}.json" + infer_settings = value.get_model_inference_config(value.get_output_model_id()) + with infer_setting_output_path.open("w") as f: + json.dump(infer_settings, f, indent=4) + print(f"Inference session parameters are saved to {output_path}.") TEMPLATE = { From 12fb5fdf977fc51adfae00564c500d7bc3b91314 Mon Sep 17 00:00:00 2001 From: Xiaoyu Date: Fri, 24 Jan 2025 00:08:49 +0000 Subject: [PATCH 3/5] rename save_config_file --- olive/cli/auto_opt.py | 4 ++-- olive/cli/base.py | 8 ++++---- olive/cli/capture_onnx.py | 4 ++-- olive/cli/finetune.py | 4 ++-- olive/cli/generate_adapter.py | 4 ++-- olive/cli/quantize.py | 4 ++-- olive/cli/session_params_tuning.py | 4 ++-- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index 9809f4fd3..6eb61c91b 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -13,7 +13,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file_options, + add_save_config_file, add_search_options, add_shared_cache_options, get_input_model_config, @@ -171,7 +171,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_shared_cache_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file_options(sub_parser) + add_save_config_file(sub_parser) sub_parser.set_defaults(func=AutoOptCommand) def run(self): diff --git a/olive/cli/base.py b/olive/cli/base.py index fac9a0e67..7702442bb 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -37,7 +37,7 @@ def _run_workflow(self): with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir: run_config = self._get_run_config(tempdir) - if self.args.generate_config_file: + if self.args.save_config_file: self._save_config_file(run_config) return olive_run(run_config) @@ -266,12 +266,12 @@ def add_logging_options(sub_parser: ArgumentParser): return sub_parser -def add_save_config_file_options(sub_parser: ArgumentParser): +def add_save_config_file(sub_parser: ArgumentParser): """Add save config file options to the sub_parser.""" sub_parser.add_argument( - "--generate_config_file", + "--save_config_file", action="store_true", - help="Generate a config file for the command.", + help="Generate and save the config file for the command.", ) return sub_parser diff --git a/olive/cli/capture_onnx.py b/olive/cli/capture_onnx.py index 43c7f0802..2ac5cbdbc 100644 --- a/olive/cli/capture_onnx.py +++ b/olive/cli/capture_onnx.py @@ -11,7 +11,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file_options, + add_save_config_file, add_shared_cache_options, get_input_model_config, update_remote_options, @@ -142,7 +142,7 @@ def register_subcommand(parser: ArgumentParser): # remote options add_remote_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file_options(sub_parser) + add_save_config_file(sub_parser) add_shared_cache_options(sub_parser) sub_parser.set_defaults(func=CaptureOnnxGraphCommand) diff --git a/olive/cli/finetune.py b/olive/cli/finetune.py index 421610a33..0f896bb9f 100644 --- a/olive/cli/finetune.py +++ b/olive/cli/finetune.py @@ -12,7 +12,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file_options, + add_save_config_file, add_shared_cache_options, get_input_model_config, update_dataset_options, @@ -76,7 +76,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_shared_cache_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file_options(sub_parser) + add_save_config_file(sub_parser) sub_parser.set_defaults(func=FineTuneCommand) def run(self): diff --git a/olive/cli/generate_adapter.py b/olive/cli/generate_adapter.py index 8ff750e43..c87a0a43b 100644 --- a/olive/cli/generate_adapter.py +++ b/olive/cli/generate_adapter.py @@ -11,7 +11,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file_options, + add_save_config_file, add_shared_cache_options, get_input_model_config, update_remote_options, @@ -40,7 +40,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file_options(sub_parser) + add_save_config_file(sub_parser) add_shared_cache_options(sub_parser) sub_parser.set_defaults(func=GenerateAdapterCommand) diff --git a/olive/cli/quantize.py b/olive/cli/quantize.py index ff823c09e..33877f207 100644 --- a/olive/cli/quantize.py +++ b/olive/cli/quantize.py @@ -16,7 +16,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file_options, + add_save_config_file, add_shared_cache_options, update_dataset_options, update_input_model_options, @@ -74,7 +74,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_shared_cache_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file_options(sub_parser) + add_save_config_file(sub_parser) sub_parser.set_defaults(func=QuantizeCommand) def _get_run_config(self, tempdir: str) -> Dict[str, Any]: diff --git a/olive/cli/session_params_tuning.py b/olive/cli/session_params_tuning.py index 4d5b12342..868b444e8 100644 --- a/olive/cli/session_params_tuning.py +++ b/olive/cli/session_params_tuning.py @@ -14,7 +14,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file_options, + add_save_config_file, add_shared_cache_options, get_input_model_config, is_remote_run, @@ -100,7 +100,7 @@ def register_subcommand(parser: ArgumentParser): add_accelerator_options(sub_parser, single_provider=False) add_remote_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file_options(sub_parser) + add_save_config_file(sub_parser) add_shared_cache_options(sub_parser) sub_parser.set_defaults(func=SessionParamsTuningCommand) From 740fc991e402ea1734ba7149c9a1275b25f767e0 Mon Sep 17 00:00:00 2001 From: Xiaoyu Date: Fri, 24 Jan 2025 00:22:08 +0000 Subject: [PATCH 4/5] fix nit --- olive/cli/auto_opt.py | 4 ++-- olive/cli/base.py | 2 +- olive/cli/capture_onnx.py | 4 ++-- olive/cli/finetune.py | 4 ++-- olive/cli/generate_adapter.py | 4 ++-- olive/cli/quantize.py | 4 ++-- olive/cli/session_params_tuning.py | 4 ++-- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index 6eb61c91b..9809f4fd3 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -13,7 +13,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file, + add_save_config_file_options, add_search_options, add_shared_cache_options, get_input_model_config, @@ -171,7 +171,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_shared_cache_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file(sub_parser) + add_save_config_file_options(sub_parser) sub_parser.set_defaults(func=AutoOptCommand) def run(self): diff --git a/olive/cli/base.py b/olive/cli/base.py index 7702442bb..55f2fdaaf 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -266,7 +266,7 @@ def add_logging_options(sub_parser: ArgumentParser): return sub_parser -def add_save_config_file(sub_parser: ArgumentParser): +def add_save_config_file_options(sub_parser: ArgumentParser): """Add save config file options to the sub_parser.""" sub_parser.add_argument( "--save_config_file", diff --git a/olive/cli/capture_onnx.py b/olive/cli/capture_onnx.py index 2ac5cbdbc..43c7f0802 100644 --- a/olive/cli/capture_onnx.py +++ b/olive/cli/capture_onnx.py @@ -11,7 +11,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file, + add_save_config_file_options, add_shared_cache_options, get_input_model_config, update_remote_options, @@ -142,7 +142,7 @@ def register_subcommand(parser: ArgumentParser): # remote options add_remote_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file(sub_parser) + add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) sub_parser.set_defaults(func=CaptureOnnxGraphCommand) diff --git a/olive/cli/finetune.py b/olive/cli/finetune.py index 0f896bb9f..421610a33 100644 --- a/olive/cli/finetune.py +++ b/olive/cli/finetune.py @@ -12,7 +12,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file, + add_save_config_file_options, add_shared_cache_options, get_input_model_config, update_dataset_options, @@ -76,7 +76,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_shared_cache_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file(sub_parser) + add_save_config_file_options(sub_parser) sub_parser.set_defaults(func=FineTuneCommand) def run(self): diff --git a/olive/cli/generate_adapter.py b/olive/cli/generate_adapter.py index c87a0a43b..8ff750e43 100644 --- a/olive/cli/generate_adapter.py +++ b/olive/cli/generate_adapter.py @@ -11,7 +11,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file, + add_save_config_file_options, add_shared_cache_options, get_input_model_config, update_remote_options, @@ -40,7 +40,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file(sub_parser) + add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) sub_parser.set_defaults(func=GenerateAdapterCommand) diff --git a/olive/cli/quantize.py b/olive/cli/quantize.py index 33877f207..ff823c09e 100644 --- a/olive/cli/quantize.py +++ b/olive/cli/quantize.py @@ -16,7 +16,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file, + add_save_config_file_options, add_shared_cache_options, update_dataset_options, update_input_model_options, @@ -74,7 +74,7 @@ def register_subcommand(parser: ArgumentParser): add_remote_options(sub_parser) add_shared_cache_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file(sub_parser) + add_save_config_file_options(sub_parser) sub_parser.set_defaults(func=QuantizeCommand) def _get_run_config(self, tempdir: str) -> Dict[str, Any]: diff --git a/olive/cli/session_params_tuning.py b/olive/cli/session_params_tuning.py index 868b444e8..4d5b12342 100644 --- a/olive/cli/session_params_tuning.py +++ b/olive/cli/session_params_tuning.py @@ -14,7 +14,7 @@ add_input_model_options, add_logging_options, add_remote_options, - add_save_config_file, + add_save_config_file_options, add_shared_cache_options, get_input_model_config, is_remote_run, @@ -100,7 +100,7 @@ def register_subcommand(parser: ArgumentParser): add_accelerator_options(sub_parser, single_provider=False) add_remote_options(sub_parser) add_logging_options(sub_parser) - add_save_config_file(sub_parser) + add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) sub_parser.set_defaults(func=SessionParamsTuningCommand) From b5481fe6c8e22ec8ca88db97e6b8f99b689d1de9 Mon Sep 17 00:00:00 2001 From: Xiaoyu Date: Fri, 24 Jan 2025 00:51:15 +0000 Subject: [PATCH 5/5] Unify _get_run_config --- olive/cli/auto_opt.py | 2 +- olive/cli/capture_onnx.py | 2 +- olive/cli/finetune.py | 2 +- olive/cli/generate_adapter.py | 2 +- olive/cli/session_params_tuning.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index 9809f4fd3..664fae4cc 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -177,7 +177,7 @@ def register_subcommand(parser: ArgumentParser): def run(self): self._run_workflow() - def get_run_config(self, tempdir) -> Dict: + def _get_run_config(self, tempdir) -> Dict: config = deepcopy(TEMPLATE) olive_config = OlivePackageConfig.load_default_config() diff --git a/olive/cli/capture_onnx.py b/olive/cli/capture_onnx.py index 43c7f0802..f8fd78dad 100644 --- a/olive/cli/capture_onnx.py +++ b/olive/cli/capture_onnx.py @@ -149,7 +149,7 @@ def register_subcommand(parser: ArgumentParser): def run(self): self._run_workflow() - def get_run_config(self, tempdir: str) -> Dict: + def _get_run_config(self, tempdir: str) -> Dict: config = deepcopy(TEMPLATE) input_model_config = get_input_model_config(self.args) diff --git a/olive/cli/finetune.py b/olive/cli/finetune.py index 421610a33..792c7caf8 100644 --- a/olive/cli/finetune.py +++ b/olive/cli/finetune.py @@ -97,7 +97,7 @@ def parse_training_args(self) -> Dict: return {k: v for k, v in vars(training_args).items() if k in arg_keys} - def get_run_config(self, tempdir: str) -> Dict: + def _get_run_config(self, tempdir: str) -> Dict: input_model_config = get_input_model_config(self.args) assert input_model_config["type"].lower() == "hfmodel", "Only HfModel is supported in finetune command." diff --git a/olive/cli/generate_adapter.py b/olive/cli/generate_adapter.py index 8ff750e43..f28555c53 100644 --- a/olive/cli/generate_adapter.py +++ b/olive/cli/generate_adapter.py @@ -47,7 +47,7 @@ def register_subcommand(parser: ArgumentParser): def run(self): self._run_workflow() - def get_run_config(self, tempdir: str) -> Dict: + def _get_run_config(self, tempdir: str) -> Dict: input_model_config = get_input_model_config(self.args) assert ( input_model_config["type"].lower() == "onnxmodel" diff --git a/olive/cli/session_params_tuning.py b/olive/cli/session_params_tuning.py index 4d5b12342..147a64292 100644 --- a/olive/cli/session_params_tuning.py +++ b/olive/cli/session_params_tuning.py @@ -121,7 +121,7 @@ def _update_pass_config(self, default_pass_config) -> Dict: pass_config.update({k: args_dict[k] for k in pass_config_keys if args_dict[k] is not None}) return pass_config - def get_run_config(self, tempdir) -> Dict: + def _get_run_config(self, tempdir) -> Dict: config = deepcopy(TEMPLATE) session_params_tuning_key = ("passes", "session_params_tuning")