Skip to content

Commit

Permalink
Update sdxl_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored Apr 11, 2024
1 parent ef84ce1 commit cfb63ef
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions models/turbine_models/tests/sdxl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,24 @@ def test02_ExportUnetModel(self):
text_embeds = torch.rand(2 * arguments["batch_size"], 1280, dtype=dtype)
time_ids = torch.zeros(2 * arguments["batch_size"], 6, dtype=dtype)
guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype)

turbine = unet_runner.run_unet(
arguments["rt_device"],
sample,
timestep,
prompt_embeds,
text_embeds,
time_ids,
guidance_scale,
arguments["vmfb_path"],
arguments["hf_model_name"],
arguments["hf_auth_token"],
arguments["external_weight_path"],
)
try:
turbine = unet_runner.run_unet(
arguments["rt_device"],
sample,
timestep,
prompt_embeds,
text_embeds,
time_ids,
guidance_scale,
arguments["vmfb_path"],
arguments["hf_model_name"],
arguments["hf_auth_token"],
arguments["external_weight_path"],
)
except err as e:
print(e)
sys.exit()

torch_output = unet_runner.run_torch_unet(
arguments["hf_model_name"],
arguments["hf_auth_token"],
Expand Down

0 comments on commit cfb63ef

Please sign in to comment.