Skip to content

Commit

Permalink
WIP: fp16 and guidance scale fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jan 19, 2024
1 parent d1fda83 commit c1aebad
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 133 deletions.
7 changes: 3 additions & 4 deletions python/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def __init__(self, hf_model_name):
subfolder="unet",
)

def forward(self, sample, timestep, encoder_hidden_states):
guidance_scale = 7.5
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
Expand Down Expand Up @@ -127,10 +126,10 @@ def main(
encoder_hidden_states=AbstractTensor(
*encoder_hidden_states_sizes, dtype=dtype
),
#guidance_scale=AbstractTensor(1, dtype=dtype),
guidance_scale=AbstractTensor(1, dtype=dtype),
):
return jittable(unet_model.forward)(
sample, timestep, encoder_hidden_states, # guidance_scale
sample, timestep, encoder_hidden_states, guidance_scale
)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
Expand Down
14 changes: 7 additions & 7 deletions python/turbine_models/custom_models/sd_inference/unet_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def run_unet(
sample,
timestep,
encoder_hidden_states,
# guidance_scale,
guidance_scale,
vmfb_path,
hf_model_name,
hf_auth_token,
Expand All @@ -64,7 +64,7 @@ def run_unet(
ireert.asdevicearray(runner.config.device, sample),
ireert.asdevicearray(runner.config.device, timestep),
ireert.asdevicearray(runner.config.device, encoder_hidden_states),
# ireert.asdevicearray(runner.config.device, guidance_scale),
ireert.asdevicearray(runner.config.device, guidance_scale),
]
results = runner.ctx.modules.compiled_unet["main"](*inputs)
return results
Expand All @@ -90,7 +90,7 @@ def __init__(self, hf_model_name, hf_auth_token):
)
self.guidance_scale = 7.5

def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale):
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
Expand All @@ -106,7 +106,7 @@ def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale):
hf_auth_token,
)
results = unet_model.forward(
sample, timestep, encoder_hidden_states, #guidance_scale
sample, timestep, encoder_hidden_states, guidance_scale
)
np_torch_output = results.detach().cpu().numpy()
return np_torch_output
Expand All @@ -118,7 +118,7 @@ def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale):
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
)
timestep = torch.zeros(1, dtype=torch.float32)
# guidance_scale = torch.Tensor([7.5], dtype=torch.float32)
guidance_scale = torch.Tensor([7.5], dtype=torch.float32)
if args.hf_model_name == "CompVis/stable-diffusion-v1-4":
encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32)
elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base":
Expand All @@ -129,7 +129,7 @@ def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale):
sample,
timestep,
encoder_hidden_states,
# guidance_scale,
guidance_scale,
args.vmfb_path,
args.hf_model_name,
args.hf_auth_token,
Expand All @@ -152,7 +152,7 @@ def forward(self, sample, timestep, encoder_hidden_states): #, guidance_scale):
sample,
timestep,
encoder_hidden_states,
# guidance_scale,
guidance_scale,
)
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
err = utils.largest_error(torch_output, turbine_output)
Expand Down
244 changes: 122 additions & 122 deletions python/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,43 +49,43 @@
arguments["hf_model_name"],
)

# vae_model = vae.VaeModel(
# # This is a public model, so no auth required
# arguments["hf_model_name"],
# custom_vae=None,
# )
vae_model = vae.VaeModel(
# This is a public model, so no auth required
arguments["hf_model_name"],
custom_vae=None,
)


