Skip to content

Commit

Permalink
Tweaks after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jan 5, 2024
1 parent a6471c6 commit 155ae04
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
4 changes: 2 additions & 2 deletions python/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ class CompiledUnet(CompiledModule):
def main(
self,
sample=AbstractTensor(*sample, dtype=dtype),
timestep=AbstractTensor(1, dtype=dtype),
timestep=AbstractTensor([1], dtype=dtype),
encoder_hidden_states=AbstractTensor(
*encoder_hidden_states_sizes, dtype=dtype
),
guidance_scale=AbstractTensor(1, dtype=dtype),
guidance_scale=AbstractTensor([1], dtype=dtype),
):
return jittable(unet_model.forward)(
sample, timestep, encoder_hidden_states, guidance_scale
Expand Down
15 changes: 8 additions & 7 deletions python/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,32 @@ def __init__(
custom_vae="",
):
super().__init__()
self.model = None
self.vae = None
self.base_vae = False
if custom_vae == "":
self.model = AutoencoderKL.from_pretrained(
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
)
elif not isinstance(custom_vae, dict):
try:
# custom HF repo with no vae subfolder
self.model = AutoencoderKL.from_pretrained(
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
)
except:
# some larger repo with vae subfolder
self.model = AutoencoderKL.from_pretrained(
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
)
else:
# custom vae as a HF state dict
self.model = AutoencoderKL.from_pretrained(
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
)
self.model.load_state_dict(custom_vae)
self.vae.load_state_dict(custom_vae)

def decode_inp(self, inp):
if not self.base_vae:
Expand Down Expand Up @@ -127,7 +128,7 @@ def export_vae_model(
class CompiledVae(CompiledModule):
params = export_parameters(vae_model)

def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
def main(self, inp=AbstractTensor(*sample, dtype=dtype)):
if variant == "decode":
return jittable(vae_model.decode_inp)(inp)
elif variant == "encode":
Expand Down

0 comments on commit 155ae04

Please sign in to comment.