Skip to content

Commit

Permalink
add flag for toggling vulkan validation layers (#624)
Browse files Browse the repository at this point in the history
* add vulkan_validation_layers flag

* categorize SD flags

* stringify true and false for flag
  • Loading branch information
PhaneeshB authored Dec 16, 2022
1 parent a14c53a commit 7345733
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 28 deletions.
78 changes: 50 additions & 28 deletions shark/examples/shark_inference/stable_diffusion/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

##############################################################################
### Stable Diffusion Params
##############################################################################

p.add_argument(
"--prompts",
nargs="+",
Expand All @@ -18,24 +22,13 @@
help="text you don't want to see in the generated image.",
)

p.add_argument(
"--device", type=str, default="cpu", help="device to run the model."
)

p.add_argument(
"--steps",
type=int,
default=50,
help="the no. of steps to do the sampling.",
)

p.add_argument(
"--version",
type=str,
default="v2.1base",
help="Specify version of stable diffusion model",
)

p.add_argument(
"--seed",
type=int,
Expand All @@ -51,21 +44,36 @@
)

p.add_argument(
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
"--max_length",
type=int,
default=77,
help="max length of the tokenizer output.",
)

##############################################################################
### Model Config and Usage Params
##############################################################################

p.add_argument(
"--device", type=str, default="cpu", help="device to run the model."
)

p.add_argument(
"--version",
type=str,
default="v2.1base",
help="Specify version of stable diffusion model",
)

p.add_argument(
"--precision", type=str, default="fp16", help="precision to run the model."
)

p.add_argument(
"--max_length",
type=int,
default=77,
help="max length of the tokenizer output.",
"--import_mlir",
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)

p.add_argument(
Expand All @@ -82,6 +90,17 @@
help="saves the compiled flatbuffer to the local directory",
)

p.add_argument(
"--use_tuned",
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)

##############################################################################
### IREE - Vulkan supported flags
##############################################################################

p.add_argument(
"--iree-vulkan-target-triple",
type=str,
Expand All @@ -97,12 +116,22 @@
)

p.add_argument(
"--use_tuned",
"--vulkan_large_heap_block_size",
default="4294967296",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)

p.add_argument(
"--vulkan_validation_layers",
default=False,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
help="flag for disabling vulkan validation layers when benchmarking",
)

##############################################################################
### Misc. Debug and Optimization flags
##############################################################################

p.add_argument(
"--local_tank_cache",
default="",
Expand All @@ -128,17 +157,10 @@
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
)

p.add_argument(
"--vulkan_large_heap_block_size",
default="4294967296",
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
)

p.add_argument(
"--enable_rgp",
default=False,
action=argparse.BooleanOptionalAction,
help="flag for inserting debug frames between iterations for use with rgp.",
)

args = p.parse_args()
1 change: 1 addition & 0 deletions shark/examples/shark_inference/stable_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def set_iree_runtime_flags():

vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
vulkan_runtime_flags += [
Expand Down

0 comments on commit 7345733

Please sign in to comment.