class StableDiffusionTest(unittest.TestCase):
# def testExportClipModel(self):
# with self.assertRaises(SystemExit) as cm:
# clip.export_clip_model(
# # This is a public model, so no auth required
# arguments["hf_model_name"],
# None,
# "vmfb",
# "safetensors",
# f"{arguments['safe_model_name']}_clip.safetensors",
# "cpu",
# )
# self.assertEqual(cm.exception.code, None)
# arguments["external_weight_path"] = f"{arguments['safe_model_name']}_clip.safetensors"
# arguments["vmfb_path"] = f"{arguments['safe_model_name']}_clip.vmfb"
# turbine = clip_runner.run_clip(
# arguments["device"],
# arguments["prompt"],
# arguments["vmfb_path"],
# arguments["hf_model_name"],
# arguments["hf_auth_token"],
# arguments["external_weight_path"],
# )
# torch_output = clip_runner.run_torch_clip(
# arguments["hf_model_name"], arguments["hf_auth_token"], arguments["prompt"]
# )
# err = utils.largest_error(torch_output, turbine[0])
# assert err < 9e-5
# #os.remove(f"{arguments['safe_model_name']}_clip.safetensors")
# #os.remove(f"{arguments['safe_model_name']}_clip.vmfb")
def testExportClipModel(self):
with self.assertRaises(SystemExit) as cm:
clip.export_clip_model(
# This is a public model, so no auth required
arguments["hf_model_name"],
None,
"vmfb",
"safetensors",
f"{arguments['safe_model_name']}_clip.safetensors",
"cpu",
)
self.assertEqual(cm.exception.code, None)
arguments["external_weight_path"] = f"{arguments['safe_model_name']}_clip.safetensors"
arguments["vmfb_path"] = f"{arguments['safe_model_name']}_clip.vmfb"
turbine = clip_runner.run_clip(
arguments["device"],
arguments["prompt"],
arguments["vmfb_path"],
arguments["hf_model_name"],
arguments["hf_auth_token"],
arguments["external_weight_path"],
)
torch_output = clip_runner.run_torch_clip(
arguments["hf_model_name"], arguments["hf_auth_token"], arguments["prompt"]
)
err = utils.largest_error(torch_output, turbine[0])
assert err < 9e-5
#os.remove(f"{arguments['safe_model_name']}_clip.safetensors")
#os.remove(f"{arguments['safe_model_name']}_clip.vmfb")

def testExportUnetModel(self):
with self.assertRaises(SystemExit) as cm:
Expand Down Expand Up @@ -124,7 +124,7 @@ def testExportUnetModel(self):
sample,
timestep,
encoder_hidden_states,
# guidance_scale,
guidance_scale,
arguments["vmfb_path"],
arguments["hf_model_name"],
arguments["hf_auth_token"],
Expand All @@ -136,100 +136,100 @@ def testExportUnetModel(self):
sample,
timestep,
encoder_hidden_states,
# guidance_scale,
guidance_scale,
)
err = utils.largest_error(torch_output, turbine)
assert err < 9e-5
#os.remove(f"{arguments['safe_model_name']}_unet.safetensors")
#os.remove(f"{arguments['safe_model_name']}_unet.vmfb")

# def testExportVaeModelDecode(self):
# with self.assertRaises(SystemExit) as cm:
# vae.export_vae_model(
# vae_model,
# # This is a public model, so no auth required
# arguments["hf_model_name"],
# arguments["batch_size"],
# arguments["height"],
# arguments["width"],
# arguments["precision"],
# compile_to="vmfb",
# external_weights="safetensors",
# external_weight_path=f"{arguments['safe_model_name']}_vae.safetensors",
# device="cpu",
# variant="decode",
# )
# self.assertEqual(cm.exception.code, None)
# arguments["external_weight_path"] = f"{arguments['safe_model_name']}_vae.safetensors"
# arguments["vmfb_path"] = f"{arguments['safe_model_name']}_vae.vmfb"
# dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32
# example_input = torch.rand(
# arguments["batch_size"],
# 4,
# arguments["height"] // 8,
# arguments["width"] // 8,
# dtype=dtype,
# )
# turbine = vae_runner.run_vae(
# arguments["device"],
# example_input,
# arguments["vmfb_path"],
# arguments["hf_model_name"],
# arguments["external_weight_path"],
# )
# torch_output = vae_runner.run_torch_vae(
# arguments["hf_model_name"],
# "decode",
# example_input,
# )
# err = utils.largest_error(torch_output, turbine)
# assert err < 9e-5
# #os.remove(f"{arguments['safe_model_name']}_vae.safetensors")
# #os.remove(f"{arguments['safe_model_name']}_vae.vmfb")
def testExportVaeModelDecode(self):
with self.assertRaises(SystemExit) as cm:
vae.export_vae_model(
vae_model,
# This is a public model, so no auth required
arguments["hf_model_name"],
arguments["batch_size"],
arguments["height"],
arguments["width"],
arguments["precision"],
compile_to="vmfb",
external_weights="safetensors",
external_weight_path=f"{arguments['safe_model_name']}_vae.safetensors",
device="cpu",
variant="decode",
)
self.assertEqual(cm.exception.code, None)
arguments["external_weight_path"] = f"{arguments['safe_model_name']}_vae.safetensors"
arguments["vmfb_path"] = f"{arguments['safe_model_name']}_vae.vmfb"
dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32
example_input = torch.rand(
arguments["batch_size"],
4,
arguments["height"] // 8,
arguments["width"] // 8,
dtype=dtype,
)
turbine = vae_runner.run_vae(
arguments["device"],
example_input,
arguments["vmfb_path"],
arguments["hf_model_name"],
arguments["external_weight_path"],
)
torch_output = vae_runner.run_torch_vae(
arguments["hf_model_name"],
"decode",
example_input,
)
err = utils.largest_error(torch_output, turbine)
assert err < 9e-5
#os.remove(f"{arguments['safe_model_name']}_vae.safetensors")
#os.remove(f"{arguments['safe_model_name']}_vae.vmfb")

