Skip to content

Commit

Permalink
Unifies SD pipeline APIs, adds sd3 support, punet integration (#706)
Browse files Browse the repository at this point in the history
- Introduces a new sd_pipeline.py that handles inference for sd1.5, sd2.1, sdxl, sdxl-turbo, sd3. The pipeline is a child of the new pipeline_base.py that provides a comprehensive starting point to bringing up new pipelines.
- Generally moves SDXL away from the "scheduled unet" approach, instead compiling small scheduler models that fit around a standalone unet module.
- Reworks pipeline API to enable deployment / compatibility APIs
- Adds multi-device pipelining support to SD pipeline
- Carries flag updates for key targets
- file management improvements
- integrates sharktank int8 partitioned unet.

Signed-off-by: aviator19941 <[email protected]>
Signed-off-by: monorimet <[email protected]>
Co-authored-by: Ian <[email protected]>
Co-authored-by: dan <[email protected]>
Co-authored-by: IanNod <[email protected]>
Co-authored-by: aviator19941 <[email protected]>
Co-authored-by: saienduri <[email protected]>
  • Loading branch information
6 people authored Jul 12, 2024
1 parent 4f5f31f commit e46a2a2
Show file tree
Hide file tree
Showing 61 changed files with 10,347 additions and 3,135 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
run: git fetch --no-tags --prune --depth=1 origin "${GITHUB_BASE_REF?}:${GITHUB_BASE_REF?}"
- name: Install black
run: |
python3 -m pip install black==23.3
python3 -m pip install black
- name: Check if modified files are formatted
run: |
# The filter lowercase `d` means to exclude deleted files.
Expand Down
9 changes: 5 additions & 4 deletions .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt
pip install --no-compile --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt
pip install --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt
pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing]
pip install --upgrade --pre --no-cache-dir iree-compiler iree-runtime -f https://iree.dev/pip-release-links.html
pip install --no-compile --pre --upgrade -e models -r models/requirements.txt
Expand All @@ -69,7 +69,8 @@ jobs:
source turbine_venv/bin/activate
pytest -v models/turbine_models/tests/sd_test.py
pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu
pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5
pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2
pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5
3 changes: 1 addition & 2 deletions .github/workflows/test_shark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
matrix:
version: [3.11]
os: [nodai-ubuntu-builder-large]
os: [nodai-amdgpu-mi250-x86-64]

runs-on: ${{matrix.os}}
steps:
Expand Down Expand Up @@ -49,7 +49,6 @@ jobs:
cd $GITHUB_WORKSPACE/SHARK
python${{ matrix.version }} -m venv shark.venv
source shark.venv/bin/activate
sed -i 's/SHARK-Turbine#/SHARK-Turbine.git@${{github.sha}}#/g' requirements.txt
pip install -r requirements.txt --no-cache-dir
pip install -e .
python apps/shark_studio/tests/api_test.py
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ wheelhouse
*.safetensors
*.gguf
*.vmfb
*.mlir
*.npy
*.png
*tmp*
11 changes: 8 additions & 3 deletions models/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
protobuf
sentencepiece
shark_turbine
gguf
transformers==4.37.1
torchsde
accelerate
diffusers @ git+https://github.com/nod-ai/[email protected]
peft
diffusers @ git+https://github.com/nod-ai/[email protected]
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
# turbine tank downloading/uploading
azure-storage-blob
# microsoft/phi model
einops
pytest
scipy
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
-e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank
5 changes: 2 additions & 3 deletions models/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,11 @@ def load_version_info():
),
install_requires=[
"Shark-Turbine",
"brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b",
"protobuf",
"sentencepiece",
"transformers==4.37.1",
"transformers>=4.37.1",
"accelerate",
"diffusers==0.24.0",
"diffusers==0.29.0.dev0",
"azure-storage-blob",
"einops",
],
Expand Down
169 changes: 169 additions & 0 deletions models/turbine_models/custom_models/llama_argmax_td_spec.mlir

Large diffs are not rendered by default.

Loading

0 comments on commit e46a2a2

Please sign in to comment.