Skip to content

Commit

Permalink
Flag updates and parametrize a few more args.
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Aug 13, 2024
1 parent e554da8 commit d23a45b
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 22 deletions.
1 change: 0 additions & 1 deletion models/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ diffusers @ git+https://github.com/nod-ai/[email protected]
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
# turbine tank downloading/uploading
azure-storage-blob
# microsoft/phi model
einops
pytest
scipy
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(
target: str | dict[str],
ireec_flags: str | dict[str] = None,
precision: str | dict[str] = "fp16",
td_spec: str | dict[str] = None,
attn_spec: str | dict[str] = None,
decomp_attn: bool | dict[bool] = False,
external_weights: str | dict[str] = None,
pipeline_dir: str = "./shark_vmfbs",
Expand Down Expand Up @@ -396,7 +396,7 @@ def __init__(
map_arguments = {
"ireec_flags": ireec_flags,
"precision": precision,
"td_spec": td_spec,
"attn_spec": attn_spec,
"decomp_attn": decomp_attn,
"external_weights": external_weights,
"hf_model_name": hf_model_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ class TextEncoderModule(torch.nn.Module):
@torch.no_grad()
def __init__(
self,
precision,
):
super().__init__()
self.dtype = torch.float16
self.dtype = torch.float16 if precision == "fp16" else torch.float32
self.clip_l = SDClipModel(
layer="hidden",
layer_idx=-2,
Expand All @@ -65,21 +66,25 @@ def __init__(
layer_norm_hidden_state=False,
return_projected_pooled=False,
textmodel_json_config=CLIPL_CONFIG,
).half()
)
if precision == "fp16":
self.clip_l = self.clip_l.half()
clip_l_weights = hf_hub_download(
repo_id="stabilityai/stable-diffusion-3-medium",
filename="text_encoders/clip_l.safetensors",
)
with safe_open(clip_l_weights, framework="pt", device="cpu") as f:
load_into(f, self.clip_l.transformer, "", "cpu", self.dtype)
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype).half()
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype)
if precision == "fp16":
self.clip_l = self.clip_g.half()
clip_g_weights = hf_hub_download(
repo_id="stabilityai/stable-diffusion-3-medium",
filename="text_encoders/clip_g.safetensors",
)
with safe_open(clip_g_weights, framework="pt", device="cpu") as f:
load_into(f, self.clip_g.transformer, "", "cpu", self.dtype)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=self.dtype).half()
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float16)
t5_weights = hf_hub_download(
repo_id="stabilityai/stable-diffusion-3-medium",
filename="text_encoders/t5xxl_fp16.safetensors",
Expand Down Expand Up @@ -150,7 +155,7 @@ def export_text_encoders(
attn_spec=attn_spec,
)
return vmfb_path
model = TextEncoderModule(hf_model_name)
model = TextEncoderModule(precision)
mapper = {}

assert (
Expand Down
32 changes: 29 additions & 3 deletions models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,47 @@ def is_valid_file(arg):
"--unet_precision",
type=str,
default=None,
help="Precision of CLIP weights and graph.",
help="Precision of UNet weights and graph.",
)
p.add_argument(
"--mmdit_precision",
type=str,
default=None,
help="Precision of CLIP weights and graph.",
help="Precision of mmdit weights and graph.",
)
p.add_argument(
"--vae_precision",
type=str,
default=None,
help="Precision of CLIP weights and graph.",
help="Precision of vae weights and graph.",
)

p.add_argument(
"--clip_spec",
type=str,
default=None,
help="transform dialect spec for the given submodel.",
)
p.add_argument(
"--unet_spec",
type=str,
default=None,
help="transform dialect spec for the given submodel.",
)
p.add_argument(
"--mmdit_spec",
type=str,
default=None,
help="transform dialect spec for the given submodel.",
)
p.add_argument(
"--vae_spec",
type=str,
default=None,
help="transform dialect spec for the given submodel.",
)


p.add_argument(
"--max_length", type=int, default=64, help="Sequence Length of Stable Diffusion"
)
Expand Down
10 changes: 8 additions & 2 deletions models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,12 @@ def numpy_to_pil_image(images):
"mmdit": args.mmdit_precision if args.mmdit_precision else args.precision,
"vae": args.vae_precision if args.vae_precision else args.precision,
}
specs = {
"text_encoder": args.clip_spec if args.clip_spec else args.attn_spec,
"unet": args.unet_spec if args.unet_spec else args.attn_spec,
"mmdit": args.mmdit_spec if args.mmdit_spec else args.attn_spec,
"vae": args.vae_spec if args.vae_spec else args.attn_spec,
}
if not args.pipeline_dir:
args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "")
benchmark = {}
Expand Down Expand Up @@ -908,11 +914,11 @@ def numpy_to_pil_image(images):
args.width,
args.batch_size,
args.max_length,
args.precision,
precisions,
devices,
targets,
ireec_flags,
args.attn_spec,
specs,
args.decomp_attn,
args.pipeline_dir,
args.external_weights_dir,
Expand Down
18 changes: 9 additions & 9 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
"--iree-execution-model=async-external",
],
"masked_attention": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics, util.func(iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))",
],
"punet": [
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))"
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))"
],
"vae_preprocess": [
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))"
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))"
],
"preprocess_default": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)",
],
"unet": [
"--iree-flow-enable-aggressive-fusion",
Expand All @@ -44,7 +44,7 @@
],
"clip": [
"--iree-flow-enable-aggressive-fusion",
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
"--iree-flow-enable-fuse-horizontal-contractions=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-opt-outer-dim-concat=true",
"--iree-rocm-waves-per-eu=2",
Expand All @@ -71,19 +71,19 @@
"--iree-opt-const-eval=false",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-flow-enable-aggressive-fusion",
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
"--iree-flow-enable-fuse-horizontal-contractions=true",
"--iree-codegen-gpu-native-math-precision=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
],
"masked_attention": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics, util.func(iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))",
],
"punet": [
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))"
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))"
],
"preprocess_default": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics)",
],
"unet": [""],
"clip": [""],
Expand Down

0 comments on commit d23a45b

Please sign in to comment.