# def testExportVaeModelEncode(self):
# with self.assertRaises(SystemExit) as cm:
# vae.export_vae_model(
# vae_model,
# # This is a public model, so no auth required
# arguments["hf_model_name"],
# arguments["batch_size"],
# arguments["height"],
# arguments["width"],
# arguments["precision"],
# "vmfb",
# external_weights="safetensors",
# external_weight_path=f"{arguments['safe_model_name']}_vae.safetensors",
# device="cpu",
# variant="encode",
# )
# self.assertEqual(cm.exception.code, None)
# arguments["external_weight_path"] = f"{arguments['safe_model_name']}_vae.safetensors"
# arguments["vmfb_path"] = f"{arguments['safe_model_name']}_vae.vmfb"
# dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32
# example_input = torch.rand(
# arguments["batch_size"],
# 3,
# arguments["height"],
# arguments["width"],
# dtype=dtype,
# )
# turbine = vae_runner.run_vae(
# arguments["device"],
# example_input,
# arguments["vmfb_path"],
# arguments["hf_model_name"],
# arguments["external_weight_path"],
# )
# torch_output = vae_runner.run_torch_vae(
# arguments["hf_model_name"],
# "encode",
# example_input,
# )
# err = utils.largest_error(torch_output, turbine)
# assert err < 2e-3
# #os.remove(f"{arguments['safe_model_name']}_vae.safetensors")
# #os.remove(f"{arguments['safe_model_name']}_vae.vmfb")
def testExportVaeModelEncode(self):
with self.assertRaises(SystemExit) as cm:
vae.export_vae_model(
vae_model,
# This is a public model, so no auth required
arguments["hf_model_name"],
arguments["batch_size"],
arguments["height"],
arguments["width"],
arguments["precision"],
"vmfb",
external_weights="safetensors",
external_weight_path=f"{arguments['safe_model_name']}_vae.safetensors",
device="cpu",
variant="encode",
)
self.assertEqual(cm.exception.code, None)
arguments["external_weight_path"] = f"{arguments['safe_model_name']}_vae.safetensors"
arguments["vmfb_path"] = f"{arguments['safe_model_name']}_vae.vmfb"
dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32
example_input = torch.rand(
arguments["batch_size"],
3,
arguments["height"],
arguments["width"],
dtype=dtype,
)
turbine = vae_runner.run_vae(
arguments["device"],
example_input,
arguments["vmfb_path"],
arguments["hf_model_name"],
arguments["external_weight_path"],
)
torch_output = vae_runner.run_torch_vae(
arguments["hf_model_name"],
"encode",
example_input,
)
err = utils.largest_error(torch_output, turbine)
assert err < 2e-3
#os.remove(f"{arguments['safe_model_name']}_vae.safetensors")
#os.remove(f"{arguments['safe_model_name']}_vae.vmfb")


if __name__ == "__main__":
Expand Down

0 comments on commit c1aebad

Please sign in to comment.