Skip to content

Commit

Permalink
Fix some mismatches in VAE model comparisons.
Browse files Browse the repository at this point in the history
Co-authored-by: jinchen64 <[email protected]>
  • Loading branch information
monorimet and jinchen64 committed Feb 14, 2024
1 parent 2beae3f commit 4e2801f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 32 deletions.
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/sdxl_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def decode_inp(self, inp):
inp = inp / 0.13025
x = self.vae.decode(inp, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
return x
return x.round()

def encode_inp(self, inp):
latents = self.vae.encode(inp).latent_dist.sample()
Expand Down
53 changes: 22 additions & 31 deletions models/turbine_models/custom_models/sdxl_inference/vae_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
"--height", type=int, default=1024, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--width", type=int, default=1024, help="Width of Stable Diffusion")
parser.add_argument("--variant", type=str, default="decode")


Expand All @@ -58,51 +58,44 @@ class VaeModel(torch.nn.Module):
def __init__(
self,
hf_model_name,
base_vae=False,
custom_vae="",
low_cpu_mem_usage=False,
hf_auth_token="",
):
super().__init__()
self.vae = None
if custom_vae == "":
if custom_vae in ["", None]:
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
hf_auth_token=hf_auth_token,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
hf_auth_token=hf_auth_token,
)
try:
# custom HF repo with no vae subfolder
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
)
except:
# some larger repo with vae subfolder
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
)
else:
# custom vae as a HF state dict
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
hf_auth_token=hf_auth_token,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae

def decode_inp(self, input):
with torch.no_grad():
if not self.base_vae:
input = 1 / 0.18215 * input
x = self.vae.decode(input, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
if self.base_vae:
return x
x = x * 255.0

def decode_inp(self, inp):
inp = inp / 0.13025
x = self.vae.decode(inp, return_dict=False)[0]
x = (x / 2 + 0.5).clamp(0, 1)
return x.round()

def encode_inp(self, inp):
latents = self.vae.encode(inp).latent_dist.sample()
return 0.18215 * latents
return 0.13025 * latents

vae_model = VaeModel(
hf_model_name,
Expand Down Expand Up @@ -144,9 +137,7 @@ def encode_inp(self, inp):
print("generating torch output: ")
from turbine_models.custom_models.sd_inference import utils

torch_output = run_torch_vae(
args.hf_model_name, args.hf_auth_token, args.variant, example_input
)
torch_output = run_torch_vae(args.hf_model_name, args.variant, example_input)
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
err = utils.largest_error(torch_output, turbine_results)
print("Largest Error: ", err)
Expand Down

0 comments on commit 4e2801f

Please sign in to comment.