Skip to content

Commit

Permalink
add tests for mmdit and vae
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Jun 19, 2024
1 parent fc6833d commit 4e2c3cc
Showing 1 changed file with 161 additions and 184 deletions.
345 changes: 161 additions & 184 deletions models/turbine_models/tests/sd3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,96 +170,73 @@ def test01_ExportPromptEncoder(self):
np.testing.assert_allclose(torch_output1, turbine_output1, rtol, atol)
np.testing.assert_allclose(torch_output2, turbine_output2, rtol, atol)

# def test02_ExportUnetModel(self):
# if arguments["device"] in ["vulkan", "cuda"]:
# self.skipTest("Unknown error on vulkan; To be tested on cuda.")
# unet.export_unet_model(
# unet_model=self.unet_model,
# # This is a public model, so no auth required
# hf_model_name=arguments["hf_model_name"],
# batch_size=arguments["batch_size"],
# height=arguments["height"],
# width=arguments["width"],
# precision=arguments["precision"],
# max_length=arguments["max_length"],
# hf_auth_token=None,
# compile_to="vmfb",
# external_weights=arguments["external_weights"],
# external_weight_path=self.safe_model_name
# + "_"
# + arguments["precision"]
# + "_unet."
# + arguments["external_weights"],
# device=arguments["device"],
# target_triple=arguments["iree_target_triple"],
# ireec_flags=arguments["ireec_flags"],
# decomp_attn=arguments["decomp_attn"],
# attn_spec=arguments["attn_spec"],
# )
# arguments["external_weight_path"] = (
# self.safe_model_name
# + "_"
# + arguments["precision"]
# + "_unet."
# + arguments["external_weights"]
# )
# arguments["vmfb_path"] = (
# self.safe_model_name
# + "_"
# + str(arguments["max_length"])
# + "_"
# + str(arguments["height"])
# + "x"
# + str(arguments["width"])
# + "_"
# + arguments["precision"]
# + "_unet_"
# + arguments["device"]
# + ".vmfb"
# )
# dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32
# sample = torch.rand(
# (
# arguments["batch_size"],
# arguments["in_channels"],
# arguments["height"] // 8,
# arguments["width"] // 8,
# ),
# dtype=dtype,
# )
# timestep = torch.zeros(1, dtype=torch.int64)
# prompt_embeds = torch.rand(
# (2 * arguments["batch_size"], arguments["max_length"], 2048),
# dtype=dtype,
# )
# text_embeds = torch.rand(2 * arguments["batch_size"], 1280, dtype=dtype)
# time_ids = torch.zeros(2 * arguments["batch_size"], 6, dtype=dtype)
# guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype)
#
# turbine = unet_runner.run_unet(
# arguments["rt_device"],
# sample,
# timestep,
# prompt_embeds,
# text_embeds,
# time_ids,
# guidance_scale,
# arguments["vmfb_path"],
# arguments["hf_model_name"],
# arguments["hf_auth_token"],
# arguments["external_weight_path"],
# )
# torch_output = unet_runner.run_torch_unet(
# arguments["hf_model_name"],
# arguments["hf_auth_token"],
# sample.float(),
# timestep,
# prompt_embeds.float(),
# text_embeds.float(),
# time_ids.float(),
# guidance_scale.float(),
# precision=arguments["precision"],
# )
def test02_ExportMMDITModel(self):
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest("Not testing on vulkan or cuda")
arguments["external_weight_path"] = (
self.safe_model_name
+ "_"
+ arguments["precision"]
+ "_mmdit."
+ arguments["external_weights"]
)
sd3_mmdit.export_mmdit_model(
mmdit_model=self.mmdit_model,
# This is a public model, so no auth required
hf_model_name=arguments["hf_model_name"],
batch_size=arguments["batch_size"],
height=arguments["height"],
width=arguments["width"],
precision=arguments["precision"],
max_length=arguments["max_length"],
hf_auth_token=None,
compile_to="vmfb",
external_weights=arguments["external_weights"],
external_weight_path=arguments["external_weight_path"],
device=arguments["mmdit_device"],
target_triple=arguments["iree_target_triple"],
ireec_flags=arguments["ireec_flags"],
decomp_attn=arguments["decomp_attn"],
attn_spec=arguments["attn_spec"],
)
arguments["vmfb_path"] = (
self.safe_model_name
+ "_"
+ str(arguments["max_length"])
+ "_"
+ str(arguments["height"])
+ "x"
+ str(arguments["width"])
+ "_"
+ arguments["precision"]
+ "_unet_"
+ arguments["device"]
+ ".vmfb"
)
dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32

