Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generate_config_file option to cli #1568

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions olive/cli/auto_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import tempfile
from argparse import ArgumentParser
from collections import OrderedDict
from copy import deepcopy
Expand All @@ -14,6 +13,7 @@
add_input_model_options,
add_logging_options,
add_remote_options,
add_save_config_file_options,
add_search_options,
add_shared_cache_options,
get_input_model_config,
Expand Down Expand Up @@ -171,16 +171,13 @@ def register_subcommand(parser: ArgumentParser):
add_remote_options(sub_parser)
add_shared_cache_options(sub_parser)
add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
sub_parser.set_defaults(func=AutoOptCommand)

def run(self):
from olive.workflows import run as olive_run
self._run_workflow()

with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir:
run_config = self.get_run_config(tempdir)
olive_run(run_config)

def get_run_config(self, tempdir) -> Dict:
def _get_run_config(self, tempdir) -> Dict:
config = deepcopy(TEMPLATE)
olive_config = OlivePackageConfig.load_default_config()

Expand Down
29 changes: 29 additions & 0 deletions olive/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@ def __init__(self, parser: ArgumentParser, args: Namespace, unknown_args: Option
if unknown_args and not self.allow_unknown_args:
parser.error(f"Unknown arguments: {unknown_args}")

def _run_workflow(self):
import tempfile

from olive.workflows import run as olive_run

with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir:
run_config = self._get_run_config(tempdir)
if self.args.save_config_file:
self._save_config_file(run_config)
return olive_run(run_config)

@staticmethod
def _save_config_file(config: Dict):
"""Save the config file."""
config_file_path = Path(config["output_dir"]) / "config.json"
with open(config_file_path, "w") as f:
json.dump(config, f, indent=4)
print(f"Config file saved at {config_file_path}")

@staticmethod
@abstractmethod
def register_subcommand(parser: ArgumentParser):
Expand Down Expand Up @@ -247,6 +266,16 @@ def add_logging_options(sub_parser: ArgumentParser):
return sub_parser


def add_save_config_file_options(sub_parser: ArgumentParser):
"""Add save config file options to the sub_parser."""
sub_parser.add_argument(
"--save_config_file",
action="store_true",
xiaoyu-work marked this conversation as resolved.
Show resolved Hide resolved
help="Generate and save the config file for the command.",
)
return sub_parser


def add_remote_options(sub_parser: ArgumentParser):
"""Add remote options to the sub_parser."""
remote_group = sub_parser
Expand Down
11 changes: 4 additions & 7 deletions olive/cli/capture_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import tempfile
from argparse import ArgumentParser
from copy import deepcopy
from typing import Dict
Expand All @@ -12,6 +11,7 @@
add_input_model_options,
add_logging_options,
add_remote_options,
add_save_config_file_options,
add_shared_cache_options,
get_input_model_config,
update_remote_options,
Expand Down Expand Up @@ -142,17 +142,14 @@ def register_subcommand(parser: ArgumentParser):
# remote options
add_remote_options(sub_parser)
add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
add_shared_cache_options(sub_parser)
sub_parser.set_defaults(func=CaptureOnnxGraphCommand)

def run(self):
from olive.workflows import run as olive_run
self._run_workflow()

with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir:
run_config = self.get_run_config(tempdir)
olive_run(run_config)

def get_run_config(self, tempdir: str) -> Dict:
def _get_run_config(self, tempdir: str) -> Dict:
config = deepcopy(TEMPLATE)

input_model_config = get_input_model_config(self.args)
Expand Down
11 changes: 4 additions & 7 deletions olive/cli/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import tempfile
from argparse import ArgumentParser
from copy import deepcopy
from typing import ClassVar, Dict
Expand All @@ -13,6 +12,7 @@
add_input_model_options,
add_logging_options,
add_remote_options,
add_save_config_file_options,
add_shared_cache_options,
get_input_model_config,
update_dataset_options,
Expand Down Expand Up @@ -76,14 +76,11 @@ def register_subcommand(parser: ArgumentParser):
add_remote_options(sub_parser)
add_shared_cache_options(sub_parser)
add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
sub_parser.set_defaults(func=FineTuneCommand)

def run(self):
from olive.workflows import run as olive_run

with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir:
run_config = self.get_run_config(tempdir)
olive_run(run_config)
self._run_workflow()

def parse_training_args(self) -> Dict:
if not self.unknown_args:
Expand All @@ -100,7 +97,7 @@ def parse_training_args(self) -> Dict:

return {k: v for k, v in vars(training_args).items() if k in arg_keys}

def get_run_config(self, tempdir: str) -> Dict:
def _get_run_config(self, tempdir: str) -> Dict:
input_model_config = get_input_model_config(self.args)
assert input_model_config["type"].lower() == "hfmodel", "Only HfModel is supported in finetune command."

Expand Down
11 changes: 4 additions & 7 deletions olive/cli/generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import tempfile
from argparse import ArgumentParser
from copy import deepcopy
from typing import Dict
Expand All @@ -12,6 +11,7 @@
add_input_model_options,
add_logging_options,
add_remote_options,
add_save_config_file_options,
add_shared_cache_options,
get_input_model_config,
update_remote_options,
Expand Down Expand Up @@ -40,17 +40,14 @@ def register_subcommand(parser: ArgumentParser):

add_remote_options(sub_parser)
add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
add_shared_cache_options(sub_parser)
sub_parser.set_defaults(func=GenerateAdapterCommand)

def run(self):
from olive.workflows import run as olive_run
self._run_workflow()

with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir:
run_config = self.get_run_config(tempdir)
olive_run(run_config)

def get_run_config(self, tempdir: str) -> Dict:
def _get_run_config(self, tempdir: str) -> Dict:
input_model_config = get_input_model_config(self.args)
assert (
input_model_config["type"].lower() == "onnxmodel"
Expand Down
9 changes: 3 additions & 6 deletions olive/cli/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# ruff: noqa: T201
# ruff: noqa: RUF012

import tempfile
from argparse import ArgumentParser
from copy import deepcopy
from typing import Any, Dict
Expand All @@ -17,6 +16,7 @@
add_input_model_options,
add_logging_options,
add_remote_options,
add_save_config_file_options,
add_shared_cache_options,
update_dataset_options,
update_input_model_options,
Expand Down Expand Up @@ -74,6 +74,7 @@ def register_subcommand(parser: ArgumentParser):
add_remote_options(sub_parser)
add_shared_cache_options(sub_parser)
add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
sub_parser.set_defaults(func=QuantizeCommand)

def _get_run_config(self, tempdir: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -140,14 +141,10 @@ def _get_run_config(self, tempdir: str) -> Dict[str, Any]:
return config

def run(self):
from olive.workflows import run as olive_run

if ("gptq" in self.args.algorithm) and (not self.args.data_name):
raise ValueError("data_name is required to use gptq.")

with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir:
run_config = self._get_run_config(tempdir)
olive_run(run_config)
self._run_workflow()


TEMPLATE = {
Expand Down
41 changes: 19 additions & 22 deletions olive/cli/session_params_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import json
import tempfile
from argparse import ArgumentParser
from copy import deepcopy
from pathlib import Path
Expand All @@ -15,6 +14,7 @@
add_input_model_options,
add_logging_options,
add_remote_options,
add_save_config_file_options,
add_shared_cache_options,
get_input_model_config,
is_remote_run,
Expand Down Expand Up @@ -100,6 +100,7 @@ def register_subcommand(parser: ArgumentParser):
add_accelerator_options(sub_parser, single_provider=False)
add_remote_options(sub_parser)
add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
add_shared_cache_options(sub_parser)
sub_parser.set_defaults(func=SessionParamsTuningCommand)

Expand All @@ -120,7 +121,7 @@ def _update_pass_config(self, default_pass_config) -> Dict:
pass_config.update({k: args_dict[k] for k in pass_config_keys if args_dict[k] is not None})
return pass_config

def get_run_config(self, tempdir) -> Dict:
def _get_run_config(self, tempdir) -> Dict:
config = deepcopy(TEMPLATE)

session_params_tuning_key = ("passes", "session_params_tuning")
Expand All @@ -142,26 +143,22 @@ def get_run_config(self, tempdir) -> Dict:
return config

def run(self):
from olive.workflows import run as olive_run

with tempfile.TemporaryDirectory(prefix="olive-cli-tmp-", dir=self.args.output_path) as tempdir:
run_config = self.get_run_config(tempdir)
output = olive_run(run_config)

if is_remote_run(self.args):
return

output_path = Path(self.args.output_path).resolve()
for key, value in output.items():
if len(value.nodes) < 1:
print(f"Tuning for {key} failed. Please set the log_level to 1 for more detailed logs.")
continue

infer_setting_output_path = output_path / f"{key}.json"
infer_settings = value.get_model_inference_config(value.get_output_model_id())
with infer_setting_output_path.open("w") as f:
json.dump(infer_settings, f, indent=4)
print(f"Inference session parameters are saved to {output_path}.")
output = self._run_workflow()

if is_remote_run(self.args):
return

output_path = Path(self.args.output_path).resolve()
for key, value in output.items():
if len(value.nodes) < 1:
print(f"Tuning for {key} failed. Please set the log_level to 1 for more detailed logs.")
continue

infer_setting_output_path = output_path / f"{key}.json"
infer_settings = value.get_model_inference_config(value.get_output_model_id())
with infer_setting_output_path.open("w") as f:
json.dump(infer_settings, f, indent=4)
print(f"Inference session parameters are saved to {output_path}.")


TEMPLATE = {
Expand Down
Loading