Skip to content

Commit

Permalink
Add precision to unet, vae and guidance scale as input to unet
Browse files Browse the repository at this point in the history
Formatting

Update sd_test.py

Update sd_test.py

Tweaks to VaeModel forward and instantiation.

Fix guidance scale arg in tests.

fix typo in unet_runner

formatting

Update VAE baseline for tests.

Update VAE baseline for tests.

Tweaks to vae model def and args.

Fixes to encoder hidden states, remove cpu vae path
  • Loading branch information
monorimet committed Jan 5, 2024
1 parent 18e8a41 commit a6471c6
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 55 deletions.
37 changes: 26 additions & 11 deletions python/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument(
"--precision", type=str, default="fp16", help="Precision of Stable Diffusion"
)
parser.add_argument(
"--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion"
)
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_path", type=str, default="")
parser.add_argument(
Expand Down Expand Up @@ -63,15 +69,14 @@ def __init__(self, hf_model_name, hf_auth_token):
subfolder="unet",
token=hf_auth_token,
)
self.guidance_scale = 7.5

def forward(self, sample, timestep, encoder_hidden_states):
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
Expand All @@ -83,6 +88,8 @@ def export_unet_model(
batch_size,
height,
width,
precision="fp32",
max_length=77,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
Expand All @@ -92,13 +99,16 @@ def export_unet_model(
max_alloc=None,
):
mapper = {}
dtype = torch.float16 if precision == "fp16" else torch.float32
unet_model = unet_model.to(dtype)
utils.save_external_weights(
mapper, unet_model, external_weights, external_weight_path
)

encoder_hidden_states_sizes = (2, 77, 768)
if hf_model_name == "stabilityai/stable-diffusion-2-1-base":
encoder_hidden_states_sizes = (2, 77, 1024)
encoder_hidden_states_sizes = (
unet_model.unet.config.layers_per_block,
max_length,
unet_model.unet.config.cross_attention_dim,
)

sample = (batch_size, unet_model.unet.config.in_channels, height // 8, width // 8)

Expand All @@ -112,13 +122,16 @@ class CompiledUnet(CompiledModule):

def main(
self,
sample=AbstractTensor(*sample, dtype=torch.float32),
timestep=AbstractTensor(1, dtype=torch.float32),
sample=AbstractTensor(*sample, dtype=dtype),
timestep=AbstractTensor(1, dtype=dtype),
encoder_hidden_states=AbstractTensor(
*encoder_hidden_states_sizes, dtype=torch.float32
*encoder_hidden_states_sizes, dtype=dtype
),
guidance_scale=AbstractTensor(1, dtype=dtype),
):
return jittable(unet_model.forward)(sample, timestep, encoder_hidden_states)
return jittable(unet_model.forward)(
sample, timestep, encoder_hidden_states, guidance_scale
)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledUnet(context=Context(), import_to=import_to)
Expand All @@ -143,6 +156,8 @@ def main(
args.batch_size,
args.height,
args.width,
args.precision,
args.max_length,
args.hf_auth_token,
args.compile_to,
args.external_weights,
Expand Down
20 changes: 16 additions & 4 deletions python/turbine_models/custom_models/sd_inference/unet_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def run_unet(
sample,
timestep,
encoder_hidden_states,
guidance_scale,
vmfb_path,
hf_model_name,
hf_auth_token,
Expand All @@ -63,13 +64,19 @@ def run_unet(
ireert.asdevicearray(runner.config.device, sample),
ireert.asdevicearray(runner.config.device, timestep),
ireert.asdevicearray(runner.config.device, encoder_hidden_states),
ireert.asdevicearray(runner.config.device, guidance_scale),
]
results = runner.ctx.modules.compiled_unet["main"](*inputs)
return results


def run_torch_unet(
hf_model_name, hf_auth_token, sample, timestep, encoder_hidden_states
hf_model_name,
hf_auth_token,
sample,
timestep,
encoder_hidden_states,
guidance_scale,
):
from diffusers import UNet2DConditionModel

Expand All @@ -83,13 +90,13 @@ def __init__(self, hf_model_name, hf_auth_token):
)
self.guidance_scale = 7.5

def forward(self, sample, timestep, encoder_hidden_states):
def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
Expand All @@ -98,7 +105,9 @@ def forward(self, sample, timestep, encoder_hidden_states):
hf_model_name,
hf_auth_token,
)
results = unet_model.forward(sample, timestep, encoder_hidden_states)
results = unet_model.forward(
sample, timestep, encoder_hidden_states, guidance_scale
)
np_torch_output = results.detach().cpu().numpy()
return np_torch_output

Expand All @@ -109,6 +118,7 @@ def forward(self, sample, timestep, encoder_hidden_states):
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
)
timestep = torch.zeros(1, dtype=torch.float32)
guidance_scale = torch.Tensor([7.5], dtype=torch.float32)
if args.hf_model_name == "CompVis/stable-diffusion-v1-4":
encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32)
elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base":
Expand All @@ -119,6 +129,7 @@ def forward(self, sample, timestep, encoder_hidden_states):
sample,
timestep,
encoder_hidden_states,
guidance_scale,
args.vmfb_path,
args.hf_model_name,
args.hf_auth_token,
Expand All @@ -141,6 +152,7 @@ def forward(self, sample, timestep, encoder_hidden_states):
sample,
timestep,
encoder_hidden_states,
guidance_scale,
)
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
err = utils.largest_error(torch_output, turbine_output)
Expand Down
60 changes: 42 additions & 18 deletions python/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,9 @@
import torch
import torch._dynamo as dynamo
from diffusers import AutoencoderKL