hidden_states = torch.randn(
(arguments["batch_size"], 16, arguments["height"] // 8, arguments["width"] // 8), dtype=dtype
)
encoder_hidden_states = torch.randn(
(arguments["batch_size"], arguments["max_length"] * 2, 4096), dtype=dtype
)
pooled_projections = torch.randn((arguments["batch_size"], 2048), dtype=dtype)
timestep = torch.tensor([0, 0], dtype=dtype)
turbine = sd3_mmdit_runner.run_mmdit_turbine(
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
arguments,
)
torch_output = sd3_mmdit_runner.run_diffusers_mmdit(
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
arguments,
)
# if arguments["benchmark"] or arguments["tracy_profile"]:
# run_benchmark(
# "unet",
Expand All @@ -274,99 +251,99 @@ def test01_ExportPromptEncoder(self):
# precision=arguments["precision"],
# tracy_profile=arguments["tracy_profile"],
# )
# rtol = 4e-2
# atol = 4e-1
#
# np.testing.assert_allclose(torch_output, turbine, rtol, atol)
#
# def test03_ExportVaeModelDecode(self):
# if arguments["device"] in ["vulkan", "cuda"]:
# self.skipTest("Compilation error on vulkan; To be tested on cuda.")
# vae.export_vae_model(
# vae_model=self.vae_model,
# # This is a public model, so no auth required
# hf_model_name=arguments["hf_model_name"],
# batch_size=arguments["batch_size"],
# height=arguments["height"],
# width=arguments["width"],
# precision=arguments["precision"],
# compile_to="vmfb",
# external_weights=arguments["external_weights"],
# external_weight_path=self.safe_model_name
# + "_"
# + arguments["precision"]
# + "_vae_decode."
# + arguments["external_weights"],
# device=arguments["device"],
# target_triple=arguments["iree_target_triple"],
# ireec_flags=arguments["ireec_flags"],
# variant="decode",
# decomp_attn=arguments["decomp_attn"],
# attn_spec=arguments["attn_spec"],
# exit_on_vmfb=True,
# )
# arguments["external_weight_path"] = (
# self.safe_model_name
# + "_"
# + arguments["precision"]
# + "_vae_decode."
# + arguments["external_weights"]
# )
# arguments["vmfb_path"] = (
# self.safe_model_name
# + "_"
# + str(arguments["height"])
# + "x"
# + str(arguments["width"])
# + "_"
# + arguments["precision"]
# + "_vae_decode_"
# + arguments["device"]
# + ".vmfb"
# )
# example_input = torch.ones(
# arguments["batch_size"],
# 4,
# arguments["height"] // 8,
# arguments["width"] // 8,
# dtype=torch.float32,
# )
# example_input_torch = example_input
# if arguments["precision"] == "fp16":
# example_input = example_input.half()
# turbine = vae_runner.run_vae(
# arguments["rt_device"],
# example_input,
# arguments["vmfb_path"],
# arguments["hf_model_name"],
# arguments["external_weight_path"],
# )
# torch_output = vae_runner.run_torch_vae(
# arguments["hf_model_name"],
# (
# "madebyollin/sdxl-vae-fp16-fix"
# if arguments["precision"] == "fp16"
# else ""
# ),
# "decode",
# example_input_torch,
# )
# if arguments["benchmark"] or arguments["tracy_profile"]:
# run_benchmark(
# "vae_decode",
# arguments["vmfb_path"],
# arguments["external_weight_path"],
# arguments["rt_device"],
# height=arguments["height"],
# width=arguments["width"],
# precision=arguments["precision"],
# tracy_profile=arguments["tracy_profile"],
# )
# rtol = 4e-2
# atol = 4e-1
#
# np.testing.assert_allclose(torch_output, turbine, rtol, atol)
#
rtol = 4e-2
atol = 4e-1

np.testing.assert_allclose(torch_output, turbine, rtol, atol)

def test03_ExportVaeModelDecode(self):
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest("not testing vulkan or cuda")
sd3_vae.export_vae_model(
vae_model=self.vae_model,
# This is a public model, so no auth required
exit_on_vmfb=True,
)

arguments["external_weight_path"] = (
self.safe_model_name
+ "_"
+ arguments["precision"]
+ "_vae_decode."
+ arguments["external_weights"]
)
sd3_vae.export_vae_model(
self.vae_model,
hf_model_name=arguments["hf_model_name"],
batch_size=arguments["batch_size"],
height=arguments["height"],
width=arguments["width"],
precision=arguments["precision"],
compile_to="vmfb",
external_weights=arguments["external_weights"],
external_weight_path=arguments["external_weight_path"],
device=arguments["device"],
target_triple=arguments["iree_target_triple"],
ireec_flags=arguments["ireec_flags"],
variant="decode",
decomp_attn=arguments["decomp_attn"],
attn_spec=arguments["attn_spec"],
)
arguments["vmfb_path"] = (
self.safe_model_name
+ "_"
+ str(arguments["height"])
+ "x"
+ str(arguments["width"])
+ "_"
+ arguments["precision"]
+ "_vae_decode_"
+ arguments["device"]
+ ".vmfb"
)
example_input = torch.ones(
arguments["batch_size"],
16,
arguments["height"] // 8,
arguments["width"] // 8,
dtype=torch.float32,
)
example_input_torch = example_input
if arguments["precision"] == "fp16":
example_input = example_input.half()
turbine = sd3_vae_runner.run_vae(
arguments["rt_device"],
example_input,
arguments["vmfb_path"],
arguments["hf_model_name"],
arguments["external_weight_path"],
)
torch_output = sd3_vae_runner.run_torch_vae(
arguments["hf_model_name"],
(
"madebyollin/sdxl-vae-fp16-fix"
if arguments["precision"] == "fp16"
else ""
),
"decode",
example_input_torch,
)
#if arguments["benchmark"] or arguments["tracy_profile"]:
# run_benchmark(
# "vae_decode",
# arguments["vmfb_path"],
# arguments["external_weight_path"],
# arguments["rt_device"],
# height=arguments["height"],
# width=arguments["width"],
# precision=arguments["precision"],
# tracy_profile=arguments["tracy_profile"],
# )
rtol = 4e-2
atol = 4e-1

np.testing.assert_allclose(torch_output, turbine, rtol, atol)

# def test04_ExportVaeModelEncode(self):
# if arguments["device"] in ["cpu", "vulkan", "cuda", "rocm"]:
# self.skipTest(
Expand Down Expand Up @@ -454,7 +431,7 @@ def test01_ExportPromptEncoder(self):
# rtol = 4e-2
# atol = 4e-2
# np.testing.assert_allclose(torch_output, turbine, rtol, atol)
#

# def test05_t2i_generate_images(self):
# if arguments["device"] in ["vulkan", "cuda", "rocm"]:
# self.skipTest(
Expand Down

0 comments on commit 4e2c3cc

Please sign in to comment.