From 1cec8ddbae5d91531208ac20ed5cb1c96f944e92 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 19 Jan 2024 12:19:43 -0600 Subject: [PATCH] WIP: fp16 and guidance scale fixes --- .../custom_models/sd_inference/unet.py | 7 +- .../custom_models/sd_inference/unet_runner.py | 14 +- python/turbine_models/tests/sd_test.py | 244 +++++++++--------- 3 files changed, 132 insertions(+), 133 deletions(-) diff --git a/python/turbine_models/custom_models/sd_inference/unet.py b/python/turbine_models/custom_models/sd_inference/unet.py index 7733b552e..81aa841ae 100644 --- a/python/turbine_models/custom_models/sd_inference/unet.py +++ b/python/turbine_models/custom_models/sd_inference/unet.py @@ -69,8 +69,7 @@ def __init__(self, hf_model_name): subfolder="unet", ) - def forward(self, sample, timestep, encoder_hidden_states): - guidance_scale = 7.5 + def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): samples = torch.cat([sample] * 2) unet_out = self.unet.forward( samples, timestep, encoder_hidden_states, return_dict=False @@ -127,10 +126,10 @@ def main( encoder_hidden_states=AbstractTensor( *encoder_hidden_states_sizes, dtype=dtype ), - #guidance_scale=AbstractTensor(1, dtype=dtype), + guidance_scale=AbstractTensor(1, dtype=dtype), ): return jittable(unet_model.forward)( - sample, timestep, encoder_hidden_states, # guidance_scale + sample, timestep, encoder_hidden_states, guidance_scale ) import_to = "INPUT" if compile_to == "linalg" else "IMPORT" diff --git a/python/turbine_models/custom_models/sd_inference/unet_runner.py b/python/turbine_models/custom_models/sd_inference/unet_runner.py index 3f78c7753..1b8c5d101 100644 --- a/python/turbine_models/custom_models/sd_inference/unet_runner.py +++ b/python/turbine_models/custom_models/sd_inference/unet_runner.py @@ -52,7 +52,7 @@ def run_unet( sample, timestep, encoder_hidden_states, - # guidance_scale, + guidance_scale, vmfb_path, hf_model_name, hf_auth_token, @@ -64,7 +64,7 @@ def run_unet( ireert.asdevicearray(runner.config.device, sample), ireert.asdevicearray(runner.config.device, timestep), ireert.asdevicearray(runner.config.device, encoder_hidden_states), - # ireert.asdevicearray(runner.config.device, guidance_scale), + ireert.asdevicearray(runner.config.device, guidance_scale), ] results = runner.ctx.modules.compiled_unet["main"](*inputs) return results @@ -90,7 +90,7 @@ def __init__(self, hf_model_name, hf_auth_token): ) self.guidance_scale = 7.5 - def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale): + def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): samples = torch.cat([sample] * 2) unet_out = self.unet.forward( samples, timestep, encoder_hidden_states, return_dict=False @@ -106,7 +106,7 @@ def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale): hf_auth_token, ) results = unet_model.forward( - sample, timestep, encoder_hidden_states, #guidance_scale + sample, timestep, encoder_hidden_states, guidance_scale ) np_torch_output = results.detach().cpu().numpy() return np_torch_output @@ -118,7 +118,7 @@ def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale): args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 ) timestep = torch.zeros(1, dtype=torch.float32) - # guidance_scale = torch.Tensor([7.5], dtype=torch.float32) + guidance_scale = torch.Tensor([7.5], dtype=torch.float32) if args.hf_model_name == "CompVis/stable-diffusion-v1-4": encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": @@ -129,7 +129,7 @@ def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale): sample, timestep, encoder_hidden_states, - # guidance_scale, + guidance_scale, args.vmfb_path, args.hf_model_name, args.hf_auth_token, @@ -152,7 +152,7 @@ def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale): sample, timestep, encoder_hidden_states, - # guidance_scale, + guidance_scale, ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) err = utils.largest_error(torch_output, turbine_output) diff --git a/python/turbine_models/tests/sd_test.py b/python/turbine_models/tests/sd_test.py index c8e62cf04..cc9800612 100644 --- a/python/turbine_models/tests/sd_test.py +++ b/python/turbine_models/tests/sd_test.py @@ -49,43 +49,43 @@ arguments["hf_model_name"], ) -# vae_model = vae.VaeModel( -# # This is a public model, so no auth required -# arguments["hf_model_name"], -# custom_vae=None, -# ) +vae_model = vae.VaeModel( + # This is a public model, so no auth required + arguments["hf_model_name"], + custom_vae=None, +) class StableDiffusionTest(unittest.TestCase): -# def testExportClipModel(self): -# with self.assertRaises(SystemExit) as cm: -# clip.export_clip_model( -# # This is a public model, so no auth required -# arguments["hf_model_name"], -# None, -# "vmfb", -# "safetensors", -# f"{arguments['safe_model_name']}_clip.safetensors", -# "cpu", -# ) -# self.assertEqual(cm.exception.code, None) -# arguments["external_weight_path"] = f"{arguments['safe_model_name']}_clip.safetensors" -# arguments["vmfb_path"] = f"{arguments['safe_model_name']}_clip.vmfb" -# turbine = clip_runner.run_clip( -# arguments["device"], -# arguments["prompt"], -# arguments["vmfb_path"], -# arguments["hf_model_name"], -# arguments["hf_auth_token"], -# arguments["external_weight_path"], -# ) -# torch_output = clip_runner.run_torch_clip( -# arguments["hf_model_name"], arguments["hf_auth_token"], arguments["prompt"] -# ) -# err = utils.largest_error(torch_output, turbine[0]) -# assert err < 9e-5 -# #os.remove(f"{arguments['safe_model_name']}_clip.safetensors") -# #os.remove(f"{arguments['safe_model_name']}_clip.vmfb") + def testExportClipModel(self): + with self.assertRaises(SystemExit) as cm: + clip.export_clip_model( + # This is a public model, so no auth required + arguments["hf_model_name"], + None, + "vmfb", + "safetensors", + f"{arguments['safe_model_name']}_clip.safetensors", + "cpu", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_path"] = f"{arguments['safe_model_name']}_clip.safetensors" + arguments["vmfb_path"] = f"{arguments['safe_model_name']}_clip.vmfb" + turbine = clip_runner.run_clip( + arguments["device"], + arguments["prompt"], + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = clip_runner.run_torch_clip( + arguments["hf_model_name"], arguments["hf_auth_token"], arguments["prompt"] + ) + err = utils.largest_error(torch_output, turbine[0]) + assert err < 9e-5 + #os.remove(f"{arguments['safe_model_name']}_clip.safetensors") + #os.remove(f"{arguments['safe_model_name']}_clip.vmfb") def testExportUnetModel(self): with self.assertRaises(SystemExit) as cm: @@ -124,7 +124,7 @@ def testExportUnetModel(self): sample, timestep, encoder_hidden_states, - # guidance_scale, + guidance_scale, arguments["vmfb_path"], arguments["hf_model_name"], arguments["hf_auth_token"], @@ -136,100 +136,100 @@ def testExportUnetModel(self): sample, timestep, encoder_hidden_states, - # guidance_scale, + guidance_scale, ) err = utils.largest_error(torch_output, turbine) assert err < 9e-5 #os.remove(f"{arguments['safe_model_name']}_unet.safetensors") #os.remove(f"{arguments['safe_model_name']}_unet.vmfb") - # def testExportVaeModelDecode(self): - # with self.assertRaises(SystemExit) as cm: - # vae.export_vae_model( - # vae_model, - # # This is a public model, so no auth required - # arguments["hf_model_name"], - # arguments["batch_size"], - # arguments["height"], - # arguments["width"], - # arguments["precision"], - # compile_to="vmfb", - # external_weights="safetensors", - # external_weight_path=f"{arguments['safe_model_name']}_vae.safetensors", - # device="cpu", - # variant="decode", - # ) - # self.assertEqual(cm.exception.code, None) - # arguments["external_weight_path"] = f"{arguments['safe_model_name']}_vae.safetensors" - # arguments["vmfb_path"] = f"{arguments['safe_model_name']}_vae.vmfb" - # dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 - # example_input = torch.rand( - # arguments["batch_size"], - # 4, - # arguments["height"] // 8, - # arguments["width"] // 8, - # dtype=dtype, - # ) - # turbine = vae_runner.run_vae( - # arguments["device"], - # example_input, - # arguments["vmfb_path"], - # arguments["hf_model_name"], - # arguments["external_weight_path"], - # ) - # torch_output = vae_runner.run_torch_vae( - # arguments["hf_model_name"], - # "decode", - # example_input, - # ) - # err = utils.largest_error(torch_output, turbine) - # assert err < 9e-5 - # #os.remove(f"{arguments['safe_model_name']}_vae.safetensors") - # #os.remove(f"{arguments['safe_model_name']}_vae.vmfb") + def testExportVaeModelDecode(self): + with self.assertRaises(SystemExit) as cm: + vae.export_vae_model( + vae_model, + # This is a public model, so no auth required + arguments["hf_model_name"], + arguments["batch_size"], + arguments["height"], + arguments["width"], + arguments["precision"], + compile_to="vmfb", + external_weights="safetensors", + external_weight_path=f"{arguments['safe_model_name']}_vae.safetensors", + device="cpu", + variant="decode", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_path"] = f"{arguments['safe_model_name']}_vae.safetensors" + arguments["vmfb_path"] = f"{arguments['safe_model_name']}_vae.vmfb" + dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 + example_input = torch.rand( + arguments["batch_size"], + 4, + arguments["height"] // 8, + arguments["width"] // 8, + dtype=dtype, + ) + turbine = vae_runner.run_vae( + arguments["device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["external_weight_path"], + ) + torch_output = vae_runner.run_torch_vae( + arguments["hf_model_name"], + "decode", + example_input, + ) + err = utils.largest_error(torch_output, turbine) + assert err < 9e-5 + #os.remove(f"{arguments['safe_model_name']}_vae.safetensors") + #os.remove(f"{arguments['safe_model_name']}_vae.vmfb") - # def testExportVaeModelEncode(self): - # with self.assertRaises(SystemExit) as cm: - # vae.export_vae_model( - # vae_model, - # # This is a public model, so no auth required - # arguments["hf_model_name"], - # arguments["batch_size"], - # arguments["height"], - # arguments["width"], - # arguments["precision"], - # "vmfb", - # external_weights="safetensors", - # external_weight_path=f"{arguments['safe_model_name']}_vae.safetensors", - # device="cpu", - # variant="encode", - # ) - # self.assertEqual(cm.exception.code, None) - # arguments["external_weight_path"] = f"{arguments['safe_model_name']}_vae.safetensors" - # arguments["vmfb_path"] = f"{arguments['safe_model_name']}_vae.vmfb" - # dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 - # example_input = torch.rand( - # arguments["batch_size"], - # 3, - # arguments["height"], - # arguments["width"], - # dtype=dtype, - # ) - # turbine = vae_runner.run_vae( - # arguments["device"], - # example_input, - # arguments["vmfb_path"], - # arguments["hf_model_name"], - # arguments["external_weight_path"], - # ) - # torch_output = vae_runner.run_torch_vae( - # arguments["hf_model_name"], - # "encode", - # example_input, - # ) - # err = utils.largest_error(torch_output, turbine) - # assert err < 2e-3 - # #os.remove(f"{arguments['safe_model_name']}_vae.safetensors") - # #os.remove(f"{arguments['safe_model_name']}_vae.vmfb") + def testExportVaeModelEncode(self): + with self.assertRaises(SystemExit) as cm: + vae.export_vae_model( + vae_model, + # This is a public model, so no auth required + arguments["hf_model_name"], + arguments["batch_size"], + arguments["height"], + arguments["width"], + arguments["precision"], + "vmfb", + external_weights="safetensors", + external_weight_path=f"{arguments['safe_model_name']}_vae.safetensors", + device="cpu", + variant="encode", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_path"] = f"{arguments['safe_model_name']}_vae.safetensors" + arguments["vmfb_path"] = f"{arguments['safe_model_name']}_vae.vmfb" + dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 + example_input = torch.rand( + arguments["batch_size"], + 3, + arguments["height"], + arguments["width"], + dtype=dtype, + ) + turbine = vae_runner.run_vae( + arguments["device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["external_weight_path"], + ) + torch_output = vae_runner.run_torch_vae( + arguments["hf_model_name"], + "encode", + example_input, + ) + err = utils.largest_error(torch_output, turbine) + assert err < 2e-3 + #os.remove(f"{arguments['safe_model_name']}_vae.safetensors") + #os.remove(f"{arguments['safe_model_name']}_vae.vmfb") if __name__ == "__main__":