Skip to content

Commit

Permalink
Add paths to downloads for specs without masked attention.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 24, 2024
1 parent bac7c63 commit 4eca3b2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def export_submodel(
self.map[submodel]["mlir"] = input_mlir

match submodel:
case "unetloop": #SDXL ONLY FOR NOW
case "unetloop": # SDXL ONLY FOR NOW
pipeline_file = get_pipeline_ir(
self.width,
self.height,
Expand Down Expand Up @@ -420,7 +420,7 @@ def export_submodel(
)
self.map[submodel]["vmfb"] = vmfb_path
self.map[submodel]["weights"] = None
case "fullpipeline": #SDXL ONLY FOR NOW
case "fullpipeline": # SDXL ONLY FOR NOW
pipeline_file = get_pipeline_ir(
self.width,
self.height,
Expand Down
21 changes: 15 additions & 6 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,14 @@ def compile_to_vmfb(
# This is a temporary solution, and should be removed or largely disabled once the functionality of
# the TD spec is implemented in C++.
if attn_spec in ["default", "mfma"]:
attn_spec = get_mfma_spec_path(target_triple, os.path.dirname(safe_name))
attn_spec = get_mfma_spec_path(
target_triple, os.path.dirname(safe_name), masked_attention
)
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])
elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec):
attn_spec = get_wmma_spec_path(target_triple, os.path.dirname(safe_name))
attn_spec = get_wmma_spec_path(
target_triple, os.path.dirname(safe_name), masked_attention
)
if attn_spec:
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])
elif attn_spec and attn_spec != "None":
Expand Down Expand Up @@ -294,8 +298,11 @@ def create_safe_name(hf_model_name, model_name_str):
return safe_name


def get_mfma_spec_path(target_chip, save_dir):
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir"
def get_mfma_spec_path(target_chip, save_dir, masked_attention=False):
if not masked_attention:
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_mfma.mlir"
else:
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
Expand All @@ -305,8 +312,10 @@ def get_mfma_spec_path(target_chip, save_dir):
return spec_path


def get_wmma_spec_path(target_chip, save_dir):
if target_chip == "gfx1100":
def get_wmma_spec_path(target_chip, save_dir, masked_attention=False):
if not masked_attention:
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_wmma.mlir"
elif target_chip == "gfx1100":
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1100.mlir"
elif target_chip in ["gfx1103", "gfx1150"]:
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1150.mlir"
Expand Down
29 changes: 17 additions & 12 deletions models/turbine_models/tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from iree.compiler.ir import Context
from shark_turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
from shark_turbine.transforms import FuncOpMatcher, Pass


class TestModule(torch.nn.Module):
def __init__(self):
Expand All @@ -25,22 +27,25 @@ def forward(self, x):
x = self.fc2(x)
return x


torch.no_grad()


def export_dummy_model():
model = TestModule()
target = "x86_64-unknown-linux-gnu"
device = "llvm-cpu"
model_metadata = {
'model_name': "TestModel2xLinear",
'input_shapes': [(10,)],
'input_dtypes': ["float32"],
'output_shapes': [(10,)],
'output_dtypes': ["float32"],
'test_kwarg_1': 'test_kwarg_1_value',
'test_kwarg_2': 'test_kwarg_2_value',
"model_name": "TestModel2xLinear",
"input_shapes": [(10,)],
"input_dtypes": ["float32"],
"output_shapes": [(10,)],
"output_dtypes": ["float32"],
"test_kwarg_1": "test_kwarg_1_value",
"test_kwarg_2": "test_kwarg_2_value",
}
dummy_input = torch.empty(10)
safe_name = model_metadata['model_name'].replace('/', '_')
safe_name = model_metadata["model_name"].replace("/", "_")
vmfb_path = f"./{safe_name}.vmfb"

fxb = FxProgramsBuilder(model)
Expand All @@ -51,19 +56,19 @@ def _forward(module, inputs):

class CompiledTester(CompiledModule):
forward = _forward

inst = CompiledTester(context=Context(), import_to="IMPORT")
mlir_module = CompiledModule.get_mlir_module(inst)
funcop_pass = Pass(mlir_module.operation)

breakpoint()




# class PipelineTest(unittest.TestCase):
# def setUp(self):
# model_map = {
# 'test_model_1':
# }

if __name__ == "__main__":
export_dummy_model()
export_dummy_model()

0 comments on commit 4eca3b2

Please sign in to comment.