import safetensors
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
)
parser.add_argument(
"--hf_model_name",
type=str,
Expand All @@ -36,6 +31,9 @@
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument(
"--precision", type=str, default="fp32", help="Precision of Stable Diffusion"
)
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_path", type=str, default="")
parser.add_argument(
Expand All @@ -57,31 +55,56 @@


class VaeModel(torch.nn.Module):
def __init__(self, hf_model_name, hf_auth_token):
def __init__(
self,
hf_model_name,
custom_vae="",
):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
token=hf_auth_token,
)
self.model = None
if custom_vae == "":
self.model = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
)
elif not isinstance(custom_vae, dict):
try:
# custom HF repo with no vae subfolder
self.model = AutoencoderKL.from_pretrained(
custom_vae,
)
except:
# some larger repo with vae subfolder
self.model = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
)
else:
# custom vae as a HF state dict
self.model = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
)
self.model.load_state_dict(custom_vae)

def decode_inp(self, inp):
with torch.no_grad():
x = self.vae.decode(inp, return_dict=False)[0]
return x
if not self.base_vae:
inp = 1 / 0.18215 * inp
x = self.vae.decode(inp, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
return x

def encode_inp(self, inp):
latents = self.vae.encode(inp).latent_dist.sample()
return 0.18215 * latents


def export_vae_model(
vae_model,
hf_model_name,
batch_size,
height,
width,
hf_auth_token=None,
precision,
compile_to="torch",
external_weights=None,
external_weight_path=None,
Expand All @@ -91,6 +114,8 @@ def export_vae_model(
variant="decode",
):
mapper = {}
dtype = torch.float16 if precision == "fp16" else torch.float32
vae_model = vae_model.to(dtype)
utils.save_external_weights(
mapper, vae_model, external_weights, external_weight_path
)
Expand Down Expand Up @@ -123,15 +148,14 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
args = parser.parse_args()
vae_model = VaeModel(
args.hf_model_name,
args.hf_auth_token,
)
mod_str = export_vae_model(
vae_model,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.hf_auth_token,
args.precision,
args.compile_to,
args.external_weights,
args.external_weight_path,
Expand Down
59 changes: 41 additions & 18 deletions python/turbine_models/custom_models/sd_inference/vae_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument(
"--hf_auth_token",
type=str,
help="The Hugging face auth token, required for some models",
)
parser.add_argument(
"--device",
type=str,
Expand All @@ -48,9 +43,7 @@
parser.add_argument("--variant", type=str, default="decode")


def run_vae(
device, example_input, vmfb_path, hf_model_name, hf_auth_token, external_weight_path
):
def run_vae(device, example_input, vmfb_path, hf_model_name, external_weight_path):
runner = vmfbRunner(device, vmfb_path, external_weight_path)

inputs = [ireert.asdevicearray(runner.config.device, example_input)]
Expand All @@ -62,26 +55,57 @@ def run_torch_vae(hf_model_name, hf_auth_token, variant, example_input):
from diffusers import AutoencoderKL

class VaeModel(torch.nn.Module):
def __init__(self, hf_model_name, hf_auth_token):
def __init__(
self,
hf_model_name,
base_vae=False,
custom_vae="",
low_cpu_mem_usage=False,
hf_auth_token="",
):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
token=hf_auth_token,
)
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
hf_auth_token=hf_auth_token,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
hf_auth_token=hf_auth_token,
)
else:
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
hf_auth_token=hf_auth_token,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae

def decode_inp(self, inp):
with torch.no_grad():
x = self.vae.decode(inp, return_dict=False)[0]
return x
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0
return x.round()

def encode_inp(self, inp):
latents = self.vae.encode(inp).latent_dist.sample()
return 0.18215 * latents

vae_model = VaeModel(
hf_model_name,
hf_auth_token,
)

if variant == "decode":
Expand All @@ -108,7 +132,6 @@ def encode_inp(self, inp):
example_input,
args.vmfb_path,
args.hf_model_name,
args.hf_auth_token,
args.external_weight_path,
)
print(
Expand Down
Loading

0 comments on commit a6471c6

Please sign in to comment.