diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 8822d0144..ec616f125 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -68,8 +68,13 @@ "--iree-codegen-gpu-native-math-precision=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", + ], + "pad_attention": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], + "preprocess_default": [ + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", + ] "unet": [""], "clip": [""], "vae": [""], @@ -121,6 +126,7 @@ def compile_to_vmfb( save_mlir=True, attn_spec=None, winograd=False, + masked_attention=False, ): flags = [] if mlir_source == "file" and not isinstance(module_str, str): @@ -205,6 +211,10 @@ def compile_to_vmfb( if "gfx11" in target_triple: flags.extend(GFX11_flags["all"]) + if masked_attention: + flags.extend(GFX11_flags["pad_attention"]) + else: + flags.extend(GFX11_flags["preprocess_default"]) # Currently, we need a transform dialect script to be applied to the compilation through IREE in certain cases. # This 'attn_spec' handles a linalg_ext.attention op lowering to mfma instructions for capable targets. diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index bd9e99a23..475cf1d1d 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -7,7 +7,6 @@ import os import sys -from iree import runtime as ireert from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * diff --git a/models/turbine_models/tests/pipeline_test.py b/models/turbine_models/tests/pipeline_test.py new file mode 100644 index 000000000..5c7d21011 --- /dev/null +++ b/models/turbine_models/tests/pipeline_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 Advanced Micro Devices, inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import pytest +import unittest +import torch +import os +import numpy as np +from iree.compiler.ir import Context +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10) + self.fc2 = torch.nn.Linear(10, 10) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + +torch.no_grad() +def export_dummy_model(): + model = TestModule() + target = "x86_64-unknown-linux-gnu" + device = "llvm-cpu" + model_metadata = { + 'model_name': "TestModel2xLinear", + 'input_shapes': [(10,)], + 'input_dtypes': ["float32"], + 'output_shapes': [(10,)], + 'output_dtypes': ["float32"], + 'test_kwarg_1': 'test_kwarg_1_value', + 'test_kwarg_2': 'test_kwarg_2_value', + } + dummy_input = torch.empty(10) + safe_name = model_metadata['model_name'].replace('/', '_') + vmfb_path = f"./{safe_name}.vmfb" + + fxb = FxProgramsBuilder(model) + + @fxb.export_program(args=(dummy_input,)) + def _forward(module, inputs): + return module.forward(inputs) + + class CompiledTester(CompiledModule): + forward = _forward + + inst = CompiledTester(context=Context(), import_to="IMPORT") + mlir_module = CompiledModule.get_mlir_module(inst) + breakpoint() + + + + +# class PipelineTest(unittest.TestCase): +# def setUp(self): +# model_map = { +# 'test_model_1': +# } + +if __name__ == "__main__": + export_dummy_model() \ No newline at end of file diff --git a/models/turbine_models/tests/sd3_test.py b/models/turbine_models/tests/sd3_test.py index b1fc664ac..95309947d 100644 --- a/models/turbine_models/tests/sd3_test.py +++ b/models/turbine_models/tests/sd3_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 Nod Labs, Inc +# Copyright 2024 Advanced Micro Devices, Inc. # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information.