diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml deleted file mode 100644 index 2828d973c..000000000 --- a/.github/workflows/test.yml +++ /dev/null @@ -1,56 +0,0 @@ -name: Turbine Unit Tests - -on: - workflow_dispatch: - pull_request: - push: - branches: - - main - -concurrency: - # A PR number if a pull request and otherwise the commit hash. This cancels - # queued and in-progress runs for the same PR (presubmit) or commit - # (postsubmit). The workflow name is prepended to avoid conflicts between - # different workflows. - group: ${{ github.workflow }}-${{ github.event.number || github.sha }} - cancel-in-progress: true - -jobs: - test: - name: "Test" - strategy: - matrix: - version: [3.11] - os: [ubuntu-latest] - - runs-on: ${{matrix.os}} - steps: - - name: "Setting up Python" - uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3 - with: - python-version: ${{matrix.version}} - - - name: "Checkout Code" - uses: actions/checkout@v2 - - - name: Sync source deps - run: | - python -m pip install --upgrade pip - # Note: We install in three steps in order to satisfy requirements - # from non default locations first. Installing the PyTorch CPU - # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install -r core/pytorch-cpu-requirements.txt - pip install --upgrade \ - -r core/requirements.txt \ - -r mypy-requirements.txt - pip install -e core[testing] - - - name: Run core tests - if: ${{ !cancelled() }} - run: | - pytest -n 4 core/ - - - name: MyPy Type Checking Core - if: ${{ !cancelled() }} - run: | - (cd core && mypy) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index e6f2d11bd..5e62f68e1 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -30,8 +30,15 @@ jobs: # with: # python-version: ${{matrix.version}} - - name: "Checkout Code" - uses: actions/checkout@v2 + - name: "Checkout This Repo" + uses: actions/checkout@v4 + + - name: "Checkout iree-turbine" + uses: actions/checkout@v4 + with: + repository: iree-org/iree-turbine + # TODO: Let the ref be passed as a parameter to run integration tests. + path: iree-turbine - name: Sync source deps # build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile @@ -42,10 +49,10 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install -r core/pytorch-cpu-requirements.txt - pip install --pre --upgrade -r core/requirements.txt - pip install --pre -e core[testing] - pip install --pre --upgrade -e models -r models/requirements.txt + pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt + pip install --no-compile --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt + pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing] + pip install --no-compile --pre --upgrade -e models -r models/requirements.txt - name: Show current free memory run: | diff --git a/.github/workflows/test_sdxl.yml b/.github/workflows/test_sdxl.yml index 5b60acc07..bf97cc392 100644 --- a/.github/workflows/test_sdxl.yml +++ b/.github/workflows/test_sdxl.yml @@ -19,10 +19,17 @@ jobs: python-version: ${{matrix.version}} - name: "Checkout Code" - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: ref: ean-sd-fp16 + - name: "Checkout iree-turbine" + uses: actions/checkout@v4 + with: + repository: iree-org/iree-turbine + # TODO: Let the ref be passed as a parameter to run integration tests. + path: iree-turbine + - name: Sync source deps # build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile run: | @@ -30,12 +37,12 @@ jobs: # Note: We install in three steps in order to satisfy requirements # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --index-url https://download.pytorch.org/whl/cpu \ - -r core/pytorch-cpu-requirements.txt - pip install --upgrade -r core/requirements.txt - pip install -e core[testing,torch-cpu-nightly] - pip install --upgrade -r models/requirements.txt - pip install -e models + pip install --no-compile --index-url https://download.pytorch.org/whl/cpu \ + -r ${{ github.workspace }}/iree-turbine//pytorch-cpu-requirements.txt + pip install --no-compile --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt + pip install --no-compile -e ${{ github.workspace }}/iree-turbine/[testing,torch-cpu-nightly] + pip install --no-compile --upgrade -r models/requirements.txt + pip install --no-compile -e models - name: Show current free memory run: | diff --git a/.github/workflows/test_shark.yml b/.github/workflows/test_shark.yml index 5f92a672d..8995439d8 100644 --- a/.github/workflows/test_shark.yml +++ b/.github/workflows/test_shark.yml @@ -36,6 +36,13 @@ jobs: path: SHARK ref: "main" + - name: "Checkout iree-turbine" + uses: actions/checkout@v4 + with: + repository: iree-org/iree-turbine + # TODO: Let the ref be passed as a parameter to run integration tests. + path: iree-turbine + # TODO: Replace with a sh script from shark repo - name: "Install SHARK" run: | diff --git a/README.md b/README.md index 42e8019dc..f1b1a674e 100644 --- a/README.md +++ b/README.md @@ -1,129 +1,30 @@ # SHARK Turbine -![image](https://netl.doe.gov/sites/default/files/2020-11/Turbine-8412270026_83cfc8ee8f_c.jpg) +This repo is Nod-AI's integration repository for various model bringup +activities and CI. In 2023 and early 2024, it played a different role +by being the place where FX/Dynamo based torch-mlir and IREE toolsets +were developed, including: -Turbine is the set of development tools that the [SHARK Team](https://github.com/nod-ai/SHARK) -is building for deploying all of our models for deployment to the cloud and devices. We -are building it as we transition from our TorchScript-era 1-off export and compilation -to a unified approach based on PyTorch 2 and Dynamo. While we use it heavily ourselves, it -is intended to be a general purpose model compilation and execution tool. +* [Torch-MLIR FxImporter](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py) +* [Torch-MLIR ONNX Importer](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/onnx_importer.py) +* [Torch-MLIR's ONNX C Importer](https://github.com/llvm/torch-mlir/tree/main/projects/onnx_c_importer) +* [IREE Turbine](https://github.com/iree-org/iree-turbine) +* [Sharktank and Shortfin](https://github.com/nod-ai/sharktank) -Turbine provides a collection of tools: +As these have all found upstream homes, this repo is a bit bare. We will +continue to use it as a staging ground for things that don't have a +more defined spot and as a way to drive certain kinds of upstreaming +activities. -* *AOT Export*: For compiling one or more `nn.Module`s to compiled, deployment - ready artifacts. This operates via both a simple one-shot export API (Already upstreamed to [torch-mlir](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py)) - for simple models and an underlying [advanced API](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/shark_turbine/aot/compiled_module.py) for complicated models - and accessing the full features of the runtime. -* *Eager Execution*: A `torch.compile` backend is provided and a Turbine Tensor/Device - is available for more native, interactive use within a PyTorch session. -* *Turbine Kernels*: (coming soon) A union of the [Triton](https://github.com/openai/triton) approach and - [Pallas](https://jax.readthedocs.io/en/latest/pallas/index.html) but based on - native PyTorch constructs and tracing. It is intended to complement for simple - cases where direct emission to the underlying, cross platform, vector programming model - is desirable. -* *Turbine-LLM*: a repository of layers, model recipes, and conversion tools - from popular Large Language Model (LLM) quantization tooling. -Under the covers, Turbine is based heavily on [IREE](https://github.com/openxla/iree) and -[torch-mlir](https://github.com/llvm/torch-mlir) and we use it to drive evolution -of both, upstreaming infrastructure as it becomes timely to do so. +## Current Projects -See [the roadmap](docs/roadmap.md) for upcoming work and places to contribute. +### turbine-models -## Contact Us +The `turbine-models` project (under models/) contains ports and adaptations +of various (mostly HF) models that we use in various ways. -Turbine is under active development. If you would like to participate as it comes online, -please reach out to us on the `#turbine` channel of the -[nod-ai Discord server](https://discord.gg/QMmR6f8rGb). +### CI -## Quick Start for Users +Integration CI for a variety of projects is rooted in this repo. -1. Install from source: - -``` -pip install shark-turbine -# Or for editable: see instructions under developers -``` - -The above does install some unecessary cuda/cudnn packages for cpu use. To avoid this you -can specify pytorch-cpu and install via: -``` -pip install -r core/pytorch-cpu-requirements.txt -pip install shark-turbine -``` - -(or follow the "Developers" instructions below for installing from head/nightly) - -2. Try one of the samples: - -Generally, we use Turbine to produce valid, dynamic shaped Torch IR (from the -[`torch-mlir torch` dialect](https://github.com/llvm/torch-mlir/tree/main/include/torch-mlir/Dialect/Torch/IR) -with various approaches to handling globals). Depending on the use-case and status of the -compiler, these should be compilable via IREE with `--iree-input-type=torch` for -end to end execution. Dynamic shape support in torch-mlir is a work in progress, -and not everything works at head with release binaries at present. - - * [AOT MLP With Static Shapes](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/examples/aot_mlp/mlp_export_simple.py) - * [AOT MLP with a dynamic batch size](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/examples/aot_mlp/mlp_export_dynamic.py) - * [AOT llama2](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/examples/llama2_inference/llama2.ipynb): - Dynamic sequence length custom compiled module with state management internal to the model. - * [Eager MNIST with `torch.compile`](https://github.com/nod-ai/SHARK-Turbine/blob/main/core/examples/eager_mlp/mlp_eager_simple.py) - -## Developers - -### Getting Up and Running - -If only looking to develop against this project, then you need to install Python -deps for the following: - -* PyTorch -* iree-compiler (with Torch input support) -* iree-runtime - -The pinned deps at HEAD require pre-release versions of all of the above, and -therefore require additional pip flags to install. Therefore, to satisfy -development, we provide a `requirements.txt` file which installs precise -versions and has all flags. This can be installed prior to the package: - -Installing into a venv is highly recommended. - -``` -pip install -r core/pytorch-cpu-requirements.txt -pip install --upgrade -r core/requirements.txt -pip install --upgrade -e "core[torch-cpu-nightly,testing]" -``` - -Run tests: - -``` -pytest core/ -``` - -### Using a development compiler - -If doing native development of the compiler, it can be useful to switch to -source builds for iree-compiler and iree-runtime. - -In order to do this, check out [IREE](https://github.com/openxla/iree) and -follow the instructions to [build from source](https://iree.dev/building-from-source/getting-started/), making -sure to specify [additional options for the Python bindings](https://iree.dev/building-from-source/getting-started/#building-with-cmake): - -``` --DIREE_BUILD_PYTHON_BINDINGS=ON -DPython3_EXECUTABLE="$(which python)" -``` - -#### Configuring Python - -Uninstall existing packages: - -``` -pip uninstall iree-compiler -pip uninstall iree-runtime -``` - -Copy the `.env` file from `iree/` to this source directory to get IDE -support and add to your path for use from your shell: - -``` -source .env && export PYTHONPATH -``` diff --git a/build_tools/build_release.py b/build_tools/build_release.py deleted file mode 100755 index fc0db7369..000000000 --- a/build_tools/build_release.py +++ /dev/null @@ -1,170 +0,0 @@ -#!/usr/bin/env python -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Fetches dependent release artifacts and builds wheels. - -See docs/releasing.md for usage. -""" - -import argparse -from datetime import date -import json -import os -from pathlib import Path -import shlex -import subprocess - - -REPO_ROOT = Path(__file__).resolve().parent.parent -VERSION_INFO_FILE = REPO_ROOT / "version_info.json" -CORE_DIR = REPO_ROOT / "core" -WHEEL_DIR = REPO_ROOT / "wheelhouse" - -# The platform flags that we will download IREE wheels for. This must match -# the platforms and Python versions we build. If it mismatches or something -# is wrong, this will error. Note that the platform and python-version -# indicates "fetch me a wheel that will install on this combo" vs "fetch me -# a specific wheel". -IREE_PLATFORM_ARGS = [ - # Linux aarch64 - ["--platform", "manylinux_2_28_aarch64", "--python-version", "3.9"], - ["--platform", "manylinux_2_28_aarch64", "--python-version", "3.10"], - ["--platform", "manylinux_2_28_aarch64", "--python-version", "3.11"], - ["--platform", "manylinux_2_28_aarch64", "--python-version", "3.12"], - # Linux x86_64 - ["--platform", "manylinux_2_28_x86_64", "--python-version", "3.9"], - ["--platform", "manylinux_2_28_x86_64", "--python-version", "3.10"], - ["--platform", "manylinux_2_28_x86_64", "--python-version", "3.11"], - ["--platform", "manylinux_2_28_x86_64", "--python-version", "3.12"], - # MacOS - ["--platform", "macosx_13_0_universal2", "--python-version", "3.11"], - # Windows - ["--platform", "win_amd64", "--python-version", "3.11"], -] - - -def eval_version(version_spec: str): - date_stamp = date.today().strftime("%Y%m%d") - return version_spec.replace("YYYYMMDD", date_stamp) - - -def write_version_info(args): - with open(VERSION_INFO_FILE, "rt") as f: - info_dict = json.load(f) - - # Compute core-version. - core_version = eval_version(args.core_version) - if args.core_pre_version: - core_version += eval_version(args.core_pre_version) - if args.core_post_version: - core_version += f".{eval_version(args.core_post_version)}" - info_dict["core-version"] = core_version - - with open(VERSION_INFO_FILE, "wt") as f: - json.dump(info_dict, f) - - print(f"Updated version_info.json:\n{json.dumps(info_dict, indent=2)}") - - -def exec(args, env=None): - args = [str(s) for s in args] - print(f": Exec: {shlex.join(args)}") - if env is not None: - full_env = dict(os.environ) - full_env.update(env) - else: - full_env = None - subprocess.check_call(args, env=full_env) - - -def download_requirements(requirements_file, platforms=()): - args = [ - "pip", - "download", - "-d", - WHEEL_DIR, - ] - if platforms: - args.append("--no-deps") - for p in platforms: - args.extend(["--platform", p]) - args += [ - "-f", - WHEEL_DIR, - "-r", - requirements_file, - ] - exec(args) - - -def download_iree_binaries(): - for platform_args in IREE_PLATFORM_ARGS: - print("Downloading for platform:", platform_args) - args = [ - "pip", - "download", - "-d", - WHEEL_DIR, - "--no-deps", - ] - args.extend(platform_args) - args += [ - "-f", - "https://iree.dev/pip-release-links.html", - "-f", - WHEEL_DIR, - "-r", - CORE_DIR / "iree-requirements.txt", - ] - exec(args) - - -def build_wheel(path, env=None): - exec( - ["pip", "wheel", "--no-index", "-f", WHEEL_DIR, "-w", WHEEL_DIR, path], env=env - ) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--core-version", help="Version for the core component", required=True - ) - parser.add_argument( - "--core-pre-version", - help="Pre-release version segment or (YYYYMMDD)", - default="", - ) - parser.add_argument( - "--core-post-version", - help="Post-release version segment or (YYYYMMDD)", - default="", - ) - parser.add_argument( - "--no-download", help="Disable dep download", action="store_true" - ) - args = parser.parse_args() - - write_version_info(args) - WHEEL_DIR.mkdir(parents=True, exist_ok=True) - - if not args.no_download: - print("Prefetching all IREE binaries") - download_iree_binaries() - print("Prefetching torch CPU") - download_requirements(CORE_DIR / "pytorch-cpu-requirements.txt") - print("Downloading remaining requirements") - download_requirements(CORE_DIR / "requirements.txt") - - print("Building shark-turbine") - build_wheel(CORE_DIR) - print("Building iree-turbine") - build_wheel(CORE_DIR, env={"TURBINE_PACKAGE_NAME": "iree-turbine"}) - - -if __name__ == "__main__": - main() diff --git a/core/README.md b/core/README.md deleted file mode 100644 index 52c232d9b..000000000 --- a/core/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Turbine-Core Sub-Project - -This directory contains the core infrastructure for the project, consisting -of the model export, runtime, and kernel development APIs. It is packaged -and released as the -[`iree-turbine` project on PyPI](https://pypi.org/project/iree-turbine/) -(previously [`SHARK-Turbine`](https://pypi.org/project/shark-turbine/)). - -It depends purely on PyTorch and the IREE compiler/runtime. - -See the repository-level README for further information. diff --git a/core/examples/aot_mlp/mlp_export_dynamic.py b/core/examples/aot_mlp/mlp_export_dynamic.py deleted file mode 100644 index cd8636554..000000000 --- a/core/examples/aot_mlp/mlp_export_dynamic.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# This sample builds a dynamic shape version of the MLP with -# a dynamic batch dimension. It uses the advanced, low-level -# API because we don't have dynamic shapes available in the -# simple API yet. - -import torch -import torch.nn as nn - -import shark_turbine.aot as aot - - -class MLP(nn.Module): - def __init__(self): - super().__init__() - self.layer0 = nn.Linear(8, 8, bias=True) - self.layer1 = nn.Linear(8, 4, bias=True) - self.layer2 = nn.Linear(4, 2, bias=True) - self.layer3 = nn.Linear(2, 2, bias=True) - - def forward(self, x: torch.Tensor): - x = self.layer0(x) - x = torch.sigmoid(x) - x = self.layer1(x) - x = torch.sigmoid(x) - x = self.layer2(x) - x = torch.sigmoid(x) - x = self.layer3(x) - return x - - -model = MLP() - - -class CompiledMLP(aot.CompiledModule): - params = aot.export_parameters(model) - - def main(self, x=aot.AbstractTensor(None, 97, 8, dtype=torch.float32)): - return aot.jittable(model.forward)( - x, - constraints=[ - x.dynamic_dim(0), - ], - ) - - -batch = torch.export.Dim("batch") -exported = aot.export( - model, - args=(torch.empty([2, 97, 8], dtype=torch.float32),), - dynamic_shapes={"x": {0: batch}}, -) -# Note that dynamic Torch IR is created below. -exported.print_readable() - - -# TODO: Enable once version roll to ToT torch-mlir with dynamic view -# op legalization fixes. -# compiled_binary = exported.compile(save_to=None) -# def infer(): -# import numpy as np -# import iree.runtime as rt - -# config = rt.Config("local-task") -# vmm = rt.load_vm_module( -# rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), -# config, -# ) -# x = np.random.rand(10, 97, 8).astype(np.float32) -# y = vmm.main(x) -# print(y.to_host()) -# infer() diff --git a/core/examples/aot_mlp/mlp_export_simple.py b/core/examples/aot_mlp/mlp_export_simple.py deleted file mode 100644 index fed4795d4..000000000 --- a/core/examples/aot_mlp/mlp_export_simple.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest -import torch -import torch.nn as nn - -import shark_turbine.aot as aot - - -class MLP(nn.Module): - def __init__(self): - super().__init__() - self.layer0 = nn.Linear(8, 8, bias=True) - self.layer1 = nn.Linear(8, 4, bias=True) - self.layer2 = nn.Linear(4, 2, bias=True) - self.layer3 = nn.Linear(2, 2, bias=True) - - def forward(self, x: torch.Tensor): - x = self.layer0(x) - x = torch.sigmoid(x) - x = self.layer1(x) - x = torch.sigmoid(x) - x = self.layer2(x) - x = torch.sigmoid(x) - x = self.layer3(x) - return x - - -model = MLP() -example_x = torch.empty(97, 8, dtype=torch.float32) -exported = aot.export(model, example_x) -exported.print_readable() -compiled_binary = exported.compile(save_to=None) - - -def infer(): - import numpy as np - import iree.runtime as rt - - config = rt.Config("local-task") - vmm = rt.load_vm_module( - rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), - config, - ) - x = np.random.rand(97, 8).astype(np.float32) - y = vmm.main(x) - print(y.to_host()) - - -class ModelTest(unittest.TestCase): - def testMLPExportSimple(selfs): - infer() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/examples/eager_mlp/mlp_eager_simple.py b/core/examples/eager_mlp/mlp_eager_simple.py deleted file mode 100644 index ec6715159..000000000 --- a/core/examples/eager_mlp/mlp_eager_simple.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch -from torch import nn -from torch.utils.data import DataLoader -import torchvision.transforms as transforms -import torchvision.datasets as datasets - -torch._dynamo.config.dynamic_shapes = ( - False # TODO: https://github.com/nod-ai/SHARK-Turbine/issues/93 -) - - -class MNISTDataLoader: - def __init__(self, batch_size, shuffle=True): - self.batch_size = batch_size - self.shuffle = shuffle - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] - ) - self.mnist_trainset = datasets.MNIST( - root="../data", train=True, download=True, transform=transform - ) - self.mnist_testset = datasets.MNIST( - root="../data", train=False, download=True, transform=transform - ) - - def get_train_loader(self): - return DataLoader( - dataset=self.mnist_trainset, - batch_size=self.batch_size, - shuffle=self.shuffle, - ) - - def get_test_loader(self): - return DataLoader( - dataset=self.mnist_testset, - batch_size=self.batch_size, - shuffle=False, - drop_last=True, - ) - - -class MLP(nn.Module): - def __init__(self): - super().__init__() - self.layer0 = nn.Linear(28, 28, bias=True) - self.layer1 = nn.Linear(28, 14, bias=True) - self.layer2 = nn.Linear(14, 7, bias=True) - self.layer3 = nn.Linear(7, 7, bias=True) - - def forward(self, x: torch.Tensor): - x = self.layer0(x) - x = torch.sigmoid(x) - x = self.layer1(x) - x = torch.sigmoid(x) - x = self.layer2(x) - x = torch.sigmoid(x) - x = self.layer3(x) - return x - - -def infer_iteration(model, images): - outputs = model(images) - return outputs - - -def infer(): - # Example Parameters - config = { - "batch_size": 64, - "learning_rate": 0.001, - "num_epochs": 10, - } - - custom_data_loader = MNISTDataLoader(config["batch_size"]) - test_loader = custom_data_loader.get_test_loader() - model = MLP() - test_opt = torch.compile(infer_iteration, backend="turbine_cpu") - for i, (images, labels) in enumerate(test_loader): - test_opt(model, images) - - -class ModelTests(unittest.TestCase): - def testMNISTEagerSimple(self): - infer() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/examples/llama2_inference/README.md b/core/examples/llama2_inference/README.md deleted file mode 100644 index 1d377727b..000000000 --- a/core/examples/llama2_inference/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# LLAMA 2 Inference - -This example require some extra dependencies. Here's an easy way to get it running on a fresh server. - -Don't forget to put in your huggingface token from https://huggingface.co/settings/tokens - -```bash -#!/bin/bash - - -# if you don't insert it, you will be prompted to log in later; -# you may need to rerun this script after logging in -YOUR_HF_TOKEN="insert token for headless" - -# clone and install dependencies -sudo apt install -y git -git clone https://github.com/nod-ai/SHARK-Turbine.git -cd SHARK-Turbine -pip install -r requirements.txt -pip install --update "huggingface_hub[cli]" transformers sentencepiece protobuf - -# do an editable install from the cloned SHARK-Turbine -pip install --editable . - -# Log in with Hugging Face CLI if token setup is required -if [[ $YOUR_HF_TOKEN == hf_* ]]; then - huggingface login --token $YOUR_HF_TOKEN - echo "Logged in with YOUR_HF_TOKEN." -elif [ -f ~/.cache/huggingface/token ]; then - # Read token from the file - TOKEN_CONTENT=$(cat ~/.cache/huggingface/token) - - # Check if the token starts with "hf_" - if [[ $TOKEN_CONTENT == hf_* ]]; then - echo "Already logged in with a Hugging Face token." - else - echo "Token in file does not start with 'hf_'. Please log into huggingface to download models." - huggingface-cli login - fi -else - echo "Please log into huggingface to download models." - huggingface-cli login -fi - -# Step 7: Run the Python script -python examples/llama2_inference/stateless_llama.py -``` diff --git a/core/examples/llama2_inference/llama2.ipynb b/core/examples/llama2_inference/llama2.ipynb deleted file mode 100644 index b008bbd20..000000000 --- a/core/examples/llama2_inference/llama2.ipynb +++ /dev/null @@ -1,503 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "c0c9f034-7af1-4dc2-bbfb-5bb9e27c07ca", - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer, AutoModelForCausalLM\n", - "import torch\n", - "from torch.utils import _pytree as pytree\n", - "from shark_turbine.aot import *\n", - "from iree.compiler.ir import Context\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4d92bb47-2b93-4f32-a445-c0ad2adc37ad", - "metadata": {}, - "outputs": [], - "source": [ - "#set some config values\n", - "\n", - "hf_auth_token = \"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\"\n", - "hf_model_name = \"meta-llama/Llama-2-7b-chat-hf\"\n", - "state_schema_path = \"llama2_state_schema.json\"\n", - "with open(state_schema_path, \"r+\") as f:\n", - " state_schema = pytree.treespec_loads(f.read())\n", - "prompt = \"\"\"\n", - "[INST] <>\n", - "Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST]\n", - "\"\"\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "d4664585-5e15-45c7-8c5c-c8eaf6381435", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py:640: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.\n", - " warnings.warn(\n", - "/home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:479: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5e411acda19c4228b008ff622bdf110e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00.5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:26 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,234] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s1 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:72 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,409] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s2, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:118 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:33,707] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s3 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:189 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,845] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s3, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:228 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:33,878] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s4, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:235 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:34,188] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s5 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:306 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:34,326] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s5, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:345 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:34,359] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s6, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:352 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:34,661] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s7 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:423 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:34,800] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s7, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:462 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:34,832] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s8, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:469 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:35,130] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s9 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:540 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:35,271] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s9, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:579 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:35,305] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s10, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:586 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:35,611] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s11 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:657 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:35,762] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s11, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:696 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:35,795] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s12, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:703 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:36,107] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s13 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:774 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:36,249] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s13, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:813 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:36,282] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s14, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:820 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:36,589] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s15 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:891 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:36,734] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s15, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:930 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:36,768] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s16, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:937 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:37,105] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s17 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1008 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:37,249] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s17, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1047 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:37,286] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s18, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1054 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:37,595] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s19 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1125 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:37,744] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s19, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1164 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:37,778] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s20, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1171 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:38,090] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s21 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1242 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:38,238] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s21, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1281 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:38,272] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s22, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1288 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:38,584] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s23 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1359 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:38,734] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s23, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1398 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:38,768] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s24, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1405 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:39,086] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s25 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1476 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:39,239] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s25, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1515 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:39,274] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s26, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1522 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:39,597] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s27 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1593 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:39,759] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s27, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1632 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:39,812] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s28, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1639 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:40,330] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s29 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1710 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:40,534] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s29, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1749 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:40,582] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s30, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1756 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:41,068] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s31 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1827 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:41,242] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s31, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1866 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:41,280] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s32, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1873 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:41,686] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s33 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1944 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:41,968] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s33, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1983 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:42,004] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s34, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1990 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:42,419] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s35 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2061 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:42,580] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s35, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2100 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:42,618] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s36, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2107 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:43,002] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s37 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2178 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:43,174] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s37, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2217 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:43,215] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s38, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2224 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:43,566] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s39 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2295 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:43,738] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s39, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2334 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:43,776] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s40, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2341 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:44,116] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s41 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2412 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:44,281] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s41, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2451 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:44,320] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s42, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2458 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:44,656] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s43 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2529 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:44,822] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s43, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2568 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:44,860] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s44, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2575 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:45,218] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s45 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2646 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:45,387] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s45, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2685 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:45,426] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s46, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2692 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:45,772] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s47 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2763 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:45,943] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s47, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2802 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:45,983] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s48, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2809 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:46,376] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s49 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2880 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:46,563] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s49, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2919 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:46,605] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s50, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2926 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:46,962] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s51 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2997 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:47,136] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s51, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3036 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:47,176] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s52, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3043 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:47,540] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s53 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3114 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:47,718] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s53, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3153 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:47,758] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s54, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3160 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:48,125] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s55 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3231 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:48,308] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s55, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3270 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:48,349] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s56, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3277 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:48,715] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s57 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3348 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:48,897] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s57, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3387 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:48,937] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s58, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3394 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:49,317] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s59 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3465 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:49,499] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s59, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3504 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:49,540] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s60, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3511 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:49,915] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s61 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3582 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:50,113] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s61, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3621 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:50,155] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s62, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3628 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:50,515] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s63 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3699 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:50,697] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s63, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3738 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:50,737] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s64, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3745 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:53,791] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards\n", - "[2023-10-09 18:49:54,155] torch.fx.experimental.symbolic_shapes: [WARNING] Ignored guard s0 + s1 > 4096 == False, this could result in accuracy problems\n", - "[2023-10-09 18:49:54,157] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s1 <= 4096 [guard added] (_decomp/decompositions.py:725 in slice_forward)\n" - ] - } - ], - "source": [ - "#Run the export pipeline\n", - "inst = StateUpdateModule(context=Context(), import_to=\"IMPORT\")\n", - "module_str = str(CompiledModule.get_mlir_module(inst))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "bc04e1db-a8cc-4182-884d-ba3d8ae5adeb", - "metadata": {}, - "outputs": [], - "source": [ - "#Output a torch-ir mlir file\n", - "with open(\"llama2_torch.mlir\", \"w+\") as f:\n", - " f.write(module_str)\n", - "#TODO: run the rest of the compile pipeline and do inference" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/core/examples/llama2_inference/llama2_state_schema.json b/core/examples/llama2_inference/llama2_state_schema.json deleted file mode 100644 index e4e55dfd1..000000000 --- a/core/examples/llama2_inference/llama2_state_schema.json +++ /dev/null @@ -1 +0,0 @@ -[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}] \ No newline at end of file diff --git a/core/examples/llama2_inference/requirements.txt b/core/examples/llama2_inference/requirements.txt deleted file mode 100644 index acbc93ca3..000000000 --- a/core/examples/llama2_inference/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -protobuf -sentencepiece -shark_turbine -transformers @ git+https://github.com/huggingface/transformers.git@7d8ff3629b2725ec43ace99c1a6e87ac1978d433 diff --git a/core/examples/resnet-18/README.md b/core/examples/resnet-18/README.md deleted file mode 100644 index aaf8ce7bb..000000000 --- a/core/examples/resnet-18/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Dynamic AOT Resnet-18 Example - -This example AOT-compiles a Resnet-18 module for performing inference on a dynamic number of input images. - -To run this example (with Python3.11), you should clone the repository to your local device and install the requirements in a virtual environment. - -```bash -git clone https://github.com/nod-ai/SHARK-Turbine.git -cd SHARK-Turbine/examples/resnet-18 -python -m venv rn18_venv -source ./rn18_venv/bin/activate -pip install -r requirements.txt -``` - -Once the requirements are installed, you should be able to run the example. - -```bash -python resnet-18.py -``` \ No newline at end of file diff --git a/core/examples/resnet-18/requirements.txt b/core/examples/resnet-18/requirements.txt deleted file mode 100644 index a6ca4b760..000000000 --- a/core/examples/resnet-18/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -transformers -shark_turbine==0.9.2 \ No newline at end of file diff --git a/core/examples/resnet-18/resnet-18.py b/core/examples/resnet-18/resnet-18.py deleted file mode 100644 index 203400130..000000000 --- a/core/examples/resnet-18/resnet-18.py +++ /dev/null @@ -1,70 +0,0 @@ -from transformers import AutoFeatureExtractor, AutoModelForImageClassification -import torch -from shark_turbine.aot import * -import iree.runtime as rt - -# Loading feature extractor and pretrained model from huggingface -# extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18") -model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18") - - -# define a function to do inference -# this will get passed to the compiled module as a jittable function -def forward(pixel_values_tensor: torch.Tensor): - with torch.no_grad(): - logits = model.forward(pixel_values_tensor).logits - predicted_id = torch.argmax(logits, -1) - return predicted_id - - -# a dynamic module for doing inference -# this will be compiled AOT to a memory buffer -class RN18(CompiledModule): - params = export_parameters(model) - - def forward(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)): - # set a constraint for the dynamic number of batches - # interestingly enough, it doesn't seem to limit BATCH_SIZE - const = [x.dynamic_dim(0) < 16] - return jittable(forward)(x, constraints=const) - - -# build an mlir module with 1-shot exporter -exported = export(RN18) -# compile exported module to a memory buffer -compiled_binary = exported.compile(save_to=None) - - -# return type is rt.array_interop.DeviceArray -# np.array of outputs can be accessed via to_host() method -def shark_infer(x): - config = rt.Config("local-task") - vmm = rt.load_vm_module( - rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), - config, - ) - y = vmm.forward(x) - return y - - -# prints the text corresponding to output label codes -def print_labels(id): - for l in id: - print(model.config.id2label[l]) - - -# finds discrepancies between id0 and id1 -def compare_labels(id0, id1): - return (id0 != id1).nonzero(as_tuple=True) - - -# load some examples and check for discrepancies between -# compiled module and standard inference (forward function) - -x = torch.randn(10, 3, 224, 224) -y0 = shark_infer(x) -y1 = forward(x) -print_labels(y0) -print( - f"Found {compare_labels(y0,y1)[0].size()[0]} discrepancies between turbine and standard result" -) diff --git a/core/iree-requirements.txt b/core/iree-requirements.txt deleted file mode 100644 index eaa171b4e..000000000 --- a/core/iree-requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -iree-compiler==20240410.859 -iree-runtime==20240410.859 \ No newline at end of file diff --git a/core/iree/turbine/__init__.py b/core/iree/turbine/__init__.py deleted file mode 100644 index c59e85c2e..000000000 --- a/core/iree/turbine/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -The turbine package provides development tools for deploying PyTorch 2 machine -learning models to cloud and edge devices. -""" - -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# TODO: This redirection layer exists while we are migrating from the -# shark_turbine top-level package name to iree.turbine. It exports the -# public API but not the internal details. In a future switch, all code -# will be directly located here and the redirect will be done in the -# shark_turbine namespace. - -from shark_turbine import aot -from shark_turbine import dynamo -from shark_turbine import kernel -from shark_turbine import ops -from shark_turbine import runtime diff --git a/core/misc-requirements.txt b/core/misc-requirements.txt deleted file mode 100644 index becb775f9..000000000 --- a/core/misc-requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -numpy>=1.26.3 -onnx>=1.15.0 -pytest>=8.0.0 -pytest-xdist>=3.5.0 diff --git a/core/mypy.ini b/core/mypy.ini deleted file mode 100644 index ea0bb6890..000000000 --- a/core/mypy.ini +++ /dev/null @@ -1,21 +0,0 @@ -[mypy] - -explicit_package_bases = True -mypy_path = $MYPY_CONFIG_FILE_DIR -packages = shark_turbine - -# Missing typing stubs for iree.compiler. -[mypy-iree.compiler.*] -ignore_missing_imports = True - -# Missing typing stubs for iree.runtime. -[mypy-iree.runtime.*] -ignore_missing_imports = True - -# fx_importer needs to be fixed upstream. -[mypy-shark_turbine.importers.fx_importer.*] -ignore_errors = True - -# TODO: Fix all typing errors in TK. -[mypy-shark_turbine.kernel.*] -ignore_errors = True diff --git a/core/pyproject.toml b/core/pyproject.toml deleted file mode 100644 index 9787c3bdf..000000000 --- a/core/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" diff --git a/core/pytorch-cpu-requirements.txt b/core/pytorch-cpu-requirements.txt deleted file mode 100644 index e4fa5c795..000000000 --- a/core/pytorch-cpu-requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ ---pre ---index-url https://download.pytorch.org/whl/test/cpu --r pytorch-requirements.txt diff --git a/core/pytorch-requirements.txt b/core/pytorch-requirements.txt deleted file mode 100644 index 63fc21602..000000000 --- a/core/pytorch-requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -torch==2.3.0 -torchaudio -torchvision diff --git a/core/requirements.txt b/core/requirements.txt deleted file mode 100644 index 8fe18be63..000000000 --- a/core/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -# These requirements are here for one-stop development -# setup with pinned deps and will fulfill all install -# requirements for the package (which pins to minimum -# versions, not specific). --f https://iree.dev/pip-release-links.html - --r pytorch-requirements.txt --r iree-requirements.txt - -# From pyproject.toml. -setuptools -wheel diff --git a/core/setup.cfg b/core/setup.cfg deleted file mode 100644 index 358360671..000000000 --- a/core/setup.cfg +++ /dev/null @@ -1,6 +0,0 @@ -[tool:pytest] -testpaths = - ./tests -filterwarnings = - # TODO: Remove once flatbuffer 'imp' usage resolved. - ignore::DeprecationWarning diff --git a/core/setup.py b/core/setup.py deleted file mode 100644 index e8bb08f8d..000000000 --- a/core/setup.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import json -import os -import distutils.command.build -from pathlib import Path - -from setuptools import find_namespace_packages, setup - -THIS_DIR = os.path.realpath(os.path.dirname(__file__)) -REPO_DIR = os.path.dirname(THIS_DIR) -VERSION_INFO_FILE = os.path.join(REPO_DIR, "version_info.json") - -# Transitional as we migrate from shark-turbine -> iree-turbine. -TURBINE_PACKAGE_NAME = os.getenv("TURBINE_PACKAGE_NAME", "shark-turbine") - -with open( - os.path.join( - REPO_DIR, - "README.md", - ), - "rt", -) as f: - README = f.read() - - -def load_version_info(): - with open(VERSION_INFO_FILE, "rt") as f: - return json.load(f) - - -version_info = load_version_info() -PACKAGE_VERSION = version_info["core-version"] - -packages = find_namespace_packages( - include=[ - "iree.turbine", - "iree.turbine.*", - "shark_turbine", - "shark_turbine.*", - ], -) - -print("Found packages:", packages) - -# Lookup version pins from requirements files. -requirement_pins = {} - - -def load_requirement_pins(requirements_file: str): - with open(Path(THIS_DIR) / requirements_file, "rt") as f: - lines = f.readlines() - pin_pairs = [line.strip().split("==") for line in lines if "==" in line] - requirement_pins.update(dict(pin_pairs)) - - -load_requirement_pins("iree-requirements.txt") -load_requirement_pins("misc-requirements.txt") -load_requirement_pins("pytorch-cpu-requirements.txt") - - -def get_version_spec(dep: str): - if dep in requirement_pins: - return f">={requirement_pins[dep]}" - else: - return "" - - -# Override build command so that we can build into _python_build -# instead of the default "build". This avoids collisions with -# typical CMake incantations, which can produce all kinds of -# hilarity (like including the contents of the build/lib directory). -class BuildCommand(distutils.command.build.build): - def initialize_options(self): - distutils.command.build.build.initialize_options(self) - self.build_base = "_python_build" - - -setup( - name=f"{TURBINE_PACKAGE_NAME}", - version=f"{PACKAGE_VERSION}", - author="SHARK Authors", - author_email="stella@nod.ai", - description="SHARK Turbine Machine Learning Deployment Tools", - long_description=README, - long_description_content_type="text/markdown", - url="https://github.com/nod-ai/SHARK-Turbine", - license="Apache-2.0", - classifiers=[ - "Development Status :: 5 - Production/Stable", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3", - ], - packages=packages, - entry_points={ - "torch_dynamo_backends": [ - "turbine_cpu = shark_turbine.dynamo.backends.cpu:backend", - ], - }, - install_requires=[ - f"numpy{get_version_spec('numpy')}", - f"iree-compiler{get_version_spec('iree-compiler')}", - f"iree-runtime{get_version_spec('iree-runtime')}", - # Use the [torch-cpu-nightly] spec to get a more recent/specific version. - # Note that during the transition to torch 2.3.0 we technically support - # back to torch 2.1, which is why we pin here in this way. However, - # the CI tests on 2.3. - "torch>=2.1.0", - ], - extras_require={ - "torch-cpu-nightly": [f"torch{get_version_spec('torch')}"], - "onnx": [ - f"onnx{get_version_spec('onnx')}", - ], - "testing": [ - f"onnx{get_version_spec('onnx')}", - f"pytest{get_version_spec('pytest')}", - f"pytest-xdist{get_version_spec('pytest-xdist')}", - ], - }, - cmdclass={"build": BuildCommand}, -) diff --git a/core/shark_turbine/aot/__init__.py b/core/shark_turbine/aot/__init__.py deleted file mode 100644 index ceb95fe21..000000000 --- a/core/shark_turbine/aot/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Toolkit for ahead-of-time (AOT) compilation and export of PyTorch programs. -""" - -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .builtins import * -from .compiled_module import * -from .decompositions import * -from .exporter import * -from .fx_programs import FxPrograms, FxProgramsBuilder -from .tensor_traits import * -from .params import * diff --git a/core/shark_turbine/aot/builtins/__init__.py b/core/shark_turbine/aot/builtins/__init__.py deleted file mode 100644 index 0d37f5d8d..000000000 --- a/core/shark_turbine/aot/builtins/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .globals import * -from .jittable import jittable -from ..support.procedural import ( - AbstractBool, - AbstractF32, - AbstractF64, - AbstractI32, - AbstractI64, - AbstractIndex, - AbstractTensor, - abstractify, -) - -# Export the instantiated IREEEmitter as "IREE" -from ..support.procedural.iree_emitter import IREEEmitter as _IREEEmitter - -IREE = _IREEEmitter() -del _IREEEmitter - -__all__ = [ - "AbstractBool", - "AbstractF32", - "AbstractF64", - "AbstractI32", - "AbstractI64", - "AbstractIndex", - "AbstractTensor", - "IREE", - "abstractify", - "export_global", - "export_global_tree", - "export_parameters", - "export_buffers", - "jittable", -] diff --git a/core/shark_turbine/aot/builtins/globals.py b/core/shark_turbine/aot/builtins/globals.py deleted file mode 100644 index befe2bbe9..000000000 --- a/core/shark_turbine/aot/builtins/globals.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Any, Callable, Optional - -import torch.nn as nn - -from ..support.procedural import ( - AbstractTypedef, - Abstractifiable, - GlobalsDef, - TreeAbstractifiable, - abstractify_single_value, -) - -from ..support.ir_utils import ( - NameMapCallback, - GlobalAttributes, -) - -from torch.utils._pytree import ( - TreeSpec, - tree_flatten, - tree_map, -) - - -__all__ = [ - "export_global", - "export_global_tree", - "export_parameters", - "export_buffers", -] - - -class export_global(GlobalsDef, Abstractifiable): - """Exports a single global into a CompiledModule.""" - - __slots__ = ["_name", "_value", "_schema"] - - def __init__( - self, - value: Any, - *, - name: str = "global", - mutable: Optional[bool] = None, - external: Optional[bool] = None, - external_scope: Optional[str] = None, - name_mapper: Optional[NameMapCallback] = None, - uninitialized: Optional[bool] = None, - attrs: Optional[GlobalAttributes] = None, - ): - if attrs is None: - attrs = GlobalAttributes( - mutable=bool(mutable), - external=external, - external_scope=external_scope, - name_mapper=name_mapper, - uninitialized=uninitialized, - ) - super().__init__(attrs) - self._name = name - self._value = value - _, self._schema = tree_flatten(self._value) - - def items(self): - yield (self._name, self._value) - - def schema(self) -> TreeSpec: - return self._schema - - def abstractify(self) -> AbstractTypedef: - return abstractify_single_value(self._value) - - -class export_global_tree(GlobalsDef, Abstractifiable): - """Exports a tree of globals into a CompiledModule.""" - - def __init__( - self, - tree, - *, - mutable: Optional[bool] = None, - external: Optional[bool] = None, - external_scope: Optional[str] = None, - name_mapper: Optional[NameMapCallback] = None, - uninitialized: Optional[bool] = None, - attrs: Optional[GlobalAttributes] = None, - ): - if attrs is None: - attrs = GlobalAttributes( - mutable=bool(mutable), - external=external, - external_scope=external_scope, - name_mapper=name_mapper, - uninitialized=uninitialized, - ) - super().__init__(attrs) - self._tree = tree - self._items, self._schema = tree_flatten(tree) - self._names, _ = tree_flatten(_transform_tree_to_names("", tree)) - assert len(self._items) == len( - self._names - ), f"Name and value tree are different sizes: {len(self._items)} != {len(self._names)}" - - def items(self): - for name, value in zip(self._names, self._items): - yield name, value - - def schema(self) -> TreeSpec: - return self._schema - - def abstractify(self) -> AbstractTypedef: - return tree_map(abstractify_single_value, self._tree) - - -class export_parameters(GlobalsDef, TreeAbstractifiable): - """Exports parameters from an nn.Module. - - These are exposed to procedural programs as a dictionary of param/values. - """ - - __slots__ = [ - "_param_list", - "_schema", - "_tree", - ] - - def __init__( - self, - nn_module: nn.Module, - *, - mutable: Optional[bool] = None, - external: Optional[bool] = None, - external_scope: Optional[str] = None, - name_mapper: Optional[NameMapCallback] = None, - uninitialized: Optional[bool] = None, - attrs: Optional[GlobalAttributes] = None, - ): - if attrs is None: - attrs = GlobalAttributes( - mutable=bool(mutable), - external=external, - external_scope=external_scope, - name_mapper=name_mapper, - uninitialized=uninitialized, - ) - super().__init__(attrs) - self._param_list = list(nn_module.named_parameters()) - self._tree = dict(self._param_list) - _, self._schema = tree_flatten(self._tree) - - def items(self): - for name, value in self._param_list: - yield (name, value) - - def schema(self) -> TreeSpec: - return self._schema - - def abstractify_tree(self): - return tree_map(abstractify_single_value, self._tree) - - def __getitem__(self, key): - return self._tree[key] - - def __repr__(self): - names = [name for name, _ in self._param_list] - return f"" - - -class export_buffers(GlobalsDef, TreeAbstractifiable): - """Exports buffers from an nn.Module. - - These are exposed to procedural programs as a dictionary of param/values. - """ - - __slots__ = [ - "_buffer_list", - "_schema", - "_tree", - ] - - def __init__( - self, - nn_module: nn.Module, - *, - mutable: Optional[bool] = None, - external: Optional[bool] = None, - external_scope: Optional[str] = None, - name_mapper: Optional[NameMapCallback] = None, - uninitialized: Optional[bool] = None, - attrs: Optional[GlobalAttributes] = None, - ): - if attrs is None: - attrs = GlobalAttributes( - mutable=bool(mutable), - external=external, - external_scope=external_scope, - name_mapper=name_mapper, - uninitialized=uninitialized, - ) - super().__init__(attrs) - self._buffer_list = list(nn_module.named_buffers()) - self._tree = dict(self._buffer_list) - _, self._schema = tree_flatten(self._tree) - - def items(self): - for name, value in self._buffer_list: - yield (name, value) - - def schema(self) -> TreeSpec: - return self._schema - - def abstractify_tree(self): - return tree_map(abstractify_single_value, self._tree) - - def __getitem__(self, key): - return self._tree[key] - - def __repr__(self): - names = [name for name, _ in self._param_list] - return f"" - - -def _transform_tree_to_names(prefix: str, tree): - """Produces a topologically similar tree but where each value is a fully qualified name.""" - join = lambda key: f"{prefix}.{key}" if prefix else key - # No need to check for cycles as pytree already did something with it and - # validates. - if isinstance(tree, dict): - return tree.__class__( - (k, _transform_tree_to_names(join(k), v)) for k, v in tree.items() - ) - elif isinstance(tree, (list, tuple)): - return tree.__class__( - _transform_tree_to_names(join(str(index)), v) - for index, v in enumerate(tree) - ) - else: - return prefix diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py deleted file mode 100644 index 29a90617b..000000000 --- a/core/shark_turbine/aot/builtins/jittable.py +++ /dev/null @@ -1,475 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Tracing builtins.""" - -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union - -import warnings - -import torch -from torch._decomp import get_decompositions -import torch._dynamo as dynamo -from torch.fx import ( - GraphModule, -) -from torch.utils._pytree import ( - tree_flatten, - tree_unflatten, -) - -from iree.compiler.extras.fx_importer import ( - GraphNodeImporter, - FxImporter, - FxImporterHooks, -) - -from ...support.ir_imports import ( - FlatSymbolRefAttr, - FunctionType, - Operation, - StringAttr, - SymbolTable, - TypeAttr, - Value, - func_d, - util_d, -) - -from ...support.logging import aot_logger as logger - -from ..decompositions import current_aot_decompositions -from ..passes import ( - functorch_functionalize, -) - -from ..support.ir_utils import ( - ModuleBuilder, -) - -from ..support.procedural import ( - CallableIntrinsic, - IrImmediateTensor, - IrTensor, - IrTrace, - MaterializedGlobal, -) - -StringAttrOrStr = Union[StringAttr, str] - - -class _Hooks(FxImporterHooks): - __slots__ = [ - "cloned_global_symbols", - "module_builder", - ] - - def __init__(self, module_builder: ModuleBuilder): - self.module_builder = module_builder - # When we first encounter a global during import, we have to pull it - # into the local module being populated by the GraphNodeImporter. This - # will exactly match the global in the target module we are merging into - # and exists so that the IR is valid during Fx import. We keep the set of - # symbols we have done this to here. - self.cloned_global_symbols: set[str] = set() - - def resolve_literal(self, gni: GraphNodeImporter, literal: Any) -> Optional[Value]: - module_builder = self.module_builder - cloned_global_symbols = self.cloned_global_symbols - - # We support resolution of tracked reference types. Currently this - # only includes Tensors. All others we let the importer do what it - # is going to do. - if not isinstance(literal, torch.Tensor): - return None - - # See if we know about it. - mapping = module_builder.global_ref_tracker.track(literal) - if mapping.is_empty: - # If it is unknown, just let the default importer take it on. - return None - - # Already materialized. - logger.debug("Resolved defined global for literal %r", mapping) - materialized_global: MaterializedGlobal = mapping.value # type: ignore - - # Clone the global into the import module (so that our symbol refs are - # legal). Note that the merger will ignore these since they already - # exist in the target module. - if materialized_global.symbol_name not in cloned_global_symbols: - materialized_global.global_op.operation.clone(ip=gni.fx_importer._m_ip) - cloned_global_symbols.add(materialized_global.symbol_name) - - # Emit a global load and conversion. - vtensor_type = gni._cc.tensor_to_vtensor_type(literal) - loaded_value = util_d.GlobalLoadOp( - materialized_global.ir_type, materialized_global.symbol_name - ).result - converted_value = Operation.create( - "torch_c.from_builtin_tensor", - results=[vtensor_type], - operands=[loaded_value], - ).result - return converted_value - - -ALL_PASSES: Set[str] = set(["functorch_functionalize"]) -DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize",) - - -class jittable(CallableIntrinsic): - """Decorator which takes a PyTorch function and makes it callable from tracing. - - It will be internally JIT-ed and exported into the module as needed. - """ - - __slots__ = [ - "constraints", - "decomposition_table", - "wrapped_f", - "function_name", - "_passes", - ] - - def __init__( - self, - wrapped_f, - *, - decompose_ops: Optional[List[Any]] = None, - decomposition_table: Optional[Dict[Any, Callable[..., Any]]] = None, - constraints: Optional[List[Any]] = None, - function_name: Optional[str] = None, - passes: Sequence[str] = DEFAULT_PASSES, - ): - if decomposition_table is None: - decomposition_table = current_aot_decompositions() - if decompose_ops: - decomposition_table.update(get_decompositions(decompose_ops)) - - self.constraints = constraints - self.decomposition_table = decomposition_table - self.wrapped_f = wrapped_f - self.function_name = function_name if function_name else wrapped_f.__name__ - self._passes = set(passes) - for p in passes: - if p not in ALL_PASSES: - raise ValueError(f"Pass is unknown: {p}") - - def __repr__(self): - return f"" - - def resolve_call( - self, - proc_trace: IrTrace, - *py_args, - constraints: Optional[List[Any]] = None, - **py_kwargs, - ): - type_converter = proc_trace.module_builder.native_type_converter - # Accumulate all constraints into a new list. - if constraints is None: - constraints = [] - else: - constraints = list(constraints) - if self.constraints is not None: - constraints.extend(self.constraints) - - export_kwargs = {} - if len(constraints) > 0: - warnings.warn( - "Compiling program with the old PyTorch constraints system " - "for dynamic shapes is deprecated and will break on PyTorch " - "nightlies after the 2.3 release cut (expect either a PyTorch " - "warning or exception to follow)", - DeprecationWarning, - ) - export_kwargs["constraints"] = constraints - - # Convert procedural trace values to things that Dynamo can handle. - flat_py_args, args_tree = tree_flatten((py_args, py_kwargs)) - flat_pytorch_args = [] - flat_ir_args = [] - for py_arg in flat_py_args: - ir_arg, pytorch_arg = self._split_py_arg(py_arg, constraints=constraints) - flat_ir_args.append(ir_arg) - flat_pytorch_args.append(pytorch_arg) - - # We have to do a bit of a contortion to preserve the ability for torch.export - # to rewrite output signatures in a way that is useful for us, and some passes - # clobber them or don't support structured arguments. So we split the difference - # and operate on linearized inputs (which is what we are working to get to and - # have already captured the schema above) and structured outputs, only using - # output clobbering passes as pre-processors. These kind of jagged - # composability constraints kind of suck, but seem to be where we are... - def flat_wrapped_f(*args): - pytorch_args, pytorch_kwargs = tree_unflatten(args, args_tree) - return self.wrapped_f(*pytorch_args, **pytorch_kwargs) - - # Run pre-processing passes. - transformed_f = flat_wrapped_f - if "functorch_functionalize" in self._passes: - transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) - - # Ask dynamo to give us an aten graph. - # TODO: Cache this for repeated calls. - logger.debug("Performing dynamo.export(constraints=%r)", constraints) - exported_f = dynamo.export( - transformed_f, - aten_graph=True, - decomposition_table=self.decomposition_table, # type: ignore - assume_static_by_default=True, - **export_kwargs, # type: ignore - ) - logger.debug("Invoking dynamo trace") - gm, guards = exported_f(*flat_pytorch_args) - logger.debug("Dynamo trace complete") - - # TODO: Add debug logging for the exported graph module. - # gm.print_readable() - - # We capture metadata about the results from the raw graph so that we can - # pass it along in the trace (since the IREE type system is a partial erasure - # of the PyTorch type system and we need the fidelity). - # This could be done by the importer but the API gets twisty so just - # doing it here since it isn't clear anyone else would ever want this. - out_spec, result_tensor_infos = _extract_graph_output_metadata(gm) - - # Import the FX graph to MLIR in a new module. - fx_importer = FxImporter( - context=proc_trace.context, - config_check=False, - hooks=_Hooks(proc_trace.module_builder), - py_attr_tracker=proc_trace.module_builder.fx_py_attr_tracker, - ) - fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name) - - # TODO: Real debugging options - # print(fx_importer.module, file=sys.stderr) - - # Splice the converted module into the main module by taking advantage - # of what we know about the conversion module: - # 1. There will be a public function of `self.function_name` that we - # want to preserve a handle to. - # 2. The function may symbolically refer to other functions and - # globals. - # 3. There is no concept of a module initializer. - # 4. When allocating tracked globals, we set them up in both the conversion - # module and the main module. We note them in the conversion module with - # the attribute `util.import_as_symbol` so that we can re-associate here. - merger = _Merger( - proc_trace.module_builder, - fx_importer.module.operation, - fx_importer.symbol_table, - self.function_name, - ) - target_op = merger.merge() - assert target_op, "Could not find target op in merged module" - - # Uncomment to print the final module. - # TODO: Real debugging options. - # print(target_op, file=sys.stderr) - - # TODO: Debug upstream why iteration over children isn't creating a typed view. - # This should just be `target_op.function_type` - target_ftype = FunctionType( - TypeAttr(target_op.attributes["function_type"]).value - ) - target_symbol_ref = FlatSymbolRefAttr.get( - StringAttr(target_op.attributes["sym_name"]).value - ) - - assert len(flat_ir_args) == len(target_ftype.inputs), ( - f"Mismatched number of IR call args vs function decl: " - f"{len(flat_ir_args)} vs {len(target_ftype.inputs)}\n" - f" For call to: {target_ftype}" - ) - - # Since the target function is defined on torch types, we must do - # a cast on each from native->torch. - flat_ir_args = [ - type_converter.materialize_native_to_torch(v, torch_type) - for v, torch_type in zip(flat_ir_args, target_ftype.inputs) - ] - - with proc_trace.ip, proc_trace.loc: - flat_ir_results = func_d.CallOp( - target_ftype.results, target_symbol_ref, flat_ir_args - ).results - - assert len(flat_ir_results) == len(result_tensor_infos) - flat_py_results = [] - for ir_result, result_tensor_info in zip(flat_ir_results, result_tensor_infos): - assert result_tensor_info is not None - (dtype,) = result_tensor_info - native_ir_result = type_converter.materialize_torch_to_native(ir_result) - if dtype is not None: - flat_py_results.append(IrImmediateTensor(native_ir_result, dtype)) - else: - raise TypeError( - f"Unknown PyTorch->IREE value mapping for jittable result: {result_tensor_info}->{native_ir_result}" - ) - - tree_py_results = tree_unflatten(flat_py_results, out_spec) - return tree_py_results - - def _split_py_arg(self, arg, constraints: List[Any]) -> Tuple[Value, Any]: - if isinstance(arg, IrTensor): - meta_tensor, meta_constraints = arg._to_meta_tensor() - constraints.extend(meta_constraints) - return arg.ir_value, meta_tensor - - raise TypeError(f"Unsupported argument to jittable: {arg}") - - -class _Merger: - __slots__ = [ - "context", - "to_module_builder", - "from_module_op", - "from_symbol_table", - "import_function_name", - "rename_map", - "nested_symbol_ops", - "nested_symbol_table_ops", - "private_attr", - ] - - def __init__( - self, - to_module_builder: ModuleBuilder, - from_module_op: Operation, - from_symbol_table: SymbolTable, - import_function_name: str, - ): - self.context = from_module_op.context - self.to_module_builder = to_module_builder - self.from_module_op = from_module_op - self.from_symbol_table = from_symbol_table - self.import_function_name = import_function_name - - self.rename_map: Dict[StringAttr, StringAttr] = {} - self.nested_symbol_ops: List[Operation] = [] - self.nested_symbol_table_ops: List[Operation] = [] - self.private_attr = StringAttr.get("private", self.context) - - def merge(self) -> Optional[Operation]: - # The needle we are looking for. - imported_func_op: Optional[Operation] = None - - # Import functions. - func_ops = _get_top_level_ops(self.from_module_op, func_d.FuncOp.OPERATION_NAME) - for func_op in func_ops: - # Pre-rename, check if it is the one we are looking for. - func_name = _get_symbol_name(func_op) - if func_name == self.import_function_name: - imported_func_op = func_op - # All functions become private. - func_op.attributes["sym_visibility"] = self.private_attr - self.import_symbol_op(func_op) - self.nested_symbol_table_ops.append(func_op) - - # Go back through to nested symbol table ops and RAUW. - for sym_operation in self.nested_symbol_table_ops: - for from_symbol, to_symbol in self.rename_map.items(): - from_name = StringAttr(from_symbol).value - to_name = StringAttr(to_symbol).value - SymbolTable.replace_all_symbol_uses(from_name, to_name, sym_operation) - - return imported_func_op - - def import_symbol_op(self, symbol_op): - target_symbol_table = self.to_module_builder.symbol_table - symbol_op = symbol_op.detach_from_parent() - orig_symbol = SymbolTable.get_symbol_name(symbol_op) - orig_symbol_name = StringAttr(orig_symbol).value - # Make sure it is unique. - new_symbol_name = _uniqueify_name(orig_symbol_name, target_symbol_table) - if new_symbol_name != orig_symbol_name: - SymbolTable.set_symbol_name(symbol_op, new_symbol_name) - self._rename(orig_symbol, new_symbol_name) - - self.to_module_builder.body.append(symbol_op) - self.nested_symbol_ops.append(symbol_op) - target_symbol_table.insert(symbol_op) - - def _rename(self, from_symbol: StringAttrOrStr, to_symbol: StringAttrOrStr): - from_symbol = self._make_string_attr(from_symbol) - to_symbol = self._make_string_attr(to_symbol) - if from_symbol != to_symbol: - self.rename_map[from_symbol] = to_symbol - - def _make_string_attr(self, string_attr_or_str: StringAttrOrStr): - if isinstance(string_attr_or_str, str): - with self.context: - return StringAttr.get(string_attr_or_str) - else: - return StringAttr(string_attr_or_str) - - -def _get_top_level_ops(module_op: Operation, *op_names: str) -> Sequence[Operation]: - results = [] - for op_view in module_op.regions[0].blocks[0]: - op = op_view.operation - if op.name in op_names: - results.append(op) - return results - - -def _get_symbol_name(op: Operation) -> str: - return StringAttr(op.attributes["sym_name"]).value - - -def _uniqueify_name(local_name: str, st: SymbolTable) -> str: - index = -1 - while True: - index += 1 - full_name = local_name - if index > 0: - full_name += f"${index}" - if full_name not in st: - return full_name - - -ResultTensorInfo = Optional[Tuple[torch.dtype]] - - -def _extract_graph_output_metadata( - gm: GraphModule, -) -> Tuple[Any, List[ResultTensorInfo]]: - # In "preserve signatures" mode, there will only be one output and its arguments - # will be the flat list of results that can be unflattened against the _out_spec - # on the graph module. There is a bit of archaelogy going on here but the idea - # is to extract an output tree spec and a tensor dtype (or None) for each flat - # tensor return value. We need this in order to propagate the actual tensor dtype - # on the procedural side. - output_metadata: List[ResultTensorInfo] = [] - try: - out_spec = gm._out_spec - except AttributeError: - raise AssertionError( - "Expected PyTorch to add an _out_spec attribute to the GraphModule" - ) - - output_nodes = [] - for node in gm.graph.nodes: - if node.op == "output": - output_nodes.append(node) - - assert ( - len(output_nodes) == 1 - ), "Expected PyTorch to produce a graph with one output node" - for flat_output_list in output_nodes[0].args: - for flat_output_node in flat_output_list: - tensor_meta = flat_output_node.meta.get("tensor_meta") - fake_val = flat_output_node.meta.get("val") - dtype = None - if tensor_meta is not None: - dtype = tensor_meta.dtype - elif fake_val is not None: - dtype = fake_val.dtype - output_metadata.append((dtype,) if dtype is not None else None) - return out_spec, output_metadata diff --git a/core/shark_turbine/aot/compiled_module.py b/core/shark_turbine/aot/compiled_module.py deleted file mode 100644 index 3f44c8b94..000000000 --- a/core/shark_turbine/aot/compiled_module.py +++ /dev/null @@ -1,687 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union - -import enum -import inspect -import logging -from pathlib import Path -import re -import weakref -import sys - -from torch.export import ExportedProgram - -from . import builtins - -from ..support.ir_imports import ( - Context, - Location, - MLIRError, - Module, - Operation, - PassManager, - StringAttr, -) -from ..support.logging import aot_logger as logger -from ..transforms.general.custom_op_expansion import ExpandCustomOpsPass - -from .support.procedural import ( - GlobalsDef, - ProcedureTrace, - current_ir_trace, -) - -from .support.procedural.exported_program import import_exported_program - -from .support.ir_utils import ( - ModuleBuilder, -) - - -__all__ = [ - "CompiledModule", -] - -################################################################################ -# Data structures -################################################################################ - - -class ImportPhase(enum.IntEnum): - # Imports to torch dialect IR. - TORCH_IR = 0 - - # Performs custom op expansion and post processing for known custom ops. - CUSTOM_OP_EXPANSION = 1 - - # Compiles to valid MLIR that IREE can ingest as an input with the - # input-type of torch. - IMPORT = CUSTOM_OP_EXPANSION - - # Runs the IREE input pipeline to compile to internal form. - IREE_INTERNAL = 2 - - # The full import pipeline (this is an alias for another enum value). - FULL = IREE_INTERNAL - - @staticmethod - def parse(spec: Union[str, None, "ImportPhase"]) -> "ImportPhase": - if spec is None: - return ImportPhase.IMPORT - if isinstance(spec, ImportPhase): - return spec - spec = spec.upper().replace("-", "_") - if spec not in ImportPhase.__members__: - raise ValueError( - f"For import_phase= argument, expected one of: " - f"{', '.join(ImportPhase.__members__.keys())}" - ) - return ImportPhase[spec] - - def __str__(self): - return self.name - - -class PyOnlyDef: - """Exportable that does not export but can be resolved in Python.""" - - __slots__ = ["py_value"] - - def __init__(self, py_value): - self.py_value = py_value - - def __str__(self): - return str(self.py_value) - - def __repr__(self): - return repr(self.py_value) - - def __call__(self, *args, **kwargs): - return self.py_value(*args, **kwargs) - - -class ExportProcDef: - __slots__ = [ - "callable", - "export_name", - "signature", - "file_line_loc", - ] - - def __init__( - self, - export_name: str, - callable: Callable, - *, - signature, - file_line_loc: Optional[Tuple[str, int]] = None, - ): - self.export_name = export_name - self.callable = callable - self.signature = signature - self.file_line_loc = file_line_loc - - def copy(self) -> "ExportProcDef": - return ExportProcDef(self.export_name, self.callable, signature=self.signature) - - def __repr__(self): - return f"" - - -class ExportedProgramDef: - def __init__( - self, - ep: ExportedProgram, - *, - export_name: Optional[str] = None, - public: bool = False, - ): - self.export_name = export_name - self.exported_program = ep - self.public = public - - def copy(self) -> "ExportedProgramDef": - return ExportedProgramDef( - self.exported_program, export_name=self.export_name, public=self.public - ) - - def __repr__(self): - return f"" - - -Exportable = Union[ExportProcDef, ExportedProgramDef, PyOnlyDef, GlobalsDef] - - -class CompiledModuleClassInfo: - __slots__ = [ - "all_exports", - "ir_module_name", - ] - - def __init__(self, *, ir_module_name: str): - self.ir_module_name = ir_module_name - self.all_exports: Dict[str, Exportable] = dict() - - def add_export(self, key: str, value: Exportable): - if key in self.all_exports: - raise TypeError(f"Cannot export attribute more than once: {key}") - self.all_exports[key] = value - - @property - def export_procs(self) -> Generator[Tuple[str, ExportProcDef], None, None]: - return filter( - lambda kv_tuple: isinstance(kv_tuple[1], ExportProcDef), - self.all_exports.items(), - ) # type: ignore - - @property - def exported_programs( - self, - ) -> Generator[Tuple[str, ExportedProgramDef], None, None]: - return filter( - lambda kv_tuple: isinstance(kv_tuple[1], ExportedProgramDef), - self.all_exports.items(), - ) # type: ignore - - @property - def py_only_defs(self) -> Generator[Tuple[str, PyOnlyDef], None, None]: - return filter( - lambda kv_tuple: isinstance(kv_tuple[1], PyOnlyDef), - self.all_exports.items(), - ) # type: ignore - - @property - def globals_defs(self) -> Generator[Tuple[str, GlobalsDef], None, None]: - return filter( - lambda kv_tuple: isinstance(kv_tuple[1], GlobalsDef), - self.all_exports.items(), - ) # type: ignore - - def def_attribute(self, key, value): - # Some decorators, the only thing we do is convert them to PyOnlyDef. - # Do that first so the generic descriptor code below handles them. - if isinstance(value, builtins.jittable): - value = PyOnlyDef(value) - - # Promote a torch ExportedProgram to an ExportedProgramDef. - if isinstance(value, ExportedProgram): - value = ExportedProgramDef( - value, export_name=key, public=not key.startswith("_") - ) - - # Detect our own descriptors. - if isinstance(value, GlobalsDef): - logging.debug("DEFINE GLOBALS: %s = %r", key, value) - self.add_export(key, value) - return value - if isinstance(value, ExportProcDef): - value = value.copy() - if value.export_name is None: - value.export_name = key - self.add_export(key, value) - return value - if isinstance(value, PyOnlyDef): - logging.debug("DEFINE PY_ONLY: %s = %r", key, value) - self.add_export(key, value) - return value - if isinstance(value, ExportedProgramDef): - if value.export_name is None: - value = value.copy() - value.export_name = key - logging.debug("DEFINE EXPORTED_PROGRAM: %r", value.export_name) - self.add_export(key, value) - return value - - # Infer if it is an exported function. - if callable(value) and inspect.isfunction(value): - return self.def_export_proc(key, value) - - raise TypeError( - f"cannot set arbitrary Python value '{key}' on " - f"compiled module: {value!r}" - ) - - def def_export_proc(self, name, f) -> ExportProcDef: - logging.debug("DEFINE EXPORT: %s = %r", name, f) - # Get a reasonable location. - file_line_loc = None - try: - sourcefile = inspect.getsourcefile(f) - _, linenum = sourcelines = inspect.getsourcelines(f) - except OSError: - ... - else: - file_line_loc = (sourcefile or "", linenum) - - sig = inspect.signature(f) - if len(sig.parameters) < 1: - raise TypeError( - f"export proc '{name}' is expected to have a 'self' parameter" - ) - - # By default, we discover signature details from default values - # on the function. But we should also source from an annotation. - input_sig = [] - parameter_list = list(sig.parameters.values()) - # TODO: Reconstitute a pytree so as to handle kwargs? - # See: https://github.com/nod-ai/SHARK-Turbine/issues/128 - for param in parameter_list[1:]: - if ( - param.kind != inspect.Parameter.POSITIONAL_ONLY - and param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - raise TypeError( - f"exported functions only support positional parameters" - ) - param_desc = param.default - if param_desc is inspect.Parameter.empty: - # TODO: Merge from a decorator? - # See: https://github.com/nod-ai/SHARK-Turbine/issues/126 - raise TypeError( - f"export function {name} missing required default value annotation " - f"for '{param.name}'" - ) - input_sig.append(param_desc) - - info = ExportProcDef(name, f, signature=input_sig, file_line_loc=file_line_loc) - self.add_export(name, info) - return info - - -class CompiledModuleInstanceInfo: - """Info class for compiled module instances.""" - - __slots__ = [ - "class_info", - "module_builder", - "shadow_dict", - "current_import_phase", - ] - - def __init__( - self, class_info: CompiledModuleClassInfo, module_builder: ModuleBuilder - ): - self.class_info = class_info - self.module_builder = module_builder - # The shadow dict holds instance attributes. We stash them here and the - # Program instance itself arbitrates access via getattr/setattr. - self.shadow_dict: dict[str, Any] = dict() - self.current_import_phase = ImportPhase.TORCH_IR - - -################################################################################ -# Live reference accounting -################################################################################ - -_all_compiled_module_class_infos: weakref.WeakKeyDictionary[ - "CompiledModuleMeta", CompiledModuleClassInfo -] = weakref.WeakKeyDictionary() -_all_compiled_module_instance_infos: weakref.WeakKeyDictionary[ - "CompiledModule", CompiledModuleInstanceInfo -] = weakref.WeakKeyDictionary() - - -################################################################################ -# CompiledModule and metaclass -################################################################################ - -# Gate that is set to True once metaclass setup is complete. -_metaclass_setup_complete = False - - -@property # type: ignore -def _blackhole_instance_attribute(self): - # We're not here. - raise AttributeError - - -def _uncallable_public_export(*args, **kwargs): - raise RuntimeError(f"Calls to exported functions not yet supported") - - -_COMPILED_MODULE_API_ATTRIBUTES = [ - "create_from_dict", - "expand_custom_ops", - "export_global", - "get_class_info", - "get_info", - "get_module_builder", - "get_mlir_module", - "jittable", - "run_import", - "run_pass_pipeline", - "save_mlir", -] - - -class CompiledModuleMeta(type): - """Metaclass for all CompiledModule subclasses. - - Do not use directly. - """ - - # __new__ on a metaclass is called when a new subclass is constructed. - # It is passed the dictionary of declared attributes and any keyword - # arguments from the class declaration: - # class Foo(Bar, kwarg="you probably just learned this is possible"): - def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None): - if not _metaclass_setup_complete: - return type.__new__(mcls, name, bases, dct) - - ir_module_name = _derive_ir_module_name(name, export_name) - logger.debug("Create new CompiledModule: %s", ir_module_name) - info = CompiledModuleClassInfo(ir_module_name=ir_module_name) - - # Process that attributes that were set as part of class definition. - # Any attributes that we decide are part of the compiled module - # are removed and appropriately transferred to the backing info - # hierarchy. - del_attr_keys = set() - for key, value in dct.items(): - if key.startswith("__") and key.endswith("__"): - continue - del_attr_keys.add(key) - info.def_attribute(key, value) - - for key in del_attr_keys: - del dct[key] - - # The CompiledModule exports a number of its own API methods, which - # we explicitly hide on subclasses and instances. - for key in _COMPILED_MODULE_API_ATTRIBUTES: - if key not in dct: - dct[key] = _blackhole_instance_attribute - - # Inheriting methods, globals, and export from parent class. - # Use case such as building a child-class to StatelessLlama. - for base in bases: - if base is CompiledModule: - continue - base_exports = _all_compiled_module_class_infos[base].all_exports - for export_name in base_exports: - if export_name in info.all_exports: - continue - info.all_exports[export_name] = base_exports[export_name] - - # Finish construction. - new_class = type.__new__(mcls, name, bases, dct) - _all_compiled_module_class_infos[new_class] = info - return new_class - - # Gets unresolved attributes on classes of this meta-class. - def __getattr__(cls, key): - # CompiledModule does not expose anything else. - if cls is CompiledModule: - raise AttributeError(f"CompiledModule.{key}") - info = CompiledModule.get_class_info(cls) - try: - return info.all_exports[key] - except KeyError: - raise AttributeError - - -class CompiledModule(metaclass=CompiledModuleMeta): - """Base class for all staged modules.""" - - @classmethod - def create_from_dict( - cls: CompiledModuleMeta, - name: str, - dct: dict, - *, - export_name: Optional[str] = None, - ) -> CompiledModuleMeta: - """Creates a CompiledModule subclass with an explicit dictionary of members. - - This is the unsugared form of: - - ``` - class Foo(CompiledModule, export_name="bar"): - def member(): ... - ``` - """ - return CompiledModuleMeta(name, (cls,), dct, export_name=export_name) - - @staticmethod - def get_class_info(cls: CompiledModuleMeta) -> CompiledModuleClassInfo: - return _all_compiled_module_class_infos[cls] - - @staticmethod - def get_info(inst: "CompiledModule") -> CompiledModuleInstanceInfo: - return _all_compiled_module_instance_infos[inst] - - @staticmethod - def get_module_builder(inst: "CompiledModule") -> Operation: - if not isinstance(inst, CompiledModule): - raise ValueError( - f"Expected a CompiledModule instance but got: {inst.__class__}" - ) - info = CompiledModule.get_info(inst) - return info.module_builder - - @staticmethod - def get_mlir_module(inst: "CompiledModule") -> Operation: - return CompiledModule.get_module_builder(inst).module_op - - @staticmethod - def run_import( - inst: "CompiledModule", import_to: Union[ImportPhase, str, None] = "import" - ): - import_to = ImportPhase.parse(import_to) - info = CompiledModule.get_info(inst) - for phase in [ - ImportPhase.TORCH_IR, - ImportPhase.CUSTOM_OP_EXPANSION, - ImportPhase.IREE_INTERNAL, - ]: - if phase > import_to: - logger.debug("Stopped import at phase %s", info.current_import_phase) - break - if info.current_import_phase >= phase: - continue - logger.debug("Run import phase %s", phase) - if phase == ImportPhase.TORCH_IR: - # Starting phase. Do nothing. - ... - elif phase == ImportPhase.CUSTOM_OP_EXPANSION: - CompiledModule.expand_custom_ops(inst) - elif phase == ImportPhase.IREE_INTERNAL: - CompiledModule.run_pass_pipeline(inst, "builtin.module(torch-to-iree)") - else: - assert False, f"Phase {phase} not handled in switch" - info.current_import_phase = phase - - @staticmethod - def expand_custom_ops(inst: "CompiledModule"): - """Performs custom torch.operator expansion for known custom ops.""" - logger.debug("Expand known torch.operator custom ops") - module_op = CompiledModule.get_mlir_module(inst) - p = ExpandCustomOpsPass(module_op) - p.run() - - @staticmethod - def run_pass_pipeline( - inst: "CompiledModule", pipeline: str, enable_ir_printing: bool = False - ): - """Runs an arbitrary pass pipeline against the current IR. - - Args: - pipeline: The text format pass pipeline as supported by PassManager.parse. - enable_ir_printing: Enables print-after-all to stderr. - """ - logger.debug("Run pass pipeline: %s", pipeline) - module_op = CompiledModule.get_mlir_module(inst) - with module_op.context: - pm = PassManager.parse(pipeline) - if enable_ir_printing: - module_op.context.enable_multithreading(False) - pm.enable_ir_printing() - try: - pm.run(module_op) - except MLIRError: - # TODO: Better error handling. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/127 - print(module_op, file=sys.stderr) - raise - - @staticmethod - def save_mlir(inst: "CompiledModule", path: Union[Path, str]): - """Saves a snapshot of the MLIR module in this CompiledModule to a file. - - This is a convenience wrapper around the facilities of the underlying - API and does not expose all features. - - Args: - path: The file path to write to. If the extension is ".mlirbc", it - will be written as bytecode. - """ - path = Path(path) - bytecode = path.suffix == ".mlirbc" - module_op = CompiledModule.get_mlir_module(inst) - with open(path, "wb") as f: - if bytecode: - module_op.write_bytecode(f) - else: - module_op.print(f, binary=True) - - jittable = staticmethod(builtins.jittable) - - def __getattr__(self, name): - info = CompiledModule.get_info(self) - try: - return info.shadow_dict[name] - except KeyError: - raise AttributeError(f"Attribute {name} not defined") - - def __setattr__(self, name, value): - info = CompiledModule.get_info(self) - try: - descriptor = info.shadow_dict[name] - except KeyError: - raise AttributeError(f"Attribute {name} cannot be set") - current_ir_trace().handle_assignment(self, descriptor, value) - - def __new__( - cls, - *, - context: Optional[Context] = None, - module_op: Optional[Operation] = None, - import_to: Union[ImportPhase, None, str] = "import", - ): - import_to = ImportPhase.parse(import_to) - self = super().__new__(cls) - class_info = CompiledModule.get_class_info(cls) - if context and module_op: - raise ValueError("Only one of context= or module_op= can be specified") - if not context and not module_op: - try: - context = Context.current - except ValueError: - pass - - if not context: - context = Context() - - if not module_op: - with context: - loc = Location.unknown(context=context) - module = Module.create(loc) - module_op = module.operation - module_op.attributes["sym_name"] = StringAttr.get( - class_info.ir_module_name, context=context - ) - module_builder = ModuleBuilder(module_op) - info = CompiledModuleInstanceInfo(class_info, module_builder=module_builder) - _all_compiled_module_instance_infos[self] = info - - # Instantiate globals - for key, globals_def in info.class_info.globals_defs: - info.shadow_dict[key] = globals_def.track(module_builder, key) - - # Make PyOnly defs visible. - for key, py_def in info.class_info.py_only_defs: - info.shadow_dict[key] = py_def.py_value - - # Instantiate exported programs. - # TODO: This should be done in two phases along with export_procs - # in order to enable dependence. - for key, ep_def in info.class_info.exported_programs: - info.shadow_dict[key] = import_exported_program( - module_builder, - ep_def.exported_program, - symbol_name=ep_def.export_name or "main", - symbol_visibility=None if ep_def.public else "private", - ) - - # Instantiate procs. - # TODO: This should be done in two phases, first binding the symbols - # and then defining them, enabling dependence. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/129 - for key, proc_def in info.class_info.export_procs: - - def do_export(proc_def: ExportProcDef): - def invoke_with_self(*args, **kwargs): - return proc_def.callable(self, *args, **kwargs) - - logger.debug("Generating procedural function: %s", key) - if proc_def.file_line_loc: - loc = Location.file( - proc_def.file_line_loc[0], - proc_def.file_line_loc[1], - col=0, - context=module_builder.context, - ) - else: - loc = Location.unknown(context=module_builder.context) - trace = ProcedureTrace.define_func( - module_builder, - symbol_name=proc_def.export_name, - posargs=proc_def.signature, - kwargs={}, # TODO(#128): kwargs - loc=loc, - ) - trace.trace_py_func(invoke_with_self) - info.shadow_dict[key] = _uncallable_public_export - - do_export(proc_def) - - module_builder.finalize_construct() - CompiledModule.run_import(self, import_to) - return self - - -_metaclass_setup_complete = True - -################################################################################ -# Utilities -################################################################################ - - -def _derive_ir_module_name(class_name: str, explicit_name: Optional[str]): - """Returns an appropriate module export name given a class name and override. - - If an explicit_name is given, that is used as is. Otherwise, the class name - is mangled by: - * Removing and "Module" suffix. - * Converting camel case to snake case. - """ - if explicit_name: - return explicit_name - return _to_snake_case(_strip_suffix(class_name, "Module")) - - -def _to_snake_case(s: str) -> str: - return re.sub(r"(? str: - if s.endswith(optional_suffix): - return s[0 : len(s) - len(optional_suffix)] - else: - return s diff --git a/core/shark_turbine/aot/decompositions.py b/core/shark_turbine/aot/decompositions.py deleted file mode 100644 index 29f723453..000000000 --- a/core/shark_turbine/aot/decompositions.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import contextlib -from typing import Optional - -import torch - -from ..dynamo.decompositions import ( - _current, - _extend_context_manager, - DecompositionOpsList, - DecompositionTable, -) - -__all__ = [ - "current_aot_decompositions", - "extend_aot_decompositions", -] - - -def current_aot_decompositions() -> DecompositionTable: - """Gets the current decomposition table for AOT.""" - return _current("aot") - - -def extend_aot_decompositions( - *, - from_current: bool = True, - add_ops: Optional[DecompositionOpsList] = None, - remove_ops: Optional[DecompositionOpsList] = None -): - """Context manager which extends the list of decompositions used for AOT.""" - return _extend_context_manager( - "aot", from_current=from_current, add_ops=add_ops, remove_ops=remove_ops - ) - - -############################################################################### -# Workarounds -############################################################################### - - -def _patch_op_dispatch(op): - if torch.__version__ >= "2.3.0" and torch.__version__ < "2.4": - # Around the torch 2.3.0 release cut, there was a regression such that - # running decompositions in a functionalized context did not work - # with Python registered ops. The issue is that they have an incomplete - # list of mode handler registrations and cannot handle the - # FunctionalTensorMode. Since we only have a handful of these, and - # since we can assume that for the sake of expediency, functional - # dispatch is basically the same as fake tensor dispatch, we just - # take the fake tensor registration and dup it onto the functional - # registration. - # Note that the torch._higher_order_ops.auto_functionalize is registered - # in Python and is itself broken, it needs to be monkey patched. - # See: https://github.com/pytorch/pytorch/issues/122752 - from torch._subclasses.fake_tensor import FakeTensorMode - from torch._subclasses.functional_tensor import FunctionalTensorMode - - t = op.python_key_mode_table - if FunctionalTensorMode not in t: - handler = t[FakeTensorMode] - t[FunctionalTensorMode] = handler - - -_patched_op_dispatch_for_export = False - - -def _patch_op_dispatch_for_export(): - global _patched_op_dispatch_for_export - if _patched_op_dispatch_for_export: - return - _patched_op_dispatch_for_export = True - import torch._higher_order_ops.auto_functionalize - - _patch_op_dispatch(torch._higher_order_ops.auto_functionalize.auto_functionalized) diff --git a/core/shark_turbine/aot/exporter.py b/core/shark_turbine/aot/exporter.py deleted file mode 100644 index 0aa317a99..000000000 --- a/core/shark_turbine/aot/exporter.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import overload, Any, Dict, List, Optional, Sequence, Tuple, Type, Union -import io -from pathlib import Path -import platform -import warnings - -import torch - -from iree.compiler.api import ( - Session, - Source, - Output, -) - -from ..support.ir_imports import ( - Context, - Operation, -) - -from .builtins import * -from .compiled_module import ( - CompiledModule, - ImportPhase, -) -from .fx_programs import FxPrograms -from . import decompositions - -__all__ = [ - "export", - "ExportOutput", -] - -_is_windows = platform.system() == "Windows" - - -ModuleLike = Union[ - torch.nn.Module, - Type[CompiledModule], - torch.export.ExportedProgram, - FxPrograms, -] -SaveableTarget = Union[str, Path, None, Output] - - -class ExportOutput: - """Wrapper around a CompiledModule produced by `export`.""" - - def __init__( - self, - session: Session, - compiled_module: CompiledModule, - *, - importer_uses_session: bool = False, - ): - self.session = session - self.session.set_flags("--iree-input-type=torch") - self.compiled_module = compiled_module - self._importer_uses_session = importer_uses_session - - @property - def mlir_module(self) -> Operation: - """Gets the MLIR module resulting from the last compilation phase.""" - return CompiledModule.get_mlir_module(self.compiled_module) - - def verify(self): - """Runs the verifier on the module, raising an exception on failure.""" - self.mlir_module.verify() - - def print_readable(self, large_elements_limit: int = 50): - """Prints a human readable version of the current compilation IR.""" - self.mlir_module.print(large_elements_limit=large_elements_limit) - - def save_mlir(self, file_path: Union[str, Path]): - """Saves the current compilation IR to a path on disk. - - Args: - file_path: Path to save the file. If it has a ".mlirbc" - extension, it will be saved as bytecode. Otherwise as - text. - """ - file_path = Path(file_path) - with open(file_path, "wb") as f: - if file_path.suffix == ".mlirbc": - self.mlir_module.write_bytecode(f) - else: - self.mlir_module.print(file=f, binary=True) - - def import_to(self, import_to: Union[ImportPhase, str]): - """Compiles the modules to a mnemonic import phase. - - This is a no-op if already compiled to this phase. - """ - CompiledModule.run_import(self.compiled_module, import_to) - - def compile( - self, - save_to: SaveableTarget, - *, - target_backends: Union[str, Sequence[str]] = ("llvm-cpu",), - ) -> Optional[memoryview]: - """Compiles the exported program to an executable binary. - - Args: - save_to: Where to save the compiled binary. Can be one of: - None: outputs to a memory buffer and return the API Output. - (str, Path): Outputs to a file - Output: Raw compiler API Output object to save to. - target_backends: A comma-delimitted string of IREE target backends or - a sequence of strings. - Returns: - None unless if `save_to=None`, in which case, we return the backing compiler API - Ouptut object. It can be queried for its backing memory via its `map_memory()` - method. - """ - return_memory_view = False - if save_to is None: - output = Output.open_membuffer() - return_memory_view = True - elif isinstance(save_to, (str, Path)): - save_to = Path(save_to) - output = Output.open_file(str(save_to)) - else: - output = save_to - assert isinstance(output, Output) - - target_backends = ( - target_backends - if isinstance(target_backends, str) - else ",".join(target_backends) - ) - inv = self.session.invocation() - if self._importer_uses_session: - inv.import_module(self.mlir_module) - else: - # Some platforms can't share the context across the importer and - # session (cough: Windows). Round-trip in this case. - buffer_io = io.BytesIO() - self.mlir_module.write_bytecode(buffer_io) - buffer = buffer_io.getvalue() - source = Source.wrap_buffer(self.session, buffer) - inv.parse_source(source) - inv.enable_console_diagnostics() - - # TODO: Don't use flags to set the target backends: set module attributes. - self.session.set_flags(f"--iree-hal-target-backends={target_backends}") - if not inv.execute(): - raise RuntimeError("Compilation failed: See diagnostics") - - inv.output_vm_bytecode(output) - output.keep() - if return_memory_view: - return output - else: - return None - - -@overload -def export( - module: torch.nn.Module, - /, - *, - args: Optional[tuple] = None, - kwargs: Optional[Dict[str, Any]] = None, - dynamic_shapes: Dict[str, Any] | Tuple[Any] | List[Any] | None = None, - module_name: Optional[str] = None, - function_name: Optional[str] = None, -) -> ExportOutput: - """Exports a torch.nn.Module. - - This is done by first invoking torch.export.export with args, kwargs, - and dynamic_shapes. - """ - ... - - -@overload -def export(compiled_module: Type[CompiledModule], /) -> ExportOutput: - """Exports a CompiledModule and returns the output.""" - ... - - -@overload -def export( - exported_program: torch.export.ExportedProgram, - /, - *, - module_name: Optional[str] = None, - function_name: Optional[str] = None, -) -> ExportOutput: - """Exports a single entry-point module consisting of an ExportedProgram.""" - ... - - -@overload -def export( - exported_programs: FxPrograms, - /, - *, - module_name: Optional[str] = None, -) -> ExportOutput: - """Exports a multi entry-point ExportedProgram.""" - ... - - -def export( - mdl: ModuleLike, - /, - *example_args: torch.Tensor, - args: Optional[tuple] = None, - kwargs: Optional[Dict[str, Any]] = None, - dynamic_shapes: Dict[str, Any] | Tuple[Any] | List[Any] | None = None, - module_name: Optional[str] = None, - function_name: Optional[str] = None, -) -> ExportOutput: - """Generic export of supported entities. - - See a more specific overload for accepted forms. - - This function behaves differently based on the type of the `mdl` argument: - - * nn.Module: The module is traced with torch.export.export passing it - `args`, `kwargs`, and `dynamic_shapes`. - * CompiledModule: The module is imported to IR. Additional arguments are - illegal in this case. - * torch.export.ExportedProgram: A pre-exported program can be passed and - it will be used to construct a single-entrypoint module. - - Args: - mdl: The nn.Module to export. - *example_args: Example tensors. - args: Example arguments to torch.export (if present, then *example_args - must be empty. - kwargs: Example keyword arguments. - dynamic_shapes: Dynamic shape specs to pass to torch.export. - - Returns: - An ExportOutput object that wraps the compilation and provides - easy access. - """ - if len(example_args) > 0: - warnings.warn( - DeprecationWarning( - "extra `example_args` positional parameters are deprecated: pass `args=tuple(...)` instead." - ) - ) - - TransformedModule: Any - current_decomps = decompositions.current_aot_decompositions() - if isinstance(mdl, torch.export.ExportedProgram): - TransformedModule = CompiledModule.create_from_dict( - "LambdaCompiledModule", - {(function_name or "main"): mdl}, - export_name=module_name or "module", - ) - - elif isinstance(mdl, FxPrograms): - TransformedModule = CompiledModule.create_from_dict( - "LambdaCompiledModule", mdl.programs, export_name=module_name or "module" - ) - elif isinstance(mdl, torch.nn.Module): - # Normalize arguments for torch.export. - if args is None: - args = example_args - elif len(example_args) > 0: - raise ValueError( - "Cannot pass args= and positional example_args at the same time" - ) - nn_module = mdl - exported_program = torch.export.export( - nn_module, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes - ) - if current_decomps: - from .decompositions import _patch_op_dispatch_for_export - - _patch_op_dispatch_for_export() - exported_program = exported_program.run_decompositions(current_decomps) - - TransformedModule = CompiledModule.create_from_dict( - "LambdaCompiledModule", - {(function_name or "main"): exported_program}, - export_name=module_name or "module", - ) - elif issubclass(mdl, CompiledModule): - TransformedModule = mdl - else: - raise TypeError(f"mdl argument (type: {type(mdl)}) is not a supported type") - - session = Session() - # There are some bugs with respect to Session/context interop that we - # haven't squashed yet. For now, default everyone to round-tripping - # via bytecode vs sharing the context between the importer/compiler. - importer_uses_session = False and not _is_windows - if importer_uses_session: - context = session.context - else: - context = Context() - - cm = TransformedModule(context=context, import_to="import") - return ExportOutput(session, cm, importer_uses_session=importer_uses_session) diff --git a/core/shark_turbine/aot/fx_programs.py b/core/shark_turbine/aot/fx_programs.py deleted file mode 100644 index b0f869204..000000000 --- a/core/shark_turbine/aot/fx_programs.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Helper classes for assembling sets of FX modules that can be compiled. - -This uses the `torch.export` machinery. However, it provides some extra -services for handling multiple modules, save/load, and state management. -""" - -import json -import os -from pathlib import Path -from typing import Any, Optional, Union - -import functools - -import torch -import torch.nn as nn - -from .decompositions import current_aot_decompositions - -# The dynamic_shapes support showed up in the Torch 2.3 timeframe. -_supports_dynamic_shapes = hasattr(torch.export, "Dim") - - -class FxPrograms: - """Represents a named set of ExportedPrograms. - - This facility works around a design flaw in Torch where they conflated - ExportedPrograms as representing a single entry-point while also having - each instance persist its own state_dict and constants. How many times, - in how many frameworks, do we have to fight this design flaw? Apparently - once more. - - This base class represents the set of programs, either loaded from storage - or built live. The tricky part it is managing is to do all of this while - aliasing state and captured constants. Having those be physically shared - is an essential optimization. - - In order to manage saving/loading of the set of things, we manually splice - the state_dict and constants dict such that while saving, we only persist - the first encountered instance of any reference. Any subsequent instances - are replaced with a SharedStateTensor, which on load can be re-associated. - - As this is primarily targeted at being able to decouple FX tracing from - further manipulation (which for reasons unknown, is competing with the - race of entropy to the heat death of the universe in terms of performance), - we don't take a lot of pains to be optimized for distribution or storage of - the resulting artifacts. - - In the future, this same technique could be employed to elide parameters - that we know we are going to resolve symbolically later, keeping them from - being loaded and consuming memory during model export and compilation. - - We have faith that in the fullness of time, the design flaws in Torch that - require this kind of thing to exist will be resolved, and we then won't - need this hack. - """ - - def __init__(self): - self.programs: dict[str, torch.export.ExportedProgram] = {} - - def save(self, path: Union[str, os.PathLike]) -> int: - """Saves the set of exported programs to a descriptor file. - - Returns the number of tensors deduped (for debugging/testing). - """ - path = Path(path).resolve() - - def permute_path(name): - return path.parent / f"{path.stem}_{name}.pt2" - - # Assemble descriptor. - program_files = {name: str(permute_path(name)) for name in self.programs.keys()} - descriptor = { - "load_order": list(program_files.keys()), - "program_files": program_files, - } - - # Accumulate shared state as we go. - shared_state_dict: dict[str, Any] = {} - shared_constants: dict[str, Any] = {} - count_deduped = 0 - - # Save each. - for program_name, ep in self.programs.items(): - # First validate the ep with normal rules, which we will then - # disable since we are violating the spec. - ep._validate() - orig_state_dict = dict(ep.state_dict) - constants_dict = _get_optional_constants(ep) - orig_constants = dict(constants_dict) - - try: - # Now unmerge the state_dict and constants by knocking it up against - # our running shared state dict. - count_deduped += _sharify_state_dict(shared_state_dict, ep.state_dict) - count_deduped += _sharify_state_dict(shared_constants, constants_dict) - - # And save our hacked program. - save_path = program_files[program_name] - torch.export.save(ep, save_path) - finally: - ep.state_dict.clear() - ep.state_dict.update(orig_state_dict) - constants_dict.clear() - constants_dict.update(orig_constants) - - # Save the descriptor. - with open(path, "wt") as f: - json.dump(descriptor, f) - return count_deduped - - @staticmethod - def load(path: Union[str, os.PathLike]) -> "FxPrograms": - instance = FxPrograms() - path = Path(path).resolve() - with open(path, "rb") as f: - descriptor = json.load(f) - - shared_state_dict: dict[str, Any] = {} - shared_constants: dict[str, Any] = {} - - for program_name in descriptor["load_order"]: - program_file_name = descriptor["program_files"][program_name] - ep = torch.export.load(path.parent / program_file_name) - _unsharify_state_dict(shared_state_dict, ep.state_dict) - _unsharify_state_dict(shared_constants, _get_optional_constants(ep)) - instance.programs[program_name] = ep - return instance - - -class FxProgramsBuilder(FxPrograms): - """Builds a new set of exported programs that are all variations of the - same root nn.Module. - - This can be used to construct multi-entrypoint sets of ExportedPrograms - in a way that alias information is preserved for lifted tensors. - - Usage: - - ``` - class MyModule(nn.Module): - ... - - fxb = FxProgramBuilder(MyModule()) - - @fxb.export_program(args=example_args) - def entrypoint(m, x, y): - return m.forward(x, y) - - fxb.save("/some/path.json") - ``` - """ - - def __init__(self, root_module: nn.Module): - super().__init__() - self.root_module = root_module - - def export_program( - fx_builder, - f=None, - *, - args=None, - kwargs=None, - dynamic_shapes=None, - name: Optional[str] = None, - ): - if f is None: - return functools.partial( - fx_builder.export_program, - args=args, - kwargs=kwargs, - dynamic_shapes=dynamic_shapes, - name=name, - ) - - if name is None: - name = f.__name__ - if name in fx_builder.programs: - raise ValueError(f"Attempt to export program '{name}' multiple times") - - class LambdaModule(nn.Module): - def __init__(self): - super().__init__() - self.add_module("root", fx_builder.root_module) - - # Here we do a tricky thing: The free-function that we take has - # signature: - # def free_function(root_module, arg1, *, kwarg1) - # Since the export machinery expects to be able to inspect and query - # based on user-specified argument names ("arg1", "kwarg1" above), - # we use the usual @functools.wraps to copy metadata. Because we wrap - # it before adding it to the class, the first-arg of the free function - # ("root_module" above) lines up with the usual "self" arg of a method - # attached to a class. When instantiated and created, this synthetic - # 'forward' method will inspect as only taking the user-specified - # argument names (i.e. "arg1", "kwarg1") because the class machinery - # swallowed the first, which is exactly the one we wanted to elide - # from Dynamo's view anyway. - # If we weren't doing this, we would need to munge the signature - # descriptors to line up because the export machinery needs to see - # the user-specified function arguments, not our "pseudo-self" root - # module argument that we always pass. - # Note that to keep Dynamo happy, we are careful to only access - # names and attributes in the module tree (vs from the surrounding - # closure, which goes down less well-trodden paths). - @functools.wraps(f) - def new_forward(self, *forward_args, **forward_kwargs): - return f(self.root, *forward_args, **forward_kwargs) - - setattr(LambdaModule, "forward", new_forward) - lambda_module = LambdaModule() - - # Export our franken-module. - extra_kwargs = {} - if dynamic_shapes: - if not _supports_dynamic_shapes: - raise ValueError( - f"torch.export with dynamic_shapes= not supported for this version of torch" - ) - extra_kwargs["dynamic_shapes"] = dynamic_shapes - program = torch.export.export( - lambda_module, args=args, kwargs=kwargs, **extra_kwargs - ) - current_decomps = current_aot_decompositions() - if current_decomps: - from .decompositions import _patch_op_dispatch_for_export - - _patch_op_dispatch_for_export() - program = program.run_decompositions(current_decomps) - fx_builder.programs[name] = program - return program - - -class SharedStateTensor(torch.Tensor): - """A fake tensor that we shove into ExportedProgram state to share.""" - - @staticmethod - def __new__( - cls, - size, - dtype, - shared_state_dict_key: str, - is_param: bool, - requires_grad=False, - ): - # Using a meta tensor as the wrapped gives us shape and dtype - # propagation. - return torch.Tensor._make_subclass( - cls, - torch.empty(size, dtype=dtype, device="meta"), - require_grad=requires_grad, - ) - - def __init__( - self, - size, - dtype, - shared_state_dict_key: str, - is_param: bool, - requires_grad=False, - ): - self.shared_state_dict_key = shared_state_dict_key - # Magic attribute that makes isinstance(t, Parameter) True. - # See torch.nn.Parameter. - self._is_param = is_param - - -def _create_shared_state_tensor( - like: torch.Tensor, shared_state_dict_key: str -) -> SharedStateTensor: - t = SharedStateTensor( - like.size(), - like.dtype, - shared_state_dict_key=shared_state_dict_key, - is_param=isinstance(like, torch.nn.Parameter), - requires_grad=like.requires_grad, - ) - return t - - -def _sharify_state_dict(shared_dict: dict, local_dict: dict) -> int: - count_deduped = 0 - for key, local_value in local_dict.items(): - if not isinstance(local_value, torch.Tensor): - continue - if key in shared_dict: - shared_value = shared_dict[key] - assert ( - shared_value is local_value - ), f"State dict key collision results in different instances ({key})!" - local_dict[key] = _create_shared_state_tensor(local_value, key) - count_deduped += 1 - else: - # Remember the original for the next time. - shared_dict[key] = local_value - return count_deduped - - -def _unsharify_state_dict(shared_dict: dict, local_dict: dict): - for key, local_value in local_dict.items(): - if not isinstance(local_value, torch.Tensor): - continue - if isinstance(local_value, SharedStateTensor): - # Replace shared state tensor. - shared_key = local_value.shared_state_dict_key - try: - shared_value = shared_dict[shared_key] - except KeyError as e: - raise KeyError( - f"Shared tensor not found during deserialization. Corrupt metadata? " - f"{shared_key}" - ) - local_dict[key] = shared_value - else: - # Remember this one for later. - shared_dict[key] = local_value - - -def _get_optional_constants(ep: torch.export.ExportedProgram) -> dict[str, Any]: - """Constants showed up in early 2.3 timeframe. - - Returns an empty dict if not supported. - """ - try: - return ep.constants # type: ignore - except AttributeError: - assert torch.__version__ < "2.3.dev1", "Constants should be available" - return dict() diff --git a/core/shark_turbine/aot/params.py b/core/shark_turbine/aot/params.py deleted file mode 100644 index 548586484..000000000 --- a/core/shark_turbine/aot/params.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Iterator, List, Optional, Set, Tuple, Union - -import json -from pathlib import Path -import warnings - -import numpy as np -import torch -import torch.nn as nn - -from iree.runtime import ( - ParameterIndex, - ParameterIndexEntry, -) - -from .tensor_traits import ( - ExternalTensorTrait, -) - - -__all__ = [ - "externalize_module_parameters", - "save_module_parameters", - "ParameterArchive", - "ParameterArchiveEntry", - "ParameterArchiveBuilder", -] - -################################################################################ -# Parameter externalization -################################################################################ - - -def externalize_module_parameters( - module: nn.Module, *, external_scope: str = "", prefix: str = "" -): - """Externalizes parameters and persistent buffers in a module by name.""" - - for tensor_name, tensor in _yield_saveable_tensors(module, prefix=prefix): - trait = ExternalTensorTrait( - external_scope=external_scope, external_name=tensor_name - ) - trait.set(tensor) - - -################################################################################ -# Metadata -################################################################################ - -_dtype_to_name: dict[torch.dtype, str] = { - torch.float32: "float32", - torch.float64: "float64", - torch.complex64: "complex64", - torch.complex128: "complex128", - torch.float16: "float16", - torch.bfloat16: "bfloat16", - torch.int8: "int8", - torch.int16: "int16", - torch.int32: "int32", - torch.int64: "int64", - torch.uint8: "uint8", - torch.bool: "bool", -} - - -# Deal with datatypes not yet added in all versions of Torch. -def _add_optional_dtype(name: str): - try: - dtype = getattr(torch, name) - except AttributeError: - return - _dtype_to_name[dtype] = name - - -_add_optional_dtype("float8_e4m3fn") -_add_optional_dtype("float8_e4m3fnuz") -_add_optional_dtype("float8_e5m2") -_add_optional_dtype("float8_e5m2fnuz") -_add_optional_dtype("uint16") -_add_optional_dtype("uint32") -_add_optional_dtype("uint64") - - -_name_to_dtype: dict[str, torch.dtype] = {v: k for k, v in _dtype_to_name.items()} - -_metadata_prefix = "PYTORCH:" - - -def _make_tensor_metadata(t: torch.Tensor) -> str: - """Makes a tensor metadata blob that can be used to reconstruct the tensor.""" - dtype = t.dtype - try: - dtype_name = _dtype_to_name[dtype] - except KeyError: - dtype_name = "unknown" - warnings.warn( - f"Unknown dtype saving params: {dtype} (missing entry in params._dtype_to_name)" - ) - dtype_desc = { - "class_name": type(dtype).__name__, - "is_complex": dtype.is_complex, - "is_floating_point": dtype.is_floating_point, - "is_signed": dtype.is_signed, - "itemsize": dtype.itemsize, - } - d = { - "type": "Tensor", - "dtype": dtype_name, - "shape": list(t.shape), - "dtype_desc": dtype_desc, - } - encoded = f"{_metadata_prefix}{json.dumps(d)}" - return encoded - - -################################################################################ -# Parameter archives save/load -################################################################################ - - -def save_module_parameters( - file_path: Union[str, Path], module: nn.Module, *, prefix: str = "" -): - """One shot save of parameters and persistent buffers on a module. - - More options are available by using a ParameterArchiveBuilder. - """ - builder = ParameterArchiveBuilder() - builder.add_module(module, prefix=prefix) - builder.save(file_path) - - -class ParameterArchiveEntry: - """Wraps a raw ParameterIndexEntry with additional helpers.""" - - def __init__(self, raw: ParameterIndexEntry): - self.raw = raw - - @property - def key(self) -> str: - return self.raw.key - - def as_flat_tensor(self) -> torch.Tensor: - """Accesses the contents as a uint8 flat tensor. - - If it is a splat, then the tensor will be a view of the splat pattern. - - Raises a ValueError on unsupported entries. - """ - if self.raw.is_file: - wrapper = np.array(self.raw.file_view, copy=False) - elif self.raw.is_splat: - wrapper = np.array(self.raw.splat_pattern, copy=True) - else: - raise ValueError(f"Unsupported ParameterIndexEntry: {self.raw}") - - return torch.from_numpy(wrapper) - - def as_tensor(self) -> torch.Tensor: - """Returns a tensor viewed with appropriate shape/dtype from metadata. - - Raises a ValueError if unsupported. - """ - # Decode metadata. - metadata = self.raw.metadata.decode() - if not metadata.startswith(_metadata_prefix): - raise ValueError( - f"No metadata for parameter entry {self.key}: Cannot convert to tensor" - ) - metadata = metadata[len(_metadata_prefix) :] - d = json.loads(metadata) - try: - type_name = d["type"] - if d["type"] != "Tensor": - raise ValueError( - f"Metadata for parameter entry {self.key} is not a Tensor ('{type_name}')" - ) - dtype_name = d["dtype"] - shape = d["shape"] - except KeyError as e: - raise ValueError(f"Bad metadata for parameter entry {self.key}") from e - - # Unpack/validate. - try: - dtype = _name_to_dtype[dtype_name] - except KeyError: - raise ValueError(f"Unknown dtype name '{dtype_name}'") - try: - shape = [int(d) for d in shape] - except ValueError as e: - raise ValueError(f"Illegal shape for parameter entry {self.key}") from e - - t = self.as_flat_tensor() - return t.view(dtype=dtype).view(shape) - - def __repr__(self): - return f"ParameterArchiveEntry({self.raw}, metadata={self.raw.metadata})" - - -class ParameterArchive: - """Allows access to a parameter archive as CPU tensors. - - TODO: Add more helpers for reading tensors once we get upstream versions that - have that integrated. - """ - - def __init__( - self, - file_path: Optional[Union[str, Path]] = None, - *, - mmap: bool = True, - readable: bool = True, - writable: bool = False, - ): - self._index = ParameterIndex() - if file_path is not None: - self.load(file_path, mmap=mmap, readable=readable, writable=writable) - - def load( - self, - file_path: Union[str, Path], - *, - mmap: bool = True, - readable: bool = True, - writable: bool = False, - ): - """Loads index entries from a file adding them to the in-memory archive.""" - self._index.load( - str(file_path), mmap=mmap, readable=readable, writable=writable - ) - - @property - def index(self) -> ParameterIndex: - return self._index - - def items(self) -> List[Tuple[str, ParameterArchiveEntry]]: - """Returns the items in the archive. - - Note that there can be duplicates if the archive was constructed that way. - """ - return [(k, ParameterArchiveEntry(v)) for k, v in self._index.items()] - - def __repr__(self): - return repr(self._index) - - -class ParameterArchiveBuilder: - """Helper for building parameter archives from live modules.""" - - def __init__(self): - self._index = ParameterIndex() - - def save(self, file_path: Union[str, Path]): - """Saves the archive.""" - self._index.create_archive_file(str(file_path)) - - def add_tensor(self, name: str, tensor: torch.Tensor): - """Adds an named tensor to the archive.""" - flat_array = tensor.detach().flatten().contiguous().cpu().view(torch.uint8) - host_array = flat_array.numpy() - self._index.add_buffer(name, host_array, metadata=_make_tensor_metadata(tensor)) - - def add_module(self, module: nn.Module, *, prefix: str = ""): - """Adds all parameters and persistent buffers from a module hierarchy.""" - for name, t in _yield_saveable_tensors(module, prefix=prefix): - self.add_tensor(name, t) - - def add_blob(self, key: str, blob): - """Adds a raw blob to the index. - - The blob must be interpretable as a buffer. - """ - self._index.add_buffer(key, blob) - - -def _yield_saveable_tensors( - module: nn.Module, *, prefix: str = "" -) -> Iterator[Tuple[str, torch.Tensor]]: - """Yields tuple of name/tensor for all saveable tensors in a module. - - This includes parameters and persistent buffers. - """ - memo: Set[str] = set() - for sub_name, sub_module in module.named_modules(prefix=prefix): - state_dict = sub_module.state_dict() - for param_name, param in sub_module.named_parameters(recurse=False): - full_param_name = f"{sub_name}.{param_name}" if sub_name else param_name - if full_param_name in memo: - continue - memo.add(full_param_name) - yield full_param_name, param - for buffer_name, buffer in sub_module.named_buffers(recurse=False): - full_buffer_name = f"{sub_name}.{buffer_name}" if sub_name else buffer_name - if full_buffer_name in memo: - continue - memo.add(full_buffer_name) - if buffer_name not in state_dict: - # Non persistent - continue - yield full_buffer_name, buffer diff --git a/core/shark_turbine/aot/passes/__init__.py b/core/shark_turbine/aot/passes/__init__.py deleted file mode 100644 index 167b8b886..000000000 --- a/core/shark_turbine/aot/passes/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .functorch import functorch_functionalize diff --git a/core/shark_turbine/aot/passes/functorch.py b/core/shark_turbine/aot/passes/functorch.py deleted file mode 100644 index 06967ecfa..000000000 --- a/core/shark_turbine/aot/passes/functorch.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Any, Callable - -import torch -from torch.fx import ( - GraphModule, -) -from torch.fx.experimental import proxy_tensor -from torch.utils import _pytree as pytree - - -# Use the functorch `functionalize()` helper. That cannot be used directly -# because it does not correctly handle fake tensor tracing. But we use -# the underlying dispatcher controls to enable/disable it and perform -# the transform. The approach was lifted from what ONNX is doing and a -# number of issues. In its present form it has a number of issues: -# 1. Cannot trace structured inputs and will drop output signature -# rewrites usually done by torch.export, rending structured -# results a non-starter if used as a transform after torch.export. -# 2. Will not play nicely with an enclosing, user specified fake mode. -# There is a lot of code on the ONNX side to enable this, but I -# don't have test cases for it and don't want to just blindly -# adapt dead code. -# 3. Loses backtrace information. The ONNX side has a helper that -# re-associates this, but it wasn't obvious it would work in our -# exact scenario. -# Further, it is not clear at all why this is using such heavy-weight -# facilities to do a simple graph transformation. I expect that we just -# need to write a pure FX pass to do the functionalization transform to -# our liking and shoot this into the sun. If we spend any time at all -# debugging the issues that can arise from all of this layering, we -# should just do that. -# -# For the reasons above, we only use this as a *pre-export* transformation, -# since that does not result in load bearing information loss. Note that -# ONNX applies this post export, which suffers from the loss of output -# destructuring rewrites that torch.export does. -def functorch_functionalize(gm_callable: Any, *args) -> GraphModule: - functionalized_callable = _functionalize_callabale(gm_callable) - # TODO: There is more of a dance needed if the user has entered with a fake_mode. - with proxy_tensor.maybe_disable_fake_tensor_mode(): - new_gm = proxy_tensor.make_fx( - functionalized_callable, - decomposition_table={}, - tracing_mode="symbolic", - _allow_non_fake_inputs=True, - _allow_fake_constant=False, - )(*args) - - return new_gm - - -def _functionalize_callabale(function: Callable) -> Callable: - def wrapped(*args): - args_functional = pytree.tree_map_only( - torch.Tensor, torch._to_functional_tensor, args - ) - torch._enable_functionalization(reapply_views=True) - try: - out = function(*args_functional) - finally: - torch._disable_functionalization() - # Do a dance to re-associate inputs. - flat_inputs, _ = pytree.tree_flatten(args) - flat_inputs_functional, _ = pytree.tree_flatten(args_functional) - for input_raw, input_functional in zip(flat_inputs, flat_inputs_functional): - if isinstance(input_functional, torch.Tensor): - torch._sync(input_functional) - torch._from_functional_tensor(input_functional) - pytree.tree_map_only(torch.Tensor, torch._sync, out) - out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out) - return out_unwrapped - - return wrapped diff --git a/core/shark_turbine/aot/support/ir_utils.py b/core/shark_turbine/aot/support/ir_utils.py deleted file mode 100644 index 813b7f4bd..000000000 --- a/core/shark_turbine/aot/support/ir_utils.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Callable, Dict, Optional, Sequence, Tuple - -from pathlib import Path -import tempfile - -import numpy as np -import torch - -from iree.compiler.extras.fx_importer import ( - ContextCache, - Empty, - EmptyType, - RefTracker, -) - -from ...dynamo.type_conversion import ( - NativeTypeConverter, -) - -from ...support.ir_imports import ( - AsmState, - Attribute, - BF16Type, - DenseElementsAttr, - DenseResourceElementsAttr, - F16Type, - F32Type, - F64Type, - FloatAttr, - FunctionType, - IndexType, - InsertionPoint, - IntegerAttr, - IntegerType, - IrType, - Location, - MLIRError, - Operation, - RankedTensorType, - StringAttr, - SymbolTable, - TypeAttr, - UnitAttr, - Value, - arith_d, - func_d, - tensor_d, -) - -from ...support.conversions import ( - TORCH_DTYPE_TO_IREE_TYPE, -) - -from ...support.logging import aot_logger as logger - -from ..tensor_traits import ( - ExternalTensorTrait, -) - -############################################################################### -# Configuration -############################################################################### - -# Maps a name to an altered name. If returns None, then the original -# name is used (this lets dict.get serve as a NameMapCallback). -NameMapCallback = Callable[[str], Optional[str]] - - -class GlobalAttributes: - """Settings for how to initialize the global.""" - - __slots__ = [ - "mutable", - "external", - "external_scope", - "name_mapper", - "noinline", - "uninitialized", - ] - - def __init__( - self, - mutable: bool = False, - external: Optional[bool] = None, - external_scope: Optional[str] = None, - name_mapper: Optional[NameMapCallback] = None, - noinline: bool = False, - uninitialized: Optional[bool] = None, - ): - if external and uninitialized: - raise ValueError( - f"Globals with external=True cannot also have uninitialized=True" - ) - if uninitialized and not mutable: - raise ValueError( - f"Globals with uninitialized=True must also be mutable=True" - ) - self.mutable = mutable - self.external = external - self.external_scope = external_scope - self.name_mapper = name_mapper - self.noinline = noinline - self.uninitialized = uninitialized - - def map_name(self, name: str) -> str: - if self.name_mapper: - new_name = self.name_mapper(name) - if new_name is not None: - return new_name - return name - - def infer_external_from_tensor( - self, t: torch.Tensor - ) -> Tuple[bool, Optional[str], Optional[str]]: - """If externality is not specified, infers it from the tensor.""" - # We check for the first item in a list because this lets us in the - # future extend the list by unwrapping. - check_tensors = [t] - for check_t in check_tensors: - trait = ExternalTensorTrait.get(check_t) - if trait is None: - continue - try: - external_scope = trait.external_scope - external_name = trait.external_name - except AttributeError as e: - raise AttributeError( - f"Tensor defines _is_turbine_external_tensor but not other fields: {type(t)} = {t}" - ) - return ( - True, - external_scope if self.external_scope is None else self.external_scope, - external_name, - ) - - return bool(self.external), self.external_scope, None - - -############################################################################### -# Builders -############################################################################### - - -class ModuleBuilder: - """Wrapper around module and IR accounting for a module being built.""" - - __slots__ = [ - "body", - "cache", - "context", - "fx_py_attr_tracker", - "global_ip", - "ip", - "module_op", - "symbol_table", - "global_ref_tracker", - "native_type_converter", - "_auto_symbol_counts", - ] - - def __init__(self, module_op: Operation): - self.module_op = module_op - self.context = module_op.context - self.body = module_op.regions[0].blocks[0] - self.symbol_table = SymbolTable(module_op) - self.global_ip = InsertionPoint.at_block_begin(self.body) - self.ip = InsertionPoint(self.body) - self.cache = ContextCache(self.context) - # Tracks global references to a MaterializedGlobal. - self.global_ref_tracker = RefTracker() - # Usually the FxImporter makes a new ref tracker for each invocation, - # but we want to preserve it across individual JIT evaluations so - # as to better intern tensors to attributes. - self.fx_py_attr_tracker = RefTracker() - self.native_type_converter = NativeTypeConverter(self.context) - self._auto_symbol_counts: Dict[str, int] = {} - - def unique_auto_symbol(self, requested_name: str) -> str: - if requested_name not in self._auto_symbol_counts: - self._auto_symbol_counts[requested_name] = 0 - return requested_name - count = self._auto_symbol_counts[requested_name] + 1 - self._auto_symbol_counts[requested_name] = count - return f"{requested_name}${count}" - - def handle_mlir_error(self, op: Operation, e: MLIRError, message: str): - # TODO: Replace with a real dumping facility. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/136 - dump_path = Path(tempfile.gettempdir()) / "turbine_module_builder_error.mlir" - logger.exception(f"{message} (dumping to {dump_path})") - try: - with open(dump_path, "wb") as f: - op.print( - file=f, - binary=True, - print_generic_op_form=True, - large_elements_limit=100, - ) - logger.debug(f"Dump complete to {dump_path}") - except Exception: - logger.exception("Error generating dump file") - - def finalize_construct(self): - try: - self.module_op.verify() - except MLIRError as e: - self.handle_mlir_error(self.module_op, e, "module failed to verify") - raise - - def create_func_op( - self, - symbol_name: str, - argument_types: Sequence[IrType], - is_public: bool = True, - add_entry_block: bool = True, - ) -> Tuple[str, func_d.FuncOp]: - with self.ip: - ftype = FunctionType.get(argument_types, []) - func_op = func_d.FuncOp(symbol_name, ftype) - if not is_public: - func_op.attributes["sym_visibility"] = StringAttr.get("private") - if add_entry_block: - func_op.add_entry_block() - self.symbol_table.insert(func_op) - actual_symbol_name = StringAttr(func_op.attributes["sym_name"]).value - return actual_symbol_name, func_op - - def torch_dtype_to_iree_type(self, dtype: torch.dtype) -> IrType: - try: - with self.context: - return TORCH_DTYPE_TO_IREE_TYPE[dtype]() - except KeyError: - raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type") - - def create_tensor_global( - self, - symbol_name: str, - t: torch.Tensor, - *, - attrs: GlobalAttributes, - logical_name: Optional[str] = None, - ) -> Tuple[str, Operation, IrType]: - element_type = self.torch_dtype_to_iree_type(t.dtype) - external, external_scope, external_name = attrs.infer_external_from_tensor(t) - - with self.global_ip, Location.unknown(): - tensor_type = RankedTensorType.get(list(t.shape), element_type) - ir_attrs = { - "sym_name": StringAttr.get(symbol_name), - "sym_visibility": StringAttr.get("private"), - "type": TypeAttr.get(tensor_type), - } - if attrs.noinline: - ir_attrs["noinline"] = UnitAttr.get() - if attrs.mutable: - ir_attrs["is_mutable"] = UnitAttr.get() - if external: - # Emit named external reference. - external_scope_attr = StringAttr.get(external_scope or "model") - external_name = ( - external_name - if external_name is not None - else attrs.map_name( - logical_name if logical_name is not None else symbol_name - ) - ) - external_name_attr = StringAttr.get(external_name) - # TODO: Have real Python builders for this. - ir_attrs["initial_value"] = Attribute.parse( - f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {tensor_type}" - ) - elif attrs.uninitialized: - # Emit unitialized initial_value to signal that the memory - # is valid but has undefined contents. - # TODO: Have real Python builders for this. - ir_attrs["initial_value"] = Attribute.parse( - f"#util.uninitialized : {tensor_type}" - ) - else: - # Emit inline initialized. - detached_tensor = t.detach().contiguous().cpu() - array = np.array(detached_tensor) - # We know that a Numpy array is a ReadableBuffer so ignore type error. - contents = memoryview(array) # type: ignore - blob_name = symbol_name - elements_attr = DenseResourceElementsAttr.get_from_buffer( - contents, blob_name, tensor_type - ) - ir_attrs["initial_value"] = elements_attr - - global_op = Operation.create("util.global", attributes=ir_attrs) - self.symbol_table.insert(global_op) - actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value - return actual_symbol_name, global_op, tensor_type - - def create_typed_global( - self, - symbol_name: str, - global_type: IrType, - *, - attrs: GlobalAttributes, - logical_name: Optional[str] = None, - ) -> Tuple[str, Operation]: - with self.global_ip, Location.unknown(): - ir_attrs = { - "sym_name": StringAttr.get(symbol_name), - "sym_visibility": StringAttr.get("private"), - "type": TypeAttr.get(global_type), - } - if attrs.noinline: - ir_attrs["noinline"] = UnitAttr.get() - if attrs.mutable: - ir_attrs["is_mutable"] = UnitAttr.get() - if attrs.uninitialized: - # Emit unitialized initial_value to signal that the memory - # is valid but has undefined contents. - # TODO: Have real Python builders for this. - ir_attrs["initial_value"] = Attribute.parse( - f"#util.uninitialized : {global_type}" - ) - else: - # Initialized by default. - ir_attrs["initial_value"] = self._create_initial_value_for_type( - global_type - ) - global_op = Operation.create("util.global", attributes=ir_attrs) - self.symbol_table.insert(global_op) - actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value - return actual_symbol_name, global_op - - def _create_initial_value_for_type(self, t: IrType) -> Attribute: - # TODO(#169): Implement something upstream for this (it exists in the C++ API) - # and use it. - if RankedTensorType.isinstance(t): - rtt = RankedTensorType(t) - if not rtt.has_static_shape: - raise ValueError( - "Cannot create initialization value for dynamic shaped tensor" - ) - element_attr = self._create_initial_value_for_type(rtt.element_type) - return DenseElementsAttr.get_splat(t, element_attr) - elif IntegerType.isinstance(t): - return IntegerAttr.get(t, 0) - elif F32Type.isinstance(t) or F64Type.isinstance(t) or F16Type.isinstance(t): - # TODO(#170): There should be a common way to check if a FloatType. - return FloatAttr.get(t, 0.0) - elif IndexType.isinstance(t): - return IntegerAttr.get(IndexType.get(), 0) - else: - raise ValueError( - f"Cannot create a default initialization value for type {t}" - ) - - -class FunctionBuilder: - """Helpers for building function bodies.""" - - __slots__ = [ - "module_builder", - "func_op", - "context", - "ip", - "return_types", - "loc", - ] - - def __init__( - self, - *, - module_builder: ModuleBuilder, - func_op: func_d.FuncOp, - ): - self.module_builder = module_builder - self.func_op = func_op - self.context = func_op.context - self.ip = InsertionPoint(self.func_op.entry_block) - self.return_types: Optional[Sequence[IrType]] = None - self.loc = self.func_op.location - - def emit_return(self, *ir_values: Value): - with self.loc, self.ip: - func_d.ReturnOp(ir_values) - # Check or rewrite the function return type. - value_types = [v.type for v in ir_values] - if self.return_types: - if value_types != self.return_types: - raise ValueError( - f"Multi-return function must return same types. " - f"{value_types} vs {self.return_types}" - ) - return - self.return_types = value_types - ftype = self.func_op.type - ftype = FunctionType.get(ftype.inputs, value_types) - self.func_op.attributes["function_type"] = TypeAttr.get(ftype) - try: - self.func_op.verify() - except MLIRError as e: - self.module_builder.handle_mlir_error( - self.func_op, e, "created function does not verify" - ) - raise - - -############################################################################### -# Helpers -############################################################################### - - -def build_index_attribute(value: int) -> IntegerAttr: - return IntegerAttr.get(IndexType.get(), value) - - -def build_index_value( - value: int, constant_cache: Optional[dict[int, Value]] = None -) -> Value: - if constant_cache is not None and value in constant_cache: - return constant_cache[value] - index_value = arith_d.ConstantOp(IndexType.get(), value).result - if constant_cache is not None: - constant_cache[value] = index_value - return index_value - - -def build_tensor_dim_value( - t: Value, dim: int, constant_cache: Optional[dict[int, Value]] = None -) -> Value: - dim_value = build_index_value(dim, constant_cache=constant_cache) - return tensor_d.DimOp(t, dim_value).result - - -# API name inspired by mlir/python/mlir/dialects/_arith_ops_ext.py -def _is_float_type(type): - return isinstance(type, (BF16Type, F16Type, F32Type, F64Type)) - - -def _is_integer_like_type(type): - return isinstance(type, (IntegerType, IndexType)) diff --git a/core/shark_turbine/aot/support/procedural/__init__.py b/core/shark_turbine/aot/support/procedural/__init__.py deleted file mode 100644 index f78c2d438..000000000 --- a/core/shark_turbine/aot/support/procedural/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# The procedural package has circular dependencies due to its -# nature. In an effort to modularize the code, we do allow circular -# imports and when used, they must be coherent with the load -# order here and must perform the import at the end of the module. - -from .base import * -from .iree_emitter import IREEEmitter -from .primitives import * -from .globals import * -from .tracer import * diff --git a/core/shark_turbine/aot/support/procedural/base.py b/core/shark_turbine/aot/support/procedural/base.py deleted file mode 100644 index b47932022..000000000 --- a/core/shark_turbine/aot/support/procedural/base.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import ( - Any, - Callable, - List, - Optional, - Sequence, -) - -from contextlib import contextmanager -import threading - -import torch -from torch.utils._pytree import tree_map - -from ....support.ir_imports import ( - F32Type, - F64Type, - IndexType, - IntegerType, - IrType, - Location, - RankedTensorType, - ShapedType, - Value, -) - -from ..ir_utils import ( - FunctionBuilder, - ModuleBuilder, -) - -ShapedTypeDynamicSizeSentinel = ShapedType.get_dynamic_size() -_thread_state = threading.local() - -############################################################################### -# Tracing intrinsics -############################################################################### - - -class ProcedureTraceError(Exception): - def __init__(self, message: str): - super().__init__(message) - - -class IrTrace(FunctionBuilder): - """Gets callbacks for tracing events.""" - - __slots__ = [] - - def finalize(self): - """Called when the trace is finished (popped off the stack).""" - pass - - def handle_call(self, target: "Intrinsic", args, kwargs): - raise NotImplementedError(f"The current trace scope does not support calls") - - def handle_assignment(self, scope, target, updated_value): - raise NotImplementedError( - f"The current trace scope does not support assignment" - ) - - -def _trace_scopes() -> List[IrTrace]: - try: - trace_scopes = _thread_state.trace_scopes - except AttributeError: - trace_scopes = _thread_state.trace_scopes = [] - return trace_scopes - - -@contextmanager -def new_ir_trace_scope(ir_trace: IrTrace): - trace_scopes = _trace_scopes() - trace_scopes.append(ir_trace) - try: - yield ir_trace - finally: - ir_trace.finalize() - del trace_scopes[-1] - - -def current_ir_trace() -> IrTrace: - return _trace_scopes()[-1] - - -class Intrinsic: - """Objects which interact natively with the tracing system implement this.""" - - __slots__: List[str] = [] - - def resolve_ir_values(self, proc_trace: "IrTrace") -> Sequence[Value]: - raise NotImplementedError( - f"Cannot use {self} as an expression in a procedural function" - ) - - def resolve_call(self, proc_trace: "IrTrace", *args, **kwargs): - raise NotImplementedError( - f"Cannot use {self} as the target of a call in a procedural function" - ) - - def resolve_assignment(self, proc_trace: "IrTrace", ir_values: Sequence[Value]): - raise NotImplementedError( - f"Cannot use {self} as the target of an assignment in a procedural function" - ) - - # Helpers for accessing the ir_value within the current trace. - @property - def ir_values(self) -> Sequence[Value]: - return self.resolve_ir_values(current_ir_trace()) - - @property - def ir_value(self) -> Value: - values = self.ir_values - assert len(values) == 1, "Expected arity one intrinsic" - return values[0] - - -class CallableIntrinsic(Intrinsic): - """Intrinsic subclass that supports calls. - - This is separate so as to make error handling better (i.e. does not support - calls) for intrinsics that are not callable. - """ - - __slots__ = [] - - def __call__(self, *args, **kwargs): - return current_ir_trace().handle_call(self, args, kwargs) - - -class AbstractIntrinsic: - """Base class for descriptor types that can be converted to Python proxies.""" - - __slots__: List[str] = [] - - def create_intrinsic(self, value: Value) -> Intrinsic: - """Creates a proxy object that can flow through a procedural trace.""" - raise NotImplementedError - - def get_ir_type(self, builder: ModuleBuilder) -> IrType: - """Gets the corresponding IR type.""" - raise NotImplementedError - - -############################################################################### -# Abstract types -############################################################################### - - -class AbstractTypedef: - """Base class for instances which declare some form of public arg/result type definition.""" - - def get_ir_type(self, builder: ModuleBuilder) -> IrType: - raise NotImplementedError - - -class Abstractifiable: - """Indicates that a type knows how to abstractify itself.""" - - def abstractify(self) -> AbstractTypedef: - raise NotImplementedError - - -class TreeAbstractifiable: - """Indicates that a type decomposes into a tree that can be abstractified.""" - - def abstractify_tree(self) -> Any: - raise NotImplementedError - - -class AbstractTensor(AbstractIntrinsic, AbstractTypedef): - """Represents a tensor of known rank and dtype.""" - - __slots__ = [ - "size", - "dtype", - ] - - def __init__(self, *size: Optional[int], dtype: torch.dtype = torch.float32): - self.size = tuple(size) - self.dtype = dtype - - def __repr__(self): - return f"AbstractTensor({', '.join(str(s) for s in self.size)}, dtype={self.dtype})" - - def create_intrinsic(self, ir_value: Value) -> Intrinsic: - return IrImmediateTensor(ir_value, self.dtype) - - def get_ir_type(self, builder: ModuleBuilder) -> IrType: - element_type = builder.torch_dtype_to_iree_type(self.dtype) - with Location.unknown(builder.context): - tensor_type = RankedTensorType.get( - [ - s if s is not None else ShapedTypeDynamicSizeSentinel - for s in self.size - ], - element_type, - ) - return tensor_type - - -class AbstractScalar(AbstractIntrinsic, AbstractTypedef): - """Represents a scalar value of some type.""" - - __slots__ = [ - "label", - "type_producer", - ] - - def __init__(self, label: str, type_producer: Callable[[], IrType]): - self.label = label - self.type_producer = type_producer - - def __repr__(self): - return f"AbstractScalar({self.label})" - - def create_intrinsic(self, ir_value: Value) -> Intrinsic: - return IrImmediateScalar(ir_value) - - def get_ir_type(self, builder: ModuleBuilder) -> IrType: - with builder.context: - return self.type_producer() - - -# Concrete scalar types. -AbstractIndex = AbstractScalar("index", lambda: IndexType.get()) -AbstractF32 = AbstractScalar("f32", lambda: F32Type.get()) -AbstractF64 = AbstractScalar("f64", lambda: F64Type.get()) -AbstractBool = AbstractScalar("bool", lambda: IntegerType.get_signless(1)) -AbstractI32 = AbstractScalar("i32", lambda: IntegerType.get_signless(32)) -AbstractI64 = AbstractScalar("i64", lambda: IntegerType.get_signless(64)) - - -def abstractify_single_value(value) -> AbstractTypedef: - if isinstance(value, AbstractTypedef): - return value - if isinstance(value, Abstractifiable): - return value.abstractify() - if isinstance(value, torch.Tensor): - return AbstractTensor(*value.shape, dtype=value.dtype) - raise TypeError( - f"Cannot convert type {value.__class__} to an abstract type: {value}" - ) - - -def abstractify(tree): - if isinstance(tree, TreeAbstractifiable): - return tree.abstractify_tree() - return tree_map(abstractify_single_value, tree) - - -# Circular iports. -from .primitives import ( - IrImmediateScalar, - IrImmediateTensor, -) diff --git a/core/shark_turbine/aot/support/procedural/exported_program.py b/core/shark_turbine/aot/support/procedural/exported_program.py deleted file mode 100644 index 4f70059a2..000000000 --- a/core/shark_turbine/aot/support/procedural/exported_program.py +++ /dev/null @@ -1,377 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Any, Dict, List, Optional - -import inspect -import math - -import torch - -from torch.utils._pytree import ( - tree_flatten, - tree_unflatten, -) - -try: - from torch.utils._pytree import treespec_pprint -except ImportError: - # torch < 2.3 does not include this. - treespec_pprint = lambda x: repr(x) # type: ignore - -from iree.compiler.extras.fx_importer import ( - FxImporter, - FxImporterHooks, - GraphNodeImporter, - InputInfo, -) - -from ....support.logging import aot_logger as logger - -from ....support.ir_imports import ( - func_d, - util_d, - FlatSymbolRefAttr, - FunctionType, - IrType, - Operation, - StringAttr, - TypeAttr, - Value, -) - -from ...tensor_traits import ( - ExternalTensorTrait, -) - -from ..ir_utils import ( - GlobalAttributes, - ModuleBuilder, -) - -from .base import ( - CallableIntrinsic, -) - -from .globals import ( - GlobalsDef, - MaterializedGlobal, -) - -from .primitives import ( - IrImmediateTensor, - IrTensor, -) - -from .tracer import ( - IrTrace, -) - -# Limit of tensor volumes. Over this limit, otherwise uncategorized tensor -# constants will be emitted out-of-line. Under the limit, inline. -INLINE_TENSOR_VOLUME_LIMIT = 1024 - - -class ExportedProgramIntrinsic(CallableIntrinsic): - def __init__( - self, - entry_func_op: Operation, - entry_sig: torch.export.ModuleCallSignature, - user_output_dtypes: List[Optional[torch.dtype]], - ): - self.entry_func_op = entry_func_op - self.entry_sig = entry_sig - self.user_output_dtypes = user_output_dtypes - - @property - def function_type(self) -> FunctionType: - return TypeAttr(self.entry_func_op.attributes["function_type"]).value - - @property - def function_symbol(self) -> StringAttr: - return StringAttr(self.entry_func_op.attributes["sym_name"]) - - @property - def function_visibility(self) -> StringAttr: - return StringAttr(self.entry_func_op.attributes["sym_visibility"]) - - def resolve_call( - self, - proc_trace: IrTrace, - *py_args, - **py_kwargs, - ): - visibility = self.function_visibility - if visibility.value != "private": - raise ValueError( - f"Currently, only private ExportedPrograms can be called: " - f"{self.function_symbol} is {visibility}" - ) - - # Flatten and convert py args to torch IR values by converting to - # the canonical tree structure for args - # (tuple of list of args, dict of kwargs). - flat_py_args, args_tree = tree_flatten(((list(py_args),), py_kwargs)) - if args_tree != self.entry_sig.in_spec: - raise ValueError( - f"Mismatched arguments to exported program. \n" - f" Got: {treespec_pprint(args_tree)}\n" - f" Expected: {treespec_pprint(self.entry_sig.in_spec)} " - ) - function_type = self.function_type - flat_ir_args = [ - self._py_to_torch_ir(proc_trace, py_arg, torch_type) - for py_arg, torch_type in zip(flat_py_args, function_type.inputs) - ] - - # Call. - with proc_trace.ip, proc_trace.loc: - flat_ir_results = func_d.CallOp( - function_type.results, - FlatSymbolRefAttr.get(self.function_symbol.value), - flat_ir_args, - ).results - - # Convert torch IR values to python. - flat_py_results = [ - self._torch_ir_to_py(proc_trace, ir_value, dtype) - for ir_value, dtype in zip(flat_ir_results, self.user_output_dtypes) - ] - - return tree_unflatten(flat_py_results, self.entry_sig.out_spec) - - def _py_to_torch_ir( - self, proc_trace: IrTrace, py_value, torch_type: IrType - ) -> Value: - type_converter = proc_trace.module_builder.native_type_converter - if isinstance(py_value, IrTensor): - # TODO: Allow certain static info casts. - return type_converter.materialize_native_to_torch( - py_value.ir_value, torch_type - ) - else: - raise ValueError( - f"Unsupported type in arguments of call to ExportedProgram: " - f"{type(py_value)}: {py_value}" - ) - - def _torch_ir_to_py( - self, proc_trace: IrTrace, ir_value: Value, dtype: Optional[torch.dtype] - ): - type_converter = proc_trace.module_builder.native_type_converter - native_ir_value = type_converter.materialize_torch_to_native(ir_value) - if dtype is not None: - return IrImmediateTensor(native_ir_value, dtype) - else: - raise TypeError( - f"Unknown PyTorch->IREE value mapping for ExportedProgram output: " - f"{native_ir_value}" - ) - - -def import_exported_program( - module_builder: ModuleBuilder, - exported_program: torch.export.ExportedProgram, - symbol_name: str, - symbol_visibility: Optional[str], -) -> ExportedProgramIntrinsic: - fx_importer = _create_fx_importer(module_builder) - entry_func_op = fx_importer.import_program( - exported_program, func_name=symbol_name, func_visibility=symbol_visibility - ) - - module_call_graph = exported_program.module_call_graph - assert len(module_call_graph) >= 1, "Expected at least one module call signature" - entry_module_call_entry = module_call_graph[0] - assert ( - entry_module_call_entry.fqn == "" - ), "Expected first module call entry to be unnamed" - - # We want additional torch-level metadata about any user outputs. - # This will help us create a true python fake without loss of information. - # TODO: It is unclear how much switchiness is actually needed here as - # modern use is pretty constrained. Potentially streamline the body of - # the for loop once done with full test cases available. - user_output_dtypes: list[Optional[torch.dtype]] = [] - node_map: Dict[str, torch.fx.Node] = { - n.name: n for n in exported_program.graph.nodes - } - for user_output in exported_program.graph_signature.user_outputs: - output_node = node_map[user_output] - tensor_meta = output_node.meta.get("tensor_meta") - fake_val = output_node.meta.get("val") - dtype = None - if tensor_meta is not None: - dtype = tensor_meta.dtype - elif fake_val is not None: - dtype = fake_val.dtype - user_output_dtypes.append(dtype) - - return ExportedProgramIntrinsic( - entry_func_op, entry_module_call_entry.signature, user_output_dtypes - ) - - -class _Hooks(FxImporterHooks): - def __init__(self, module_builder: ModuleBuilder): - self.module_builder = module_builder - - def store_produced_value( - self, - gni: GraphNodeImporter, - py_value: Any, - produced_ir_value: Any, - info: InputInfo, - ): - module_builder = self.module_builder - # See if we know about it. - mapping = module_builder.global_ref_tracker.track(py_value) - if mapping.is_empty: - raise ValueError(f"Cannot store value to unmapped global for: {info}") - logger.debug("Resolved global for store %r", mapping) - materialized_global: MaterializedGlobal = mapping.value # type: ignore - converted_value = Operation.create( - "torch_c.to_builtin_tensor", - results=[materialized_global.ir_type], - operands=[produced_ir_value], - ).result - util_d.GlobalStoreOp(converted_value, materialized_global.symbol_name) - - def resolve_literal(self, gni: GraphNodeImporter, literal: Any) -> Optional[Value]: - # We support resolution of tracked reference types. Currently this - # only includes Tensors. All others we let the importer do what it - # is going to do. - if not isinstance(literal, torch.Tensor): - return None - - # See if we know about it. - materialized_global = self._lift_tensor_to_global(literal) - if not materialized_global: - # If it is unknown, just let the default importer take it on. - return None - - # Emit a global load and conversion. - vtensor_type = gni._cc.tensor_to_vtensor_type(literal) - loaded_value = util_d.GlobalLoadOp( - materialized_global.ir_type, materialized_global.symbol_name - ).result - converted_value = Operation.create( - "torch_c.from_builtin_tensor", - results=[vtensor_type], - operands=[loaded_value], - ).result - return converted_value - - def _lift_tensor_to_global( - self, literal: torch.Tensor - ) -> Optional[MaterializedGlobal]: - module_builder = self.module_builder - mapping = module_builder.global_ref_tracker.track(literal) - if not mapping.is_empty: - # Already materialized. - logger.debug("Resolved defined global for literal %r", mapping) - materialized_global: MaterializedGlobal = mapping.value # type: ignore - return materialized_global - - # Policy check: Should we auto-import? Generally, we keep "small" - # tensors as inline as they can be optimized. - external_trait = ExternalTensorTrait.get(literal) - if not self._should_lift_tensor_to_global(literal, external_trait): - return None - - # If it is a tensor we haven't seen yet, materialize it - # as a global and return. - if external_trait is not None: - # If it is an external tensor, we can generate a nicer - # symbol name. - name = external_trait.external_name - else: - # Otherwise, generate a name based on what we have. - shape_desc = "_".join([str(d) for d in literal.shape]) - name = f"constant_{shape_desc}_{str(literal.dtype)}" - - name = module_builder.unique_auto_symbol(name) - # TODO: We may want to unique this somehow in the module builder. - auto_def = AutoGlobalTensorDef(name, literal, GlobalAttributes()) - materialized_global = auto_def.track(module_builder, "_auto") - assert isinstance(materialized_global, MaterializedGlobal) - return materialized_global - - def _should_lift_tensor_to_global( - self, literal: torch.Tensor, external_trait: Optional[ExternalTensorTrait] - ) -> bool: - if external_trait is not None: - return True - volume = math.prod(literal.shape) - return volume > INLINE_TENSOR_VOLUME_LIMIT - - -class AutoGlobalTensorDef(GlobalsDef): - """Global definition that is used for arbitrary tensor literals encountered - during processing.""" - - __slots__ = [ - "_name", - "_value", - "_schema", - ] - - def __init__(self, name: str, value: torch.Tensor, attrs: GlobalAttributes): - super().__init__(attrs) - self._name = name - self._value = value - _, self._schema = tree_flatten(self._value) - - def items(self): - yield (self._name, self._value) - - def schema(self): - return self._schema - - -# In https://github.com/llvm/torch-mlir/pull/3046, the FxImporter was -# extended to accept a "module_op" as an Operation (vs a Module). Switch for -# compatibility. -_fx_importer_accepts_module_op = ( - "module_op" in inspect.getfullargspec(FxImporter).kwonlyargs -) - - -def _create_fx_importer(module_builder: ModuleBuilder) -> FxImporter: - hooks = _Hooks(module_builder) - if _fx_importer_accepts_module_op: - # New path. - return FxImporter( - module_op=module_builder.module_op, - config_check=False, - py_attr_tracker=module_builder.fx_py_attr_tracker, - hooks=hooks, - ) - else: - # Legacy path. - class FakeModule: - def __init__(self, op): - self._op = module_builder.module_op - - @property - def context(self): - return self._op.context - - @property - def operation(self): - return self._op - - @property - def body(self): - return self._op.regions[0].blocks[0] - - return FxImporter( - module=FakeModule(module_builder.module_op), - config_check=False, - py_attr_tracker=module_builder.fx_py_attr_tracker, - hooks=hooks, - ) diff --git a/core/shark_turbine/aot/support/procedural/globals.py b/core/shark_turbine/aot/support/procedural/globals.py deleted file mode 100644 index 250a8d600..000000000 --- a/core/shark_turbine/aot/support/procedural/globals.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Global references in a module. - -from typing import ( - Any, - Callable, - Dict, - Generator, - Optional, - Sequence, - Tuple, -) - -import torch - -from torch.utils._pytree import ( - TreeSpec, - tree_unflatten, -) - -from ....support.ir_imports import ( - IrType, - Operation, - Value, - util_d, -) - -from ....support.logging import aot_logger as logger - -from ..ir_utils import ( - GlobalAttributes, - ModuleBuilder, -) - -from .base import ( - AbstractScalar, - AbstractTensor, - Intrinsic, - IrTrace, - current_ir_trace, -) - -from .primitives import ( - IrScalar, - IrTensor, -) - -############################################################################### -# Globals -############################################################################### - - -class LiveGlobalCollectionProxy: - """Proxy object around a collection which knows how to redirect setitem.""" - - __slots__ = ["_raw_collection"] - - def __init__(self, raw_collection): - self._raw_collection = raw_collection - - def __getitem__(self, key: str): - actual = self._raw_collection[key] - if isinstance(actual, MaterializedGlobal): - return actual - else: - return LiveGlobalCollectionProxy(actual) - - def __setitem__(self, key, value): - item = self._raw_collection[key] - if isinstance(item, MaterializedGlobal): - current_ir_trace().handle_assignment(self, item, value) - else: - raise AttributeError( - f"Globals collection {self._raw_collection.__class__} only supports assignment of leaves" - ) - - def __len__(self): - return len(self._raw_collection) - - def __repr__(self): - return f"LiveGlobalsProxy({self._raw_collection})" - - -class GlobalsDef: - """Base class for all exporting descriptors.""" - - __slots__ = [ - "_attrs", - ] - - def __init__(self, attrs: GlobalAttributes): - self._attrs = attrs - - def items(self) -> Generator[Tuple[str, Any], None, None]: - """Yields tuples of name/value exports.""" - raise NotImplementedError - - def schema(self) -> TreeSpec: - """A schema used to unflatten for access from Python.""" - raise NotImplementedError - - def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any: - """Track the given pack of globals, returning a struct that can be used to access them.""" - flat_globals = [] - for name, value in self.items(): - # Switch on types we support. - fq_name = f"{export_namespace}.{name}" - if isinstance(value, torch.Tensor): - mapping = module_builder.global_ref_tracker.track(value) - if not mapping.is_empty: - logger.debug( - "IGNORE EXISTING TRACKED TENSOR(%s): %r", fq_name, mapping - ) - flat_globals.append(mapping.value) - continue - ( - actual_symbol_name, - global_op, - global_type, - ) = module_builder.create_tensor_global( - f"_{fq_name}", - value, - attrs=self._attrs, - logical_name=fq_name, - ) - mapping.value = IrGlobalTensor( - fq_name, - self, - symbol_name=actual_symbol_name, - global_op=global_op, - global_type=global_type, - dtype=value.dtype, - ) - logger.debug("TRACK NEW TENSOR(%s): %r", fq_name, mapping) - flat_globals.append(mapping.value) - continue - elif isinstance(value, AbstractTensor): - global_type = value.get_ir_type(module_builder) - ( - actual_symbol_name, - global_op, - ) = module_builder.create_typed_global( - f"_{fq_name}", - global_type, - attrs=self._attrs, - logical_name=fq_name, - ) - flat_globals.append( - IrGlobalTensor( - fq_name, - self, - symbol_name=actual_symbol_name, - global_op=global_op, - global_type=global_type, - dtype=value.dtype, - ) - ) - continue - elif isinstance(value, AbstractScalar): - global_type = value.get_ir_type(module_builder) - ( - actual_symbol_name, - global_op, - ) = module_builder.create_typed_global( - f"_{fq_name}", - global_type, - attrs=self._attrs, - logical_name=fq_name, - ) - flat_globals.append( - IrGlobalScalar( - fq_name, - self, - symbol_name=actual_symbol_name, - global_op=global_op, - global_type=global_type, - ) - ) - continue - - raise TypeError(f"Unsupported global type: {value.__class__}") - tree_globals = tree_unflatten(flat_globals, self.schema()) - if isinstance(tree_globals, MaterializedGlobal): - return tree_globals - else: - return LiveGlobalCollectionProxy(tree_globals) - - -class MaterializedGlobal: - """Tags an Ir* that is duck-typed as a global.""" - - ir_type: IrType - symbol_name: str - global_op: Operation - global_type: IrType - - -class IrGlobalScalar(IrScalar, MaterializedGlobal): - """An IrScalar that is loaded from a global and associated with its aggregate.""" - - __slots__ = [ - "global_op", - "global_type", - "info", - "export_name", - "symbol_name", - ] - - def __init__( - self, - export_name: str, - info: GlobalsDef, - *, - symbol_name: str, - global_op: Operation, - global_type: IrType, - ): - super().__init__(global_type) - self.info = info - self.export_name = export_name - self.symbol_name = symbol_name - self.global_op = global_op - - def resolve_ir_values(self, trace: IrTrace) -> Sequence[Value]: - with trace.loc, trace.ip: - value = util_d.GlobalLoadOp(self.ir_type, self.symbol_name).result - return [value] - - def resolve_assignment(self, proc_trace: "IrTrace", ir_values: Sequence[Value]): - if len(ir_values) != 1: - raise ValueError( - f"Can only assign a single value to a global. Got {len(ir_values)}" - ) - source_ir_type = ir_values[0].type - if source_ir_type != self.ir_type: - raise TypeError( - f"Cannot assign to a global with a different type: {self.ir_type} != {source_ir_type}" - ) - with proc_trace.loc, proc_trace.ip: - util_d.GlobalStoreOp(ir_values[0], self.symbol_name) - - def set(self, other): - t = current_ir_trace() - self.resolve_assignment(t, super().set(other).ir_values) - - def __repr__(self): - return ( - f"" - ) - - -class IrGlobalTensor(IrTensor, MaterializedGlobal): - """An IrScalar that is loaded from a global and associated with its aggregate.""" - - __slots__ = [ - "global_op", - "info", - "export_name", - "symbol_name", - ] - - def __init__( - self, - export_name: str, - info: GlobalsDef, - *, - symbol_name: str, - global_op: Operation, - global_type: IrType, - dtype: torch.dtype, - ): - super().__init__(global_type, dtype) - self.info = info - self.export_name = export_name - self.symbol_name = symbol_name - self.global_op = global_op - - def resolve_ir_values(self, trace: IrTrace) -> Sequence[Value]: - with trace.loc, trace.ip: - value = util_d.GlobalLoadOp(self.ir_type, self.symbol_name).result - return [value] - - def resolve_assignment(self, proc_trace: "IrTrace", ir_values: Sequence[Value]): - if len(ir_values) != 1: - raise ValueError( - f"Can only assign a single value to a global. Got {len(ir_values)}" - ) - source_ir_type = ir_values[0].type - if source_ir_type != self.ir_type: - raise TypeError( - f"Cannot assign to a global with a different type: {self.ir_type} != {source_ir_type}" - ) - with proc_trace.loc, proc_trace.ip: - util_d.GlobalStoreOp(ir_values[0], self.symbol_name) - - def __repr__(self): - return f"" diff --git a/core/shark_turbine/aot/support/procedural/iree_emitter.py b/core/shark_turbine/aot/support/procedural/iree_emitter.py deleted file mode 100644 index dbfd4ea2b..000000000 --- a/core/shark_turbine/aot/support/procedural/iree_emitter.py +++ /dev/null @@ -1,356 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Python API for IREE's high-level tensor dialects.""" - -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union - -import functools - -import torch - -from ....support.ir_imports import ( - IndexType, - IntegerType, - IrType, - RankedTensorType, - StringAttr, - Value, - arith_d, - flow_d, -) - -from ....support.conversions import ( - TORCH_DTYPE_TO_IREE_TYPE, -) - -from ..ir_utils import ( - build_index_value, -) - -from .base import ( - Intrinsic, - current_ir_trace, - ShapedTypeDynamicSizeSentinel, -) - -from .primitives import ( - IrScalar, - IrImmediateScalar, - IrTensor, - IrImmediateTensor, -) - -BuildableScalarValue = Union[IrScalar, Value] -BuildableTensorDimDecl = Union[int, Value] -BuildableTensorType = IrTensor -BuildableIndexType = Union[BuildableScalarValue, int] -BuildableIndexLengthType = Union[ - BuildableTensorDimDecl, Tuple[BuildableTensorDimDecl, BuildableTensorDimDecl] -] -BuildableSliceType = Sequence[BuildableIndexLengthType] -StaticIndexType = int - - -def cast_scalar_value(x: BuildableScalarValue) -> Value: - x = unwrap_intrinsic_value(x) - if not isinstance(x, Value): - raise ValueError(f"Expected a scalar value but got {x}") - return x - - -def cast_tensor_value(x: BuildableTensorType) -> IrTensor: - assert isinstance(x, IrTensor), f"Expected a tensor but got {type(x)}" - return x - - -def cast_index_value( - x: BuildableIndexType, *, constant_cache: Optional[Dict[int, Value]] = None -) -> Value: - x = unwrap_intrinsic_value(x) - if isinstance(x, int): - return build_index_value(x, constant_cache=constant_cache) - else: - return x - - -def cast_static_bounded_index(x: int, min_value: int, max_value: int) -> int: - if not isinstance(x, int): - raise ValueError(f"Expected int but got {type(x)}") - if x < min_value or x > max_value: - raise ValueError( - f"Expected int in range [{min_value}, {max_value}] but got {x}" - ) - return x - - -def cast_tensor_dim_decl( - xs: Sequence[BuildableTensorDimDecl], -) -> Tuple[Sequence[int], Sequence[Value]]: - """Casts a sequence of tensor declaration dimensions to dims suitable - for construction of a TensorType and a sequence of dynamic dim values.""" - dim_decls: List[int] = [] - dynamic_dim_values: List[Value] = [] - for x in xs: - x = unwrap_intrinsic_value(x) - if isinstance(x, Value): - assert_value_is_index(x) - dim_decls.append(ShapedTypeDynamicSizeSentinel) - dynamic_dim_values.append(x) - elif isinstance(x, int) and x >= 0: - dim_decls.append(x) - else: - raise ValueError( - f"Expected a tensor dimension as a positive integer or None but got {x}" - ) - return dim_decls, dynamic_dim_values - - -def cast_scalar_to_element_type(scalar: Value, element_type: IrType) -> Value: - scalar_type = scalar.type - # Support cast from Index -> Integer. - if scalar_type == IndexType.get() and IntegerType.isinstance(element_type): - return arith_d.IndexCastUIOp(element_type, scalar).result - raise ValueError( - f"Provided splat value ({scalar_type}) does not match dtype {element_type} (and cannot be cast)" - ) - - -def assert_value_is_index(x: Value): - t = x.type - if not IndexType.isinstance(t): - raise ValueError(f"Expected an index value but got {t}") - - -def unwrap_intrinsic_value(x) -> Any: - if isinstance(x, Intrinsic): - x, *rest = x.resolve_ir_values(current_ir_trace()) - if rest: - raise ValueError( - f"Expected a value that has an arity of one component but for {len(rest) + 1}" - ) - return x - - -def emitter(f): - @functools.wraps(f) - def wrapper(*args, **kwargs): - t = current_ir_trace() - with t.loc, t.ip: - return f(*args, **kwargs) - - return wrapper - - -class IREEEmitter: - @emitter - def tensor_dim( - self, - source: BuildableTensorType, - index: int, - *, - dtype: Optional[torch.dtype] = None, - ) -> "IrScalar": - """Gets the dimension size of a tensor at a static position.""" - source = cast_tensor_value(source) - index = cast_static_bounded_index(index, 0, source.rank - 1) - dim_value = source.get_dim_value(index) - if dtype is not None: - try: - cast_type = TORCH_DTYPE_TO_IREE_TYPE[dtype]() - except KeyError: - raise ValueError(f"Could not map Torch dtype {dtype} to an IREE type") - dim_value = arith_d.IndexCastUIOp(cast_type, dim_value).result - return IrImmediateScalar(dim_value) - - @emitter - def tensor_empty( - self, *dims: BuildableTensorDimDecl, dtype: torch.dtype = torch.float32 - ) -> IrTensor: - """Constructs a tensor with uninitialized values. - - TODO: Support an IREE/raw element type in addition to the torch dtype. - See: https://github.com/nod-ai/SHARK-Turbine/issues/130 - """ - dim_decls, dyn_dim_values = cast_tensor_dim_decl(dims) - try: - element_type = TORCH_DTYPE_TO_IREE_TYPE[dtype]() - except KeyError: - raise ValueError(f"Could not map Torch dtype {dtype} to an IREE type") - tensor_type = RankedTensorType.get(dim_decls, element_type) - raw_tensor = flow_d.TensorEmptyOp(tensor_type, dyn_dim_values).result - result = IrImmediateTensor(raw_tensor, dtype=dtype) - result.set_dynamic_dim_values(dyn_dim_values) - return result - - @emitter - def tensor_reshape( - self, source: BuildableTensorType, *result_dims: BuildableTensorDimDecl - ) -> "IrTensor": - constant_cache: Dict[int, Value] = {} - source = cast_tensor_value(source) - result_dim_decls, result_dynamic_dims = cast_tensor_dim_decl(result_dims) - result_type = RankedTensorType.get( - result_dim_decls, source.ir_type.element_type - ) - result_value = flow_d.TensorReshapeOp( - result_type, - source.ir_value, - source.get_only_dynamic_dim_values(constant_cache=constant_cache), - result_dynamic_dims, - ).result - result = IrImmediateTensor(result_value, dtype=source.dtype) - result.set_dynamic_dim_values(result_dynamic_dims) - return result - - @emitter - def tensor_slice( - self, source: BuildableTensorType, *indices: BuildableSliceType - ) -> "IrTensor": - """Extracts a slice of a tensor. - - The given indices must match the rank of the source and each index is - interpreted as `(start_index[, length])`, where the `length` is taken - to be 1 if only a single value is given for an index. - """ - source = cast_tensor_value(source) - source_value = source.ir_value - rank = source.rank - if len(indices) != rank: - raise ValueError( - f"Slice indices must match the source rank. Got {len(indices)}, expected {rank}" - ) - # Unpack start_indices and lengths. - start_indices: List[BuildableIndexType] = [] - lengths: List[BuildableIndexType] = [] - for index_pack in indices: - if isinstance(index_pack, (tuple, list)): - if len(index_pack) == 2: - start_indices.append(index_pack[0]) - lengths.append(index_pack[1]) - continue - else: - start_indices.append(index_pack) - lengths.append(1) - continue - raise ValueError( - f"Slice indices expected to be a single value or a 2-tuple. Got {index_pack}" - ) - - # Process the lengths into a result shape and input length. - index_value_cache: Dict[int, Value] = {} - length_values: List[Value] = [] - result_shape: List[int] = [] - result_dynamic_dims: List[Value] = [] - for raw_length in lengths: - if isinstance(raw_length, int): - # Static. - result_shape.append(raw_length) - if raw_length in index_value_cache: - # Cached. - length_values.append(index_value_cache[raw_length]) - else: - # Not cached. - length_value = cast_index_value(raw_length) - index_value_cache[raw_length] = length_value - length_values.append(length_value) - else: - # Dynamic. - result_shape.append(ShapedTypeDynamicSizeSentinel) - length_value = cast_index_value(raw_length) - length_values.append(length_value) - result_dynamic_dims.append(length_value) - assert len(length_values) == rank - assert result_shape.count(ShapedTypeDynamicSizeSentinel) == len( - result_dynamic_dims - ) - - # Process start indices. - start_index_values = [cast_index_value(idx) for idx in start_indices] - # Emit. - result_type = RankedTensorType.get(result_shape, source.ir_type.element_type) - constant_cache: Dict[int, Value] = {} - result_value = flow_d.TensorSliceOp( - result_type, - source_value, - source.get_only_dynamic_dim_values(constant_cache=constant_cache), - start_index_values, - length_values, - result_dynamic_dims, - ).result - result = IrImmediateTensor(result_value, dtype=source.dtype) - result.set_dynamic_dim_values(result_dynamic_dims) - return result - - @emitter - def tensor_update( - self, - target: BuildableTensorType, - update: BuildableTensorType, - *start_indices: BuildableIndexType, - ) -> "IrTensor": - """Applies an update to a target at start_indices and returns the mutated target.""" - constant_cache: Dict[int, Value] = {} - target = cast_tensor_value(target) - target_dynamic_dims = target.get_only_dynamic_dim_values( - constant_cache=constant_cache - ) - update = cast_tensor_value(update) - update_dynamic_dims = update.get_only_dynamic_dim_values( - constant_cache=constant_cache - ) - start_index_dim_values = [ - cast_index_value(idx, constant_cache=constant_cache) - for idx in start_indices - ] - result_value = flow_d.TensorUpdateOp( - target.ir_value, - target_dynamic_dims, - start_index_dim_values, - update.ir_value, - update_dynamic_dims, - ).result - result = IrImmediateTensor(result_value, target.dtype) - result.set_dynamic_dim_values(target_dynamic_dims) - return result - - @emitter - def tensor_splat( - self, - *dims: BuildableTensorDimDecl, - value: BuildableScalarValue, - dtype: torch.dtype, - ) -> "IrTensor": - # TODO: Type infer the dtype if missing. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/125 - dim_decls, dyn_dim_values = cast_tensor_dim_decl(dims) - try: - element_type = TORCH_DTYPE_TO_IREE_TYPE[dtype]() - except KeyError: - raise ValueError(f"Could not map Torch dtype {dtype} to an IREE type") - value = cast_scalar_value(value) - if value.type != element_type: - value = cast_scalar_to_element_type(value, element_type) - tensor_type = RankedTensorType.get(dim_decls, element_type) - raw_tensor = flow_d.TensorSplatOp(tensor_type, value, dyn_dim_values).result - result = IrImmediateTensor(raw_tensor, dtype=dtype) - result.set_dynamic_dim_values(dyn_dim_values) - return result - - @emitter - def tensor_trace(self, key: str, *ts: BuildableTensorType): - dynamic_dims = [] - for t in ts: - dynamic_dims.extend(t.get_only_dynamic_dim_values()) - ts = tuple(cast_tensor_value(t).ir_value for t in ts) - flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims) - - -# Circular imports to resolve typing. -from .primitives import ( - IrScalar, - IrTensor, -) diff --git a/core/shark_turbine/aot/support/procedural/primitives.py b/core/shark_turbine/aot/support/procedural/primitives.py deleted file mode 100644 index ad406c877..000000000 --- a/core/shark_turbine/aot/support/procedural/primitives.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Live types during runtime of a procedure trace. User code will -# operate on instances of these. - -from typing import ( - cast, - Dict, - List, - Optional, - Sequence, - Tuple, - Union, -) - -import torch - -from torch.export import ( - Constraint, - dynamic_dim, -) - -from ....support.ir_imports import ( - F32Type, - IrType, - RankedTensorType, - Value, - arith_d, -) - -from ..ir_utils import ( - build_tensor_dim_value, - _is_float_type, - _is_integer_like_type, - Empty, - EmptyType, -) - -from .base import ( - Intrinsic, - IrTrace, - ShapedTypeDynamicSizeSentinel, - current_ir_trace, -) - -############################################################################### -# Tensors and scalars -############################################################################### - - -class IrScalar(Intrinsic): - """An intrinsic that represents a scalar value. - - Subclasses are responsible for providing either value or load semantics. - """ - - __slots__ = [ - "ir_type", - ] - - def __init__(self, ir_type: IrType): - self.ir_type = ir_type - - def set(self, other): - t = current_ir_trace() - with t.ip, t.loc: - # Type check and promotion. - # TODO: Add more comprehensive type promotion hiearchy. - lhs = self.ir_value - rhs = None - if isinstance(other, IrScalar): - # Assumes when both are Value, they have same type. - rhs = other.ir_value - elif isinstance(other, (int, bool)) and _is_integer_like_type(self.ir_type): - rhs = arith_d.ConstantOp(lhs.type, other).result - elif isinstance(other, (float)) and _is_float_type(self.ir_type): - rhs = arith_d.ConstantOp(lhs.type, other).result - if rhs is None or lhs.type != rhs.type: - raise ValueError( - f"Cannot handle src type of {self.ir_type} to dst python type of {type(other)}." - ) - return IrImmediateScalar(rhs) - - def __add__(self, other): - t = current_ir_trace() - with t.ip, t.loc: - # Type check and promotion. - # TODO: Add more comprehensive type promotion hiearchy as seen in - # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html - # See: https://github.com/nod-ai/SHARK-Turbine/issues/132 - lhs = self.ir_value - if isinstance(other, IrScalar): - # Assumes when both are Value, they have same type. - rhs = other.ir_value - elif isinstance(other, (int, bool)): - rhs = arith_d.ConstantOp(lhs.type, other).result - elif isinstance(other, float) and _is_integer_like_type(self.ir_type): - lhs = arith_d.SIToFPOp(F32Type.get(), lhs).result - rhs = arith_d.ConstantOp(F32Type.get(), other).result - - # Checks that lhs and rhs has same type. - if lhs.type != rhs.type: - raise ValueError("Mismatch type between lhs and rhs.") - - # Emit computation. - if _is_integer_like_type(lhs.type): - return IrImmediateScalar(arith_d.AddIOp(lhs, rhs).result) - elif _is_float_type(lhs.type): - return IrImmediateScalar(arith_d.AddFOp(lhs, rhs).result) - else: - raise ValueError( - f"Expected operand to be either Int or Float but got {self.ir_type} instead." - ) - - -class IrImmediateScalar(IrScalar): - """Represents an IR scalar value.""" - - __slots__ = [ - "_ir_value", - ] - - def __init__(self, ir_value: Value): - super().__init__(ir_value.type) - assert isinstance(ir_value, Value) - self._ir_value = ir_value - - def resolve_ir_values(self, proc_trace: IrTrace) -> Sequence[Value]: - return (self._ir_value,) - - -class IrTensor(Intrinsic): - """An intrinsic that represents a tensor value. - - Carries additional metadata needed to resolve dimensions and original - PyTorch attributes. - """ - - __slots__ = [ - "ir_type", - "dtype", - "_cached_dim_values", - "_dynamic_dims", - "_shape", - "_meta_tensor", - "_meta_tensor_constraints", - ] - - def __init__(self, ir_type: IrType, dtype: torch.dtype): - assert isinstance(dtype, torch.dtype) - ranked_ir_type = RankedTensorType(ir_type) - self.ir_type = ranked_ir_type - self.dtype = dtype - # We always cache the meta tensor once asked for since it is used - # to anchor constraints. The constraints list is the same size as - # the rank and has a non-None dynamic_dim constraint for each - # dynamic dimension in the type. - self._meta_tensor: Optional[torch.Tensor] = None - self._meta_tensor_constraints: Optional[List[Constraint]] = None - - # Figure dynamic dims. - # _dynamic_dims is either Empty if static, or Value/None if dynamic. - self._shape = ranked_ir_type.shape - self._dynamic_dims: List[Union[EmptyType, Value, None]] = [ - None if d == ShapedTypeDynamicSizeSentinel else Empty for d in self._shape - ] - - # If we computed a dim, then stash it here for later use. - self._cached_dim_values: List[Optional[Value]] = [None] * len( - self._dynamic_dims - ) - - def dynamic_dim(self, i: int) -> Constraint: - """Access the dynamic_dim constraint for the i'th dimension.""" - mt, constraints = self._get_meta_tensor_constraints() - c = constraints[i] - if c is None: - raise TypeError( - f"Requested dynamic_dim constraint for dimension {i} of {self.ir_type} which is not dynamic" - ) - return c - - @property - def rank(self) -> int: - return len(self._shape) - - @property - def dynamic_dim_count(self) -> int: - return len(self._dynamic_dims) - self._dynamic_dims.count(Empty) - - def set_dim_value(self, index: int, value: Optional[Value]): - """Sets the value of a dynamic dim. - - Raises ValueError if the dimension is not dynamic. - """ - if self._dynamic_dims is Empty: - raise ValueError(f"Dimension {index} of {self} is not dynamic") - self._dynamic_dims[index] = value - - def set_dynamic_dim_values(self, values: Sequence[Value]): - """Sets all dynamic dim values.""" - dd = self._dynamic_dims - input_index = 0 - for pos in range(len(dd)): - if dd[pos] is Empty: - # Static - continue - assert input_index < len(values), "Mismatched static/dynamic dims" - assert isinstance(values[input_index], Value) - dd[pos] = values[input_index] - input_index += 1 - assert input_index == len(values), "Mismatched static/dynamic dims" - - def get_dim_value( - self, - index: int, - *, - constant_cache: Optional[Dict[int, Value]] = None, - resolved_ir_value: Optional[Value] = None, - ) -> Value: - """Gets a dimension as an Index value. - - Requires that an InsertionPoint and Location are on the context stack. - - This will cache the dim value, returning the cached value later if - requested. - """ - cached_dim = self._cached_dim_values[index] - if cached_dim: - return cached_dim - dynamic_dim = self._dynamic_dims[index] - if dynamic_dim is Empty or dynamic_dim is None: - if resolved_ir_value is None: - resolved_ir_value = self.ir_value - # Construct a static dimension. - # TODO: Add MLIR API support for creating an insertion point after - # an operation and use that to set the InsertionPoint to the - # earliest point. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/133 - dim_value = build_tensor_dim_value( - resolved_ir_value, index, constant_cache=constant_cache - ) - self._cached_dim_values[index] = dim_value - return dim_value - else: - # Dynamic dim is known. - return dynamic_dim - - def get_only_dynamic_dim_values( - self, - *, - constant_cache: Optional[Dict[int, Value]] = None, - resolved_ir_value: Optional[Value] = None, - ) -> List[Value]: - """Returns a list of *only* the dynamic dim Values.""" - values: List[Value] = [] - for i, sentinel in enumerate(self._dynamic_dims): - if sentinel is not Empty: - # Cache IR value so we don't materialize for each - # dynamic dim. - if resolved_ir_value is None: - resolved_ir_value = self.ir_value - values.append( - self.get_dim_value( - i, - constant_cache=constant_cache, - resolved_ir_value=resolved_ir_value, - ) - ) - return values - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - return NotImplemented - - def _get_meta_tensor_constraints(self) -> tuple[torch.Tensor, list[Constraint]]: - if self._meta_tensor is not None and self._meta_tensor_constraints is not None: - return self._meta_tensor, self._meta_tensor_constraints - - ir_tensor_type = self.ir_type - shape = ir_tensor_type.shape - # TODO: We shouldn't need to create a real tensor here, as Dynamo will - # immediately convert it to fake. However, it will also set up the shape - # environment and asserts that any fake tensor inputs are from its - # internal FakeMode. There should be a way but needs more investigation. - # TODO: This tensor needs a device that matches the model being exported. - # We just create these on the CPU because that is common. - # Note that in Dynamo's modeling of dynamic shapes, 0/1 are specialized and - # cannot be dynamic, and we must use a >= 2 dimension value to represent - # a dynamic quantity. We therefore adjust the shape in this way and - # add a dynamic_dim constraint. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/134 - extents = [2 if d < 0 else d for d in shape] - mt = self._meta_tensor = torch.empty(extents, dtype=self.dtype) - # Generate constraints that are aligned with any dynamic dimensions or None - # if static. - self._meta_tensor_constraints = constraints = [ - dynamic_dim(mt, i) if d < 0 else None for i, d in enumerate(shape) - ] - return mt, constraints - - def _to_meta_tensor(self) -> Tuple[torch.Tensor, List[Constraint]]: - """Converts to a fake Tensor that dynamo can handle.""" - mt, constraints = self._get_meta_tensor_constraints() - return mt, [c for c in constraints if c is not None] - - -class IrImmediateTensor(IrTensor): - """Represents a Value in the IR under construction during procedural tracing.""" - - __slots__ = [ - "_ir_value", - ] - - def __init__(self, ir_value: Value, dtype: torch.dtype): - super().__init__(ir_value.type, dtype) - self._ir_value = ir_value - - def __repr__(self): - return f"IrValueTensor(@{self.ir_value})" - - def resolve_ir_values(self, proc_trace: IrTrace) -> Sequence[Value]: - return (self._ir_value,) diff --git a/core/shark_turbine/aot/support/procedural/tracer.py b/core/shark_turbine/aot/support/procedural/tracer.py deleted file mode 100644 index 942d36d9e..000000000 --- a/core/shark_turbine/aot/support/procedural/tracer.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Concrete tracer for running buildable code. - -from typing import ( - Any, - Callable, - List, - Sequence, -) - -from torch.utils._pytree import ( - tree_flatten, - tree_unflatten, - treespec_dumps, -) - -from ....support.ir_imports import ( - Location, - StringAttr, - Value, - func_d, -) - -from ....support.logging import aot_logger as logger - -from ..ir_utils import ( - ModuleBuilder, -) - -from .base import ( - AbstractIntrinsic, - Intrinsic, - IrTrace, - ProcedureTraceError, - new_ir_trace_scope, -) - -from .globals import ( - LiveGlobalCollectionProxy, -) - -############################################################################### -# Concrete procedure building IrTracer. -############################################################################### - - -class ProcedureTrace(IrTrace): - """Captures execution of a Python func into IR.""" - - __slots__ = [ - "proxy_posargs", - "proxy_kwargs", - ] - - def __init__( - self, - *, - module_builder: ModuleBuilder, - func_op: func_d.FuncOp, - proxy_posargs, - proxy_kwargs, - ): - super().__init__(module_builder=module_builder, func_op=func_op) - self.proxy_posargs = proxy_posargs - self.proxy_kwargs = proxy_kwargs - - @staticmethod - def define_func( - module_builder: ModuleBuilder, - *, - symbol_name: str, - posargs: Sequence, - kwargs: dict, - loc: Location, - ) -> "ProcedureTrace": - # Unpack arguments. - arguments_flat, arguments_tree_def = tree_flatten((posargs, kwargs)) - argument_ir_types = [] - for arg in arguments_flat: - if not isinstance(arg, AbstractIntrinsic): - raise ProcedureTraceError(f"Expected a AbstractIntrinsic but got {arg}") - argument_ir_types.append(arg.get_ir_type(module_builder)) - - with loc: - _, func_op = module_builder.create_func_op(symbol_name, argument_ir_types) - - # Bind proxy arguments to an IR value. - ir_proxy_arguments_flat = [] - for ir_value, arg_proxy_type in zip( - func_op.body.blocks[0].arguments, arguments_flat - ): - ir_proxy_arguments_flat.append(arg_proxy_type.create_intrinsic(ir_value)) - - # Unflatten. - proxy_posargs, proxy_kwargs = tree_unflatten( - ir_proxy_arguments_flat, arguments_tree_def - ) - - # Metadata. - if arguments_flat: - func_op.attributes["torch.args_schema"] = StringAttr.get( - treespec_dumps(arguments_tree_def), context=module_builder.context - ) - - return ProcedureTrace( - module_builder=module_builder, - func_op=func_op, - proxy_posargs=proxy_posargs, - proxy_kwargs=proxy_kwargs, - ) - - def trace_py_func(self, py_f: Callable): - with new_ir_trace_scope(self) as t: - # TODO: Create IR proxies for python arguments. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/135 - return_py_value = _unproxy(py_f(*self.proxy_posargs, **self.proxy_kwargs)) - if return_py_value is None: - self.emit_return() - else: - flat_return_py_values, schema = tree_flatten(return_py_value) - flat_return_ir_values: List[Value] = [] - for py_value in flat_return_py_values: - flat_return_ir_values.extend(convert_py_value_to_ir(self, py_value)) - self.func_op.attributes["torch.return_schema"] = StringAttr.get( - treespec_dumps(schema), context=self.context - ) - self.emit_return(*flat_return_ir_values) - - def handle_call(self, target: Intrinsic, args, kwargs): - """Implements calls to jittable functions.""" - with self.loc, self.ip: - return target.resolve_call(self, *args, **kwargs) - - def handle_assignment(self, scope, target, updated_value): - logger.debug( - "ASSIGN %r.%r = %r", scope.__class__, target.__class__, updated_value - ) - self._recursive_assign(target, updated_value, set()) - - def _recursive_assign(self, target, source, encountered_set): - target = _unproxy(target) - source = _unproxy(source) - - # Check for cycles. - target_id = id(target) - if target_id in encountered_set: - raise TypeError(f"Cycle in tree assignment target") - encountered_set.add(target_id) - - # Leaves/terminals. - if isinstance(target, Intrinsic): - if not isinstance(source, Intrinsic): - raise TypeError( - f"Cannot assign mismatched leaf types in a tree: " - f"{target.__class__} vs {source.__class__}" - ) - leaf_values = source.resolve_ir_values(self) - target.resolve_assignment(self, leaf_values) - return - - # Zip across dicts. - if isinstance(target, dict): - if not isinstance(source, dict): - raise TypeError( - f"Mismatched dict assignment in a tree: {target.__class__} vs {source.__class__}" - ) - target_keys = target.keys() - source_keys = source.keys() - if target_keys != source_keys: - raise TypeError( - f"Mismatched dict keys in tree assignment: {target_keys} vs {source_keys}" - ) - for k in target_keys: - target_child = target[k] - source_child = source[k] - self._recursive_assign(target_child, source_child, encountered_set) - return - - # Zip across lists/tuples (we let them be used interchangeably at the source). - if isinstance(target, list): - if not isinstance(source, (list, tuple)): - if len(target) != len(source): - raise TypeError( - f"Mismatched sequence length in tree assignment: {len(target)} vs {len(source)}" - ) - for target_child, source_child in zip(target, source): - self._recursive_assign(target_child, source_child, encountered_set) - return - - raise TypeError( - f"Cannot recursively assign through a container of {target.__class__}" - ) - - -def convert_py_value_to_ir( - proc_trace: ProcedureTrace, py_value: Any -) -> Sequence[Value]: - """Given procedurally traced python values, type check and convert to IR.""" - if isinstance(py_value, Intrinsic): - return py_value.resolve_ir_values(proc_trace) - if isinstance(py_value, Value): - return [py_value] - raise TypeError( - f"Illegal type passed in procedural trace: {py_value.__class__} ({py_value})" - ) - - -def _unproxy(value): - if isinstance(value, LiveGlobalCollectionProxy): - return value._raw_collection - return value diff --git a/core/shark_turbine/aot/tensor_traits.py b/core/shark_turbine/aot/tensor_traits.py deleted file mode 100644 index bb7a52809..000000000 --- a/core/shark_turbine/aot/tensor_traits.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Optional -from dataclasses import dataclass - -import torch - - -__all__ = [ - "ExternalTensorTrait", -] - - -@dataclass -class ExternalTensorTrait: - """Represents a 'trait' that can be applied to a Tensor to signal that - it is to be loaded by name from an external archive at AOT execution time. - """ - - external_scope: str - external_name: str - - @staticmethod - def get(from_tensor: torch.Tensor) -> Optional["ExternalTensorTrait"]: - existing = getattr(from_tensor, "_turbine_external_tensor_trait", None) - if existing is None: - return None - assert isinstance(existing, ExternalTensorTrait) - return existing - - def set(self, to_tensor: torch.Tensor): - to_tensor._turbine_external_tensor_trait = self # type: ignore diff --git a/core/shark_turbine/dynamo/__init__.py b/core/shark_turbine/dynamo/__init__.py deleted file mode 100644 index b122f4621..000000000 --- a/core/shark_turbine/dynamo/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - - -from .tensor import ( - enable, - TurbineMode, - DeviceTensor, -) diff --git a/core/shark_turbine/dynamo/backends/cpu.py b/core/shark_turbine/dynamo/backends/cpu.py deleted file mode 100644 index 521d03808..000000000 --- a/core/shark_turbine/dynamo/backends/cpu.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import functools -import sys - -from ...runtime.device import ( - DeviceState, -) - -from ..executor import ( - SpecializedExecutable, -) - -from iree.compiler.api import ( - Invocation, - Session, - Source, - Output, -) - -from iree.compiler.ir import ( - Context, -) -from iree.compiler.passmanager import ( - PassManager, -) - -from iree.runtime import ( - VmModule, -) - -from iree.compiler.extras.fx_importer import FxImporter - -import torch -from torch._dynamo.backends.common import aot_autograd -from ..passes import turbine_cpu_pass_pipeline - -DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",) - - -def _base_backend(gm: torch.fx.GraphModule, example_inputs): - # Set up the session, context and invocation. - # Note that we do this on one in-memory module in a few phases: - # 1. Build it from the FX graph. - # 2. Run torch MLIR passes to lower it to a suitable form for - # input. - # 3. Run IREE's main compiler. - # 4. Output to an mmap buffer. - session = Session() - session.set_flags(*DEFAULT_COMPILER_FLAGS) - session.set_flags("--iree-hal-target-backends=llvm-cpu") - context = session.context - importer = FxImporter(context=context) - module = importer.module - inv = session.invocation() - # TODO: Should capture diagnostics. - inv.enable_console_diagnostics() - inv.import_module(module.operation) - - # Apply decompositions. - gm = turbine_cpu_pass_pipeline(gm, example_inputs) - - # Import phase. - importer.import_graph_module(gm) - print(module, file=sys.stderr) - with context: - pm = PassManager.parse("builtin.module(torch-to-iree)") - pm.run(module.operation) - print(module, file=sys.stderr) - - # IREE compilation phase. - inv.execute() - - # Output phase. - output = Output.open_membuffer() - inv.output_vm_bytecode(output) - - # Set up for runtime. - device_state = _get_device_state() - # TODO: Switch to wrap_buffer once https://github.com/openxla/iree/issues/14926 - # is fixed. - # vmfb_module = VmModule.wrap_buffer( - # device_state.instance, - # output.map_memory(), - # destroy_callback=output.close, - # ) - vmfb_module = VmModule.copy_buffer( - device_state.instance, - output.map_memory(), - ) - output.close() - - return SpecializedExecutable(vmfb_module, device_state) - - -backend = aot_autograd(fw_compiler=_base_backend) - - -# IREE runtime globals. For the CPU right now, there is no device selection, -# so it is easy. -@functools.lru_cache(maxsize=None) -def _get_device_state() -> DeviceState: - return DeviceState(driver="local-task") diff --git a/core/shark_turbine/dynamo/decompositions.py b/core/shark_turbine/dynamo/decompositions.py deleted file mode 100644 index 84f630c23..000000000 --- a/core/shark_turbine/dynamo/decompositions.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Callable, Dict, List, Optional, Sequence, Union - -import contextlib -import threading - -import torch -from torch._decomp import get_decompositions, remove_decompositions - -DecompositionTable = Dict[torch._ops.OperatorBase, Callable] -DecompositionOpsList = Sequence[ - Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket] -] - -# Manages "scopes" for decompositions used. Each unique scope is an attribute on -# the _decomp_local. If the attribute is missing, then the default -# decompositions are used. The scope "aot" is used for all AOT cases. -_decomp_local = threading.local() - - -def _get_decomp_stack(scope: str) -> List[DecompositionTable]: - try: - return getattr(_decomp_local, scope) - except AttributeError: - stack: List[DecompositionTable] = [] - setattr(_decomp_local, scope, stack) - return stack - - -def _current(scope: str) -> DecompositionTable: - """Gets the current decomposition table (which may be the default).""" - stack = _get_decomp_stack(scope) - if stack: - return dict(stack[-1]) - else: - return dict(DEFAULT_DECOMPOSITION_TABLE) - - -@contextlib.contextmanager -def _extend_context_manager( - scope: str, - *, - from_current: bool = True, - add_ops: Optional[DecompositionOpsList] = None, - remove_ops: Optional[DecompositionOpsList] = None -): - table: DecompositionTable - if from_current: - table = dict(_current(scope)) - else: - table = {} - if add_ops: - table.update(get_decompositions(add_ops)) - if remove_ops: - remove_decompositions(table, remove_ops) # type: ignore - stack = _get_decomp_stack(scope) - stack.append(table) - try: - yield table - finally: - popped = stack.pop() - assert ( - popped is table - ), "contextmanager unbalanced: popped different that pushed" - - -def _get_default_decomposition_ops() -> DecompositionOpsList: - aten = torch.ops.aten - # default decompositions pulled from SHARK / torch._decomp - return [ - aten.embedding_dense_backward, - aten.native_layer_norm_backward, - aten.slice_backward, - aten.select_backward, - aten.norm.ScalarOpt_dim, - aten.native_group_norm, - aten.upsample_bilinear2d.vec, - aten.split.Tensor, - aten.split_with_sizes, - aten.native_layer_norm, - aten.masked_fill.Tensor, - aten.masked_fill.Scalar, - aten.t, - aten.addmm, - # decompositions that aid us in handling nn.BatchNorm2d - aten._native_batch_norm_legit_functional, - aten._native_batch_norm_legit_no_training, - aten._native_batch_norm_legit, - aten._native_batch_norm_legit.no_stats, - aten.squeeze.dims, - # decompositions for miscellaneous ops that are not handled in torch-mlir but have available decompositions - aten.soft_margin_loss, - aten.im2col, - aten._euclidean_dist, - aten.index_copy, - aten.index_copy_, - aten.grid_sampler_2d, - aten.log_sigmoid_forward, - aten.unsafe_split.Tensor, - aten.binary_cross_entropy, - aten.dot, - aten._adaptive_avg_pool2d, - aten._prelu_kernel, - aten.full, - aten._log_softmax, - aten.nll_loss_forward, - aten.nll_loss_backward, - aten._to_copy, - aten._log_softmax_backward_data, - aten.lift_fresh_copy.default, - aten._unsafe_index.Tensor, - aten.unbind.int, - ] - - -# Some older APIs still use an op list instead of a table. -DEFAULT_DECOMPOSITIONS: DecompositionOpsList = _get_default_decomposition_ops() - -# The table of default decompositions. -DEFAULT_DECOMPOSITION_TABLE: DecompositionTable = get_decompositions( - DEFAULT_DECOMPOSITIONS -) diff --git a/core/shark_turbine/dynamo/executor.py b/core/shark_turbine/dynamo/executor.py deleted file mode 100644 index 561de515a..000000000 --- a/core/shark_turbine/dynamo/executor.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import functools -import os -from typing import List, Optional, Sequence, Union -from dataclasses import dataclass -from iree.runtime import ( - asdevicearray, - create_hal_module, - HalBuffer, - HalBufferView, - HalFence, - HalElementType, - DeviceArray, - get_driver, - VmContext, - HalDevice, - HalDriver, - VmInstance, - VmModule, - VmVariantList, -) - -import torch -from torch import ( - from_numpy as torch_from_numpy, -) - -from ..runtime.device import Device, DeviceState - - -@functools.lru_cache(maxsize=None) -def get_vm_instance() -> VmInstance: - return VmInstance() - - -_ELEMENT_TYPE_TO_DTYPE = { - HalElementType.FLOAT_16: torch.float16, - HalElementType.BFLOAT_16: torch.bfloat16, - HalElementType.FLOAT_32: torch.float32, - HalElementType.FLOAT_64: torch.float64, - HalElementType.UINT_8: torch.uint8, - HalElementType.SINT_8: torch.int8, - HalElementType.SINT_16: torch.int16, - HalElementType.SINT_32: torch.int32, - HalElementType.SINT_64: torch.int64, - HalElementType.BOOL_8: torch.bool, - HalElementType.OPAQUE_8: torch.qint8, - HalElementType.OPAQUE_8: torch.quint8, - HalElementType.COMPLEX_64: torch.complex64, - HalElementType.COMPLEX_128: torch.complex128, -} - - -class SpecializedExecutable: - """A concrete executable that has been specialized in some way.""" - - __slots__ = [ - "device_state", - "entry_function", - "user_module", - "vm_context", - ] - - def __init__( - self, - user_module: VmModule, - device_state: DeviceState, - entry_name: str = "main", - ): - self.user_module = user_module - self.vm_context = VmContext( - device_state.instance, - ( - create_hal_module(device_state.instance, device_state.device), - user_module, - ), - ) - self.device_state = device_state - self.entry_function = self.user_module.lookup_function(entry_name) - - def __call__(self, *inputs): - arg_list = VmVariantList(len(inputs)) - ret_list = VmVariantList( - 1 - ) # TODO: Get the number of results from the descriptor. - - # Move inputs to the device and add to arguments. - self._inputs_to_device(inputs, arg_list) - # TODO: Append semaphores for async execution. - - # Invoke. - self.vm_context.invoke(self.entry_function, arg_list, ret_list) - return self._returns_to_user(ret_list) - - def _inputs_to_device(self, inputs: list, arg_list: VmVariantList): - # TODO: We are assuming the worst case here which is that we have unknown Torch - # tensors that we send to the CPU and make continguous. Ideally, we would have - # fast paths for our own backends and interop. - for input in inputs: - input_cpu = input.cpu().contiguous() - # Since this is already a fallback case, just use the numpy array interop. - # It isn't great, but meh... fallback case. - device_array = asdevicearray(self.device_state.device, input_cpu) - arg_list.push_ref(device_array._buffer_view) - - def _returns_to_user(self, ret_list: VmVariantList): - # TODO: This is also not good that we are moving back to the CPU like this. - # We should be returning a custom Tensor implementation which represents - # our device data and has synchronization hooks for accessing it. - device = self.device_state.device - num_returns = len(ret_list) - user_returns = [None] * num_returns - for i in range(num_returns): - device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i)) - device_array = DeviceArray(device, device_buffer_view) - host_array = device_array.to_host() - user_returns[i] = torch_from_numpy(host_array) # type: ignore - - return user_returns - - -@dataclass -class EagerExecResult: - buffer: HalBuffer - size: int - dtype: torch.dtype - signal: Optional[HalFence] = None - - -def _element_type_to_dtype(element_type) -> torch.dtype: - try: - return _ELEMENT_TYPE_TO_DTYPE[element_type] - except KeyError: - raise ValueError(f"Unable to map {element_type} to torch dtype.") - - -class EagerSpecializedExecutable: - """A concrete executable that has been specialized in some way.""" - - __slots__ = [ - "device_state", - "entry_function", - "user_module", - "vm_context", - ] - - def __init__( - self, - user_module: VmModule, - device_state: DeviceState, - entry_name: str = "main$async", - ): - self.user_module = user_module - self.vm_context = VmContext( - device_state.instance, - ( - create_hal_module(device_state.instance, device_state.device), - user_module, - ), - ) - self.device_state = device_state - self.entry_function = self.user_module.lookup_function(entry_name) - - def __call__(self, *inputs): - arg_list = VmVariantList(len(inputs)) - ret_list = VmVariantList( - 1 - ) # TODO: Get the number of results from the descriptor. - - # Initialize wait and signal fence if not async mode. - device = inputs[0]._storage.device - wait_fence, signal_fence = self._initialize_fences(device, inputs, arg_list) - - # Move inputs to the device and add to arguments. - self._inputs_to_device(inputs, arg_list, wait_fence, signal_fence) - - # Invoke. - self.vm_context.invoke(self.entry_function, arg_list, ret_list) - return self._returns_to_user(ret_list, signal_fence) - - def _inputs_to_device( - self, - inputs: list, - arg_list: VmVariantList, - wait_fence: HalFence = None, - signal_fence: HalFence = None, - ): - # TODO: We are assuming the worst case here which is that we have unknown Torch - # tensors that we send to the CPU and make continguous. Ideally, we would have - # fast paths for our own backends and interop. - for input in inputs: - arg_list.push_ref(input.buffer_view) - wait_fence.extend(input._storage.ready_fence) - - # Append fences into list. - arg_list.push_ref(wait_fence) - arg_list.push_ref(signal_fence) - - def _returns_to_user(self, ret_list: VmVariantList, signal: HalFence = None): - # TODO: This is also not good that we are moving back to the CPU like this. - # We should be returning a custom Tensor implementation which represents - # our device data and has synchronization hooks for accessing it. - device = self.device_state.device - num_returns = len(ret_list) - user_returns = [None] * num_returns - for i in range(num_returns): - device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i)) - dtype = _element_type_to_dtype(device_buffer_view.element_type) - size = torch.Size(device_buffer_view.shape) - device_buffer = device_buffer_view.get_buffer() - user_returns[i] = EagerExecResult(device_buffer, size, dtype, signal) # type: ignore - return user_returns - - def _initialize_fences(self, device: Device, inputs: list, arg_list: VmVariantList): - fence_capacity = device._fence_capacity - tx_semaphore = device._tx_timeline - current_tx_timepoint = device._tx_timepoint - - # Create wait semaphore and fence. - wait_semaphores = (tx_semaphore, current_tx_timepoint) - wait_fence = HalFence(fence_capacity) - wait_fence.insert(*wait_semaphores) - - # Create signal semaphore and fence. - device._tx_timepoint += 1 - signals_semaphore = (tx_semaphore, current_tx_timepoint + 1) - signal_fence = HalFence(fence_capacity) - signal_fence.insert(*signals_semaphore) - - # Add fences into arg_list for async exec. - return wait_fence, signal_fence diff --git a/core/shark_turbine/dynamo/passes.py b/core/shark_turbine/dynamo/passes.py deleted file mode 100644 index 23078a834..000000000 --- a/core/shark_turbine/dynamo/passes.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from torch.fx.experimental.proxy_tensor import make_fx -from torch._decomp import get_decompositions -from torch.func import functionalize -from typing import List, Optional - -from .decompositions import DEFAULT_DECOMPOSITIONS - - -def apply_decompositions( - gm: torch.fx.GraphModule, - example_inputs, - decompose_ops: Optional[List[torch._ops.OpOverload]] = None, -): - if decompose_ops is None: - return gm - - decompositions = get_decompositions(decompose_ops) - gm = make_fx( - functionalize(gm), - decomposition_table=decompositions, - )(*example_inputs) - - return gm - - -def turbine_cpu_pass_pipeline(gm: torch.fx.GraphModule, example_inputs): - decompose_ops = DEFAULT_DECOMPOSITIONS - return apply_decompositions(gm, example_inputs, decompose_ops) # type: ignore diff --git a/core/shark_turbine/dynamo/tensor.py b/core/shark_turbine/dynamo/tensor.py deleted file mode 100644 index 85913953a..000000000 --- a/core/shark_turbine/dynamo/tensor.py +++ /dev/null @@ -1,658 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""A Turbine tensor. - -This implementation is adapted from a variety of sources, most notably the subclass -zoo: https://github.com/albanD/subclass_zoo/blob/main/new_device.py -""" - -from typing import Any, Optional, Sequence - -import functools -import atexit -import numpy as np -from types import BuiltinFunctionType - -import torch -import torch._dynamo as dynamo -from torch.overrides import TorchFunctionMode - -from ..runtime.device import ( - Device, - DeviceState, -) - -from ..support.conversions import ( - DTYPE_TO_ELEMENT_TYPE, - dtype_to_element_type, - torch_dtype_to_numpy, -) - -from .executor import EagerSpecializedExecutable - -from ..support import ( - ApiSequencingError, - UnknownDTypeError, -) - -from iree.runtime import ( - HalBuffer, - HalBufferView, - HalCommandBuffer, - HalElementType, - HalFence, - VmModule, -) - -from iree.compiler.api import Session, Output -from iree.compiler.passmanager import PassManager - -from iree.compiler.extras.fx_importer import FxImporter - -DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",) - -############################################################################### -# Factories and device enablement -############################################################################### - - -class TurbineMode(TorchFunctionMode): - """Enables PyTorch tensor device= support for Tensor factory functions. - - This can be used in a `with` block to dynamically scope enablement, or - it can be enabled globally via the `enable()` function. - """ - - IMPLEMENTATIONS: dict = {} - CACHED_IMPLEMENTATIONS: dict = {} - COMPUTE_METHODS = set((torch.add, torch.sub, torch.mul, torch.abs)) - - def __torch_function__(self, func, types, args=(), kwargs=None): - def super_fn(*args, **kwargs): - # Disable torch_function by hand because we don't want the wrapping behavior of - # the super() impl - with torch._C.DisableTorchFunction(): - return func(*args, **kwargs) - - if func in self.IMPLEMENTATIONS: - if func in self.COMPUTE_METHODS: - args += (func,) - return self.IMPLEMENTATIONS[func](super_fn, *args, **kwargs or {}) - - # This is just a no-op for all the non-factory functions: - return super_fn(*args, **kwargs or {}) - - -def enable(): - """Enables PyTorch tensor device= support for Turbine permanently.""" - TurbineMode().__enter__() - Device("local-task").set() - atexit.register(disable) - - -def disable(): - Device.current().clear() - TurbineMode().__exit__(None, None, None) - - -# Convenient wrapper to register functions -def raw_factory(func): - """Decorator to register an unconditional factory function.""" - - def _inner_fn(impl): - TurbineMode.IMPLEMENTATIONS[func] = impl - return impl - - return _inner_fn - - -# Convenient wrapper to register functions -def compute_factory(func): - """Decorator to register an unconditional factory function.""" - - def _inner_fn(impl): - TurbineMode.IMPLEMENTATIONS[func] = impl - TurbineMode.COMPUTE_METHODS.add(func) - return impl - - return _inner_fn - - -def device_factory(func): - """Decorator to invoke the user provided factory for our devices. - - Wrap a function like this: - - @device_factory(torch.zeros) - def _zeros(*args, device: Device, **kwargs): - ... - """ - - def _inner_fn(impl): - def _filter_impl(super_fn, *args, **kwargs): - device: Optional[Device] = None - device_spec = kwargs.get("device", None) - if device_spec: - device = _parse_device(device_spec) - if device: - del kwargs["device"] - return impl(*args, device=device, **kwargs) - return super_fn(*args, **kwargs) - - TurbineMode.IMPLEMENTATIONS[func] = _filter_impl - - return _inner_fn - - -_TURBINE_PREFIX = "turbine-" - - -def _parse_device(device_arg) -> Optional[Device]: - if isinstance(device_arg, Device): - return device_arg - elif isinstance(device_arg, str): - if device_arg == "turbine": - return Device.current() - elif device_arg.startswith(_TURBINE_PREFIX): - return Device(device_arg[len(_TURBINE_PREFIX) :]) - return None - - -############################################################################### -# Turbine storage -############################################################################### - - -class Storage: - __slots__ = [ - "buffer", - "device", - "ready_fence", - ] - - def __init__(self, device: Device, buffer: HalBuffer): - fence_capacity = device._fence_capacity - self.buffer = buffer - self.device = device - # Signalled when the buffer is ready to be consumed. Consumers should - # join this fence and wait on it. It must be advanced when dependencies - # are queued. - self.ready_fence = HalFence(fence_capacity) - - def sync(self): - """Stops the world and waits for all scheduled mutations to complete.""" - self.ready_fence.wait() - - def execute_transfer(self, cb: HalCommandBuffer): - """Executes a transfer command buffer that has no external dependencies.""" - device = self.device - hal_device = device.hal_device - device._tx_timepoint += 1 - signal_sem = (device._tx_timeline, device._tx_timepoint) - hal_device.queue_execute( - [cb], wait_semaphores=self.ready_fence, signal_semaphores=[signal_sem] - ) - self.ready_fence.insert(*signal_sem) - - def kill(self): - """Kills the device memory associated with this storage.""" - if not self.buffer: - raise ApiSequencingError("Storage.kill() called on a non-live instance") - device = self.device - hal_device = device.hal_device - hal_device.queue_dealloca(self.buffer, self.ready_fence, []) - self.buffer = None - self.device = None - - def __del__(self): - if self.buffer: - self.kill() - - -############################################################################### -# Tensor class and support -############################################################################### - - -class DeviceTensor(torch.Tensor): - """A Tensor accessing memory on a Turbine device.""" - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs={}): - # Now, we check the function to determine how to handle it. If it's - # aten.add, then we call aten.sub. Otherwise, we pass through to - # the original function - args += (func,) - return compute_method(func, *args, **kwargs) - - def _to_meta_tensor(self): - return torch.empty(self.shape, dtype=self.dtype) - - @staticmethod - def __new__(cls, size, dtype, raw_data=None, requires_grad=False): - # Using a meta tensor as the wrapped gives us shape and dtype - # propagation. - return torch.Tensor._make_subclass( - cls, - torch.empty(size, dtype=dtype, device="meta"), - require_grad=requires_grad, - ) - - def __init__(self, size, dtype, raw_data=None, requires_grad=False): - if isinstance(raw_data, Storage): - self._storage = raw_data - self._bv = None - else: - if raw_data is not None: - raise NotImplementedError( - f"raw_data= not implemented for DeviceTensor ({raw_data.__class__})" - ) - - @staticmethod - def from_torch(input_tensor: torch.Tensor): - if isinstance(input_tensor, torch.Tensor): - dev_tensor = DeviceTensor._async_create_empty( - input_tensor.size(), Device("local-task"), input_tensor.dtype - ) - dev_tensor._async_copy_from_host(input_tensor.numpy()) - return dev_tensor - else: - if input_tensor is not None: - raise ValueError("Expected input to be of type torch.Tensor.") - - @property - def buffer_view(self) -> HalBufferView: - if self._bv is None: - self._bv = HalBufferView( - self._storage.buffer, - shape=self.size(), - element_type=dtype_to_element_type(self.dtype), - ) - return self._bv - - def cpu(self): - return self.to("cpu") - - @property - def device(self): - return self._storage.device - - def __repr__(self): - hal_device = self._storage.device.hal_device - try: - return f"" - except UnknownDTypeError: - return f"" - - @staticmethod - def _async_create_empty( - size: Sequence[int], device: Device, dtype: torch.dtype - ) -> "DeviceTensor": - """Creates an uninitialized tensor with a given size and dtype.""" - alloc_size = _calculate_c_contig_size(size, dtype) - hal_device = device.hal_device - # Async allocate a buffer, waiting for the device (tx_timeline, tx_timepoint) - # and signalling tx_timepoint + 1. Because we are just creating an empty - # (uninitialized) tensor, it is ready when allocation completes. - tx_semaphore = device._tx_timeline - current_tx_timepoint = device._tx_timepoint - wait_semaphores = [(tx_semaphore, current_tx_timepoint)] - alloca_complete_semaphore = (tx_semaphore, current_tx_timepoint + 1) - signal_semaphores = [alloca_complete_semaphore] - device._tx_timepoint += 1 - buffer = hal_device.queue_alloca(alloc_size, wait_semaphores, signal_semaphores) - storage = Storage(device, buffer) - storage.ready_fence.insert(*alloca_complete_semaphore) - return DeviceTensor(size, dtype, raw_data=storage) - - @staticmethod - def _from_buffer( - buffer: HalBuffer, - size: Sequence[int], - dtype: torch.dtype, - device: Device, - signal: HalFence, - ) -> "DeviceTensor": - """Creates an uninitialized tensor with a given size and dtype.""" - storage = Storage(device, buffer) - if signal is not None: - storage.ready_fence = signal - return DeviceTensor(size, dtype, raw_data=storage) - - def _async_fill_py_value(self, value): - """Fills a value in all elements of the tensor. - - The value is interpreted relative to the tensor's dtype and is suitable for integer - values like 0, 1, etc. Anything more complicated should use a lower-level API to - set up a fill pattern. - """ - storage = self._storage - hal_device = storage.device.hal_device - cb = HalCommandBuffer(hal_device) - pattern = _create_pattern_for_dtype(self.dtype, value) - cb.fill(storage.buffer, pattern, end=True) - storage.execute_transfer(cb) - - def _async_copy_from_host(self, host_data): - """Copies from arbitrary host data of unknown providence. - - Note that this is pretty much the worst way to get data onto the device as - the default path for many devices involves either host copies or expensive - device synchronization in order to setup memory mappings. However, as a - general purpose fallback, its utility cannot be denied. - """ - storage = self._storage - hal_device = storage.device.hal_device - staging_buffer = hal_device.allocator.allocate_host_staging_buffer_copy( - hal_device, host_data - ) - cb = HalCommandBuffer(hal_device) - cb.copy(staging_buffer, storage.buffer, end=True) - storage.execute_transfer(cb) - - -def _normalize_size(size_or_nested) -> Sequence[int]: - if len(size_or_nested) == 1 and not isinstance(size_or_nested[0], int): - return size_or_nested[0] - else: - return size_or_nested - - -def _calculate_c_contig_size(size: Sequence[int], dtype: torch.dtype) -> int: - """Calculates a C-contiguous buffer size in bytes for torch size and dtype.""" - accum = _DTYPE_TO_ELEMENT_SIZE[dtype] - for s in size: - accum *= s - return accum - - -# And some factory functions -# By hand -@raw_factory(torch.Tensor.to) -def to(super_fn, self, device, dtype=None, non_blocking=None): - # Note that we only implement a subset of .to() here - turbine_device = _parse_device(device) - if turbine_device: - # To turbine. - # For now, falling back to a copy via CPU. - new_t = DeviceTensor._async_create_empty( - self.size(), turbine_device, self.dtype - ) - new_t._async_copy_from_host(self.numpy()) - return new_t - elif isinstance(self, DeviceTensor): - # From turbine. - # TODO: We can handle certain catwalk cases from/to specific device classes - # before just falling back to transferring through the CPU. - # Stop the world and transfer to CPU. - storage = self._storage - storage.sync() - bv = self.buffer_view - dtype_descr = HalElementType.map_to_dtype(bv.element_type) - memory = storage.buffer.map() - np_array = memory.asarray(self.size(), dtype_descr) - return torch.from_numpy(np_array) - else: - return super_fn(self, device) - - -@raw_factory(torch._C._nn._parse_to) -def _parse_to(super_fn, *args, **kwargs): - if "turbine" in args: - # TODO: Parse through args and kwargs for correct params. - device = "turbine" - dtype = None - non_blocking = False - convert_to_format = None - return device, dtype, non_blocking, convert_to_format - else: - return super_fn(self, device) - - -@device_factory(torch.empty) -def _empty(*size, device: Device, dtype=torch.float32): - # Turbine empty. - norm_size = _normalize_size(size) - return DeviceTensor._async_create_empty(norm_size, device=device, dtype=dtype) - - -@device_factory(torch.zeros) -def _zeros(*size, device: Device, dtype=torch.float32): - t = DeviceTensor._async_create_empty(_normalize_size(size), device, dtype) - t._async_fill_py_value(0) - return t - - -@device_factory(torch.ones) -def _ones(*size, device: Device, dtype=torch.float32): - t = DeviceTensor._async_create_empty(_normalize_size(size), device, dtype) - t._async_fill_py_value(1) - return t - - -def cpu_tensor_constructor(cpu_func): - """For our devices, calls a user function which returns a CPU tensor. - - The returned CPU tensor will - The contents of the array will be copied to a new empty tensor. - While not terribly efficient, this can be used to fill in bulk-factory - functions that have not yet been optimized to run completely on device. - """ - - def inner(*args, device: Device, **kwargs): - cpu_t = cpu_func(*args, **kwargs) - dev_t = DeviceTensor._async_create_empty(cpu_t.size(), device, cpu_t.dtype) - dev_t._async_copy_from_host(cpu_t.numpy()) - return dev_t - - return inner - - -@device_factory(torch.arange) -@cpu_tensor_constructor -def _arange(*args, dtype=None): - if dtype is not None: - dtype = torch_dtype_to_numpy(dtype) - return torch.from_numpy(np.arange(*args, dtype=dtype)) - - -@device_factory(torch.rand) -@cpu_tensor_constructor -def _rand(*args, dtype=None): - t = torch.from_numpy(np.random.rand(*args)) - if dtype: - t = t.to(dtype) - return t - - -@functools.lru_cache(maxsize=None) -def _get_device_state() -> DeviceState: - return DeviceState(driver="local-task") - - -# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/shark_turbine/aot/builtins/jittable.py#L212-L237 -# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/backends/cpu.py -# TODO: Try to generalize for other devices. -def compute_method(super_fn, *args, **kwargs): - # Compute factory fns reserve the last arg as src_op - # Requires src_op rather than super_fn, because super_fn - # is often wrapped by DisableTorchFunction. - init_py_args = args[:-1] - src_op = args[-1] - - any_turbine_tensor = False - devices_set = set() - arg_shape_dtype_encode = [] - py_args = [] - for arg_idx, py_arg in enumerate(init_py_args): - ret_val = py_arg - if isinstance(py_arg, DeviceTensor): - any_turbine_tensor = True - if isinstance(py_arg, (int, float)): - ret_val = DeviceTensor.from_torch(torch.tensor(py_arg)) - devices_set.add(ret_val.device) - arg_shape_dtype_encode.append(str(ret_val.shape) + str(ret_val.dtype)) - py_args.append(ret_val) - - # Check if turbine device exist. If doesn't run regular fn. - if not any_turbine_tensor: - super_fn(*py_args, **kwargs) - - # Do not support interop between Turbine and other devices. - if len(devices_set) > 1: - raise ValueError("Turbine do not support mixed device!") - cur_device = py_args[0].device - # Get a unique encoding to identify computation/dispatch using opCode, input shapes, and dtypes. - if isinstance(src_op, torch._ops.OpOverload): - src_op_name = src_op.name() - elif isinstance(src_op, BuiltinFunctionType): - src_op_name = src_op.__name__ - else: - raise ValueError("Expected srcOp to be torchOp or builtinFn.") - compute_id_encode = src_op_name + "".join(arg_shape_dtype_encode) - compute_hash = hash(compute_id_encode) - if compute_hash in TurbineMode.CACHED_IMPLEMENTATIONS: - # TODO: Handle multiple output. - exec_res = TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash](*py_args, **kwargs)[ - 0 - ] - res_buf = DeviceTensor._from_buffer( - exec_res.buffer, exec_res.size, exec_res.dtype, cur_device, exec_res.signal - ) - return res_buf - - # Preprocess func and generate into FX. - flat_pytorch_args = [py_arg._to_meta_tensor() for py_arg in py_args] - - # TODO: Replace all the below with torch.compile, although currently seems like - # the problem lies in it will try to generate DeviceTensor, but it would be missing - # _storage and causes error. - def func_src_op(*args, **kwargs): - return src_op(*args, **kwargs) - - exported_f = dynamo.export( - func_src_op, - aten_graph=True, - decomposition_table={}, - constraints={}, - assume_static_by_default=True, - ) - gm, guards = exported_f(*flat_pytorch_args) - - # Setup mlir compilation pipeline. - session = Session() - session.set_flags(*DEFAULT_COMPILER_FLAGS) - session.set_flags("--iree-hal-target-backends=llvm-cpu") - context = session.context - - # Generate MLIR from FX. - importer = FxImporter(context=context) - module = importer.module - inv = session.invocation() - # TODO: Should capture diagnostics. - inv.enable_console_diagnostics() - inv.import_module(module.operation) - importer.import_graph_module(gm) - - # Compile MLIR to vmfb. - inv.execute() - output = Output.open_membuffer() - inv.output_vm_bytecode(output) - - # Map VMFB to buffer. - device_state = _get_device_state() - vmfb_module = VmModule.wrap_buffer( - device_state.instance, - output.map_memory(), - destroy_callback=output.close, - ) - - # Load and execute VMFB file. - exec = EagerSpecializedExecutable(vmfb_module, device_state) - exec_results = exec(*py_args) - if len(exec_results) != 1: - raise ValueError("Currently only support one output for now.") - exec_res = exec_results[0] - - TurbineMode.CACHED_IMPLEMENTATIONS[compute_hash] = exec - - # Rewrap torch tensor into DeviceTensor and return. - # TODO: Handle multiple output. - dev_res = DeviceTensor._from_buffer( - exec_res.buffer, exec_res.size, exec_res.dtype, cur_device, exec_res.signal - ) - return dev_res - - -############################################################################### -# Conversions -############################################################################### - -_ELEMENT_TYPE_TO_NUMPY_DTYPE = { - HalElementType.FLOAT_16: np.float16, - HalElementType.FLOAT_32: np.float32, - HalElementType.FLOAT_64: np.float64, - HalElementType.UINT_8: np.uint8, - HalElementType.SINT_8: np.int8, - HalElementType.SINT_16: np.int16, - HalElementType.SINT_32: np.int32, - HalElementType.SINT_64: np.int64, - HalElementType.BOOL_8: np.bool_, - HalElementType.COMPLEX_64: np.complex64, - HalElementType.COMPLEX_128: np.complex128, -} - - -def _element_type_to_numpy_dtype(element_type: HalElementType) -> Any: - try: - return DTYPE_TO_ELEMENT_TYPE[element_type] - except KeyError: - raise UnknownDTypeError(element_type) - - -def _create_pattern_for_dtype(dtype: torch.dtype, x): - ctor = _simple_pattern_ctors.get(dtype, None) - if ctor: - return ctor(x) - else: - raise UnknownDTypeError(dtype) - - -_simple_pattern_ctors = { - torch.float16: lambda x: np.float16(float(x)), - torch.float32: lambda x: np.float32(float(x)), - torch.float64: lambda x: np.float64(float(x)), - torch.uint8: lambda x: np.uint8(int(x)), - torch.int8: lambda x: np.int8(int(x)), - torch.int16: lambda x: np.int16(int(x)), - torch.int32: lambda x: np.int32(int(x)), - torch.int64: lambda x: np.int64(int(x)), - torch.bool: lambda x: np.bool_(bool(x)), - torch.complex64: lambda x: np.complex64(complex(x)), - torch.complex128: lambda x: np.complex128(complex(x)), -} - - -# returns the torch datatype element size in bytes -_DTYPE_TO_ELEMENT_SIZE = { - torch.quint4x2: 1, - torch.uint8: 1, - torch.int8: 1, - torch.quint8: 1, - torch.qint8: 1, - torch.int16: 2, - torch.float16: 2, - torch.bfloat16: 2, - torch.int32: 4, - torch.qint32: 4, - torch.float32: 4, - torch.complex32: 4, - torch.int64: 8, - torch.float64: 8, - torch.complex64: 8, - torch.complex128: 16, -} diff --git a/core/shark_turbine/dynamo/type_conversion.py b/core/shark_turbine/dynamo/type_conversion.py deleted file mode 100644 index 8206e10f6..000000000 --- a/core/shark_turbine/dynamo/type_conversion.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Converters to/from torch types. - -Note that there are ad-hoc type conversions spread around a bit, and we -should consolidate them here. -""" -from typing import List, Optional - -import functools -import re - -from ..support.ir_imports import ( - tensor_d, - Context, - F64Type, - IntegerType, - RankedTensorType, - ShapedType, - IrType, - Location, - Operation, - Value, -) - - -# Match an overall torch type declaration. Groups: -# 1. Local name (int, float, vtensor) -# 2. Parameter block ("<...>"), including the delimitters -# 3. Inner parameter block (no delimitters) -DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch.([^<]+)(<([^>]*)>)?$") - -# Decomposes a vtensor parameter block into a dimension list and dtype. Groups: -# 1. Dimension list -# 2. Dtype -DECOMPOSE_TENSOR_PARAMS_PATTERN = re.compile(r"\[([^\]]*)\],([^,]+)$") - - -class NativeTypeConverter: - def __init__(self, context: Context): - self._context = context - # Cache per instance. - self.torch_type_to_native = functools.lru_cache(maxsize=None)( # type: ignore[method-assign] - self.torch_type_to_native - ) - - def torch_type_to_native(self, torch_type: IrType, signless: bool = True) -> IrType: - """Converts a presumed torch type to a corresponding native type. - - This mirrors the type conversion in torch-mlir's BackendTypeConversion.cpp. - - As an example: - !torch.int -> i64 - !torch.float -> f64 - !torch.bool -> i1 - !torch.vtensor -> tensor - - If `signless=False`, then integer types will retain their signs. - """ - # We don't presently have API support for introspecting torch type, - # and even if we did, it is likely that this is more efficient. - m = re.match(DECOMPOSE_TORCH_TYPE_PATTERN, str(torch_type)) - if m: - name, _, params_str = m.groups() - with self._context: - if name == "bool": - return IntegerType.get_signless(1) - if name == "int": - return ( - IntegerType.get_signless(64) - if signless - else IntegerType.get_signed(64) - ) - elif name == "float": - return F64Type.get() - elif name == "vtensor": - tm = re.match(DECOMPOSE_TENSOR_PARAMS_PATTERN, params_str) - assert tm, f"Could not parse !torch.vtensor params: {params_str}" - dim_list_str, dtype_str = tm.groups() - dim_list = parse_tensor_dim_list(dim_list_str) - dtype = self.convert_torch_element_type_to_native( - IrType.parse(dtype_str), signless=signless - ) - # TODO: Eliminate RankedTensorType dependence on Location. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/145 - with Location.unknown(): - return RankedTensorType.get(dim_list, dtype) - raise TypeError(f"Unsupported torch type conversion for {torch_type}") - - def convert_torch_element_type_to_native( - self, torch_type: IrType, signless: bool = True - ) -> IrType: - # Torch uses the builtin type hierarchy of IntegerType and FloatType - # to represent dtypes. These are mostly the same, but it always uses - # signed IntegerTypes which we must convert to signless for the native - # type system. - if signless: - if IntegerType.isinstance(torch_type): - signed_int_type = IntegerType(torch_type) - return IntegerType.get_signless(signed_int_type.width) - return torch_type - - def materialize_native_to_torch( - self, native_value: Value, torch_type: IrType, *, static_info_cast: bool = False - ) -> Value: - native_type = native_value.type - if RankedTensorType.isinstance(native_type): - # Convert to vtensor. - if static_info_cast: - required_native_type = self.torch_type_to_native(torch_type) - if required_native_type != native_type: - native_value = tensor_d.cast(required_native_type, native_value) - return Operation.create( - "torch_c.from_builtin_tensor", - results=[torch_type], - operands=[native_value], - ).result - elif IntegerType.isinstance(native_type): - # Convert to !torch.int - int_type = IntegerType(native_type) - width = int_type.width - if width == 1: - op_name = "torch_c.from_i1" - elif width == 64: - op_name = "torch_c.from_i64" - else: - raise TypeError( - f"Unsupported integer bit width for native->torch ABI: {int_type}" - ) - return Operation.create( - op_name, results=[torch_type], operands=[native_value] - ).result - elif F64Type.isinstance(native_type): - # Convert to !torch.float - return Operation.create( - "torch_c.from_f64", results=[torch_type], operands=[native_type] - ).result - else: - raise TypeError( - f"Unsupported native->torch ABI type conversion: {native_type} -> {torch_type}" - ) - - def materialize_torch_to_native( - self, torch_value: Value, *, static_info_cast_to: Optional[IrType] = None - ) -> Value: - native_type = self.torch_type_to_native(torch_value.type) - if RankedTensorType.isinstance(native_type): - # Convert to vtensor. - builtin_tensor_value = Operation.create( - "torch_c.to_builtin_tensor", - results=[native_type], - operands=[torch_value], - ).result - # Detect type difference and assume a static cast is needed. - if static_info_cast_to is not None and static_info_cast_to != native_type: - builtin_tensor_value = tensor_d.cast( - static_info_cast_to, builtin_tensor_value - ) - return builtin_tensor_value - elif IntegerType.isinstance(native_type): - # Convert to !torch.int - int_type = IntegerType(native_type) - width = int_type.width - if width == 1: - op_name = "torch_c.to_i1" - elif width == 64: - op_name = "torch_c.to_i64" - else: - raise TypeError( - f"Unsupported integer bit width for torch->native ABI: {int_type}" - ) - return Operation.create( - op_name, results=[native_type], operands=[torch_value] - ).result - elif F64Type.isinstance(native_type): - # Convert to !torch.float - return Operation.create( - "torch_c.to_f64", results=[native_type], operands=[torch_value] - ).result - else: - raise TypeError( - f"Unsupported torch->native ABI type conversion: {native_type} -> {native_type}" - ) - - -ShapedTypeDynamicSizeSentinel = ShapedType.get_dynamic_size() - - -def parse_tensor_dim_list(dim_list_str: str) -> List[int]: - if not dim_list_str: - return [] - comps = dim_list_str.split(",") - return [ShapedTypeDynamicSizeSentinel if d == "?" else int(d) for d in comps] diff --git a/core/shark_turbine/importers/README.md b/core/shark_turbine/importers/README.md deleted file mode 100644 index 9e47ca469..000000000 --- a/core/shark_turbine/importers/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Importers from various systems - -This directory is self-contained and intended to be shared with other -projects with its source-of-truth in torch-mlir. - -All MLIR API dependencies must route through the relative `ir.py`, which -it is expected that sub-projects will customize accordingly. diff --git a/core/shark_turbine/importers/ir.py b/core/shark_turbine/importers/ir.py deleted file mode 100644 index 4b4f94d74..000000000 --- a/core/shark_turbine/importers/ir.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from iree.compiler.ir import ( - ArrayAttr, - Attribute as Attribute, - Block, - Context, - DenseElementsAttr, - DenseResourceElementsAttr, - DictAttr, - FloatAttr, - BF16Type, - ComplexType, - F16Type, - F32Type, - F64Type, - Float8E4M3FNType, - Float8E5M2FNUZType, - Float8E5M2Type, - FunctionType, - InsertionPoint, - IntegerAttr, - IntegerType, - MLIRError, - RankedTensorType, - Location, - Module, - Operation, - StringAttr, - SymbolTable, - Type as IrType, - Value, -) - -from iree.compiler.dialects import ( - func as func_dialect, -) diff --git a/core/shark_turbine/importers/utils.py b/core/shark_turbine/importers/utils.py deleted file mode 100644 index a2047d544..000000000 --- a/core/shark_turbine/importers/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Any, Dict, List, Tuple - -import weakref - - -class TypeSubclassMap: - """Mapping of super-types to values. - - Maintains a cache of actual types seen and uses that instead of a linear - scan. - """ - - __slots__ = [ - "_cache", - "_mapping", - ] - - def __init__(self): - # The linear list of converters. - self._mapping: List[Tuple[type, Any]] = [] - # When there is a hit on the linear mapping, memoize it here. - self._cache: Dict[type, Any] = {} - - def map(self, t: type, value: Any): - self._mapping.append((t, value)) - self._cache[t] = value - - def lookup(self, t: type) -> Any: - try: - return self._cache[t] - except KeyError: - pass - for t_super, value in self._mapping: - if issubclass(t, t_super): - self._cache[t] = value - return value - else: - self._cache[t] = None - return None diff --git a/core/shark_turbine/kernel/__init__.py b/core/shark_turbine/kernel/__init__.py deleted file mode 100644 index 333dce24f..000000000 --- a/core/shark_turbine/kernel/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from . import gen -from . import lang - - -# Helpers that are good to have in the global scope. -def __getattr__(name): - if name == "DEBUG": - return lang.is_debug() - raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - - -# Dynamic attributes so that IDEs see them. -DEBUG: bool diff --git a/core/shark_turbine/kernel/_support/context.py b/core/shark_turbine/kernel/_support/context.py deleted file mode 100644 index 31aec520c..000000000 --- a/core/shark_turbine/kernel/_support/context.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Optional, Type, TypeVar - -import threading - -_tls = threading.local() - -T = TypeVar("T") - - -def push(context_type: Type[T], instance: T) -> T: - """Pushes an instance onto a thread-local context stack. - - The context type must define an attribute __tk_context_idname__ which is - a valid/unique identifier. - """ - assert isinstance(instance, context_type) - key = context_type.__tk_context_idname__ - try: - stack: list = getattr(_tls, key) - except AttributeError: - stack = [] - setattr(_tls, key, stack) - stack.append(instance) - return instance - - -def pop(context_type: Type[T], expected: Optional[T] = None): - """Pops the current context off of the stack. - - Raises IndexError if no current. - """ - stack: list = getattr(_tls, context_type.__tk_context_idname__) - instance = stack.pop() - assert ( - expected is None or expected is instance - ), f"mismatched context push/pop for {context_type}" - - -def current(context_type: Type[T]) -> T: - """Returns the current context from the stack. - - Raises IndexError on failure. - """ - try: - stack: list = getattr(_tls, context_type.__tk_context_idname__) - except AttributeError: - raise IndexError(f"No current context for {context_type}") - try: - instance = stack[-1] - except IndexError: - raise IndexError(f"No current context for {context_type}") - assert isinstance(instance, context_type) - return instance diff --git a/core/shark_turbine/kernel/_support/dtype.py b/core/shark_turbine/kernel/_support/dtype.py deleted file mode 100644 index 2b56012ef..000000000 --- a/core/shark_turbine/kernel/_support/dtype.py +++ /dev/null @@ -1,59 +0,0 @@ -__all__ = [ - "DataType", - "bool", - "i4", - "i8", - "i16", - "i32", - "i64", - "f16", - "f32", - "f64", - "index", -] - -_INT_TYPES = ["i1", "i4", "i8", "i16", "i32", "i64"] -_FLOAT_TYPES = ["f16", "f32", "f64"] -_INDEX_TYPES = ["index"] - - -# TODO: this should really be a type. -class DataType: - _name: str - _ir_type_asm: str - - def __init__(self, name, ir_type_asm=None): - self._name = name - self._ir_type_asm = ir_type_asm if ir_type_asm else name - - def ir_type_asm(self): - return self._ir_type_asm - - def __str__(self): - return self._name - - def __repr__(self): - return f"DataType({self._ir_type_asm})" - - def is_int_asm(self): - return self._name in _INT_TYPES - - def is_float_asm(self): - return self._name in _FLOAT_TYPES - - def is_index_asm(self): - return self._name in _INDEX_TYPES - - -bool = DataType("bool", "i1") -i4 = DataType("i4") -i8 = DataType("i8") -i16 = DataType("i16") -i32 = DataType("i32") -i64 = DataType("i64") -f32 = DataType("f32") -f64 = DataType("f64") -f16 = DataType("f16") -f32 = DataType("f32") -f64 = DataType("f64") -index = DataType("index") diff --git a/core/shark_turbine/kernel/_support/indexing.py b/core/shark_turbine/kernel/_support/indexing.py deleted file mode 100644 index 88e8584f5..000000000 --- a/core/shark_turbine/kernel/_support/indexing.py +++ /dev/null @@ -1,387 +0,0 @@ -from typing import Any, ClassVar, Optional, Type, TypeVar, Union - -from abc import ABC -from dataclasses import dataclass - -import sympy - -from . import context -from . import dtype -from .shaped_type import ShapedType, ShapedDataType - -__all__ = [ - "backed_sym_index_type", - "sym", - "BoundedRelation", - "EqualRelation", - "IndexingContext", - "IndexRelation", - "IndexExpr", - "IndexSymbol", - "SymIndex", -] - -DataType = dtype.DataType -DefaultDataType = dtype.f32 - - -class NotSetType: - ... - - -NotSet = NotSetType() - -SubtypeT = TypeVar("SubtypeT") - -############################################################################### -# Index symbols and expressions -# These are just light-weight helpers around sympy symbols and expressions. -############################################################################### - -IndexSymbol = sympy.core.Symbol -IndexExpr = sympy.core.Expr - - -def index_symbol(name: str) -> IndexSymbol: - """Returns a named symbol, assumed to be a non-negative integer.""" - return sympy.Symbol(name, integer=True, nonnegative=True) - - -def index_expr(value: Any) -> IndexExpr: - expr = sympy.sympify(value) - return expr - - -class _IndexSymbolExpando: - def __getattr__(self, n): - return index_symbol(n) - - -sym = _IndexSymbolExpando() - -############################################################################### -# Shape expressions -############################################################################### - -SymbolicDimable = Union[str, IndexExpr] -SymbolicShapeable = tuple[SymbolicDimable] -SymbolicShapeExpr = tuple[IndexExpr] -Dims = list[Union[None, IndexSymbol, int]] - -############################################################################### -# IndexingContext -############################################################################### - - -@dataclass(slots=True) -class _ShapedBinding: - # The instance of shaped_type. Can be anything. We resolve dimension values - # against this. - instance: Any - - # Shaped type that backes the instance. - shaped_type: ShapedType - - # The symbolic shape (tuple of index expressions). - symbolic_shape: list[IndexExpr] - - # Concrete dimensions instantiated with. Each is an integer or a dynamic - # dim symbol. It can also be None if the value is not dynamic and must be - # inferred from context. - dims: Dims - - -class IndexingContext: - """The indexing context is responsible handling the binding of indexed - symbols to concrete values. - """ - - __slots__ = [ - "subs", - "shaped_bindings", - "dyn_dims", - "frozen_subs", - "unbacked_symbols", - ] - - __tk_context_idname__ = "IndexingContext" - - def __init__(self): - self.subs: dict[IndexSymbol, int] = {} - # Indexed by .instance - self.shaped_bindings: dict[Any, _ShapedBinding] = {} - self.dyn_dims: list[IndexSymbol] = [] - self.frozen_subs: list[tuple[IndexSymbol, int]] = [] - self.unbacked_symbols: list[IndexSymbol] = [] - - def next_dyn_dim(self) -> IndexSymbol: - s = index_symbol(f"D{len(self.dyn_dims)}") - self.dyn_dims.append(s) - return s - - def new_unbacked_symbol(self) -> IndexSymbol: - s = index_symbol(f"_S{len(self.unbacked_symbols)}") - self.unbacked_symbols.append(s) - return s - - def bind_shaped(self, instance: Any, shaped_type: ShapedType, dims: Dims) -> None: - if instance in self.shaped_bindings: - raise ValueError(f"Argument binding {instance} is already bound") - symbolic_shape = shaped_type.symbolic_shape - rank = shaped_type.rank - if rank != len(dims): - raise ValueError( - f"For {shaped_type} mismatched symbolic shape vs dim arity: {symbolic_shape} vs {dims}" - ) - binding = _ShapedBinding( - instance, shaped_type, list(symbolic_shape), list(dims) - ) - self.shaped_bindings[instance] = binding - - def bind_constant(self, sym: IndexSymbol, value: int) -> None: - try: - self._bind_symbol(sym, value) - except ValueError: - raise ValueError( - f"Attempt to bind symbol {sym}={value} conflicts with previous " - f"{self.subs[sym]}" - ) - - def _bind_symbol(self, symbol: IndexSymbol, value: int): - existing = self.subs.get(symbol) - if existing is not None and existing != value: - raise ValueError - self.subs[symbol] = value - - def finalize(self): - assert len(self.frozen_subs) == 0 - # Go over everything we know and bind all free symbols. - for _sb in self.shaped_bindings.values(): - for i in range(_sb.shaped_type.rank): - dim_expr = _sb.symbolic_shape[i] - dim_value = _sb.dims[i] - if dim_value is not None: - if isinstance(dim_expr, IndexSymbol): - try: - self._bind_symbol(dim_expr, dim_value) - except ValueError as e: - raise ValueError( - f"For {_sb.instance} of {_sb.shaped_type} attempt to bind dim " - f"{dim_expr}={dim_value} conflicts with previous " - f"{self.subs[dim_expr]}" - ) - - # Note: At this point, we could solve the set of equation based - # bindings and maybe elicit some additional information, but for now - # we do forward-only inference. - frozen_subs = self.frozen_subs - frozen_subs.extend(self.subs.items()) - - # Check any equation based dims. - errors = [] - for _sb in self.shaped_bindings.values(): - for i in range(_sb.shaped_type.rank): - dim_expr = _sb.symbolic_shape[i] - dim_value = _sb.dims[i] - dim_expr = dim_expr.subs(frozen_subs).simplify() - _sb.symbolic_shape[i] = dim_expr - if dim_value is None: - # Ensure resolves to a known value. - if not isinstance(dim_expr, sympy.Integer): - errors.append( - f" {_sb.instance} of {_sb.shaped_type}[{i}]={dim_expr} did not " - f"resolve to a known value" - ) - continue - # Notate the inferred dim. - _sb.dims[i] = int(dim_expr) - elif isinstance(dim_expr, sympy.Integer): - dim_expr_value = int(dim_expr) - if isinstance(dim_value, IndexExpr): - # If dynamic, then it turns out we have enough static information, - # so replace. - _sb.dims[i] = dim_expr_value - else: - # If static, make sure it matches the runtime value. - if dim_value is not None and dim_expr_value != dim_value: - errors.append( - f" {_sb.instance} of {_sb.shaped_type}[{i}]={dim_expr} was initialized with a " - f"mismatched runtime value of {dim_value}" - ) - continue - - # Error check. - if errors: - joined = "\n".join(errors) - raise ValueError(f"Indexing mismatches were encountered:\n{joined}") - - def eval_dim(self, instance: Any, shaped_type: ShapedType, pos: int) -> IndexExpr: - # TODO: Could see if shaped_type is in self.shaped_bindings: it has some - # precomputed values that may save cycles to use. - symbolic_shape = shaped_type.symbolic_shape - try: - expr = symbolic_shape[pos] - except IndexError: - raise IndexError(f"Attempt to access out of range {shaped_type}[{pos}]") - return expr.subs(self.frozen_subs).simplify() - - def eval_static_dim( - self, instance: Any, shaped_type: ShapedType, pos: int - ) -> Optional[int]: - expr = self.eval_dim(instance, shaped_type, pos) - try: - return int(expr) - except TypeError: - return None - - def simplify_expr(self, expr: IndexExpr) -> IndexExpr: - return expr.subs(self.frozen_subs).simplify() - - def get_static_value(self, expr: IndexExpr) -> Optional[int]: - expr = self.simplify_expr(expr) - try: - return int(expr) - except TypeError: - return None - - ##### Context management. - @staticmethod - def current() -> "IndexingContext": - return context.current(IndexingContext) - - def __enter__(self) -> "IndexingContext": - return context.push(IndexingContext, self) - - def __exit__(self, exc_type, exc_val, exc_tb): - context.pop(IndexingContext, self) - - -############################################################################### -# Symbolic index value type. -# TODO: We think we want to remove this in the next rev, in favor of doing -# relationship verification as part of a pass. -############################################################################### - - -class IndexRelation(ABC): - """ABC for assumptions that can be made about an index value.""" - - __slots__ = [] - - -class EqualRelation(IndexRelation): - """An index assumption that can take a single symbolic value.""" - - __slots__ = ["eq_expr"] - - def __init__(self, eq_expr: IndexExpr): - self.eq_expr = eq_expr - - def __eq__(self, other): - if not isinstance(other, EqualRelation): - return False - return self.eq_expr == other.eq_expr - - def __repr__(self): - expr = self.eq_expr - if isinstance(expr, IndexSymbol): - return f"=={expr}" - else: - return f"==({expr})" - - -class BoundedRelation(IndexRelation): - """An index assumption that can take any value in a range.""" - - __slots__ = [ - "lower_expr", - "lower_inclusive", - "upper_expr", - "upper_inclusive", - ] - - def __init__( - self, - lower_expr: Any, - upper_expr: Any, - *, - lower_inclusive: bool = True, - upper_inclusive: bool = True, - ): - self.lower_expr = index_expr(lower_expr) - self.lower_inclusive = lower_inclusive - self.upper_expr = index_expr(upper_expr) - self.upper_inclusive = upper_inclusive - - def __eq__(self, other): - if not isinstance(other, BoundedRelation): - return False - return ( - self.lower_inclusive == other.lower_inclusive - and self.upper_inclusive == other.upper_inclusive - and self.lower_expr == other.lower_expr - and self.upper_expr == other.upper_expr - ) - - def __repr__(self): - return ( - f"∈{'[' if self.lower_inclusive else '('}" - f"{self.lower_expr}, {self.upper_expr}" - f"{']' if self.upper_inclusive else ')'}" - ) - - -class _SymIndexMeta(type): - """Meta-class for a concrete symbolic index value.""" - - def __new__( - mcls, - name: str, - bases, - dct, - *, - assumption: Optional[IndexRelation], - ): - new_class = type.__new__(mcls, name, bases, dct) - new_class.assumption = assumption - new_class.__qualname__ = repr(new_class) - return new_class - - def __repr__(self): - if self.assumption: - return f"SymIndex{self.assumption}" - else: - return "UnbackedSymIndex" - - -class SymIndex(metaclass=_SymIndexMeta, assumption=None): - """Symbolic index value defined for an assumption. - - The base type is unbacked (None assumption). - """ - - __slots__ = [ - "symbol", - ] - - assumption: ClassVar[Optional[IndexRelation]] - - def __init__(self, symbol: IndexSymbol): - self.symbol = symbol - - def __repr__(self): - return f"<'{self.symbol}' over {type(self)}>" - - def cast(self, cast: Type["SymIndex"]) -> "SymIndex": - """Cast the SymIndex to a new type, typically to further constrain it. - - The new instance shares the symbol. - """ - return cast(self.symbol) - - -def backed_sym_index_type(assumption: IndexRelation) -> Type[SymIndex]: - class BackedSymIndex(SymIndex, assumption=assumption): - ... - - return BackedSymIndex diff --git a/core/shark_turbine/kernel/_support/regions.py b/core/shark_turbine/kernel/_support/regions.py deleted file mode 100644 index d39b15a8c..000000000 --- a/core/shark_turbine/kernel/_support/regions.py +++ /dev/null @@ -1,160 +0,0 @@ -from typing import ( - Optional, - TypeVar, - Callable, - Type, - cast, - List, - Dict, - Tuple, -) -import random -import contextlib - -import torch.fx as fx -import torch.utils._pytree as pytree - - -class RegionGraph: - def __init__(self): - self.tracers: List["SubgraphTracer"] = [] - self.subgraphs: Dict[str, fx.Graph] = dict() - self.inner_freevars: Dict[fx.Graph, List[fx.Proxy]] = dict() - - @property - def root_tracer(self) -> "SubgraphTracer": - return self.tracers[0] - - @property - def current_tracer(self) -> "SubgraphTracer": - return self.tracers[-1] - - def create_proxy(self, *args, **kwargs): - return self.current_tracer.create_proxy(*args, **kwargs) - - def create_node(self, *args, **kwargs): - return self.current_tracer.create_node(*args, **kwargs) - - def create_arg(self, *args, **kwargs): - return self.current_tracer.create_arg(*args, **kwargs) - - def new_subtracer( - self, region_graph: "RegionGraph", parent: Optional["SubgraphTracer"] = None - ) -> "SubgraphTracer": - ... - - ### ======================================================================== - ### Subgraph Tracing - ### ======================================================================== - def add_subgraph( - self, name: str, graph: fx.Graph, inner_freevars: List[fx.Proxy] - ) -> str: - i = 0 - while True: - candidate_name = f"{name}_{i}" - i += 1 - if candidate_name not in self.subgraphs: - self.subgraphs[candidate_name] = graph - self.inner_freevars[graph] = inner_freevars - return candidate_name - - @contextlib.contextmanager - def subtracer(self): - if self.tracers: - new_tracer = self.new_subtracer(self, self.current_tracer) - else: - new_tracer = self.new_subtracer(self) - self.tracers.append(new_tracer) - yield new_tracer - self.tracers.pop() - - def __str__(self): - out = "" - for name, subgraph in self.subgraphs.items(): - out += f"{name}:" - out += str(subgraph) - out += "\n" - return out - - -class SubgraphTracer(fx.Tracer): - def __init__( - self, region_graph: RegionGraph, parent: Optional["SubgraphTracer"] = None - ): - super().__init__() - self.graph = fx.Graph() - self.region_graph = region_graph - self.parent = parent - self.lifted_freevars: Dict[fx.Proxy, fx.Proxy] = {} - - def trace(self, *args, **kwargs) -> Tuple[str, List[fx.Proxy]]: - traced = super().trace(*args, **kwargs) - inner_freevars = list(self.lifted_freevars.values()) - implicit_capture = list(self.lifted_freevars.keys()) - subgraph_name = self.region_graph.add_subgraph("region", traced, inner_freevars) - return subgraph_name, implicit_capture - - def _create_graph_input(self, name: str, type_expr=None) -> fx.Proxy: - proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) - # Can use this to check where the freevar has been lifted from. - proxy.node.meta["lifted"] = None - return proxy - - def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy): - # It makes no sense for the root graph to have free variables - assert self.parent is not None, "Cannot lift freevars to input in root tracer" - - # If the freevar has already been lifted, return the lifted version. - if proxy in self.lifted_freevars: - return self.lifted_freevars[proxy] - - # Otherwise, create a new input and store it. - new_proxy = self._create_graph_input(proxy.node.name, proxy.node.type) - self.lifted_freevars[proxy] = new_proxy - - # Propagate freevar usage upwards. - if self.parent is not None and proxy.tracer != self.parent: - self.parent._lift_tracked_freevar_to_input(proxy) - return new_proxy - - def _maybe_lift_tracked_freevar_to_input(self, arg): - """ - If arg is a free variable, then lift it to be an input. - Returns the new lifted arg (if lifted), else the original arg. - """ - if not isinstance(arg, fx.Proxy): - return arg - elif arg.tracer == self: - return arg - else: - return self._lift_tracked_freevar_to_input(arg) - - def create_proxy( - self, - kind, - target, - args, - kwargs, - name=None, - type_expr=None, - proxy_factory_fn=None, - ): - if self.parent is not None: - flat_args, tree_spec = pytree.tree_flatten((args, kwargs)) - new_flat_args = [] - for arg in flat_args: - maybe_new_arg = self._maybe_lift_tracked_freevar_to_input(arg) - new_flat_args.append(maybe_new_arg) - args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec) - - rv = super().create_proxy( - kind, - target, - args, - kwargs, - name, - type_expr, - proxy_factory_fn, - ) - - return rv diff --git a/core/shark_turbine/kernel/_support/shaped_type.py b/core/shark_turbine/kernel/_support/shaped_type.py deleted file mode 100644 index f5bd25cf4..000000000 --- a/core/shark_turbine/kernel/_support/shaped_type.py +++ /dev/null @@ -1,136 +0,0 @@ -import typing -from typing import Optional, Type, TypeVar, cast - -from .dtype import DataType - -if typing.TYPE_CHECKING: - from .indexing import IndexExpr - -SymbolicShapeExpr = tuple["IndexExpr", ...] - -SubtypeT = TypeVar("SubtypeT") - -############################################################################### -# Shaped Type -############################################################################### - - -def _shaped_data_type_repr( - name: str, - *, - symbolic_shape: Optional[SymbolicShapeExpr], - dtype: Optional[DataType] = None, -) -> str: - stem = name - if symbolic_shape: - stem += f"[{', '.join(repr(s) for s in symbolic_shape)}]" - if dtype: - stem += f".of({dtype})" - return stem - - -class ShapedType(type): - """A shaped type. - - This lets us specialize with symbolic shape information. - """ - - symbolic_shape: Optional[SymbolicShapeExpr] = None - rank: Optional[int] - - def __new__(mcls, name: str, bases, dct): - symbolic_shape = dct.get("symbolic_shape") - if symbolic_shape is not None: - rank = len(symbolic_shape) - dct["rank"] = rank - - # TODO: I don't know a better way to do this. Ask Stella for better way. - if "__qualname__" not in dct: - dct["__qualname__"] = _shaped_data_type_repr( - name, symbolic_shape=symbolic_shape - ) - - new_class = type.__new__(mcls, name, bases, dct) - return new_class - - def new_shaped_subtype( - cls: Type[SubtypeT], - *, - symbolic_shape: SymbolicShapeExpr, - ) -> Type[SubtypeT]: - init_symbolic_shape = symbolic_shape - - class Subtype(cls): - symbolic_shape = init_symbolic_shape - rank = len(init_symbolic_shape) - - Subtype.__name__ = cls.__name__ - - return cast(Type[SubtypeT], Subtype) - - def __str__(cls): - return repr(cls) - - def __repr__(cls): - return _shaped_data_type_repr(cls.__name__, symbolic_shape=cls.symbolic_shape) - - -############################################################################### -# Shaped Data Type -############################################################################### - - -class ShapedDataType(ShapedType): - """A shaped type containing data of a specific element type. - - This lets us specialize with symbolic shape information. - """ - - dtype: Optional[DataType] = None - - def __new__( - mcls, - name: str, - bases, - dct, - ): - shaped_type = dct.get("shaped_type") - dtype = dct.get("dtype") - - if "__qualname__" not in dct: - dct["__qualname__"] = _shaped_data_type_repr( - name, - symbolic_shape=shaped_type, - dtype=dtype, - ) - - new_class = type.__new__(mcls, name, bases, dct) - return new_class - - def new_shaped_data_subtype( - cls: Type[SubtypeT], - *, - symbolic_shape: SymbolicShapeExpr, - dtype: DataType, - ) -> Type[SubtypeT]: - init_symbolic_shape = symbolic_shape - init_dtype = dtype - - class Subtype(cls): - symbolic_shape = init_symbolic_shape - rank = len(init_symbolic_shape) - dtype = init_dtype - - Subtype.__name__ = cls.__name__ - - return cast(Type[SubtypeT], Subtype) - - def __str__(cls): - return repr(cls) - - def __repr__(cls): - return _shaped_data_type_repr( - cls.__name__, - symbolic_shape=cls.symbolic_shape, - dtype=cls.dtype, - ) diff --git a/core/shark_turbine/kernel/_support/tracing.py b/core/shark_turbine/kernel/_support/tracing.py deleted file mode 100644 index b89818829..000000000 --- a/core/shark_turbine/kernel/_support/tracing.py +++ /dev/null @@ -1,498 +0,0 @@ -from abc import ABC, abstractmethod -from typing import ( - Optional, - TypeVar, - Callable, - Type, - cast, - Dict, - Tuple, -) - -from ..compiler.ir import Operation - -import functools -import warnings - -import torch.fx as fx - -from .indexing import ( - backed_sym_index_type, - BoundedRelation, - IndexExpr, - IndexSymbol, - IndexingContext, -) - -from ..lang.kernel_buffer import KernelBuffer -from ..lang.grid import Grid - -from ..lang.types import ( - Index, -) - -from .regions import RegionGraph, SubgraphTracer - -from .. import ops -from ..ops.base import ( - OpDispatcher, -) - -from . import context -from .dtype import DataType - -try: - from typing import assert_type -except ImportError: - # No-op if not supported. Introduced in Python 3.11. - def assert_type(a, b): - pass - - -TCallable = TypeVar("TCallable", bound=Callable) - -############################################################################### -# Kernel Region Graph -############################################################################### - - -class KernelRegionGraph(RegionGraph): - def new_subtracer( - self, - region_graph: "RegionGraph", - parent: Optional["SubgraphTracer"] = None, - ) -> "KernelTracer": - return KernelTracer(region_graph, parent=parent) - - -############################################################################### -# Tracing machinery -############################################################################### - - -class KernelBufferProxy(fx.Proxy): - """Custom proxy for KernelBuffer so that we can override special methods.""" - - def __init__( - self, - node: fx.Node, - tracer: "KernelTracer", - orig_type: Type[KernelBuffer], - ): - super().__init__(node, tracer) - self._orig_type = orig_type - # The shape and rank are statically available (not proxied). - self.symbolic_shape = orig_type.symbolic_shape - self.rank = orig_type.rank - - def __getitem__(self, key): - return ops.kernel_buffer_getitem(self, key) - - def __setitem__(self, key, item): - ops.kernel_buffer_setitem(self, key, item) - - -class KernelTracer(SubgraphTracer): - """Custom Tracer for generating a trace of a kernel computation.""" - - # Register our custom proxies. - def proxy(self, node: fx.Node) -> fx.Proxy: - t = node.type - if t is not None: - if issubclass(t, KernelBuffer): - return KernelBufferProxy(node, self, t) - return super().proxy(node) - - def create_arg(self, a): - # Let IndexExpr persist as arguments. - if isinstance(a, IndexExpr): - return a - # Let DataType persist as arguments. - if isinstance(a, DataType): - return a - return super().create_arg(a) - - -class CapturedTrace: - def __init__(self, region_graph: RegionGraph, root_graph: str): - self.region_graph = region_graph - self.root_graph = root_graph - - def get_subgraph(self, name: str) -> fx.Graph: - return self.region_graph.subgraphs[name] - - def get_root_graph(self) -> fx.Graph: - return self.get_subgraph(self.root_graph) - - -############################################################################### -# Execution context. -# A valid BaseContext derived instance (EagerContext or CompiledContext) must -# be active for any evaluation of a generated/traced function. -############################################################################### - - -class BaseContext(OpDispatcher): - __tk_context_idname__ = "ExecutionContext" - - def __init__(self, *, eager: bool): - self.eager = eager - - @staticmethod - def current() -> "BaseContext": - return context.current(BaseContext) - - def __enter__(self) -> "BaseContext": - context.push(OpDispatcher, self) - return context.push(BaseContext, self) - - def __exit__(self, exc_type, exc_val, exc_tb): - context.pop(OpDispatcher, self) - context.pop(BaseContext, self) - - -class EagerContext(BaseContext): - def __init__(self, rank: int = 0): - super().__init__(eager=True) - self.rank = rank - self.current_thread: list[int] = rank * [0] - - def handle_thread_program_id(self, op, axis: int) -> int: - assert axis >= 0 and axis < self.rank - return Index(self.current_thread[axis]) - - def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key): - return kernel_buffer._tensor.__getitem__(key) - - def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, item): - kernel_buffer._tensor.__setitem__(key, item) - - -class CompiledContext(BaseContext): - def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]): - super().__init__(eager=False) - self.region_graph = region_graph - self.grid_type = grid_type - self.current_thread_types = [ - backed_sym_index_type(BoundedRelation(0, n, upper_inclusive=False)) - for n in grid_type.symbolic_shape - ] - - ### ======================================================================== - ### Core Operations - ### ======================================================================== - - def handle_thread_program_id(self, op, axis: int) -> Index: - grid_types = self.current_thread_types - if axis < 0 or axis >= len(grid_types): - raise IndexError( - f"Illegal index into grid of rank {len(grid_types)}: {axis}" - ) - - proxy = self.region_graph.create_proxy( - "call_function", - op, - args=(axis,), - kwargs={}, - type_expr=grid_types[axis], - ) - return proxy - - def handle_to_dtype(self, op, val, dtype): - return self.region_graph.create_proxy( - "call_function", - op, - args=(val, dtype), - kwargs={}, - ) - - def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key): - return self.region_graph.create_proxy( - "call_function", - op, - args=(kernel_buffer, key), - kwargs={}, - ) - - def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, item): - self.region_graph.create_proxy( - "call_function", - target=op, - args=(kernel_buffer, key, item), - kwargs={}, - ) - - ### ======================================================================== - ### Memory Operations - ### ======================================================================== - def handle_kernel_buffer_load(self, op, kernel_buffer, multi_index, shape): - return self.region_graph.create_proxy( - "call_function", - target=op, - args=(kernel_buffer, multi_index, shape), - kwargs={}, - ) - - def handle_kernel_buffer_store(self, op, kernel_buffer, multi_index, item): - self.region_graph.create_proxy( - "call_function", - target=op, - args=(kernel_buffer, multi_index, item), - kwargs={}, - ) - - ### ======================================================================== - ### Control Flow Operations - ### ======================================================================== - - def handle_for_loop(self, op, start, stop=None, step=None, init_args=[]): - if stop is None: - stop = start - start = 0 - if step is None: - step = 1 - - def wrapper(f): - with self.region_graph.subtracer() as subtracer: - subgraph_name, implicit_capture = subtracer.trace(f) - # Create a call to this subgraph - ret = self.region_graph.create_proxy( - "call_function", - target=op, - name="for_loop", - args=(start, stop, step, init_args), - kwargs={ - "subgraph": subgraph_name, - "implicit_capture": implicit_capture, - }, - ) - return ret - - return wrapper - - ### ======================================================================== - ### Math Operations - ### ======================================================================== - def handle_exp2(self, op, val): - return self.region_graph.create_proxy( - "call_function", - target=op, - args=(val,), - kwargs={}, - ) - - def handle_vector_constant( - self, op, shape: Tuple[int, ...], dtype, value: int | float - ): - return self.region_graph.create_proxy( - "call_function", - target=op, - args=(shape, dtype, value), - kwargs={}, - ) - - ### ======================================================================== - ### Reduction Operations - ### ======================================================================== - def handle_vector_max(self, op, vector, axis=None, acc=None): - return self.region_graph.create_proxy( - "call_function", - target=op, - args=(vector, axis, acc), - kwargs={}, - ) - - def handle_vector_sum(self, op, vector, axis=None, acc=None): - return self.region_graph.create_proxy( - "call_function", - target=op, - args=(vector, axis, acc), - kwargs={}, - ) - - def handle_vector_dot(self, op, lhs, rhs, acc=None): - return self.region_graph.create_proxy( - "call_function", - target=op, - args=(lhs, rhs, acc), - kwargs={}, - ) - - ### ======================================================================== - ### Shape Manipulation Operations - ### ======================================================================== - def handle_vector_broadcast(self, op, vector, leading_sizes): - return self.region_graph.create_proxy( - "call_function", - target=op, - args=(vector, leading_sizes), - kwargs={}, - ) - - def handle_vector_broadcast_in_dim(self, op, vector, shape, broadcast_dimensions): - # Currently, we do not have a corressponding op in MLIR, so - # we trace this to broadcast + transpose. - # TODO: Add a vector dialect op for this in MLIR. - - # Remove broadcast_dimensions from shape. - shape_with_leading = tuple( - dim for i, dim in enumerate(shape) if i not in broadcast_dimensions - ) - - # Broadcast - broadcasted_vector = self.region_graph.create_proxy( - "call_function", - target=ops.vector_broadcast, - args=(vector, shape_with_leading), - kwargs={}, - ) - - # Get the permutation for the transpose. - permutation = tuple( - i for i in range(len(shape)) if i not in broadcast_dimensions - ) - permutation = permutation + tuple(broadcast_dimensions) - - # Transpose - return self.region_graph.create_proxy( - "call_function", - target=ops.vector_transpose, - args=(broadcasted_vector, permutation), - kwargs={}, - ) - - def handle_vector_transpose(self, op, vector, permutation): - return self.region_graph.create_proxy( - "call_function", - target=op, - args=(vector, permutation), - kwargs={}, - ) - - -############################################################################### -# Launch context -# The launch context controls how the call into a kernel is dispatched. -# This can either be to run it eagerly for debugging or some higher order -# integration. -############################################################################### - - -class Launchable(ABC): - """Base class for objects which behave like a kernel launch when called.""" - - def __init__(self, eager_function: Callable): - self._eager_function = eager_function - - def __call__(self, *args, **kwargs): - launch_context = LaunchContext.current() - return launch_context.launch(self, args, kwargs) - - @abstractmethod - def eager_execute(self, args, kwargs): - ... - - def aot_execute(self, args, kwargs): - ... - - def test_execute(self, args, kwargs): - ... - - -class LaunchContext(ABC): - __tk_context_idname__ = "ExecutionContext" - - def __init__(self, constant_bindings: Dict[IndexSymbol, int] = {}): - self.constant_bindings = constant_bindings - - @staticmethod - def current() -> "LaunchContext": - try: - return context.current(LaunchContext) - except IndexError: - warnings.warn( - "defaulting to debug/eager execution of tk kernel launch " - "because no launch context has been established" - ) - return DebugLaunchContext() - - def __enter__(self) -> "LaunchContext": - # Push an indexing context with the constand bindings for this launch - # context in it. - # TODO: Is creating a IndexingContext as part of LaunchContext the - # correct layering? - idxc = IndexingContext() - context.push(IndexingContext, idxc) - for s, val in self.constant_bindings.items(): - idxc.bind_constant(s, val) - return context.push(LaunchContext, self) - - def __exit__(self, exc_type, exc_val, exc_tb): - # Pop the indexing context created as part of this launch. - # TODO: Is creating a IndexingContext as part of LaunchContext the - # correct layering? - context.pop(IndexingContext, IndexingContext().current()) - context.pop(LaunchContext, self) - - @abstractmethod - def launch(self, launchable: Launchable, args, kwargs): - ... - - -class DebugLaunchContext(LaunchContext): - def launch(self, launchable: Launchable, args, kwargs): - return launchable.eager_execute(args, kwargs) - - -class TestLaunchContext(LaunchContext): - def launch(self, launchable: Launchable, args, kwargs): - return launchable.test_execute(args, kwargs) - - -class AOTLaunchContext(LaunchContext): - module: Operation - - def __init__( - self, module: Operation, constant_bindings: Dict[IndexSymbol, int] = {} - ): - self.module = module - super().__init__(constant_bindings) - - def launch(self, launchable: Launchable, args, kwargs): - return launchable.aot_execute(args, kwargs) - - -############################################################################### -# Helpers -############################################################################### - - -def eager_context() -> EagerContext: - context = BaseContext.current() - assert context.eager, "Expected to be executed against an EagerContext" - assert_type(context, EagerContext) - return context - - -def custom_primitive_fn( - f: Optional[TCallable] = None, *, compiled: Callable -) -> TCallable: - """Decorator for a primitive function with a custom callback for tracing. - - The wrapped function will be invoked as-is when executing eagerly. When - tracing, the `compiled` callback will be invoked with the same signature - but with the `CompiledContext` added as a first postional argument. - """ - if f is None: - return functools.partial(custom_primitive_fn, compiled=compiled) - - @functools.wraps(f) - def wrapper(*args, **kwargs): # type: ignore - context = BaseContext.current() - if context.eager: - return f(*args, **kwargs) - else: - assert_type(context, CompiledContext) - return compiled(context, *args, **kwargs) - - return cast(TCallable, wrapper) diff --git a/core/shark_turbine/kernel/compiler/base.py b/core/shark_turbine/kernel/compiler/base.py deleted file mode 100644 index 6af9ec02a..000000000 --- a/core/shark_turbine/kernel/compiler/base.py +++ /dev/null @@ -1,9 +0,0 @@ -NDEBUG = False - - -class CodegenError(Exception): - ... - - -class ValidationError(CodegenError): - ... diff --git a/core/shark_turbine/kernel/compiler/builder.py b/core/shark_turbine/kernel/compiler/builder.py deleted file mode 100644 index c7231250b..000000000 --- a/core/shark_turbine/kernel/compiler/builder.py +++ /dev/null @@ -1,310 +0,0 @@ -from typing import Any, Optional, Union - -from .._support.indexing import ( - IndexExpr, - SymIndex, -) - -from .base import ( - CodegenError, - NDEBUG, -) - -from .ir import ( - Attribute, - Context, - FloatAttr, - IndexType, - IntegerAttr, - IntegerType, - DenseElementsAttr, - IrType, - Location, - Operation, - SymbolTable, - Value, - VectorType, - arith_d, - math_d, - builtin_d, - F16Type, - F32Type, - F64Type, -) - -# TODO: Use FloatType from upstream when available. -FLOAT_BITWIDTHS = { - "bf16": 16, - "f16": 16, - "f32": 32, - "f64": 64, - # TODO: FP8 types. -} - - -class IRProxyValue: - """Wrapper around an (ir.Value, py_value) for handling notionally python - proxies that are associated with an IR Value. - """ - - __slots__ = [ - "ir_value", - "py_value", - ] - - def __init__(self, ir_value: Value, py_value: Any = None): - self.ir_value = ir_value - self.py_value = py_value - assert NDEBUG or self.validate() - - def validate(self): - assert isinstance(self.ir_value, Value), f"Got {type(self.ir_value)}" - return True - - def __repr__(self): - return f"" - - -class ModuleBuilder: - def __init__( - self, - *, - context: Optional[Context] = None, - module_op: Optional[Operation] = None, - ): - if module_op: - self.module_op = module_op - self.body_block = module_op.regions[0].blocks[0] - else: - if not context: - context = Context() - self.module_op = builtin_d.ModuleOp(loc=Location.unknown(context)) - self.body_block = self.module_op.body - self.context = self.module_op.context - self.unknown_loc = Location.unknown(self.context) - self.symbol_table = SymbolTable(self.module_op) - - -class _ScalarBuilder: - def is_floating_point_type(self, t: IrType) -> bool: - # TODO: Use FloatType from upstream when available. - return str(t) in FLOAT_BITWIDTHS - - def is_integer_type(self, t: IrType) -> bool: - return IntegerType.isinstance(t) - - def is_index_type(self, t: IrType) -> bool: - return IndexType.isinstance(t) - - def get_typeclass(self, t: IrType, index_same_as_integer=False) -> str: - # If this is a vector type, get the element type. - if isinstance(t, VectorType): - t = t.element_type - if self.is_floating_point_type(t): - return "float" - if self.is_integer_type(t): - return "integer" - if self.is_index_type(t): - return "integer" if index_same_as_integer else "index" - raise CodegenError(f"Unknown typeclass for type `{t}`") - - def get_float_bitwidth(self, t: IrType) -> int: - # If this is a vector type, get the element type. - if isinstance(t, VectorType): - t = t.element_type - return FLOAT_BITWIDTHS[str(t)] - - def to_dtype(self, value: IRProxyValue, dtype: IrType) -> IRProxyValue: - value_type = value.ir_value.type - # Create a vector type for dtype if value is a vector. - to_type = dtype - if isinstance(value_type, VectorType): - to_type = VectorType.get(value_type.shape, dtype) - - # Short-circuit if already the right type. - if value_type == to_type: - return value - - value_typeclass = self.get_typeclass(value_type) - to_typeclass = self.get_typeclass(dtype) - attr_name = f"to_dtype_{value_typeclass}_to_{to_typeclass}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"No implemented path to implicitly promote scalar `{value_type}` to `{to_type}` (tried '{attr_name}')" - ) - return IRProxyValue(handler(value.ir_value, to_type)) - - def constant_attr(self, val: int | float, element_type: IrType) -> Attribute: - if self.is_integer_type(element_type) or self.is_index_type(element_type): - if not isinstance(val, int): - raise TypeError(f"Expected an integer value, got {val}") - return IntegerAttr.get(element_type, val) - - if self.is_floating_point_type(element_type): - if not isinstance(val, float): - raise TypeError(f"Expected a float value, got {val}") - return FloatAttr.get(element_type, val) - - raise CodegenError( - f"Cannot create a constant attribute for type `{element_type}`" - ) - - def zero_attr(self, t: IrType) -> Attribute: - if self.is_integer_type(t) or self.is_index_type(t): - return self.constant_attr(0, t) - if self.is_floating_point_type(t): - return self.constant_attr(0.0, t) - raise CodegenError(f"Cannot create a zero attribute for type `{t}`") - - def constant(self, py_value, element_type: IrType) -> IRProxyValue: - attr = self.constant_attr(py_value, element_type) - return IRProxyValue(arith_d.constant(element_type, attr)) - - def constant_vector(self, py_value, shape, element_type: IrType) -> IRProxyValue: - attr = self.constant_attr(py_value, element_type) - vector_type = VectorType.get(shape, element_type) - splat = DenseElementsAttr.get_splat(vector_type, attr) - return IRProxyValue(arith_d.constant(vector_type, splat)) - - def binary_arithmetic( - self, op: str, lhs: IRProxyValue, rhs: IRProxyValue - ) -> IRProxyValue: - lhs_ir_type = lhs.ir_value.type - rhs_ir_type = rhs.ir_value.type - - if lhs_ir_type != rhs_ir_type: - raise CodegenError( - f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir_type} and {rhs_ir_type} due to element type mismatch" - ) - - typeclass = self.get_typeclass(lhs_ir_type, True) - attr_name = f"binary_{op}_{typeclass}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir_type} and {rhs_ir_type} (tried '{attr_name}')" - ) - return handler(lhs, rhs) - - def binary_vector_arithmetic( - self, op: str, lhs: IRProxyValue, rhs: IRProxyValue - ) -> IRProxyValue: - lhs_ir = lhs.ir_value - rhs_ir = rhs.ir_value - lhs_element_type = VectorType(lhs_ir.type).element_type - rhs_element_type = VectorType(rhs_ir.type).element_type - - if lhs_element_type != rhs_element_type: - raise CodegenError( - f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir.type} and {rhs_ir.type} due to element type mismatch" - ) - - typeclass = self.get_typeclass(lhs_element_type, True) - attr_name = f"binary_{op}_{typeclass}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir.type} and {rhs_ir.type} (tried '{attr_name}')" - ) - return handler(lhs, rhs) - - def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue: - val_ir_type = val.ir_value.type - typeclass = self.get_typeclass(val_ir_type, True) - attr_name = f"unary_{op}_{typeclass}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot perform unary arithmetic operation '{op}' on {val_ir_type} (tried '{attr_name}')" - ) - return handler(val) - - def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue: - val_ir = val.ir_value - val_element_type = VectorType(val_ir.type).element_type - typeclass = self.get_typeclass(val_element_type, True) - attr_name = f"unary_{op}_{typeclass}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot perform unary arithmetic operation '{op}' on {val_ir.type} (tried '{attr_name}')" - ) - return handler(val) - - ### Specializations - - # Casting - def to_dtype_index_to_integer(self, value: Value, to_type: IrType) -> Value: - return arith_d.index_cast(to_type, value) - - def to_dtype_index_to_float(self, value: Value, to_type: IrType) -> Value: - # Cast index to integer, and then ask for a integer to float cast. - # TODO: I don't really know how to query the machine bitwidth here, - # so using 64. - casted_to_int = arith_d.index_cast(IntegerType.get_signless(64), value) - return self.to_dtype(IRProxyValue(casted_to_int), to_type).ir_value - - def to_dtype_integer_to_float(self, value: Value, to_type: IrType) -> Value: - # sitofp - casted_to_float = arith_d.sitofp(to_type, value) - return self.to_dtype(IRProxyValue(casted_to_float), to_type).ir_value - - def to_dtype_float_to_float(self, value: Value, to_type: IrType) -> Value: - # Check bitwidth to determine if we need to extend or narrow - from_type = value.type - from_bitwidth = self.get_float_bitwidth(from_type) - to_bitwidth = self.get_float_bitwidth(to_type) - if from_bitwidth < to_bitwidth: - return arith_d.extf(to_type, value) - elif from_bitwidth > to_bitwidth: - return arith_d.truncf(to_type, value) - else: - raise CodegenError(f"NYI: Cast from {from_type} to {to_type}") - - # Binary integer/integer arithmetic. - def binary_add_integer(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: - return IRProxyValue(arith_d.addi(lhs.ir_value, rhs.ir_value)) - - def binary_mul_integer(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: - return IRProxyValue(arith_d.muli(lhs.ir_value, rhs.ir_value)) - - def binary_sub_integer(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: - return IRProxyValue(arith_d.subi(lhs.ir_value, rhs.ir_value)) - - def binary_mod_integer(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: - return IRProxyValue(arith_d.remsi(lhs.ir_value, rhs.ir_value)) - - def binary_floordiv_integer( - self, lhs: IRProxyValue, rhs: IRProxyValue - ) -> IRProxyValue: - return IRProxyValue(arith_d.floordivsi(lhs.ir_value, rhs.ir_value)) - - # Binary float arithmetic - def binary_add_float(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: - return IRProxyValue(arith_d.addf(lhs.ir_value, rhs.ir_value)) - - def binary_mul_float(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: - return IRProxyValue(arith_d.mulf(lhs.ir_value, rhs.ir_value)) - - def binary_sub_float(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: - return IRProxyValue(arith_d.subf(lhs.ir_value, rhs.ir_value)) - - def binary_mod_float(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue: - return IRProxyValue(arith_d.remf(lhs.ir_value, rhs.ir_value)) - - def binary_truediv_float( - self, lhs: IRProxyValue, rhs: IRProxyValue - ) -> IRProxyValue: - return IRProxyValue(arith_d.divf(lhs.ir_value, rhs.ir_value)) - - def unary_exp2_float(self, val: IRProxyValue) -> IRProxyValue: - return IRProxyValue(math_d.exp2(val.ir_value)) - - -ScalarBuilder = _ScalarBuilder() diff --git a/core/shark_turbine/kernel/compiler/dispatch_codegen.py b/core/shark_turbine/kernel/compiler/dispatch_codegen.py deleted file mode 100644 index 737ecd800..000000000 --- a/core/shark_turbine/kernel/compiler/dispatch_codegen.py +++ /dev/null @@ -1,252 +0,0 @@ -"""Code generation support for top-level IREE dispatch constructs. - -This assumes that you have some form of code generation for the -"inside" of some kernels, as this layer is responsible for -embedding and generating the calls/dispatches. -""" - -from typing import Any, Callable, Optional, Type - -from .._support.indexing import ( - IndexingContext, -) - -from .base import ( - CodegenError, - ValidationError, -) - -from .builder import ( - ModuleBuilder, -) - -from .ir import ( - Block, - FunctionType, - IndexType, - InsertionPoint, - IntegerAttr, - IrType, - Location, - Operation, - StringAttr, - Value, - arith_d, - flow_d, - func_d, - stream_d, -) - -from .kernel_codegen import ( - BindingDesc, - BindingType, - BoundKernelSignature, - KernelSignature, -) - -from ..lang.grid import Grid - - -class StreamExecutable: - """Encapsulates a 'stream' compilable executable which can be dispatched to. - - This corresponds to a `stream.executable`, consisting of one or more exported - dispatch functions. - """ - - __slots__ = [ - "_mb", - "_exe_op", - "_exe_block", - "_loc", - "sym_name", - "def_module", - ] - - def __init__( - self, - mb: ModuleBuilder, - *, - loc: Optional[Location] = None, - name: str = "__executable", - ): - self._mb = mb - if not loc: - loc = mb.unknown_loc - self._loc = loc - - # Construct the executable. - with loc: - with InsertionPoint(mb.body_block): - self._exe_op = exe_op = stream_d.ExecutableOp( - name, sym_visibility="private" - ) - exe_block = exe_op.body.blocks.append() - self._exe_block: Block = exe_block - stream_d.ExecutableEndOp(ip=InsertionPoint(exe_block)) - mb.symbol_table.insert(exe_op) - self.sym_name: StringAttr = exe_op.sym_name - - # Construct the inner definitions module. - with InsertionPoint.at_block_begin(exe_block): - self.def_module = ModuleBuilder(context=mb.context) - - def define_entrypoint( - self, - name: str, - sig: KernelSignature, - grid: Grid, - ) -> "DispatchEntrypoint": - """Defines a dispatch function with a signature like: - - ``` - func.func @name(%in0 : !stream.binding, %in1 : !stream.binding, - %workload0 : index, %workload1 : index, - %result0 : !stream.binding, %result1 : !stream.binding) - ``` - - Also adds an export with workgroup function like: - - ``` - stream.executable.export private @name(%workload0 : index, %workload1 : index) -> (index, [[grid_arity...]]) { - - } - ``` - - The given name is not uniqued (must be unique as given by the caller). - """ - kb_input_bindings = sig.kernel_buffer_input_bindings - kb_temp_bindings = sig.kernel_buffer_temporary_bindings - kb_output_bindings = sig.kernel_buffer_output_bindings - # TODO: The way we are doing grid bindings is wrong. The Grid type - # should be paramerized with special grid axis symbols which are - # algebraically related to concrete shape dim symbols. For now, we are - # just assuming that the grid dims can be resolved to constants , when - # in reality, we should pass the workload and parameterize the grid - # dims on the workloads. - workload_axis_bindings = [] - - # Input bindings are always user specified. - # Grid/workgroup bindings are in the inputs section but are implied. - # Temp bindings are a special kind of output bindings. - # Output bindings are the real outputs. - linear_bindings = ( - kb_input_bindings - + workload_axis_bindings - + kb_temp_bindings - + kb_output_bindings - ) - - # TODO: This is sloppy. This assert will hit on some user errors for - # unsupported type combinations and is just a last resort right now. - # TODO: This is currently disabled because the grid_bindings don't match - # workload bindings. - # assert len(linear_bindings) == len( - # sig.bindings - # ), f"Not all bindings converted: {linear_bindings} vs {sig.bindings}" - - with self._loc: - binding_type = IrType.parse("!stream.binding") - index_type = IndexType.get() - - # Define the dispatch function. - def abi_type(binding: BindingDesc): - if binding.binding_type == BindingType.KERNEL_BUFFER: - return binding_type - return binding.as_mlir_type() - - def_ftype = FunctionType.get( - [abi_type(b) for b in linear_bindings], - [], - ) - with InsertionPoint(self.def_module.body_block): - def_func_op = func_d.FuncOp(name, def_ftype) - def_func_block = def_func_op.add_entry_block() - def_func_args = list(def_func_block.arguments) - - # Define the export. - with InsertionPoint.at_block_begin(self._exe_block): - export_op = stream_d.ExecutableExportOp(name, name) - export_block = export_op.workgroup_count.blocks.append( - *([b.as_mlir_type() for b in workload_axis_bindings]) - ) - - workgroup_builder = WorkgroupBuilder( - export_block, lambda vs: stream_d.ReturnOp(vs) - ) - - # TODO: Support passing workload to the dispatch function. - with InsertionPoint(workgroup_builder.entry_block): - result_type = IndexType.get() - workgroup_values = [ - arith_d.constant(result_type, IntegerAttr.get(result_type, dim)) - for dim in grid.dims - ] - - while len(workgroup_values) < 3: - workgroup_values.append( - arith_d.constant(result_type, IntegerAttr.get(result_type, 1)) - ) - workgroup_builder.terminate(workgroup_values) - - return DispatchEntrypoint(sig, def_func_block, linear_bindings) - - -class WorkgroupBuilder: - """Builder for a workgroup calculation block.""" - - __slots__ = [ - "entry_block", - "workload", - "_term_ctor", - ] - - def __init__(self, entry_block: Block, term_ctor: Callable[[list[Value]], None]): - self.entry_block = entry_block - self.workload = list(entry_block.arguments) - self._term_ctor = term_ctor - - @property - def location(self) -> Location: - return self.entry_block.owner.location - - def terminate(self, returns: list[Value]): - entry_block = self.entry_block - with entry_block.owner.location, InsertionPoint(entry_block): - self._term_ctor(returns) - - -class DispatchEntrypoint(BoundKernelSignature): - def __init__( - self, - sig: KernelSignature, - entry_block: Block, - linear_bindings: list[BindingDesc], - ): - super().__init__(sig, entry_block) - self._abi_value_by_reference: dict[tuple[str, Any], Value] = { - b.reference: value - for value, b in zip(entry_block.arguments, linear_bindings) - } - - def resolve(self, binding: BindingDesc) -> Value: - ref_type, ref_value = binding.reference - if ref_type == "grid": - return stream_d.dispatch_workgroup_id( - IntegerAttr.get(IndexType.get(), ref_value) - ) - - if binding.binding_type == BindingType.KERNEL_BUFFER: - # Issue a subspan to get into the memref domain. - result_type = IndexType.get() - zero_value = arith_d.constant(result_type, IntegerAttr.get(result_type, 0)) - linear_arg_value = self._abi_value_by_reference[binding.reference] - # TODO: Need to also look up dynamic symbol values. - return stream_d.binding_subspan( - binding.as_mlir_type(), - linear_arg_value, - byte_offset=zero_value, - dynamic_dims=[], - ) - - raise ValidationError(f"Unhandled binding type: {binding}") diff --git a/core/shark_turbine/kernel/compiler/host_codegen.py b/core/shark_turbine/kernel/compiler/host_codegen.py deleted file mode 100644 index 9225d8314..000000000 --- a/core/shark_turbine/kernel/compiler/host_codegen.py +++ /dev/null @@ -1,58 +0,0 @@ -from .kernel_codegen import KernelSignature -from .dispatch_codegen import StreamExecutable - -from .builder import ( - ModuleBuilder, -) - -from .ir import ( - Block, - FunctionType, - InsertionPoint, - IrType, - Location, - ArrayAttr, - SymbolRefAttr, - MemRefType, - RankedTensorType, - flow_d, - func_d, -) - - -def memref_to_tensor(memrefs: list[IrType]): - tensors = [] - for m in memrefs: - assert isinstance(m, MemRefType) - t = RankedTensorType.get(m.shape, m.element_type) - tensors.append(t) - return tensors - - -def isolated_test_call( - mb: ModuleBuilder, exe: StreamExecutable, sig: KernelSignature, entrypoint: str -): - with InsertionPoint(mb.body_block), Location.unknown(): - input_types = [b.as_mlir_type() for b in sig.kernel_buffer_input_bindings] - input_tensors = memref_to_tensor(input_types) - output_types = [b.as_mlir_type() for b in sig.kernel_buffer_output_bindings] - output_tensors = memref_to_tensor(output_types) - - ftype = FunctionType.get(input_tensors, output_tensors) - func_op = func_d.FuncOp("isolated_benchmark", ftype) - arg_locs = [ - (Location.name(b.name) if b.name is not None else Location.unknown()) - for b in sig.kernel_buffer_input_bindings - ] - entry_block = func_op.add_entry_block(arg_locs) - with InsertionPoint(entry_block): - assert isinstance(entry_block, Block) - # Create a flow.dispatch op to the kernel - dispatch = SymbolRefAttr.get([exe.sym_name.value, entrypoint]) - entrypoints = ArrayAttr.get([dispatch]) - - out = flow_d.DispatchOp( - output_tensors, [], entrypoints, entry_block.arguments, [], [] - ) - - func_d.ReturnOp(out) diff --git a/core/shark_turbine/kernel/compiler/ir.py b/core/shark_turbine/kernel/compiler/ir.py deleted file mode 100644 index 560b85cd2..000000000 --- a/core/shark_turbine/kernel/compiler/ir.py +++ /dev/null @@ -1,43 +0,0 @@ -from iree.compiler.ir import ( - AffineConstantExpr, - AffineExpr, - AffineMap, - FlatSymbolRefAttr, - SymbolRefAttr, - AffineMapAttr, - Attribute, - RankedTensorType, - ArrayAttr, - Block, - Context, - DenseElementsAttr, - F16Type, - F32Type, - F64Type, - FloatAttr, - FunctionType, - IndexType, - InsertionPoint, - IntegerAttr, - IntegerType, - Location, - Operation, - MemRefType, - ShapedType, - StringAttr, - SymbolTable, - Type as IrType, - Value, - VectorType, -) - -from iree.compiler.dialects import ( - arith as arith_d, - builtin as builtin_d, - flow as flow_d, - func as func_d, - math as math_d, - stream as stream_d, - vector as vector_d, - scf as scf_d, -) diff --git a/core/shark_turbine/kernel/compiler/kernel_codegen.py b/core/shark_turbine/kernel/compiler/kernel_codegen.py deleted file mode 100644 index 5a6805dfa..000000000 --- a/core/shark_turbine/kernel/compiler/kernel_codegen.py +++ /dev/null @@ -1,264 +0,0 @@ -"""Code generation support for kernel entry-points. - -In a typical code generation stack, there are three elements: - -1. Dispatch code-generation: Embeds executables into some overall - program and coordinates launches. -2. Kernel code-generation: Handles device-side kernel signatures - and global marshalling physical kernel inputs to logical - kernel inputs and grid functions. -3. Low-level code-generation: Generates DMAs and compute operations - based on a logical program. - -This level handles #2. -""" - -from typing import Any, Optional, Type - -from abc import ABC, abstractmethod -from enum import Enum -from dataclasses import dataclass - -import torch.fx as fx - -from .._support.indexing import ( - IndexingContext, - IndexSymbol, -) - -from ..lang.kernel_buffer import ( - KernelBuffer, - KernelBufferUsage, - is_kernel_buffer_meta_derived, -) -from ..lang.grid import Grid - -from .base import ( - CodegenError, -) - -from .builder import ( - ModuleBuilder, -) - -from .ir import ( - Block, - FunctionType, - IndexType, - InsertionPoint, - IrType, - Location, - Operation, - Value, - func_d, -) - - -class BindingType(Enum): - KERNEL_BUFFER = 0 - INDEX_VALUE = 1 - SYMBOL_VALUE = 2 - - -@dataclass -class BindingDesc: - # The unique reference object that this is derived from. This will - # be different for each kind of argument: - # FX node placeholders: ('node', fx.Node) - # Grid indices: ('grid', n) - reference: tuple[str, Any] - - # Discrimnator type of this argument. - binding_type: BindingType - - # Debug name derived from the source, if available. - name: Optional[str] = None - - # If an INPUT_BUFFER, OUTPUT_BUFFER, or TEMPORARY_BUFFER, then this - # is the backing KernelBuffer type. - kernel_buffer_type: Optional[Type[KernelBuffer]] = None - - # If a SYMBOL_VALUE, then this is the corresponding IndexSymbol. - symbol_type: Optional[Type[IndexSymbol]] = None - - def as_mlir_type(self) -> IrType: - idx_context = IndexingContext.current() - - def sym_to_dim_asm(s: IndexSymbol) -> str: - static_value = idx_context.get_static_value(s) - return "?" if static_value is None else str(static_value) - - binding_type = self.binding_type - if binding_type == BindingType.KERNEL_BUFFER: - kb_t = self.kernel_buffer_type # type: KernelBuffer - element_type_asm = kb_t.dtype.ir_type_asm() - symbolic_shape = kb_t.symbolic_shape - if symbolic_shape is not None: - shape_asm = "x".join(sym_to_dim_asm(s) for s in kb_t.symbolic_shape) - spec_asm = f"{shape_asm}x{element_type_asm}" - else: - # Unranked. Not well supported, but for completeness. - spec_asm = element_type_asm - memref_asm = f"memref<{spec_asm}>" - return IrType.parse(memref_asm) - elif binding_type == BindingType.INDEX_VALUE: - return IndexType.get() - elif binding_type == BindingType.SYMBOL_VALUE: - return IndexType.get() - else: - raise AssertionError("Unhandled switch BindingType") - - -class KernelSignature: - def __init__(self): - self.bindings: list[BindingDesc] = [] - - @property - def grid_bindings(self) -> list[BindingDesc]: - """Gets all grid axis bindings.""" - return [b for b in self.bindings if b.reference[0] == "grid"] - - @property - def kernel_buffer_input_bindings(self) -> list[BindingDesc]: - """Gets all kernel buffer bindings with input usage.""" - return [ - b - for b in self.bindings - if b.binding_type == BindingType.KERNEL_BUFFER - and b.kernel_buffer_type.usage == KernelBufferUsage.INPUT - ] - - @property - def kernel_buffer_output_bindings(self) -> list[BindingDesc]: - """Gets all kernel buffer bindings with input usage.""" - return [ - b - for b in self.bindings - if b.binding_type == BindingType.KERNEL_BUFFER - and b.kernel_buffer_type.usage == KernelBufferUsage.OUTPUT - ] - - @property - def kernel_buffer_temporary_bindings(self) -> list[BindingDesc]: - """Gets all kernel buffer bindings with input usage.""" - return [ - b - for b in self.bindings - if b.binding_type == BindingType.KERNEL_BUFFER - and b.kernel_buffer_type.usage == KernelBufferUsage.TEMPORARY - ] - - def add_from_graph_placeholders(self, graph: fx.Graph): - placeholder_nodes: list[fx.Node] = [] - for node in graph.nodes: - if node.op != "placeholder": - continue - placeholder_nodes.append(node) - - for node in placeholder_nodes: - t = node.type - if is_kernel_buffer_meta_derived(t): - self.bindings.append( - BindingDesc( - ("node", node), - BindingType.KERNEL_BUFFER, - name=node.target, - kernel_buffer_type=t, - ) - ) - elif issubclass(t, IndexSymbol): - self.bindings.append( - BindingDesc( - ("node", node), - BindingType.SYMBOL_VALUE, - name=node.target, - symbol_type=t, - ) - ) - else: - raise ValueError( - f"Unsupported placeholder node type: {t} (for node {node})" - ) - - def add_grid(self, grid: Type[Grid]): - assert grid.symbolic_shape, "code emission requires a symbolically shaped grid" - for index, s in enumerate(grid.symbolic_shape): - self.bindings.append( - BindingDesc( - ("grid", index), BindingType.INDEX_VALUE, name=f"grid{index}" - ) - ) - - def __repr__(self): - parts = [] - for b in self.bindings: - part = repr(b.reference) - if b.name: - part = f"{b.name}: {part}" - parts.append(part) - return f"Signature({', '.join(parts)})" - - -class BoundKernelSignature(ABC): - """Represents a KernelSignature bound to a concrete IR structure.""" - - def __init__(self, sig: KernelSignature, entry_block: Block): - self.sig = sig - self.entry_block = entry_block - self._bindings_by_reference: dict[Any, BindingDesc] = { - b.reference: b for b in sig.bindings - } - - def resolve_by_reference(self, reference: Any) -> Value: - binding = self._bindings_by_reference[reference] - return self.resolve(binding) - - @abstractmethod - def resolve(self, binding: BindingDesc) -> Value: - """Resolves a binding to a concrete Value. - - Note that for some implementations, this may involve creating IR. It - is recommended to cache it. - """ - ... - - -class FunctionalKernelSignature(BoundKernelSignature): - """Simple BoundKernelSignature which maps all bindings to function args. - - Arguments are represented in binding order. - """ - - def __init__(self, sig: KernelSignature, entry_block: Block): - super().__init__(sig, entry_block) - block_args = list(entry_block.arguments) - bindings = sig.bindings - assert len(block_args) == len( - bindings - ), "Mismatched signature vs block arguments" - self._mapping: dict[Any, Value] = { - binding.reference: arg_value - for binding, arg_value in zip(bindings, block_args) - } - - def resolve(self, binding: BindingDesc) -> Value: - try: - return self._mapping[binding.reference] - except KeyError: - raise CodegenError(f"Binding {binding.reference} is not bound") - - @staticmethod - def create( - sig: KernelSignature, mb: ModuleBuilder, name: str = "kernel" - ) -> tuple["FunctionalKernelSignature", Operation]: - """Create a function in the module, returning the bound signature and the function.""" - with InsertionPoint(mb.body_block), Location.unknown(): - input_types = [b.as_mlir_type() for b in sig.bindings] - ftype = FunctionType.get(input_types, []) - func_op = func_d.FuncOp(name, ftype) - arg_locs = [ - (Location.name(b.name) if b.name is not None else Location.unknown()) - for b in sig.bindings - ] - entry_block = func_op.add_entry_block(arg_locs) - return FunctionalKernelSignature(sig, entry_block), func_op.operation diff --git a/core/shark_turbine/kernel/compiler/op_matchers.py b/core/shark_turbine/kernel/compiler/op_matchers.py deleted file mode 100644 index 81f738a5e..000000000 --- a/core/shark_turbine/kernel/compiler/op_matchers.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional - -import torch -from torch import Tensor - -import functools -import inspect - - -def signature_matcher(f=None, *, arity: Optional[int] = None, original_name: str = ""): - """Transforms a function into a signature matcher. - - The transfored function takes the same args/kwargs as the original, but - it will return an inspect.BoundArguments.arguments when invoked. - - Optional overload selectors can be specified, and if not met, None - will be returned (versus raising an error). - - On argument mismatch, a TypeError will be raised. - """ - if f is None: - return functools.partial( - signature_matcher, arity=arity, original_name=original_name - ) - - sig = inspect.signature(f) - - def wrapped(*args, **kwargs) -> Optional[inspect.BoundArguments]: - if arity is not None and arity != (len(args) + len(kwargs)): - return None - try: - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - return bound_args.arguments - except TypeError as e: - reported_name = original_name or f.__name__ - raise TypeError(f"{reported_name}() {str(e)}") - - return wrapped - - -@signature_matcher(original_name="torch.exp") -def torch_exp(input: Tensor) -> Tensor: - ... - - -@signature_matcher(arity=1, original_name="torch.max") -def torch_max_unary(input: Tensor) -> Tensor: - ... - - -@signature_matcher(original_name="torch.max") -def torch_max(input: Tensor, dim: int, keepdim: bool = False): - ... - - -@signature_matcher(arity=1, original_name="torch.sum") -def torch_sum_unary(input: Tensor) -> Tensor: - ... - - -@signature_matcher(original_name="torch.sum") -def torch_sum(input: Tensor, dim: int, keepdim: bool = False): - ... diff --git a/core/shark_turbine/kernel/compiler/vector_codegen.py b/core/shark_turbine/kernel/compiler/vector_codegen.py deleted file mode 100644 index 5ecb8ef93..000000000 --- a/core/shark_turbine/kernel/compiler/vector_codegen.py +++ /dev/null @@ -1,1056 +0,0 @@ -"""Code generation for generating vector-dialect based kernels. - -Such kernels operate on global memory at the boundary, scheduling -actual loads/stores/computes to local vectors using PyTorch tensor -level operations executed as threads over a grid. -""" -from typing import Any, Callable, Type, Optional, Sequence, Union, List -import types - -from dataclasses import dataclass -import inspect -import operator as py_operator - -import torch -import torch.fx as fx -import torch.utils._pytree as pytree - -from .._support.indexing import ( - IndexExpr, - IndexingContext, - IndexSymbol, - SymIndex, - index_expr, -) - -from ..lang.kernel_buffer import KernelBuffer - -from .._support import dtype - -from .._support.tracing import CapturedTrace - -from .. import lang as tkl - -from ..lang import ( - Index, -) - -from .. import ops - -from .builder import ( - IRProxyValue, - ScalarBuilder, -) - -from .base import ( - CodegenError, - NDEBUG, - ValidationError, -) - -from .ir import ( - AffineMap, - Attribute, - AffineExpr, - AffineMapAttr, - ArrayAttr, - FunctionType, - VectorType, - DenseElementsAttr, - F32Type, - IndexType, - FloatAttr, - InsertionPoint, - IrType, - Location, - MemRefType, - ShapedType, - Value, - VectorType, - arith_d, - func_d, - math_d, - vector_d, - scf_d, -) - -from .kernel_codegen import ( - BoundKernelSignature, -) - -from . import op_matchers - -ArgTypeUnion = Union[IndexSymbol, Type[KernelBuffer]] - - -@dataclass -class NodeAttrs: - # By default, integers are assumed signed. We propagate unsigned as graph - # node attrs. - unsigned: bool = False - - @staticmethod - def load(py_value) -> "NodeAttrs": - if isinstance(py_value, fx.Node): - return NodeAttrs(unsigned=bool(py_value.meta.get("unsigned"))) - return NodeAttrs() - - def store(self, node: fx.Node): - node.meta["unsigned"] = self.unsigned - - -class ThreadEmitter: - """Emits a 'thread function' as a `func` with a signature derived from the gm.""" - - OP_HANDLERS: dict[Any, Callable[["ThreadEmitter", fx.Node], None]] = {} - - def __init__(self, root_sig: BoundKernelSignature, trace: CapturedTrace): - self._node_values: dict[fx.Node, List[IRProxyValue]] = {} - self._grid_axis_values: dict[int, IRProxyValue] = {} - self._root_sig = root_sig - self.trace = trace - self.ip = InsertionPoint(root_sig.entry_block) - - def lookup_node_values(self, node: fx.Node) -> List[Value]: - assert NDEBUG or isinstance(node, fx.Node) - values = self._node_values.get(node) - if values is None: - values = [self._root_sig.resolve_by_reference(("node", node))] - self._node_values[node] = values - return values - - def lookup_grid_axis_value(self, grid_axis: int) -> IRProxyValue: - assert NDEBUG or isinstance(grid_axis, int) - value = self._grid_axis_values.get(grid_axis) - if value is None: - try: - ir_value = self._root_sig.resolve_by_reference(("grid", grid_axis)) - except KeyError: - raise CodegenError(f"Grid axis {grid_axis} out of bounds") - sym_index = SymIndex(IndexingContext.current().new_unbacked_symbol()) - value = IRProxyValue(ir_value, sym_index) - self._grid_axis_values[grid_axis] = value - return value - - def bind_node_proxy( - self, node: fx.Node, proxy: IRProxyValue, *, attrs: Optional[NodeAttrs] = None - ): - """Binds a node's result to a Python/IR proxy object.""" - assert NDEBUG or (isinstance(node, fx.Node) and isinstance(proxy, IRProxyValue)) - assert ( - node not in self._node_values - ), f"Cannot rebind node {node}: already bound" - if attrs is not None: - attrs.store(node) - self._node_values[node] = [proxy] - - def bind_node_proxies( - self, - node: fx.Node, - proxies: list[IRProxyValue], - *, - attrs: Optional[NodeAttrs] = None, - ): - """Binds a node's result to a list of Python/IR proxy object.""" - assert NDEBUG or ( - all(isinstance(p, IRProxyValue) for p in proxies) - and isinstance(node, fx.Node) - ) - assert ( - node not in self._node_values - ), f"Cannot rebind node {node}: already bound" - if attrs is not None: - attrs.store(node) - self._node_values[node] = proxies - - def emit(self): - with self.ip, Location.unknown(): - self.emit_graph(self.trace.get_root_graph()) - - def emit_function_call_node(self, node: fx.Node): - target_op = node.target - try: - handler = self.OP_HANDLERS[target_op] - except KeyError: - raise CodegenError(f"No handler registered for op {target_op}") - handler(self, node) - # dump - - def emit_graph(self, graph: fx.Graph): - """Emits the given graph at the current insertion point.""" - for node in graph.nodes: - if node.op == "call_function": - self.emit_function_call_node(node) - if node.op == "output": - return node.args - - def emit_subgraph(self, subgraph: fx.Graph, implicit_capture: list[fx.Node]): - # Map subgraph freevars -> implicit_capture - freevars = self.trace.region_graph.inner_freevars[subgraph] - assert len(freevars) == len( - implicit_capture - ), f"Expected {len(freevars)} implicit capture args, got {len(implicit_capture)}" - for freevar, arg in zip(freevars, implicit_capture): - self._node_values[freevar.node] = self.lookup_node_values(arg) - - # Emit subgraph - return self.emit_graph(subgraph) - - def finish(self): - with self.ip, Location.unknown(): - func_d.ReturnOp([]) - - -def handle_op(op): - def decorator(f: Callable[["ThreadEmitter", fx.Node], None]): - ThreadEmitter.OP_HANDLERS[op] = f - return None - - return decorator - - -############################################################################### -# Python/scalar ops -############################################################################### - - -@handle_op(py_operator.getitem) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - proxy, index = node.args - except ValueError as e: - raise ValidationError("Malformed arguements") from e - - if not isinstance(proxy, fx.Node): - raise CodegenError(f"Expected fx.Node") - node_values = emitter.lookup_node_values(proxy) - emitter.bind_node_proxy(node, node_values[index]) - - -BINARY_ARITHMETIC_OPS = [ - (py_operator.add, "add"), - (py_operator.mul, "mul"), - (py_operator.sub, "sub"), - (py_operator.mod, "mod"), - (py_operator.floordiv, "floordiv"), - (py_operator.truediv, "truediv"), -] - -UNARY_ARITHMETIC_OPS = [ - (tkl.exp2, "exp2"), -] - - -def binary_broadcast( - lhs: IRProxyValue, rhs: IRProxyValue -) -> tuple[bool, IRProxyValue, IRProxyValue]: - assert NDEBUG or (isinstance(lhs, IRProxyValue) and isinstance(rhs, IRProxyValue)) - lhs_type = lhs.ir_value.type - rhs_type = rhs.ir_value.type - lhs_is_vector = VectorType.isinstance(lhs_type) - rhs_is_vector = VectorType.isinstance(rhs_type) - if not lhs_is_vector and not rhs_is_vector: - # Not vectors: return as-is. - return False, lhs, rhs - - # Promote to vector. - if not lhs_is_vector: - lhs = IRProxyValue(vector_d.splat(VectorType.get([], lhs_type), lhs.ir_value)) - if not rhs_is_vector: - rhs = IRProxyValue(vector_d.splat(VectorType.get([], rhs_type), rhs.ir_value)) - lhs_type = VectorType(lhs.ir_value.type) - rhs_type = VectorType(rhs.ir_value.type) - - broadcast_shape = lhs_type.shape - rhs_shape = rhs_type.shape - rank = max(len(broadcast_shape), len(rhs_shape)) - while len(broadcast_shape) < rank: - broadcast_shape.insert(0, 1) - while len(rhs_shape) < rank: - rhs_shape.insert(0, 1) - - for i in range(rank): - a = broadcast_shape[i] - b = rhs_shape[i] - if a != b: - if a != 1 and b != 1: - raise CodegenError( - f"Binary operands are not broadcast compatible: {lhs_type}, {rhs_type}" - ) - broadcast_shape[i] = rhs_shape[i] = max(a, b) - - lhs_type = VectorType.get(broadcast_shape, lhs_type.element_type) - rhs_type = VectorType.get(broadcast_shape, rhs_type.element_type) - if lhs_type != lhs.ir_value.type: - lhs = IRProxyValue(vector_d.broadcast(lhs_type, lhs.ir_value)) - if rhs_type != rhs.ir_value.type: - rhs = IRProxyValue(vector_d.broadcast(rhs_type, rhs.ir_value)) - return True, lhs, rhs - - -def _define_arithmetic_handlers(): - def register_binary_op(op, mnemonic): - @handle_op(op) - def _(emitter: ThreadEmitter, node: fx.Node): - try: - lhs, rhs = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - lhs = cast_py_value(emitter, lhs) - rhs = cast_py_value(emitter, rhs) - is_vector, lhs, rhs = binary_broadcast(lhs, rhs) - if is_vector: - result = ScalarBuilder.binary_vector_arithmetic(mnemonic, lhs, rhs) - else: - result = ScalarBuilder.binary_arithmetic(mnemonic, lhs, rhs) - emitter.bind_node_proxy(node, result) - - def register_unary_op(op, mnemonic): - @handle_op(op) - def _(emitter: ThreadEmitter, node: fx.Node): - try: - (val,) = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - val = cast_py_value(emitter, val) - is_vector = VectorType.isinstance(val.ir_value.type) - if is_vector: - result = ScalarBuilder.unary_vector_arithmetic(mnemonic, val) - else: - result = ScalarBuilder.unary_arithmetic(mnemonic, val) - emitter.bind_node_proxy(node, result) - - for op, mnemonic in BINARY_ARITHMETIC_OPS: - # Need to capture these per iteration, not just final value, - # so call a function. - register_binary_op(op, mnemonic) - - for op, mnemonic in UNARY_ARITHMETIC_OPS: - register_unary_op(op, mnemonic) - - -_define_arithmetic_handlers() - -############################################################################### -# Core data movement and indexing ops -############################################################################### - - -@handle_op(ops.thread_program_id) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - (axis,) = node.args - axis = Index(axis) - except ValueError as e: - raise ValidationError("Malformed arguments") from e - proxy_value = emitter.lookup_grid_axis_value(axis) - # The value we get back is just an unbacked SymInt. Since we have the - # type information to make a bounded instance, create that sharing the - # symbol. - sym_index_type = node.type - assert issubclass(sym_index_type, SymIndex) - emitter.bind_node_proxy( - node, - IRProxyValue(proxy_value.ir_value, proxy_value.py_value.cast(sym_index_type)), - ) - - -@handle_op(tkl.to_dtype) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - (val, dtype) = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - ir_type = cast_dtype(emitter, dtype) - casted = cast_vector(emitter, val, element_type=ir_type) - emitter.bind_node_proxy(node, IRProxyValue(casted)) - - -@handle_op(ops.kernel_buffer_getitem) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - kb, slice_spec = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - kb_src, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, kb) - ref_shape = kb_py_type.symbolic_shape - slice_spec = cast_slice_spec(emitter, ref_shape, slice_spec) - start_indices = extract_slice_starts(emitter, ref_shape, slice_spec) - vector_shape = extract_static_slice_shape(emitter, ref_shape, slice_spec) - element_type = kb_ir_type.element_type - vector_type = VectorType.get(vector_shape, element_type) - pad_attr = ScalarBuilder.zero_attr(element_type) - pad_value = arith_d.constant(element_type, pad_attr) - result = vector_d.transfer_read( - vector_type, - kb_src, - start_indices, - AffineMap.get_identity(len(start_indices)), - pad_value, - ) - emitter.bind_node_proxy(node, IRProxyValue(result)) - - -@handle_op(ops.kernel_buffer_setitem) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - kb, slice_spec, item = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - kb_dest, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, kb) - dest_rank = kb_ir_type.rank - ref_shape = kb_py_type.symbolic_shape - slice_spec = cast_slice_spec(emitter, ref_shape, slice_spec) - start_indices = extract_slice_starts(emitter, ref_shape, slice_spec) - if dest_rank != len(start_indices): - raise CodegenError( - f"Mismatched slice assignment: Expected rank {dest_rank}, got {len(start_indices)}" - ) - insert_vector = cast_vector(emitter, item, element_type=kb_ir_type.element_type) - insert_type = VectorType(insert_vector.type) - - # Special case rank-0 broadcast. - if insert_type.rank == 0: - broadcast_type = VectorType.get(dest_rank * [1], kb_ir_type.element_type) - insert_vector = vector_d.broadcast(broadcast_type, insert_vector) - - permutation_map = AffineMap.get_identity(dest_rank) - vector_d.transfer_write( - None, - insert_vector, - kb_dest, - start_indices, - AffineMapAttr.get(permutation_map), - ) - - -############################################################################### -# Memory Ops -############################################################################### - - -@handle_op(tkl.load) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - kb, multi_index, vector_shape = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - vector_shape = cast_py_literal(emitter, vector_shape) - kb_src, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, kb) - ref_shape = kb_py_type.symbolic_shape - slice_spec = cast_slice_spec(emitter, ref_shape, multi_index) - start_indices = extract_slice_starts(emitter, ref_shape, slice_spec) - element_type = kb_ir_type.element_type - vector_type = VectorType.get(vector_shape, element_type) - pad_attr = ScalarBuilder.zero_attr(element_type) - pad_value = arith_d.constant(element_type, pad_attr) - result = vector_d.transfer_read( - vector_type, - kb_src, - start_indices, - AffineMap.get_minor_identity(len(ref_shape), len(vector_shape)), - pad_value, - ) - emitter.bind_node_proxy(node, IRProxyValue(result)) - - -@handle_op(tkl.store) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - kb, multi_index, item = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - kb_dest, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, kb) - dest_rank = kb_ir_type.rank - ref_shape = kb_py_type.symbolic_shape - slice_spec = cast_slice_spec(emitter, ref_shape, multi_index) - start_indices = extract_slice_starts(emitter, ref_shape, slice_spec) - if dest_rank != len(start_indices): - raise CodegenError( - f"Mismatched slice assignment: Expected rank {dest_rank}, got {len(start_indices)}" - ) - insert_vector = cast_vector(emitter, item, element_type=kb_ir_type.element_type) - insert_type = VectorType(insert_vector.type) - insert_rank = insert_type.rank - - # Special case rank-0 broadcast. - if insert_rank == 0: - broadcast_type = VectorType.get(dest_rank * [1], kb_ir_type.element_type) - insert_vector = vector_d.broadcast(broadcast_type, insert_vector) - - permutation_map = AffineMap.get_minor_identity(dest_rank, insert_rank) - vector_d.transfer_write( - None, - insert_vector, - kb_dest, - start_indices, - AffineMapAttr.get(permutation_map), - ) - - -############################################################################### -# Math Ops -############################################################################### - - -@handle_op(tkl.constant) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - shape, dtype, value = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - shape = cast_py_literal(emitter, shape) - dtype = cast_dtype(emitter, dtype) - constant = ScalarBuilder.constant_vector(value, shape, dtype) - emitter.bind_node_proxy(node, constant) - - -############################################################################### -# Reduction Ops -############################################################################### - - -@handle_op(tkl.dot) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - lhs, rhs, acc = node.args - lhs = cast_vector(emitter, lhs) - rhs = cast_vector(emitter, rhs) - acc = cast_vector(emitter, acc) - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - vector_type = VectorType(lhs.type) - element_type = vector_type.element_type - rank = vector_type.rank - - n, m, k = ( - AffineExpr.get_dim(0), - AffineExpr.get_dim(1), - AffineExpr.get_dim(2), - ) - indexing_maps = [ - AffineMap.get(3, 0, [n, k]), - AffineMap.get(3, 0, [k, m]), - AffineMap.get(3, 0, [n, m]), - ] - indexing_maps_attr = [AffineMapAttr.get(map) for map in indexing_maps] - # TODO: Bad hack, please fix. - iterator_types = ArrayAttr.get( - [ - Attribute.parse("#vector.iterator_type"), - Attribute.parse("#vector.iterator_type"), - Attribute.parse("#vector.iterator_type"), - ] - ) - result = vector_d.ContractionOp( - acc.type, - lhs, - rhs, - acc, - indexing_maps_attr, - iterator_types, - ).result - emitter.bind_node_proxy(node, IRProxyValue(result)) - - -def register_reduction(op): - def decorator(f: Callable[[IrType, NodeAttrs], vector_d.CombiningKind]): - @handle_op(op) - def _(emitter: ThreadEmitter, node: fx.Node): - try: - vector, axis, acc = node.args - except ValueError as e: - raise ValidationError("Malformed arguements") from e - - axis = cast_py_literal(emitter, axis) - emit_reduction(emitter, node, vector, axis, acc, f) - - return decorator - - -def emit_reduction( - emitter: ThreadEmitter, - node: fx.Node, - raw_input, - axis: int, - raw_acc, - combiner_callback: Callable[[IrType, NodeAttrs], vector_d.CombiningKind], -): - # Setup. - attrs = NodeAttrs.load(raw_input) - input = cast_vector(emitter, raw_input) - vector_type = VectorType(input.type) - element_type = vector_type.element_type - rank = vector_type.rank - - if raw_acc: - acc = cast_vector(emitter, raw_acc) - else: - acc = arith_d.constant(element_type, ScalarBuilder.zero_attr(element_type)) - - combiner = combiner_callback(element_type, attrs) - - if not axis: - # Reduce to scalar. - scalar_result = vector_d.multi_reduction( - combiner, input, acc, list(range(rank)) - ) - result = vector_d.splat(VectorType.get([], element_type), scalar_result) - emitter.bind_node_proxy(node, IRProxyValue(result), attrs=attrs) - else: - # Reduce to vector. - vector_result = vector_d.multi_reduction(combiner, input, acc, [axis]) - emitter.bind_node_proxy(node, IRProxyValue(vector_result), attrs=attrs) - - -@register_reduction(tkl.max) -def _(element_type: IrType, attrs: NodeAttrs) -> vector_d.CombiningKind: - if ScalarBuilder.is_floating_point_type(element_type): - # Non-NaN propagating. - # TODO: Carry a "fastmath" flag on the emitter and choose between this - # and MAXIMUMF? - return vector_d.CombiningKind.MAXNUMF - elif ScalarBuilder.is_integer_type(element_type): - return ( - vector_d.CombiningKind.MAXUI - if attrs.unsigned - else vector_d.CombiningKind.MAXSI - ) - - raise CodegenError(f"No max reduction for type {element_type}") - - -@register_reduction(tkl.sum) -def _(element_type: IrType, attrs: NodeAttrs) -> vector_d.CombiningKind: - return vector_d.CombiningKind.ADD - - -############################################################################### -# Control Flow ops -############################################################################### - - -@handle_op(tkl.for_loop) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - start, end, step, init_args = node.args - subgraph = node.kwargs["subgraph"] - implicit_capture = node.kwargs["implicit_capture"] - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - # Check if init_args is a flattened list of values. - for arg in init_args: - if len(emitter.lookup_node_values(arg)) != 1: - raise CodegenError(f"NYI: For loop init args must be flattened") - - # Get IR values mapping to the node args. - start = cast_py_value(emitter, start) - end = cast_py_value(emitter, end) - step = cast_py_value(emitter, step) - - # Flatten init_args and get IR values for each of them. - flat_init_args, init_args_spec = pytree.tree_flatten((init_args)) - flat_init_args = [cast_py_value(emitter, arg) for arg in flat_init_args] - - # Get the subgraph for body of the loop. - assert isinstance(subgraph, str) - subgraph = emitter.trace.get_subgraph(subgraph) - - # Create scf.for operation. - forOp = scf_d.ForOp( - start.ir_value, - end.ir_value, - step.ir_value, - [a.ir_value for a in flat_init_args], - ) - # Enter body of for loop. - with InsertionPoint(forOp.body): - # TODO: Flatten subgraph args here. - subgraph_args = [ - node - for node in subgraph.nodes - if node.op == "placeholder" and "lifted" not in node.meta - ] - # Add mapping for induction variable argument. - emitter.bind_node_proxy( - subgraph_args[0], IRProxyValue(forOp.induction_variable) - ) - # Add mapping for iter_args. - for i, v in enumerate(forOp.inner_iter_args): - emitter.bind_node_proxy(subgraph_args[i + 1], IRProxyValue(v)) - - ret = emitter.emit_subgraph(subgraph, implicit_capture) - # Use ret in terminatory of body - # TODO: Flatten return values here. - flat_ret_values, ret_spec = pytree.tree_flatten((ret)) - flat_ret_values = [ - cast_py_value(emitter, value).ir_value for value in flat_ret_values - ] - scf_d.YieldOp(flat_ret_values) - - results = forOp.results_ - emitter.bind_node_proxies(node, [IRProxyValue(v) for v in results]) - - -############################################################################### -# Shape Manipulation Ops -############################################################################### - - -@handle_op(tkl.broadcast) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - vector, leading_sizes = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - vector = cast_vector(emitter, vector) - leading_sizes = cast_py_literal(emitter, leading_sizes) - - old_shape = vector.type.shape - broadcasted_shape = list(leading_sizes) + old_shape - broadcasted_type = VectorType.get(broadcasted_shape, vector.type.element_type) - result = vector_d.broadcast(broadcasted_type, vector) - emitter.bind_node_proxy(node, IRProxyValue(result)) - - -@handle_op(tkl.transpose) -def _(emitter: ThreadEmitter, node: fx.Node): - try: - vector, permutation = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - - vector = cast_vector(emitter, vector) - permutation = cast_py_literal(emitter, permutation) - new_shape = [vector.type.shape[i] for i in permutation] - result_type = VectorType.get(new_shape, vector.type.element_type) - - result = vector_d.transpose(result_type, vector, permutation) - emitter.bind_node_proxy(node, IRProxyValue(result)) - - -############################################################################### -# Conversion utilities -############################################################################### - - -def cast_py_literal(emitter: ThreadEmitter, value) -> Any: - """Treats the given value as a Python literal. - - An exception will be raised if it cannot be computed statically. - """ - if isinstance(value, IndexExpr): - simplified = IndexingContext.current().simplify_expr(value) - try: - return int(simplified) - except TypeError as e: - raise CodegenError( - f"Literal value required but got symbolic value requiring " - f"dynamic resolution: {simplified}" - ) from e - elif isinstance(value, tuple): - return tuple(cast_py_literal(emitter, v) for v in value) - elif isinstance(value, list): - return [cast_py_literal(emitter, v) for v in value] - elif isinstance(value, dict): - return { - cast_py_literal(emitter, k): cast_py_literal(emitter, v) - for k, v in value.items() - } - elif isinstance(value, (int, float, str)): - return value - - -def cast_py_value(emitter: ThreadEmitter, value) -> IRProxyValue: - """ - Converts the given value to an IR Value. - If the value is a fx.Node, the result of the fx.Node should corresspond to - exactly one IR Value. - If the value is a constant, a constant value will be built for it. - """ - if isinstance(value, fx.Node): - try: - node_values = emitter.lookup_node_values(value) - assert len(node_values) == 1, f"Expected exactly one value for node {value}" - return node_values[0] - except KeyError: - raise CodegenError(f"Producer node `{value}` has no IR Value") - elif isinstance(value, IndexExpr): - simplified = IndexingContext.current().simplify_expr(value) - try: - value = int(simplified) - except TypeError as e: - raise CodegenError( - f"Dynamically resolved symbolic values not yet implemented. Got: " - f"{simplified}" - ) from e - return ScalarBuilder.constant(value, IndexType.get()) - - -def cast_py_lvalue(emitter: ThreadEmitter, py_value: fx.Node) -> tuple[Value, fx.Node]: - if isinstance(py_value, fx.Node): - try: - node_values = emitter.lookup_node_values(py_value) - assert ( - len(node_values) == 1 - ), f"Expected exactly one value for node {py_value}" - return node_values[0], py_value - except KeyError: - raise CodegenError(f"Producer node `{py_value}` has no IR Value") - else: - raise CodegenError( - f"Required a traced node in the graph. Got: {py_value} (type {type(py_value)})" - ) - - -def cast_kernel_buffer( - emitter: ThreadEmitter, kb -) -> tuple[Value, MemRefType, Type[KernelBuffer]]: - """Casts a Python value of type KernelBuffer, which lowers to a MemRefType'd value.""" - value, node = cast_py_lvalue(emitter, kb) - ir_type = value.type - py_type = node.type - - if not MemRefType.isinstance(ir_type): - raise CodegenError( - f"Expected a KernelBuffer (aka. `memref`) but got `{ir_type}`" - ) - - if not issubclass(py_type, KernelBuffer): - raise CodegenError( - f"Expected an lvalue of type KernelBuffer but got '{py_type}' for node {node}" - ) - - return value, MemRefType(ir_type), py_type - - -def cast_vector( - emitter: ThreadEmitter, value, *, element_type: Optional[IrType] = None -): - proxy_value = cast_py_value(emitter, value) - - # Cast scalar types correctly first. - if element_type and not ShapedType.isinstance(proxy_value.ir_value.type): - # Implicit scalar type promotion. - proxy_value = ScalarBuilder.to_dtype(proxy_value, element_type) - - value = proxy_value.ir_value - - # After scalar promotion, promote to vector. - if VectorType.isinstance(value.type): - # Already a vector. Coerce or return. - if element_type is not None: - value = ScalarBuilder.to_dtype(proxy_value, element_type).ir_value - # No target element_type. - return value - else: - # Scalar -> vector. - element_type = value.type - vector_type = VectorType.get([], element_type) - return vector_d.splat(vector_type, value) - - -def cast_dtype(emitter: ThreadEmitter, dtype: dtype.DataType) -> IrType: - try: - ir_dtype = IrType.parse(dtype.ir_type_asm()) - except CodegenError as e: - raise CodegenError(f"Failed to convert dtype {dtype} to IR type") from e - - return ir_dtype - - -############################################################################### -# Slice and indexing -############################################################################### - -SliceAtom = Union[slice, None, IndexExpr, IRProxyValue] - - -def cast_slice_spec( - emitter: ThreadEmitter, ref_shape: tuple[IndexExpr], py_slice_spec -) -> list[SliceAtom]: - """Casts a node argument to a 'slice spec', normalizing it in the process. - - A 'slice spec' is what can go in the `[]` of an array access. It is either - a tuple of slice atoms or a single slice atom. A slice atom is one of: - * `slice` object - * `None` indicating dimension insertion. - * elippsis (...) to indicate a space filling `slice()` - * `IndexExpr` for a constant index value. - * `IRProxyValue` containing a `SymIndex` for a dynamic index value. - - The numpy page has a good description here: - https://numpy.org/doc/1.26/user/basics.indexing.html - - As part of casting, this implementation will replace any ellipsis with a - rank filling number of None values. - """ - rank = len(ref_shape) - if not isinstance(py_slice_spec, tuple): - py_slice_spec = (py_slice_spec,) - - # Rank normalize. - none_count = py_slice_spec.count(None) - ellipsis_count = py_slice_spec.count(...) - if ellipsis_count == 1: - # Expand by the original list of slices less any unit dim insertions. - # If negative, this does nothing and will be caught later upon - # rank validation. - expand_index = py_slice_spec.index(...) - del py_slice_spec[expand_index] - expansion_count = (rank + none_count) - len(py_slice_spec) - for _ in range(expansion_count): - py_slice_spec.insert(expand_index, slice(None)) - elif ellipsis_count > 1: - raise IndexError( - f"Cannot index into a rank expanding referrent with multiple `...` values" - ) - - return [ - cast_slice_atom(emitter, ref_shape[i], py_slice_spec[i]) for i in range(rank) - ] - - -def cast_slice_atom( - emitter: ThreadEmitter, dim_size: IndexExpr, py_slice_atom -) -> SliceAtom: - """Casts a single 'atom' in a slice spec. See cast_slice_spec.""" - if py_slice_atom is None: - # Pass-through. - return py_slice_atom - if isinstance(py_slice_atom, slice): - # Re-compose. - idxc = IndexingContext.current() - start = py_slice_atom.start - stop = py_slice_atom.stop - step = py_slice_atom.step - - # Apply start defaults. - if start is None: - start = index_expr(0) - else: - start = cast_index_value(emitter, start) - # Apply stop defaults. - if stop is None: - # Stop defaults to the dim size. - stop = idxc.simplify_expr(dim_size) - else: - # Cast it. - stop = cast_index_value(emitter, stop) - # Apply step defaults. - if step is None: - step = index_expr(1) - else: - step = cast_index_value(emitter, step) - - return slice(start, stop, step) - else: - return cast_index_value(emitter, py_slice_atom) - - -def cast_index_value( - emitter: ThreadEmitter, py_index -) -> Union[IRProxyValue, IndexExpr]: - """Casts an arbitrary py_index value to either a static or dynamic index. - - Static indices are of type IndexExpr and can be completely defined in terms - of sympy expressions on symbols. Dynamic are computed in the IR in some - fashion and are IRProxyValue with an py_value of type SymIndex. - """ - # Static IndexExpr - if isinstance(py_index, int): - return index_expr(py_index) - if isinstance(py_index, IndexExpr): - return IndexingContext.current().simplify_expr(py_index) - - # fx.Node -> IRProxyValue. - if isinstance(py_index, fx.Node): - # Cast index value. - try: - node_values = emitter.lookup_node_values(py_index) - assert ( - len(node_values) == 1 - ), f"Expected exactly one value for node {py_index}" - py_index = node_values[0] - except KeyError: - raise CodegenError(f"Producer node `{py_index}` has no IR Value") - - if not isinstance(py_index.py_value, (SymIndex, types.NoneType)): - raise CodegenError(f"Expected dynamic index value but got {py_index}") - return py_index - - -def cast_dynamic_index_value(emitter: ThreadEmitter, py_index) -> IRProxyValue: - """Casts an arbitrary py_index value to a dynamic index. - - If it was a static index, it will be materialized. - """ - py_index = cast_index_value(emitter, py_index) - if isinstance(py_index, IRProxyValue): - return py_index - - # Materialize. - try: - int_value = int(py_index) - except TypeError: - # Need to materialize the expression. - raise CodegenError(f"NYI: Materialized index expression {py_index}") - return ScalarBuilder.constant(int_value, IndexType.get()) - - -def extract_slice_starts( - emitter: ThreadEmitter, - ref_shape: tuple[IndexExpr, ...], - slice_spec: list[SliceAtom], -) -> list[Value]: - def _extract(i): - atom = slice_spec[i] - if atom is None: - return ScalarBuilder.constant(0, IndexType.get()) - elif isinstance(atom, slice): - return cast_dynamic_index_value(emitter, atom.start).ir_value - else: - return cast_dynamic_index_value(emitter, atom).ir_value - - return [_extract(i) for i in range(len(ref_shape))] - - -def extract_static_slice_shape( - emitter: ThreadEmitter, - ref_shape: tuple[IndexExpr, ...], - slice_spec: list[SliceAtom], -) -> list[int]: - rank = len(ref_shape) - shape = [0] * rank - idxc = IndexingContext.current() - for i in range(rank): - atom = slice_spec[i] - if atom is None: - # Insert 1 dim. - shape[i] = 1 - elif isinstance(atom, slice): - # Compute from slice. - if atom.step != 1: - raise CodegenError(f"NYI: Step != 1") - try: - expr = idxc.simplify_expr(atom.stop - atom.start) - shape[i] = int(expr) - except TypeError: - raise CodegenError( - f"A static shape was required but got: {slice_spec}[{i}] = {expr}" - ) - else: - # Index a single value. - shape[i] = 1 - return shape diff --git a/core/shark_turbine/kernel/gen/__init__.py b/core/shark_turbine/kernel/gen/__init__.py deleted file mode 100644 index 68db21b76..000000000 --- a/core/shark_turbine/kernel/gen/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .thread import * -from .kernel import * - -from .._support.tracing import TestLaunchContext diff --git a/core/shark_turbine/kernel/gen/kernel.py b/core/shark_turbine/kernel/gen/kernel.py deleted file mode 100644 index 1ff9134c1..000000000 --- a/core/shark_turbine/kernel/gen/kernel.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Custom op registeration for TK""" - -import inspect - -import torch - -from typing import Callable, Any - -from ..lang.kernel_buffer import is_kernel_buffer_meta_derived - -from ..lang import ( - InputBuffer, - OutputBuffer, - Grid, - IndexExpr, -) - -from .thread import LaunchableThread - -from ..compiler.ir import ( - SymbolRefAttr, - ArrayAttr, - flow_d, - IrType, -) - -from ...runtime.op_reg import ( - def_library, - CustomOp, - KernelBuilder, - KernelSelection, - TensorArg, -) - -from .._support.tracing import AOTLaunchContext -from .._support.indexing import IndexingContext - -TK_LIBRARY = def_library("tk") - - -__all__ = [ - "kernel", -] - - -def kernel(*symbolic_shape: IndexExpr): - def decorator(f: Callable): - # Convert all InputBuffer to inputs and OutputBuffers to outputs - sig = inspect.signature(f) - params = sig.parameters - inputs: list[tuple[str, Any]] = [] - outputs: list[tuple[str, Any]] = [] - for arg_name, param in params.items(): - # TODO: Implement more input arguements. - if not is_kernel_buffer_meta_derived(param.annotation): - raise NotImplementedError( - "Only KernelBuffer is supported as input for now" - ) - - if param.annotation.usage == InputBuffer.usage: - inputs.append((arg_name, param.annotation)) - elif param.annotation.usage == OutputBuffer.usage: - outputs.append((arg_name, param.annotation)) - - name_spec = f"kernel_{f.__name__}__@UNIQUE@" - input_signature = ["Tensor " + name for name, _ in inputs] - output_signature = ["Tensor " + name for name, _ in outputs] - - @CustomOp.register(library=TK_LIBRARY) - class TKCustomOp(CustomOp): - signature = ( - f"{name_spec}({', '.join(input_signature)}) -> " - f"({', '.join(output_signature)})" - ) - - def select(self, ksel: KernelSelection): - # Infer the result tensor based on the input tensor - idxc = IndexingContext() - - i = 0 - for arg_name, arg in inputs: - if is_kernel_buffer_meta_derived(arg): - x = ksel.arg_tensor(i) - # We currently only do static dimensions. - # TODO: Support dynamic dimensions. - x.spec_dims = list(x.t.shape) - assert isinstance(x, TensorArg) - idxc.bind_shaped(arg_name, arg, list(x.t.shape)) - i += 1 - else: - raise NotImplementedError( - "Only KernelBuffer is supported as input for now" - ) - - idxc.finalize() - - i = 0 - for _, arg in outputs: - if is_kernel_buffer_meta_derived(arg): - shape = arg.symbolic_shape - static_shape = [idxc.get_static_value(x) for x in shape] - x = torch.empty(*static_shape) - ksel.return_tensor(x) - # TODO: Support dynamic dimensions. - # Set spec_dims for output so that we can infer the - # type of the output tensor. - ksel.result_descs[i].spec_dims = list(x.shape) - i += 1 - else: - raise NotImplementedError( - "Only KernelBuffer is supported as input for now" - ) - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - entrypoint = f"tk_{self.name}" - # Create a flow.dispatch op to the kernel - dispatch = SymbolRefAttr.get([entrypoint, entrypoint]) - entrypoints = ArrayAttr.get([dispatch]) - - result_types = [ - IrType.parse(x.mlir_type_asm) for x in ksel.result_descs - ] - - out = flow_d.DispatchOp( - result_types, [], entrypoints, kb.arg_bindings, [], [] - ) - - kb.yield_results(*out.results_) - - # Build the kernel as a stream executable. - args = [] - for arg in ksel.arg_descs: - if isinstance(arg, TensorArg): - args.append(arg.t) - else: - raise NotImplementedError("Non TensorArg arg binding") - - for res in ksel.result_descs: - if isinstance(res, TensorArg): - args.append(res.t) - else: - raise NotImplementedError("Non TensorArg result binding") - - launchable = LaunchableThread(Grid[symbolic_shape], entrypoint, f) - with AOTLaunchContext(kb.module_body.owner) as launch_ctx: - launch_ctx.launch(launchable, args, {}) - - return TKCustomOp - - return decorator diff --git a/core/shark_turbine/kernel/gen/thread.py b/core/shark_turbine/kernel/gen/thread.py deleted file mode 100644 index 44ad976e1..000000000 --- a/core/shark_turbine/kernel/gen/thread.py +++ /dev/null @@ -1,166 +0,0 @@ -from typing import ( - Type, - Callable, - Optional, -) - -import inspect -import math - -import torch - -from ..lang import ( - KernelBuffer, - Grid, - IndexExpr, -) - -from .._support.tracing import ( - CapturedTrace, - CompiledContext, - EagerContext, - Launchable, - KernelRegionGraph, - LaunchContext, - AOTLaunchContext, -) - -from .._support.indexing import IndexingContext - -from ..compiler import ( - kernel_codegen, - dispatch_codegen, - builder, - vector_codegen, - host_codegen, -) - -from ..compiler.ir import ( - Context, - Operation, -) - -__all__ = [ - "thread", -] - - -def thread(*symbolic_shape: IndexExpr): - GridType = Grid[symbolic_shape] - - def decorator(f: Callable) -> "LaunchableThread": - return LaunchableThread(GridType, f.__name__, f) - - return decorator - - -class LaunchableThread(Launchable): - def __init__( - self, - grid_type: Type[Grid], - name: str, - eager_function: Callable, - ): - super().__init__(eager_function) - self.grid_type = grid_type - self._name = name - self._f = eager_function - self._sig = inspect.signature(eager_function) - - def _trace(self) -> CapturedTrace: - region_graph = KernelRegionGraph() - with CompiledContext(region_graph, grid_type=self.grid_type) as context: - with region_graph.subtracer() as subtracer: - root_name, _ = subtracer.trace(self._f) - trace = CapturedTrace(region_graph, root_name) - return trace - - def eager_execute(self, args, kwargs): - grid = self.grid_type() - rank = grid.rank - with EagerContext(rank=rank) as context: - sig = self._sig - bound = sig.bind(*args, *kwargs) - bound.apply_defaults() - # Transform args to KernelBuffers. - for arg_name in list(bound.arguments.keys()): - arg_value = bound.arguments[arg_name] - param = sig.parameters[arg_name] - param_type = param.annotation - if isinstance(param_type, type) and issubclass( - param_type, KernelBuffer - ): - kernel_buffer = param_type(arg_value) - bound.arguments[arg_name] = kernel_buffer - volume = math.prod(grid) - current_thread = context.current_thread - for it in range(volume): - for i in range(rank - 1): - current_thread[i] = it // grid[i] - it = it % grid[i] - current_thread[-1] = it - self._eager_function(*bound.args, **bound.kwargs) - - def _trace_and_get_kernel_signature( - self, - args, - kwargs, - context: Optional[Context] = None, - module_op: Optional[Operation] = None, - ): - # Trace the function. - trace = self._trace() - idxc = IndexingContext.current() - - sig = self._sig - bound = sig.bind(*args, *kwargs) - bound.apply_defaults() - - for arg_name in list(bound.arguments.keys()): - arg_value = bound.arguments[arg_name] - param = sig.parameters[arg_name] - param_type = param.annotation - if isinstance(param_type, type) and issubclass(param_type, KernelBuffer): - assert isinstance(arg_value, torch.Tensor) - idxc.bind_shaped(arg_name, param_type, list(arg_value.shape)) - - idxc.finalize() - - kernel_sig = kernel_codegen.KernelSignature() - kernel_sig.add_from_graph_placeholders(trace.get_root_graph()) - kernel_sig.add_grid(self.grid_type) - - grid = self.grid_type() - - mb = builder.ModuleBuilder(context=context, module_op=module_op) - entrypoint_name = self._name - exe = dispatch_codegen.StreamExecutable(mb, name=entrypoint_name) - dispatch_entrypoint = exe.define_entrypoint(entrypoint_name, kernel_sig, grid) - emitter = vector_codegen.ThreadEmitter(dispatch_entrypoint, trace) - emitter.emit() - emitter.finish() - - mb.module_op.verify() - - return mb, exe, kernel_sig, entrypoint_name - - def test_execute(self, args, kwargs): - mb, exe, kernel_sig, entrypoint_name = self._trace_and_get_kernel_signature( - args, kwargs - ) - host_codegen.isolated_test_call(mb, exe, kernel_sig, entrypoint_name) - - print(mb.module_op.get_asm()) - - def aot_execute(self, args, kwargs): - launch_context = LaunchContext.current() - assert isinstance(launch_context, AOTLaunchContext) - - module = launch_context.module - - mb, exe, kernel_sig, entrypoint_name = self._trace_and_get_kernel_signature( - args, kwargs, context=module.context, module_op=module.operation - ) - - def __repr__(self): - return f"tk.gen.thread @{self._name}[{self.grid_type}]" diff --git a/core/shark_turbine/kernel/lang/__init__.py b/core/shark_turbine/kernel/lang/__init__.py deleted file mode 100644 index ad1ed5aac..000000000 --- a/core/shark_turbine/kernel/lang/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from .prims import * -from .types import * -from .kernel_buffer import * -from .grid import * - -# Include publics from the _support library. -from .._support.indexing import ( - IndexExpr, - IndexSymbol, - sym, -) - -from .._support.dtype import ( - bool, - i4, - i8, - i16, - i32, - i64, - f16, - f32, - f64, - index, -) diff --git a/core/shark_turbine/kernel/lang/grid.py b/core/shark_turbine/kernel/lang/grid.py deleted file mode 100644 index 6c21c6903..000000000 --- a/core/shark_turbine/kernel/lang/grid.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import cast, Type, ClassVar - -from .._support.shaped_type import ShapedType -from .._support.indexing import IndexingContext, IndexExpr - -__all__ = [ - "Grid", -] - - -class Grid(metaclass=ShapedType): - """Grid with bounding symbolic shape information in the type.""" - - symbolic_shape: ClassVar[tuple[IndexExpr, ...]] - rank: ClassVar[int] - dims: list[int] - - def __init__(self): - # Resolve the symbolic shape to concrete values. - idxc = IndexingContext.current() - if self.symbolic_shape: - dims = [idxc.get_static_value(dim) for dim in self.symbolic_shape] - if None in dims: - raise ValueError(f"NYI: Dynamic dims in Grid") - self.dims = cast(list[int], dims) - else: - self.dims = [] - - def __class_getitem__( - cls, symbolic_shape: tuple[IndexExpr, ...] | IndexExpr - ) -> Type["Grid"]: - if not isinstance(symbolic_shape, tuple): - symbolic_shape = (symbolic_shape,) - - return cls.new_shaped_subtype(symbolic_shape=symbolic_shape) - - @property - def shape(self) -> tuple[int, ...]: - return tuple(self.dims) - - def __repr__(self): - return f"{repr(type(self))}({', '.join(str(i) for i in self.dims)})" - - def __getitem__(self, index: int) -> int: - return self.dims[index] - - def __len__(self) -> int: - return len(self.dims) - - def __iter__(self): - return iter(self.dims) diff --git a/core/shark_turbine/kernel/lang/kernel_buffer.py b/core/shark_turbine/kernel/lang/kernel_buffer.py deleted file mode 100644 index 9cbcd5dec..000000000 --- a/core/shark_turbine/kernel/lang/kernel_buffer.py +++ /dev/null @@ -1,150 +0,0 @@ -from typing import Type, TypeVar, cast, ClassVar - -from enum import Enum - -import torch - -from .._support.indexing import IndexExpr -from .._support.shaped_type import ShapedDataType -from .._support.dtype import DataType, f32 -from .. import ops - -__all__ = [ - "KernelBuffer", - "InputBuffer", - "OutputBuffer", - "TemporaryBuffer", - "is_kernel_buffer_meta_derived", -] - -SubtypeT = TypeVar("SubtypeT") - - -class NotSetType: - ... - - -NotSet = NotSetType() - - -class KernelBufferUsage(Enum): - NONE = 0 - INPUT = 1 - OUTPUT = 2 - TEMPORARY = 3 - - @staticmethod - def _type_name(v) -> str: - if v == KernelBufferUsage.NONE: - return "KernelBuffer" - elif v == KernelBufferUsage.INPUT: - return "InputBuffer" - elif v == KernelBufferUsage.OUTPUT: - return "OutputBuffer" - elif v == KernelBufferUsage.TEMPORARY: - return "TemporaryBuffer" - else: - raise AssertionError(f"uncovered KernelBufferUsage enum ({v})") - - -class _KernelBufferMeta(ShapedDataType): - usage: KernelBufferUsage = KernelBufferUsage.NONE - - def new_subtype( - cls: Type[SubtypeT], - *, - symbolic_shape: tuple[IndexExpr, ...] | NotSetType = NotSet, - dtype: DataType | NotSetType = NotSet, - usage: KernelBufferUsage | NotSetType = NotSet, - ) -> Type[SubtypeT]: - init_symbolic_shape = symbolic_shape if symbolic_shape is not NotSet else cls.symbolic_shape # type: ignore - init_dtype = dtype if dtype is not NotSet else cls.dtype # type: ignore - init_usage = usage if usage is not NotSet else cls.usage # type: ignore - - class SubType(cls): - symbolic_shape = init_symbolic_shape - rank = len(init_symbolic_shape) # type: ignore - dtype = init_dtype - usage = init_usage - - SubType.__name__ = KernelBufferUsage._type_name(init_usage) - - return cast(Type[SubtypeT], SubType) - - -class KernelBuffer(metaclass=_KernelBufferMeta): - """Represents a buffer in global memory. - - Top level kernels always operate on global memory via these - buffers, and the primary operations that can be performed on - them are loads/stores and DMAs to some form of compute - capable local buffer. - - When executing eagerly, these are backed by a normal torch - Tensor. When compiling, an appropriate duck-typed proxy - is used. - """ - - symbolic_shape: ClassVar[tuple[IndexExpr, ...]] - rank: ClassVar[int] - dtype: ClassVar[DataType] - - def __init__(self, tensor: torch.Tensor): - assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}" - type_rank = type(self).rank - tensor_rank = len(tensor.shape) - if type_rank is not None and type_rank != tensor_rank: - raise ValueError( - f"Cannot create {type(self)}(tensor({tensor.shape})): mismatched symbolic rank" - ) - self._tensor = tensor - - def __class_getitem__( - cls, shape_and_dtype: tuple[IndexExpr | DataType, ...] - ) -> Type["KernelBuffer"]: - """Syntax: `KernelBuffer[shape1, shape2, ..., shapeN, dtype]`""" - - if not isinstance(shape_and_dtype, tuple) or len(shape_and_dtype) < 2: - raise TypeError(f"Expected at least 2 arguments, got: {shape_and_dtype}") - - shape = shape_and_dtype[:-1] - dtype = shape_and_dtype[-1] - - if not all(isinstance(s, IndexExpr) for s in shape): - raise TypeError(f"Expected shape to be a tuple of IndexExpr, got {shape}") - if not isinstance(dtype, DataType): - raise TypeError(f"Expected dtype to be a DataType, got {dtype}") - - shape = cast(tuple[IndexExpr, ...], shape) - dtype = cast(DataType, dtype) - - return cls.new_subtype(symbolic_shape=shape, dtype=dtype) - - def __repr__(self): - return f"{type(self)}({self._tensor})" - - def __setitem__(self, key, item): - ops.kernel_buffer_setitem(self, key, item) - - def __getitem__(self, key): - return ops.kernel_buffer_getitem(self, key) - - @property - def shape(self) -> tuple[int, ...]: - return self._tensor.shape - - -class InputBuffer(KernelBuffer): - usage = KernelBufferUsage.INPUT - - -class OutputBuffer(KernelBuffer): - usage = KernelBufferUsage.OUTPUT - - -class TemporaryBuffer(KernelBuffer): - usage = KernelBufferUsage.TEMPORARY - - -def is_kernel_buffer_meta_derived(t: type) -> bool: - return isinstance(t, _KernelBufferMeta) diff --git a/core/shark_turbine/kernel/lang/prims.py b/core/shark_turbine/kernel/lang/prims.py deleted file mode 100644 index b9f163085..000000000 --- a/core/shark_turbine/kernel/lang/prims.py +++ /dev/null @@ -1,56 +0,0 @@ -from .. import ops - -from .._support.tracing import ( - BaseContext, - CompiledContext, - custom_primitive_fn, - eager_context, -) - -__all__ = [ - "is_debug", - "program_id", - "constant", - "exp2", - "max", - "sum", - "dot", - "for_loop", - "load", - "store", - "broadcast", - "broadcast_in_dim", - "transpose", - "to_dtype", -] - - -def is_debug() -> bool: - """Returns whether we are currently executing a kernel in eager debug mode.""" - return BaseContext.current().eager - - -# Core language operations -program_id = ops.thread_program_id -to_dtype = ops.to_dtype - -# Math Operations -exp2 = ops.exp2 -constant = ops.vector_constant - -# Reduction Operations -max = ops.vector_max -sum = ops.vector_sum -dot = ops.vector_dot - -# Control Flow Operations -for_loop = ops.for_loop - -# Memory Operations -load = ops.kernel_buffer_load -store = ops.kernel_buffer_store - -# Shape Manipulation operations -broadcast = ops.vector_broadcast -broadcast_in_dim = ops.vector_broadcast_in_dim -transpose = ops.vector_transpose diff --git a/core/shark_turbine/kernel/lang/types.py b/core/shark_turbine/kernel/lang/types.py deleted file mode 100644 index 1f42fdb11..000000000 --- a/core/shark_turbine/kernel/lang/types.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Type - -__all__ = [ - "Index", - "Vector", -] - -############################################################################### -# Index and specific sized integer types -############################################################################### - - -def _impl_fixed_int(t: Type[int]): - """Mixes in dunder functions for integer math to an `int` derived type. - - The result of the computation will be cast to type `t` before returning. - """ - t.__add__ = lambda a, b: t(super(t, a).__add__(b)) - t.__sub__ = lambda a, b: t(super(t, a).__sub__(b)) - t.__mul__ = lambda a, b: t(super(t, a).__mul__(b)) - t.__floordiv__ = lambda a, b: t(super(t, a).__floordiv__(b)) - t.__mod__ = lambda a, b: t(super(t, a).__mod__(b)) - t.__pow__ = lambda a, b, modulo=None: t(super(t, a).__pow__(b, modulo)) - t.__pos__ = lambda a: t(super(t, a).__pos__()) - t.__neg__ = lambda a: t(super(t, a).__neg__()) - return t - - -@_impl_fixed_int -class Index(int): - """An index type that is isomorphic to MLIR `index`. - - At the Python level, this is just an int. - """ - - ... - - -class Vector: - """A tensor like type that is isomorphic to MLIR `vector`. - - A vector has value semantics and allows computation over it. - """ - - # TODO: Implement operator overloading once math ops are added. - ... diff --git a/core/shark_turbine/kernel/ops/__init__.py b/core/shark_turbine/kernel/ops/__init__.py deleted file mode 100644 index c022248f2..000000000 --- a/core/shark_turbine/kernel/ops/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .core import * -from .math import * -from .reduction import * -from .control_flow import * -from .memory import * -from .shape_manipulation import * diff --git a/core/shark_turbine/kernel/ops/base.py b/core/shark_turbine/kernel/ops/base.py deleted file mode 100644 index a726c2546..000000000 --- a/core/shark_turbine/kernel/ops/base.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Support for defining the op library and dispatch.""" - -from typing import Callable, Type, TypeVar - -import functools - -from .._support import context - -T = TypeVar("T") - - -class OpDispatcher: - """Handles dispatch of operations by their idname. - - Operations are dispatched by looking up a function on the dispatcher like: - def handle_{idname}(self, operator, *args, **kwargs) - """ - - __tk_context_idname__ = "OpDispatcher" - - @staticmethod - def current() -> "OpDispatcher": - return context.current(OpDispatcher) - - def __enter__(self) -> "OpDispatcher": - return context.push(OpDispatcher, self) - - def __exit__(self, exc_type, exc_val, exc_tb): - context.pop(OpDispatcher, self) - - -def define_op(f: T) -> T: - idname = f.__name__ - - @functools.wraps(f) - def wrapped(*args, **kwargs): - dispatcher = OpDispatcher.current() - try: - handler = getattr(dispatcher, f"handle_{idname}") - except AttributeError: - raise AttributeError( - f"The current OpDispatcher ({dispatcher}) does not register a handler for {idname}" - ) - return handler(wrapped, *args, **kwargs) - - wrapped.__tk_op_idname__ = idname - return wrapped diff --git a/core/shark_turbine/kernel/ops/control_flow.py b/core/shark_turbine/kernel/ops/control_flow.py deleted file mode 100644 index beb94c772..000000000 --- a/core/shark_turbine/kernel/ops/control_flow.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import ( - Any, - List, - Tuple, - Optional, - Iterator, - overload, - Callable, - Tuple, -) -import typing - -if typing.TYPE_CHECKING: - from ..lang.types import Index - -from .base import ( - define_op, -) - -__all__ = [ - "for_loop", -] - - -@define_op -def for_loop( - start: "Index", - stop: Optional["Index"] = None, - step: Optional["Index"] = None, - init_args: List[Any] = [], -) -> Callable[[Callable[["Index", List[Any]], Optional[Tuple]]], List[Any]]: - # TODO: The output type signature should also allow a single element return - # instead of a List for better programming experience. - ... diff --git a/core/shark_turbine/kernel/ops/core.py b/core/shark_turbine/kernel/ops/core.py deleted file mode 100644 index 0121c69c0..000000000 --- a/core/shark_turbine/kernel/ops/core.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Any, TypeVar -import typing - -if typing.TYPE_CHECKING: - from ..lang.types import Index, Vector - -from .base import define_op -from .._support.dtype import DataType - -__all__ = [ - "kernel_buffer_getitem", - "kernel_buffer_setitem", - "thread_program_id", - "to_dtype", -] - - -@define_op -def kernel_buffer_getitem(kernel_buffer, key) -> "Vector": - ... - - -@define_op -def kernel_buffer_setitem(kernel_buffer, key, item) -> None: - ... - - -@define_op -def thread_program_id(axis: int) -> "Index": - ... - - -@define_op -def to_dtype(val, dtype: DataType): - ... diff --git a/core/shark_turbine/kernel/ops/math.py b/core/shark_turbine/kernel/ops/math.py deleted file mode 100644 index 0b617baa5..000000000 --- a/core/shark_turbine/kernel/ops/math.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Tuple -import typing - -if typing.TYPE_CHECKING: - from ..lang.types import Vector - -from .base import ( - define_op, -) - -__all__ = [ - "exp2", - "vector_constant", -] - - -@define_op -def exp2(val): - ... - - -@define_op -def vector_constant(shape: Tuple[int, ...], dtype, value: int | float) -> "Vector": - ... diff --git a/core/shark_turbine/kernel/ops/memory.py b/core/shark_turbine/kernel/ops/memory.py deleted file mode 100644 index a4cb073dc..000000000 --- a/core/shark_turbine/kernel/ops/memory.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import ( - Any, - List, - Tuple, - Optional, - Iterator, - overload, - Callable, - Tuple, -) -import typing - -if typing.TYPE_CHECKING: - from ..lang.types import Index, Vector - -from .base import ( - define_op, -) - -__all__ = ["kernel_buffer_load", "kernel_buffer_store"] - - -@define_op -def kernel_buffer_load( - kernel_buffer, - multi_index: Tuple["Index", ...], - shape: Tuple[int, ...], -) -> "Vector": - ... - - -@define_op -def kernel_buffer_store( - kernel_buffer, - multi_index: Tuple["Index", ...], - item: "Vector", -) -> None: - ... diff --git a/core/shark_turbine/kernel/ops/reduction.py b/core/shark_turbine/kernel/ops/reduction.py deleted file mode 100644 index 3a97057bb..000000000 --- a/core/shark_turbine/kernel/ops/reduction.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Any, List, Optional -import typing - -if typing.TYPE_CHECKING: - from ..lang.types import Vector - -from .base import ( - define_op, -) - -__all__ = [ - "vector_max", - "vector_sum", - "vector_dot", -] - - -@define_op -def vector_max(vector: "Vector", axis=None, acc=None) -> "Vector": - ... - - -@define_op -def vector_sum(vector: "Vector", axis=None, acc=None) -> "Vector": - ... - - -@define_op -def vector_dot(lhs: "Vector", rhs: "Vector", acc=None) -> "Vector": - ... diff --git a/core/shark_turbine/kernel/ops/shape_manipulation.py b/core/shark_turbine/kernel/ops/shape_manipulation.py deleted file mode 100644 index 9f7285bd3..000000000 --- a/core/shark_turbine/kernel/ops/shape_manipulation.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Tuple -import typing - -if typing.TYPE_CHECKING: - from ..lang.types import Vector - -from .base import ( - define_op, -) - -__all__ = [ - "vector_broadcast", - "vector_broadcast_in_dim", - "vector_transpose", -] - - -@define_op -def vector_broadcast(v: "Vector", leading_sizes: Tuple[int]) -> "Vector": - ... - - -@define_op -def vector_broadcast_in_dim( - v: "Vector", shape: Tuple[int], broadcast_dimensions: Tuple[int] -) -> "Vector": - ... - - -@define_op -def vector_transpose(v: "Vector", permutation: Tuple[int]) -> "Vector": - ... diff --git a/core/shark_turbine/ops/__init__.py b/core/shark_turbine/ops/__init__.py deleted file mode 100644 index 3f4a4554e..000000000 --- a/core/shark_turbine/ops/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from . import iree diff --git a/core/shark_turbine/ops/iree.py b/core/shark_turbine/ops/iree.py deleted file mode 100644 index e28826db8..000000000 --- a/core/shark_turbine/ops/iree.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Custom ops for built-in IREE functionality.""" -from typing import cast - -from ..support.ir_imports import ( - RankedTensorType, - StringAttr, - Value, - flow_d, - tensor_d, -) - -from ..runtime.op_reg import ( - CustomOp, - KernelBuilder, - KernelSelection, - AttrArg, - def_library, -) - -__all__ = [ - "trace", -] - -IREE_LIBRARY = def_library("iree") - - -################################################################################ -# trace_tensor / trace_tensors -# See the flow.tensor_trace op for details. In essence: -# * trace_key is a name to label tensors with (intended for log filtering) -# * tensor or tensors are values to log a value for -################################################################################ - - -def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]): - dynamic_dims = [] - for t in ts: - rtt = RankedTensorType(t.type) - for i in range(rtt.rank): - if rtt.is_dynamic_dim(i): - dynamic_dims.append(tensor_d.dim(t, kb.constant_index(i))) - flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims) - - -@CustomOp.register(library=IREE_LIBRARY) -class trace_tensor(CustomOp): - signature = "trace_tensor(str trace_key, Tensor(a!) tensor) -> ()" - - def select(self, ksel: KernelSelection): - ksel.attr_str(0) - ksel.arg_tensor(1, inplace_tied=True) - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - key = cast(AttrArg, ksel.arg_descs[0]) - _emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]]) - kb.yield_results(kb.arg_bindings[1]) diff --git a/core/shark_turbine/runtime/__init__.py b/core/shark_turbine/runtime/__init__.py deleted file mode 100644 index 29434c268..000000000 --- a/core/shark_turbine/runtime/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .device import * -from . import op_reg diff --git a/core/shark_turbine/runtime/device.py b/core/shark_turbine/runtime/device.py deleted file mode 100644 index a294525df..000000000 --- a/core/shark_turbine/runtime/device.py +++ /dev/null @@ -1,380 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from functools import lru_cache -from typing import Callable, Optional, Union -from threading import local, Lock - -import torch - -from iree.runtime import ( - BufferUsage, - HalBufferView, - HalDevice, - HalDriver, - MemoryType, - VmInstance, - VmModule, - create_hal_module, - get_driver, -) - -from ..support.conversions import ( - dtype_to_element_type, - torch_dtype_to_numpy, -) - -from ..support.exceptions import ( - NoCurrentDeviceError, - MismatchedDeviceSetClearError, - UnsupportedTorchDeviceError, -) - -from ..support.logging import runtime_logger as logger - -__all__ = [ - "get_vm_instance", - "Device", - "DeviceState", -] - -_CONFIG_LOCK = Lock() -_GLOBAL_VM_INSTANCE: Optional[VmInstance] = None -_CURRENT_THREAD = local() - -############################################################################### -# DeviceState ande Device classes. -# These associated shared VmInstance and HalDrivers with a concrete HalDevice. -# The Device class also adds other accounting needed for interop in PyTorch's -# eager environment (i.e. transfer and compute queue counters, etc). -############################################################################### - - -def get_vm_instance() -> VmInstance: - global _GLOBAL_VM_INSTANCE - if not _GLOBAL_VM_INSTANCE: - with _CONFIG_LOCK: - if not _GLOBAL_VM_INSTANCE: - _GLOBAL_VM_INSTANCE = VmInstance() - return _GLOBAL_VM_INSTANCE - - -class DeviceState: - """State for an instantiated HAL device. - - Note that the IREE runtime internally manages a global cache of drivers for - standard named-access (not custom-constructed) drivers. - """ - - __slots__ = [ - "device", - "driver", - "instance", - ] - - def __init__( - self, - *, - driver: Union[str, HalDriver], - device: Optional[HalDevice] = None, - vm_instance: Optional[VmInstance] = None, - ): - self.instance = vm_instance or get_vm_instance() - self.driver = driver if isinstance(driver, HalDriver) else get_driver(driver) - self.device = device if device else self.driver.create_default_device() - - @staticmethod - @lru_cache(maxsize=None) - def from_uri(uri: str) -> "DeviceState": - driver = get_driver(uri) - return DeviceState(driver=driver, device=driver.create_device_by_uri(uri)) - - -class Device: - """Represents a low-level device (HalDriver/HalDevice) and scheduling data. - - This is the type that user's interact with as a 'Device'. Devices can be handled - loose-leaf or bound to a thread with a context manager. - """ - - __slots__ = [ - "_s", - "_main_timeline", - "_main_timepoint", - "_tx_timeline", - "_tx_timepoint", - "_fence_capacity", - "compile_target_flags", - "export_torch_tensor", - "import_torch_tensor", - "instance_cache_key", - "type_cache_key", - ] - - _s: DeviceState - - # Each device will have a function attached to import a torch.tensor - # *that is already on that device* directly from device memory. - # This is unsafe and relatively unchecked. If criss-crossing devices, - # it is undefined behavior. - import_torch_tensor: Callable[[torch.Tensor], HalBufferView] - - # Devices can also export a torch tensor from a HalBufferView, given - # a meta tensor that describes it. - export_torch_tensor: Callable[[HalBufferView, torch.Tensor], torch.Tensor] - - # Cache key that uniquely identifies this device. - instance_cache_key: str - - # Cache key that uniquely identifies this type of device (currently - # based on its driver). - type_cache_key: str - - # Compiler flags to use to target this device. - # TODO: We should replace this with a target attribute but need an API - # to derive that. - compile_target_flags: tuple[str, ...] - - def __new__( - cls, uri: Optional[str] = None, *, device_state: Optional[DeviceState] = None - ): - if uri is not None: - # Construction by URI is cached on the thread. - assert not device_state, "device_state= cannot be given with explicit URI" - try: - existing = _CURRENT_THREAD.device_by_uri[uri] - except (AttributeError, KeyError): - ... - else: - return existing - - # New instance. - device_state = DeviceState.from_uri(uri) - new_inst = super().__new__(cls) - new_inst._s = device_state - try: - _CURRENT_THREAD.device_by_uri[uri] = new_inst - except AttributeError: - _CURRENT_THREAD.device_by_uri = {uri: new_inst} - new_inst._initialize() - return new_inst - else: - # Explicit construction with a device_state is assumed that you know what you - # are doing and an uncached instance will be returned. This will be unsychronized - # relative to any cached instance. - assert device_state, "device_state= must be given if URI ommitted" - new_inst = super().__new__(cls) - new_inst._s = device_state - new_inst._initialize() - return new_inst - - def _initialize(self): - d = self._s.device - self._main_timeline = d.create_semaphore(0) - self._main_timepoint = 0 - self._tx_timeline = d.create_semaphore(0) - self._tx_timepoint = 0 - # Maximum number of semaphores the device uses. Can be increased if doing out of the - # ordinary scheduling. - self._fence_capacity = 2 - - # Perform driver specific augmentations. - # TODO: Add a HalDriver.id property to get the driver name instead of parsing - # the device repr. - driver_id = repr(d) - colon_pos = driver_id.find(":") - if colon_pos >= 0: - driver_id = driver_id[0:colon_pos] - try: - import_fn = TORCH_TENSOR_IMPORTERS[driver_id] - export_fn = TORCH_TENSOR_EXPORTERS[driver_id] - self.import_torch_tensor = lambda t: import_fn(self, t) - self.export_torch_tensor = lambda bv, t: export_fn(self, bv, t) - self.compile_target_flags = DEVICE_TARGET_COMPILE_FLAGS[driver_id] - except KeyError as e: - raise AssertionError( - f"Unsupported TORCH_TENSOR_IMPORTERS for iree driver '{driver_id}'" - ) from e - - # Cache keys. - # TODO: The type cache key should actually be based on the driver id - # and device characteristics hash. - self.instance_cache_key = repr(d) - self.type_cache_key = driver_id - - @property - def hal_device(self) -> HalDevice: - return self._s.device - - @property - def vm_instance(self) -> VmInstance: - return self._s.instance - - def create_hal_module(self) -> VmModule: - s = self._s - return create_hal_module(s.instance, s.device) - - @staticmethod - def current() -> "Device": - try: - return _CURRENT_THREAD.stack[-1] - except (AttributeError, IndexError): - raise NoCurrentDeviceError() - - def set(self) -> "Device": - """Sets this device as the current device without a context manager.""" - try: - _CURRENT_THREAD.stack.append(self) - except AttributeError: - _CURRENT_THREAD.stack = [self] - return self - - def clear(self): - """Clears the current device without a context manager.""" - try: - c = _CURRENT_THREAD.stack[-1] - if _CURRENT_THREAD.stack[-1] is self: - _CURRENT_THREAD.stack.pop() - return - except (AttributeError, IndexError): - ... - raise MismatchedDeviceSetClearError() - - def __repr__(self): - return f"" - - def __enter__(self): - try: - _CURRENT_THREAD.stack.append(self) - except AttributeError: - _CURRENT_THREAD.stack = [self] - - def __exit__(self, type, value, traceback): - _CURRENT_THREAD.stack.pop() - - -def _device_import_torch_tensor_cpu(device: Device, t: torch.Tensor) -> HalBufferView: - hal_device = device.hal_device - element_type = dtype_to_element_type(t.dtype) - # TODO: In this case, we should be importing the raw buffer, but this is not - # generically exposed to Python in the IREE runtime. - bv = device.hal_device.allocator.allocate_buffer_copy( - memory_type=MemoryType.DEVICE_LOCAL, - allowed_usage=BufferUsage.DEFAULT, - device=hal_device, - buffer=t.detach().numpy(), - element_type=element_type, - ) - return bv - - -def _device_export_torch_tensor_cpu( - device: Device, bv: HalBufferView, like: torch.Tensor -) -> torch.Tensor: - # TODO: Similar to import, we know that the buffer is in local CPU memory - # and could export it if we had Python API support for that. Until we have - # that, we do this very torturous indirection. - mapped_memory = bv.map() - shape = list(like.shape) - np_dtype = torch_dtype_to_numpy(like.dtype) - mapped_array = mapped_memory.asarray(shape, np_dtype) - return torch.from_numpy(mapped_array) - - -# Mapping of torch tensor importers keyed by driver name. -TORCH_TENSOR_IMPORTERS: dict[str, Callable[[Device, torch.Tensor], HalBufferView]] = { - "local-sync": _device_import_torch_tensor_cpu, - "local-task": _device_import_torch_tensor_cpu, -} - -TORCH_TENSOR_EXPORTERS: dict[ - str, Callable[[Device, HalBufferView, torch.Tensor], torch.Tensor] -] = { - "local-sync": _device_export_torch_tensor_cpu, - "local-task": _device_export_torch_tensor_cpu, -} - -DEVICE_TARGET_COMPILE_FLAGS: dict[str, tuple[str, ...]] = { - "local-task": ( - "--iree-hal-target-backends=llvm-cpu", - "--iree-llvmcpu-target-cpu-features=host", - ), -} - -# Aliases. -DEVICE_TARGET_COMPILE_FLAGS["local-sync"] = DEVICE_TARGET_COMPILE_FLAGS["local-task"] - -# Make sure all tables have the same keys. -assert ( - TORCH_TENSOR_IMPORTERS.keys() == DEVICE_TARGET_COMPILE_FLAGS.keys() -), "Not all devices have the same configs" - -assert ( - TORCH_TENSOR_IMPORTERS.keys() == TORCH_TENSOR_EXPORTERS.keys() -), "Not all devices have the same configs" - -############################################################################### -# torch.device to Device mapping -############################################################################### - - -def lookup_device_from_torch( - torch_device: torch.device, *, create: bool = True -) -> Optional[Device]: - """Gets a shared Device corresponding to the given torch.device. - - This will return None if the device is wholly unsupported or if - create=False. Otherwise, faults in setting up the device are - reported as an appropriate exception. - """ - try: - mapping = _CURRENT_THREAD.device_by_torch_device - except AttributeError: - _CURRENT_THREAD.device_by_torch_device = mapping = {} - device = mapping.get(torch_device) - if device is not None or not create: - return device - logger.debug("Creating turbine device for torch.device = %r", torch_device) - device = _create_device_from_torch(torch_device) - if device is not None: - mapping[torch_device] = device - return device - - -def get_device_from_torch(torch_device: torch.device) -> Device: - """Gets a shared Device corresponding to the given torch.device. - - Raises an exception if the device cannot be created. - """ - device = lookup_device_from_torch(torch_device) - if device is None: - raise UnsupportedTorchDeviceError(torch_device) - return device - - -def _create_device_from_torch(torch_device: torch.device) -> Optional[Device]: - torch_type = torch_device.type - uri = None - if torch_type == "cpu": - uri = "local-task" - - if uri is None: - return None - - return Device(uri) - - -############################################################################### -# Utilities -############################################################################### - -# The nanobind leak checker doesn't interop well with the way that -# global state is managed for PyTorch. It isn't clear that this -# is a fully correctable state of affairs, so we just disable it -# for now. RIP nice things :( -from iree.runtime._binding import disable_leak_checker - -disable_leak_checker() diff --git a/core/shark_turbine/runtime/op_reg/__init__.py b/core/shark_turbine/runtime/op_reg/__init__.py deleted file mode 100644 index c18790d33..000000000 --- a/core/shark_turbine/runtime/op_reg/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .base import * diff --git a/core/shark_turbine/runtime/op_reg/base.py b/core/shark_turbine/runtime/op_reg/base.py deleted file mode 100644 index 2fcf7863f..000000000 --- a/core/shark_turbine/runtime/op_reg/base.py +++ /dev/null @@ -1,873 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Base classes for registering custom operations with the PyTorch -dispatcher. -""" - -from typing import Any, Callable, List, Optional, Sequence, Type, Union, cast - -from abc import ABC, abstractmethod -import functools -import logging -import re -import textwrap -import threading - -import torch -from torch import Tensor - -from ...support.ir_imports import ( - Block, - Context, - FunctionType, - IndexType, - InsertionPoint, - IntegerAttr, - Location, - StringAttr, - SymbolTable, - IrType, - Value, - arith_d, - builtin_d, - func_d, -) - -from ...support.logging import runtime_logger as logger - -from ...support.conversions import ( - TORCH_DTYPE_TO_IREE_TYPE_ASM, -) - -__all__ = [ - "ArgDescriptor", - "AttrArg", - "CustomOp", - "FreeFuncKernelBuilder", - "IntArg", - "KernelBuilder", - "KernelSelection", - "TensorArg", - "def_library", -] - - -############################################################################### -# Op library management -############################################################################### - -_CONFIG_LOCK = threading.Lock() - - -def def_library(ns) -> torch.library.Library: - """Creates a new 'DEF' library which contains custom ops. - - It is necessary to create such custom op libraries in this way since - the library is registered with the compiler in such a way that it can - operate over all known custom ops. - """ - return torch.library.Library(ns, "DEF") - - -def default_dispatch_keys() -> list[str]: - # TODO: Dynamically determine what devices to register against. - # Note that we have to register against specific keys instead of the - # fallback, as fallback is too broad and breaks certain elements of - # fx tracing. - return ["CPU"] - - -# All such custom kernels are registered in the 'turbine' library/namespace. -# We also allow extending existing libraries outside of this, but that is -# the non default case. -TURBINE_LIBRARY = def_library("turbine") - -# Set of all programmatically registered op names in libraries we manage. -# This is used to detect name collisions eagerly and providing name uniqueing. -# Keys are (Library.ns, name) -DEFINED_OP_NAMES: set[tuple[str, str]] = set() - -# Mapping of (Library.ns, name_spec) to an integer counter used to unique it. -UNIQUE_OP_NAME_COUNTER: dict[tuple[str, str], int] = {} - - -class CustomOp(ABC): - """Users subclass this in order to register a turbine custom op.""" - - @staticmethod - def register( - op_class: Optional[Type["CustomOp"]] = None, - *, - library: torch.library.Library = TURBINE_LIBRARY, - dispatch_key: Union[str, Sequence[str], None] = None, - register_meta: bool = True, - register_impl: bool = True, - ) -> Callable: - """Class decorator for `CustomOp` implementations. - - The decorator will instantiate the class and then replace it with - the callable operation that can be used to invoke the kernel. - - Typical usage: - - ``` - @CustomOp.register - class identity(CustomOp): - ... - - result = identity(torch.tensor(1, 2, 3)) - ``` - """ - if not op_class: - return functools.partial( - CustomOp.register, - library=library, - dispatch_key=dispatch_key, - register_meta=register_meta, - register_impl=register_impl, - ) - instance = op_class( - library=library, - dispatch_key=dispatch_key, - register_meta=register_meta, - register_impl=register_impl, - ) - return instance.op - - def __init__( - self, - *, - library: torch.library.Library, - dispatch_key: Union[str, Sequence[str], None], - register_meta: bool, - register_impl: bool, - ): - self.name = name = _define_signature_in_library(library, self.signature) - self.library = library - self.cache_key_base = f"{library.ns}.{library.kind}::{name}" - self.op = _get_library_op(library, name) - - # The meta kernel can be provided by the selection machinery and - # does not require a tie-in to the kernel generator, which layers - # on top. - if register_meta: - library.impl(name, _get_meta_impl(self), "Meta") - - if register_impl: - if dispatch_key is None: - dispatch_key = default_dispatch_keys() - elif isinstance(dispatch_key, str): - dispatch_key = [dispatch_key] - for k in dispatch_key: - library.impl(name, _create_impl_trampoline(self), k) - - fq_name = f"{library.ns}.{name}" - ALL_CUSTOM_OP_REGS[fq_name] = self - - @property - @abstractmethod - def signature(self) -> str: - """PyTorch function signature. - - This is in the normal PyTorch kernel registration form. For example: - - ``` - my_op(Tensor t) -> Tensor - ``` - - The signature can have some special tokens in the name part: - - * "@UNIQUE@": Generates a name-specific numeric value and replaces it. - """ - ... - - @abstractmethod - def select(self, sel: "KernelSelection"): - """Performs kernel selection. - - This method has three purposes: - - 1. Selects which kernel specialization is needed based on - arguments. - 2. Returns the meta tensor results of the operation, effectively - completing the transfer function from argument types to - result types. - 3. Sets additional metadata that the generate method can use. - - The `device="meta"` kernel implementation is composed completely by - invoking `select`. For implementation devices, `select` is called - for each invocation. The `generate` will be called subsequently if - the kernel needs to be generated. - """ - ... - - @abstractmethod - def generate(self, ksel: "KernelSelection", kb: "KernelBuilder"): - """Generates a kernel based on the `KernelSelection`. - - This method should generate IR into the given `KernelBuilder`. It - can do so by consulting any state set on the `KernelSelection`. - Each `KernelSelection.args` corresponds to `KernelBuilder.args`. - Unless if the argument was set as `ir_arity=0`, the argument - will be a `Value`. Otherwise, it will be `None`. It is recommended - to use `KernelBuilder.arg(n)` to access. - - Generation should conclude with a call to `KernelBuilder.yield_results`. - """ - ... - - -# All instantiated CustomOp instances, keyed by fully qualified name. This is -# used by the AOT compiler to expand custom ops that were captured in a trace. -ALL_CUSTOM_OP_REGS: dict[str, CustomOp] = {} - - -class KernelSelection(ABC): - """Represents a selected kernel based on a concrete signature. - - The `CustomOp.select` method must yield an instance of this, and - it will be done for every invocation. At this point, the kernel - has not yet been generated, but we have selected a generation - strategy based on a concrete signature. - - This mechanism also serves as the means for servicing `meta` - registrations because it implicitly computes everything needed - (i.e. shapes, etc). - """ - - __slots__ = [ - "arg_descs", - "inplace_tied_arg_descs", - "op", - "result_descs", - "variant", - ] - - def __init__(self, op: CustomOp, arg_arity: int): - self.op = op - self.arg_descs = cast(list[Optional[ArgDescriptor]], arg_arity * [None]) - self.inplace_tied_arg_descs: list[ArgDescriptor] = [] - self.result_descs: list[ArgDescriptor] = [] - self.variant: str = "default" - - def __repr__(self): - lines = [ - "KernelSelection<", - f" op = '{self.op.name}',", - f" variant = '{self.variant}',", - " arg_descs = [", - ] - for arg_desc in self.arg_descs: - lines.append(f" {arg_desc},") - lines.append(" ],") - lines.append(" result_descs = [") - for result_desc in self.result_descs: - lines.append(f" {result_desc},") - lines.append(" ]") - lines.append(">") - return "\n".join(lines) - - def generate_meta_returns(self) -> Any: - results = [d.generate_meta() for d in self.result_descs] - arity = len(results) - if arity == 1: - return results[0] - elif arity == 0: - return None - else: - return tuple(results) - - @property - def spec_key(self) -> str: - try: - arg_keys = ",".join( - d.spec_key if d is not None else "None" for d in self.arg_descs - ) - return_keys = ",".join( - d.spec_key if d is not None else "None" for d in self.result_descs - ) - return ( - f"{self.op.cache_key_base}::{self.variant}({arg_keys})->({return_keys})" - ) - except Exception as e: - raise AssertionError( - f"Error generating spec_key from:\n{textwrap.indent(repr(self), ' ')}" - ) from e - - @abstractmethod - def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> "TensorArg": - """Declares an argument to allow any ranked tensor and to specialize for each rank - and dtype. - - Returns the argument descriptor, which can be used to further inspect or constrain - the selection. It will default to allowing all dimensions to be dynamic. - - If inplace_tied is True, then this argument participates in in-place - semantics. The kernel must yield the result-mutated after all normal - results in the order declared. - """ - ... - - @abstractmethod - def arg_tensor_list(self, arg: int) -> "TensorListArg": - """Declares an argument to accept a list of tensors which will be specialized - for the list size and each rank/dtype. - - Returns the argument descriptor, which can be used to further inspect or constrain - the selection. It will default to allowing all dimensions to be dynamic. - """ - ... - - @abstractmethod - def arg_int(self, arg: int) -> "IntArg": - """Declares an argument to be an integer value that can take any value. - - Returns the argument descriptor, which can be used to further inspect or constrain - the selection. - """ - ... - - @abstractmethod - def attr_str(self, arg: int) -> "AttrArg": - """Declares an argument to be a string attribute. - - Such arguments are not materialized in the IR as Values but may be used to - generate the IR. In AOT contexts, they must be derived from static values. - """ - ... - - @abstractmethod - def return_tensor(self, t: Tensor) -> "TensorArg": - """Marks the next return value as a Tensor. - - By default, it will be rank and dtype specialized but have completely dynamic - dimensions. Dimensions can be further constrained by modifying the returned - descriptor. - """ - ... - - def return_new_tensor(self, size: list, dtype: torch.dtype) -> "TensorArg": - """Constructs a new symbolic tensor and marks the next result as returning it. - - This delegates to `return_tensor` but takes care of some easy to mess - up boiler plate for dynamic shapes. - """ - return self.return_tensor(torch.empty(size, dtype=dtype, device="meta")) - - -class EagerKernelSelection(KernelSelection): - """Kernel selection specialized for eager arguments.""" - - __slots__ = [ - "args", - ] - - def __init__(self, op: CustomOp, args: list[Any]): - super().__init__(op, len(args)) - self.args = args - - def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> "TensorArg": - arg_descs = self.arg_descs - arg_value = self.args[arg] - assert arg_descs[arg] is None, f"Already constrained argument {arg}" - assert isinstance( - arg_value, Tensor - ), f"Argument type mismatch from Torch for {arg}: Expected tensor, got {type(arg_value)}" - arg_descs[arg] = desc = TensorArg(arg_value) - if inplace_tied: - self.inplace_tied_arg_descs.append(desc) - return desc - - def arg_tensor_list(self, arg: int) -> "TensorListArg": - arg_descs = self.arg_descs - arg_value = self.args[arg] - assert arg_descs[arg] is None, f"Already constrained argument {arg}" - assert isinstance( - arg_value, list - ), f"Argument type mismatch from Torch for {arg}: Expected list, got {type(arg_value)}" - arg_descs[arg] = desc = TensorListArg(arg_value) - return desc - - def arg_int(self, arg: int) -> "IntArg": - arg_descs = self.arg_descs - arg_value = self.args[arg] - assert arg_descs[arg] is None, f"Already constrained argument {arg}" - assert isinstance( - arg_value, int - ), f"Argument type mismatch from Torch for {arg}: Expected int, got {type(arg_value)}" - arg_descs[arg] = desc = IntArg(arg_value) - return desc - - def attr_str(self, arg: int) -> "AttrArg": - arg_descs = self.arg_descs - arg_value = self.args[arg] - assert arg_descs[arg] is None, f"Already constrained argument {arg}" - assert isinstance( - arg_value, str - ), f"Argument type mismatch from Torch for {arg}: Expected int, got {type(arg_value)}" - arg_descs[arg] = desc = AttrArg(arg_value) - return desc - - def return_tensor(self, t: Tensor) -> "TensorArg": - desc = TensorArg(t) - self.result_descs.append(desc) - return desc - - -class AttrArg: - ir_arity: int = 0 - maybe_tensor_value: Optional[Tensor] = None - is_list: bool = False - - __slots__ = [ - "v", - "spec_value", - ] - - def __init__(self, v: object): - self.v = v - # We specialize on every distinct value. - self.spec_value: Optional[Any] = v - - def __repr__(self): - return f"AttrArg(<{self.spec_value}>)" - - def generate_meta(self) -> object: - return self.v - - @property - def spec_key(self) -> str: - """Generates a key that will be the same for all specializations.""" - return f"attr<{self.spec_value}>" - - @property - def mlir_type_asm(self) -> str: - raise AssertionError("Cannot resolve `mlir_type_asm` for an AttrArg") - - -class IntArg: - __slots__ = [ - "ir_arity", - "spec_value", - "v", - ] - - # All descriptors have an attribute to indicate their value - # as a tensor, and those that aren't are fixated to None. - # This is to enable fast lookup in the hot path of determining - # how to dispatch. - maybe_tensor_value: Optional[Tensor] = None - is_list: bool = False - - def __init__(self, v: int): - self.v = v - self.spec_value: Optional[Any] = None - self.ir_arity: int = 1 - - def __repr__(self): - return f"IntArg({self.v}, spec_value={self.spec_value}, is_ir_arg={self.is_ir_arg})" - - def generate_meta(self) -> int: - return self.v - - @property - def spec_key(self) -> str: - """Generates a key that will be the same for all specializations.""" - return f"int<{self.spec_value}>" - - @property - def mlir_type_asm(self) -> str: - # TODO: We can have individual kernels constrain this to a narrower - # type. - return "i64" - - -_NoneInt: Optional[int] = None - - -class TensorArg: - __slots__ = [ - "t", - "spec_dims", - "maybe_tensor_value", - ] - - ir_arity: int = 1 - is_list: bool = False - - def __init__(self, t: Tensor): - self.t = t - # Any static dims that we are specializing. Defaults to all dynamic. - self.spec_dims = len(t.shape) * [_NoneInt] - # All descriptors have an attribute to indicate their value - # as a tensor, and those that aren't are fixated to None. - # This is to enable fast lookup in the hot path of determining - # how to dispatch. - self.maybe_tensor_value: Tensor = t - - def specialize_all_dims(self): - """Marks all dimensions as specialized.""" - self.spec_dims = list(self.t.shape) - - def specialize_dims(self, *indices: int): - """Specializes individual dimensions. - - `i` can have negative indexing. - """ - for i in indices: - self.spec_dims[i] = self.t.size(i) - - def __repr__(self): - return ( - f"TensorArg(shape={self.t.shape}, dtype={self.t.dtype}, " - f"spec_dims={self.spec_dims})" - ) - - def generate_meta(self) -> Tensor: - t = self.t - if t.device == "meta": - return t - else: - return t.clone().detach().to("meta") - - @property - def spec_key(self) -> str: - """Generates a key that will be the same for all specializations.""" - t = self.t - return f"tensor[{len(t.shape)}:{str(t.dtype)}]<{self.spec_dims}>" - - @property - def mlir_type_asm(self) -> str: - t = self.t - try: - dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[t.dtype] - except KeyError as e: - raise KeyError( - f"Unknown mapping of torch dtype {t.dtype} to MLIR " - f"(possibly missing in TORCH_DTYPE_TO_IREE_TYPE_ASM table)" - ) from e - dim_asm = "x".join(["?" if d is None else str(d) for d in self.spec_dims]) - spec = f"{dim_asm}x{dtype_asm}" if dim_asm else dtype_asm - return f"tensor<{spec}>" - - -class TensorListArg: - __slots__ = [ - "ts", - "spec_dims", - "ir_arity", - "maybe_tensor_value", - ] - - is_list: bool = True - - def __init__(self, ts: list[Tensor]): - self.ts = ts - self.ir_arity = len(ts) - # Any static dims that we are specializing. Defaults to all dynamic. - self.spec_dims: list[list[Optional[int]]] = [len(t.shape) * [None] for t in ts] # type: ignore - # All descriptors have an attribute to indicate their value - # as a tensor, and those that aren't are fixated to None. - # This is to enable fast lookup in the hot path of determining - # how to dispatch. - self.maybe_tensor_value: list[Tensor] = ts - - def __repr__(self): - return ( - f"TensorListArg(shape={[t.shape for t in self.ts]}, " - f"dtype={[t.dtype for t in self.ts]}, " - f"spec_dims={self.spec_dims}, ir_arity={self.ir_arity})" - ) - - def generate_meta(self) -> list[Tensor]: - metas = [] - for t in self.ts: - if t.device == "meta": - metas.append(t) - else: - metas.append(t.clone().detach().to("meta")) - return metas - - @property - def spec_key(self) -> str: - """Generates a key that will be the same for all specializations.""" - return ( - f"tensor[{[len(t.shape) for t in self.ts]}" - f":{[str(t.dtype) for t in self.ts]}]<{self.spec_dims}>" - ) - - @property - def mlir_type_asm(self) -> list[str]: - asms = [] - for t, spec_dims in zip(self.ts, self.spec_dims): - try: - dtype_asm = TORCH_DTYPE_TO_IREE_TYPE_ASM[t.dtype] - except KeyError as e: - raise KeyError( - f"Unknown mapping of torch dtype {t.dtype} to MLIR " - f"(possibly missing in TORCH_DTYPE_TO_IREE_TYPE_ASM table)" - ) from e - dim_asm = "x".join(["?" if d is None else str(d) for d in spec_dims]) - spec = f"{dim_asm}x{dtype_asm}" if dim_asm else dtype_asm - asms.append(f"tensor<{spec}>") - return asms - - -ArgDescriptor = Union[AttrArg, IntArg, TensorArg, TensorListArg] - -############################################################################### -# KernelBuilder -# Helper object for constructing IR -############################################################################### - - -class KernelBuilder(ABC): - """Support class for building a kernel.""" - - def __init__( - self, - ksel: KernelSelection, - arg_bindings: list[Union[Value, list[Value]]], - *, - ip: InsertionPoint, - module_body: Block, - symbol_table: SymbolTable, - ): - self.ksel = ksel - self.arg_bindings = arg_bindings - self.ip = ip - self.module_body = module_body - self.context = module_body.owner.context - self.symbol_table = symbol_table - self.yielded = False - - def arg_value(self, index: int) -> Union[list[Value], Value]: - """Gets the concrete IR `Value` for the argument at `index`. - - This will assert if the corresponding argument was set as `ir_arity=0` - during kernel selection. - """ - try: - v = self.arg_bindings[index] - except IndexError as e: - raise AssertionError( - f"Out of range access to kernel arg. Expected 0..{len(self.arg_bindings)}. Got {index}" - ) from e - assert ( - v is not None - ), f"No `Value` is available for arg {index}: it was marked as `is_ir_arg=False` during kernel selection." - return v - - @abstractmethod - def yield_results(self, *results: Value): - """Yields results of the kernel computation.""" - ... - - def constant_index(self, i: int) -> Value: - """Builds a constant index value.""" - return arith_d.constant(IndexType.get(), IntegerAttr.get(IndexType.get(), i)) - - -class FreeFuncKernelBuilder(KernelBuilder): - """Kernel builder that emits the body of the kernel into a free function. - - This is intended to be used when compiling a standalone module that will - be directly invoked by the runtime. Further variants exist that generate - into a func but also emit a call into another local context. - """ - - def __init__( - self, - ksel: KernelSelection, - *, - module_body: Block, - symbol_table: SymbolTable, - func_name: Optional[str] = None, - is_public: bool = True, - ): - self.module_op = module_body.owner - context = self.module_op.context - if func_name is None: - func_name = ksel.op.name - with context, Location.unknown(), InsertionPoint(module_body): - # Assemble arg types. - arg_types = [] - for d in ksel.arg_descs: - assert d is not None, "NYI: None arguments" - arity = d.ir_arity - if not d.is_list: - if arity == 1: - arg_types.append(IrType.parse(d.mlir_type_asm)) - else: - continue - else: - for i in range(arity): - arg_types.append(IrType.parse(d.mlir_type_asm[i])) - - # Assemble result types. - result_types = [] - for d in (*ksel.result_descs, *ksel.inplace_tied_arg_descs): - if not d.is_list: - if d.ir_arity == 1: - result_types.append(IrType.parse(d.mlir_type_asm)) - else: - continue - else: - raise AssertionError("NYI: arity > 1 results") - - # Create the func. - ftype = FunctionType.get(arg_types, result_types) - func_op = func_d.FuncOp(func_name, ftype) - if not is_public: - func_op.attributes["sym_visibility"] = StringAttr.get("private") - entry_block: Block = func_op.add_entry_block() - symbol_table.insert(func_op) - - # Map inputs to arg bindings, lining up with arguments that are elided. - block_arguments = list(entry_block.arguments) - block_arg_index = 0 - arg_bindings: list[Optional[Value]] = [] - for desc in ksel.arg_descs: - assert desc is not None, "NYI: None arguments" - arity = desc.ir_arity - if not desc.is_list: - if arity == 1: - arg_bindings.append(block_arguments[block_arg_index]) - block_arg_index += 1 - else: - arg_bindings.append(None) - else: - arg_bindings.append( - block_arguments[block_arg_index : block_arg_index + arity] - ) - block_arg_index += arity - - super().__init__( - ksel, - arg_bindings, - ip=InsertionPoint(entry_block), - module_body=module_body, - symbol_table=symbol_table, - ) - - @staticmethod - def create_module( - ksel: KernelSelection, - *, - context: Optional[Context] = None, - func_name: Optional[str] = None, - is_public: bool = True, - ) -> "FreeFuncKernelBuilder": - """Short-cut to create a new module with a single function in one shot.""" - if context is None: - context = Context() - with context, Location.unknown(): - module_op = builtin_d.ModuleOp() - return FreeFuncKernelBuilder( - ksel, - module_body=module_op.body, - symbol_table=SymbolTable(module_op), - func_name=func_name, - is_public=is_public, - ) - - def yield_results(self, *results: Value): - """Yields results of the kernel computation.""" - assert not self.yielded, "yield_results has already been called" - ksel = self.ksel - expected_count = len(ksel.result_descs) + len(ksel.inplace_tied_arg_descs) - assert ( - len(results) == expected_count - ), f"Mismatched yielded results and declared+inplace: Expected={expected_count}, Got={len(results)}" - with self.ip, Location.unknown(): - func_d.ReturnOp(results) - self.yielded = True - - -############################################################################### -# Private utilities -############################################################################### - - -def _get_library_op(library: torch.library.Library, name: str) -> Any: - ns = getattr(torch.ops, library.ns) - return getattr(ns, name) - - -def _get_meta_impl(op: CustomOp): - def meta(*args): - sel = EagerKernelSelection(op, args) - op.select(sel) - if logger.isEnabledFor(logging.DEBUG): - logging.debug( - "Meta dispatch on %s for specialization %s", op.name, sel.spec_key - ) - return sel.generate_meta_returns() - - return meta - - -def _create_impl_trampoline(op: CustomOp): - # Import lazily when an implementation trampoline is requested to avoid - # circular dependency between base objects and eager runtime goo. - from .eager import ( - eager_dispatch, - ) - - def handler(*args): - ksel = EagerKernelSelection(op, args) - op.select(ksel) - if logger.isEnabledFor(logging.DEBUG): - logging.debug( - "Dispatch on %s for specialization %s", op.name, ksel.spec_key - ) - return eager_dispatch(ksel) - - return handler - - -def _define_signature_in_library(lib: torch.library.Library, signature: str) -> str: - """Helper to define a schema in the library. - - This handles the interlocked process of uniqueing, reserving the name, - and calling `lib.define` on the resulting schema. - """ - ns = lib.ns - with _CONFIG_LOCK: - name, call_args = _split_signature(signature) - - # Unique the name. - if "@UNIQUE@" in name: - # Uniqueify. - unique_key = (ns, name) - counter = UNIQUE_OP_NAME_COUNTER.get(unique_key, 0) - counter += 1 - name = name.replace("@UNIQUE@", str(counter)) - UNIQUE_OP_NAME_COUNTER[unique_key] = counter - - # Define it, recording in the defined op names. - key = (lib.ns, name) - schema = f"{name}{call_args}" - if key in DEFINED_OP_NAMES: - raise RuntimeError( - f"Duplicate turbine custom op registration: library={lib.ns}, " - f"name={name}" - ) - lib.define(schema) - DEFINED_OP_NAMES.add(key) - return name - - -_SIGNATURE_NAME_PATTERN = re.compile(r"^([^(]+)(\(.+)$") - - -def _split_signature(sig: str) -> tuple[str, str]: - """Splits a signature into name and call-args parts.""" - m = re.match(_SIGNATURE_NAME_PATTERN, sig) - if not m: - raise ValueError(f"Expected signature of form `name(...) -> type. Got: {sig}") - return m.group(1), m.group(2) diff --git a/core/shark_turbine/runtime/op_reg/compiler.py b/core/shark_turbine/runtime/op_reg/compiler.py deleted file mode 100644 index 154070e4b..000000000 --- a/core/shark_turbine/runtime/op_reg/compiler.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from dataclasses import dataclass -from timeit import default_timer -from typing import Any, Optional - -from iree.compiler.api import ( - Session, - Source, - Output, -) - -from iree.runtime import ( - VmContext, - VmFunction, - VmModule, -) - -from ...support.exceptions import ( - GeneralError, -) - -from ...support.ir_imports import ( - Location, -) - -from ...support.logging import ( - runtime_logger as logger, -) - -from ..device import ( - Device, -) - -from ..tracing import tracer - -from .base import ( - FreeFuncKernelBuilder, - KernelSelection, -) - - -@dataclass(slots=True) -class KernelCompileConfig: - # Unique key for this kernel. - key: str - - # Compiler flags to pass. - flags: list[str] - - # Use the in-process compiler (default). Some compiler options are only - # available when invoked standalone/out-of-process, so this is allowed. - # Out-of-process can also be a useful debugging feature and may be - # globally controlled. - in_process: bool = True - - # Whether compiled for async invocations. - async_invocations: bool = False - - # Whether we compiled with layout specialization and can handle certain - # permutations of strided tensors. This is currently not supported but will - # be at some point. Having the option lets us annotate code paths that are - # NYI. - layout_specialized: bool = False - - # Arbitrary objects to keep alive as part of this config. This can include - # things like unbacked memory mappings, etc. - keep_alive: Any = None - - # If tracing is enabled, this may contain a sanitized key that can be - # used to log additional information against the kernel. - tracing_key: Optional[str] = None - - -# TODO: The cache should be more than just a simple dict. Can be persistent -KERNEL_CACHE: dict[str, tuple[VmContext, VmFunction, KernelCompileConfig]] = {} - - -def _testing_get_cache_size() -> int: - return len(KERNEL_CACHE) - - -def compile_standalone_kernel( - device: Device, ksel: KernelSelection, func_name: str = "main" -) -> tuple[VmContext, VmFunction, KernelCompileConfig]: - # Early exit on cache hit. - cache_key = f"{ksel.spec_key}::{device.type_cache_key}" - cache_hit = KERNEL_CACHE.get(cache_key) - if cache_hit is not None: - return cache_hit - - # Cache miss. - start = default_timer() - config = KernelCompileConfig(cache_key, list(device.compile_target_flags)) - kb = FreeFuncKernelBuilder.create_module(ksel, func_name=func_name) - with kb.ip, Location.unknown(): - ksel.op.generate(ksel, kb) - kb.module_op.verify() - module_asm = kb.module_op.get_asm( - binary=True, enable_debug_info=True, print_generic_op_form=False - ) - generation_time = default_timer() - start - - if not config.in_process: - raise NotImplementedError("Out-of-process compilation not yet supported") - - # TODO: We could be caching the session per device type key. - # TODO: Create the source and get the module to build into from that vs - # reserializing (once issues are worked out for that). - start = default_timer() - session = Session() - session.set_flags(*config.flags) - inv = session.invocation() - source = Source.wrap_buffer(session, module_asm) - output = Output.open_membuffer() - inv.enable_console_diagnostics() - inv.parse_source(source) - if not inv.execute(): - # TODO: Capture diagnostics and report. - raise GeneralError(f"Kernel compilation failed. See diagnostics.") - inv.output_vm_bytecode(output) - mapped_memory = output.map_memory() - compilation_time = default_timer() - start - - # Load. - vm_instance = device.vm_instance - vm_module = VmModule.copy_buffer(vm_instance, mapped_memory) - # TODO: We should be able to wrap the buffer as below but there are some - # subtle ref-counting/shutdown sequencing issues that need to be resolved. - # vm_module = VmModule.wrap_buffer(vm_instance, mapped_memory) - vm_context = VmContext(vm_instance, [device.create_hal_module(), vm_module]) - main_function = vm_module.lookup_function("main") - - if tracer.enabled: - config.tracing_key = tracer.save_jit_kernel_artifacts( - cache_key=cache_key, module_asm=module_asm, binary=mapped_memory - ) - tracer.log_structured( - tag="COMPILE", - msg=f"Compiled kernel {config.tracing_key}, cache_key={cache_key}", - columns=[ - config.tracing_key, - main_function.name, - len(module_asm), - len(mapped_memory), - generation_time * 1000, - compilation_time * 1000, - " ".join(session.get_flags(non_default_only=True)), - ], - ) - cache_hit = (vm_context, main_function, config) - KERNEL_CACHE[cache_key] = cache_hit - return cache_hit diff --git a/core/shark_turbine/runtime/op_reg/eager.py b/core/shark_turbine/runtime/op_reg/eager.py deleted file mode 100644 index e32ebeb9c..000000000 --- a/core/shark_turbine/runtime/op_reg/eager.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Custom op integration into the eager executor.""" - -from timeit import default_timer -from typing import Optional - -import torch - -from iree.runtime import ( - HalBufferView, - HalElementType, - VmRef, - VmVariantList, -) - -from ...support.exceptions import ( - UnsupportedTypeError, -) - -from ...support.logging import ( - runtime_logger as logger, -) - -from ..device import ( - Device, - lookup_device_from_torch, -) - -from ..tracing import tracer - -from .base import ( - AttrArg, - IntArg, - KernelSelection, -) - -from .compiler import ( - compile_standalone_kernel, - KernelCompileConfig, -) - -__all__ = [ - "eager_dispatch", -] - - -def eager_dispatch(ksel: KernelSelection): - """Main entry-point for handling dispatch of a selected kernel via a generator.""" - # Scan arg descs and decide on a compute device. - # For now, we compute on the first device that we support. - # This is very simplisitic and will need to be extended for multi-device, etc. - device: Optional[Device] = None - torch_device: Optional[torch.device] = None - for arg_desc in ksel.arg_descs: - assert arg_desc is not None, "NYI: None arguments" - if not arg_desc.is_list: - if arg_desc.ir_arity == 1: - # One arg has maybe_tensor_value as a single element (common case). - tensor_arg = arg_desc.maybe_tensor_value - if tensor_arg is None: - continue - assert isinstance(tensor_arg, torch.Tensor) - torch_device = tensor_arg.device - device = lookup_device_from_torch(torch_device) - if device is not None: - break - else: - # Optional arg omitted. - assert arg_desc.ir_arity == 0 - continue - else: - # List. maybe_tensor_value is a list. Uncommon case. - assert isinstance(arg_desc.maybe_tensor_value, list) - for tensor_arg in arg_desc.maybe_tensor_value: - if tensor_arg is None: - continue - torch_device = tensor_arg.device - device = lookup_device_from_torch(torch_device) - if device is not None: - break - - # Default to CPU. - if device is None: - logger.debug("Fallback to CPU device due to no supported device in arguments") - torch_device = torch.device("cpu") - device = lookup_device_from_torch(torch_device) - assert ( - device is not None - ), "Could not resolve lookup_device_from_torch for argument" - - # Compile. - # TODO: We can do compilation asynchronously with the device movement - vm_context, vm_f, config = compile_standalone_kernel(device, ksel) - - # Build the concrete args, issuing device movement as necessary. - arg_list = VmVariantList(len(ksel.arg_descs)) - - def push_scalar(scalar_value): - if isinstance(scalar_value, int): - arg_list.push_int(scalar_value) - elif isinstance(scalar_value, float): - arg_list.push_float(scalar_value) - else: - raise UnsupportedTypeError(type(scalar_value)) - - def push_tensor(tensor_arg): - if tensor_arg.device != torch_device: - # TODO: If the source and target device are both known to us, - # we can do this "in house" vs asking torch to do it. - tensor_arg = tensor_arg.to(torch_device) - if not tensor_arg.is_contiguous(): - if config.layout_specialized: - raise NotImplementedError( - "Layout specialized kernels are not yet implemented" - ) - tensor_arg = tensor_arg.contiguous() - # Since we know we are on the same device, we can use the unsafe - # import_torch_tensor. - arg_list.push_ref(device.import_torch_tensor(tensor_arg)) - - for arg_desc in ksel.arg_descs: - assert arg_desc is not None, "NYI: None arguments" - arity = arg_desc.ir_arity - if not arg_desc.is_list: - # Non-list. - if arity == 1: - tensor_arg = arg_desc.maybe_tensor_value - if tensor_arg is not None: - push_tensor(tensor_arg) - else: - assert isinstance(arg_desc, (IntArg, AttrArg)) - push_scalar(arg_desc.v) - else: - continue - else: - # List. Uncommon case. - tensor_arg = arg_desc.maybe_tensor_value - if tensor_arg is not None: - for i in range(arity): - push_tensor(tensor_arg[i]) - else: - for i in range(arity): - assert isinstance(arg_desc, (IntArg, AttrArg)) - list_arg = arg_desc.v - assert isinstance(list_arg, list) - push_scalar(list_arg[i]) - - if config.async_invocations: - raise NotImplementedError("Async execution not yet implemented") - - # Invoke. - ret_list = VmVariantList(len(ksel.result_descs)) - start = default_timer() - vm_context.invoke(vm_f, arg_list, ret_list) - invoke_time = default_timer() - start - if tracer.enabled: - _log_eager_dispatch(config, arg_list, invoke_time * 1000) - - # Unpack results. - results = [] - for i, result_desc in enumerate(ksel.result_descs): - arity = result_desc.ir_arity - meta_tensor_value = result_desc.maybe_tensor_value - if meta_tensor_value is None: - # Scalar return. - raise NotImplementedError("CustomOp scalar return") - assert isinstance( - meta_tensor_value, torch.Tensor - ), "NYI: Optional and result lists" - - # Tensor return. The meta tensor value already has the correct torch - # dtype and shape, so we just need to export and return it for the - # appropriate device. - bv: HalBufferView = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i)) - results.append(device.export_torch_tensor(bv, meta_tensor_value)) - - if len(results) == 1: - return results[0] - elif len(results) == 0: - return None - else: - return tuple(results) - - -def _log_eager_dispatch( - config: KernelCompileConfig, arg_list: VmVariantList, invoke_time_millis: float -): - args = [] - try: - for i in range(arg_list.size): - variant = arg_list.get_variant(i) - if isinstance(variant, VmRef): - if variant.isinstance(HalBufferView): - args.append(_log_format_buffer_view(variant.deref(HalBufferView))) - continue - args.append(variant) - except: - tracer.exception("Exception while pretty-printing arguments") - - msg = "" - tracer.log_structured( - tag="INVOKE_KERNEL", - msg=msg, - columns=[config.tracing_key, invoke_time_millis] + args, - ) - - -def _log_format_buffer_view(bv: HalBufferView) -> str: - # TODO: We should expose this as a method on HalBufferView upstream instead - # of half doing it here. - shape = "x".join(str(i) for i in bv.shape) - dtype_desc = _LOG_HAL_ELEMENT_TYPE_DESC.get(bv.element_type) - if dtype_desc is None: - dtype_desc = f"<{bv.element_type}>" - return f"{shape}x{dtype_desc}" - - -_LOG_HAL_ELEMENT_TYPE_DESC = { - HalElementType.BFLOAT_16: "bf16", - HalElementType.BOOL_8: "i1", - HalElementType.COMPLEX_64: "cf64", - HalElementType.COMPLEX_128: "cf128", - HalElementType.FLOAT_16: "f16", - HalElementType.FLOAT_32: "f32", - HalElementType.FLOAT_64: "f64", - HalElementType.INT_4: "i4", - HalElementType.INT_8: "i8", - HalElementType.INT_16: "i16", - HalElementType.INT_32: "i32", - HalElementType.INT_64: "i64", - HalElementType.SINT_4: "si4", - HalElementType.SINT_8: "si8", - HalElementType.SINT_16: "si16", - HalElementType.SINT_32: "si32", - HalElementType.SINT_64: "si64", - HalElementType.UINT_4: "ui4", - HalElementType.UINT_8: "ui8", - HalElementType.UINT_16: "ui16", - HalElementType.UINT_32: "ui32", - HalElementType.UINT_64: "ui64", -} diff --git a/core/shark_turbine/runtime/tracing.py b/core/shark_turbine/runtime/tracing.py deleted file mode 100644 index abe908ed5..000000000 --- a/core/shark_turbine/runtime/tracing.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import hashlib -import os -from pathlib import Path -import logging - -from ..support.debugging import flags -from ..support.logging import get_logger, DefaultFormatter - -logger = get_logger("turbine.runtime") - - -class RuntimeTracer: - """Supports fine grained tracing of runtime interactions. - - The default implementation no-ops. - """ - - __slots__ = ["enabled"] - - def __init__(self): - self.enabled: bool = False - - def save_jit_kernel_artifacts( - self, *, cache_key: str, module_asm: bytes, binary: memoryview - ) -> str: - return cache_key - - def info(self, msg, *args, **kwargs): - ... - - def error(self, msg, *args, **kwargs): - ... - - def exception(self, msg, *args, **kwargs): - ... - - def log_structured(self, *, tag: str, msg: str, columns: list): - ... - - -class DirectoryTracer(RuntimeTracer): - __slots__ = [ - "dir", - "logger", - ] - - def __init__(self, dir: Path): - self.dir = dir - self.enabled = True - # Configure a root logger that outputs what we want. - trace_logger = self.logger = logging.getLogger("turbine.runtime.tracer") - log_file = dir / "runtime.log" - trace_logger.setLevel(logging.DEBUG) - handler = logging.FileHandler(log_file) - handler.setFormatter(DefaultFormatter()) - trace_logger.addHandler(handler) - trace_logger.propagate = False - logger.info(f"Set up turbine runtime tracing to %s", log_file) - trace_logger.info("Started process %d", os.getpid()) - - def save_jit_kernel_artifacts( - self, *, cache_key: str, module_asm: bytes, binary: memoryview - ) -> str: - hasher = hashlib.sha1(cache_key.encode(), usedforsecurity=False) - tracing_key = hasher.digest().hex() - try: - with open(self.dir / f"{tracing_key}.mlir", "wb") as f: - f.write(module_asm) - with open(self.dir / f"{tracing_key}.vmfb", "wb") as f: - f.write(binary) - except IOError: - self.logger.exception(f"Error saving artifact for {tracing_key}") - finally: - self.logger.info(f"Saved artifacts for {tracing_key}") - return tracing_key - - def info(self, msg, *args, **kwargs): - self.logger.info(msg, *args, **kwargs) - - def error(self, msg, *args, **kwargs): - self.logger.error(msg, *args, **kwargs) - - def exception(self, msg, *args, **kwargs): - self.logger.exception(msg, *args, **kwargs, stacklevel=2) - - def log_structured(self, *, tag: str, msg: str, columns: list): - columns_joined = "\t".join(str(c) for c in columns) - self.logger.info("%s\n::%s\t%s", msg, tag, columns_joined) - - -# Determine whether configured to do real tracing. -def _setup_default_tracer() -> RuntimeTracer: - if flags.runtime_trace_dir: - try: - trace_dir = Path(flags.runtime_trace_dir) - trace_dir.mkdir(parents=True, exist_ok=True) - return DirectoryTracer(trace_dir) - except IOError: - logger.exception("Error configuring runtime tracing to: %s", trace_dir) - return RuntimeTracer() - - return RuntimeTracer() - - -tracer: RuntimeTracer = _setup_default_tracer() diff --git a/core/shark_turbine/support/__init__.py b/core/shark_turbine/support/__init__.py deleted file mode 100644 index bf935cb0c..000000000 --- a/core/shark_turbine/support/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Debugging must be loaded first as other low level things depend on it. -from .debugging import * -from .exceptions import * diff --git a/core/shark_turbine/support/conversions.py b/core/shark_turbine/support/conversions.py deleted file mode 100644 index 41adeec11..000000000 --- a/core/shark_turbine/support/conversions.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Any, Callable - -import numpy as np -import torch - -from iree.runtime import ( - HalElementType, -) - -from iree.compiler.extras.fx_importer import ( - TORCH_DTYPE_TO_MLIR_TYPE_ASM, -) - -from .exceptions import ( - UnknownDTypeError, -) - -from .ir_imports import ( - BF16Type, - ComplexType, - F16Type, - F32Type, - F64Type, - IntegerType, - IrType, -) - -# We need the inverse of the TORCH_DTYPE_TO_MLIR_TYPE_ASM table. -MLIR_TYPE_ASM_TO_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_TO_MLIR_TYPE_ASM.items()} - -# When emitting constants, we have to create native IREE types. -TORCH_DTYPE_TO_IREE_TYPE: dict[torch.dtype, Callable[[], IrType]] = { - torch.float16: lambda: F16Type.get(), - torch.bfloat16: lambda: BF16Type.get(), - torch.float32: lambda: F32Type.get(), - torch.float64: lambda: F64Type.get(), - torch.uint8: lambda: IntegerType.get_signless(8), - torch.int8: lambda: IntegerType.get_signless(8), - torch.int16: lambda: IntegerType.get_signless(16), - torch.int32: lambda: IntegerType.get_signless(32), - torch.int64: lambda: IntegerType.get_signless(64), - torch.bool: lambda: IntegerType.get_signless(1), - torch.qint8: lambda: IntegerType.get_signless(8), - torch.quint8: lambda: IntegerType.get_signless(8), - torch.complex32: lambda: ComplexType.get(F16Type.get()), - torch.complex64: lambda: ComplexType.get(F32Type.get()), - torch.complex128: lambda: ComplexType.get(F64Type.get()), -} - -TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM = { - torch.float16: "f16", - torch.bfloat16: "bf16", - torch.float32: "f32", - torch.float64: "f64", - torch.uint8: "ui8", - torch.int8: "si8", - torch.int16: "si16", - torch.int32: "si32", - torch.int64: "si64", - torch.bool: "i1", - torch.qint8: "si8", - torch.quint8: "ui8", - torch.complex32: "complex", - torch.complex64: "complex", - torch.complex128: "complex", -} - -SIGNED_MLIR_TYPE_ASM_TO_TORCH_DTYPE = dict( - (v, k) for k, v in TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM.items() -) - -TORCH_DTYPE_TO_IREE_TYPE_ASM = { - torch.float16: "f16", - torch.bfloat16: "bf16", - torch.float32: "f32", - torch.float64: "f64", - torch.uint8: "i8", - torch.int8: "i8", - torch.int16: "i16", - torch.int32: "i32", - torch.int64: "i64", - torch.bool: "i1", - torch.qint8: "i8", - torch.quint8: "i8", - torch.complex32: "complex", - torch.complex64: "complex", - torch.complex128: "complex", -} - -DTYPE_TO_ELEMENT_TYPE: dict[torch.dtype, HalElementType] = { - torch.float16: HalElementType.FLOAT_16, - torch.bfloat16: HalElementType.BFLOAT_16, - torch.float32: HalElementType.FLOAT_32, - torch.float64: HalElementType.FLOAT_64, - torch.uint8: HalElementType.UINT_8, - torch.int8: HalElementType.SINT_8, - torch.int16: HalElementType.SINT_16, - torch.int32: HalElementType.SINT_32, - torch.int64: HalElementType.SINT_64, - torch.bool: HalElementType.BOOL_8, - torch.qint8: HalElementType.OPAQUE_8, - torch.quint8: HalElementType.OPAQUE_8, - torch.complex64: HalElementType.COMPLEX_64, - torch.complex128: HalElementType.COMPLEX_128, -} - - -def dtype_to_element_type(dtype) -> HalElementType: - try: - return DTYPE_TO_ELEMENT_TYPE[dtype] - except KeyError: - raise UnknownDTypeError(dtype) - - -TORCH_DTYPE_TO_NUMPY = { - torch.float16: np.dtype("f2"), - torch.float32: np.dtype("f4"), - torch.float64: np.dtype("f8"), - torch.uint8: np.dtype("u1"), - torch.int8: np.dtype("i1"), - torch.int16: np.dtype("i2"), - torch.int32: np.dtype("i4"), - torch.int64: np.dtype("i8"), - torch.bool: np.dtype("?"), - torch.complex64: np.dtype("c8"), - torch.complex128: np.dtype("c16"), -} - - -def torch_dtype_to_numpy(torch_dtype: torch.dtype) -> Any: - try: - return TORCH_DTYPE_TO_NUMPY[torch_dtype] - except KeyError: - raise UnknownDTypeError(torch_dtype) diff --git a/core/shark_turbine/support/debugging.py b/core/shark_turbine/support/debugging.py deleted file mode 100644 index cf8475eb6..000000000 --- a/core/shark_turbine/support/debugging.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Debug flags and settings.""" - -from typing import Optional -from dataclasses import dataclass -import logging -import re -import os - -__all__ = [ - "flags", - "NDEBUG", -] - -# We use the native logging vs our .logging setup because our logging depends -# on this module. It will spew to stderr with issues. -logger = logging.getLogger("turbine.bootstrap") - -# The TURBINE_DEBUG environment variable is a comma separated list of settings -# of the form "(-)?name[=value]". -# Available settings: -# log_level: A log level name to enable. -# asserts: Whether to enable all assertions (defaults to enabled). -FLAGS_ENV_NAME = "TURBINE_DEBUG" -SETTING_PART_PATTERN = re.compile(r"""^([\\+\\-])?([^=]+)(=(.*))?$""") - -# Some settings can also be set in dedicated environment variables. Those are -# mapped here. -ENV_SETTINGS_MAP = { - "TURBINE_LOG_LEVEL": "log_level", -} - -# Whether debug/prolific assertions are disabled. -NDEBUG: bool = False - - -@dataclass -class DebugFlags: - log_level: int = logging.WARNING - asserts: bool = False - runtime_trace_dir: Optional[str] = None - - def set(self, part: str): - m = re.match(SETTING_PART_PATTERN, part) - if not m: - logger.warn("Syntax error in %s flag: '%s'", FLAGS_ENV_NAME, part) - return - name = m.group(2) - value = m.group(4) - if value: - logical_sense = value.upper() not in ["FALSE", "OFF", "0"] - else: - logical_sense = m.group(1) != "-" - - if name == "log_level": - log_level_mapping = logging.getLevelNamesMapping() - try: - self.log_level = log_level_mapping[value.upper()] - except KeyError: - logger.warn("Log level '%s' unknown (ignored)", value) - elif name == "asserts": - self.asserts = logical_sense - global NDEBUG - NDEBUG = not logical_sense - elif name == "runtime_trace_dir": - self.runtime_trace_dir = value - else: - logger.warn("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name) - - @staticmethod - def parse(settings: str) -> "DebugFlags": - new_flags = DebugFlags() - parts = settings.split(",") - for part in parts: - part = part.strip() - if not part: - continue - new_flags.set(part) - return new_flags - - @staticmethod - def parse_from_env() -> "DebugFlags": - settings = os.getenv(FLAGS_ENV_NAME) - if settings is None: - new_flags = DebugFlags() - else: - new_flags = DebugFlags.parse(settings) - for env_name, setting_name in ENV_SETTINGS_MAP.items(): - env_value = os.getenv(env_name) - if env_value is not None: - new_flags.set(f"{setting_name}={env_value}") - logger.debug("Parsed debug flags from env %s: %r", FLAGS_ENV_NAME, new_flags) - return new_flags - - -flags = DebugFlags.parse_from_env() diff --git a/core/shark_turbine/support/exceptions.py b/core/shark_turbine/support/exceptions.py deleted file mode 100644 index be2c2a633..000000000 --- a/core/shark_turbine/support/exceptions.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - - -class GeneralError(Exception): - ... - - -class MismatchedDeviceSetClearError(AssertionError): - def __init__(self): - super().__init__("Calls to Device.set()/clear() are mismatched or unbalanced.") - - -class NoCurrentDeviceError(Exception): - def __init__(self): - super().__init__( - "You accessed a method which requires a current device but none was set on this thread. " - "Either pass an explicit 'device=' or set a current device via " - "`with device:`" - ) - - -class UnsupportedTorchDeviceError(Exception): - def __init__(self, torch_device): - super().__init__( - f"Attempt to use turbine with a torch.device that is not supported by this build: {torch_device}" - ) - - -class UnsupportedTypeError(Exception): - def __init__(self, t: type, usage: str): - super().__init__(f"Python type {t} is not supported for {usage}") - - -class ApiSequencingError(Exception): - ... - - -class UnknownDTypeError(ValueError): - def __init__(self, dtype): - self.dtype = dtype - super().__init__(f"Unable to map torch dtype {dtype} to Turbine") diff --git a/core/shark_turbine/support/ir_imports.py b/core/shark_turbine/support/ir_imports.py deleted file mode 100644 index d2285d4f1..000000000 --- a/core/shark_turbine/support/ir_imports.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Unifies all imports of iree.compiler.ir into one place.""" - -from iree.compiler.ir import ( - AsmState, - Attribute, - Block, - BlockArgument, - Context, - DenseElementsAttr, - DenseResourceElementsAttr, - FlatSymbolRefAttr, - FloatAttr, - FunctionType, - IndexType, - InsertionPoint, - IntegerAttr, - Location, - MLIRError, - Module, - OpResult, - Operation, - RankedTensorType, - ShapedType, - StringAttr, - SymbolTable, - Type as IrType, - TypeAttr, - UnitAttr, - # Types. - ComplexType, - BF16Type, - F16Type, - F32Type, - F64Type, - IntegerType, - RankedTensorType, - Value, -) - -from iree.compiler.passmanager import ( - PassManager, -) - -from iree.compiler.dialects import ( - builtin as builtin_d, - flow as flow_d, - func as func_d, - util as util_d, - arith as arith_d, - tensor as tensor_d, -) diff --git a/core/shark_turbine/support/logging.py b/core/shark_turbine/support/logging.py deleted file mode 100644 index 2bb205eae..000000000 --- a/core/shark_turbine/support/logging.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import sys - -from .debugging import flags - - -class DefaultFormatter(logging.Formatter): - def __init__(self): - super().__init__( - "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s", - "%m-%d %H:%M:%S", - ) - - -def _setup_logger(): - root_logger = logging.getLogger("turbine") - root_logger.setLevel(flags.log_level) - default_handler = logging.StreamHandler(sys.stderr) - default_handler.flush = sys.stderr.flush - default_handler.setLevel(flags.log_level) - default_handler.setFormatter(DefaultFormatter()) - root_logger.addHandler(default_handler) - root_logger.propagate = False - return root_logger, default_handler - - -root_logger, default_handler = _setup_logger() - - -def get_logger(name: str): - logger = logging.getLogger(name) - logger.setLevel(flags.log_level) - logger.addHandler(default_handler) - logger.propagate = False - return logger - - -aot_logger = get_logger("turbine.aot") -runtime_logger = get_logger("turbine.runtime") diff --git a/core/shark_turbine/transforms/builder.py b/core/shark_turbine/transforms/builder.py deleted file mode 100644 index eab07a24c..000000000 --- a/core/shark_turbine/transforms/builder.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import List, Optional, Sequence - -from ..dynamo import type_conversion - -from iree.compiler.ir import ( - Context, - RankedTensorType, - Type as IrType, - Value, -) - -from iree.compiler.dialects import ( - func as func_d, -) - -__all__ = [ - "Builder", -] - - -class Builder: - def __init__(self, context: Context = None): - if not context: - context = Context.current - self.context = context - self.native_type_conversion = type_conversion.NativeTypeConverter(self.context) - - def to_native_type(self, t: IrType) -> IrType: - return self.native_type_conversion.torch_type_to_native(t) - - def to_native_tensor_type(self, t: IrType) -> RankedTensorType: - if not RankedTensorType.isinstance(t): - try: - return RankedTensorType(self.to_native_type(t)) - except Exception as e: - raise ValueError(f"Could not convert to tensor type ({t})") from e - return RankedTensorType(t) - - def get_tensor_dims(self, tensor_type: IrType) -> List[Optional[int]]: - rt = self.to_native_tensor_type(tensor_type) - return [ - None if rt.is_dynamic_dim(axis) else rt.get_dim_size(axis) - for axis in range(rt.rank) - ] - - def get_tensor_element_type(self, tensor_type: IrType) -> IrType: - rt = self.to_native_tensor_type(tensor_type) - return rt.element_type - - def call_native( - self, callee_name: str, result_types: Sequence[IrType], *operands: Value - ) -> Sequence[Value]: - """Calls a function on native types, adding conversions as needed.""" - native_result_types = [ - self.native_type_conversion.torch_type_to_native(t) for t in result_types - ] - native_operands = [ - self.native_type_conversion.materialize_torch_to_native(v) for v in operands - ] - native_results = func_d.CallOp( - native_result_types, callee_name, native_operands - ).results - return [ - self.native_type_conversion.materialize_native_to_torch(v, t) - for t, v in zip(result_types, native_results) - ] diff --git a/core/shark_turbine/transforms/general/custom_op_expansion.py b/core/shark_turbine/transforms/general/custom_op_expansion.py deleted file mode 100644 index 53c9fec17..000000000 --- a/core/shark_turbine/transforms/general/custom_op_expansion.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Callable - -import torch -from torch import Tensor -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.symbolic_shapes import ShapeEnv - -from ...dynamo.type_conversion import ( - NativeTypeConverter, -) - -from ...runtime.op_reg.base import ( - ALL_CUSTOM_OP_REGS, - AttrArg, - IntArg, - CustomOp, - KernelBuilder, - KernelSelection, - TensorArg, - TensorListArg, -) - -from ...support.conversions import ( - MLIR_TYPE_ASM_TO_TORCH_DTYPE, -) - -from ...support.ir_imports import ( - Block, - IrType, - InsertionPoint, - OpResult, - Operation, - RankedTensorType, - StringAttr, - SymbolTable, - Value, -) - -from ..rewriter import ( - Pass, -) - - -class ExpandCustomOpsPass(Pass): - def __init__( - self, root_op: Operation, reg: dict[str, CustomOp] = ALL_CUSTOM_OP_REGS - ): - super().__init__(root_op) - self.reg = reg - # Track pending deletions in a dict to preserve order and unique. - self.ops_to_delete: dict[Operation, None] = {} - self.type_converter = NativeTypeConverter(root_op.context) - self.symbol_table = SymbolTable(root_op) - self.shape_env = ShapeEnv() - self.fake_mode = FakeTensorMode(shape_env=self.shape_env) - - def delete_op(self, op): - self.ops_to_delete[op.operation] = None - - def run(self): - for mr in self.funcs: - self.expand_func(mr.op) - for op in self.ops_to_delete.keys(): - self.erase_unused_op(op) - - def expand_func(self, func_op: Operation): - """Expands custom ops in a traced torch function. - - This finds operations of the form: - %0 = torch.operator "torch.ns.op" - - And looks them up in the reg dict. If it originated from one of those - registered ops, then it will be expanded in place. - """ - name_prefix = "torch." - - for block in func_op.regions[0].blocks: - for op in block.operations: - if op.name == "torch.operator": - custom_op_name = StringAttr(op.attributes["name"]).value - if custom_op_name.startswith(name_prefix): - local_name = custom_op_name[len(name_prefix) :] - custom_op_reg = self.reg.get(local_name) - if custom_op_reg is not None: - self.expand_custom_op(custom_op_reg, op) - - def expand_custom_op(self, op_reg: CustomOp, op: Operation): - original_operands: list[Value] = list(op.operands) - ksel = AOTKernelSelection( - op_reg, - original_operands, - list(op.results), - self.type_converter, - self.shape_env, - ) - with self.fake_mode: - op_reg.select(ksel) - ksel._run_validators() - - module_body = self.root_op.regions[0].blocks[0] - kb = InlineKernelBuilder( - ksel, - op, - type_converter=self.type_converter, - module_body=module_body, - symbol_table=self.symbol_table, - ) - with kb.ip, kb.location: - op_reg.generate(ksel, kb) - assert kb.yielded, "Custom op generation did not yield_results()" - - self.delete_op(op) - - -class AOTKernelSelection(KernelSelection): - __slots__ = [ - "operands", - "results", - "type_converter", - "shape_env", - "_validators", - ] - - def __init__( - self, - op: CustomOp, - operands: list[Value], - results: list[Value], - type_converter: NativeTypeConverter, - shape_env: ShapeEnv, - ): - super().__init__(op, len(operands)) - self.operands = operands - self.results = results - self.type_converter = type_converter - self.shape_env = shape_env - self._validators: list[Callable] = [] - - def _run_validators(self): - for v in self._validators: - v() - - def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> TensorArg: - # This is annoying: We have to go from the Torch MLIR type system to the - # original torch.tensor Python type system. We do this by way of the native - # type converter because it has the mapping pathway we need. This is one of the - # only places in the code where we have to go this way to preserve the facade. - # Everywhere else is going from Torch -> IREE native. - arg_descs = self.arg_descs - assert arg_descs[arg] is None, f"Already constrained argument {arg}" - operand = self.operands[arg] - signed_native_type = self.type_converter.torch_type_to_native( - operand.type, signless=False - ) - try: - rtt = RankedTensorType(signed_native_type) - except TypeError as e: - raise TypeError( - f"Argument type mismatch from Torch IR for arg {arg}: Expected ranked tensor, got {signed_native_type}" - ) from e - element_type_asm = str(rtt.element_type) - try: - dtype = MLIR_TYPE_ASM_TO_TORCH_DTYPE[element_type_asm] - except KeyError as e: - raise AssertionError( - f"Could not find dtype mapping for {element_type_asm} in MLIR_TYPE_ASM_TO_TORCH_DTYPE" - ) - - # Because we are operating in fake_mode, replace MLIR dyn dims with - # symints for the PyTorch type system. - shape_env = self.shape_env - sym_shape = [ - d if d >= 0 else shape_env.create_unbacked_symint() for d in rtt.shape - ] - t = torch.empty(sym_shape, dtype=dtype) - arg_descs[arg] = desc = TensorArg(t) - if inplace_tied: - self.inplace_tied_arg_descs.append(desc) - - def validator(): - rank = rtt.rank - for i in range(rank): - spec_dim = desc.spec_dims[i] - if rtt.is_dynamic_dim(i): - # Make sure that it wasn't specialized. - if spec_dim is not None: - raise ValueError( - f"Custom op {self.op}, arg {arg} requires a static dim " - f"at index {i} but it is dynamic: {rtt}" - ) - else: - # Make sure specialized dim matches. - actual_dim = rtt.get_dim_size(i) - if spec_dim is not None and actual_dim != spec_dim: - raise ValueError( - f"Custom op {self.op}, arg {arg} has a mismatched static " - f"dim at index {i}: actual = {actual_dim}, expected = {spec_dim}" - ) - - self._validators.append(validator) - return desc - - def arg_tensor_list(self, arg: int) -> TensorListArg: - raise NotImplementedError("NYI: AOT arg_tensor_list") - - def arg_int(self, arg: int) -> IntArg: - raise NotImplementedError("NYI: AOT arg_int") - - def attr_str(self, arg: int) -> AttrArg: - arg_descs = self.arg_descs - assert arg_descs[arg] is None, f"Already constrained argument {arg}" - operand = self.operands[arg] - ty = operand.type - assert ( - str(ty) == "!torch.str" - ), f"Argument type mismatch from Torch IR for {arg}: Expected !torch.str, got {ty}" - str_value = _get_constant_str_from_value(operand) - arg_descs[arg] = desc = AttrArg(str_value) - return desc - - def return_tensor(self, t: Tensor) -> TensorArg: - desc = TensorArg(t) - self.result_descs.append(desc) - return desc - - -def _get_constant_str_from_value(v: Value) -> str: - """Given a constant str producer, return the str. - - Example: %str = torch.constant.str "TEST" - """ - constant_op = OpResult(v).owner - assert ( - constant_op.name == "torch.constant.str" - ), f"Expected constant !torch.str to be produced by a torch.constant.str op but got: {constant_op}" - return StringAttr(constant_op.attributes["value"]).value - - -class InlineKernelBuilder(KernelBuilder): - def __init__( - self, - ksel: KernelSelection, - torch_op: Operation, - *, - type_converter: NativeTypeConverter, - module_body: Block, - symbol_table: SymbolTable, - ): - location = torch_op.location - ip = InsertionPoint(torch_op) - with ip, location: - operands = list(torch_op.operands) - arg_bindings = [] - for desc, operand in zip(ksel.arg_descs, operands): - assert desc is not None, "NYI: None arguments" - arity = desc.ir_arity - if not desc.is_list: - if arity == 1: - arg_bindings.append( - type_converter.materialize_torch_to_native( - operand, - static_info_cast_to=IrType.parse(desc.mlir_type_asm), - ) - ) - else: - arg_bindings.append(None) - else: - # arg_bindings.extend(native_operands) - raise NotImplementedError("NYI: AOT custom op list arguments") - - super().__init__( - ksel, - arg_bindings=arg_bindings, - ip=ip, - module_body=module_body, - symbol_table=symbol_table, - ) - self.location = location - self.torch_op = torch_op - self.type_converter = type_converter - - def yield_results(self, *results: Value): - """Yields results of the kernel computation.""" - assert not self.yielded, "yield_results has already been called" - ksel = self.ksel - expected_count = len(ksel.result_descs) + len(ksel.inplace_tied_arg_descs) - assert ( - len(results) == expected_count - ), f"Mismatched yielded results and declared+inplace: Expected={expected_count}, Got={len(results)}" - with self.ip, self.location: - torch_op_results: list[Value] = list(self.torch_op.results) - assert len(results) == len( - torch_op_results - ), f"Mismatched yield_results with custom op results" - for new_result, old_result in zip(results, torch_op_results): - torch_type = old_result.type - new_result = self.type_converter.materialize_native_to_torch( - new_result, - torch_type, - static_info_cast=True, - ) - old_result.replace_all_uses_with(new_result) - self.yielded = True diff --git a/core/shark_turbine/transforms/general/rename_parameters.py b/core/shark_turbine/transforms/general/rename_parameters.py deleted file mode 100644 index d74de0e4b..000000000 --- a/core/shark_turbine/transforms/general/rename_parameters.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""This pass will rename any #stream.parameter.named<> attributes on globals. - -It can either be used as-is or by sub-classing (i.e. in a model specific -subclass that renames from A->B, etc). - -By default, no attributes are touched unless: - -* rename_map= has an exact match -* rename_callback= returns a change -""" - -from typing import Callable, Dict, List, Optional, Tuple, Union - -import re - -from iree.compiler.ir import ( - Attribute, - Operation, - StringAttr, -) - -from ..rewriter import * -from iree.compiler.ir import Context - -ScopedName = Tuple[Optional[str], str] -MaybeScopedName = Union[str, ScopedName] - - -class RenameParametersPass(Pass): - def __init__( - self, - root_op: Operation, - *, - rename_map: Optional[Dict[MaybeScopedName, MaybeScopedName]] = None, - rename_callback: Callable[ - [Optional[str], str], Optional[ScopedName] - ] = lambda scope, name: None, - ): - super().__init__(root_op) - self.context = root_op.context - self.rename_map = rename_map or {} - self.rename_callback = rename_callback - with self.context: - # Make a prototype named parameter attribute so we can get its - # typeid. - self.parameter_attr_typeid = Attribute.parse( - '#stream.parameter.named<""::"">' - ).typeid - - def run(self): - globals = self.globals - for _, global_op in self.globals.items(): - attrs = global_op.op.attributes - try: - initial_value = attrs["initial_value"] - except KeyError: - continue - - if initial_value.typeid == self.parameter_attr_typeid: - updated_value = self.remap(initial_value) - if updated_value != initial_value: - attrs["initial_value"] = updated_value - - def remap(self, parameter_attr: Attribute) -> Attribute: - comps = _parse_parameter_attr(parameter_attr) - if not comps: - return parameter_attr - if len(comps) == 1: - orig_scope = None - name = comps[0] - else: - orig_scope, name = comps - - def norm_map_result(result: MaybeScopedName) -> ScopedName: - if isinstance(result, str): - return orig_scope, result - if len(result) == 1: - return orig_scope, result[0] - else: - return result[0], result[1] - - def make_attr(scoped_name: ScopedName) -> Attribute: - if scoped_name[0] is None: - name = StringAttr.get(scoped_name[1]) - return Attribute.parse( - f"#stream.parameter.named<{name}> : {parameter_attr.type}" - ) - else: - scope = StringAttr.get(scoped_name[0]) - name = StringAttr.get(scoped_name[1]) - return Attribute.parse( - f"#stream.parameter.named<{scope}::{name}> : {parameter_attr.type}" - ) - - # Check the rename map. - # Check with a fully-qualified name. - result = self.rename_map.get((orig_scope, name)) - if result is not None: - return make_attr(norm_map_result(result)) - # Check with just the - result = self.rename_map.get(name) - if result is not None: - return make_attr(norm_map_result(result)) - - # Check the callback. - result = self.rename_callback(orig_scope, name) - if result is not None: - return make_attr(result) - - return parameter_attr - - -def _parse_parameter_attr(attr: Attribute) -> Optional[List[str]]: - # Returns one of: - # None if failed to parse of not a simple named parameter without attributes. - # [name] for names with a default scope - # [scope, name] for scoped names - # TODO: Burn this with fire. We should add Python bindings for these attributes - # vs string parsing them. - # TODO: The parameter attribute correctly parses C escapes but prints unescaped :( - asm = str(attr) - PREFIX = "#stream.parameter.named<" - STR_PATTERN = re.compile(r'"(.*?)(?!\\)"') - if asm.startswith(PREFIX): - asm = asm[len(PREFIX) :] - results = [] - # Parse a str - m = STR_PATTERN.search(asm) - if not m or m.start() != 0: - return None - results.append(m.group(1)) - asm = asm[m.end() :] - # Parse :: - if asm.startswith("::"): - asm = asm[2:] - else: - return results - # Parse a str - m = STR_PATTERN.search(asm) - if not m or m.start() != 0: - return None - results.append(m.group(1)) - asm = asm[m.end() :] - if not asm.startswith(">"): - return None - return results - - -if __name__ == "__main__": - pass_main(RenameParametersPass) diff --git a/core/shark_turbine/transforms/merger.py b/core/shark_turbine/transforms/merger.py deleted file mode 100644 index a1b1be558..000000000 --- a/core/shark_turbine/transforms/merger.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Any, Dict, List, Optional, Sequence, Union - -from iree.compiler.ir import ( - Attribute, - Block, - InsertionPoint, - Operation, - StringAttr, - SymbolTable, - Context, -) - - -__all__ = [ - "Merger", -] - - -def null_logger(msg): - pass - - -def get_top_level_ops(module_op: Operation, *op_names: str) -> Sequence[Operation]: - results = [] - for op_view in module_op.regions[0].blocks[0]: - op = op_view.operation - if op.name in op_names: - results.append(op) - return results - - -def is_global_immutable_initialized(global_op: Operation): - return ( - "is_mutable" not in global_op.attributes - and "initial_value" in global_op.attributes - ) - - -def uniqueify_name(local_name: str, st: SymbolTable) -> str: - index = -1 - while True: - index += 1 - full_name = local_name - if index > 0: - full_name += f"${index}" - if full_name not in st: - return full_name - - -class Merger: - """Merges the contents of one module into another module. - - This performs an opinionated merge that: - * Applies a heuristic to determine whether to merge/rename a global or - keep the existing. - * Moves functions to the target, renaming on collision. - * Moves initializers to the target. - - Globals are handled specially according to the following rules: - * If mutable or not inline initialized, they will be copied from source - to target with a renamed symbol on collision. - * Similar if immutable and inline initialized to a value that is different - than the existing. - * Otherwise, the existing will be used. - - Note that this is a destructive operation on the source module as its contents - are mutated and moved into the target module. - """ - - def __init__( - self, - source_module: Operation, - target_module: Operation, - *, - user_rename_map: Optional[Dict[str, str]] = None, - target_symbol_table: Optional[SymbolTable] = None, - _logger=None, - ): - self._context = source_module.context - self.source_module = source_module - self.target_module = target_module - self._target_body = self.target_module.regions[0].blocks[0] - self.user_rename_map = user_rename_map if user_rename_map is not None else {} - self._logger = _logger if _logger else null_logger - self._source_symbol_table = SymbolTable(self.source_module) - self._target_symbol_table = ( - target_symbol_table - if target_symbol_table is not None - else SymbolTable(self.target_module) - ) - self._rename_map: Dict[StringAttr, StringAttr] = {} - - self._nested_symbol_ops: List[Operation] = [] - self._nested_symbol_table_ops: List[Operation] = [] - self._top_ip = InsertionPoint.at_block_begin(self._target_body) - - # Map of value attributes to global operation. - self._initialized_globals: Dict[Attribute, Operation] = {} - target_globals = get_top_level_ops(self.target_module, "util.global") - for global_op in target_globals: - if not is_global_immutable_initialized(global_op): - continue - self._initialized_globals[global_op.attributes["initial_value"]] = global_op - - def merge(self): - """Performs the merge.""" - # Merge globals. - source_globals = get_top_level_ops(self.source_module, "util.global") - for global_op in source_globals: - if not is_global_immutable_initialized(global_op): - self._import_symbol_op(global_op, append=False) - continue - global_value = global_op.attributes["initial_value"] - alias_global_op = self._initialized_globals.get(global_value) - if alias_global_op: - # Don't import the global, just note the rename. - alias_from = SymbolTable.get_symbol_name(global_op) - alias_to = SymbolTable.get_symbol_name(alias_global_op) - self._logger( - f"Aliasing imported global {StringAttr(alias_from).value} -> {StringAttr(alias_to).value}" - ) - self._rename(alias_from, alias_to) - else: - # Import the global. - self._import_symbol_op(global_op, append=False) - - # Merge initializers. - initializers = get_top_level_ops(self.source_module, "util.initializer") - for init_op in initializers: - init_op.detach_from_parent() - self._nested_symbol_table_ops.append(init_op) - self._target_body.append(init_op) - - # Merge external dispatches. - sources = get_top_level_ops(self.source_module, "hal.executable.source") - for source in sources: - source.detach_from_parent() - self._nested_symbol_table_ops.append(source) - self._target_body.append(source) - - # Merge functions. - funcs = get_top_level_ops(self.source_module, "func.func", "util.func") - for func_op in funcs: - self._import_symbol_op(func_op) - self._nested_symbol_table_ops.append(func_op) - - self._logger(f"The following symbol renames will be made: {self._rename_map}") - - # Go back through to nested symbol table ops and RAUW. - for sym_operation in self._nested_symbol_table_ops: - for from_symbol, to_symbol in self._rename_map.items(): - from_name = StringAttr(from_symbol).value - to_name = StringAttr(to_symbol).value - SymbolTable.replace_all_symbol_uses(from_name, to_name, sym_operation) - - def translate_symbol(self, source_symbol_name: str) -> str: - """Looks up the actual name of a source symbol after merge into the target.""" - source_symbol_attr = StringAttr.get(source_symbol_name, context=self._context) - rename_symbol_attr = self._rename_map.get(source_symbol_attr) - if rename_symbol_attr is None: - return source_symbol_name - else: - return rename_symbol_attr.value - - def _import_symbol_op(self, symbol_op, *, append: bool = True): - symbol_op = symbol_op.detach_from_parent() - orig_symbol = SymbolTable.get_symbol_name(symbol_op) - orig_symbol_name = StringAttr(orig_symbol).value - requested_symbol = self.user_rename_map.get(orig_symbol_name) - if requested_symbol: - # Has a user mapping. - if requested_symbol in self._target_symbol_table: - raise ValueError( - f"Requested symbol rename {requested_symbol} exists in the target" - ) - self._logger(f"Requested rename {orig_symbol_name} -> {requested_symbol}") - SymbolTable.set_symbol_name(symbol_op, requested_symbol) - self._rename(orig_symbol, requested_symbol) - else: - # No user mapping - make sure it is unique. - new_symbol_name = uniqueify_name( - orig_symbol_name, self._target_symbol_table - ) - if new_symbol_name != orig_symbol_name: - self._logger( - f"Implicit rename of conflicting symbol: {orig_symbol_name} -> {new_symbol_name}" - ) - SymbolTable.set_symbol_name(symbol_op, new_symbol_name) - self._rename(orig_symbol, new_symbol_name) - - if append: - self._target_body.append(symbol_op) - else: - self._top_ip.insert(symbol_op) - self._nested_symbol_ops.append(symbol_op) - self._target_symbol_table.insert(symbol_op) - - def _rename(self, from_symbol, to_symbol): - from_symbol = self._make_string_attr(from_symbol) - to_symbol = self._make_string_attr(to_symbol) - if from_symbol != to_symbol: - self._rename_map[from_symbol] = to_symbol - - def _make_string_attr(self, string_attr_or_str): - if isinstance(string_attr_or_str, str): - with self._context: - return StringAttr.get(string_attr_or_str) - else: - return StringAttr(string_attr_or_str) diff --git a/core/shark_turbine/transforms/quantization/mm_group_quant.py b/core/shark_turbine/transforms/quantization/mm_group_quant.py deleted file mode 100644 index bb4d96a85..000000000 --- a/core/shark_turbine/transforms/quantization/mm_group_quant.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Optional, cast - -from iree.compiler.ir import ( - InsertionPoint, - Operation, - Type as IrType, -) - -from ..rewriter import * -from iree.compiler.ir import Context - - -class TransposedMMResult(OpMatchResult): - def __init__( - self, - op: Operation, - *, - weight_global: Operation, - param_name: str, - m: Optional[int], - n: Optional[int], - k: Optional[int], - element_type: IrType, - ): - super().__init__(op) - self.weight_global = weight_global - self.param_name = param_name - self.m = m - self.n = n - self.k = k - self.element_type = element_type - - def __repr__(self): - return f"TransposedMM(weight={self.param_name}, m={self.m}, n={self.n}, k={self.k}, element_type={self.element_type})" - - -class TransposedMMMatcher(NamedOpMatcher[TransposedMMResult]): - def __init__(self, globals: GlobalsDict, builder: Builder): - super().__init__("torch.aten.mm") - self.globals = globals - self.builder = builder - - def match(self, op: Operation): - weight_transpose = Transpose2DMatcher()(op.operands[1]) - if not weight_transpose: - return None - weight_load = GlobalLoadMatcher(self.globals)(weight_transpose.input) - if not weight_load or not weight_load.resolved_global: - return None - - m, n = self.builder.get_tensor_dims(op.operands[0].type) - _, k = self.builder.get_tensor_dims(op.operands[1].type) - return TransposedMMResult( - op, - weight_global=weight_load.resolved_global, - param_name=weight_load.global_ref, - m=m, - n=n, - k=k, - element_type=self.builder.get_tensor_element_type(op.operands[0].type), - ) - - -# TODO (ian): Make more generalizable using RenameParametersPass. Currently hardcoded for brevitas quantization -GROUP_MATMUL_TEMPLATE = r""" -module {{ - util.global private @{param_name} {{noinline}} = #stream.parameter.named<"model"::"{param_name}"> : tensor<{k}x{n_div}xi8> - util.global private @{param_name}.quant.scale {{noinline}} = #stream.parameter.named<"model"::"{param_name}_scale"> : tensor<{k}x{group0}x{element_type}> - util.global private @{param_name}.quant.zero_point {{noinline}} = #stream.parameter.named<"model"::"{param_name}_zp"> : tensor<{k}x{group0}x{element_type}> - - func.func private @compute_mm_group_quant(%a : tensor<{m}x{n}x{element_type}>) -> tensor<{m}x{k}x{element_type}> {{ - %c0 = arith.constant 0 : index - %weight_raw = util.global.load @{param_name} : tensor<{k}x{n_div}xi8> - %m = tensor.dim %a, %c0 : tensor<{m}x{n}x{element_type}> - %k = tensor.dim %weight_raw, %c0 : tensor<{k}x{n_div}xi8> - %scale = util.global.load @{param_name}.quant.scale : tensor<{k}x{group0}x{element_type}> - %zp = util.global.load @{param_name}.quant.zero_point : tensor<{k}x{group0}x{element_type}> - %weight = flow.tensor.bitcast %weight_raw : tensor<{k}x{n_div}xi8> -> tensor<{k}x{n}x{lowp_type}> - %a_exp = tensor.expand_shape %a [[0], [1, 2]] : tensor<{m}x{n}x{element_type}> into tensor<{m}x{group0}x{group1}x{element_type}> - %weight_exp = tensor.expand_shape %weight [[0], [1, 2]] : tensor<{k}x{n}x{lowp_type}> into tensor<{k}x{group0}x{group1}x{lowp_type}> - %empty_0 = tensor.empty() : tensor<{k}x{group0}x{group1}x{element_type}> - %weight_cast = linalg.generic {{ - indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1, d2)>], - iterator_types = ["parallel", "parallel", "parallel"] }} - ins(%weight_exp, %scale, %zp : tensor<{k}x{group0}x{group1}x{lowp_type}>, tensor<{k}x{group0}x{element_type}>, tensor<{k}x{group0}x{element_type}>) - outs(%empty_0 : tensor<{k}x{group0}x{group1}x{element_type}>) {{ - ^bb0(%in: {lowp_type}, %in_1: {element_type}, %in_2: {element_type}, %out: {element_type}): - %16 = arith.extui %in : {lowp_type} to i32 - %17 = arith.uitofp %16 : i32 to {element_type} - %18 = arith.subf %17, %in_2 : {element_type} - %19 = arith.mulf %18, %in_1 : {element_type} - linalg.yield %19 : {element_type} - }} -> tensor<{k}x{group0}x{group1}x{element_type}> - %cst = arith.constant 0.000000e+00 : {element_type} - %empty_1_dyn = tensor.empty(%m, %k) : tensor - %empty_1 = tensor.cast %empty_1_dyn : tensor to tensor<{m}x{k}x{element_type}> - %zero_init = linalg.fill ins(%cst : {element_type}) outs(%empty_1 : tensor<{m}x{k}x{element_type}>) -> tensor<{m}x{k}x{element_type}> - %result = linalg.generic {{ - indexing_maps = [ - affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, - affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, - affine_map<(d0, d1, d2, d3) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction", "reduction"] }} - ins(%a_exp, %weight_cast : tensor<{m}x{group0}x{group1}x{element_type}>, tensor<{k}x{group0}x{group1}x{element_type}>) - outs(%zero_init : tensor<{m}x{k}x{element_type}>) {{ - ^bb0(%in: {element_type}, %in_1: {element_type}, %out: {element_type}): - %16 = arith.mulf %in, %in_1 : {element_type} - %17 = arith.addf %16, %out : {element_type} - linalg.yield %17 : {element_type} - }} -> tensor<{m}x{k}x{element_type}> - return %result : tensor<{m}x{k}x{element_type}> - }} -}} -""" - - -class MMGroupQuantRewriterPass(Pass): - def __init__(self, root_op: Operation, *, group_size: int = 128): - super().__init__(root_op) - self.group_size = group_size - self.context = root_op.context - - def run(self): - globals = self.globals - mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder)) - - for mr in mms: - if mr.k is None or mr.n is None: - continue - if (mr.k % self.group_size) != 0: - continue - self.rewrite(mr) - - self.inline() - self.cleanup() - - def rewrite(self, mr: TransposedMMResult): - none_to_q = lambda x: "?" if x is None else x - static_n = mr.n - if static_n is None: - return - # TODO (ian): make generalizable and not specific for brevitas - if "lm_head.weight" not in mr.param_name: - inline_module_asm = GROUP_MATMUL_TEMPLATE.format( - # TODO (ian): Fix skipping the "_params." portion of the name to match safetensor format with RenameParametersPass - param_name=mr.param_name[8:], - lowp_type="i4", - m=none_to_q(mr.m), - n=none_to_q(mr.n), - k=none_to_q(mr.k), - n_div=static_n // 2, - group0=static_n // self.group_size, - group1=self.group_size, - element_type=mr.element_type, - ) - - inline_module = Operation.parse(inline_module_asm, context=self.context) - actual_callee_name = self.merge_module(inline_module).translate_symbol( - "compute_mm_group_quant" - ) - with InsertionPoint(mr.op), mr.op.location: - results = self.builder.call_native( - actual_callee_name, [mr.op.result.type], mr.op.operands[0] - ) - self.replace_op(mr.op, *results) - - -if __name__ == "__main__": - pass_main(MMGroupQuantRewriterPass) diff --git a/core/shark_turbine/transforms/rewriter.py b/core/shark_turbine/transforms/rewriter.py deleted file mode 100644 index 90205e67e..000000000 --- a/core/shark_turbine/transforms/rewriter.py +++ /dev/null @@ -1,369 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Dict, Generic, List, Optional, Set, Union, Type, TypeVar, cast - -import argparse -import sys - -from iree.compiler.ir import ( - Block, - BlockArgument, - Context, - FlatSymbolRefAttr, - IntegerAttr, - Operation, - OpResult, - OpView, - Region, - StringAttr, - Value, -) - -from iree.compiler.passmanager import ( - PassManager, -) - -from .builder import Builder -from .merger import Merger - -__all__ = [ - "Builder", - "GlobalLoadMatcher", - "GlobalsDict", - "NamedOpMatcher", - "OpMatchResult", - "Pass", - "Transpose2DMatcher", - "match_children", - "pass_main", -] - -############################################################################### -# Matching -############################################################################### - -OpMatchT = TypeVar("OpMatchT", bound=Operation) - - -class OpMatchResult(Generic[OpMatchT]): - def __init__(self, op: OpMatchT): - self.op = op - - def __repr__(self): - return f"{type(self).__name__}({self.op})" - - -OpMatchResultT = TypeVar("OpMatchResultT", bound=OpMatchResult) -OperationParent = Union[None, Operation, OpView, Region, Block, OpMatchResult] -OperationParentOrList = Union[OperationParent, List[OperationParent]] -MaybeOperation = Union[None, Value, OpMatchResult, Operation, OpView] - - -class OpMatcher(Generic[OpMatchResultT]): - """Base class for things that match an operation.""" - - def __call__(self, maybe_op: MaybeOperation) -> Optional[OpMatchResultT]: - if maybe_op is None: - return None - if isinstance(maybe_op, OpMatchResult): - op = maybe_op.op - elif isinstance(maybe_op, Operation): - op = maybe_op - elif isinstance(maybe_op, OpView): - op = maybe_op.operation - elif isinstance(maybe_op, Value): - if OpResult.isinstance(maybe_op): - op = _op_as_operation(OpResult(maybe_op).owner) - elif BlockArgument.isinstance(maybe_op): - return None - else: - raise ValueError(f"Unexpected OpMatcher input: {type(maybe_op)}") - - return self._match(op) - - def _match(self, op: Operation) -> Optional[OpMatchResultT]: - raise NotImplementedError - - -class NamedOpMatcher(OpMatcher[OpMatchResultT]): - """Matches operations by name.""" - - def __init__(self, *op_names: str): - self.op_names = op_names - - def _match(self, op: Operation) -> Optional[OpMatchResultT]: - if op.name in self.op_names: - return self.match(op) - return None - - def match(self, op: Operation) -> Optional[OpMatchResultT]: - return OpMatchResult(op) # type: ignore - - -def get_child_blocks(of: OperationParentOrList) -> List[Block]: - """Gets all child blocks of an Operation, Region, or Block (self).""" - blocks: List[Block] = [] - if of is None: - return blocks - - if isinstance(of, OpMatchResult): - of = of.op - - if isinstance(of, (Operation, OpView)): - for r in of.regions: - for b in r.blocks: - blocks.append(b) - elif isinstance(of, Region): - for b in of.blocks: - blocks.append(b) - elif isinstance(of, Block): - blocks.append(of) - elif isinstance(of, List): - for p in of: - blocks.extend(get_child_blocks(p)) - else: - raise ValueError(f"Must be an Operation, Region, or Block. Got: {type(of)}") - return blocks - - -def match_children( - of: OperationParentOrList, *matchers: OpMatcher -) -> List[OpMatchResult]: - """Matches children of a parent. - - For any child, the match result from the first matcher which matches - will be added to the result list. - """ - results = [] - blocks = get_child_blocks(of) - for b in blocks: - for op in b.operations: - for m in matchers: - result = m(op.operation) - if result: - results.append(result) - break - return results - - -############################################################################### -# Specific op matchers -############################################################################### - - -class FuncOpMatcher(NamedOpMatcher): - """Matches func.func functions.""" - - def __init__(self): - super().__init__("func.func") - - -class GlobalOpResult(OpMatchResult): - @property - def sym_name(self) -> str: - return StringAttr(self.op.attributes["sym_name"]).value - - -class GlobalOpMatcher(NamedOpMatcher[GlobalOpResult]): - """Matches global operations.""" - - def __init__(self): - super().__init__("util.global") - - def match(self, op: Operation) -> Optional[GlobalOpResult]: - return GlobalOpResult(op) - - -class Transpose2DResult(OpMatchResult): - @property - def input(self) -> Value: - return self.op.operands[0] - - -class Transpose2DMatcher(NamedOpMatcher[Transpose2DResult]): - def __init__(self): - super().__init__("torch.aten.transpose.int") - - def match(self, op: Operation) -> Optional[Transpose2DResult]: - result = Transpose2DResult(op) - if not ConstantIntMatcher(0)(op.operands[1]) or not ConstantIntMatcher(1)( - op.operands[2] - ): - return None - return result - - -class ConstantIntMatcher(NamedOpMatcher): - def __init__(self, value: int): - super().__init__("torch.constant.int") - self.value = value - - def match(self, op: Operation): - value_attr = IntegerAttr(op.attributes["value"]) - if value_attr.value != self.value: - return None - return OpMatchResult(op) - - -class GlobalLoadResult(OpMatchResult): - def __init__(self, op: Operation): - super().__init__(op) - self.resolved_global: Optional[GlobalOpResult] = None - - @property - def global_ref(self) -> str: - return FlatSymbolRefAttr(self.op.attributes["global"]).value - - -class GlobalLoadMatcher(NamedOpMatcher[GlobalLoadResult]): - def __init__(self, globals: Optional["GlobalsDict"] = None): - super().__init__("util.global.load", "torch_c.from_builtin_tensor") - self.globals = globals - - def match(self, op: Operation) -> Optional[GlobalLoadResult]: - # Skip over any builtin tensor conversion. - if op.name == "torch_c.from_builtin_tensor": - op = _value_as_op_or_none(op.operands[0]) - if not op: - return None - - result = GlobalLoadResult(op) - if self.globals: - result.resolved_global = self.globals.get(result.global_ref) - return result - - -############################################################################### -# Passes -############################################################################### - -GlobalsDict = Dict[str, GlobalOpResult] - - -class Pass: - """Callable which performs some mutation on the IR.""" - - def __init__(self, root_op: Operation): - self.root_op = root_op - self.builder = Builder(root_op.context) - - def run(self): - raise NotImplementedError - - @property - def funcs(self) -> List[OpMatchResult]: - return match_children(self.root_op, FuncOpMatcher()) - - @property - def globals(self) -> GlobalsDict: - results = match_children(self.root_op, GlobalOpMatcher()) - return {r.sym_name: r for r in cast(list[GlobalOpResult], results)} - - def merge_module(self, source_module: Operation) -> Merger: - """Merges the given source module into the root. - - See documentation for the Merger for more information. - """ - merger = Merger(source_module, self.root_op) - merger.merge() - return merger - - def inline(self): - """Runs the inliner.""" - with self.root_op.context: - pm = PassManager.parse("builtin.module(inline)") - pm.run(self.root_op) - - def cleanup(self): - """Runs module cleanup passes.""" - with self.root_op.context: - pm = PassManager.parse("builtin.module(canonicalize, symbol-dce)") - pm.run(self.root_op) - - def replace_op(self, old_op: Operation, *new_results: Value): - old_results = old_op.results - assert len(old_results) == len( - new_results - ), "Can only replace_op with the same arity" - for old_result, new_result in zip(old_results, new_results): - old_result.replace_all_uses_with(new_result) - self.erase_unused_op(old_op) - - def erase_unused_op(self, op: Operation): - """Recursively erases any unused torch ops, starting with op. - - Torch ops generally are not erased automatically, but as part of - pattern matching, when we know we want to replace them, we can do - this ourself. - """ - worklist: Set[Operation] = set() - worklist.add(op) - while worklist: - ops = worklist - worklist = set() - for op in ops: - if not _is_erasable_value_op(op): - continue - if not _op_is_live(op): - for operand in op.operands: - if OpResult.isinstance(operand): - worklist.add(operand.owner) - op.erase() - - -def pass_main(pass_class: Type[Pass], *, argv=None): - """Simple main entry-point which reads a file, runs a callback and outputs.""" - parser = argparse.ArgumentParser(description="Rewrite driver") - parser.add_argument("input_file", help="File to process") - parser.add_argument("-o", dest="output_file", help="Output file") - args, _ = parser.parse_known_args(argv) - - with Context() as context: - with open(args.input_file, "r") as f: - module_op = Operation.parse(f.read(), source_name=args.input_file) - - p = pass_class(module_op) - p.run() - - if args.output_file: - with open(args.output_file, "wb") as f: - module_op.print(file=f, binary=True) - else: - module_op.print(file=sys.stdout) - - -############################################################################### -# Utilities -############################################################################### - - -def _value_as_op_or_none(value: Value) -> Optional[Operation]: - if OpResult.isinstance(value): - return _op_as_operation(OpResult(value).owner) - return None - - -def _op_as_operation(op: Union[Operation, OpView]) -> Operation: - if isinstance(op, OpView): - return op.operation - else: - return op - - -def _op_is_live(op: Operation) -> bool: - for r in op.results: - try: - next(r.uses) - return True - except StopIteration: - pass - return False - - -def _is_erasable_value_op(op: Operation): - name = op.name - return name.startswith("torch.") or name.startswith("torch_c.") diff --git a/core/tests/aot/api_test.py b/core/tests/aot/api_test.py deleted file mode 100644 index e038704d6..000000000 --- a/core/tests/aot/api_test.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -from iree.compiler.ir import ( - Context, -) - -from shark_turbine.aot import * - -import torch -import torch.nn as nn - - -class GeneralAPI(unittest.TestCase): - def testTypedefs(self): - self.assertEqual( - "AbstractTensor(3, 2, dtype=torch.float16)", - repr(AbstractTensor(3, 2, dtype=torch.float16)), - ) - - -class CompiledModuleAPI(unittest.TestCase): - def testBasic(self): - class BasicModule(CompiledModule): - ... - - inst = BasicModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("module @basic", module_str) - - def testExplicitName(self): - class BasicModule(CompiledModule, export_name="explicit"): - ... - - inst = BasicModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("module @explicit", module_str) - - def testJittableFunc(self): - class BasicModule(CompiledModule): - @CompiledModule.jittable - def mul(x, y): - return x * y - - inst = BasicModule(context=Context()) - self.assertIsInstance(inst.mul, builtins.jittable) - - def testBareJittableFunc(self): - class BasicModule(CompiledModule): - @jittable - def mul(x, y): - return x * y - - inst = BasicModule(context=Context()) - self.assertIsInstance(inst.mul, builtins.jittable) - - def testExportedProc(self): - class ExportedProcModule(CompiledModule): - def foobar(self): - ... - - inst = ExportedProcModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - - -class ExportAPI(unittest.TestCase): - def testStaticNNModule(self): - mdl = SimpleParams() - exported = export(mdl, args=(torch.empty([128, 20]),)) - exported.print_readable() - asm = str(exported.mlir_module) - self.assertIn("dense_resource", asm) - - def testDynamicNNModule(self): - mdl = SimpleParams() - batch = torch.export.Dim("batch") - exported = export( - mdl, args=(torch.empty([128, 20]),), dynamic_shapes={"x": {0: batch}} - ) - exported.print_readable() - asm = str(exported.mlir_module) - self.assertIn( - "func.func @main(%arg0: !torch.vtensor<[?,20],f32>) -> !torch.vtensor<[?,30],f32>", - asm, - ) - - def testExternalParamsNNModule(self): - mdl = SimpleParams() - externalize_module_parameters(mdl) - exported = export(mdl, args=(torch.empty([128, 20]),)) - exported.print_readable() - asm = str(exported.mlir_module) - self.assertNotIn("dense_resource", asm) - self.assertIn("util.global.load", asm) - - def testTorchExportedProgram(self): - mdl = SimpleParams() - externalize_module_parameters(mdl) - prg = torch.export.export(mdl, args=(torch.empty([128, 20]),)) - exported = export(prg) - exported.print_readable() - asm = str(exported.mlir_module) - self.assertNotIn("dense_resource", asm) - self.assertIn( - 'util.global private @__auto.classifier.weight = #stream.parameter.named<"model"::"classifier.weight">', - asm, - ) - self.assertIn( - 'util.global private @__auto.classifier.bias = #stream.parameter.named<"model"::"classifier.bias">', - asm, - ) - - def testCompiledModuleExportedProgram(self): - class BasicModule(CompiledModule): - ... - - exported = export(BasicModule) - module_str = str(exported.mlir_module) - print(module_str) - self.assertIn("module @basic", module_str) - - def testUnsupportedExportedProgram(self): - class UnsupportedExportType: - ... - - with self.assertRaises(TypeError): - export(UnsupportedExportType) - - -class SimpleParams(nn.Module): - def __init__(self): - super().__init__() - self.classifier = nn.Linear(20, 30) - - def forward(self, x): - return self.classifier(x) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/aot/args_test.py b/core/tests/aot/args_test.py deleted file mode 100644 index d7ec458da..000000000 --- a/core/tests/aot/args_test.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -from iree.compiler.ir import ( - Context, -) - -from shark_turbine.aot import * - - -class ArgsTest(unittest.TestCase): - def testProcArgs(self): - class ProcArgsModule(CompiledModule): - def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): - return b, a - - inst = ProcArgsModule(context=Context(), import_to="full") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "util.func public @foobar$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.fence, %arg3: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view)", - module_str, - ) - - def testProcToJitArgs(self): - class testProcToJitArgs(CompiledModule): - def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): - return self.compute(a, b) - - @jittable - def compute(a, b): - return a + b - - inst = testProcToJitArgs(context=Context(), import_to="full") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "linalg.generic", - module_str, - ) - - def testProcToJitArgsMultiCall(self): - class ProcArgsModule(CompiledModule): - def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): - x = self.compute(a, b) - y = self.compute(x, a) - return y - - @jittable - def compute(a, b): - return a + b - - inst = ProcArgsModule(context=Context(), import_to="full") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertEqual( - 2, - module_str.count("linalg.generic"), - msg=f"Did not find two linalg.generics in module: module_str", - ) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/aot/compiled_exported_program_test.py b/core/tests/aot/compiled_exported_program_test.py deleted file mode 100644 index 4bf3fa157..000000000 --- a/core/tests/aot/compiled_exported_program_test.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch -import torch.nn as nn - -from iree.compiler.ir import ( - Context, -) - -from shark_turbine.aot import * -from shark_turbine.aot.builtins import * - - -class TorchExportTests(unittest.TestCase): - def testImportPhases(self): - class MyModule(torch.nn.Module): - def forward(self): - ... - - fxb = FxProgramsBuilder(MyModule()) - - @fxb.export_program( - args=([torch.empty([3, 2]), torch.empty([1, 2])],), - kwargs={"foobar": torch.empty([3, 1])}, - ) - def compute(module, inputs, *, foobar): - t1 = inputs[0] - t2 = inputs[1] - t3 = t1 + t2 + foobar - return [t3 * t3, foobar] - - class ExportedProcModule(CompiledModule): - _compute = compute - - def foobar( - self, - t1=AbstractTensor(3, 2), - t2=AbstractTensor(1, 2), - t3=AbstractTensor(3, 1), - ): - return self._compute(t1, t2, foobar=t3) - - inst = ExportedProcModule(context=Context(), import_to="import") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("func.func private @_compute", module_str) - self.assertIn("func.func @foobar", module_str) - - def testMultiPublic(self): - class MyModule(torch.nn.Module): - def forward(self): - ... - - fxb = FxProgramsBuilder(MyModule()) - - @fxb.export_program( - args=([torch.empty([3, 2]), torch.empty([1, 2])],), - kwargs={"foobar": torch.empty([3, 1])}, - ) - def _compute1(module, inputs, *, foobar): - t1 = inputs[0] - t2 = inputs[1] - t3 = t1 + t2 + foobar - return [t3 * t3, foobar] - - @fxb.export_program( - args=([torch.empty([5]), torch.empty([5])],), - kwargs={"foobar": torch.empty([5])}, - ) - def _compute2(module, inputs, *, foobar): - t1 = inputs[0] - t2 = inputs[1] - t3 = t1 + t2 + foobar - return [t3 * t3, foobar] - - class ExportedPublicModule(CompiledModule): - compute1 = _compute1 - compute2 = _compute2 - - inst = ExportedPublicModule(context=Context(), import_to="import") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("func.func @compute1", module_str) - self.assertIn("func.func @compute2", module_str) - - def testParametersAsGlobals(self): - fxb = FxProgramsBuilder(SimpleParams()) - - @fxb.export_program( - args=(torch.empty([128, 20]),), - ) - def _compute1(module, x): - return module.forward(x) - - class ParamsAsGlobalsModule(CompiledModule): - params = export_parameters(fxb.root_module) - compute1 = _compute1 - compute2 = _compute1 - - inst = ParamsAsGlobalsModule(context=Context(), import_to="import") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("util.global private @_params.classifier.weight", module_str) - self.assertIn("util.global private @_params.classifier.bias", module_str) - # Should only be two. - self.assertEqual(2, module_str.count("util.global private")) - # And two loads each loads. - self.assertEqual( - 2, module_str.count("util.global.load @_params.classifier.weight") - ) - self.assertEqual( - 2, module_str.count("util.global.load @_params.classifier.bias") - ) - - def testBuffersAsGlobals(self): - fxb = FxProgramsBuilder(SimpleBuffers()) - - @fxb.export_program(args=(torch.empty([128]),)) - def _compute1(module, x): - return module.forward(x) - - class BuffersAsGlobalsModule(CompiledModule): - buffers = export_buffers(fxb.root_module, mutable=True) - compute1 = _compute1 - - inst = BuffersAsGlobalsModule(context=Context(), import_to="import") - module_str = str(CompiledModule.get_mlir_module(inst)) - self.assertIn("util.global private mutable @_buffers.buf", module_str) - self.assertIn("%_buffers.buf = util.global.load @_buffers.buf", module_str) - self.assertIn("util.global.store", module_str) - - -class SimpleParams(nn.Module): - def __init__(self): - super().__init__() - self.classifier = nn.Linear(20, 30) - - def forward(self, x): - return self.classifier(x) - - -class SimpleBuffers(nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("buf", torch.randn(1)) - - def forward(self, x: torch.Tensor): - sumx = (x).sum() - output = x * self.buf - self.buf.copy_(sumx) - return output - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/aot/decompositions_test.py b/core/tests/aot/decompositions_test.py deleted file mode 100644 index baf96604c..000000000 --- a/core/tests/aot/decompositions_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import torch - -import logging -import unittest - -from shark_turbine.aot import decompositions - - -class DecompTest(unittest.TestCase): - def testDefault(self): - table = decompositions.current_aot_decompositions() - self.assertTrue(table) - - def testExtendToEmpty(self): - with decompositions.extend_aot_decompositions(from_current=False) as t: - self.assertFalse(t, msg=f"{t}") - current_table = decompositions.current_aot_decompositions() - self.assertFalse(current_table, msg=f"{current_table}") - - def testNestedExtend(self): - initial_table = decompositions.current_aot_decompositions() - with decompositions.extend_aot_decompositions(from_current=False) as empty_t: - with decompositions.extend_aot_decompositions( - add_ops=[ - torch.ops.aten.masked_fill.Tensor, - torch.ops.aten.masked_fill.Scalar, - ] - ): - current_table = decompositions.current_aot_decompositions() - self.assertEqual(2, len(current_table), msg=f"{current_table}") - with decompositions.extend_aot_decompositions( - remove_ops=[ - torch.ops.aten.masked_fill.Tensor, - ] - ): - current_table = decompositions.current_aot_decompositions() - self.assertEqual(1, len(current_table), msg=f"{current_table}") - current_table = decompositions.current_aot_decompositions() - self.assertDictEqual(current_table, initial_table) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/aot/functionalize_test.py b/core/tests/aot/functionalize_test.py deleted file mode 100644 index bbd15ad83..000000000 --- a/core/tests/aot/functionalize_test.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch - -from iree.compiler.ir import ( - Context, -) - -from shark_turbine.aot import * - - -class FunctionalizeTests(unittest.TestCase): - def testImportPhases(self): - class ExportedProcModule(CompiledModule): - def foobar(self): - return self.compute(), self.compute() - - @CompiledModule.jittable - def compute(): - offset = torch.ones(2, 2) - t1 = torch.ones(2, 2) - t1.add_(offset) - return t1 * t1 - - inst = ExportedProcModule(context=Context(), import_to="import") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertNotIn("add_", module_str) - - def testDynamicDims(self): - class ProcArgsModule(CompiledModule): - def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)): - return self.compute( - a, - b, - constraints=[ - a.dynamic_dim(0) == b.dynamic_dim(0), - ], - ) - - @jittable - def compute(a, b): - a.mul_(b) - return a - - inst = ProcArgsModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertNotIn("mul_", module_str) - - def testCallWithStructure(self): - class ProcArgsModule(CompiledModule): - def call_with_dicts(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): - intermediate = self.compute({"a": a, "b": b}) - return self.compute(intermediate) - - @jittable - def compute(struct): - a = struct["a"] - b = struct["b"] - a.add_(b) - return {"a": a, "b": b} - - inst = ProcArgsModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertNotIn("add_", module_str) - - def testCallWithArgsKwargs(self): - class ProcArgsModule(CompiledModule): - def call_with_kwargs(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): - intermediate = self.compute(**{"a": a, "b": b}) - return self.compute(**intermediate) - - @jittable - def compute(*, a, b): - a.add_(b) - return {"a": a, "b": b} - - inst = ProcArgsModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertNotIn("add_", module_str) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/aot/fx_programs_test.py b/core/tests/aot/fx_programs_test.py deleted file mode 100644 index c54f1851b..000000000 --- a/core/tests/aot/fx_programs_test.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from pathlib import Path -import tempfile - -import pytest -import torch - -from shark_turbine.aot import ( - FxPrograms, - FxProgramsBuilder, -) - - -def test_save_load_dynamic_shapes(): - if torch.__version__ < "2.3.0.dev1": - pytest.skip("Unsupported PyTorch version") - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU()) - self.branch2 = torch.nn.Sequential( - torch.nn.Linear(128, 64), torch.nn.ReLU() - ) - self.buffer = torch.ones(32) - - def forward(self, x1, x2): - out1 = self.branch1(x1) - out2 = self.branch2(x2) - return (out1 + self.buffer, out2) - - example_args = (torch.randn(32, 64), torch.randn(32, 128)) - - # Create a dynamic batch size - batch = torch.export.Dim("batch") - # Specify that the first dimension of each input is that batch size - dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} - - fxb = FxProgramsBuilder(M()) - - @fxb.export_program(args=example_args, dynamic_shapes=dynamic_shapes) - def dynamic_batch(module: M, x1, x2): - return module.forward(x1, x2) - - @fxb.export_program(args=example_args) - def bs32(module: M, x1, x2): - return module.forward(x1, x2) - - with tempfile.TemporaryDirectory() as d: - p = Path(d) / "branchy.json" - dedup_count = fxb.save(p) - assert dedup_count == 5 # Two sets of weights/bias and one constant - new_programs = FxPrograms.load(p) - - prog_0 = new_programs.programs["dynamic_batch"] - prog_1 = new_programs.programs["bs32"] - - for key, value_0 in prog_0.state_dict.items(): - value_1 = prog_1.state_dict[key] - assert value_0 is value_1, f"State dict item {key} was not aliased on load" - - for key, value_0 in prog_0.constants.items(): - value_1 = prog_1.constants[key] - assert value_0 is value_1, f"Constant item {key} was not aliased on load" diff --git a/core/tests/aot/globals_test.py b/core/tests/aot/globals_test.py deleted file mode 100644 index 26bab1a61..000000000 --- a/core/tests/aot/globals_test.py +++ /dev/null @@ -1,430 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -from iree.compiler.ir import ( - Context, -) - -from shark_turbine.aot import * - -import torch -import torch.nn as nn - - -class SimpleParams(nn.Module): - def __init__(self): - super().__init__() - self.classifier = nn.Linear(20, 30) - - def forward(self, x): - return self.classifier(x) - - -class GlobalsTest(unittest.TestCase): - def testGlobalParameters(self): - m = SimpleParams() - - class GlobalModule(CompiledModule): - params = export_parameters(m) - compute = jittable(m.forward) - - def run(self, x=AbstractTensor(128, 20)): - return self.compute(x) - - inst = GlobalModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("util.global private @_params.classifier.weight", module_str) - self.assertIn("util.global private @_params.classifier.bias", module_str) - - def testGlobalLoadFromPyTree(self): - m = SimpleParams() - - class GlobalModule(CompiledModule): - params = export_parameters(m) - - def read_params(self): - return self.params - - inst = GlobalModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "%_params.classifier.weight = util.global.load @_params.classifier.weight", - module_str, - ) - self.assertIn( - "%_params.classifier.bias = util.global.load @_params.classifier.bias", - module_str, - ) - - def testGlobalLoadFromPyLeaf(self): - m = SimpleParams() - - class GlobalModule(CompiledModule): - params = export_parameters(m) - - def read_weight(self): - return self.params["classifier.weight"] - - inst = GlobalModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "%_params.classifier.weight = util.global.load @_params.classifier.weight", - module_str, - ) - - def testGlobalStoreFromPyTree(self): - m = SimpleParams() - - class GlobalModule(CompiledModule): - params = export_parameters(m, mutable=True) - - def update_params(me, updates=abstractify(params)): - self.assertIn("classifier.weight", updates) - self.assertIn("classifier.bias", updates) - me.params = updates - - inst = GlobalModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertRegex( - module_str, "util.global.store %.*, @_params.classifier.weight" - ) - self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias") - - def testGlobalStoreFromLeaf(self): - m = SimpleParams() - - class GlobalModule(CompiledModule): - params = export_parameters(m, mutable=True) - - def update_bias(self, new_bias=abstractify(params["classifier.bias"])): - self.params["classifier.bias"] = new_bias - - inst = GlobalModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias") - - def testExportSingleGlobalTensor(self): - state_example = torch.randn(3, 11) - - class SingleState(CompiledModule): - state0 = export_global(state_example, name="global") - - def read_state(self): - return self.state0 - - inst = SingleState(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("util.global private @_state0.global", module_str) - self.assertIn("%_state0.global = util.global.load @_state0.global", module_str) - - def testExportTreeGlobalTensors(self): - state_example = { - "data": torch.randn(3, 11), - "seq": [ - torch.randn(1), - torch.randn(2), - torch.randn(3), - ], - } - - class SingleState(CompiledModule): - state0 = export_global_tree(state_example) - - def read_state(self): - return self.state0 - - inst = SingleState(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("util.global private @_state0.seq.0", module_str) - self.assertIn("util.global private @_state0.seq.1", module_str) - self.assertIn("util.global private @_state0.seq.2", module_str) - self.assertIn("util.global private @_state0.data", module_str) - self.assertIn("%_state0.data = util.global.load @_state0.data", module_str) - self.assertIn("%_state0.seq.0 = util.global.load @_state0.seq.0", module_str) - self.assertIn("%_state0.seq.1 = util.global.load @_state0.seq.1", module_str) - self.assertIn("%_state0.seq.2 = util.global.load @_state0.seq.2", module_str) - - def testExportGlobalScalars(self): - class ScalarState(CompiledModule): - state_index = export_global(AbstractIndex, mutable=True) - state_f32 = export_global(AbstractF32, mutable=True) - state_f64 = export_global(AbstractF64, mutable=True) - state_i32 = export_global(AbstractI32, mutable=True) - state_i64 = export_global(AbstractI64, mutable=True) - state_bool = export_global(AbstractBool, mutable=True) - - def read(self): - return ( - self.state_index, - self.state_f32, - self.state_f64, - self.state_i32, - self.state_i64, - self.state_bool, - ) - - inst = ScalarState(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("@_state_index.global = 0 : index", module_str) - self.assertIn("@_state_f32.global = 0.000000e+00 : f32", module_str) - self.assertIn("@_state_f64.global = 0.000000e+00 : f64", module_str) - self.assertIn("@_state_i32.global = 0 : i32", module_str) - self.assertIn("@_state_i64.global = 0 : i64", module_str) - self.assertIn("@_state_bool.global = false", module_str) - - def testInheritExportScalars(self): - class BaseState(CompiledModule): - state_index = export_global(AbstractIndex, mutable=True) - state_f32 = export_global(AbstractF32, mutable=True) - - def read(self): - return (self.state_index, self.state_f32) - - class DerivedState(BaseState): - pass - - inst = DerivedState(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("@_state_index.global = 0 : index", module_str) - self.assertIn("@_state_f32.global = 0.000000e+00 : f32", module_str) - - def testInheritOverrideBase(self): - class BaseState(CompiledModule): - state_index = export_global(AbstractIndex, mutable=True) - state_f32 = export_global(AbstractF32, mutable=True) - - def read(self): - return (self.state_index, self.state_f32) - - class DerivedState(BaseState): - def read(self): - return self.state_index - - inst = DerivedState(context=Context(), import_to="full") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("@_state_index.global = 0 : index", module_str) - self.assertNotIn("@_state_f32.global = 0.000000e+00 : f32", module_str) - self.assertIn("return %_state_index.global : index", module_str) - - def testInheritExportModules(self): - m = SimpleParams() - - class BaseModule(CompiledModule): - params = export_parameters(m, mutable=True) - - def update_params(me, updates=abstractify(params)): - self.assertIn("classifier.weight", updates) - self.assertIn("classifier.bias", updates) - me.params = updates - - class DerivedModule(BaseModule): - pass - - inst = DerivedModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertRegex( - module_str, "util.global.store %.*, @_params.classifier.weight" - ) - self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias") - - def testUpdateGlobalStateTree(self): - state_example = { - "data": torch.randn(3, 11), - "seq": [ - torch.randn(1).to(torch.int32), - torch.randn(2).to(torch.float64), - torch.randn(3).to(torch.int64), - ], - } - - class SingleState(CompiledModule): - state0 = export_global_tree(abstractify(state_example), mutable=True) - - def read_state(self, updates=abstractify(state_example)): - self.state0 = updates - - inst = SingleState(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "util.global private mutable @_state0.seq.0 = dense<0> : tensor<1xi32>", - module_str, - ) - self.assertIn( - "util.global private mutable @_state0.seq.1 = dense<0.000000e+00> : tensor<2xf64>", - module_str, - ) - self.assertIn( - "util.global private mutable @_state0.seq.2 = dense<0> : tensor<3xi64>", - module_str, - ) - self.assertIn("util.global private mutable @_state0.data", module_str) - self.assertRegex(module_str, "util.global.store %.*, @_state0.data") - self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.0") - self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.1") - self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.2") - - def testTensorUpdateGlobal(self): - state_example = torch.randn(5, 20) - update_example = torch.randn(1, 20) - - class UpdateState(CompiledModule): - state0 = export_global(state_example, mutable=True) - - def tensor_update_state(self, update=abstractify(update_example)): - return IREE.tensor_update(self.state0, update, 0, 0) - - inst = UpdateState(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertRegex( - module_str, - "flow.tensor.update %.*, %_state0.global\\[%c0, %c0\\] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>", - ) - - def testTensorUpdateGlobalReturnNone(self): - state_example = torch.randn(5, 20, 4) - update_example = torch.randn(1, 1, 4) - - class UpdateState(CompiledModule): - state0 = export_global(state_example, mutable=True) - - def tensor_update_state(self, update=abstractify(update_example)): - thing = [] - self.state0 = IREE.tensor_update(self.state0, update, 4, 0, 0) - return None - - inst = UpdateState(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("flow.tensor.update", module_str) - - def testExternalGlobalParametersDefaults(self): - m = SimpleParams() - - class GlobalModule( - CompiledModule, export_name="external_global_parameters_defaults" - ): - params = export_parameters(m, external=True) - compute = jittable(m.forward) - - def run(self, x=AbstractTensor(128, 20)): - return self.compute(x) - - inst = GlobalModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - '#stream.parameter.named<"model"::"params.classifier.weight"> : tensor<30x20xf32>', - module_str, - ) - self.assertIn( - '#stream.parameter.named<"model"::"params.classifier.bias"> : tensor<30xf32>', - module_str, - ) - - def testExternalGlobalParametersExplicit(self): - m = SimpleParams() - - class GlobalModule( - CompiledModule, export_name="external_global_parameters_explicit" - ): - params = export_parameters( - m, external=True, external_scope="foo", name_mapper=lambda s: s.upper() - ) - compute = jittable(m.forward) - - def run(self, x=AbstractTensor(128, 20)): - return self.compute(x) - - inst = GlobalModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - '#stream.parameter.named<"foo"::"PARAMS.CLASSIFIER.WEIGHT"> : tensor<30x20xf32>', - module_str, - ) - self.assertIn( - '#stream.parameter.named<"foo"::"PARAMS.CLASSIFIER.BIAS"> : tensor<30xf32>', - module_str, - ) - - def testExternalGlobalParametersMapDict(self): - m = SimpleParams() - mapper = { - "params.classifier.weight": "WEIGHT", - } - - class GlobalModule( - CompiledModule, export_name="external_global_parameters_map_dict" - ): - params = export_parameters( - m, external=True, external_scope="foo", name_mapper=mapper.get - ) - compute = jittable(m.forward) - - def run(self, x=AbstractTensor(128, 20)): - return self.compute(x) - - inst = GlobalModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - '#stream.parameter.named<"foo"::"WEIGHT"> : tensor<30x20xf32>', - module_str, - ) - self.assertIn( - '#stream.parameter.named<"foo"::"params.classifier.bias"> : tensor<30xf32>', - module_str, - ) - - def testUninitializedParameters(self): - m = SimpleParams() - - class GlobalModule(CompiledModule, export_name="uninitialized_parameters"): - params = export_parameters(m, uninitialized=True, mutable=True) - y = export_global(AbstractF32, uninitialized=True, mutable=True) - compute = jittable(m.forward) - - def run(self, x=AbstractTensor(128, 20)): - return self.compute(x), self.y - - inst = GlobalModule(context=Context()) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "#util.uninitialized : tensor<30x20xf32>", - module_str, - ) - self.assertIn( - "#util.uninitialized : f32", - module_str, - ) - - def testUnsupportedCombinations(self): - with self.assertRaisesRegex(ValueError, "mutable=True"): - export_global(AbstractF32, uninitialized=True) - with self.assertRaisesRegex(ValueError, "external=True"): - export_global(AbstractF32, external=True, uninitialized=True) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/aot/iree_procedural_test.py b/core/tests/aot/iree_procedural_test.py deleted file mode 100644 index 9f4799210..000000000 --- a/core/tests/aot/iree_procedural_test.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch - -from iree.compiler.ir import ( - Context, -) - -from shark_turbine.aot import * - - -class CompiledModuleAPI(unittest.TestCase): - def testTensorDim(self): - class BasicModule(CompiledModule): - def foobar(self, a=AbstractTensor(None, 3)): - return IREE.tensor_dim(a, 0) - - inst = BasicModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("%c0 = arith.constant 0", module_str) - self.assertIn("%dim = tensor.dim %arg0, %c0", module_str) - self.assertIn("return %dim", module_str) - - def testTensorDimAsDtype(self): - class BasicModule(CompiledModule): - def foobar(self, a=AbstractTensor(None, 3)): - return IREE.tensor_dim(a, 0, dtype=torch.int32) - - inst = BasicModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("%c0 = arith.constant 0", module_str) - self.assertIn("%dim = tensor.dim %arg0, %c0", module_str) - self.assertIn("%0 = arith.index_castui %dim : index to i32", module_str) - self.assertIn("return %0", module_str) - - def testTensorEmpty(self): - class BasicModule(CompiledModule): - def foobar(self, x=AbstractIndex): - empty = IREE.tensor_empty(x, 16) - dim0 = IREE.tensor_dim(empty, 0) - return empty, dim0 - - inst = BasicModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("%0 = flow.tensor.empty : tensor{%arg0}", module_str) - # NOTE: We are testing below that the dynamic dimension is associated - # and used from the input vs being recalculated. - self.assertIn("return %0, %arg0 : tensor, index", module_str) - - def testTensorSplat(self): - class BasicModule(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractF32): - empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32) - dim0 = IREE.tensor_dim(empty, 0) - return empty, dim0 - - inst = BasicModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "%0 = flow.tensor.splat %arg1 : tensor{%arg0}", module_str - ) - # NOTE: We are testing below that the dynamic dimension is associated - # and used from the input vs being recalculated. - self.assertIn("return %0, %arg0 : tensor, index", module_str) - - def testTensorSplatCasting(self): - class BasicModule(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractIndex): - empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.int32) - dim0 = IREE.tensor_dim(empty, 0) - return empty, dim0 - - inst = BasicModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("%0 = arith.index_castui %arg1 : index to i32", module_str) - self.assertIn("%1 = flow.tensor.splat %0 : tensor{%arg0}", module_str) - - def testTensorTrace(self): - class BasicModule(CompiledModule): - def foobar(self, x=AbstractTensor(None), y=AbstractTensor(3)): - IREE.tensor_trace("DEBUG", x, y) - - inst = BasicModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - 'flow.tensor.trace "DEBUG" = [%arg0 : tensor{%dim}, %arg1 : tensor<3xf32>]', - module_str, - ) - - def testStoreDynamic(self): - class BasicModule(CompiledModule): - x = export_global(AbstractTensor(None, 34), mutable=True) - - def foobar(self, x=AbstractIndex, y=AbstractF32): - splat = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32) - self.x = splat - - # TODO(#171): It is not exactly clear how we want to support dynamic shaped - # globals at this level. - with self.assertRaisesRegex( - ValueError, "Cannot create initialization value for dynamic shaped tensor" - ): - inst = BasicModule(context=Context(), import_to=None) - - def testTensorSliceStatic(self): - class BasicModule(CompiledModule): - def foobar(self, x=AbstractTensor(3, 4)): - return IREE.tensor_slice(x, 0, (1, 3)) - - inst = BasicModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "flow.tensor.slice %arg0[%c0, %c1_0 for %c1, %c3] : tensor<3x4xf32> -> tensor<1x3xf32>", - module_str, - ) - - def testTensorSliceDynamicIndex(self): - class SliceDynamicIndex(CompiledModule): - def foobar(self, x=AbstractIndex): - empty = IREE.tensor_empty(x, 16) - return IREE.tensor_slice(empty, x, 4) - - inst = SliceDynamicIndex(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "flow.tensor.slice %0[%arg0, %c4 for %c1, %c1] : tensor{%arg0} -> tensor<1x1xf32>", - module_str, - ) - - def testTensorSliceDynamicLength(self): - class SliceDynamicIndex(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractIndex): - empty = IREE.tensor_empty(x, 16) - return IREE.tensor_slice(empty, (x, y), 4) - - inst = SliceDynamicIndex(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "flow.tensor.slice %0[%arg0, %c4 for %arg1, %c1] : tensor{%arg0} -> tensor{%arg1}", - module_str, - ) - - def testTensorUpdateStatic(self): - class UpdateStatic(CompiledModule): - def foobar( - self, - target=AbstractTensor(4, 4), - update=AbstractTensor(2, 2), - i=AbstractIndex, - j=AbstractIndex, - ): - return IREE.tensor_update(target, update, i, j) - - inst = UpdateStatic(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "flow.tensor.update %arg1, %arg0[%arg2, %arg3] : tensor<2x2xf32> -> %arg0 as tensor<4x4xf32>", - module_str, - ) - - def testTensorUpdateDynamic(self): - class UpdateDynamic(CompiledModule): - def foobar( - self, - x=AbstractIndex, - y=AbstractIndex, - i=AbstractIndex, - j=AbstractIndex, - value=AbstractF32, - ): - target = IREE.tensor_empty(x, y) - update = IREE.tensor_splat(i, j, value=value, dtype=torch.float32) - return IREE.tensor_update(target, update, 2, 2) - - inst = UpdateDynamic(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "flow.tensor.update %1, %0[%c2, %c2] : tensor{%arg2, %arg3} -> %0 as tensor{%arg0, %arg1}", - module_str, - ) - - def testTensorReshape(self): - class ReshapeModule(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractIndex): - empty = IREE.tensor_empty(x, 16) - reshaped = IREE.tensor_reshape(empty, 1, y, y) - return reshaped - - inst = ReshapeModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn( - "flow.tensor.reshape %0 : tensor{%arg0} -> tensor<1x?x?xf32>{%arg1, %arg1}", - module_str, - ) - - def testScalarAddInt(self): - class ArithModule(CompiledModule): - def foobar(self, a=AbstractI32, b=AbstractI32): - return a + b - - inst = ArithModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - self.assertIn("arith.addi %arg0, %arg1 : i32", module_str) - - def testScalarAddFloat(self): - class ArithModule(CompiledModule): - def foobar(self, a=AbstractF32, b=AbstractF32): - return a + b - - inst = ArithModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - self.assertIn("arith.addf %arg0, %arg1 : f32", module_str) - - def testScalarAddLiteral(self): - class ArithModule(CompiledModule): - def foobar(self, a=AbstractI32): - return a + 1 - - inst = ArithModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - self.assertIn("%c1_i32 = arith.constant 1 : i32", module_str) - self.assertIn("arith.addi %arg0, %c1_i32 : i32", module_str) - - def testScalarAddLiteralMixedType(self): - class ArithModule(CompiledModule): - def foobar(self, a=AbstractI32): - return a + 3.23 - - inst = ArithModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - self.assertIn("%0 = arith.sitofp %arg0 : i32 to f32", module_str) - self.assertIn("%cst = arith.constant 3.230000e+00 : f32", module_str) - self.assertIn("arith.addf %0, %cst : f32", module_str) - - def testSetScalarState(self): - class ArithModule(CompiledModule): - state_index = export_global(AbstractIndex, mutable=True) - state_f32 = export_global(AbstractF32, mutable=True) - - def foobar(self): - self.state_index.set(5) - self.state_f32.set(5.5) - - inst = ArithModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertIn("util.global.store %c5, @_state_index.global : index", module_str) - self.assertIn("%cst = arith.constant 5.500000e+00 : f32", module_str) - self.assertIn("util.global.store %cst, @_state_f32.global", module_str) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/aot/jittable_test.py b/core/tests/aot/jittable_test.py deleted file mode 100644 index 6419c0bd4..000000000 --- a/core/tests/aot/jittable_test.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch - -from iree.compiler.ir import ( - Context, -) - -from shark_turbine.aot import * - - -class JittableTests(unittest.TestCase): - def testImportPhases(self): - class ExportedProcModule(CompiledModule): - def foobar(self): - return self.compute(), self.compute() - - @CompiledModule.jittable - def compute(): - t1 = torch.ones(2, 2) - t2 = t1 + t1 - return t2 * t2 - - inst = ExportedProcModule(context=Context(), import_to="import") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - # Functions should still be on torch types. - self.assertIn( - "func private @compute() -> !torch.vtensor<[2,2],f32>", module_str - ) - CompiledModule.run_import(inst, import_to="full") - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - self.assertNotIn("!torch.vtensor", module_str) - - def testCallWithStructure(self): - class ProcArgsModule(CompiledModule): - def call_with_dicts(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): - intermediate = self.compute({"a": a, "b": b}) - return self.compute(intermediate) - - @jittable - def compute(struct): - a = struct["a"] - b = struct["b"] - result = a + b - return {"a": result, "b": b} - - inst = ProcArgsModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - - def testCallWithArgsKwargs(self): - class ProcArgsModule(CompiledModule): - def call_with_kwargs(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): - intermediate = self.compute(**{"a": a, "b": b}) - return self.compute(**intermediate) - - @jittable - def compute(*, a, b): - result = a + b - return {"a": result, "b": b} - - inst = ProcArgsModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - - def testDynamicDims(self): - class DynamicDimsModule(CompiledModule): - def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)): - return self.compute( - a, - b, - constraints=[ - a.dynamic_dim(0) == b.dynamic_dim(0), - ], - ) - - @jittable - def compute(a, b): - return a * b - - inst = DynamicDimsModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - - def testIntTensors(self): - class ProcArgsModule(CompiledModule): - def dynamic_dim( - self, - a=AbstractTensor(2, 2, dtype=torch.int64), - b=AbstractTensor(1, 1, dtype=torch.int64), - ): - return self.compute(a, b) - - @jittable - def compute(a, b): - return a * b - - inst = ProcArgsModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - - def testIrImmediateTensorAsInputToDynamicDims(self): - class ProcArgsModule(CompiledModule): - def dynamic_dim(self, x=AbstractIndex): - a = IREE.tensor_empty(x, 4) - b = IREE.tensor_empty(x, 4) - return self.compute( - a, b, constraints=[a.dynamic_dim(0) == b.dynamic_dim(0)] - ) - - @jittable - def compute(a, b): - return a * b - - inst = ProcArgsModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - - def testImplicitTensorsImportedAndDeduped(self): - implicit = torch.tensor([1, 2, 3], dtype=torch.int32) - - class ProcArgsModule(CompiledModule): - def implicit( - self, - a=AbstractTensor(3, dtype=torch.int32), - ): - return self.compute(a), self.compute(a) - - @jittable - def compute(a): - return a * implicit - - inst = ProcArgsModule(context=Context(), import_to=None) - module_str = str(CompiledModule.get_mlir_module(inst)) - print(module_str) - # This is testing machinery which ensures that not only are - # implicit captured tensors emitted as resources, but that - # multiple subsequent references to the same tensor is - # only captured once. - resource_string = ( - r'''torch_tensor_3_torch.int32: "0x04000000010000000200000003000000"''' - ) - self.assertIn(resource_string, module_str) - self.assertEqual( - 1, - module_str.count(resource_string), - f"Expected to find exactly one of '{resource_string}' in '{module_str}'", - ) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/aot/params_test.py b/core/tests/aot/params_test.py deleted file mode 100644 index 7ddaac5ca..000000000 --- a/core/tests/aot/params_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -from pathlib import Path -import tempfile -import unittest - -import torch -import torch.nn as nn - -from shark_turbine.aot import ( - export, - externalize_module_parameters, - save_module_parameters, - ExternalTensorTrait, - ParameterArchive, -) - - -class SimpleParamsModule(nn.Module): - def __init__(self): - super().__init__() - self.classifier = nn.Linear(20, 30) - self.large_tensor = torch.rand([30, 50]) - self.dup_large_tensor = torch.rand([30, 50]) - - def forward(self, x): - result = self.classifier(x) + torch.tensor([1.0], dtype=torch.float32) - result = torch.matmul(result, self.large_tensor + self.dup_large_tensor) - return result - - -class ParamsTest(unittest.TestCase): - def testCreateArchive(self): - with tempfile.NamedTemporaryFile("wb", delete=False, suffix=".irpa") as f: - file_path = Path(f.name) - try: - m = SimpleParamsModule() - save_module_parameters(file_path, m) - # mmap=False is a bit nicer for tests on Windows because it doesn't - # lock the file for an arbitrary duration. - archive = ParameterArchive(file_path, mmap=False) - items = dict(archive.items()) - weight = items["classifier.weight"].as_tensor() - bias = items["classifier.bias"].as_tensor() - torch.testing.assert_close(weight, m.classifier.weight) - torch.testing.assert_close(bias, m.classifier.bias) - finally: - file_path.unlink() - - def testCreateArchiveWithPrefixScope(self): - with tempfile.NamedTemporaryFile("wb", delete=False, suffix=".irpa") as f: - file_path = Path(f.name) - try: - m = SimpleParamsModule() - save_module_parameters(file_path, m, prefix="foobar.model") - # mmap=False is a bit nicer for tests on Windows because it doesn't - # lock the file for an arbitrary duration. - archive = ParameterArchive(file_path, mmap=False) - items = dict(archive.items()) - weight = items["foobar.model.classifier.weight"].as_tensor() - bias = items["foobar.model.classifier.bias"].as_tensor() - torch.testing.assert_close(weight, m.classifier.weight) - torch.testing.assert_close(bias, m.classifier.bias) - finally: - file_path.unlink() - - def testExportExternalized(self): - m = SimpleParamsModule() - externalize_module_parameters(m) - output = export(m, args=(torch.empty([128, 20]),)) - asm = str(output.mlir_module) - self.assertIn( - 'util.global private @__auto.classifier.weight = #stream.parameter.named<"model"::"classifier.weight">', - asm, - ) - self.assertIn( - 'util.global private @__auto.classifier.bias = #stream.parameter.named<"model"::"classifier.bias">', - asm, - ) - # Verify that the small tensor is inlined. - self.assertIn("torch.vtensor.literal(dense<1.000000e+00> : tensor<1xf32>)", asm) - # Verify that the large tensors are named uniquely and lifted. - self.assertIn("@__auto.constant_30_50_torch.float32 =", asm) - self.assertIn("@__auto.constant_30_50_torch.float32$1 =", asm) - - -class ExternalTensorTest(unittest.TestCase): - def testExternalTensorTrait(self): - t = torch.ones([2, 3], dtype=torch.float32) - trait = ExternalTensorTrait(external_name="foobar", external_scope="test") - self.assertIsNone(trait.get(t)) - trait.set(t) - self.assertIs(ExternalTensorTrait.get(t), trait) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/dynamo/backend_smoke_test.py b/core/tests/dynamo/backend_smoke_test.py deleted file mode 100644 index 479b8713d..000000000 --- a/core/tests/dynamo/backend_smoke_test.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import torch - - -def test_basic(): - def foo(x, y): - a = torch.sin(x) - b = torch.cos(y) - return a + b - - opt_foo1 = torch.compile(foo, backend="turbine_cpu") - print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10))) diff --git a/core/tests/dynamo/importer_backward_test.py b/core/tests/dynamo/importer_backward_test.py deleted file mode 100644 index e12689b28..000000000 --- a/core/tests/dynamo/importer_backward_test.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from testutils import * - - -class ImportTests(unittest.TestCase): - def testImportCustomLossModule(self): - def foo(x, y): - loss = ((0.5 * x - y) ** 2).mean() - loss.backward() - return loss - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo(torch.randn(10), torch.randn(10, requires_grad=True)) - - # TODO: using func.grad for backward test - - # TODO: MNIST Classifier using LeNet for backward test - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/dynamo/importer_basic_test.py b/core/tests/dynamo/importer_basic_test.py deleted file mode 100644 index 07ccb7d9b..000000000 --- a/core/tests/dynamo/importer_basic_test.py +++ /dev/null @@ -1,358 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import torch - -from testutils import * - - -class ImportTests(unittest.TestCase): - def testImportStateless(self): - a = torch.randn(3, 4) - backend = create_backend() - - @dynamo.optimize(backend) - def basic(x): - return torch.tanh(x) * a - - basic(torch.randn(3, 4)) - - def testImportDtype(self): - def foo(x): - o = x.to(torch.complex32) - o = o.to(torch.float32) - o = o.to(torch.float64) - o = o.to(torch.float16) - o = o.to(torch.int64) - o = o.to(torch.int32) - o = o.to(torch.int16) - o = o.to(torch.int8) - o = o.to(torch.uint8) - o = o.to(torch.complex64) - o = o.to(torch.bool) - # o = o.to(torch.qint8) # we do not currently support quantized dtypes - # o = o.to(torch.quint8) - o = o.to(torch.bfloat16) - return o - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo(torch.ones(10)) - - def testImportDevice(self): - def foo(x): - return torch.arange(x, device="cpu") - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo(10) - - def testImportLayout(self): - def foo(x): - # sparse layouts are not currently supported as they can not be created on the 'meta' device - return torch.ones_like(x, layout=torch.strided) - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo(torch.randn(10)) - - def testImportMemoryFormat(self): - def foo(): - x = torch.ones_like(torch.randn(10), memory_format=torch.contiguous_format) - x = torch.ones_like(torch.randn(10), memory_format=torch.preserve_format) - x = torch.ones_like( - torch.randn(1, 1, 1, 1), memory_format=torch.channels_last - ) - x = torch.ones_like( - torch.randn(1, 1, 1, 1, 1), memory_format=torch.channels_last_3d - ) - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo() - - def testImportListArgs(self): - def foo(): - return torch.randn((4, 5, 6)) - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo() - - def testImportListNodeArgs(self): - def foo(x, y): - return torch.cat((x, y), 0) - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo(torch.randn(10), torch.randn(10)) - - def testImportOptionalListArgs(self): - """ - Upsample triggers aten.index.Tensor with an 'indices' argument of the form List[Optional[Tensor]], this case tests - whether we handle these cases properly in _import_list_argument - """ - - def foo(x): - up = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - return up(x) - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo(torch.randn(4, 4, 4, 4)) - - def testScalarLiteralConversion(self): - """ - Test whether scalar tensors are appropriately converted to literals - """ - - def foo(): - a = torch.tensor(0, dtype=torch.int32) - b = torch.tensor(0, dtype=torch.int64) - c = torch.tensor(0, dtype=torch.float32) - d = torch.tensor(0, dtype=torch.float64) - e = torch.tensor(0, dtype=torch.complex64) - f = torch.tensor(0, dtype=torch.complex128) - g = torch.tensor(0, dtype=torch.bool) - h = torch.tensor(0, dtype=torch.uint8) - i = torch.tensor(0, dtype=torch.int8) - j = torch.tensor(0, dtype=torch.int16) - return a, b, c, d, e, f, g, h, i, j - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo() - print(opt_foo()) - - def testSingleElementTensor(self): - """ - Test whether single element tensors are properly converted to scalars - """ - - def foo(): - a = torch.tensor([0], dtype=torch.int32) - b = torch.tensor([0], dtype=torch.int64) - c = torch.tensor([0], dtype=torch.float32) - d = torch.tensor([0], dtype=torch.float64) - e = torch.tensor([0], dtype=torch.complex64) - f = torch.tensor([0], dtype=torch.complex128) - g = torch.tensor([0], dtype=torch.bool) - h = torch.tensor([0], dtype=torch.uint8) - i = torch.tensor([0], dtype=torch.int8) - j = torch.tensor([0], dtype=torch.int16) - return a[0], b[0], c[0], d[0], e[0], f[0], g[0], h[0], i[0], j[0] - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo() - - def testPromoteScalarTensor(self): - """ - Test whether scalar arguments are properly promoted to 0-rank Tensors for torch ops with no Scalar equivalent - """ - - def foo(x): - return torch.ops.aten.div.Tensor_mode(x, 14, rounding_mode="floor") - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo(torch.randn(4, 4, 4, 4)) - - def testImportDecomposeChunk(self): - def foo_chunk(x): - return torch.chunk(x, 2, dim=-1) - - opt = torch.compile( - foo_chunk, - backend=create_backend( - decompose_ops=[ - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes, - ] - ), - ) - t = torch.randn([4, 4, 4, 4]) - opt(t) - - def testImportDecomposeBatchNorm2D(self): - def foo_bn(x): - return torch.nn.BatchNorm2d(4)(x) - - opt = torch.compile( - foo_bn, - backend=create_backend( - decompose_ops=[ - torch.ops.aten._native_batch_norm_legit_functional, - torch.ops.aten.squeeze.dims, - ] - ), - ) - t = torch.randn([4, 4, 4, 4]) - opt(t) - - def testLiftFreshCopy(self): - def foo(): - w = torch.tensor([[1, 2], [3, 4]], dtype=torch.uint8) - x = torch.tensor([[1, 2], [3, 4]], dtype=torch.int32) - y = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) - # TODO: Figure out why f64 is throwing a verification error - # z = torch.tensor([[1, 2], [3, 4]], dtype=torch.float64) - return w, x, y - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo() - - def testLiftFreshCopyComplex(self): - def foo(): - x = torch.tensor([[1, 2], [3, 4]], dtype=torch.complex64) - y = torch.tensor([[1, 2], [3, 4]], dtype=torch.complex128) - return x, y - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo() - - def testDenseResourceIntegerTypes(self): - def foo(): - b = torch.tensor([True, False], dtype=torch.bool) - ui8 = torch.tensor([[1, 2], [3, -4]], dtype=torch.uint8) - i16 = torch.tensor([[1, 2], [-3, 4]], dtype=torch.int16) - i32 = torch.tensor([[1, -2], [3, 4]], dtype=torch.int32) - i64 = torch.tensor([[-1, 2], [3, 4]], dtype=torch.int64) - return b, ui8, i16, i32, i64 - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo() - - def testDenseResourceFloatTypes(self): - def foo(): - f16 = torch.tensor([1.1, 2.2, 3.3, 4.4], dtype=torch.float16) - f32 = torch.tensor([1.1, 2.2, 3.3, 4.4], dtype=torch.float32) - return f16, f32 - - opt_foo = torch.compile(foo, backend=create_backend()) - opt_foo() - - def testImportVisionModule(self): - from torch import nn - import torch.nn.functional as F - - class ConvBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): - super(ConvBlock, self).__init__() - self.stride = stride - self.channel_pad = out_channels - in_channels - padding = (kernel_size - 1) // 2 - self.convs = nn.Sequential( - nn.Conv2d( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - groups=in_channels, - bias=True, - ), - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=True, - ), - ) - self.act = nn.ReLU(inplace=True) - - def forward(self, x): - h = x - if self.channel_pad > 0: - x = F.pad(x, (0, 0, 0, 0, 0, self.channel_pad), "constant", 0) - return self.act(self.convs(h) + x) - - mod = ConvBlock(3, 5) - opt_mod = torch.compile(mod, backend=create_backend()) - opt_mod(torch.randn(1, 3, 256, 256)) - - def testMultiHeadAttentionModule(self): - import torch.nn as nn - import torch.nn.functional as F - - class ScaledDotProductAttention(nn.Module): - def __init__(self): - super(ScaledDotProductAttention, self).__init__() - - def forward(self, Q, K, V, scale=None): - attention = torch.matmul(Q, K.permute(0, 2, 1)) - if scale: - attention = attention * scale - attention = F.softmax(attention, dim=-1) - context = torch.matmul(attention, V) - return context - - class MultiHeadAttention(nn.Module): - def __init__(self, dim_model, num_head, dropout=0.0): - super(MultiHeadAttention, self).__init__() - self.num_head = num_head - assert dim_model % num_head == 0 - self.dim_head = dim_model // self.num_head - self.fc_Q = nn.Linear(dim_model, num_head * self.dim_head) - self.fc_K = nn.Linear(dim_model, num_head * self.dim_head) - self.fc_V = nn.Linear(dim_model, num_head * self.dim_head) - self.attention = ScaledDotProductAttention() - self.fc = nn.Linear(num_head * self.dim_head, dim_model) - self.dropout = nn.Dropout(dropout) - self.layer_norm = nn.LayerNorm(dim_model) - - def forward(self, x): - batch_size = x.size(0) - Q = self.fc_Q(x) - K = self.fc_K(x) - V = self.fc_V(x) - Q = Q.view(batch_size * self.num_head, -1, self.dim_head) - K = K.view(batch_size * self.num_head, -1, self.dim_head) - V = V.view(batch_size * self.num_head, -1, self.dim_head) - scale = K.size(-1) ** -0.5 - context = self.attention(Q, K, V, scale) - context = context.view(batch_size, -1, self.dim_head * self.num_head) - out = self.fc(context) - out = self.dropout(out) - out = out + x - out = self.layer_norm(out) - return out - - mod = MultiHeadAttention(256, 4) - opt = torch.compile(mod, backend=create_backend()) - opt(torch.randn(1, 1, 256, 256)) - - def testImportAtenFull(self): - def foo(x): - return torch.full(x.size(), fill_value=float("-inf")) - - opt_foo = torch.compile(foo, backend="turbine_cpu") - opt_foo(torch.randn(2, 3)) - - def _create_model(self, bias): - import torch.nn as nn - - class SimpleModel(nn.Module): - def __init__(self, input_size, output_size, bias=False): - super().__init__() - self.classifier = torch.nn.Linear(input_size, output_size, bias=bias) - - def forward(self, x): - return self.classifier(x) - - return SimpleModel(20, 30, bias) - - def test_model_no_bias(self): - model_no_bias = self._create_model(bias=False) - output_no_bias = model_no_bias(torch.randn(128, 20)) - print("\nOutput without bias:") - print(output_no_bias) - opt_foo = torch.compile(model_no_bias, backend="turbine_cpu") - opt_foo(torch.randn(128, 20)) - - def test_model_with_bias(self): - model_with_bias = self._create_model(bias=True) - output_with_bias = model_with_bias(torch.randn(128, 20)) - print("\nOutput with bias:") - print(output_with_bias) - opt_foo = torch.compile(model_with_bias, backend="turbine_cpu") - opt_foo(torch.randn(128, 20)) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/dynamo/importer_dynamic_test.py b/core/tests/dynamo/importer_dynamic_test.py deleted file mode 100644 index a47171a97..000000000 --- a/core/tests/dynamo/importer_dynamic_test.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import sys -import unittest - -import torch -import torch._dynamo as dynamo -from torch._export import dynamic_dim - -# from torch._export.constraints import constrain_as_size, constrain_as_value -from iree.compiler.extras.fx_importer import FxImporter -from shark_turbine.dynamo.passes import turbine_cpu_pass_pipeline -import torch -import torch._dynamo as dynamo -from torch._dynamo.backends.common import aot_autograd -from torch.fx.experimental.proxy_tensor import make_fx -from torch._decomp import get_decompositions -from torch.func import functionalize -from torch.fx import ( - GraphModule, -) -from iree.compiler.api import ( - Invocation, - Session, - Source, - Output, -) - -from iree.compiler.passmanager import ( - PassManager, -) - - -DEFAULT_COMPILER_FLAGS = ( - # Enable asynchronous calling convention. - # TODO: Enable async execution mode. - # "--iree-execution-model=async-external", - "--iree-input-type=tm_tensor", -) - - -def import_compiler(gm: GraphModule, example_inputs, decompose_ops=None): - session = Session() - session.set_flags(*DEFAULT_COMPILER_FLAGS) - session.set_flags("--iree-hal-target-backends=llvm-cpu") - context = session.context - imp = FxImporter(context=context) - module = imp.module - - inv = session.invocation() - # TODO: Should capture diagnostics. - inv.enable_console_diagnostics() - inv.import_module(module.operation) - - if decompose_ops is not None: - gm = make_fx( - functionalize(gm), - decomposition_table=get_decompositions(decompose_ops), - )(*example_inputs) - - gm.print_readable() - try: - imp.import_graph_module(gm) - print(module, file=sys.stderr) - with context: - with open("/tmp/module.mlir", "w") as file: - file.write(str(module)) - pm = PassManager.parse("builtin.module(torch-to-iree)") - pm.run(module.operation) - - finally: - print(module, file=sys.stderr) - module.operation.verify() - return gm - - -class DynamicBMM(torch.nn.Module): - def __init__(self, n, k): - super().__init__() - self.weight0 = torch.nn.Parameter(torch.rand(n, k)) - - def forward(self, inp, *, bias): - mm = torch.matmul(inp, self.weight0) - biased = mm + bias - return {"result": biased} - - -class DynamicBuiltinOps(torch.nn.Module): - def forward(self, inp): - x = inp.size()[1] - inp.size()[2] - x = x * inp.size()[1] - 34.2 - g = x / 32 - return {"result": g} - - -class DynamicShapeStridedModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a): - dynamic_shape = [a.size(0), a.size(1), a.size(2)] - x = torch.ops.aten.empty_strided( - dynamic_shape, stride=[12, 4, 1] - ) # Default stride = [12, 4, 1] - y = x.copy_(a) - return y - - -class ImportSmokeTests(unittest.TestCase): - def testStaticExport(self): - """ - 'tensor.collapse_shape' op expected dimension 0 of collapsed type to be static value of 1 - """ - model = DynamicBMM(12, 19) - inp_example = torch.rand(1, 2, 12) - bias_example = torch.rand(19) - f = dynamo.export( - model.forward, - aten_graph=True, - same_signature=False, - assume_static_by_default=True, - constraints=[ - dynamic_dim(inp_example, 1) >= 2, - ], - ) - g, guards = f(inp=inp_example, bias=bias_example) - g = import_compiler(g, [inp_example, bias_example]) - - def testStaticExportSameSignatureTrue(self): - """ - 'tensor.collapse_shape' op expected dimension 0 of collapsed type to be static value of 1 - """ - model = DynamicBMM(12, 19) - inp_example = torch.rand(1, 2, 12) - bias_example = torch.rand(19) - f = dynamo.export( - model.forward, - aten_graph=True, - same_signature=True, - assume_static_by_default=True, - constraints=[ - dynamic_dim(inp_example, 1) >= 2, - ], - ) - g, guards = f(inp=inp_example, bias=bias_example) - g = import_compiler(g, [inp_example, bias_example]) - - def testStaticExportBuiltinOps(self): - model = DynamicBuiltinOps() - inp_example = torch.rand(1, 2, 12) - f = dynamo.export( - model.forward, - aten_graph=True, - same_signature=True, - assume_static_by_default=True, - constraints=[ - dynamic_dim(inp_example, 1) >= 2, - ], - ) - g, guards = f(inp=inp_example) - g = import_compiler(g, [inp_example]) - - @unittest.expectedFailure - def testDynamicShapeStrided(self): - """ - Regardless of default stride=[12, 4, 1] provided, we get the following error. - failed to legalize operation 'torch.constant.int' - By Dumping IR, you get the following. - /tmp/module.mlir:7:10: error: 'tensor.collapse_shape' op expected dimension 0 of collapsed type to be static value of 1 - %2 = torch.aten.view %arg0, %1 : !torch.vtensor<[1,?,12],f32>, !torch.list -> !torch.vtensor<[?,12],f32> - """ - model = DynamicShapeStridedModule() - # inp_example = torch.rand(5, 7, 9) - inp_example = torch.randn(2, 3, 4) # input for default stride - f = dynamo.export( - model.forward, - aten_graph=True, - same_signature=True, - assume_static_by_default=True, - constraints=[ - dynamic_dim(inp_example, 0) >= 0, - ], - ) - g, guards = f(a=inp_example) - g = import_compiler(g, [inp_example]) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/dynamo/llama_test.py b/core/tests/dynamo/llama_test.py deleted file mode 100644 index 657502771..000000000 --- a/core/tests/dynamo/llama_test.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging - - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import math -import unittest -import pytest -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import torch -import torch.nn.functional as F -from torch import nn - - -@dataclass -class ModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: Optional[int] = None - vocab_size: int = -1 # defined later by tokenizer - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None - norm_eps: float = 1e-5 - - max_batch_size: int = 32 - max_seq_len: int = 2048 - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - # model_parallel_size = fs_init.get_model_parallel_world_size() - # print("MODELPARALLELSIZE", model_parallel_size) - model_parallel_size = 1 - self.n_local_heads = args.n_heads // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // model_parallel_size - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // args.n_heads - - self.wq = nn.Linear( - args.dim, - args.n_heads * self.head_dim, - bias=False, - ) - self.wk = nn.Linear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=False, - ) - self.wv = nn.Linear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=False, - ) - self.wo = nn.Linear( - args.n_heads * self.head_dim, - args.dim, - bias=False, - ) - - self.cache_k = torch.zeros( - ( - args.max_batch_size, - args.max_seq_len, - self.n_local_kv_heads, - self.head_dim, - ) - ) - self.cache_v = torch.zeros( - ( - args.max_batch_size, - args.max_seq_len, - self.n_local_kv_heads, - self.head_dim, - ) - ) - - def forward( - self, - x: torch.Tensor, - start_pos: int, - freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], - ): - bsz, seqlen, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - self.cache_k = self.cache_k.to(xq) - self.cache_v = self.cache_v.to(xq) - - self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk - self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv - - keys = self.cache_k[:bsz, : start_pos + seqlen] - values = self.cache_v[:bsz, : start_pos + seqlen] - - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) - if mask is not None: - scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) - output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - return self.wo(output) - - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - self.w2 = nn.Linear( - hidden_dim, - dim, - bias=False, - ) - self.w3 = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs): - super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.head_dim = args.dim // args.n_heads - self.attention = Attention(args) - self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, - ) - self.layer_id = layer_id - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - - def forward( - self, - x: torch.Tensor, - start_pos: int, - freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], - ): - h = x + self.attention.forward( - self.attention_norm(x), start_pos, freqs_cis, mask - ) - out = h + self.feed_forward.forward(self.ffn_norm(h)) - return out - - -class Transformer(nn.Module): - def __init__(self, params: ModelArgs): - super().__init__() - self.params = params - self.vocab_size = params.vocab_size - self.n_layers = params.n_layers - - self.tok_embeddings = nn.Embedding( - params.vocab_size, - params.dim, - ) - - self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params)) - - self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear( - params.dim, - params.vocab_size, - bias=False, - ) - - self.freqs_cis = precompute_freqs_cis( - self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 - ) - - # @torch.inference_mode() - def forward(self, tokens: torch.Tensor, start_pos: int): - _bsz, seqlen = tokens.shape - h = self.tok_embeddings(tokens) - self.freqs_cis = self.freqs_cis.to(h.device) - freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] - - mask = None - if seqlen > 1: - mask = torch.full( - (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device - ) - mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) - - for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) - h = self.norm(h) - output = self.output(h).float() - return output - - -def main(): - # Example input values - batch_size = 2 - seq_len = 8 - vocab_size = 16 - n_layers = 2 - n_heads = 4 - dim = 64 - start_pos = 0 - example_tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) - print(example_tokens.dtype) - - args = ModelArgs(vocab_size=vocab_size, n_layers=n_layers, n_heads=n_heads, dim=dim) - mod = Transformer(args) - # mod(example_tokens, start_pos) - opt = torch.compile(mod, backend="turbine_cpu") - opt(example_tokens, start_pos) - - -@pytest.mark.xfail(reason="https://github.com/nod-ai/SHARK-Turbine/issues/221") -class ModelTests(unittest.TestCase): - def testLLama(self): - main() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - main() diff --git a/core/tests/dynamo/mninst_test.py b/core/tests/dynamo/mninst_test.py deleted file mode 100644 index 88742e2bd..000000000 --- a/core/tests/dynamo/mninst_test.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch -from torch import nn -from torch.utils.data import DataLoader -import torchvision.transforms as transforms -import torchvision.datasets as datasets - - -import torch._dynamo.config - -torch._dynamo.config.dynamic_shapes = ( - False # TODO: https://github.com/nod-ai/SHARK-Turbine/issues/93 -) - - -class MNISTDataLoader: - def __init__(self, batch_size, shuffle=True): - self.batch_size = batch_size - self.shuffle = shuffle - - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] - ) - - self.mnist_trainset = datasets.MNIST( - root="../data", train=True, download=True, transform=transform - ) - self.mnist_testset = datasets.MNIST( - root="../data", train=False, download=True, transform=transform - ) - - def get_train_loader(self): - return DataLoader( - dataset=self.mnist_trainset, - batch_size=self.batch_size, - shuffle=self.shuffle, - ) - - def get_test_loader(self): - return DataLoader( - dataset=self.mnist_testset, - batch_size=self.batch_size, - shuffle=False, - drop_last=True, - ) - - -class MLP(nn.Module): - def __init__(self): - super().__init__() - self.layer0 = nn.Linear(28, 28, bias=True) - self.layer1 = nn.Linear(28, 14, bias=True) - self.layer2 = nn.Linear(14, 7, bias=True) - self.layer3 = nn.Linear(7, 7, bias=True) - - def forward(self, x: torch.Tensor): - x = self.layer0(x) - x = torch.sigmoid(x) - x = self.layer1(x) - x = torch.sigmoid(x) - x = self.layer2(x) - x = torch.sigmoid(x) - x = self.layer3(x) - return x - - -def infer_iteration(model, images): - outputs = model(images) - return outputs - - -def infer(): - # Example Parameters - config = { - "batch_size": 64, - "learning_rate": 0.001, - "num_epochs": 10, - } - - custom_data_loader = MNISTDataLoader(config["batch_size"]) - test_loader = custom_data_loader.get_test_loader() - - model = MLP() - test_opt = torch.compile(infer_iteration, backend="turbine_cpu") - - for i, (images, labels) in enumerate(test_loader): - test_opt(model, images) - - -class ModelTests(unittest.TestCase): - def testMNISTEagerSimple(self): - infer() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/dynamo/tensor_scalar_op_conversion_importer_test.py b/core/tests/dynamo/tensor_scalar_op_conversion_importer_test.py deleted file mode 100644 index a7d9ea63e..000000000 --- a/core/tests/dynamo/tensor_scalar_op_conversion_importer_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from testutils import * - - -class TensorScalarOpConversionImportModule(unittest.TestCase): - def setUp(self): - self.t = torch.randn(2, 2) - - def add(self, x): - return x + 8.2 - - def sub(self, x): - return x - 1.6 - - def mul(self, x): - return x * 3.2 - - def div(self, x): - return x / 2.1 - - def floor_div(self, x): - return x // 4.2 - - def testAdd(self): - opt_torch_scalar_convert = torch.compile(self.add, backend=create_backend()) - result = opt_torch_scalar_convert(self.t) - expected_result = self.add(self.t) - self.assertTrue(torch.allclose(result, expected_result), "broken") - - def testSub(self): - opt_torch_scalar_convert = torch.compile(self.sub, backend=create_backend()) - result = opt_torch_scalar_convert(self.t) - expected_result = self.sub(self.t) - self.assertTrue(torch.allclose(result, expected_result), "broken") - - def testMul(self): - opt_torch_scalar_convert = torch.compile(self.mul, backend=create_backend()) - result = opt_torch_scalar_convert(self.t) - expected_result = self.mul(self.t) - self.assertTrue(torch.allclose(result, expected_result), "broken") - - def testDiv(self): - opt_torch_scalar_convert = torch.compile(self.div, backend=create_backend()) - result = opt_torch_scalar_convert(self.t) - expected_result = self.div(self.t) - self.assertTrue(torch.allclose(result, expected_result), "broken") - - def testFloorDiv(self): - """ - This op isn't successfully created by IREE due to partial implementation of floor_div op in torch-mlir - However, the importer works successfully. - """ - opt_torch_scalar_convert = torch.compile( - self.floor_div, backend=create_backend() - ) - result = opt_torch_scalar_convert(self.t) - expected_result = self.floor_div(self.t) - self.assertTrue(torch.allclose(result, expected_result), "broken") - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/dynamo/tensor_test.py b/core/tests/dynamo/tensor_test.py deleted file mode 100644 index fcd406608..000000000 --- a/core/tests/dynamo/tensor_test.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import time -import unittest - -import numpy as np -import torch - -# Public API imports. -from shark_turbine.runtime import Device -from shark_turbine.dynamo import TurbineMode, DeviceTensor - - -class TensorTest(unittest.TestCase): - def setUp(self): - self.mode = TurbineMode() - self.mode.__enter__() - Device("local-task").set() - - def tearDown(self) -> None: - Device.current().clear() - self.mode.__exit__(None, None, None) - - @unittest.expectedFailure - def test_explicit_construct(self): - size = (2, 2) - t1 = DeviceTensor(size, torch.float32, np.ones(size)) - t2 = DeviceTensor( - size, torch.float32, np.arange(size[0] * size[1]).reshape(size) - ) - print("Inputs:") - print(t1) - print(t2) - - def test_async_copy_from_host(self): - t1 = torch.empty(4, device="turbine") - ar = np.arange(4, dtype=np.float32) - t1._async_copy_from_host(ar) - np.testing.assert_array_equal(t1.cpu(), ar) - - def test_cpu_to(self): - t_cpu = torch.arange(4).cpu() - t_t = t_cpu.to(device="turbine") - np.testing.assert_array_equal(t_cpu, t_t.cpu()) - - def test_factory_function_empty(self): - # Factory functions - t1 = torch.empty(4, device="turbine") - print("Empty Tensor (un-initialized memory!):") - print(t1) - - def test_factory_function_empty_tuple_size(self): - # TODO: Test some invariants vs just printing. - t1 = torch.empty((4, 4), device="turbine") - print("Empty Tensor (un-initialized memory!):") - print(t1) - print(t1.buffer_view) - print(t1.to("cpu")) - print(t1.cpu()) - - def test_factory_function_zeros(self): - t1 = torch.zeros(2, 3, device="turbine") - np.testing.assert_array_equal(t1.cpu(), [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - - def test_factory_function_ones(self): - t1 = torch.ones(2, 3, device="turbine") - np.testing.assert_array_equal(t1.cpu(), [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) - - def test_factory_arange(self): - t1 = torch.arange(4, device="turbine", dtype=torch.float32) - ar = np.arange(4, dtype=np.float32) - np.testing.assert_array_equal(t1.cpu(), ar) - - def test_factory_rand(self): - t1 = torch.rand(4, device="turbine", dtype=torch.float32) - print(t1.cpu()) - - def test_binary_op(self): - t1 = 5.3 * torch.ones(2, 3).to(device="turbine") - t2 = 2.3 * torch.ones(2, 3).to(device="turbine") - t3 = t1 * t2 - np.testing.assert_allclose( - t3.cpu(), [[12.19, 12.19, 12.19], [12.19, 12.19, 12.19]] - ) - - def test_unary_op(self): - t1 = -5.3 * torch.ones(2, 3).to(device="turbine") - t2 = torch.abs(t1) - np.testing.assert_allclose(t2.cpu(), [[5.3, 5.3, 5.3], [5.3, 5.3, 5.3]]) - - def test_nn_linear(self): - m = torch.nn.Linear(20, 30) - input = torch.randn(128, 20) - ref_output = m(input) - m.to("turbine") - input = input.to("turbine") - turbine_output = m(input) - np.testing.assert_allclose( - turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6 - ) - - def test_nn_MLP(self): - class MLP(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer0 = torch.nn.Linear(64, 32, bias=True) - self.layer1 = torch.nn.Linear(32, 16, bias=True) - self.layer2 = torch.nn.Linear(16, 7, bias=True) - self.layer3 = torch.nn.Linear(7, 7, bias=True) - - def forward(self, x: torch.Tensor): - x = self.layer0(x) - x = torch.sigmoid(x) - x = self.layer1(x) - x = torch.sigmoid(x) - x = self.layer2(x) - x = torch.sigmoid(x) - x = self.layer3(x) - return x - - m = MLP() - input = torch.randn(16, 64) - ref_output = m(input) - m.to("turbine") - input = input.to("turbine") - turbine_output = m(input) - np.testing.assert_allclose( - turbine_output.cpu(), ref_output.detach().numpy(), atol=1e-6 - ) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/dynamo/testutils.py b/core/tests/dynamo/testutils.py deleted file mode 100644 index 2a89ef92e..000000000 --- a/core/tests/dynamo/testutils.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest -from typing import List - -from iree.compiler.extras.fx_importer import FxImporter -import torch -import torch._dynamo as dynamo -from torch._dynamo.backends.common import aot_autograd -from torch.fx.experimental.proxy_tensor import make_fx -from torch._decomp import get_decompositions -from torch.func import functionalize -from torch.fx import ( - GraphModule, -) - - -def create_backend(decompose_ops: List[torch._ops.OpOverloadPacket] = None): - imp = FxImporter() - - def import_compiler(gm: GraphModule, example_inputs): - if decompose_ops is not None: - gm = make_fx( - functionalize(gm), - decomposition_table=get_decompositions(decompose_ops), - )(*example_inputs) - - gm.print_readable() - try: - imp.import_graph_module(gm) - finally: - print(imp.module) - imp.module.operation.verify() - return gm - - backend = import_compiler - backend = aot_autograd(fw_compiler=backend) - return backend diff --git a/core/tests/dynamo/type_conversion_test.py b/core/tests/dynamo/type_conversion_test.py deleted file mode 100644 index dfc3de25b..000000000 --- a/core/tests/dynamo/type_conversion_test.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -from iree.compiler.ir import ( - Context, - Type as IrType, -) - -import shark_turbine.dynamo.type_conversion as tc - - -class TypeConversionTest(unittest.TestCase): - def setUp(self) -> None: - self.conv = tc.NativeTypeConverter(Context()) - - def testPrimitives(self): - self._compareNative("!torch.bool", "i1") - self._compareNative("!torch.int", "i64") - self._compareNative("!torch.float", "f64") - - def testSigned(self): - self._compareNative("!torch.bool", "i1", signless=False) - self._compareNative("!torch.int", "si64", signless=False) - - def testValueTensors(self): - self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>") - self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor") - self._compareNative("!torch.vtensor<[],f32>", "tensor") - - def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True): - with self.conv._context: - torch_type = IrType.parse(torch_str) - native_type = self.conv.torch_type_to_native(torch_type, signless=signless) - self.assertEqual(str(native_type), native_str) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/examples/aot_mlp_test.py b/core/tests/examples/aot_mlp_test.py deleted file mode 100644 index c4266a4af..000000000 --- a/core/tests/examples/aot_mlp_test.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -from pathlib import Path -import sys -import subprocess -import unittest - -REPO_DIR = Path(__file__).resolve().parent.parent.parent - - -def _run(local_path: str): - path = REPO_DIR / local_path - subprocess.check_call([sys.executable, str(path)]) - - -class AOTMLPTest(unittest.TestCase): - def testMLPExportSimple(self): - _run("examples/aot_mlp/mlp_export_simple.py") - - def testMLPExportSimple(self): - _run("examples/aot_mlp/mlp_export_dynamic.py") - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/generated/evaluate.py b/core/tests/generated/evaluate.py deleted file mode 100644 index 3184930d8..000000000 --- a/core/tests/generated/evaluate.py +++ /dev/null @@ -1,43 +0,0 @@ -from stats import ErrorAggregatorDict -import logging - -from iree.compiler.extras.fx_importer import FxImporter -from shark_turbine.dynamo.passes import turbine_cpu_pass_pipeline -import torch -import torch._dynamo as dynamo -from torch._dynamo.backends.common import aot_autograd -from torch.fx import ( - GraphModule, -) - - -def create_backend(): - imp = FxImporter() - - def import_compiler(gm: GraphModule, example_inputs): - gm = turbine_cpu_pass_pipeline(gm, example_inputs) - - try: - imp.import_graph_module(gm) - finally: - pass - imp.module.operation.verify() - return gm - - backend = import_compiler - backend = aot_autograd(fw_compiler=backend) - return backend - - -def evaluate_importer(nn_cls, get_init_args, get_forward_args, test_identifier): - log = logging.getLogger("turbine-test") - try: - args, kwargs = get_init_args() - nn_module = nn_cls(*args, **kwargs) - opt_mod = torch.compile(nn_module, backend=create_backend()) - - fargs, fkwargs = get_forward_args() - opt_mod(*fargs, **fkwargs) - except Exception as e: - err = ErrorAggregatorDict.single(str(e), test_identifier) - return err diff --git a/core/tests/generated/extract_unimpl_ops.sh b/core/tests/generated/extract_unimpl_ops.sh deleted file mode 100644 index ace9134eb..000000000 --- a/core/tests/generated/extract_unimpl_ops.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -python main.py --limit 500 -j 8 | grep "NotImplementedError: Unimplemented torch op in the IREE compiler" | grep -o "'[^']*'" | sed "s/'//g" > unimplemented_torch_ops.txt \ No newline at end of file diff --git a/core/tests/generated/main.py b/core/tests/generated/main.py deleted file mode 100644 index 98f15febd..000000000 --- a/core/tests/generated/main.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -import sys -import argparse - -from testutils import evaluate_all - -# magic: makes importing this module work in the testsuite -import torch._inductor.config - -import logging - -log = logging.getLogger("turbine-test") -logging.basicConfig(level=logging.INFO) - -ENV_FILE = "JITPARITYBENCH_PATH.txt" - - -def get_args(raw_args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--jobs", - "-j", - type=int, - default=4, - help="Number of threads in our threadpool, ignored if --sequential is set", - ) - parser.add_argument( - "--offset", - type=int, - default=0, - help="Pick files starting from this offset. Together with --limit, we can run through all files in multiple separate runs", - ) - parser.add_argument("--limit", "-l", type=int, help="only run the first N files") - parser.add_argument( - "--filter", "-f", "-k", help="only run module containing given name" - ) - parser.add_argument("--skips", type=str) - parser.add_argument( - "--tests-dir", - default=None, - help="jit-paritybench location (i.e. /path/to/pytorch-jit-paritybench)", - ) - parser.add_argument( - "--sequential", - action="store_true", - help="Set to run tests sequentially without threading, this can help resolve hanging or long runtimes due to low memory", - ) - # parser.add_argument("--device", default="cuda", type=str, help="evaluate modules using cuda or cpu") # excluded for now as we only have turbine-cpu, can use this later - - args = parser.parse_args(raw_args) - return args - - -def write_path(path: str): - with open(ENV_FILE, "w") as f: - f.write(path) - - -def read_path() -> str: - with open(ENV_FILE, "r") as f: - path = f.read() - return path - - -if __name__ == "__main__": - args = get_args() - - if args.tests_dir is not None: - pb = args.tests_dir - write_path(pb) # store this path for next time - log.info(f"Using test directory from CLI: {pb}") - elif os.path.exists(ENV_FILE): - pb = read_path() - log.info(f"Using test directory from {ENV_FILE}: {pb}") - else: - raise RuntimeError( - f"Must either pass 'tests-dir' or set {ENV_FILE} in order to run tests" - ) - - # enables finding necessary modules in jit-paritybench - pb_gen = pb + "/generated" - sys.path.append(pb) - sys.path.append(pb_gen) - - evaluate_all(args, pb_gen, offset=args.offset, limit=args.limit, jobs=args.jobs) diff --git a/core/tests/generated/running_tests.md b/core/tests/generated/running_tests.md deleted file mode 100644 index 6c42cc31f..000000000 --- a/core/tests/generated/running_tests.md +++ /dev/null @@ -1,51 +0,0 @@ -# Running Tests - -## Set Up -This test suite requires a local clone of the `pytorch-jit-paritybench` repository: -```shell -# somewhere ... -git clone https://github.com/jansel/pytorch-jit-paritybench.git -cd pytorch-jit-paritybench -pip install -r requirements.txt -conda install cpuonly -c pytorch-nightly -``` - -Note that we are not exactly following the setup described in the above repo, mainly to avoid issues with dependencies between `conda` and `pip` versions of relevant packages (see 'Known Issues' below). - -There may be some additional packages to install that are not in the requirements.txt in order to successfully run the tests, one that comes to mind is `expecttest` - -## Running -Once everything is set up in your conda environment we can run the test suite using `python/test/generated/main.py`. Initially you have to pass the location of your `pytorch-jit-paritybench` repo to the script: `python main.py --tests-dir /path/to/pytorch-jit-paritybench`. After the first run, the script will save the path for future use in a local text file for convenience and you do not need to pass it again. - -To speed up iteration on tests it's recommended to make use of the `offset`, `limit`, and `filter` arguments as running the full test suite can take some time. - -## Unimplemented Torch Ops -Many of the errors in our test suite will arise due to unimplemented ops in torch-mlir, we can use the `extract_unimpl_ops.sh` script to extract a list of these ops: -```bash -python main.py --limit 500 -j 8 | grep "NotImplementedError: Unimplemented torch op in the IREE compiler" | grep -o "'[^']*'" | sed "s/'//g" > unimplemented_torch_ops.txt -``` - - -## Help -``` -usage: main.py [-h] [--jobs JOBS] [--offset OFFSET] [--limit LIMIT] [--filter FILTER] [--skips SKIPS] [--tests-dir TESTS_DIR] - -options: - -h, --help show this help message and exit - --jobs JOBS, -j JOBS Number of threads in our threadpool, jobs=1 is essentially sequential execution - --offset OFFSET Pick files starting from this offset. Together with --limit, we can run through all files in multiple separate runs - --limit LIMIT, -l LIMIT - only run the first N files - --filter FILTER, -f FILTER, -k FILTER - only run module containing given name - --skips SKIPS - --tests-dir TESTS_DIR - jit-paritybench location (i.e. /path/to/pytorch-jit-paritybench) - -``` - - -# Known Issues -On Mac, setting resource limits via a python shell is finicky/not allowed this can cause issues with the jit parity-bench tests as they utilize the `resource` package to set an optional resource limit [here](https://github.com/jansel/pytorch-jit-paritybench/blob/7e55a422588c1d1e00f35a3d3a3ff896cce59e18/paritybench/utils.py#L57) - the simplest fix is to simply comment those lines out in your local `jit-paritybench` repo. - -Getting unknown symbol errors associated with a shared library (commonly a torch library) often occurs because of mixing conda installed dependencies with pip installed dependencies because of possible differences in how certain shared libraries are linked, if possible use pip for all of your dependencies (especially when they have a dependence on one another like torch, torchvision, and torchaudio) \ No newline at end of file diff --git a/core/tests/generated/stats.py b/core/tests/generated/stats.py deleted file mode 100644 index e39795bbd..000000000 --- a/core/tests/generated/stats.py +++ /dev/null @@ -1,94 +0,0 @@ -import csv -import logging -import os -import random -import re -from collections import Counter, defaultdict -from typing import List - -log = logging.getLogger("turbine-test") - - -class Stats(Counter): - """ - Collect and group error messages for a debug report at the end - """ - - def __str__(self): - stats_keys = [ - "PASSED", - "FAILED", - "SKIPPED", - "XFAILED", - "NO_FWD", - "TIMEOUT", - "CRASHED", - "TOTAL", - ] - - return str([(k, self[k]) for k in stats_keys if k in self]) - - -class ErrorAggregatorDict(object): - """ - Collect and group error messages for a debug report at the end - """ - - def __init__(self): - super(ErrorAggregatorDict, self).__init__() - self.errors = defaultdict(list) - - def __len__(self): - return len(self.errors) - - def __getitem__(self, item: str): - return self.errors[item] - - def __iadd__(self, other): - self.update(other) - return self - - @classmethod - def single(cls, error, name): - obj = ErrorAggregatorDict() - obj.insert(error, name) - return obj - - def items(self): - return self.errors - - # insert into dict with error string and test name - def insert(self, error: str, name: str): - self.errors[error].append(name) - - def update(self, other): - if not len(other): - return - - other_dict = other.items().items() - for error, names in other_dict: - self.errors[error] += names - - def print_report(self): - if not len(self.errors): - log.info("No exceptions") - return - - print("\n" + "".join("*" * 80) + "\n" + "EXCEPTIONS" + "\n" + "".join("*" * 80)) - for error, test_names in sorted(self.errors.items(), key=lambda e: len(e[1])): - extra = "" - n_tests = len(test_names) - if n_tests > 15: - extra = ", ..." - test_names = test_names[:15] - - print( - "\033[1;36m" - + str(n_tests) - + ": " - + ", ".join(test_names) - + extra - + "\033[0;0m" - ) - print("\033[1;31m" + error.strip() + "\033[0;0m" + "\n") - print("-" * 20) diff --git a/core/tests/generated/testutils.py b/core/tests/generated/testutils.py deleted file mode 100644 index 9131d1d7a..000000000 --- a/core/tests/generated/testutils.py +++ /dev/null @@ -1,209 +0,0 @@ -import time -import types -import os -import re -import sys -from functools import partial -import multiprocessing -from multiprocessing.pool import ThreadPool -import threading -import signal -import platform -import resource -import logging -from tqdm import * - -from stats import Stats, ErrorAggregatorDict -from evaluate import evaluate_importer - -log = logging.getLogger("turbine-test") - - -def call_with_timeout(fn, args, kwargs=None, timeout=10): - kwargs = kwargs or {} - parent_conn, child_conn = multiprocessing.Pipe() - start = time.time() - proc = multiprocessing.Process( - target=call_with_timeout_subproc, args=(fn, args, kwargs, child_conn) - ) - proc.start() - while proc.is_alive(): - if parent_conn.poll(1): - result = parent_conn.recv() - proc.join() - return result - if time.time() - start > timeout: - os.kill( - proc.pid, signal.SIGINT - ) # maybe generate a stack trace for debugging - time.sleep(1) - proc.terminate() - proc.join(10) - raise TimeoutError(f"took longer than {timeout} seconds") - - proc.join() - if proc.exitcode == 0: - return parent_conn.recv() - else: - raise OSError(f"exitcode should be 0, got {proc.exitcode}") - - -def call_with_timeout_subproc(fn, args, kwargs, return_pipe): - # use_rlimit = ( - # os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES") // 1024 ** 3 < 1000 - # if platform.system() == "Linux" - # else True - # ) - # if use_rlimit: - # _, hard = resource.getrlimit(resource.RLIMIT_AS) - # # resource.setrlimit(resource.RLIMIT_AS, (int(os.environ.get("RLIMIT_AS_GB", 10)) * 1024 ** 3, hard)) - try: - result = fn(*args, *kwargs) - return_pipe.send(result) - except Exception: - log.exception("Error from subprocess") - sys.exit(1) - - -def subproc_wrapper(path: str, fn: callable, timeout: int = 900): - """ - A wrapper around call_with_timeout() adding a temp dir and error handling. - - :param path: path to code to test - :param fn: function to run in subprocess - :param timeout: seconds to wait - :return: errors, stats - """ - file = os.path.basename(path).split("/")[-1] - test_identifier = re.sub(r"\.py$", "", file) - - log.info(f"Running {path}") - try: - return call_with_timeout(fn, [path], {}, timeout=timeout) - except TimeoutError as e: - return ErrorAggregatorDict.single(str(e), test_identifier), Stats( - {"TIMEOUT": 1} - ) - except OSError as e: - return ErrorAggregatorDict.single(str(e), test_identifier), Stats( - {"CRASHED": 1} - ) - - -def import_file(path): - """ - :param path: to a *.py file - :return: a python module - """ - module = types.ModuleType(re.findall(r"test_[^.]+", path)[0]) - sys.modules[module.__name__] = module - exec( - compile(open(path).read(), filename=path, mode="exec"), - module.__dict__, - module.__dict__, - ) - if not hasattr(module, "TESTCASES"): - module.TESTCASES = [] - - return module - - -def evaluate_pyfile_subproc(path: str, args, eval_fn=evaluate_importer): - """ - Evaluate/test all the TESTCASES in path. - - :param path: *.py file to test - :return: errors, stats - """ - errors = ErrorAggregatorDict() - stats = Stats() - module = import_file(path) - - if not module.TESTCASES: - log.info(f"Skipping empty module: {module.__name__}") - stats["SKIPPED"] += 1 - return errors, stats - - index = -1 - for nn_cls, get_init_args, get_forward_args, compiles in module.TESTCASES: - index += 1 - stats["TOTAL"] += 1 - - if args.filter and args.filter not in nn_cls.__name__: - stats["SKIPPED"] += 1 - continue - - if args.skips and f"{nn_cls.__name__}" in args.skips: - stats["SKIPPED"] += 1 - continue - - # nn.module doesn't have `forward` function(e.g, has __call__ instead). - # dynamo doesn't plan to support it yet. - if nn_cls.forward.__name__ == "_forward_unimplemented": - stats["NO_FWD"] += 1 - continue - - repro = f"{nn_cls.__name__} # pytest {path} -k test_{index:03d}" - test_identifier = f"{module.__name__}__{index:03d}" - eval_args = [nn_cls, get_init_args, get_forward_args, test_identifier] - - try: - err_dict = eval_fn(*eval_args) - if err_dict and len(err_dict): - log.info(f"{test_identifier} - FAIL") - errors.update(err_dict) - stats["FAILED"] += 1 - else: - log.info(f"{test_identifier} - PASS") - stats["PASSED"] += 1 - except Exception as e: - log.info(f"{test_identifier} - FAIL (Exception)") - errors.insert(str(e), test_identifier) - - return errors, stats - - -def evaluate_all( - args, tests_dir: str = "./generated", offset: int = 0, limit: int = None, jobs=4 -): - """ - Generate a paritybench score, main entrypoint for this module. - - :param tests_dir: directory containing paritybench testcases - :param limit: optional maximum number of files to process - :param fn: inner function to run the tests - :param jobs: how many processes to run at once - """ - feval = partial(evaluate_pyfile_subproc, args=args) - fn = partial(subproc_wrapper, fn=feval) - start = time.time() - stats = Stats() - errors = ErrorAggregatorDict() - testfiles = [ - os.path.join(tests_dir, f) - for f in os.listdir(tests_dir) - if re.search(r"test_.*[.]py$", f) - ] - testfiles.sort() - - if limit: - testfiles = testfiles[offset : offset + limit] - - with tqdm(total=len(testfiles)) as pbar: - if args.sequential: - for file in testfiles: - errors_part, stats_part = fn(path=file) - errors.update(errors_part) - stats.update(stats_part) - pbar.update() - else: - pool = ThreadPool(jobs) - for errors_part, stats_part in pool.imap_unordered(fn, testfiles): - errors.update(errors_part) - stats.update(stats_part) - pbar.update() - pool.close() - - errors.print_report() - log.info(f"Total time: {time.time() - start:02f} s") - log.info(stats) diff --git a/core/tests/kernel/aot_kernel_test.py b/core/tests/kernel/aot_kernel_test.py deleted file mode 100644 index 690e366a8..000000000 --- a/core/tests/kernel/aot_kernel_test.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2024 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import re -import unittest - -import torch -from shark_turbine.aot import export -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl - - -def export_softmax_kernel(): - M = tkl.sym.M - N = tkl.sym.K - - @tk.gen.kernel(M) - def softmax( - input: tkl.InputBuffer[M, N, tkl.f16], output: tkl.OutputBuffer[M, N, tkl.f16] - ): - row_index = tkl.program_id(0) - row = tkl.load(input, (row_index, 0), (1, N)) - row_minus_max = row - tkl.max(row) - numerator = tkl.exp2(row_minus_max) - denominator = tkl.sum(numerator) - softmax_output = numerator / denominator - tkl.store(output, (row_index, 0), softmax_output) - - class NN(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(64, 64, dtype=torch.float16) - - def forward(self, x): - x = self.linear(x) - x = softmax(x) - return x - - model = NN() - a = torch.ones(64, 64, dtype=torch.float16) - exported = export(model, a) - return exported - - -class AotKernelTest(unittest.TestCase): - def test_unique_naming(self): - # We test it twice to ensure that local name collisions cannot happen, - # verifying that each run generates a uniquely named kernel. This is - # a by-product of the Torch namespace being global and every one of - # these that we define being a separate incarnation based on the - # same local function name. - unique_names = set() - for _ in range(2): - exported = export_softmax_kernel() - exported.print_readable() - ir_text = str(exported.mlir_module) - matches = re.findall( - r"flow.dispatch @(tk_kernel_softmax__([0-9]+))::", ir_text - ) - self.assertEqual(1, len(matches)) - match = matches[0] - print("NAME MATCH:", match) - self.assertNotIn(match, unique_names) - unique_names.add(match) - - -if __name__ == "__main__": - import logging - - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/kernel/arith_test.py b/core/tests/kernel/arith_test.py deleted file mode 100644 index 1ef3067f9..000000000 --- a/core/tests/kernel/arith_test.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -import unittest - -import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl - -from shark_turbine.kernel.compiler import ( - builder, - kernel_codegen, - vector_codegen, -) -from shark_turbine.kernel._support import ( - indexing, -) - -M = tkl.sym.M -K = tkl.sym.K - - -class Test(unittest.TestCase): - def testIotaFx(self): - @tk.gen.thread(M) - def iota_kernel(out: tkl.OutputBuffer[M, tkl.f32]): - # Integer types - for dtype in [ - tkl.bool, - tkl.i4, - tkl.i8, - tkl.i16, - tkl.i32, - tkl.i64, - tkl.index, - ]: - a = tkl.constant((17, 37, 19), dtype, 5) - b = tkl.constant((17, 37, 19), dtype, 10) - c = tkl.constant((17, 37, 19), dtype, 2) - c = (a * b) // c - c = c + a - b - - # Float types - for dtype in [tkl.f16, tkl.f32, tkl.f64]: - a = tkl.constant((17, 37, 19), dtype, 5.0) - b = tkl.constant((17, 37, 19), dtype, 10.0) - c = tkl.constant((17, 37, 19), dtype, 2.0) - c = (a * b) / c - c = c + a - b - - with tk.gen.TestLaunchContext(): - iota_kernel(torch.zeros(17)) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/kernel/dispatch_codegen_test.py b/core/tests/kernel/dispatch_codegen_test.py deleted file mode 100644 index 2ed50b6c2..000000000 --- a/core/tests/kernel/dispatch_codegen_test.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging -import unittest - -import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl - -from shark_turbine.kernel.compiler import ( - builder, - dispatch_codegen, - kernel_codegen, - vector_codegen, -) -from shark_turbine.kernel._support import ( - indexing, -) - - -M = tk.lang.sym.M -K = tk.lang.sym.K - - -class Test(unittest.TestCase): - def testEmptyStreamExecutable(self): - @tk.gen.thread(M) - def softmax_kernel( - input: tk.lang.InputBuffer[M, K, tkl.f32], - output: tk.lang.OutputBuffer[M, K, tkl.f32], - ): - row_index = tk.lang.program_id(0) - input_row = input[row_index, :] - numerator = tkl.exp2(input_row - tkl.max(input_row)) - output_row = numerator / tkl.sum(numerator) - output[row_index, :] = output_row - - input = torch.randn(128, 64) - output = torch.zeros(128, 64) - with tk.gen.TestLaunchContext(): - softmax_kernel(input, output) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/kernel/fused_attention_test.py b/core/tests/kernel/fused_attention_test.py deleted file mode 100644 index a60d7edeb..000000000 --- a/core/tests/kernel/fused_attention_test.py +++ /dev/null @@ -1,89 +0,0 @@ -import logging -import unittest - -import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl - -BATCH = tkl.sym.BATCH -N_HEADS = tkl.sym.N_HEADS -N_CTX = tkl.sym.N_CTX -D_HEAD = tkl.sym.D_HEAD - -BLOCK_N = tkl.sym.BLOCK_N -BLOCK_M = tkl.sym.BLOCK_M - - -class Test(unittest.TestCase): - def testFusedAttention(self): - @tk.gen.thread(N_CTX // BLOCK_M, BATCH * N_HEADS) - def fused_attention( - Q: tkl.InputBuffer[BATCH, N_HEADS, N_CTX, D_HEAD, tkl.f16], - K: tkl.InputBuffer[BATCH, N_HEADS, N_CTX, D_HEAD, tkl.f16], - V: tkl.InputBuffer[BATCH, N_HEADS, N_CTX, D_HEAD, tkl.f16], - O: tkl.OutputBuffer[BATCH, N_HEADS, N_CTX, D_HEAD, tkl.f16], - ): - grid_n = tkl.program_id(0) - grid_m = tkl.program_id(1) - - batch = grid_m // N_HEADS - head = grid_m % N_HEADS - - q = tkl.load(Q, (batch, head, grid_n * BLOCK_M, 0), (BLOCK_M, D_HEAD)) - acc_init = tkl.constant((BLOCK_M, D_HEAD), tkl.f32, 0.0) - max_stat_init = tkl.constant((BLOCK_M,), tkl.f32, -1e9) - sum_stat_init = tkl.constant((BLOCK_M,), tkl.f32, 0.0) - - @tkl.for_loop( - 0, N_CTX, BLOCK_N, init_args=[max_stat_init, sum_stat_init, acc_init] - ) - def body(i, old_max, old_sum, old_acc): - k = tkl.load(K, (batch, head, i, 0), (BLOCK_N, D_HEAD)) - kT = tkl.transpose(k, (1, 0)) - - qkT = tkl.constant((BLOCK_M, BLOCK_N), tkl.f32, 0.0) - qkT = tkl.dot(q, kT, qkT) - - new_max = tkl.max(qkT, axis=1, acc=old_max) - broadcasted_max = tkl.broadcast_in_dim( - new_max, (BLOCK_M, BLOCK_N), (0,) - ) - partial_softmax = tkl.exp2(qkT - broadcasted_max) - scale_factor = tkl.exp2(old_max - new_max) - scaled_old_sum = scale_factor * old_sum - new_sum = tkl.sum(partial_softmax, axis=1, acc=scaled_old_sum) - broadcasted_scale_factor = tkl.broadcast_in_dim( - scale_factor, (BLOCK_M, D_HEAD), (0,) - ) - new_acc = old_acc * broadcasted_scale_factor - - v = tkl.load(V, (batch, head, i, 0), (BLOCK_N, D_HEAD)) - qkT16 = tkl.to_dtype(qkT, tkl.f16) - new_acc = tkl.dot(qkT16, v, new_acc) - - return (new_max, new_sum, new_acc) - - sum_stat = body[1] - result = body[2] - one = tkl.constant((BLOCK_M,), tkl.f32, 1.0) - one_by_sum = one / sum_stat - result = tkl.broadcast_in_dim(one_by_sum, (BLOCK_M, D_HEAD), (0,)) * result - tkl.store(O, (batch, head, grid_n * BLOCK_M, 0), result) - - Q = torch.randn(4, 48, 1024, 64) - K = torch.randn(4, 48, 1024, 64) - V = torch.randn(4, 48, 1024, 64) - O = torch.randn(4, 48, 1024, 64) - - with tk.gen.TestLaunchContext( - { - BLOCK_N: 128, - BLOCK_M: 256, - } - ): - fused_attention(Q, K, V, O) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/kernel/indexing_test.py b/core/tests/kernel/indexing_test.py deleted file mode 100644 index efd14b2f2..000000000 --- a/core/tests/kernel/indexing_test.py +++ /dev/null @@ -1,203 +0,0 @@ -import re -import unittest - -import torch - -from shark_turbine.kernel._support.indexing import * -from shark_turbine.kernel.lang import * - -M = sym.M -N = sym.N -K = sym.K - - -class TestTypes(unittest.TestCase): - def testGridRepr(self): - self.assertEqual("Grid", repr(Grid)) - self.assertEqual("Grid[M]", repr(Grid[M])) - self.assertEqual("Grid[M]", repr(Grid[M])) - self.assertEqual("Grid[M, N]", repr(Grid[sym.M, sym.N])) - self.assertEqual("Grid[M, M/2]", repr(Grid[M, M / 2])) - - def testGridAttrs(self): - T = Grid[M, N] - self.assertIs(T.symbolic_shape[0], M) - self.assertIs(T.symbolic_shape[1], N) - self.assertEqual(2, T.rank) - - def testShapedGridInstance(self): - G = Grid[M, N, K] - with IndexingContext() as idxc: - idxc.bind_constant(M, 1) - idxc.bind_constant(N, 2) - idxc.bind_constant(K, 3) - idxc.finalize() - g = G() - self.assertEqual(3, len(g)) - self.assertEqual(1, g[0]) - self.assertEqual([1, 2, 3], list(g)) - self.assertEqual(3, g.rank) - - def testKernelBufferRepr(self): - self.assertEqual("KernelBuffer", repr(KernelBuffer)) - self.assertEqual("KernelBuffer[M].of(f32)", repr(KernelBuffer[sym.M, f32])) - self.assertEqual("KernelBuffer[M, N].of(f32)", repr(KernelBuffer[M, N, f32])) - self.assertEqual("KernelBuffer[M, N].of(f32)", repr(KernelBuffer[M, N, f32])) - self.assertEqual( - "KernelBuffer[M, M/2].of(f32)", repr(KernelBuffer[M, M / 2, f32]) - ) - - def testKernelBufferAttrs(self): - T = KernelBuffer[M, N, f32] - self.assertIs(T.symbolic_shape[0], M) - self.assertIs(T.symbolic_shape[1], N) - self.assertEqual(2, T.rank) - - def testKernelBufferGenericInstance(self): - kb = KernelBuffer[N, M, f32](torch.empty((3, 4))) - self.assertEqual(2, kb.rank) - - def testKernelBufferInstance(self): - T1 = KernelBuffer[M, f32] - with self.assertRaisesRegex(ValueError, "mismatched symbolic rank"): - T1(torch.empty((3, 4))) - kb = T1(torch.empty((3,))) - self.assertEqual(1, kb.rank) - self.assertEqual((M,), kb.symbolic_shape) - - def testUsageAndElementTypeInstance(self): - T = InputBuffer[M, f16] - self.assertEqual("InputBuffer[M].of(f16)", repr(T)) - - -class ContextTest(unittest.TestCase): - def testConstant(self): - c = IndexingContext() - c.bind_constant(M, 4) - c.finalize() - - def testConstantConflict(self): - c = IndexingContext() - c.bind_constant(M, 4) - with self.assertRaisesRegex( - ValueError, - re.escape("Attempt to bind symbol M=5 conflicts with previous 4"), - ): - c.bind_constant(M, 5) - - def testKernelBuffers(self): - c = IndexingContext() - kb1 = KernelBuffer[M, N, f32] - c.bind_shaped(object(), kb1, (1, 2)) - c.finalize() - - def testDimConflict(self): - c = IndexingContext() - kb1 = KernelBuffer[M, M, f32] - c.bind_shaped(object(), kb1, (1, 2)) - with self.assertRaisesRegex( - ValueError, - re.escape( - "KernelBuffer[M, M].of(f32) attempt to bind dim M=2 conflicts with previous 1" - ), - ): - c.finalize() - - def testDimExprRequiredEquation(self): - c = IndexingContext() - inst = object() - kb1 = KernelBuffer[M, M / 2, f32] - c.bind_shaped(inst, kb1, (4, None)) - c.finalize() - self.assertEqual(c.eval_static_dim(inst, kb1, 0), 4) - self.assertEqual(c.eval_static_dim(inst, kb1, 1), 2) - - def testDimExprRequiredEquationNotSatisfied(self): - c = IndexingContext() - kb1 = KernelBuffer[M, N, f32] - c.bind_shaped(object(), kb1, (4, None)) - with self.assertRaisesRegex( - ValueError, - re.escape( - "KernelBuffer[M, N].of(f32)[1]=N did not resolve to a known value" - ), - ): - c.finalize() - - def testDimExprOptionalDynamicDim(self): - c = IndexingContext() - inst = object() - kb1 = KernelBuffer[M, N, f32] - c.bind_shaped(inst, kb1, (4, c.next_dyn_dim())) - c.finalize() - self.assertEqual(c.dyn_dims[0], c.eval_dim(inst, kb1, 1)) - - def testDynamicDimStaticInfoSufficient(self): - c = IndexingContext() - inst = object() - kb1 = KernelBuffer[M, M * 4, f32] - c.bind_shaped(inst, kb1, (4, c.next_dyn_dim())) - c.finalize() - self.assertEqual(16, c.eval_static_dim(inst, kb1, 1)) - - def testDimExpressionBackedDynamicDimInferenceMismatch(self): - c = IndexingContext() - kb1 = KernelBuffer[M, M / 2, f32] - c.bind_shaped(object(), kb1, (4, 3)) - with self.assertRaisesRegex( - ValueError, - re.escape( - "KernelBuffer[M, M/2].of(f32)[1]=2 was initialized with a mismatched runtime value of 3" - ), - ): - c.finalize() - - def testDependentDynamicDims(self): - c = IndexingContext() - inst = object() - kb1 = KernelBuffer[M, M * 4, f32] - c.bind_shaped(inst, kb1, (c.next_dyn_dim(), c.next_dyn_dim())) - c.finalize() - self.assertEqual(c.dyn_dims[0], c.eval_dim(inst, kb1, 0)) - self.assertEqual(c.dyn_dims[0] * 4, c.eval_dim(inst, kb1, 1)) - - -class SymIndexTest(unittest.TestCase): - def testUnbacked(self): - idxc = IndexingContext() - i = SymIndex(idxc) - self.assertEqual("UnbackedSymIndex", repr(type(i))) - - def testEqual(self): - idxc = IndexingContext() - idxc.bind_constant(M, 30) - idxc.finalize() - - t0 = backed_sym_index_type(EqualRelation(M)) - self.assertEqual("SymIndex==M", repr(t0)) - i0 = t0(idxc) - - t1 = backed_sym_index_type(EqualRelation(M + 1)) - self.assertEqual("SymIndex==(M + 1)", repr(t1)) - i1 = t1(idxc) - - def testBounded(self): - idxc = IndexingContext() - idxc.bind_constant(M, 30) - idxc.finalize() - - t = backed_sym_index_type(BoundedRelation(M, M + 1)) - self.assertEqual("SymIndex∈[M, M + 1]", repr(t)) - i = t(idxc) - - t = backed_sym_index_type( - BoundedRelation(M, M + 1, lower_inclusive=False, upper_inclusive=False) - ) - self.assertEqual("SymIndex∈(M, M + 1)", repr(t)) - - t = backed_sym_index_type(BoundedRelation(0, M, upper_inclusive=False)) - self.assertEqual("SymIndex∈[0, M)", repr(t)) - - -if __name__ == "__main__": - unittest.main() diff --git a/core/tests/kernel/simple_kernel_test.py b/core/tests/kernel/simple_kernel_test.py deleted file mode 100644 index 87cf3ed2e..000000000 --- a/core/tests/kernel/simple_kernel_test.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch - -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl - -M = tk.lang.sym.M -K = tk.lang.sym.K - - -class Test(unittest.TestCase): - def testIotaEager(self): - @tk.gen.thread(M) - def iota_kernel(out: tk.lang.OutputBuffer[M, tkl.index]): - i = tk.lang.program_id(0) - out[i] = i - - out = torch.empty(8, dtype=torch.int32) - with tk.gen.TestLaunchContext(): - iota_kernel(out) - print(out) - - def testIotaFx(self): - @tk.gen.thread(M) - def iota_kernel(out: tk.lang.KernelBuffer[M, tkl.index]): - i = tk.lang.program_id(0) - out[i] = i - - print(iota_kernel._trace().region_graph) - # Prints: - # .graph(): - # %out : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=out] - # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) - # %_global_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._global_buffer_setitem](args = (%out, %program_id, %program_id), kwargs = {}) - # return None - - def testSoftmax(self): - @tk.gen.thread(M) - def softmax_kernel( - input: tk.lang.InputBuffer[M, K, tkl.f32], - output: tk.lang.OutputBuffer[M, K, tkl.f32], - ): - row_index = tk.lang.program_id(0) - input_row = input[row_index, :] - numerator = torch.exp(input_row - tk.lang.max(input_row)) - output_row = numerator / torch.sum(numerator) - output[row_index, :] = output_row - # Some debugging info if in debug mode and processing the first row. - if tk.DEBUG and row_index == 0: - print(f"*** Input: {input}") - print(f"*** Output: {output}") - print( - f"*** Input Row[{row_index}]: {type(output_row).__name__}({input_row.shape})" - ) - print( - f"*** Output Row: {type(output_row).__name__}({output_row.shape})" - ) - - def softmax(x): - y = torch.empty_like(x) - softmax_kernel[x.shape[0]](x, y) - return y - - input = torch.rand((128, 64)) - # generated = softmax(input) - # actual = torch.softmax(input, -1) - # torch.testing.assert_close(generated, actual) - print(softmax_kernel._trace().region_graph) - # Prints: - # graph(): - # %input_1 : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=input] - # %output : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=output] - # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) - # %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%input_1, (%program_id, slice(None, None, None))), kwargs = {}) - # %max_1 : [num_users=1] = call_function[target=torch.max](args = (%getitem,), kwargs = {}) - # %sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem, %max_1), kwargs = {}) - # %exp : [num_users=2] = call_function[target=torch.exp](args = (%sub,), kwargs = {}) - # %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%exp,), kwargs = {}) - # %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%exp, %sum_1), kwargs = {}) - # %program_id_1 : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) - # %_kernel_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._kernel_buffer_setitem](args = (%output, (%program_id_1, slice(None, None, None)), %truediv), kwargs = {}) - # return None - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/kernel/types_test.py b/core/tests/kernel/types_test.py deleted file mode 100644 index 69619768c..000000000 --- a/core/tests/kernel/types_test.py +++ /dev/null @@ -1,30 +0,0 @@ -import logging -import unittest - -from shark_turbine.kernel.lang import ( - Index, -) - - -class IndexTypeTest(unittest.TestCase): - def testIndexType(self): - i = Index(5) - j = Index(-6) - self.assertIndexEqual(-1, i + j) - self.assertIndexEqual(11, i - j) - self.assertIndexEqual(-30, i * j) - self.assertIndexEqual(-1, i // j) - self.assertIndexEqual(2, Index(20) % Index(18)) - self.assertIndexEqual(16, Index(4) ** Index(2)) - self.assertIndexEqual(1, pow(Index(6), Index(8), Index(5))) - self.assertIndexEqual(-6, +j) - self.assertIndexEqual(6, -j) - - def assertIndexEqual(self, expected: int, actual): - self.assertEqual(expected, actual) - self.assertIsInstance(actual, Index) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/kernel/vector_codegen_test.py b/core/tests/kernel/vector_codegen_test.py deleted file mode 100644 index 0e5927c42..000000000 --- a/core/tests/kernel/vector_codegen_test.py +++ /dev/null @@ -1,99 +0,0 @@ -import logging -import unittest - -import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl - -M = tk.lang.sym.M -K = tk.lang.sym.K - - -class Test(unittest.TestCase): - # This test is using the compiler "the hard way" until we have all of the - # API layering in place. - def testIotaFx(self): - @tk.gen.thread(M) - def iota_kernel(out: tk.lang.OutputBuffer[M, tkl.index]): - i = tk.lang.program_id(0) - secret_value = ((i * (33 - i) + 4) % 8) // 2 - out[i] = secret_value - - with tk.gen.TestLaunchContext(): - out = torch.zeros(17, dtype=torch.int32) - - def testSoftmaxFx(self): - @tk.gen.thread(M) - def softmax_kernel( - input: tk.lang.InputBuffer[M, K, tkl.f32], - output: tk.lang.OutputBuffer[M, K, tkl.f32], - ): - row_index = tk.lang.program_id(0) - input_row = input[row_index, :] - numerator = tkl.exp2(input_row - tkl.max(input_row)) - output_row = numerator / tkl.sum(numerator) - output[row_index, :] = output_row - - with tk.gen.TestLaunchContext(): - input = torch.randn(128, 64, dtype=torch.float32) - output = torch.zeros(128, 64, dtype=torch.float32) - softmax_kernel(input, output) - - def testForLoopFx(self): - @tk.gen.thread(M) - def for_loop_kernel( - input: tk.lang.InputBuffer[M, K, tkl.f32], - output: tk.lang.OutputBuffer[M, K, tkl.f32], - ): - row_idx = tkl.program_id(0) - sum = input[row_idx, 0] - prefetch = input[row_idx, 1] - - @tkl.for_loop(2, 5, init_args=[sum, prefetch]) - def prefetch_sum(i, sum, prefetch): - new_sum = sum + prefetch - new_prefetch = input[row_idx, i] - return new_sum, new_prefetch - - output[row_idx, 0] = prefetch_sum[0] - - with tk.gen.TestLaunchContext(): - input = torch.randn(128, 64, dtype=torch.float32) - output = torch.zeros(128, 64, dtype=torch.float32) - for_loop_kernel(input, output) - - def testGemmFx(self): - N = tkl.sym.N - M = tkl.sym.M - K = tkl.sym.K - BLOCK_SIZE = tkl.sym.BLOCK_SIZE - - @tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE) - def gemm_kernel( - A: tkl.InputBuffer[N, K, tkl.f32], - B: tkl.InputBuffer[K, M, tkl.f32], - output: tkl.OutputBuffer[N, M, tkl.f32], - ): - grid_n = tkl.program_id(0) - grid_m = tkl.program_id(1) - - acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), tkl.f32, 0.0) - - @tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc]) - def body(i, c): - a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE)) - b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE)) - return (tkl.dot(a, b, c),) - - tkl.store(output, (grid_n, grid_m), body[0]) - - with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}): - A = torch.randn(512, 1024, dtype=torch.float32) - B = torch.randn(1024, 2048, dtype=torch.float32) - output = torch.zeros(512, 2048, dtype=torch.float32) - gemm_kernel(A, B, output) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/ops/iree_test.py b/core/tests/ops/iree_test.py deleted file mode 100644 index b41647d65..000000000 --- a/core/tests/ops/iree_test.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch - -import shark_turbine.ops as ops - - -class KernelRegTest(unittest.TestCase): - def testTrace(self): - t = torch.randn(3, 4) - ops.iree.trace_tensor("TEST", t) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/runtime/device_test.py b/core/tests/runtime/device_test.py deleted file mode 100644 index c37750cca..000000000 --- a/core/tests/runtime/device_test.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest -import threading - -import torch - -from iree.runtime import HalElementType - -# Public API imports. -from shark_turbine.runtime import ( - Device, -) - -# Internals. -from shark_turbine.runtime.device import ( - _CURRENT_THREAD, - get_device_from_torch, -) - -from shark_turbine.support.exceptions import * - - -class DeviceTest(unittest.TestCase): - def test_create(self): - d = Device("local-task") - self.assertEqual(repr(d), "") - - def test_current_device(self): - with self.assertRaises(NoCurrentDeviceError): - Device.current() - - d1 = Device("local-task") - d2 = Device("local-sync") - with d1: - self.assertIs(Device.current(), d1) - - with d2: - self.assertIs(Device.current(), d2) - - self.assertIs(Device.current(), d1) - - with self.assertRaises(NoCurrentDeviceError): - Device.current() - - def test_set_clear(self): - d1 = Device("local-task") - d2 = Device("local-sync") - - with self.assertRaises(MismatchedDeviceSetClearError): - d1.clear() - try: - d1.set() - self.assertIs(Device.current(), d1) - with self.assertRaises(MismatchedDeviceSetClearError): - d2.clear() - d1.clear() - with self.assertRaises(NoCurrentDeviceError): - Device.current() - finally: - # Patch it back to the reset state for testing. - _CURRENT_THREAD.stack = [] - - def test_cached_devices_same_thread(self): - d1 = Device("local-task") - d2 = Device("local-task") - self.assertIs(d1, d2) - - def test_cached_device_diff_thread(self): - devices = [None, None] - - def run_t1(): - devices[0] = Device("local-task") - - def run_t2(): - devices[1] = Device("local-task") - - t1 = threading.Thread(target=run_t1) - t2 = threading.Thread(target=run_t2) - t1.start() - t2.start() - t1.join() - t2.join() - self.assertIsNotNone(devices[0]) - self.assertIsNotNone(devices[1]) - self.assertIsNot(devices[0], devices[1]) - - -# CPU is always available so we can enable this unconditionally. -class TorchCPUInterop(unittest.TestCase): - def testFromTorchDevice(self): - torch_device = torch.device("cpu") - device1 = get_device_from_torch(torch_device) - print(device1) - self.assertIsNotNone(device1) - device2 = get_device_from_torch(torch_device) - self.assertIs(device1, device2) - - def testCpuDeviceCacheKey(self): - d = get_device_from_torch(torch.device("cpu")) - self.assertEqual(d.instance_cache_key, "local-task") - self.assertEqual(d.type_cache_key, "local-task") - - def testImportExportTorchTensor(self): - d = get_device_from_torch(torch.device("cpu")) - cpu_tensor = torch.tensor([1, 2, 3], dtype=torch.int32, device="cpu") - bv = d.import_torch_tensor(cpu_tensor) - print(bv) - self.assertEqual(bv.shape, [3]) - self.assertEqual(bv.element_type, HalElementType.SINT_32) - meta_tensor = cpu_tensor.to(device="meta") - readback_tensor = d.export_torch_tensor(bv, meta_tensor) - torch.testing.assert_close(cpu_tensor, readback_tensor) - - def testCompilerFlags(self): - d = get_device_from_torch(torch.device("cpu")) - self.assertIn("--iree-hal-target-backends=llvm-cpu", d.compile_target_flags) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/runtime/op_reg/kernel_aot_test.py b/core/tests/runtime/op_reg/kernel_aot_test.py deleted file mode 100644 index 4aa04857a..000000000 --- a/core/tests/runtime/op_reg/kernel_aot_test.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch -import torch.nn as nn - -import shark_turbine.aot as aot -import shark_turbine.ops as ops - -from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass - - -class MLP(nn.Module): - def __init__(self): - super().__init__() - self.layer0 = nn.Linear(8, 8, bias=True) - self.layer1 = nn.Linear(8, 4, bias=True) - self.layer2 = nn.Linear(4, 2, bias=True) - self.layer3 = nn.Linear(2, 2, bias=True) - - def forward(self, x: torch.Tensor): - x = self.layer0(x) - x = torch.sigmoid(x) - ops.iree.trace_tensor("LAYER0", x) - x = self.layer1(x) - x = torch.sigmoid(x) - ops.iree.trace_tensor("LAYER1", x) - x = self.layer2(x) - x = torch.sigmoid(x) - ops.iree.trace_tensor("LAYER2", x) - x = self.layer3(x) - ops.iree.trace_tensor("LAYER3", x) - return x - - -class KernelRegTest(unittest.TestCase): - def testTrace(self): - mlp = MLP() - - prog = aot.export(mlp, torch.empty(97, 8, dtype=torch.float32)) - p = ExpandCustomOpsPass(prog.mlir_module) - p.run() - - print("CUSTOM OP CONVERTED:") - module_asm = str(prog.mlir_module) - print(module_asm) - self.assertIn('flow.tensor.trace "LAYER0"', module_asm) - self.assertIn('flow.tensor.trace "LAYER1"', module_asm) - self.assertIn('flow.tensor.trace "LAYER3"', module_asm) - - def testEager(self): - mlp = MLP() - mlp.forward(torch.empty(97, 8, dtype=torch.float32)) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/runtime/op_reg/kernel_reg_test.py b/core/tests/runtime/op_reg/kernel_reg_test.py deleted file mode 100644 index 75554b046..000000000 --- a/core/tests/runtime/op_reg/kernel_reg_test.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - -import torch - -from shark_turbine.runtime.op_reg import * - -from shark_turbine.runtime.op_reg.compiler import _testing_get_cache_size - - -class KernelRegTest(unittest.TestCase): - def testRegistrationDispatchAndCache(self): - @CustomOp.register - class identity(CustomOp): - signature = "test_identity(Tensor self) -> Tensor" - - def select(self, ksel: KernelSelection): - x = ksel.arg_tensor(0) - ksel.return_tensor(x.t) - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - # This just yields the IR value of kernel input as the output. - # Effectively in eager mode, this is a `return` from the kernel - # function. - kb.yield_results(kb.arg_bindings[0]) - - self.assertIsNotNone(torch.ops.turbine.test_identity) - - start_compile_count = _testing_get_cache_size() - - # Make sure that the meta registration works. - t = torch.tensor([[1, 2, 3]], dtype=torch.int32, device="meta") - result = identity(t) - self.assertListEqual(list(result.shape), [1, 3]) - self.assertEqual(result.dtype, torch.int32) - self.assertEqual(t.device.type, "meta") - # Meta dispatch should not trigger compilation. - self.assertEqual(_testing_get_cache_size(), start_compile_count) - - # Make sure that CPU dispatch works. - t = torch.tensor([[1, 2, 3]], dtype=torch.int32) - result = identity(t) - print("CPU result:", result) - torch.testing.assert_close(result, t) - # Novel execution should compile a new kernel. - self.assertEqual(_testing_get_cache_size(), start_compile_count + 1) - - # Second run of the same kernel should serve out of cache. - result = identity(t) - torch.testing.assert_close(result, t) - # Repeated execution should use a cached kernel. - self.assertEqual(_testing_get_cache_size(), start_compile_count + 1) - - # It should recompile for different dtype. - t = torch.tensor([[1, 2, 3]], dtype=torch.int16) - result = identity(t) - print("CPU result:", result) - torch.testing.assert_close(result, t) - # Novel execution should compile a new kernel. - self.assertEqual(_testing_get_cache_size(), start_compile_count + 2) - - # It should recompile for different rank. - t = torch.tensor([1, 2, 3], dtype=torch.int16) - result = identity(t) - print("CPU result:", result) - torch.testing.assert_close(result, t) - # Novel execution should compile a new kernel. - self.assertEqual(_testing_get_cache_size(), start_compile_count + 3) - - # It should serve out of cache for same-rank but different dims. - t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.int16) - result = identity(t) - print("CPU result:", result) - torch.testing.assert_close(result, t) - self.assertEqual(_testing_get_cache_size(), start_compile_count + 3) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/top_level_package_test.py b/core/tests/top_level_package_test.py deleted file mode 100644 index 52ea796bb..000000000 --- a/core/tests/top_level_package_test.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import unittest - - -class TopLevelPackageTest(unittest.TestCase): - def testIreeTurbineRedirect(self): - # We have a temporary redirect of the top-level API to the - # iree.turbine namespace. - from iree.turbine import aot, dynamo, kernel, ops, runtime - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/transforms/general/custom_op_expansion_test.py b/core/tests/transforms/general/custom_op_expansion_test.py deleted file mode 100644 index b94e2750a..000000000 --- a/core/tests/transforms/general/custom_op_expansion_test.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -from pathlib import Path -import torch -import unittest - -from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass -from shark_turbine.runtime.op_reg import ( - def_library, - CustomOp, - KernelBuilder, - KernelSelection, -) - -from shark_turbine.support.ir_imports import ( - Context, - Module, -) - - -class PassTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.lib = def_library("expand_custom_op_pass_test") - CustomOp.register(library=cls.lib)(IdentityOp) - CustomOp.register(library=cls.lib)(PrintStringAttrOp) - CustomOp.register(library=cls.lib)(IntArgOp) - - def testTensorArgReturn(self): - m = self.run_test_case("custom_op_simple.mlir") - m_asm = str(m) - print(m_asm) - self.assertNotIn("torch.operator", m_asm) - self.assertIn( - "%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[97,8],f32> -> tensor<97x8xf32>", - m_asm, - ) - # TODO: Upgrade to a FileCheck style test so we can pattern match that - # the casts are inserted properly. - self.assertIn( - "%1 = torch_c.from_builtin_tensor %cast_0 : tensor<97x8xf32> -> !torch.vtensor<[97,8],f32>", - m_asm, - ) - - def testStringAttrArg(self): - global _TEST_STRING_ATTR - _TEST_STRING_ATTR = "" - m = self.run_test_case("custom_op_string_attr.mlir") - m_asm = str(m) - self.assertEqual(_TEST_STRING_ATTR, "TEST_VALUE") - self.assertNotIn("torch.operator", m_asm) - print(m_asm) - - def testIntArg(self): - global _TEST_STRING_ATTR - _TEST_STRING_ATTR = "" - with self.assertRaisesRegex(NotImplementedError, "arg_int"): - self.run_test_case("custom_op_int_arg.mlir") - - def run_test_case(self, file_name: str): - p = Path(__file__).resolve().parent / "testdata" / file_name - contents = p.read_text() - with Context() as ctx: - m = Module.parse(contents) - p = ExpandCustomOpsPass(m.operation) - p.run() - print(f"TEST CASE {file_name}:\n{m}") - m.operation.verify() - return m - - -class IdentityOp(CustomOp): - signature = "identity_tensor(Tensor t) -> Tensor" - - def select(self, ksel: KernelSelection): - x = ksel.arg_tensor(0) - ksel.return_tensor(x.t) - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - kb.yield_results(kb.arg_bindings[0]) - - -class PrintStringAttrOp(CustomOp): - signature = "print_string_attr(str key) -> ()" - - def select(self, ksel: KernelSelection): - ksel.attr_str(0) - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - global _TEST_STRING_ATTR - _TEST_STRING_ATTR = str(ksel.arg_descs[0].v) - print("CAPTURED STRING ATTR:", _TEST_STRING_ATTR) - kb.yield_results() - - -class IntArgOp(CustomOp): - signature = "int_arg(int t) -> ()" - - def select(self, ksel: KernelSelection): - x = ksel.arg_int(0) - ksel.return_int() - - def generate(self, ksel: KernelSelection, kb: KernelBuilder): - kb.yield_results(kb.arg_bindings[0]) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/transforms/general/rename_parameters_test.py b/core/tests/transforms/general/rename_parameters_test.py deleted file mode 100644 index 203e6b455..000000000 --- a/core/tests/transforms/general/rename_parameters_test.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from pathlib import Path -import logging -import unittest - -from iree.compiler.ir import ( - Context, - Operation, -) - -from shark_turbine.transforms import rewriter -from shark_turbine.transforms.general import rename_parameters - -SIMPLE_GLOBALS_ASM = r""" -module { - util.global private @_params.classifier.default {noinline} = #stream.parameter.named<"default"> : tensor<30xf32> - util.global private @_params.classifier.weight {noinline} = #stream.parameter.named<"foo"::"WEIGHT"> : tensor<30x20xf32> - util.global private @_params.classifier.bias {noinline} = #stream.parameter.named<"foo"::"params.classifier.bias"> : tensor<30xf32> - util.global private @_params.classifier.other {noinline} = dense<0.0> : tensor<30xf32> - util.global private @_uninitialized {noinline} : tensor<30xf32> -} -""" - - -class RenameTest(unittest.TestCase): - def testBasic(self): - with Context() as context: - module_op = Operation.parse(SIMPLE_GLOBALS_ASM) - rename_parameters.RenameParametersPass( - module_op, - rename_map={ - "WEIGHT": "weight", - ("foo", "params.classifier.bias"): ("bar", "BIAS"), - }, - rename_callback=lambda scope, name: ("XXX", "YYY") - if name == "default" - else None, - ).run() - module_asm = str(module_op) - print(module_asm) - self.assertIn( - '@_params.classifier.default {noinline} = #stream.parameter.named<"XXX"::"YYY"> : tensor<30xf32>', - module_asm, - ) - self.assertIn( - '@_params.classifier.weight {noinline} = #stream.parameter.named<"foo"::"weight"> : tensor<30x20xf32>', - module_asm, - ) - self.assertIn( - '@_params.classifier.bias {noinline} = #stream.parameter.named<"bar"::"BIAS"> : tensor<30xf32>', - module_asm, - ) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/core/tests/transforms/general/testdata/custom_op_int_arg.mlir b/core/tests/transforms/general/testdata/custom_op_int_arg.mlir deleted file mode 100644 index 1a17b9a6a..000000000 --- a/core/tests/transforms/general/testdata/custom_op_int_arg.mlir +++ /dev/null @@ -1,9 +0,0 @@ -builtin.module { - -func.func @forward() { - %i = torch.constant.int 1000 - torch.operator "torch.expand_custom_op_pass_test.int_arg"(%i) : (!torch.int) -> () - return -} - -} diff --git a/core/tests/transforms/general/testdata/custom_op_simple.mlir b/core/tests/transforms/general/testdata/custom_op_simple.mlir deleted file mode 100644 index b0a879f9b..000000000 --- a/core/tests/transforms/general/testdata/custom_op_simple.mlir +++ /dev/null @@ -1,8 +0,0 @@ -builtin.module { - -func.func @forward(%arg0: !torch.vtensor<[97,8],f32>) -> !torch.vtensor<[97,8],f32> { - %0 = torch.operator "torch.expand_custom_op_pass_test.identity_tensor"(%arg0) : (!torch.vtensor<[97,8],f32>) -> (!torch.vtensor<[97,8],f32>) - return %0 : !torch.vtensor<[97,8],f32> -} - -} diff --git a/core/tests/transforms/general/testdata/custom_op_string_attr.mlir b/core/tests/transforms/general/testdata/custom_op_string_attr.mlir deleted file mode 100644 index c534a0745..000000000 --- a/core/tests/transforms/general/testdata/custom_op_string_attr.mlir +++ /dev/null @@ -1,9 +0,0 @@ -builtin.module { - -func.func @forward() { - %str = torch.constant.str "TEST_VALUE" - torch.operator "torch.expand_custom_op_pass_test.print_string_attr"(%str) : (!torch.str) -> () - return -} - -} diff --git a/core/tests/transforms/quantization/mm_f32_to_int4.mlir b/core/tests/transforms/quantization/mm_f32_to_int4.mlir deleted file mode 100644 index 09dd1ac2b..000000000 --- a/core/tests/transforms/quantization/mm_f32_to_int4.mlir +++ /dev/null @@ -1,12 +0,0 @@ -module @state_update { - util.global private @_params.model.layers.0.self_attn.q_proj.weight {noinline} : tensor<4096x4096xf32> - func.func @initialize(%arg0: !torch.vtensor<[?,4096],f32>) -> (!torch.vtensor<[?,4096],f32>) { - %_params.model.layers.0.self_attn.q_proj.weight = util.global.load @_params.model.layers.0.self_attn.q_proj.weight : tensor<4096x4096xf32> - %55 = torch_c.from_builtin_tensor %_params.model.layers.0.self_attn.q_proj.weight : tensor<4096x4096xf32> -> !torch.vtensor<[4096,4096],f32> - %int0_74 = torch.constant.int 0 - %int1_75 = torch.constant.int 1 - %56 = torch.aten.transpose.int %55, %int0_74, %int1_75 : !torch.vtensor<[4096,4096],f32>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f32> - %59 = torch.aten.mm %arg0, %56 : !torch.vtensor<[?,4096],f32>, !torch.vtensor<[4096,4096],f32> -> !torch.vtensor<[?,4096],f32> - return %59 : !torch.vtensor<[?,4096],f32> - } -} diff --git a/core/tests/transforms/quantization/mm_group_quant_test.py b/core/tests/transforms/quantization/mm_group_quant_test.py deleted file mode 100644 index c6870d2c3..000000000 --- a/core/tests/transforms/quantization/mm_group_quant_test.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# Portions Copyright 2022 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from pathlib import Path -import logging -import unittest - -from iree.compiler.ir import ( - Context, - Operation, -) - -from shark_turbine.transforms import rewriter -from shark_turbine.transforms.quantization import mm_group_quant - -MM_F32_TO_INT4_CONTENTS = ( - Path(__file__).resolve().parent / "mm_f32_to_int4.mlir" -).read_text() - - -class Int4Quant(unittest.TestCase): - def setUp(self): - self.MM_F32_TO_INT4_CONTENTS = ( - Path(__file__).resolve().parent / "mm_f32_to_int4.mlir" - ).read_text() - - def testBasic(self): - with Context() as context: - module_op = Operation.parse(self.MM_F32_TO_INT4_CONTENTS) - mm_group_quant.MMGroupQuantRewriterPass(module_op).run() - module_asm = str(module_op) - print(module_asm) - self.assertNotIn("torch.aten.mm", module_asm) - self.assertNotIn( - "@_params.model.layers.0.self_attn.q_proj.weight ", module_asm - ) - self.assertIn("linalg.generic", module_asm) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/docs/developer/known_issues.md b/docs/developer/known_issues.md deleted file mode 100644 index 05d045671..000000000 --- a/docs/developer/known_issues.md +++ /dev/null @@ -1,95 +0,0 @@ -# Known Issues in SHARK-Turbine - -## Dealing with functional variants of Torch Ops - -```py -import torch.nn.functional as F -def forward(self, x): - return F.max_pool2d(8, x) -``` -``` -# occuring in importer -> import_list_arguments -compiler_fn raised IndexError: list index out of range -``` - -Currently, we have issues dealing with functional variants of -torch operations that do not define meaningful defaults for their arguments. -Two common operations for which this issue arises are `F.avg_pool2d` and `F.max_pool2d`. -Taking `max_pool2d` as an example, the [functional version](https://pytorch.org/docs/stable/generated/torch.nn.functional.max_pool2d.html) sets `stride=None` by default (which returns an empty list to the importer), -however, the actual intended default setting is to set `stride=kernel_size`. This issue does not occur with the corresponding `nn.Module` wrapper `MaxPool2d` because -it actually [manually sets the intended default value](https://pytorch.org/docs/stable/_modules/torch/nn/modules/pooling.html#_MaxPoolNd). The same issue is at play in `avg_pool2d`. - -## Ephemeral Tensor objects from `aten.lift_fresh_copy.default` -```py -import torch -def forward(self): - return torch.tensor([1,2]) -``` -``` -# in importer -> import_argument -torch._dynamo.exc.BackendCompilerFailed: compiler_fn raised KeyError: (_tensor_constant0, 0) -torch._dynamo.exc.BackendCompilerFailed: compiler_fn raised AssertionError: Can not create literal tensor for unsupported datatype: torch.complex64 -``` -This error arises due to an odd case in the Fx Graph generation where the -graph module for our code generates a node `_tensor_constant0 = self._tensor_constant0` with no traceable origin within -the graph. Torch dynamo dynamically creates this attribute in the top level module object, hence this object is never -passed through our importer, meaning that our lookup for the appropriate MlirValue in the importer's `_v` table fails. This consistently -occurs when the graph generates an intermediate `aten.lift_fresh_copy` as in the case of creating a new tensor above. - -We now have a fix for this by directly instantiating the object using a reference to the top level graph module in the -importer, but this method does not support all torch datatypes - in particular it fails to support `bfloat16` and -complex datatypes. - - -## Assertion failure in `aten.lift` in the aot_eager, inductor, and turbine backend. -```python -import torch -def forward(self, x): - return torch.ops.aten.lift(x) -``` -``` -RuntimeError: !at::functionalization::impl::isFunctionalTensor(self) -INTERNAL ASSERT FAILED at "../aten/src/ATen/FunctionalizeFallbackKernel.cpp":167, please report a bug to PyTorch. -``` -[`aten.lift`](https://github.com/pytorch/pytorch/blob/3a3cf0e09d475df9237c95ebd14debf650e0c038/aten/src/ATen/native/native_functions.yaml#L7583) seems to fail the [functionalization stage](https://github.com/pytorch/pytorch/blob/3a3cf0e09d475df9237c95ebd14debf650e0c038/aten/src/ATen/FunctionalizeFallbackKernel.cpp#L176), -in particular it seems that the input tensor fails an [assertion](https://github.com/pytorch/pytorch/blob/3a3cf0e09d475df9237c95ebd14debf650e0c038/aten/src/ATen/FunctionalTensorWrapper.cpp#L575) that it is of functional form. - -[PyTorch Issue](https://github.com/pytorch/pytorch/issues/107961) - -## TorchDynamo failure in training backward due to `aten.scalar_tensor` output not wrapped as a fake tensor - -```python -class LinearModel(nn.Module): - def __init__(self, input_dim, output_dim): - super(LinearModel, self).__init__() - self.linear = nn.Linear(input_dim, output_dim) - - def forward(self, x): - x = x.view(x.size(0), -1) - out = self.linear(x) - return out -``` -During the training in backwards, -`aten.where.self` expects fake inputs, but `aten.scalar_tensor` output is not wrapped as a fake tensor. -``` -File "/home/brucekimrok/CLionProjects/SHARK-Turbine/tvenv3.11/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1632, in validate -raise Exception( -Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.where.self(FakeTensor(..., size=(64, 1), dtype=torch.bool), FakeTensor(..., size=(64, 1), dtype=torch.int64), tensor(0, size=())) -``` -https://github.com/pytorch/pytorch/blob/98c8550158a4a79c4d39533a5331c5953f6ea279/torch/_subclasses/fake_tensor.py#L1657-L1669 - -Relevant issue is raised here: [PyTorch Issue](https://github.com/pytorch/pytorch/issues/92941). -However, this case is about when DDP optimization + Dynamo + `aten.where` are invoked. -The [PR](https://github.com/pytorch/pytorch/pull/92986) to address this issue was made in `torch/_dynamo/optimizations/distributed.py`. -In our case, we do not use DDP optimization. - -## FX emitted as None due to bug in TorchScript in `aten.convolution_backward` - -When schema calls for a Tensor, sometimes None is emitted due to the way TS is maintained. -For `convolution_backward` op, TS has a problem of returning None when output_mask=[True, True, False]. -In eager mode, similar can happen. -https://github.com/pytorch/pytorch/issues/97524 - -Vivek [fixed movdedim](https://github.com/llvm/torch-mlir/pull/1773) to allow torch-mlir emitted when output_mask=[True, True, True] -So we should find a way to set Output_mask = [True, True, True] to fix this issue. - diff --git a/docs/releasing.md b/docs/releasing.md deleted file mode 100644 index d9bebf81e..000000000 --- a/docs/releasing.md +++ /dev/null @@ -1,75 +0,0 @@ -# Releasing SHARK-Turbine/core - -There are multiple release artifacts that are deployed from this project: - -* shark-turbine wheel (transitional while switching to iree-turbine) -* iree-turbine wheel -* iree-compiler wheels -* iree-runtime wheels - -Typically we deploy IREE compiler and runtime wheels along with a turbine -release, effectively promoting a nightly. - -## Building Artifacts - -Build a pre-release: - -``` -./build_tools/build_release.py --core-version 2.3.0 --core-pre-version=rcYYYYMMDD -``` - -Build an official release: - -``` -./build_tools/build_release.py --core-version 2.3.0 -``` - -This will download all deps, including wheels for all supported platforms and -Python versions for iree-compiler and iree-runtime. All wheels will be placed -in the `wheelhouse/` directory. - - -## Testing - -TODO: Write a script for this. - -``` -python -m venv wheelhouse/test.venv -source wheelhouse/test.venv/bin/activate -pip install -f wheelhouse iree-turbine[testing] -# Temp: tests require torchvision. -pip install -f wheelhouse torchvision -pytest core/tests -``` - -## Push - -From the testing venv, verify that everything is sane: - -``` -pip freeze -``` - -Push IREE deps (if needed/updated): - -``` -twine upload wheelhouse/iree_compiler-* wheelhouse/iree_runtime-* -``` - -Push built wheels: - -``` -twine upload wheelhouse/iree_turbine-* wheelhouse/shark_turbine-* -``` - -## Install from PyPI and Sanity Check - -TODO: Script this - -From the testing venv: - -``` -pip uninstall -y shark-turbine iree-turbine iree-compiler iree-runtime -pip install iree-turbine -pytest core/tests -``` diff --git a/docs/roadmap.md b/docs/roadmap.md deleted file mode 100644 index 394f29e7d..000000000 --- a/docs/roadmap.md +++ /dev/null @@ -1,24 +0,0 @@ -# Turbine Roadmap - -## Path to V1 - -In the short term, we are investing in Turbine to make it feature complete -and usable for: - -* *AOT export* (https://github.com/nod-ai/SHARK-Turbine/issues/125) -* *Serving SHARK Gen-AI Models* (https://github.com/nod-ai/SHARK-Turbine/issues/119) -* *Integration with Quantization Workflows* (https://github.com/nod-ai/SHARK-Turbine/issues/120) -* *General eager execution* (https://github.com/nod-ai/SHARK-Turbine/issues/105) -* *FX Importer Completion* (https://github.com/nod-ai/SHARK-Turbine/issues/139) - -We are also investing in project infrastructure: - -* *Releasing* (https://github.com/nod-ai/SHARK-Turbine/issues/121) -* *CI and Testing* (https://github.com/nod-ai/SHARK-Turbine/issues/122) -* *Upstreaming* (https://github.com/nod-ai/SHARK-Turbine/issues/123) -* *User Documentation and Samples* (https://github.com/nod-ai/SHARK-Turbine/issues/124) -* *Developer/Debugging Workflows* (https://github.com/nod-ai/SHARK-Turbine/issues/75) - -Roughly, we will aim to align delivery of these features with the next official -PyTorch release, at which point we will track PyTorch versions as our versioning -scheme.