Skip to content

Commit

Permalink
Flag-guard padded attention preprocessing instruction, start adding t…
Browse files Browse the repository at this point in the history
…ests for abstracted pipeline
  • Loading branch information
monorimet committed Jun 24, 2024
1 parent 5846d10 commit 7fabc3c
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 2 deletions.
10 changes: 10 additions & 0 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,13 @@
"--iree-codegen-gpu-native-math-precision=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
],
"pad_attention": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))",
],
"preprocess_default": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))",
]
"unet": [""],
"clip": [""],
"vae": [""],
Expand Down Expand Up @@ -121,6 +126,7 @@ def compile_to_vmfb(
save_mlir=True,
attn_spec=None,
winograd=False,
masked_attention=False,
):
flags = []
if mlir_source == "file" and not isinstance(module_str, str):
Expand Down Expand Up @@ -205,6 +211,10 @@ def compile_to_vmfb(

if "gfx11" in target_triple:
flags.extend(GFX11_flags["all"])
if masked_attention:
flags.extend(GFX11_flags["pad_attention"])
else:
flags.extend(GFX11_flags["preprocess_default"])

# Currently, we need a transform dialect script to be applied to the compilation through IREE in certain cases.
# This 'attn_spec' handles a linalg_ext.attention op lowering to mfma instructions for capable targets.
Expand Down
1 change: 0 additions & 1 deletion models/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import sys

from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
Expand Down
69 changes: 69 additions & 0 deletions models/turbine_models/tests/pipeline_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 Advanced Micro Devices, inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
import pytest
import unittest
import torch
import os
import numpy as np
from iree.compiler.ir import Context
from shark_turbine.aot import *
from turbine_models.custom_models.sd_inference import utils

class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 10)
self.fc2 = torch.nn.Linear(10, 10)

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x

torch.no_grad()
def export_dummy_model():
model = TestModule()
target = "x86_64-unknown-linux-gnu"
device = "llvm-cpu"
model_metadata = {
'model_name': "TestModel2xLinear",
'input_shapes': [(10,)],
'input_dtypes': ["float32"],
'output_shapes': [(10,)],
'output_dtypes': ["float32"],
'test_kwarg_1': 'test_kwarg_1_value',
'test_kwarg_2': 'test_kwarg_2_value',
}
dummy_input = torch.empty(10)
safe_name = model_metadata['model_name'].replace('/', '_')
vmfb_path = f"./{safe_name}.vmfb"

fxb = FxProgramsBuilder(model)

@fxb.export_program(args=(dummy_input,))
def _forward(module, inputs):
return module.forward(inputs)

class CompiledTester(CompiledModule):
forward = _forward

inst = CompiledTester(context=Context(), import_to="IMPORT")
mlir_module = CompiledModule.get_mlir_module(inst)
breakpoint()




# class PipelineTest(unittest.TestCase):
# def setUp(self):
# model_map = {
# 'test_model_1':
# }

if __name__ == "__main__":
export_dummy_model()
2 changes: 1 addition & 1 deletion models/turbine_models/tests/sd3_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Nod Labs, Inc
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
Expand Down

0 comments on commit 7fabc3c

Please sign in to comment.