Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(SD) Add benchmark option and add a printer. #773

Merged
merged 5 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 64 additions & 15 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,15 @@ class PipelineComponent:
This aims to make new pipelines and execution modes easier to write, manage, and debug.
"""

def __init__(self, dest_type="devicearray", dest_dtype="float16"):
def __init__(
self, printer, dest_type="devicearray", dest_dtype="float16", benchmark=False
):
self.runner = None
self.module_name = None
self.device = None
self.metadata = None
self.benchmark = False
self.printer = printer
self.benchmark = benchmark
self.dest_type = dest_type
self.dest_dtype = dest_dtype

Expand All @@ -101,7 +104,7 @@ def load(
extra_plugin=None,
):
self.module_name = module_name
print(
self.printer.print(
f"Loading {module_name} from {vmfb_path} with external weights: {external_weight_path}."
)
self.runner = vmfbRunner(
Expand Down Expand Up @@ -222,7 +225,9 @@ def _run_and_benchmark(self, function_name, inputs: list):
start_time = time.time()
output = self._run(function_name, inputs)
latency = time.time() - start_time
print(f"Latency for {self.module_name}['{function_name}']: {latency}sec")
self.printer.print(
f"Latency for {self.module_name}['{function_name}']: {latency}sec"
)
return output

def __call__(self, function_name, inputs: list):
Expand All @@ -238,6 +243,41 @@ def __call__(self, function_name, inputs: list):
return output


class Printer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to use this instead of just import logging and use that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set it up as a logger, I used this since we had it setup nicely for tresleches full_runner.py

def __init__(self, verbose, start_time, print_time):
"""
verbose: 0 for silence, 1 for print
start_time: time of construction (or reset) of this Printer
last_print: time of last call to 'print' method
print_time: 1 to print with time prefix, 0 to not
"""
self.verbose = verbose
self.start_time = start_time
self.last_print = start_time
self.print_time = print_time

def reset(self):
if self.print_time:
if self.verbose:
self.print("Will now reset clock for printer to 0.0 [s].")
self.last_print = time.time()
self.start_time = time.time()
if self.verbose:
self.print("Clock for printer reset to t = 0.0 [s].")

def print(self, message):
if self.verbose:
# Print something like "[t=0.123 dt=0.004] 'message'"
if self.print_time:
time_now = time.time()
print(
f"[t={time_now - self.start_time:.3f} dt={time_now - self.last_print:.3f}] {message}"
)
self.last_print = time_now
else:
print(f"{message}")


class TurbinePipelineBase:
"""
This class is a lightweight base for Stable Diffusion
Expand Down Expand Up @@ -298,9 +338,13 @@ def __init__(
pipeline_dir: str = "./shark_vmfbs",
external_weights_dir: str = "./shark_weights",
hf_model_name: str | dict[str] = None,
benchmark: bool | dict[bool] = False,
verbose: bool = False,
common_export_args: dict = {},
):
self.map = model_map
self.verbose = verbose
self.printer = Printer(self.verbose, time.time(), True)
if isinstance(device, dict):
assert isinstance(
target, dict
Expand Down Expand Up @@ -329,6 +373,7 @@ def __init__(
"decomp_attn": decomp_attn,
"external_weights": external_weights,
"hf_model_name": hf_model_name,
"benchmark": benchmark,
}
for arg in map_arguments.keys():
self.map = merge_arg_into_map(self.map, map_arguments[arg], arg)
Expand Down Expand Up @@ -396,7 +441,7 @@ def prepare_all(
ready = self.is_prepared(vmfbs, weights)
match ready:
case True:
print("All necessary files found.")
self.printer.print("All necessary files found.")
return
case False:
if interactive:
Expand All @@ -407,7 +452,7 @@ def prepare_all(
exit()
for submodel in self.map.keys():
if not self.map[submodel].get("vmfb"):
print("Fetching: ", submodel)
self.printer.print("Fetching: ", submodel)
self.export_submodel(
submodel, input_mlir=self.map[submodel].get("mlir")
)
Expand Down Expand Up @@ -456,8 +501,6 @@ def is_prepared(self, vmfbs, weights):
mlir_keywords.remove(kw)
avail_files = os.listdir(pipeline_dir)
candidates = []
# print("MLIR KEYS: ", mlir_keywords)
# print("AVAILABLE FILES: ", avail_files)
for filename in avail_files:
if all(str(x) in filename for x in keywords) and not any(
x in filename for x in neg_keywords
Expand All @@ -470,8 +513,8 @@ def is_prepared(self, vmfbs, weights):
if len(candidates) == 1:
self.map[key]["vmfb"] = candidates[0]
elif len(candidates) > 1:
print(f"Multiple files found for {key}: {candidates}")
print(f"Choosing {candidates[0]} for {key}.")
self.printer.print(f"Multiple files found for {key}: {candidates}")
self.printer.print(f"Choosing {candidates[0]} for {key}.")
self.map[key]["vmfb"] = candidates[0]
else:
# vmfb not found in pipeline_dir. Add to list of files to generate.
Expand Down Expand Up @@ -503,16 +546,18 @@ def is_prepared(self, vmfbs, weights):
if len(candidates) == 1:
self.map[key]["weights"] = candidates[0]
elif len(candidates) > 1:
print(f"Multiple weight files found for {key}: {candidates}")
print(f"Choosing {candidates[0]} for {key}.")
self.printer.print(
f"Multiple weight files found for {key}: {candidates}"
)
self.printer.print(f"Choosing {candidates[0]} for {key}.")
self.map[key][weights] = candidates[0]
elif self.map[key].get("external_weights"):
# weights not found in external_weights_dir. Add to list of files to generate.
missing[key].append("weights")
if not any(x for x in missing.values()):
ready = True
else:
print("Missing files: ", missing)
self.printer.print("Missing files: ", missing)
ready = False
return ready

Expand Down Expand Up @@ -678,7 +723,7 @@ def export_submodel(
def load_map(self):
for submodel in self.map.keys():
if not self.map[submodel]["load"]:
print("Skipping load for ", submodel)
self.printer.print("Skipping load for ", submodel)
continue
self.load_submodel(submodel)

Expand All @@ -690,7 +735,11 @@ def load_submodel(self, submodel):
):
raise ValueError(f"Weights not found for {submodel}.")
dest_type = self.map[submodel].get("dest_type", "devicearray")
self.map[submodel]["runner"] = PipelineComponent(dest_type=dest_type)
self.map[submodel]["runner"] = PipelineComponent(
printer=self.printer,
dest_type=dest_type,
benchmark=self.map[submodel].get("benchmark", False),
)
self.map[submodel]["runner"].load(
self.map[submodel]["driver"],
self.map[submodel]["vmfb"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ def is_valid_file(arg):
help="Run scheduling on native pytorch CPU backend.",
)

p.add_argument(
"--benchmark",
type=str,
default=None,
help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.",
)

##############################################################################
# SDXL Modelling Options
# These options are used to control model defining parameters for SDXL.
Expand Down Expand Up @@ -198,6 +205,7 @@ def is_valid_file(arg):

p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb")

p.add_argument("--verbose", "-v", action="store_true")
p.add_argument(
"--external_weights",
type=str,
Expand Down
23 changes: 21 additions & 2 deletions models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ def __init__(
scheduler_id: str = None, # compatibility only
shift: float = 1.0, # compatibility only
use_i8_punet: bool = False,
benchmark: bool | dict[bool] = False,
verbose: bool = False,
batch_prompts: bool = False,
):
common_export_args = {
"hf_model_name": None,
Expand Down Expand Up @@ -276,6 +279,8 @@ def __init__(
pipeline_dir,
external_weights_dir,
hf_model_name,
benchmark,
verbose,
common_export_args,
)
for submodel in sd_model_map:
Expand Down Expand Up @@ -329,6 +334,7 @@ def __init__(
self.base_model_name, subfolder="tokenizer_2"
),
]
self.map["text_encoder"]["export_args"]["batch_input"] = batch_prompts
self.latents_precision = self.map["unet"]["precision"]
self.scheduler_device = self.map["unet"]["device"]
self.scheduler_driver = self.map["unet"]["driver"]
Expand Down Expand Up @@ -559,7 +565,10 @@ def _produce_latents_sdxl(
[guidance_scale],
dtype=self.map["unet"]["np_dtype"],
)
for i, t in tqdm(enumerate(timesteps)):
for i, t in tqdm(
enumerate(timesteps),
disable=(self.map["unet"].get("benchmark") and self.verbose),
):
if self.cpu_scheduling:
latent_model_input, t = self.scheduler.scale_model_input(
latents,
Expand All @@ -571,7 +580,6 @@ def _produce_latents_sdxl(
latent_model_input, t = self.scheduler(
"run_scale", [latents, step, timesteps]
)

unet_inputs = [
latent_model_input,
t,
Expand Down Expand Up @@ -703,6 +711,15 @@ def numpy_to_pil_image(images):
}
if not args.pipeline_dir:
args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "")
benchmark = {}
if args.benchmark:
if args.benchmark.lower() == "all":
benchmark = True
else:
for i in args.benchmark.split(","):
benchmark[i] = True
else:
benchmark = False
if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]):
args.decomp_attn = {
"text_encoder": args.decomp_attn,
Expand Down Expand Up @@ -731,6 +748,8 @@ def numpy_to_pil_image(images):
args.scheduler_id,
None,
args.use_i8_punet,
benchmark,
args.verbose,
)
sd_pipe.prepare_all()
sd_pipe.load_map()
Expand Down
Loading