Skip to content

Commit

Permalink
Tweaks after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jan 5, 2024
1 parent a6471c6 commit 4498486
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions python/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,32 @@ def __init__(
custom_vae="",
):
super().__init__()
self.model = None
self.vae = None
self.base_vae = False
if custom_vae == "":
self.model = AutoencoderKL.from_pretrained(
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
)
elif not isinstance(custom_vae, dict):
try:
# custom HF repo with no vae subfolder
self.model = AutoencoderKL.from_pretrained(
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
)
except:
# some larger repo with vae subfolder
self.model = AutoencoderKL.from_pretrained(
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
)
else:
# custom vae as a HF state dict
self.model = AutoencoderKL.from_pretrained(
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
)
self.model.load_state_dict(custom_vae)
self.vae.load_state_dict(custom_vae)

def decode_inp(self, inp):
if not self.base_vae:
Expand Down Expand Up @@ -127,7 +128,7 @@ def export_vae_model(
class CompiledVae(CompiledModule):
params = export_parameters(vae_model)

def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
def main(self, inp=AbstractTensor(*sample, dtype=dtype)):
if variant == "decode":
return jittable(vae_model.decode_inp)(inp)
elif variant == "encode":
Expand Down

0 comments on commit 4498486

Please sign in to comment.