Skip to content

Commit

Permalink
Adds SDXL support and CI testing, benchmarks. (#271)
Browse files Browse the repository at this point in the history
- Updates turbine-models requirements to use nod-ai fork of diffusers
- Adds SDXL implementations and tests, benchmarks
- Updates sd_inference/utils with newest flags and adds scheduler scaffolding to sd1.5/2.1

Co-authored-by: jinchen62 <[email protected]>
Co-authored-by: jinchen <[email protected]>
Co-authored-by: PhaneeshB <[email protected]>
Co-authored-by: Avinash Sharma <[email protected]>
Co-authored-by: gpetters94 <[email protected]>
Co-authored-by: George Petterson <[email protected]>
Co-authored-by: aviator19941 <[email protected]>
Co-authored-by: saienduri <[email protected]>
  • Loading branch information
9 people authored Apr 11, 2024
1 parent 9484484 commit 1dea19e
Show file tree
Hide file tree
Showing 45 changed files with 6,286 additions and 297 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
uses: actions/checkout@v2

- name: Sync source deps
# build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile
run: |
python -m venv turbine_venv
source turbine_venv/bin/activate
Expand All @@ -44,7 +45,7 @@ jobs:
pip install -r core/pytorch-cpu-requirements.txt
pip install --pre --upgrade -r core/requirements.txt
pip install --pre -e core[testing]
pip install --pre -e models
pip install --pre --upgrade -e models -r models/requirements.txt
- name: Show current free memory
run: |
Expand All @@ -59,3 +60,6 @@ jobs:
run: |
source turbine_venv/bin/activate
pytest models/turbine_models/tests/sd_test.py
pytest models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu
pytest models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16
2 changes: 1 addition & 1 deletion core/iree-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
iree-compiler==20240410.859
iree-runtime==20240410.859
iree-runtime==20240410.859
2 changes: 1 addition & 1 deletion core/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def flat_wrapped_f(*args):
)
logger.debug("Invoking dynamo trace")
gm, guards = exported_f(*flat_pytorch_args)
logger.debug("Dyanmo trace complete")
logger.debug("Dynamo trace complete")

# TODO: Add debug logging for the exported graph module.
# gm.print_readable()
Expand Down
2 changes: 0 additions & 2 deletions core/shark_turbine/dynamo/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def _get_default_decomposition_ops() -> DecompositionOpsList:
aten.lift_fresh_copy.default,
aten._unsafe_index.Tensor,
aten.unbind.int,
# decompositions added manually in this file
aten._scaled_dot_product_flash_attention.default,
]


Expand Down
2 changes: 1 addition & 1 deletion models/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ sentencepiece
shark_turbine
transformers==4.37.1
accelerate
diffusers==0.24.0
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
# turbine tank downloading/uploading
azure-storage-blob
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import sys
import argparse
import numpy as np
import re
import os
import re
import sys

from transformers import AutoTokenizer
from iree import runtime as ireert
from turbine_models.utils.benchmark import benchmark_module
import turbine_models.custom_models.stateless_llama as llama

import argparse

import subprocess
from collections import namedtuple

parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -71,16 +69,20 @@ def run_benchmark(args):
input.append(temp)
input.append(np.array(args.steps))

vmfbs = []
vmfbs.append(args.llama_vmfb_path)
vmfbs.append(args.benchmark_vmfb_path)

if args.external_weight_file:
results = benchmark_module(
benchmark_mod,
args,
"run",
vmfbs,
input,
parameters=f"model={args.external_weight_file}",
)
else:
results = benchmark_module(benchmark_mod, args, "run", input)
results = benchmark_module(benchmark_mod, "run", vmfbs, input)

for benchmark_result in results:
print(
Expand Down Expand Up @@ -146,16 +148,20 @@ def run_forward_benchmark(args):
input.append(temp)
input.append(np.array(args.steps))

vmfbs = []
vmfbs.append(args.llama_vmfb_path)
vmfbs.append(args.benchmark_vmfb_path)

if args.external_weight_file:
results = benchmark_module(
benchmark_mod,
args,
"run",
vmfbs,
input,
parameters=f"model={args.external_weight_file}",
)
else:
results = benchmark_module(benchmark_mod, args, "run", input)
results = benchmark_module(benchmark_mod, "run", vmfbs, input)

for benchmark_result in results:
print(
Expand All @@ -178,110 +184,6 @@ def run_forward_benchmark(args):
np.dtype(np.bool_): "i1",
}

BenchmarkResult = namedtuple(
"BenchmarkResult", "benchmark_name time cpu_time iterations user_counters"
)


class BenchmarkToolError(Exception):
"""Benchmark exception that preserves the command line and error output."""

def __init__(self, message):
self.message = message
super().__init__(self.message)


class BenchmarkTimeoutError(Exception):
"""Exception raised if the benchmark is cancelled by the user specified timeout."""

pass


def benchmark_module(
module, bench_args, entry_function=None, inputs=[], timeout=None, **kwargs
):
funcs = [a for a in module.function_names if a != "__init"]
if entry_function is None:
if len(funcs) > 1:
raise ValueError(f"No function specified with multiple options {funcs}")
entry_function = funcs[0]
if entry_function not in funcs:
raise ValueError(
f"Attempted to benchmark unknown function {entry_function} of options {funcs}"
)

args = [ireert.benchmark_exe()]
args.append(f"--function={entry_function}")

for inp in inputs:
if isinstance(inp, str):
args.append(f"--input={inp}")
continue
shape = "x".join([str(d) for d in inp.shape])
abitype = DTYPE_TO_ABI_TYPE[inp.dtype]
values = inp.flatten()
if np.all(values[0] == values):
values = str(values[0])
else:
values = ",".join([str(v) for v in values])

args.append(f"--input={shape}x{abitype}={values}")

for k in kwargs:
v = kwargs[k]
args.append(f"--{k}={v}")

args.append(f"--module={bench_args.llama_vmfb_path}")
args.append(f"--module={bench_args.benchmark_vmfb_path}")

try:
benchmark_process = subprocess.run(
args=args,
# input=flatbuffer,
timeout=timeout,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.TimeoutExpired:
raise BenchmarkTimeoutError(f"Benchmark timed out after {timeout} seconds")
out = benchmark_process.stdout
err = benchmark_process.stderr

err = err.decode()
if "INVALID_ARGUMENT;" in err:
raise ValueError("Invalid inputs specified for benchmarking")

# In the event benchmarking runs but encounteres an internal error,
# return the internal error instead of benchmark results.
if "INTERNAL; CUDA driver error" in str(out):
raise BenchmarkToolError(str(out))

# Grab individual results by line (skip header lines)
bench_lines = out.decode().split("\n")[3:]
benchmark_results = []
for line in bench_lines:
split = line.split()
if len(split) == 0:
continue
benchmark_name = split[0]
time = " ".join(split[1:3])
cpu_time = " ".join(split[3:5])
iterations = split[5]
user_counters = None
if len(split) > 5:
user_counters = split[6]
benchmark_results.append(
BenchmarkResult(
benchmark_name=benchmark_name,
time=time,
cpu_time=cpu_time,
iterations=iterations,
user_counters=user_counters,
)
)

return benchmark_results


if __name__ == "__main__":
args = parser.parse_args()
Expand Down
Loading

0 comments on commit 1dea19e

Please sign in to comment.