-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consolidates SD pipelines and adds support for sharktank unet. #766
Conversation
): | ||
fxb = FxProgramsBuilder(vae_model) | ||
|
||
# @fxb.export_program(args=(encode_args,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code
base_params=None, | ||
) | ||
else: | ||
ds = None # get_punet_dataset(results["config.json"], results["params.safetensors"], base_params=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't implemented it yet. just ran i8 straight through and left some spots for the other configurations. Will replace with TODOs or just finish it after higher priority items
output = export( | ||
unet_model, | ||
kwargs=example_forward_args_dict, | ||
module_name="compiled_unet", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be compiled_punet or no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can.
quant_params_struct=None, | ||
base_params=None, | ||
): | ||
from sharktank.models.punet.tools.import_brevitas_dataset import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't do it this way. Just import the main
function and pass it CL args vs using the private internals.
|
||
# TODO: Post-process to introduce fused cross-layer connections. | ||
|
||
ds.save(output_path, io_report_callback=print) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bug in the Dataset handling code where if you construct a Dataset on the fly and then use that without loading it from disk, it does not annotate the imported tensors, resulting in them being inlined vs referenced. I need to fix that bug, but the workflow in this case shouldn't be doing this anyway.
If you follow the advice in the comment above and use the main() entrypoint to generate the IRPA on disk, then do Dataset.load(irpa_path), that will work.
9825e66
to
34d3d84
Compare
needs iree-org/iree-turbine#40 to land for some metadata pass fixes on clip. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a number of comments that would be good to address here or in a followup.
no_boseos_middle=no_boseos_middle, | ||
chunk_length=pipe.model_max_length, | ||
) | ||
print("The following text was removed due to truncation:", removed_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to use print
like this? Generally not great to spew to stdout from a library. Maybe use warnings.warn(...)
instead?
torch.empty(1, dtype=dtype), | ||
] | ||
decomp_list = [] | ||
if decomp_attn == True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just if decomp_attn:
?
@@ -331,7 +340,8 @@ def get_wmma_spec_path(target_chip, save_dir, masked_attention=False): | |||
else: | |||
return None | |||
attn_spec = urlopen(url).read().decode("utf-8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this isn't the only place where we are fetching network resources, but wish we weren't doing this... some day.
repo_id = "amd-shark/sdxl-quant-models" | ||
subfolder = "unet/int8" | ||
revision = "82e06d6ea22ac78102a9aded69e8ddfb9fa4ae37" | ||
elif precision in ["fp16", "fp32"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI - the sdxl-quant-models do have an fp16 variant IRPA file. I use that for various comparisons, etc.
https://huggingface.co/amd-shark/sdxl-quant-models/tree/main/unet/fp16/export
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can decide whether we want to do the full model download/quantization here or just fetch the IRPA file from the export subdir. If we did the latter, you could drop get_punet_dataset() that is actually building it and just use the one we publish. If doing that, you can eliminate this branch and just change the subfolder and filename or something.
Fine to check in like this since it is working, but can simplify by leaning more on published artifacts.
Now that I'm thinking of it, though, we're going to need a procedure for ML-Perf to rely completely on generated artifacts. Currently, that is being done by providing scripts to regenerate, but then we probably also need to support flags here and elsewhere to side-load.
In any case, let's land and discuss/iterate.
} | ||
if precision == "i8": | ||
results["quant_params.json"] = download("quant_params.json") | ||
output_path = external_weight_path.split("unet")[0] + "punet_dataset_i8.irpa" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would happen if this were in the directory C:\Users\unet\Downloads
? Would be better to use pathlib for manipulation of just the stem or something.
def get_punet_dataset( | ||
config_json_path, | ||
params_path, | ||
output_path="./punet_dataset_i8.irpa", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the default. Should always be provided, I think.
output_path="./punet_dataset_i8.irpa", | ||
quant_params_path=None, | ||
quant_params_struct=None, | ||
base_params=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems not used/remove (this was a hack/prototype thing that didn't work out anyway).
params_path, | ||
output_path="./punet_dataset_i8.irpa", | ||
quant_params_path=None, | ||
quant_params_struct=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used/remove?
[ | ||
f"--config-json={config_json_path}", | ||
f"--params={params_path}", | ||
f"--quant-params={quant_params_path}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this will work for the fp16 version like this (the literal value "None" will try to load a file named "None").
There is an importer right next to this for a stock huggingface dataset which was intended to be used for this case. Should probably switch here, but judging from above, I'm guessing that fp16 punet has not been tested yet. Can flush when it is.
elif (not decomp_attn) and "gfx9" in target: | ||
attn_spec = "mfma" | ||
elif (not decomp_attn) and "gfx11" in target: | ||
attn_spec = "wmma" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kind of thing makes me sad. Nothing to do about it now.
No description provided.