Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds SDXL support and CI testing, benchmarks. #271

Merged
merged 182 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
182 commits
Select commit Hold shift + click to select a range
a29fb24
Add precision to unet, vae and guidance scale as input to unet
monorimet Dec 18, 2023
a67f255
(WIP) Add SDXL
monorimet Jan 19, 2024
95a2f7f
WIP: Add CLIP, CLIP2 and tweaks to unet for SDXL
monorimet Jan 22, 2024
8fdc639
Move SDXL scripts.
monorimet Feb 7, 2024
97ee822
Fix formatting.
monorimet Feb 7, 2024
430ef6c
Fix formatting 2
monorimet Feb 9, 2024
010649c
f32/f16 -> fp32/fp16
monorimet Feb 9, 2024
55d8c42
Cherry-pick c404693 : Add a guarded sdpa_cpu torch decomposition and …
monorimet Feb 9, 2024
a38e0e9
Tweaks to SDXL script defaults, handles periods in hf_model_name via …
monorimet Feb 9, 2024
ebce84c
Change VAE export script default model.
monorimet Feb 9, 2024
ad1a5d5
Add a line in sd_test workflow to update torch version.
monorimet Feb 9, 2024
269ffe2
Makes sequence length configurable.
monorimet Feb 12, 2024
f297e60
Add max_length to safe names for unet, clip
monorimet Feb 12, 2024
2f8132d
Fix formatting
monorimet Feb 12, 2024
c703fa9
Add sdxl_test to CI
monorimet Feb 12, 2024
34b65da
Simplify VAE and remove SDPA decompositions
monorimet Feb 14, 2024
a0cf6dd
Fix formatting.
monorimet Feb 14, 2024
8c554fc
Fix some mismatches in VAE model comparisons.
monorimet Feb 14, 2024
26ad145
Small fixes to clip runner and remove sdpa decomp implem.
monorimet Feb 15, 2024
fd5328a
Make clip, vae filenames unique to precision, vae variant
monorimet Feb 21, 2024
5ed0c8a
Update SDXL tests, scripts with many small fixes to CLIP and others
monorimet Feb 22, 2024
f143cfc
Fix formatting.
monorimet Feb 22, 2024
33bee27
Make consteval flags exposed as arg in compile_to_vmfb
monorimet Feb 22, 2024
ce064b8
Exhaustively differentiate .mlir, vmfb files by config.
monorimet Feb 22, 2024
35a7f98
SDXL test
jinchen62 Feb 22, 2024
2d80b7e
Small filename fix and compile flag tweaks.
monorimet Feb 23, 2024
4dd1e51
Bump IREE version to >=20230306.822 for fx importer
monorimet Feb 23, 2024
0f7cd00
test argparse tweaks and pin mpmath.
monorimet Feb 23, 2024
1c769b5
SDXL test and benchmark (#474)
jinchen62 Feb 23, 2024
4ce0df7
Add flag to exporters, sdxl tests to decompose sdpfa at fx
monorimet Feb 25, 2024
2f63446
Change pytorch cpu requirement to latest (>=2.3.0)
monorimet Feb 26, 2024
b6de347
Fix --decomp_attn
monorimet Feb 26, 2024
3026e3d
Fix --decomp_attn for VAE as well.
monorimet Feb 26, 2024
49c712d
Change unet_runner timestep input to int64
monorimet Feb 27, 2024
b9aa8ac
Fix CLI for vae_runner.py
monorimet Feb 27, 2024
700ecb1
Use madebyollin/sdxl-vae-fp16-fix for weights in vae/vae_runner.py if…
monorimet Feb 27, 2024
6e66dc6
Add txt2img test.
monorimet Feb 29, 2024
44983d1
(WIP): Add e2e inference test for txtimg sdxl.
monorimet Feb 29, 2024
914b73f
Separate clip tester and encode_prompt fn
monorimet Feb 29, 2024
52a28da
Fix call to clip_runner in t2i test.
monorimet Feb 29, 2024
08f9544
Fix e2e t2i test for sdxl.
monorimet Feb 29, 2024
7edc7b1
Pass command line args for sdxl pytest (#487)
jinchen62 Feb 29, 2024
ef0d929
More t2i fixes (file mgmt)
monorimet Feb 29, 2024
edaeff7
Check for vmfbs, weights or skip t2i test, small fixes to torch runners
monorimet Mar 1, 2024
fc1d673
flag tweaks, and fixes to e2e inference
monorimet Mar 1, 2024
eecab90
Explicitly set some types in pytest args.
monorimet Mar 1, 2024
1bfed12
Support for SDXL schedulers + example (#499)
monorimet Mar 1, 2024
be93f74
Explicitly set target triple flag to string type.
monorimet Mar 1, 2024
a642cb3
Fix formatting.
monorimet Mar 1, 2024
112c6ed
WIP: shrink-wrap unet+scheduler txt2img
monorimet Mar 1, 2024
bddcb08
Fix iree_target_triple pytest arg.
monorimet Mar 1, 2024
214b526
fix sd/sdxl CI (#500)
jinchen62 Mar 2, 2024
65067a0
(WIP) Move argparser and start mlir pipelining for sdxl.
monorimet Mar 4, 2024
1c2c2bd
test sdxl inference (#503)
jinchen62 Mar 4, 2024
a97b27c
fix unet script args
monorimet Mar 5, 2024
7a52bcc
Set max_model_length in CLIP tokenizers based on user spec.
monorimet Mar 5, 2024
6cd40a3
Small models and script fixes.
monorimet Mar 5, 2024
8e6f85e
Add SDXL pipeline script and unify SDXL args.
monorimet Mar 6, 2024
ab35501
Fix compiled scheduled unet pipeline.
monorimet Mar 6, 2024
efc2136
Fix formatting, ireec_flags parsing, weights naming in pipeline script
monorimet Mar 6, 2024
805f29d
fixup: remove breakpoint
monorimet Mar 6, 2024
63cc7ef
Remove windows hardcoded rocm bc dir flag.
monorimet Mar 6, 2024
5d9e19f
Fix sdxl test args (#520)
jinchen62 Mar 6, 2024
da6809a
Fixup pipeline mlir -> vmfb
monorimet Mar 6, 2024
a3c4751
Explicitly set dtypes based on precision argument
monorimet Mar 6, 2024
e87e6b1
Fixup fp16 pipeline
eagarvey-amd Mar 7, 2024
7d0caee
Fix vae decode export case returning tuple.
monorimet Mar 7, 2024
e514e3a
Fixup breakpoint
monorimet Mar 7, 2024
7f9ca66
Fix VAE export case (again)
monorimet Mar 7, 2024
0df84aa
Fix vae export function returns for vmfb.
monorimet Mar 7, 2024
d71ceb1
Remove source map stripping flag from rocm compile args
monorimet Mar 8, 2024
0ebd3fa
Add .mlir for unrolled loop, add option to have scheduled unet return…
monorimet Mar 8, 2024
424c1d5
Fix --return_path for pipeline.
monorimet Mar 8, 2024
0199fd8
Add --decomp_attn conditional back into unet.py
monorimet Mar 8, 2024
08ffad4
Add unrolled pipeline IRs
eagarvey-amd Mar 8, 2024
b771d05
Update rocm flags for sd.
monorimet Mar 9, 2024
a55200b
Switch const_expr_hoisting to true by default.
monorimet Mar 9, 2024
91687db
fix steps count output of run_initialize
monorimet Mar 9, 2024
e248cca
Add batch count to pipeline, improve benchmarking reports, explicitly…
monorimet Mar 10, 2024
56684e6
Rework timings, start simplifying prompt encoding
monorimet Mar 10, 2024
d29145d
Add a variant of the pipeline with 0 device->host after tokenization
monorimet Mar 11, 2024
d1c1f26
Fix issues with preparation of files after export
monorimet Mar 11, 2024
d514720
Fix prep for old pipeline.
monorimet Mar 11, 2024
ef7746f
Fix seed propagation and batching.
monorimet Mar 11, 2024
d976ab0
Fix formatting.
monorimet Mar 11, 2024
552798a
Fix return ordering of export_prompt_encoder call.
monorimet Mar 11, 2024
620b53c
Correct timesteps for benchmarking PNDM
monorimet Mar 12, 2024
b2d3398
Fixes to pipeline, cooler prompt, fix scheduled unet comparisons
monorimet Mar 13, 2024
b630c80
Fixups to pipeline, import examples, move unrolled loop .mlirs to Azure
monorimet Mar 14, 2024
93aaf08
formatting fixes
monorimet Mar 14, 2024
0c9c605
small fixes
monorimet Mar 14, 2024
0dd0b6b
Let the user know if comparison is OK
monorimet Mar 14, 2024
fb73926
Bake in flags to utils for MI instructions.
monorimet Mar 14, 2024
e3cd97e
Remove vector distribution from golden MI flags
monorimet Mar 14, 2024
26d1e65
Add attention spec flag and check in a default verified version.
monorimet Mar 14, 2024
75a36a4
Add attention spec flag to parser
monorimet Mar 14, 2024
7c40f02
add attn_spec to vae expoirt
monorimet Mar 14, 2024
c68fec1
Prop. attn_spec to compilation correctly.
monorimet Mar 14, 2024
9ac15e7
Setup mlir input and downloads for SDXL models, update flags for gfx9XX
monorimet Mar 14, 2024
eb139ed
Remove empty flags before parsing ireec opts.
monorimet Mar 14, 2024
630b720
Bump MI flags for SDXL branch of IREE.
monorimet Mar 14, 2024
5ae2946
Add all flags
monorimet Mar 14, 2024
06504b1
Comment out weights-only getter
monorimet Mar 14, 2024
1ef84c4
Prop attn_spec arg to unet.py
monorimet Mar 14, 2024
2e53620
Update MI flags for sdxl.
monorimet Mar 14, 2024
5b87995
The golden flag commit
monorimet Mar 14, 2024
9b35a58
Update docs.
monorimet Mar 14, 2024
98af417
Simplify some compile flags and add weights fetching option to exports
monorimet Mar 16, 2024
1a6291f
Add input mlir opt to unet.py and add winograd flag.
monorimet Mar 16, 2024
00387bf
Fix --input_mlir for unet/vae
monorimet Mar 16, 2024
c7ef8f4
Exit after .vmfb compiles if --input_mlir specified.
monorimet Mar 16, 2024
13f493e
Use --device for all runner scripts since it is unambiguous there.
monorimet Mar 16, 2024
fa2c52f
send outputs to host before output/comparison.
monorimet Mar 16, 2024
e04f5a5
Disable native math precision flag on CLIP
monorimet Mar 16, 2024
6f15574
Flags update (remove native math precision on VAE)
monorimet Mar 16, 2024
33ea878
Pipe through mlir_source in mlir input mode for Scheduled unet.
monorimet Mar 16, 2024
65a6f23
Bump spec to 1bcbef6
monorimet Mar 16, 2024
b687c2c
Make it easier to run and validate scheduled unet + pipeline wrapper.
monorimet Mar 16, 2024
5270841
Fix bug generating model artifacts with --external_weights=irpa
monorimet Mar 17, 2024
9f3a5b7
add full pipeline wrapper .mlir and compile alongside scheduled unet
monorimet Mar 17, 2024
496e126
Switch clip main function name and pipe through support for e2e onesh…
monorimet Mar 17, 2024
f02405a
fixup: differentiate pipeline filenames by mode
monorimet Mar 17, 2024
bf2afa7
Small fixes to pipeline modes
monorimet Mar 17, 2024
ca8a059
Small fixes to pipeline vmfb naming
monorimet Mar 17, 2024
10bc439
Move d2h after image completiooutside of computation timing.
monorimet Mar 18, 2024
4cd3596
Fix formatting
monorimet Mar 18, 2024
f616846
formatting with right black version
monorimet Mar 18, 2024
84e4a81
Add requests to serving setup.
monorimet Mar 18, 2024
377918d
Update and rename import_examples.md to COMMANDS.md
monorimet Mar 18, 2024
914caa6
Bypass type check on two functionalized graph method calls.
monorimet Mar 18, 2024
c72f38d
Fix formatting
eagarvey-amd Mar 18, 2024
493184b
Refactor pipeline into a class and update sdxl e2e test.
eagarvey-amd Mar 18, 2024
c088f49
Fixup args for SDXL pipeline
monorimet Mar 18, 2024
82084d9
Fix conditional logic for setting sdxl flags.
monorimet Mar 18, 2024
668eaa2
Fix formatting.
eagarvey-amd Mar 18, 2024
06ce9a1
Bump attn spec to
eagarvey-amd Mar 18, 2024
afd8c97
Update flags for MI perf.
monorimet Mar 19, 2024
b2dd042
Fixup flags.
monorimet Mar 19, 2024
e370c8f
Set latest flags and attention spec (07f52fe)
monorimet Mar 20, 2024
938c9ea
add a separate flag for decomposing attn in VAE
eagarvey-amd Mar 20, 2024
55fc076
Flags and spec update to 90bacfae
monorimet Mar 21, 2024
f0d5f5d
Update attn spec.
monorimet Mar 21, 2024
0f7ccad
Fixup weights only exports/f32 cpu route
monorimet Mar 22, 2024
c85b5d6
Fix sdpa on Vulkan for SD (#557)
gpetters94 Mar 27, 2024
d545889
Fix args in sd_test
monorimet Mar 28, 2024
c388f08
Fixup test API calls.
monorimet Mar 29, 2024
96ca856
cleanup triple default behavior
monorimet Mar 29, 2024
d2a9af5
small fixes to sd_test and sd utils.
monorimet Mar 29, 2024
cdff9f5
fixup sd_test
monorimet Mar 29, 2024
2a1bc50
Explicitly send outputs to host for test runners.
monorimet Mar 29, 2024
d4848d7
Fix latent_model_input calculation in scheduled unet w/ EulerDiscrete…
aviator19941 Mar 29, 2024
a559e57
Fix segfaults issue by disabling caching allocator on CPU
monorimet Mar 29, 2024
4c74c96
Fix formatting.
monorimet Mar 29, 2024
b2871f8
Remove redundant d2h for clip outputs
monorimet Mar 29, 2024
f5d5a3f
send correct guidance_scale value to unet runner
monorimet Mar 29, 2024
e20bd59
Fixup test file mgmt.
monorimet Mar 29, 2024
69a1bef
Remove expected system exits from testing.
monorimet Mar 29, 2024
a70e9b5
few more fixes to sdxl tests, args
monorimet Mar 29, 2024
9f73fbb
Tweak test config.
monorimet Apr 4, 2024
f02fbd3
Merge branch 'main' into ean-sd-fp16
monorimet Apr 4, 2024
8cff3d8
Fix precision for cpu test.
monorimet Apr 4, 2024
c8d62fe
Explicitly install nod-ai diffusers fork for sd tests.
monorimet Apr 4, 2024
b92be0e
Install turbine-models requirements in model testing job.
monorimet Apr 4, 2024
7bcb003
Don't specify pipeline directory for model unit tests.
monorimet Apr 4, 2024
f569cea
fix stateless llama testing (#600)
saienduri Apr 8, 2024
bc54f7b
Remove expected failure for vae encoder test.
monorimet Apr 8, 2024
fe27035
Change rocm runtime device to "hip"
monorimet Apr 8, 2024
15ddf3b
Merge branch 'main' into ean-sd-fp16
monorimet Apr 9, 2024
699ba0d
Try hip driver and tweak rocm flags.
monorimet Apr 9, 2024
0011328
cleanup pipeline test artifacts after completion.
monorimet Apr 9, 2024
4d7bfef
restrict wmma flags to gfx94X
monorimet Apr 9, 2024
946a02f
Decompose attention in CI tests.
monorimet Apr 9, 2024
77d4308
Pipe through attn spec option correctly.
monorimet Apr 9, 2024
3336b6b
Use fp16 for mi210 CI.
monorimet Apr 9, 2024
68c3c6c
Fix default attention spec behavior
monorimet Apr 9, 2024
8dc1fba
Update test_models.yml
monorimet Apr 9, 2024
bfbebef
xfail e2e on rocm shortly, pending move to nightly test
monorimet Apr 10, 2024
246d32d
Merge branch 'main' into ean-sd-fp16
monorimet Apr 10, 2024
bb2c7e0
use config A for cpu CI
monorimet Apr 10, 2024
12b91f4
Remove xfails on submodels for rocm.
monorimet Apr 11, 2024
575bcd0
Cleanup comments and redundant code.
monorimet Apr 11, 2024
eaeb646
Skip tests that crash on MI210 for now.
monorimet Apr 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
uses: actions/checkout@v2

- name: Sync source deps
# build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile
run: |
python -m venv turbine_venv
source turbine_venv/bin/activate
Expand All @@ -44,7 +45,7 @@ jobs:
pip install -r core/pytorch-cpu-requirements.txt
pip install --pre --upgrade -r core/requirements.txt
pip install --pre -e core[testing]
pip install --pre -e models
pip install --pre --upgrade -e models -r models/requirements.txt

- name: Show current free memory
run: |
Expand All @@ -59,3 +60,6 @@ jobs:
run: |
source turbine_venv/bin/activate
pytest models/turbine_models/tests/sd_test.py
pytest models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu
pytest models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16
2 changes: 1 addition & 1 deletion core/iree-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
iree-compiler==20240410.859
iree-runtime==20240410.859
iree-runtime==20240410.859
2 changes: 1 addition & 1 deletion core/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def flat_wrapped_f(*args):
)
logger.debug("Invoking dynamo trace")
gm, guards = exported_f(*flat_pytorch_args)
logger.debug("Dyanmo trace complete")
logger.debug("Dynamo trace complete")

# TODO: Add debug logging for the exported graph module.
# gm.print_readable()
Expand Down
2 changes: 0 additions & 2 deletions core/shark_turbine/dynamo/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def _get_default_decomposition_ops() -> DecompositionOpsList:
aten.lift_fresh_copy.default,
aten._unsafe_index.Tensor,
aten.unbind.int,
# decompositions added manually in this file
aten._scaled_dot_product_flash_attention.default,
]


Expand Down
2 changes: 1 addition & 1 deletion models/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ sentencepiece
shark_turbine
transformers==4.37.1
accelerate
diffusers==0.24.0
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
# turbine tank downloading/uploading
azure-storage-blob
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import sys
import argparse
import numpy as np
import re
import os
import re
import sys

from transformers import AutoTokenizer
from iree import runtime as ireert
from turbine_models.utils.benchmark import benchmark_module
import turbine_models.custom_models.stateless_llama as llama

import argparse

import subprocess
from collections import namedtuple

parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -71,16 +69,20 @@ def run_benchmark(args):
input.append(temp)
input.append(np.array(args.steps))

vmfbs = []
vmfbs.append(args.llama_vmfb_path)
vmfbs.append(args.benchmark_vmfb_path)

if args.external_weight_file:
results = benchmark_module(
benchmark_mod,
args,
"run",
vmfbs,
input,
parameters=f"model={args.external_weight_file}",
)
else:
results = benchmark_module(benchmark_mod, args, "run", input)
results = benchmark_module(benchmark_mod, "run", vmfbs, input)

for benchmark_result in results:
print(
Expand Down Expand Up @@ -146,16 +148,20 @@ def run_forward_benchmark(args):
input.append(temp)
input.append(np.array(args.steps))

vmfbs = []
vmfbs.append(args.llama_vmfb_path)
vmfbs.append(args.benchmark_vmfb_path)

if args.external_weight_file:
results = benchmark_module(
benchmark_mod,
args,
"run",
vmfbs,
input,
parameters=f"model={args.external_weight_file}",
)
else:
results = benchmark_module(benchmark_mod, args, "run", input)
results = benchmark_module(benchmark_mod, "run", vmfbs, input)

for benchmark_result in results:
print(
Expand All @@ -178,110 +184,6 @@ def run_forward_benchmark(args):
np.dtype(np.bool_): "i1",
}

BenchmarkResult = namedtuple(
"BenchmarkResult", "benchmark_name time cpu_time iterations user_counters"
)


class BenchmarkToolError(Exception):
"""Benchmark exception that preserves the command line and error output."""

def __init__(self, message):
self.message = message
super().__init__(self.message)


class BenchmarkTimeoutError(Exception):
"""Exception raised if the benchmark is cancelled by the user specified timeout."""

pass


def benchmark_module(
module, bench_args, entry_function=None, inputs=[], timeout=None, **kwargs
):
funcs = [a for a in module.function_names if a != "__init"]
if entry_function is None:
if len(funcs) > 1:
raise ValueError(f"No function specified with multiple options {funcs}")
entry_function = funcs[0]
if entry_function not in funcs:
raise ValueError(
f"Attempted to benchmark unknown function {entry_function} of options {funcs}"
)

args = [ireert.benchmark_exe()]
args.append(f"--function={entry_function}")

for inp in inputs:
if isinstance(inp, str):
args.append(f"--input={inp}")
continue
shape = "x".join([str(d) for d in inp.shape])
abitype = DTYPE_TO_ABI_TYPE[inp.dtype]
values = inp.flatten()
if np.all(values[0] == values):
values = str(values[0])
else:
values = ",".join([str(v) for v in values])

args.append(f"--input={shape}x{abitype}={values}")

for k in kwargs:
v = kwargs[k]
args.append(f"--{k}={v}")

args.append(f"--module={bench_args.llama_vmfb_path}")
args.append(f"--module={bench_args.benchmark_vmfb_path}")

try:
benchmark_process = subprocess.run(
args=args,
# input=flatbuffer,
timeout=timeout,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.TimeoutExpired:
raise BenchmarkTimeoutError(f"Benchmark timed out after {timeout} seconds")
out = benchmark_process.stdout
err = benchmark_process.stderr

err = err.decode()
if "INVALID_ARGUMENT;" in err:
raise ValueError("Invalid inputs specified for benchmarking")

# In the event benchmarking runs but encounteres an internal error,
# return the internal error instead of benchmark results.
if "INTERNAL; CUDA driver error" in str(out):
raise BenchmarkToolError(str(out))

# Grab individual results by line (skip header lines)
bench_lines = out.decode().split("\n")[3:]
benchmark_results = []
for line in bench_lines:
split = line.split()
if len(split) == 0:
continue
benchmark_name = split[0]
time = " ".join(split[1:3])
cpu_time = " ".join(split[3:5])
iterations = split[5]
user_counters = None
if len(split) > 5:
user_counters = split[6]
benchmark_results.append(
BenchmarkResult(
benchmark_name=benchmark_name,
time=time,
cpu_time=cpu_time,
iterations=iterations,
user_counters=user_counters,
)
)

return benchmark_results


if __name__ == "__main__":
args = parser.parse_args()
Expand Down
Loading
Loading