Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Feb 12, 2024
1 parent 08178bd commit e6e118e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
8 changes: 6 additions & 2 deletions models/turbine_models/custom_models/sdxl_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,12 @@ def main(self, inp=AbstractTensor(1, max_length, dtype=torch.int64)):
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name_1 = safe_name = utils.create_safe_name(args.hf_model_name, f"_{str(args.max_length)}_clip_1")
safe_name_2 = safe_name = utils.create_safe_name(args.hf_model_name, f"_{str(args.max_length)}_clip_2")
safe_name_1 = safe_name = utils.create_safe_name(
args.hf_model_name, f"_{str(args.max_length)}_clip_1"
)
safe_name_2 = safe_name = utils.create_safe_name(
args.hf_model_name, f"_{str(args.max_length)}_clip_2"
)
with open(f"{safe_name_1}.mlir", "w+") as f:
f.write(mod_1_str)
print("Saved to", safe_name_1 + ".mlir")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
parser.add_argument(
"--precision", type=str, default="fp32", help="Precision of Stable Diffusion"
)
parser.add_argument("--max_length", type=int, default=77, help="Max input length of Stable Diffusion")
parser.add_argument(
"--max_length", type=int, default=77, help="Max input length of Stable Diffusion"
)


def run_unet(
Expand Down
12 changes: 9 additions & 3 deletions models/turbine_models/tests/sdxl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@
vae_model = vae.VaeModel(
# This is a public model, so no auth required
arguments["hf_model_name"],
custom_vae="madebyollin/sdxl-vae-fp16-fix" if arguments.precision == "fp16" else None,
custom_vae="madebyollin/sdxl-vae-fp16-fix"
if arguments.precision == "fp16"
else None,
)


Expand All @@ -73,8 +75,12 @@ def test01_ExportClipModels(self):
f"{arguments['safe_model_name']}" + "_clip",
"cpu",
)
assert os.path.exists(f"{arguments['safe_model_name']}_{str(arguments['max_length'])}_clip_1.vmfb")
assert os.path.exists(f"{arguments['safe_model_name']}_{str(arguments['max_length'])}_clip_2.vmfb")
assert os.path.exists(
f"{arguments['safe_model_name']}_{str(arguments['max_length'])}_clip_1.vmfb"
)
assert os.path.exists(
f"{arguments['safe_model_name']}_{str(arguments['max_length'])}_clip_2.vmfb"
)
arguments[
"external_weight_path_1"
] = f"{arguments['safe_model_name']}_clip_1.safetensors"
Expand Down

0 comments on commit e6e118e

Please sign in to comment.