diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 00b5ae7d4..7fa99582c 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -280,20 +280,24 @@ def test02_ExportUnetModel(self): text_embeds = torch.rand(2 * arguments["batch_size"], 1280, dtype=dtype) time_ids = torch.zeros(2 * arguments["batch_size"], 6, dtype=dtype) guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype) - - turbine = unet_runner.run_unet( - arguments["rt_device"], - sample, - timestep, - prompt_embeds, - text_embeds, - time_ids, - guidance_scale, - arguments["vmfb_path"], - arguments["hf_model_name"], - arguments["hf_auth_token"], - arguments["external_weight_path"], - ) + try: + turbine = unet_runner.run_unet( + arguments["rt_device"], + sample, + timestep, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + except err as e: + print(e) + sys.exit() + torch_output = unet_runner.run_torch_unet( arguments["hf_model_name"], arguments["hf_auth_token"],