-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unifies SD pipeline APIs, adds sd3 support, punet integration #706
Conversation
711d839
to
94cd822
Compare
287d325
to
ec0a66a
Compare
Exporting models was using "devices" instead of "driver" which caused errors on compiling given a specified rocm://<> device. Changed to fix bug
Missed a couple "devices" instead of "driver" changes
New flag batch_prompt_input determines if prompt encoder uses batchsize flag to concat output, or to batch the input shapes
Fixed typo for sdxl_prompt_encoder arg
The reason for updating the revision hash is this PR by Stella in sharktank: nod-ai/SHARK-Platform#93. Because we are using sharktank TOM, we need to update here too so that it gives sharktank the expected quant_params.json. --------- Signed-off-by: saienduri <[email protected]>
…l_map Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: aviator19941 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's some commented code in multiple files, did you want to keep those there?
|
||
|
||
if __name__ == "__main__": | ||
from turbine_models.custom_models.sd_inference.sd_cmd_opts import args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use sd_cmd_opts here? The default height and width is 512.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's ok, can just pass in height and width as 1024 through the CL args for SDXL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept this as a default since it is the only size supported by all the pipeline's supported models. Since the pipelines are using the same class/API now, they should either use the same args or we need a layer between command line arguments and pipeline API that sets some defaults based on a few core arguments like hf model names. If we want to do the latter it would be best as a follow-up.
I plan on doing a scrub of things like print statements and commented code -- there are also a few files/functions to deprecate or remove, so that can come as a follow-up to this patch.
@@ -502,98 +400,103 @@ def test04_ExportVaeModelEncode(self): | |||
np.testing.assert_allclose(torch_output, turbine, rtol, atol) | |||
|
|||
def test05_t2i_generate_images(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like this test is failing the CI with this error:
ValueError: Expected input 4 to be of shape (2, 6) for compiled_unet['run_forward'], got (1, 6).
No description provided.