Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidates SD pipelines and adds support for sharktank unet. #766

Merged
merged 12 commits into from
Jul 10, 2024

Conversation

monorimet
Copy link
Contributor

No description provided.

):
fxb = FxProgramsBuilder(vae_model)

# @fxb.export_program(args=(encode_args,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code

base_params=None,
)
else:
ds = None # get_punet_dataset(results["config.json"], results["params.safetensors"], base_params=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't implemented it yet. just ran i8 straight through and left some spots for the other configurations. Will replace with TODOs or just finish it after higher priority items

output = export(
unet_model,
kwargs=example_forward_args_dict,
module_name="compiled_unet",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be compiled_punet or no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can.

quant_params_struct=None,
base_params=None,
):
from sharktank.models.punet.tools.import_brevitas_dataset import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't do it this way. Just import the main function and pass it CL args vs using the private internals.


# TODO: Post-process to introduce fused cross-layer connections.

ds.save(output_path, io_report_callback=print)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bug in the Dataset handling code where if you construct a Dataset on the fly and then use that without loading it from disk, it does not annotate the imported tensors, resulting in them being inlined vs referenced. I need to fix that bug, but the workflow in this case shouldn't be doing this anyway.

If you follow the advice in the comment above and use the main() entrypoint to generate the IRPA on disk, then do Dataset.load(irpa_path), that will work.

@monorimet monorimet changed the title sharktank UNet model support. Consolidates unet pipelines and adds support for sharktank unet. Jul 10, 2024
@monorimet monorimet changed the title Consolidates unet pipelines and adds support for sharktank unet. Consolidates SD pipelines and adds support for sharktank unet. Jul 10, 2024
@monorimet
Copy link
Contributor Author

needs iree-org/iree-turbine#40 to land for some metadata pass fixes on clip.

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a number of comments that would be good to address here or in a followup.

no_boseos_middle=no_boseos_middle,
chunk_length=pipe.model_max_length,
)
print("The following text was removed due to truncation:", removed_text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to use print like this? Generally not great to spew to stdout from a library. Maybe use warnings.warn(...) instead?

torch.empty(1, dtype=dtype),
]
decomp_list = []
if decomp_attn == True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just if decomp_attn:?

@@ -331,7 +340,8 @@ def get_wmma_spec_path(target_chip, save_dir, masked_attention=False):
else:
return None
attn_spec = urlopen(url).read().decode("utf-8")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this isn't the only place where we are fetching network resources, but wish we weren't doing this... some day.

repo_id = "amd-shark/sdxl-quant-models"
subfolder = "unet/int8"
revision = "82e06d6ea22ac78102a9aded69e8ddfb9fa4ae37"
elif precision in ["fp16", "fp32"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI - the sdxl-quant-models do have an fp16 variant IRPA file. I use that for various comparisons, etc.

https://huggingface.co/amd-shark/sdxl-quant-models/tree/main/unet/fp16/export

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can decide whether we want to do the full model download/quantization here or just fetch the IRPA file from the export subdir. If we did the latter, you could drop get_punet_dataset() that is actually building it and just use the one we publish. If doing that, you can eliminate this branch and just change the subfolder and filename or something.

Fine to check in like this since it is working, but can simplify by leaning more on published artifacts.

Now that I'm thinking of it, though, we're going to need a procedure for ML-Perf to rely completely on generated artifacts. Currently, that is being done by providing scripts to regenerate, but then we probably also need to support flags here and elsewhere to side-load.

In any case, let's land and discuss/iterate.

}
if precision == "i8":
results["quant_params.json"] = download("quant_params.json")
output_path = external_weight_path.split("unet")[0] + "punet_dataset_i8.irpa"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if this were in the directory C:\Users\unet\Downloads? Would be better to use pathlib for manipulation of just the stem or something.

def get_punet_dataset(
config_json_path,
params_path,
output_path="./punet_dataset_i8.irpa",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the default. Should always be provided, I think.

output_path="./punet_dataset_i8.irpa",
quant_params_path=None,
quant_params_struct=None,
base_params=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems not used/remove (this was a hack/prototype thing that didn't work out anyway).

params_path,
output_path="./punet_dataset_i8.irpa",
quant_params_path=None,
quant_params_struct=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used/remove?

[
f"--config-json={config_json_path}",
f"--params={params_path}",
f"--quant-params={quant_params_path}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will work for the fp16 version like this (the literal value "None" will try to load a file named "None").

There is an importer right next to this for a stock huggingface dataset which was intended to be used for this case. Should probably switch here, but judging from above, I'm guessing that fp16 punet has not been tested yet. Can flush when it is.

elif (not decomp_attn) and "gfx9" in target:
attn_spec = "mfma"
elif (not decomp_attn) and "gfx11" in target:
attn_spec = "wmma"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of thing makes me sad. Nothing to do about it now.

@monorimet monorimet merged commit d534cd4 into ean-unify-sd Jul 10, 2024
2 of 3 checks passed
@monorimet monorimet deleted the ean-unify-sd-staging branch July 10, 2024 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants